/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.runtime.operators.join.stream.keyselector;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.runtime.operators.join.stream.keyselector.JoinKeyExtractor;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;

public class AttributeBasedJoinKeyExtractor
implements JoinKeyExtractor,
Serializable {
    private static final long serialVersionUID = 1L;
    private final Map<Integer, List<ConditionAttributeRef>> joinAttributeMap;
    private final List<RowType> inputTypes;
    private final Map<Integer, List<KeyExtractor>> inputIdToExtractorsMap;
    private final Map<Integer, List<Integer>> inputKeyFieldIndices;
    private final Map<Integer, List<KeyExtractor>> commonJoinKeyExtractors;
    private RowType commonJoinKeyType;

    public AttributeBasedJoinKeyExtractor(Map<Integer, List<ConditionAttributeRef>> joinAttributeMap, List<RowType> inputTypes) {
        this.joinAttributeMap = joinAttributeMap;
        this.inputTypes = inputTypes;
        this.inputIdToExtractorsMap = new HashMap<Integer, List<KeyExtractor>>();
        this.inputKeyFieldIndices = new HashMap<Integer, List<Integer>>();
        this.commonJoinKeyExtractors = new HashMap<Integer, List<KeyExtractor>>();
        this.initializeCaches();
        this.initializeCommonJoinKeyStructures();
    }

    @Override
    public RowData getJoinKey(RowData row, int inputId) {
        if (inputId == 0) {
            return null;
        }
        List<ConditionAttributeRef> attributeMapping = this.joinAttributeMap.get(inputId);
        if (attributeMapping == null || attributeMapping.isEmpty()) {
            return null;
        }
        List<Integer> keyFieldIndices = this.inputKeyFieldIndices.get(inputId);
        if (keyFieldIndices == null || keyFieldIndices.isEmpty()) {
            return null;
        }
        return this.buildKeyRow(row, inputId, keyFieldIndices);
    }

    @Override
    public RowData getLeftSideJoinKey(int depth, RowData joinedRowData) {
        if (depth == 0) {
            return null;
        }
        List<KeyExtractor> keyExtractors = this.inputIdToExtractorsMap.get(depth);
        if (keyExtractors == null || keyExtractors.isEmpty()) {
            return null;
        }
        return this.buildKeyRow(keyExtractors, joinedRowData);
    }

    @Override
    @Nullable
    public RowType getJoinKeyType(int inputId) {
        if (inputId == 0) {
            return null;
        }
        List<Integer> keyFieldIndices = this.createJoinKeyFieldInputExtractors(inputId);
        if (keyFieldIndices.isEmpty()) {
            return null;
        }
        return this.buildJoinKeyType(inputId, keyFieldIndices);
    }

    @Override
    public int[] getJoinKeyIndices(int inputId) {
        List<Integer> keyFieldIndices = this.inputKeyFieldIndices.get(inputId);
        if (keyFieldIndices == null) {
            return new int[0];
        }
        return keyFieldIndices.stream().mapToInt(i -> i).toArray();
    }

    @Override
    public RowType getCommonJoinKeyType() {
        return Objects.requireNonNullElseGet(this.commonJoinKeyType, () -> RowType.of((LogicalType[])new LogicalType[0]));
    }

    @Override
    @Nullable
    public RowData getCommonJoinKey(RowData row, int inputId) {
        List<KeyExtractor> extractors = this.commonJoinKeyExtractors.get(inputId);
        if (extractors == null || extractors.isEmpty()) {
            return null;
        }
        return this.buildCommonJoinKey(row, extractors);
    }

    @Override
    public int[] getCommonJoinKeyIndices(int inputId) {
        List<KeyExtractor> extractors = this.commonJoinKeyExtractors.get(inputId);
        if (extractors == null || extractors.isEmpty()) {
            return new int[0];
        }
        return extractors.stream().mapToInt(KeyExtractor::getFieldIndexInSourceRow).toArray();
    }

    private void initializeCaches() {
        if (this.inputTypes != null) {
            for (int i = 0; i < this.inputTypes.size(); ++i) {
                this.inputIdToExtractorsMap.put(i, this.createLeftJoinKeyFieldExtractors(i));
                this.inputKeyFieldIndices.put(i, this.createJoinKeyFieldInputExtractors(i));
            }
        }
    }

    private List<KeyExtractor> createLeftJoinKeyFieldExtractors(int depth) {
        if (depth == 0) {
            return Collections.emptyList();
        }
        List<ConditionAttributeRef> attributeMapping = this.joinAttributeMap.get(depth);
        if (attributeMapping == null || attributeMapping.isEmpty()) {
            return Collections.emptyList();
        }
        ArrayList<KeyExtractor> keyExtractors = new ArrayList<KeyExtractor>();
        for (ConditionAttributeRef entry : attributeMapping) {
            AttributeRef leftAttrRef = AttributeBasedJoinKeyExtractor.getLeftAttributeRef(depth, entry);
            keyExtractors.add(this.createKeyExtractor(leftAttrRef));
        }
        keyExtractors.sort(Comparator.comparingInt(KeyExtractor::getInputIdToAccess).thenComparingInt(KeyExtractor::getFieldIndexInSourceRow));
        return keyExtractors;
    }

    private static AttributeRef getLeftAttributeRef(int depth, ConditionAttributeRef entry) {
        AttributeRef leftAttrRef = new AttributeRef(entry.leftInputId, entry.leftFieldIndex);
        if (leftAttrRef.inputId >= depth) {
            throw new IllegalStateException("Invalid joinAttributeMap configuration for depth " + depth + ". Left attribute " + String.valueOf(leftAttrRef) + " does not reference a previous input (< " + depth + ").");
        }
        return leftAttrRef;
    }

    private KeyExtractor createKeyExtractor(AttributeRef attrRef) {
        RowType rowType = this.inputTypes.get(attrRef.inputId);
        this.validateFieldIndex(attrRef.inputId, attrRef.fieldIndex, rowType);
        LogicalType fieldType = rowType.getTypeAt(attrRef.fieldIndex);
        int absoluteFieldIndex = attrRef.fieldIndex;
        for (int i = 0; i < attrRef.inputId; ++i) {
            absoluteFieldIndex += this.inputTypes.get(i).getFieldCount();
        }
        return new KeyExtractor(attrRef.inputId, attrRef.fieldIndex, absoluteFieldIndex, fieldType);
    }

    private List<Integer> createJoinKeyFieldInputExtractors(int inputId) {
        List<ConditionAttributeRef> attributeMapping = this.joinAttributeMap.get(inputId);
        if (attributeMapping == null) {
            return Collections.emptyList();
        }
        return attributeMapping.stream().filter(rightAttrRef -> rightAttrRef.rightInputId == inputId).map(rightAttrRef -> rightAttrRef.rightFieldIndex).distinct().sorted().collect(Collectors.toList());
    }

    private RowData buildKeyRow(List<KeyExtractor> keyExtractors, RowData joinedRowData) {
        if (keyExtractors.isEmpty()) {
            return null;
        }
        GenericRowData keyRow = new GenericRowData(keyExtractors.size());
        for (int i = 0; i < keyExtractors.size(); ++i) {
            keyRow.setField(i, keyExtractors.get(i).getLeftSideKey(joinedRowData));
        }
        return keyRow;
    }

    private GenericRowData buildKeyRow(RowData sourceRow, int inputId, List<Integer> keyFieldIndices) {
        GenericRowData keyRow = new GenericRowData(keyFieldIndices.size());
        RowType rowType = this.inputTypes.get(inputId);
        for (int i = 0; i < keyFieldIndices.size(); ++i) {
            int fieldIndex = keyFieldIndices.get(i);
            this.validateFieldIndex(inputId, fieldIndex, rowType);
            LogicalType fieldType = rowType.getTypeAt(fieldIndex);
            RowData.FieldGetter fieldGetter = RowData.createFieldGetter((LogicalType)fieldType, (int)fieldIndex);
            Object value = fieldGetter.getFieldOrNull(sourceRow);
            keyRow.setField(i, value);
        }
        return keyRow;
    }

    private RowData buildCommonJoinKey(RowData row, List<KeyExtractor> extractors) {
        GenericRowData commonJoinKeyRow = new GenericRowData(extractors.size());
        for (int i = 0; i < extractors.size(); ++i) {
            commonJoinKeyRow.setField(i, extractors.get(i).getRightSideKey(row));
        }
        return commonJoinKeyRow;
    }

    private RowType buildJoinKeyType(int inputId, List<Integer> keyFieldIndices) {
        RowType originalRowType = this.inputTypes.get(inputId);
        LogicalType[] keyTypes = new LogicalType[keyFieldIndices.size()];
        String[] keyNames = new String[keyFieldIndices.size()];
        for (int i = 0; i < keyFieldIndices.size(); ++i) {
            int fieldIndex = keyFieldIndices.get(i);
            this.validateFieldIndex(inputId, fieldIndex, originalRowType);
            keyTypes[i] = originalRowType.getTypeAt(fieldIndex);
            keyNames[i] = (String)originalRowType.getFieldNames().get(fieldIndex) + "_key";
        }
        return RowType.of((LogicalType[])keyTypes, (String[])keyNames);
    }

    private void initializeCommonJoinKeyStructures() {
        this.commonJoinKeyType = null;
        if (this.inputTypes != null) {
            for (int i = 0; i < this.inputTypes.size(); ++i) {
                this.commonJoinKeyExtractors.put(i, Collections.emptyList());
            }
        }
        assert (this.inputTypes != null);
        if (this.inputTypes.isEmpty() || this.joinAttributeMap.isEmpty()) {
            return;
        }
        HashMap<AttributeRef, AttributeRef> parent = new HashMap<AttributeRef, AttributeRef>();
        HashMap<AttributeRef, Integer> rank = new HashMap<AttributeRef, Integer>();
        Set<AttributeRef> allAttrRefs = this.collectAllAttributeRefs();
        if (allAttrRefs.isEmpty()) {
            return;
        }
        this.initializeDisjointSets(parent, rank, allAttrRefs);
        this.processJoinConditions(parent, rank);
        Map<AttributeRef, Set<AttributeRef>> equivalenceSets = this.buildEquivalenceSets(parent, allAttrRefs);
        List<Set<AttributeRef>> commonConceptualAttributeSets = this.findCommonConceptualAttributeSets(equivalenceSets);
        this.processCommonAttributes(commonConceptualAttributeSets);
    }

    private Set<AttributeRef> collectAllAttributeRefs() {
        HashSet<AttributeRef> allAttrRefs = new HashSet<AttributeRef>();
        for (List<ConditionAttributeRef> mapping : this.joinAttributeMap.values()) {
            for (ConditionAttributeRef attrRef : mapping) {
                allAttrRefs.add(new AttributeRef(attrRef.leftInputId, attrRef.leftFieldIndex));
                allAttrRefs.add(new AttributeRef(attrRef.rightInputId, attrRef.rightFieldIndex));
            }
        }
        return allAttrRefs;
    }

    private void initializeDisjointSets(Map<AttributeRef, AttributeRef> parent, Map<AttributeRef, Integer> rank, Set<AttributeRef> allAttrRefs) {
        for (AttributeRef attrRef : allAttrRefs) {
            parent.put(attrRef, attrRef);
            rank.put(attrRef, 0);
        }
    }

    private void processJoinConditions(Map<AttributeRef, AttributeRef> parent, Map<AttributeRef, Integer> rank) {
        for (List<ConditionAttributeRef> mapping : this.joinAttributeMap.values()) {
            for (ConditionAttributeRef condition : mapping) {
                AttributeBasedJoinKeyExtractor.unionAttributeSets(parent, rank, new AttributeRef(condition.leftInputId, condition.leftFieldIndex), new AttributeRef(condition.rightInputId, condition.rightFieldIndex));
            }
        }
    }

    private Map<AttributeRef, Set<AttributeRef>> buildEquivalenceSets(Map<AttributeRef, AttributeRef> parent, Set<AttributeRef> allAttrRefs) {
        HashMap<AttributeRef, Set<AttributeRef>> equivalenceSets = new HashMap<AttributeRef, Set<AttributeRef>>();
        for (AttributeRef attrRef : allAttrRefs) {
            AttributeRef root = AttributeBasedJoinKeyExtractor.findAttributeSet(parent, attrRef);
            equivalenceSets.computeIfAbsent(root, k -> new HashSet()).add(attrRef);
        }
        return equivalenceSets;
    }

    private List<Set<AttributeRef>> findCommonConceptualAttributeSets(Map<AttributeRef, Set<AttributeRef>> equivalenceSets) {
        ArrayList<Set<AttributeRef>> commonConceptualAttributeSets = new ArrayList<Set<AttributeRef>>();
        for (Set<AttributeRef> eqSet : equivalenceSets.values()) {
            if (!this.isCommonConceptualAttributeSet(eqSet)) continue;
            commonConceptualAttributeSets.add(eqSet);
        }
        return commonConceptualAttributeSets;
    }

    private boolean isCommonConceptualAttributeSet(Set<AttributeRef> eqSet) {
        if (this.joinAttributeMap.isEmpty()) {
            return false;
        }
        for (List<ConditionAttributeRef> conditionsForStep : this.joinAttributeMap.values()) {
            if (conditionsForStep.isEmpty()) {
                return false;
            }
            boolean foundInThisStep = false;
            for (ConditionAttributeRef condition : conditionsForStep) {
                if (!eqSet.contains(new AttributeRef(condition.leftInputId, condition.leftFieldIndex)) && !eqSet.contains(new AttributeRef(condition.rightInputId, condition.rightFieldIndex))) continue;
                foundInThisStep = true;
                break;
            }
            if (foundInThisStep) continue;
            return false;
        }
        return true;
    }

    private void processCommonAttributes(List<Set<AttributeRef>> commonConceptualAttributeSets) {
        for (int currentInputId = 0; currentInputId < this.inputTypes.size(); ++currentInputId) {
            List<AttributeRef> commonAttrsForThisInput = this.findCommonAttributesForInput(currentInputId, commonConceptualAttributeSets);
            if (commonAttrsForThisInput.isEmpty()) {
                throw new IllegalStateException("All inputs in a multi-way join must share a common join key. Input #" + currentInputId + " does not share a join key with the other inputs. Please ensure all join conditions connect all inputs with a common key. Support for multiple independent join key groups is tracked under FLINK-37890.");
            }
            this.processInputCommonAttributes(currentInputId, commonAttrsForThisInput);
        }
    }

    private List<AttributeRef> findCommonAttributesForInput(int currentInputId, List<Set<AttributeRef>> commonConceptualAttributeSets) {
        ArrayList<AttributeRef> commonAttrsForThisInput = new ArrayList<AttributeRef>();
        block0: for (Set<AttributeRef> eqSet : commonConceptualAttributeSets) {
            for (AttributeRef attrRef : eqSet) {
                if (attrRef.inputId != currentInputId) continue;
                commonAttrsForThisInput.add(attrRef);
                continue block0;
            }
        }
        commonAttrsForThisInput.sort(Comparator.comparingInt(attr -> attr.fieldIndex));
        return commonAttrsForThisInput;
    }

    private void processInputCommonAttributes(int currentInputId, List<AttributeRef> commonAttrsForThisInput) {
        ArrayList<KeyExtractor> extractors = new ArrayList<KeyExtractor>();
        LogicalType[] keyFieldTypes = new LogicalType[commonAttrsForThisInput.size()];
        String[] keyFieldNames = new String[commonAttrsForThisInput.size()];
        RowType originalRowType = this.inputTypes.get(currentInputId);
        for (int i = 0; i < commonAttrsForThisInput.size(); ++i) {
            AttributeRef attr = commonAttrsForThisInput.get(i);
            this.validateFieldIndex(currentInputId, attr.fieldIndex, originalRowType);
            LogicalType fieldType = originalRowType.getTypeAt(attr.fieldIndex);
            extractors.add(new KeyExtractor(currentInputId, attr.fieldIndex, attr.fieldIndex, fieldType));
            keyFieldTypes[i] = fieldType;
            keyFieldNames[i] = (String)originalRowType.getFieldNames().get(attr.fieldIndex) + "_common";
        }
        this.commonJoinKeyExtractors.put(currentInputId, extractors);
        if (currentInputId == 0 && !extractors.isEmpty()) {
            this.commonJoinKeyType = RowType.of((LogicalType[])keyFieldTypes, (String[])keyFieldNames);
        }
    }

    private void validateFieldIndex(int inputId, int fieldIndex, RowType rowType) {
        if (fieldIndex >= rowType.getFieldCount() || fieldIndex < 0) {
            throw new IndexOutOfBoundsException("joinAttributeMap references field index " + fieldIndex + " which is out of bounds for inputId " + inputId + " with type " + String.valueOf(rowType));
        }
    }

    private static AttributeRef findAttributeSet(Map<AttributeRef, AttributeRef> parent, AttributeRef item) {
        if (!parent.get(item).equals(item)) {
            parent.put(item, AttributeBasedJoinKeyExtractor.findAttributeSet(parent, parent.get(item)));
        }
        return parent.get(item);
    }

    private static void unionAttributeSets(Map<AttributeRef, AttributeRef> parent, Map<AttributeRef, Integer> rank, AttributeRef a, AttributeRef b) {
        AttributeRef rootB;
        AttributeRef rootA = AttributeBasedJoinKeyExtractor.findAttributeSet(parent, a);
        if (!rootA.equals(rootB = AttributeBasedJoinKeyExtractor.findAttributeSet(parent, b))) {
            if (rank.get(rootA) < rank.get(rootB)) {
                parent.put(rootA, rootB);
            } else if (rank.get(rootA) > rank.get(rootB)) {
                parent.put(rootB, rootA);
            } else {
                parent.put(rootB, rootA);
                rank.put(rootA, rank.get(rootA) + 1);
            }
        }
    }

    public static final class ConditionAttributeRef
    implements Serializable {
        public int leftInputId;
        public int leftFieldIndex;
        public int rightInputId;
        public int rightFieldIndex;

        public ConditionAttributeRef() {
        }

        public ConditionAttributeRef(int leftInputId, int leftFieldIndex, int rightInputId, int rightFieldIndex) {
            this.leftInputId = leftInputId;
            this.leftFieldIndex = leftFieldIndex;
            this.rightInputId = rightInputId;
            this.rightFieldIndex = rightFieldIndex;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            ConditionAttributeRef that = (ConditionAttributeRef)o;
            return this.leftInputId == that.leftInputId && this.leftFieldIndex == that.leftFieldIndex && this.rightInputId == that.rightInputId && this.rightFieldIndex == that.rightFieldIndex;
        }

        public int hashCode() {
            return Objects.hash(this.leftInputId, this.leftFieldIndex, this.rightInputId, this.rightFieldIndex);
        }

        public String toString() {
            return "LeftInputId:" + this.leftInputId + ";LeftFieldIndex:" + this.leftFieldIndex + ";RightInputId:" + this.rightInputId + ";RightFieldIndex:" + this.rightFieldIndex + ";";
        }
    }

    public static final class AttributeRef
    implements Serializable {
        public int inputId;
        public int fieldIndex;

        public AttributeRef() {
        }

        public AttributeRef(int inputId, int fieldIndex) {
            this.inputId = inputId;
            this.fieldIndex = fieldIndex;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            AttributeRef that = (AttributeRef)o;
            return this.inputId == that.inputId && this.fieldIndex == that.fieldIndex;
        }

        public int hashCode() {
            return Objects.hash(this.inputId, this.fieldIndex);
        }

        public String toString() {
            return "InputId:" + this.inputId + ";FieldIndex:" + this.fieldIndex + ";";
        }
    }

    private static final class KeyExtractor
    implements Serializable {
        private static final long serialVersionUID = 1L;
        private final int inputIdToAccess;
        private final int fieldIndexInSourceRow;
        private final int absoluteFieldIndex;
        private final LogicalType fieldType;
        private transient RowData.FieldGetter fieldGetter;

        public KeyExtractor(int inputIdToAccess, int fieldIndexInSourceRow, int absoluteFieldIndex, LogicalType fieldType) {
            this.inputIdToAccess = inputIdToAccess;
            this.fieldIndexInSourceRow = fieldIndexInSourceRow;
            this.absoluteFieldIndex = absoluteFieldIndex;
            this.fieldType = fieldType;
            this.fieldGetter = RowData.createFieldGetter((LogicalType)this.fieldType, (int)this.fieldIndexInSourceRow);
        }

        public Object getRightSideKey(RowData joinedRowData) {
            if (joinedRowData == null) {
                return null;
            }
            if (this.fieldGetter == null) {
                this.fieldGetter = RowData.createFieldGetter((LogicalType)this.fieldType, (int)this.fieldIndexInSourceRow);
            }
            return this.fieldGetter.getFieldOrNull(joinedRowData);
        }

        public Object getLeftSideKey(RowData joinedRowData) {
            if (joinedRowData == null) {
                return null;
            }
            if (this.fieldGetter == null) {
                this.fieldGetter = RowData.createFieldGetter((LogicalType)this.fieldType, (int)this.absoluteFieldIndex);
            }
            return this.fieldGetter.getFieldOrNull(joinedRowData);
        }

        public int getInputIdToAccess() {
            return this.inputIdToAccess;
        }

        public int getFieldIndexInSourceRow() {
            return this.fieldIndexInSourceRow;
        }

        private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
            in.defaultReadObject();
            if (this.fieldType != null) {
                this.fieldGetter = RowData.createFieldGetter((LogicalType)this.fieldType, (int)this.fieldIndexInSourceRow);
            }
        }
    }
}

