// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.expression.rules;

import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper;
import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Coalesce;
import org.apache.doris.nereids.trees.expressions.functions.scalar.NullIf;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Nullable;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Nvl;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.StringType;

import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Test;

public class SimplifyConditionalFunctionTest extends ExpressionRewriteTestHelper {
    @Test
    public void testCoalesce() {
        executor = new ExpressionRuleExecutor(ImmutableList.of(bottomUp((SimplifyConditionalFunction.INSTANCE))));
        SlotReference slot = new SlotReference("a", StringType.INSTANCE, true);
        SlotReference nonNullableSlot = new SlotReference("b", StringType.INSTANCE, false);

        // coalesce(null, null, nullable_slot) -> nullable_slot
        assertRewrite(new Coalesce(NullLiteral.INSTANCE, NullLiteral.INSTANCE, slot), slot);

        // coalesce(null, null, nullable_slot, slot) -> coalesce(nullable_slot, slot)
        assertRewrite(new Coalesce(NullLiteral.INSTANCE, NullLiteral.INSTANCE, slot, slot),
                new Coalesce(slot, slot));

        // coalesce(null, null, non-nullable_slot, slot) -> non-nullable_slot
        assertRewrite(new Coalesce(NullLiteral.INSTANCE, NullLiteral.INSTANCE, nonNullableSlot, slot),
                nonNullableSlot);

        // coalesce(non-nullable_slot, ...) -> non-nullable_slot
        assertRewrite(new Coalesce(nonNullableSlot, NullLiteral.INSTANCE, nonNullableSlot, slot),
                nonNullableSlot);

        // coalesce(nullable_slot, slot) -> coalesce(nullable_slot, slot)
        assertRewrite(new Coalesce(slot, nonNullableSlot), new Coalesce(slot, nonNullableSlot));

        // coalesce(null, null) -> null
        assertRewrite(new Coalesce(NullLiteral.INSTANCE, NullLiteral.INSTANCE), NullLiteral.INSTANCE);

        // coalesce(null) -> null
        assertRewrite(new Coalesce(NullLiteral.INSTANCE), NullLiteral.INSTANCE);

        // coalesce(non-nullable_slot) -> non-nullable_slot
        assertRewrite(new Coalesce(nonNullableSlot), nonNullableSlot);

        // coalesce(non-nullable_slot) -> non-nullable_slot
        assertRewrite(new Coalesce(slot), slot);

        // coalesce(null, nullable_slot, literal) -> coalesce(nullable_slot, slot, literal)
        assertRewrite(new Coalesce(slot, nonNullableSlot), new Coalesce(slot, nonNullableSlot));

        SlotReference datetimeSlot = new SlotReference("dt", DateTimeV2Type.of(0), false);
        // coalesce(null_datetime(0), non-nullable_slot_datetime(6))
        assertRewrite(
                new Coalesce(new NullLiteral(DateTimeV2Type.of(6)), datetimeSlot),
                new Cast(datetimeSlot, DateTimeV2Type.of(6))
        );
        // coalesce(non-nullable_slot_datetime(6), null_datetime(0))
        assertRewrite(
                new Coalesce(datetimeSlot, new NullLiteral(DateTimeV2Type.of(6))),
                new Cast(datetimeSlot, DateTimeV2Type.of(6))
        );
    }

    @Test
    public void testNvl() {
        executor = new ExpressionRuleExecutor(ImmutableList.of(bottomUp((SimplifyConditionalFunction.INSTANCE))));
        SlotReference slot = new SlotReference("a", StringType.INSTANCE, true);
        SlotReference nonNullableSlot = new SlotReference("b", StringType.INSTANCE, false);
        // nvl(null, nullable_slot) -> nullable_slot
        assertRewrite(new Nvl(NullLiteral.INSTANCE, slot), slot);

        // nvl(null, non-nullable_slot) -> non-nullable_slot
        assertRewrite(new Nvl(NullLiteral.INSTANCE, nonNullableSlot), nonNullableSlot);

        // nvl(nullable_slot, nullable_slot) -> nvl(nullable_slot, nullable_slot)
        assertRewrite(new Nvl(slot, nonNullableSlot), new Nvl(slot, nonNullableSlot));

        // nvl(non-nullable_slot, null) -> non-nullable_slot
        assertRewrite(new Nvl(nonNullableSlot, NullLiteral.INSTANCE), nonNullableSlot);

        // nvl(null, null) -> null
        assertRewrite(new Nvl(NullLiteral.INSTANCE, NullLiteral.INSTANCE), NullLiteral.INSTANCE);

        SlotReference datetimeSlot = new SlotReference("dt", DateTimeV2Type.of(0), false);
        // nvl(null_datetime(0), non-nullable_slot_datetime(6))
        assertRewrite(
                new Nvl(new NullLiteral(DateTimeV2Type.of(6)), datetimeSlot),
                new Cast(datetimeSlot, DateTimeV2Type.of(6))
        );
        // nvl(non-nullable_slot_datetime(6), null_datetime(0))
        assertRewrite(
                new Nvl(datetimeSlot, new NullLiteral(DateTimeV2Type.of(6))),
                new Cast(datetimeSlot, DateTimeV2Type.of(6))
        );
    }

    @Test
    public void testNullIf() {
        executor = new ExpressionRuleExecutor(ImmutableList.of(bottomUp((SimplifyConditionalFunction.INSTANCE))));
        SlotReference slot = new SlotReference("a", StringType.INSTANCE, true);
        SlotReference nonNullableSlot = new SlotReference("b", StringType.INSTANCE, false);
        // nullif(null, slot) -> null
        assertRewrite(new NullIf(NullLiteral.INSTANCE, slot),
                new Nullable(new NullLiteral(StringType.INSTANCE)));

        // nullif(nullable_slot, null) -> slot
        assertRewrite(new NullIf(slot, NullLiteral.INSTANCE), new Nullable(slot));

        // nullif(non-nullable_slot, null) -> non-nullable_slot
        assertRewrite(new NullIf(nonNullableSlot, NullLiteral.INSTANCE), new Nullable(nonNullableSlot));

        // nullif(null_datetime(0), null_datetime(6)) -> null_datetime(6)
        assertRewrite(
                new NullIf(
                        new NullLiteral(DateTimeV2Type.of(0)),
                        new NullLiteral(DateTimeV2Type.of(6))
                ),
                new Cast(new Nullable(new NullLiteral(DateTimeV2Type.of(0))), DateTimeV2Type.of(6))
        );
    }

}
