/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kylin.query.util;

import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Calendar;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
import java.util.TimeZone;
import org.apache.calcite.sql.SqlBasicCall;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlNumericLiteral;
import org.apache.calcite.sql.dialect.CalciteSqlDialect;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.util.SqlVisitor;
import org.apache.calcite.util.Litmus;
import org.apache.kylin.common.util.Pair;
import org.apache.kylin.metadata.model.tool.CalciteParser;
import org.apache.kylin.query.IQueryTransformer;
import org.apache.kylin.query.util.AbstractSqlVisitor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DateNumberFilterTransformer
implements IQueryTransformer {
    private static final Logger logger = LoggerFactory.getLogger(DateNumberFilterTransformer.class);
    private static final ThreadLocal<SimpleDateFormat> THREAD_LOCAL = ThreadLocal.withInitial(() -> new SimpleDateFormat("yyyy-MM-dd", Locale.getDefault(Locale.Category.FORMAT)));

    @Override
    public String transform(String originSql, String project, String defaultSchema) {
        try {
            SqlTimeFilterMatcher matcher = new SqlTimeFilterMatcher(originSql);
            SqlNode sqlNode = this.getSqlNode(originSql);
            sqlNode.accept((SqlVisitor)matcher);
            if (matcher.getTimeFilterPositions().isEmpty()) {
                return originSql;
            }
            logger.debug("'DateNumberFilterTransformer' will be used to transform SQL");
            return this.replaceTimeFilter(matcher.getTimeFilterPositions(), originSql);
        }
        catch (Exception e) {
            logger.warn("Something unexpected in DateNumberFilterTransformer, return original query", (Throwable)e);
            return originSql;
        }
    }

    private SqlNode getSqlNode(String sql) {
        SqlNode sqlNode;
        try {
            sqlNode = CalciteParser.parse((String)sql);
        }
        catch (SqlParseException e) {
            throw new IllegalStateException(e);
        }
        return sqlNode;
    }

    private String replaceTimeFilter(List<Pair<String, Pair<Integer, Integer>>> positions, String originSql) {
        positions.sort((o1, o2) -> (Integer)((Pair)o2.getSecond()).getFirst() - (Integer)((Pair)o1.getSecond()).getFirst());
        String sql = originSql + " ";
        for (Pair<String, Pair<Integer, Integer>> pos : positions) {
            sql = sql.substring(0, (Integer)((Pair)pos.getSecond()).getFirst()) + (String)pos.getFirst() + sql.substring((Integer)((Pair)pos.getSecond()).getSecond());
        }
        return sql.trim();
    }

    static class SqlTimeFilterMatcher
    extends AbstractSqlVisitor {
        private final List<Pair<String, Pair<Integer, Integer>>> timeFilterPositions = new ArrayList<Pair<String, Pair<Integer, Integer>>>();
        private static final List<String> SUPPORT_FUN = Arrays.asList("=", "IN", "NOT IN", "BETWEEN", "NOT BETWEEN", "<", ">", "<=", ">=", "!=", "<>");
        private static final List<String> YEAR_FUN = Arrays.asList("YEAR", "{fn YEAR}");
        private static final List<String> MONTH_FUN = Arrays.asList("MONTH", "{fn MONTH}");
        private static final List<String> DAY_FUN = Arrays.asList("DAYOFMONTH", "{fn DAYOFMONTH}");

        public SqlTimeFilterMatcher(String originSql) {
            super(originSql);
        }

        @Override
        public void visitInSqlWhere(SqlNode call) {
            LinkedList<SqlNode> conditions = new LinkedList<SqlNode>();
            conditions.add(call);
            while (!conditions.isEmpty()) {
                SqlNode cond = (SqlNode)conditions.poll();
                if (!(cond instanceof SqlBasicCall)) continue;
                SqlBasicCall node = (SqlBasicCall)cond;
                if (SUPPORT_FUN.contains(node.getOperator().toString())) {
                    this.fetchTimeFilter(node);
                    continue;
                }
                conditions.addAll(node.getOperandList());
            }
        }

        public void fetchTimeFilter(SqlBasicCall call) {
            switch (call.getOperator().toString()) {
                case "=": 
                case "!=": 
                case ">": 
                case ">=": 
                case "<": 
                case "<=": 
                case "<>": {
                    this.fetchTimeFilterInNormalCondition(call);
                    break;
                }
                case "BETWEEN": 
                case "NOT BETWEEN": {
                    this.fetchTimeFilterInBetweenCondition(call);
                    break;
                }
                case "IN": 
                case "NOT IN": {
                    this.fetchTimeFilterOfInCondition(call);
                    break;
                }
            }
        }

        private void fetchTimeFilterOfInCondition(SqlBasicCall call) {
            String delimiter;
            String operator;
            if (!(call.operand(0) instanceof SqlBasicCall) || !(call.operand(1) instanceof SqlNodeList)) {
                return;
            }
            SqlBasicCall expression = (SqlBasicCall)call.operand(0);
            SqlNodeList timeList = (SqlNodeList)call.operand(1);
            TimeExpression timeExpression = new TimeExpression(expression);
            if (call.getOperator().toString().equals("IN")) {
                operator = "=";
                delimiter = " OR ";
            } else if (call.getOperator().toString().equals("NOT IN")) {
                operator = "!=";
                delimiter = " AND ";
            } else {
                return;
            }
            ArrayList<String> targets = new ArrayList<String>();
            for (SqlNode node : timeList) {
                String subTarget = null;
                if (node instanceof SqlNumericLiteral) {
                    subTarget = this.rewriteFilter(operator, (SqlNumericLiteral)node, timeExpression);
                }
                if (subTarget == null) {
                    return;
                }
                targets.add(subTarget);
            }
            if (!targets.isEmpty()) {
                this.timeFilterFound((SqlNode)call, String.format(Locale.ROOT, "(%s)", String.join((CharSequence)delimiter, targets)));
            }
        }

        private void fetchTimeFilterInNormalCondition(SqlBasicCall call) {
            SqlNumericLiteral time;
            SqlBasicCall expression;
            if (call.operand(0) instanceof SqlBasicCall && call.operand(1) instanceof SqlNumericLiteral) {
                expression = (SqlBasicCall)call.operand(0);
                time = (SqlNumericLiteral)call.operand(1);
            } else if (call.operand(1) instanceof SqlBasicCall && call.operand(0) instanceof SqlNumericLiteral) {
                expression = (SqlBasicCall)call.operand(1);
                time = (SqlNumericLiteral)call.operand(0);
            } else {
                return;
            }
            TimeExpression timeExpression = new TimeExpression(expression);
            String target = this.rewriteFilter(call.getOperator().toString(), time, timeExpression);
            if (target != null) {
                this.timeFilterFound((SqlNode)call, target);
            }
        }

        private void fetchTimeFilterInBetweenCondition(SqlBasicCall call) {
            if (!(call.operand(0) instanceof SqlBasicCall && call.operand(1) instanceof SqlNumericLiteral && call.operand(2) instanceof SqlNumericLiteral)) {
                return;
            }
            SqlBasicCall expression = (SqlBasicCall)call.operand(0);
            SqlNumericLiteral leftTime = (SqlNumericLiteral)call.operand(1);
            SqlNumericLiteral rightTime = (SqlNumericLiteral)call.operand(2);
            if (leftTime.toString().length() != rightTime.toString().length()) {
                return;
            }
            TimeExpression timeExpression = new TimeExpression(expression);
            String target = this.rewriteFilter(call.getOperator().toString(), leftTime, rightTime, timeExpression);
            if (target != null) {
                this.timeFilterFound((SqlNode)call, target);
            }
        }

        String rewriteFilter(String op, SqlNumericLiteral leftTime, SqlNumericLiteral rightTime, TimeExpression timeExpression) {
            String target = null;
            if (timeExpression.isYear(leftTime)) {
                target = String.format(Locale.ROOT, "cast(%s as date) %s %s and %s", timeExpression.colNameString(), op, this.yearToDate((Integer)leftTime.getValueAs(Integer.class), false), this.yearToDate((Integer)rightTime.getValueAs(Integer.class), true));
            } else if (timeExpression.isMonthYear(leftTime)) {
                target = String.format(Locale.ROOT, "cast(%s as date) %s %s and %s", timeExpression.colNameString(), op, this.monthToDate((Integer)leftTime.getValueAs(Integer.class), false), this.monthToDate((Integer)rightTime.getValueAs(Integer.class), true));
            } else if (timeExpression.isDayMonthYear(leftTime)) {
                target = String.format(Locale.ROOT, "cast(%s as date) %s %s and %s", timeExpression.colNameString(), op, this.dayMonthYearToDate(leftTime.toString()), this.dayMonthYearToDate(rightTime.toString()));
            }
            return target;
        }

        String rewriteFilter(String op, SqlNumericLiteral time, TimeExpression timeExpression) {
            String target = null;
            switch (op) {
                case "=": {
                    if (timeExpression.isYear(time)) {
                        target = String.format(Locale.ROOT, "cast(%s as date) BETWEEN %s", timeExpression.colNameString(), this.yearToRange((Integer)time.getValueAs(Integer.class)));
                        break;
                    }
                    if (timeExpression.isMonthYear(time)) {
                        target = String.format(Locale.ROOT, "cast(%s as date) BETWEEN %s", timeExpression.colNameString(), this.monthYearToRange((Integer)time.getValueAs(Integer.class)));
                        break;
                    }
                    if (!timeExpression.isDayMonthYear(time)) break;
                    target = String.format(Locale.ROOT, "cast(%s as date) = %s", timeExpression.colNameString(), this.dayMonthYearToDate(time.toString()));
                    break;
                }
                case "!=": 
                case "<>": {
                    if (timeExpression.isYear(time)) {
                        target = String.format(Locale.ROOT, "cast(%s as date) NOT BETWEEN %s", timeExpression.colNameString(), this.yearToRange((Integer)time.getValueAs(Integer.class)));
                        break;
                    }
                    if (timeExpression.isMonthYear(time)) {
                        target = String.format(Locale.ROOT, "cast(%s as date) NOT BETWEEN %s", timeExpression.colNameString(), this.monthYearToRange((Integer)time.getValueAs(Integer.class)));
                        break;
                    }
                    if (!timeExpression.isDayMonthYear(time)) break;
                    target = String.format(Locale.ROOT, "cast(%s as date) <> %s", timeExpression.colNameString(), this.dayMonthYearToDate(time.toString()));
                    break;
                }
                case ">": 
                case "<": 
                case ">=": 
                case "<=": {
                    if (timeExpression.isYear(time)) {
                        target = String.format(Locale.ROOT, "cast(%s as date) %s %s", timeExpression.colNameString(), op, this.yearToDate((Integer)time.getValueAs(Integer.class), op.equals(">") || op.equals("<=")));
                        break;
                    }
                    if (timeExpression.isMonthYear(time)) {
                        target = String.format(Locale.ROOT, "cast(%s as date) %s %s", timeExpression.colNameString(), op, this.monthToDate((Integer)time.getValueAs(Integer.class), op.equals(">") || op.equals("<=")));
                        break;
                    }
                    if (!timeExpression.isDayMonthYear(time)) break;
                    target = String.format(Locale.ROOT, "cast(%s as date) %s %s", timeExpression.colNameString(), op, this.dayMonthYearToDate(time.toString()));
                    break;
                }
            }
            return target;
        }

        String yearToRange(int year) {
            return this.yearToDate(year, false) + " and " + this.yearToDate(year, true);
        }

        String monthYearToRange(int monthYear) {
            return this.monthToDate(monthYear, false) + " and " + this.monthToDate(monthYear, true);
        }

        String yearToDate(int year, boolean end) {
            return end ? String.format(Locale.ROOT, "'%d-12-31'", year) : String.format(Locale.ROOT, "'%d-01-01'", year);
        }

        String monthToDate(int monthYear, boolean end) {
            int month = monthYear % 100;
            int year = (monthYear - month) / 100;
            Calendar cal = Calendar.getInstance(TimeZone.getDefault(), Locale.getDefault());
            cal.clear();
            cal.set(1, year);
            cal.set(2, month - 1);
            int day = end ? cal.getActualMaximum(5) : 1;
            cal.set(5, day);
            SimpleDateFormat sdf = (SimpleDateFormat)THREAD_LOCAL.get();
            return String.format(Locale.ROOT, "'%s'", sdf.format(cal.getTime()));
        }

        String dayMonthYearToDate(String dayMonthYear) {
            return String.format(Locale.ROOT, "'%s-%s-%s'", dayMonthYear.substring(0, 4), dayMonthYear.substring(4, 6), dayMonthYear.substring(6));
        }

        public void timeFilterFound(SqlNode filter, String target) {
            Pair pos = CalciteParser.getReplacePos((SqlNode)filter, (String)this.originSql);
            this.timeFilterPositions.add((Pair<String, Pair<Integer, Integer>>)new Pair((Object)target, (Object)pos));
        }

        public List<Pair<String, Pair<Integer, Integer>>> getTimeFilterPositions() {
            return this.timeFilterPositions;
        }

        static class TimeExpression {
            private int timeType = 0;
            private SqlNode colName = null;
            private List<SqlNode> addedExpression = new ArrayList<SqlNode>();
            private SqlBasicCall yearExpression = null;
            private SqlBasicCall monthExpression = null;
            private SqlBasicCall dayExpression = null;

            public TimeExpression(SqlBasicCall expression) {
                this.extraSubExpression((SqlNode)expression);
                if (this.addedExpression.size() == 1) {
                    this.initYearTime();
                } else if (this.addedExpression.size() == 2) {
                    this.initMonthYearTime();
                } else if (this.addedExpression.size() == 3) {
                    this.initDayMonthYearTime();
                }
            }

            private void initYearTime() {
                if (!(this.addedExpression.get(0) instanceof SqlBasicCall)) {
                    return;
                }
                this.yearExpression = (SqlBasicCall)this.addedExpression.get(0);
                if (YEAR_FUN.contains(this.yearExpression.getOperator().toString())) {
                    this.timeType = 1;
                    this.colName = this.yearExpression.operand(0);
                }
            }

            private void initMonthYearTime() {
                if (!(this.addedExpression.get(0) instanceof SqlBasicCall) || !(this.addedExpression.get(1) instanceof SqlBasicCall)) {
                    return;
                }
                for (SqlNode node : this.addedExpression) {
                    SqlBasicCall subExp = (SqlBasicCall)node;
                    if (MONTH_FUN.contains(subExp.getOperator().toString())) {
                        this.monthExpression = subExp;
                        if (this.colName == null) {
                            this.colName = subExp.operand(0);
                            continue;
                        }
                        if (this.colName.equalsDeep(subExp.operand(0), Litmus.IGNORE)) continue;
                        return;
                    }
                    if (!subExp.getOperator().toString().equals("*")) continue;
                    this.checkMultiplicationExpression(subExp);
                }
                if (this.monthExpression != null && this.yearExpression != null) {
                    this.timeType = 2;
                }
            }

            private void initDayMonthYearTime() {
                if (!(this.addedExpression.get(0) instanceof SqlBasicCall && this.addedExpression.get(1) instanceof SqlBasicCall && this.addedExpression.get(2) instanceof SqlBasicCall)) {
                    return;
                }
                for (SqlNode node : this.addedExpression) {
                    SqlBasicCall subExp = (SqlBasicCall)node;
                    if (DAY_FUN.contains(subExp.getOperator().toString())) {
                        this.dayExpression = subExp;
                        if (this.colName == null) {
                            this.colName = subExp.operand(0);
                            continue;
                        }
                        if (this.colName.equalsDeep(subExp.operand(0), Litmus.IGNORE)) continue;
                        return;
                    }
                    if (!subExp.getOperator().toString().equals("*")) continue;
                    this.checkMultiplicationExpression(subExp);
                }
                if (this.dayExpression != null && this.monthExpression != null && this.yearExpression != null) {
                    this.timeType = 3;
                }
            }

            void checkMultiplicationExpression(SqlBasicCall subExp) {
                SqlNode multiplier;
                SqlBasicCall tmpExpression;
                if (subExp.operand(0) instanceof SqlBasicCall) {
                    tmpExpression = (SqlBasicCall)subExp.operand(0);
                    multiplier = subExp.operand(1);
                } else if (subExp.operand(1) instanceof SqlBasicCall) {
                    tmpExpression = (SqlBasicCall)subExp.operand(1);
                    multiplier = subExp.operand(0);
                } else {
                    return;
                }
                if (multiplier instanceof SqlNumericLiteral) {
                    if (multiplier.toString().equals("100") && this.addedExpression.size() == 3 && MONTH_FUN.contains(tmpExpression.getOperator().toString())) {
                        if (this.colName == null || this.colName.equalsDeep(tmpExpression.operand(0), Litmus.IGNORE)) {
                            this.colName = tmpExpression.operand(0);
                            this.monthExpression = tmpExpression;
                        }
                    } else if (this.isYearMultiplicationExpression(multiplier, tmpExpression) && (this.colName == null || this.colName.equalsDeep(tmpExpression.operand(0), Litmus.IGNORE))) {
                        this.colName = tmpExpression.operand(0);
                        this.yearExpression = tmpExpression;
                    }
                }
            }

            public String colNameString() {
                return this.colName.toSqlString(CalciteSqlDialect.DEFAULT).toString();
            }

            void extraSubExpression(SqlNode node) {
                if (node instanceof SqlBasicCall && ((SqlBasicCall)node).getOperator().toString().equals("+")) {
                    for (SqlNode subNode : ((SqlBasicCall)node).getOperandList()) {
                        this.extraSubExpression(subNode);
                    }
                } else {
                    this.addedExpression.add(node);
                }
            }

            public boolean isYearMultiplicationExpression(SqlNode multiplier, SqlBasicCall expression) {
                return (multiplier.toString().equals("10000") && this.addedExpression.size() == 3 || multiplier.toString().equals("100") && this.addedExpression.size() == 2) && YEAR_FUN.contains(expression.getOperator().toString());
            }

            public boolean isYear(SqlNumericLiteral time) {
                return this.timeType == 1 && time.toString().length() == 4;
            }

            public boolean isMonthYear(SqlNumericLiteral time) {
                return this.timeType == 2 && time.toString().length() == 6;
            }

            public boolean isDayMonthYear(SqlNumericLiteral time) {
                return this.timeType == 3 && time.toString().length() == 8;
            }
        }
    }
}

