/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.rules.logical;

import java.util.HashMap;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.flink.calcite.shaded.com.google.common.collect.ImmutableList;
import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable;
import org.apache.flink.table.planner.plan.metadata.FlinkRelMetadataQuery;
import org.apache.flink.table.planner.plan.rules.logical.ImmutableAggregateReduceGroupingRule;
import org.immutables.value.Value;

@Value.Enclosing
public class AggregateReduceGroupingRule
extends RelRule<AggregateReduceGroupingRuleConfig> {
    public static final AggregateReduceGroupingRule INSTANCE = AggregateReduceGroupingRuleConfig.DEFAULT.toRule();

    protected AggregateReduceGroupingRule(AggregateReduceGroupingRuleConfig config) {
        super(config);
    }

    @Override
    public boolean matches(RelOptRuleCall call) {
        Aggregate agg = (Aggregate)call.rel(0);
        return agg.getGroupCount() > 1 && agg.getGroupType() == Aggregate.Group.SIMPLE;
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        FlinkRelMetadataQuery fmq;
        ImmutableBitSet newGrouping;
        Aggregate agg = (Aggregate)call.rel(0);
        RelDataType aggRowType = agg.getRowType();
        RelNode input = agg.getInput();
        RelDataType inputRowType = input.getRowType();
        ImmutableBitSet originalGrouping = agg.getGroupSet();
        ImmutableBitSet uselessGrouping = originalGrouping.except(newGrouping = (fmq = FlinkRelMetadataQuery.reuseOrCreate(call.getMetadataQuery())).getUniqueGroups(input, originalGrouping));
        if (uselessGrouping.isEmpty()) {
            return;
        }
        HashMap<Integer, Integer> indexOldToNewMap = new HashMap<Integer, Integer>();
        List<Integer> newGroupingList = newGrouping.toList();
        int idxOfNewGrouping = 0;
        int idxOfAggCallsForDroppedGrouping = newGroupingList.size();
        int index = 0;
        for (int column2 : originalGrouping) {
            if (newGroupingList.contains(column2)) {
                indexOldToNewMap.put(index, idxOfNewGrouping);
                ++idxOfNewGrouping;
            } else {
                indexOldToNewMap.put(index, idxOfAggCallsForDroppedGrouping);
                ++idxOfAggCallsForDroppedGrouping;
            }
            ++index;
        }
        assert (indexOldToNewMap.size() == originalGrouping.cardinality());
        for (int i2 = originalGrouping.cardinality(); i2 < aggRowType.getFieldCount(); ++i2) {
            indexOldToNewMap.put(i2, i2);
        }
        List<AggregateCall> aggCallsForDroppedGrouping = uselessGrouping.asList().stream().map(column -> {
            RelDataType fieldType = inputRowType.getFieldList().get((int)column).getType();
            String fieldName = inputRowType.getFieldNames().get((int)column);
            return AggregateCall.create(FlinkSqlOperatorTable.AUXILIARY_GROUP, false, false, false, ImmutableList.of(), ImmutableList.of(column), -1, null, RelCollations.EMPTY, fieldType, fieldName);
        }).collect(Collectors.toList());
        aggCallsForDroppedGrouping.addAll(agg.getAggCallList());
        Aggregate newAgg = agg.copy(agg.getTraitSet(), input, newGrouping, ImmutableList.of(newGrouping), aggCallsForDroppedGrouping);
        RelBuilder builder = call.builder();
        builder.push(newAgg);
        List projects = IntStream.range(0, newAgg.getRowType().getFieldCount()).mapToObj(i -> {
            Integer refIndex = (Integer)indexOldToNewMap.get(i);
            if (refIndex == null) {
                throw new IllegalArgumentException("Illegal index: " + i);
            }
            return builder.field(refIndex);
        }).collect(Collectors.toList());
        builder.project(projects, aggRowType.getFieldNames());
        call.transformTo(builder.build());
    }

    @Value.Immutable(singleton=false)
    public static interface AggregateReduceGroupingRuleConfig
    extends RelRule.Config {
        public static final AggregateReduceGroupingRuleConfig DEFAULT = ImmutableAggregateReduceGroupingRule.AggregateReduceGroupingRuleConfig.builder().relBuilderFactory(RelFactories.LOGICAL_BUILDER).operandSupplier(b0 -> b0.operand(Aggregate.class).anyInputs()).description("AggregateReduceGroupingRule").build();

        @Override
        default public AggregateReduceGroupingRule toRule() {
            return new AggregateReduceGroupingRule(this);
        }
    }
}

