/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.legacy.rewriter.join;

import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr;
import com.alibaba.druid.sql.ast.expr.SQLQueryExpr;
import com.alibaba.druid.sql.ast.statement.SQLExprTableSource;
import com.alibaba.druid.sql.ast.statement.SQLJoinTableSource;
import com.alibaba.druid.sql.ast.statement.SQLSelectQuery;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlASTVisitorAdapter;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.function.Consumer;
import org.opensearch.sql.legacy.esdomain.LocalClusterState;
import org.opensearch.sql.legacy.esdomain.mapping.FieldMappings;
import org.opensearch.sql.legacy.rewriter.RewriteRule;
import org.opensearch.sql.legacy.rewriter.matchtoterm.VerificationException;
import org.opensearch.sql.legacy.utils.StringUtils;
import shaded.com.google.common.collect.ArrayListMultimap;

public class JoinRewriteRule
implements RewriteRule<SQLQueryExpr> {
    private static final String DOT = ".";
    private int aliasSuffix = 0;
    private final LocalClusterState clusterState;

    public JoinRewriteRule(LocalClusterState clusterState) {
        this.clusterState = clusterState;
    }

    @Override
    public boolean match(SQLQueryExpr root) {
        return this.isJoin(root);
    }

    private boolean isJoin(SQLQueryExpr sqlExpr) {
        SQLSelectQuery sqlSelectQuery = sqlExpr.getSubQuery().getQuery();
        if (!(sqlSelectQuery instanceof MySqlSelectQueryBlock)) {
            return false;
        }
        MySqlSelectQueryBlock query = (MySqlSelectQueryBlock)sqlSelectQuery;
        return query.getFrom() instanceof SQLJoinTableSource && ((SQLJoinTableSource)query.getFrom()).getJoinType() != SQLJoinTableSource.JoinType.COMMA;
    }

    @Override
    public void rewrite(SQLQueryExpr root) {
        ArrayListMultimap tableByFieldName = ArrayListMultimap.create();
        HashMap<String, String> tableNameToAlias = new HashMap<String, String>();
        HashSet explicitAliases = new HashSet();
        this.visitTable(root, tableExpr -> {
            String tableName = tableExpr.getExpr().toString().replaceAll(" ", "").split("/")[0];
            if (tableExpr.getAlias() == null) {
                String alias = this.createAlias(tableName);
                tableExpr.setAlias(alias);
                explicitAliases.add(alias);
            }
            Table table = new Table(tableName, tableExpr.getAlias());
            tableNameToAlias.put(table.getName(), table.getAlias());
            FieldMappings fieldMappings = (FieldMappings)this.clusterState.getFieldMappings(new String[]{tableName}).firstMapping();
            fieldMappings.flat((fieldName, type2) -> tableByFieldName.put(fieldName, table));
        });
        if (tableNameToAlias.size() == 1) {
            String tableName = (String)tableNameToAlias.keySet().iterator().next();
            if (explicitAliases.size() == 2) {
                throw new VerificationException(StringUtils.format("Not unique table/alias: [%s]", tableName));
            }
            if (explicitAliases.size() == 1) {
                tableNameToAlias.put(tableName, (String)explicitAliases.iterator().next());
            }
        }
        this.visitColumnName(root, idExpr -> {
            String columnName = idExpr.getName();
            Collection tables = tableByFieldName.get(columnName);
            if (tables.size() > 1) {
                throw new VerificationException(StringUtils.format("Field name [%s] is ambiguous", columnName));
            }
            if (tables.isEmpty()) {
                tableNameToAlias.keySet().stream().forEach(tableName -> {
                    if (columnName.startsWith(tableName + DOT)) {
                        idExpr.setName(columnName.replace(tableName + DOT, (String)tableNameToAlias.get(tableName) + DOT));
                    }
                });
            } else {
                Table table = (Table)tables.iterator().next();
                idExpr.setName(String.join((CharSequence)DOT, table.getAlias(), columnName));
            }
        });
    }

    private void visitTable(SQLQueryExpr root, final Consumer<SQLExprTableSource> visit) {
        root.accept(new MySqlASTVisitorAdapter(this){

            @Override
            public void endVisit(SQLExprTableSource tableExpr) {
                visit.accept(tableExpr);
            }
        });
    }

    private void visitColumnName(SQLQueryExpr expr, final Consumer<SQLIdentifierExpr> visit) {
        expr.accept(new MySqlASTVisitorAdapter(this){

            @Override
            public boolean visit(SQLExprTableSource x) {
                return false;
            }

            @Override
            public void endVisit(SQLIdentifierExpr idExpr) {
                visit.accept(idExpr);
            }
        });
    }

    private String createAlias(String alias) {
        return String.format("%s_%d", alias, this.next());
    }

    private Integer next() {
        return this.aliasSuffix++;
    }

    private static class Table {
        private final String name;
        private final String alias;

        public String getName() {
            return this.name;
        }

        public String getAlias() {
            return this.alias;
        }

        Table(String name, String alias) {
            this.name = name;
            this.alias = alias;
        }

        public String toString() {
            return this.name + "-->" + this.alias;
        }
    }
}

