/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Stack;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.hive.common.ObjectPair;
import org.apache.hadoop.hive.metastore.api.FieldSchema;
import org.apache.hadoop.hive.metastore.api.Order;
import org.apache.hadoop.hive.ql.exec.FileSinkOperator;
import org.apache.hadoop.hive.ql.exec.FilterOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.SMBMapJoinOperator;
import org.apache.hadoop.hive.ql.exec.SelectOperator;
import org.apache.hadoop.hive.ql.exec.TableScanOperator;
import org.apache.hadoop.hive.ql.io.AcidUtils;
import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker;
import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.lib.Rule;
import org.apache.hadoop.hive.ql.lib.RuleRegExp;
import org.apache.hadoop.hive.ql.metadata.Partition;
import org.apache.hadoop.hive.ql.metadata.Table;
import org.apache.hadoop.hive.ql.optimizer.Transform;
import org.apache.hadoop.hive.ql.parse.ParseContext;
import org.apache.hadoop.hive.ql.parse.PrunedPartitionList;
import org.apache.hadoop.hive.ql.parse.SemanticAnalyzer;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.FileSinkDesc;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc;
import org.apache.hadoop.hive.ql.plan.SMBJoinDesc;
import org.apache.hadoop.hive.ql.plan.SelectDesc;
import org.apache.hadoop.hive.ql.plan.TableScanDesc;
import org.apache.hadoop.hive.shims.ShimLoader;

public class BucketingSortingReduceSinkOptimizer
implements Transform {
    @Override
    public ParseContext transform(ParseContext pctx) throws SemanticException {
        LinkedHashMap<Rule, NodeProcessor> opRules = new LinkedHashMap<Rule, NodeProcessor>();
        opRules.put(new RuleRegExp("R1", ReduceSinkOperator.getOperatorName() + "%" + SelectOperator.getOperatorName() + "%" + FileSinkOperator.getOperatorName() + "%"), this.getBucketSortReduceSinkProc(pctx));
        DefaultRuleDispatcher disp = new DefaultRuleDispatcher(this.getDefaultProc(), opRules, null);
        DefaultGraphWalker ogw = new DefaultGraphWalker(disp);
        ArrayList<Node> topNodes = new ArrayList<Node>();
        topNodes.addAll(pctx.getTopOps().values());
        ogw.startWalking(topNodes, null);
        return pctx;
    }

    private NodeProcessor getDefaultProc() {
        return new NodeProcessor(){

            @Override
            public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
                return null;
            }
        };
    }

    private NodeProcessor getBucketSortReduceSinkProc(ParseContext pctx) {
        return new BucketSortReduceSinkProcessor(pctx);
    }

    public class BucketSortReduceSinkProcessor
    implements NodeProcessor {
        protected ParseContext pGraphContext;

        public BucketSortReduceSinkProcessor(ParseContext pGraphContext) {
            this.pGraphContext = pGraphContext;
        }

        private List<Integer> getBucketPositions(List<String> tabBucketCols, List<FieldSchema> tabCols) {
            ArrayList<Integer> posns = new ArrayList<Integer>();
            block0: for (String bucketCol : tabBucketCols) {
                int pos = 0;
                for (FieldSchema tabCol : tabCols) {
                    if (bucketCol.equals(tabCol.getName())) {
                        posns.add(pos);
                        continue block0;
                    }
                    ++pos;
                }
            }
            return posns;
        }

        private ObjectPair<List<Integer>, List<Integer>> getSortPositionsOrder(List<Order> tabSortCols, List<FieldSchema> tabCols) {
            ArrayList<Integer> sortPositions = new ArrayList<Integer>();
            ArrayList<Integer> sortOrders = new ArrayList<Integer>();
            block0: for (Order sortCol : tabSortCols) {
                int pos = 0;
                for (FieldSchema tabCol : tabCols) {
                    if (sortCol.getCol().equals(tabCol.getName())) {
                        sortPositions.add(pos);
                        sortOrders.add(sortCol.getOrder());
                        continue block0;
                    }
                    ++pos;
                }
            }
            return new ObjectPair(sortPositions, sortOrders);
        }

        private boolean checkPartition(Partition partition, List<Integer> bucketPositionsDest, List<Integer> sortPositionsDest, List<Integer> sortOrderDest, int numBucketsDest) {
            int numBuckets = partition.getBucketCount();
            if (numBucketsDest != numBuckets) {
                return false;
            }
            List<Integer> partnBucketPositions = this.getBucketPositions(partition.getBucketCols(), partition.getTable().getCols());
            ObjectPair<List<Integer>, List<Integer>> partnSortPositionsOrder = this.getSortPositionsOrder(partition.getSortCols(), partition.getTable().getCols());
            return bucketPositionsDest.equals(partnBucketPositions) && sortPositionsDest.equals(partnSortPositionsOrder.getFirst()) && sortOrderDest.equals(partnSortPositionsOrder.getSecond());
        }

        private boolean checkTable(Table table, List<Integer> bucketPositionsDest, List<Integer> sortPositionsDest, List<Integer> sortOrderDest, int numBucketsDest) {
            int numBuckets = table.getNumBuckets();
            if (numBucketsDest != numBuckets) {
                return false;
            }
            List<Integer> tableBucketPositions = this.getBucketPositions(table.getBucketCols(), table.getCols());
            ObjectPair<List<Integer>, List<Integer>> tableSortPositionsOrder = this.getSortPositionsOrder(table.getSortCols(), table.getCols());
            return bucketPositionsDest.equals(tableBucketPositions) && sortPositionsDest.equals(tableSortPositionsOrder.getFirst()) && sortOrderDest.equals(tableSortPositionsOrder.getSecond());
        }

        private void storeBucketPathMapping(TableScanOperator tsOp, FileStatus[] srcs) {
            HashMap<String, Integer> bucketFileNameMapping = new HashMap<String, Integer>();
            for (int pos = 0; pos < srcs.length; ++pos) {
                if (ShimLoader.getHadoopShims().isDirectory(srcs[pos])) {
                    throw new RuntimeException("Was expecting '" + srcs[pos].getPath() + "' to be bucket file.");
                }
                bucketFileNameMapping.put(srcs[pos].getPath().getName(), pos);
            }
            ((TableScanDesc)tsOp.getConf()).setBucketFileNameMapping(bucketFileNameMapping);
        }

        private void removeReduceSink(ReduceSinkOperator rsOp, TableScanOperator tsOp, FileSinkOperator fsOp, FileStatus[] srcs) {
            if (srcs == null) {
                return;
            }
            this.removeReduceSink(rsOp, tsOp, fsOp);
            this.storeBucketPathMapping(tsOp, srcs);
        }

        private void removeReduceSink(ReduceSinkOperator rsOp, TableScanOperator tsOp, FileSinkOperator fsOp) {
            Operator<OperatorDesc> parRSOp = rsOp.getParentOperators().get(0);
            parRSOp.getChildOperators().set(0, fsOp);
            fsOp.getParentOperators().set(0, parRSOp);
            ((FileSinkDesc)fsOp.getConf()).setMultiFileSpray(false);
            ((FileSinkDesc)fsOp.getConf()).setTotalFiles(1);
            ((FileSinkDesc)fsOp.getConf()).setNumFiles(1);
            ((FileSinkDesc)fsOp.getConf()).setRemovedReduceSinkBucketSort(true);
            tsOp.setUseBucketizedHiveInputFormat(true);
        }

        private int findColumnPosition(List<FieldSchema> cols, String colName) {
            int pos = 0;
            for (FieldSchema col : cols) {
                if (colName.equals(col.getName())) {
                    return pos;
                }
                ++pos;
            }
            return -1;
        }

        private boolean validateSMBJoinKeys(SMBJoinDesc smbJoinDesc, List<ExprNodeColumnDesc> sourceTableBucketCols, List<ExprNodeColumnDesc> sourceTableSortCols, List<Integer> sortOrder) {
            if (!sourceTableBucketCols.equals(sourceTableSortCols)) {
                return false;
            }
            Byte[] tagOrder = smbJoinDesc.getTagOrder();
            Map<Byte, List<Integer>> retainList = smbJoinDesc.getRetainList();
            int totalNumberColumns = 0;
            for (Byte tag : tagOrder) {
                totalNumberColumns += retainList.get(tag).size();
            }
            byte[] columnTableMappings = new byte[totalNumberColumns];
            int[] columnNumberMappings = new int[totalNumberColumns];
            int currentColumnPosition = 0;
            for (Byte tag : tagOrder) {
                int pos = 0;
                while (pos < retainList.get(tag).size()) {
                    columnTableMappings[currentColumnPosition] = tag;
                    columnNumberMappings[currentColumnPosition] = pos++;
                    ++currentColumnPosition;
                }
            }
            List<String> outputColumnNames = smbJoinDesc.getOutputColumnNames();
            byte tableTag = -1;
            int[] columnNumbersExprList = new int[sourceTableBucketCols.size()];
            int currentColPosition = 0;
            for (ExprNodeColumnDesc bucketCol : sourceTableBucketCols) {
                String colName = bucketCol.getColumn();
                int colNumber = outputColumnNames.indexOf(colName);
                if (colNumber < 0) {
                    return false;
                }
                if (tableTag < 0) {
                    tableTag = columnTableMappings[colNumber];
                } else if (tableTag != columnTableMappings[colNumber]) {
                    return false;
                }
                columnNumbersExprList[currentColPosition++] = columnNumberMappings[colNumber];
            }
            List<ExprNodeDesc> allExprs = smbJoinDesc.getExprs().get(tableTag);
            List<ExprNodeDesc> keysSelectedTable = smbJoinDesc.getKeys().get(tableTag);
            currentColPosition = 0;
            for (ExprNodeDesc keySelectedTable : keysSelectedTable) {
                if (!(keySelectedTable instanceof ExprNodeColumnDesc)) {
                    return false;
                }
                if (allExprs.get(columnNumbersExprList[currentColPosition++]).isSame(keySelectedTable)) continue;
                return false;
            }
            return true;
        }

        @Override
        public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
            TableScanOperator tso;
            FileSinkOperator fsOp = (FileSinkOperator)nd;
            ReduceSinkOperator rsOp = (ReduceSinkOperator)fsOp.getParentOperators().get(0).getParentOperators().get(0);
            List<ReduceSinkOperator> rsOps = this.pGraphContext.getReduceSinkOperatorsAddedByEnforceBucketingSorting();
            if (rsOps != null && !rsOps.contains(rsOp)) {
                return null;
            }
            if (this.pGraphContext.getContext().getAcidOperation() == AcidUtils.Operation.UPDATE || this.pGraphContext.getContext().getAcidOperation() == AcidUtils.Operation.DELETE) {
                return null;
            }
            if (stack.get(0) instanceof TableScanOperator && SemanticAnalyzer.isAcidTable(((TableScanDesc)(tso = (TableScanOperator)stack.get(0)).getConf()).getTableMetadata())) {
                return null;
            }
            if (((FileSinkDesc)fsOp.getConf()).getDynPartCtx() != null) {
                return null;
            }
            for (ExprNodeDesc keyCol : ((ReduceSinkDesc)rsOp.getConf()).getKeyCols()) {
                if (keyCol instanceof ExprNodeColumnDesc) continue;
                return null;
            }
            Table destTable = ((FileSinkDesc)fsOp.getConf()).getTable();
            if (destTable == null) {
                return null;
            }
            int numBucketsDestination = destTable.getNumBuckets();
            List<Integer> bucketPositions = this.getBucketPositions(destTable.getBucketCols(), destTable.getCols());
            ObjectPair<List<Integer>, List<Integer>> sortOrderPositions = this.getSortPositionsOrder(destTable.getSortCols(), destTable.getCols());
            List sortPositions = (List)sortOrderPositions.getFirst();
            List sortOrder = (List)sortOrderPositions.getSecond();
            boolean useBucketSortPositions = true;
            Operator op = rsOp;
            ArrayList<ExprNodeColumnDesc> sourceTableBucketCols = new ArrayList<ExprNodeColumnDesc>();
            ArrayList<ExprNodeColumnDesc> sourceTableSortCols = new ArrayList<ExprNodeColumnDesc>();
            op = op.getParentOperators().get(0);
            while (op instanceof TableScanOperator || op instanceof FilterOperator || op instanceof SelectOperator || op instanceof SMBMapJoinOperator) {
                String colName;
                if (op instanceof SMBMapJoinOperator) {
                    if (!bucketPositions.equals(sortPositions)) {
                        return null;
                    }
                    SMBMapJoinOperator smbOp = (SMBMapJoinOperator)op;
                    SMBJoinDesc smbJoinDesc = (SMBJoinDesc)smbOp.getConf();
                    int posBigTable = smbJoinDesc.getPosBigTable();
                    List<ExprNodeDesc> keysBigTable = smbJoinDesc.getKeys().get((byte)posBigTable);
                    if (keysBigTable.size() != bucketPositions.size()) {
                        return null;
                    }
                    if (!this.validateSMBJoinKeys(smbJoinDesc, sourceTableBucketCols, sourceTableSortCols, sortOrder)) {
                        return null;
                    }
                    sourceTableBucketCols.clear();
                    sourceTableSortCols.clear();
                    useBucketSortPositions = false;
                    for (ExprNodeDesc keyBigTable : keysBigTable) {
                        if (!(keyBigTable instanceof ExprNodeColumnDesc)) {
                            return null;
                        }
                        sourceTableBucketCols.add((ExprNodeColumnDesc)keyBigTable);
                        sourceTableSortCols.add((ExprNodeColumnDesc)keyBigTable);
                    }
                    op = op.getParentOperators().get(posBigTable);
                    continue;
                }
                if (op instanceof TableScanOperator) {
                    assert (!useBucketSortPositions);
                    TableScanOperator ts = (TableScanOperator)op;
                    Table srcTable = ((TableScanDesc)ts.getConf()).getTableMetadata();
                    ArrayList<Integer> newBucketPositions = new ArrayList<Integer>();
                    for (int pos = 0; pos < bucketPositions.size(); ++pos) {
                        ExprNodeColumnDesc col = (ExprNodeColumnDesc)sourceTableBucketCols.get(pos);
                        colName = col.getColumn();
                        int bucketPos = this.findColumnPosition(srcTable.getCols(), colName);
                        if (bucketPos < 0) {
                            return null;
                        }
                        newBucketPositions.add(bucketPos);
                    }
                    ArrayList<Integer> newSortPositions = new ArrayList<Integer>();
                    for (int pos = 0; pos < sortPositions.size(); ++pos) {
                        ExprNodeColumnDesc col = (ExprNodeColumnDesc)sourceTableSortCols.get(pos);
                        String colName2 = col.getColumn();
                        int sortPos = this.findColumnPosition(srcTable.getCols(), colName2);
                        if (sortPos < 0) {
                            return null;
                        }
                        newSortPositions.add(sortPos);
                    }
                    if (srcTable.isPartitioned()) {
                        PrunedPartitionList prunedParts = this.pGraphContext.getPrunedPartitions(srcTable.getTableName(), ts);
                        List<Partition> partitions = prunedParts.getNotDeniedPartns();
                        if (partitions == null || partitions.isEmpty() || partitions.size() > 1) {
                            return null;
                        }
                        for (Partition partition : partitions) {
                            if (this.checkPartition(partition, newBucketPositions, newSortPositions, sortOrder, numBucketsDestination)) continue;
                            return null;
                        }
                        this.removeReduceSink(rsOp, (TableScanOperator)op, fsOp, partitions.get(0).getSortedPaths());
                        return null;
                    }
                    if (!this.checkTable(srcTable, newBucketPositions, newSortPositions, sortOrder, numBucketsDestination)) {
                        return null;
                    }
                    this.removeReduceSink(rsOp, (TableScanOperator)op, fsOp, srcTable.getSortedPaths());
                    return null;
                }
                if (op instanceof SelectOperator) {
                    ExprNodeDesc selectColList;
                    SelectOperator selectOp = (SelectOperator)op;
                    SelectDesc selectDesc = (SelectDesc)selectOp.getConf();
                    if (!useBucketSortPositions) {
                        int colPos;
                        bucketPositions.clear();
                        sortPositions.clear();
                        List<String> outputColumnNames = selectDesc.getOutputColumnNames();
                        for (ExprNodeColumnDesc col : sourceTableBucketCols) {
                            colName = col.getColumn();
                            colPos = outputColumnNames.indexOf(colName);
                            if (colPos < 0) {
                                return null;
                            }
                            bucketPositions.add(colPos);
                        }
                        for (ExprNodeColumnDesc col : sourceTableSortCols) {
                            colName = col.getColumn();
                            colPos = outputColumnNames.indexOf(colName);
                            if (colPos < 0) {
                                return null;
                            }
                            sortPositions.add(colPos);
                        }
                    }
                    sourceTableBucketCols.clear();
                    sourceTableSortCols.clear();
                    for (int pos : bucketPositions) {
                        selectColList = selectDesc.getColList().get(pos);
                        if (!(selectColList instanceof ExprNodeColumnDesc)) {
                            return null;
                        }
                        sourceTableBucketCols.add((ExprNodeColumnDesc)selectColList);
                    }
                    for (int pos : sortPositions) {
                        selectColList = selectDesc.getColList().get(pos);
                        if (!(selectColList instanceof ExprNodeColumnDesc)) {
                            return null;
                        }
                        sourceTableSortCols.add((ExprNodeColumnDesc)selectColList);
                    }
                    useBucketSortPositions = false;
                }
                op = op.getParentOperators().get(0);
            }
            return null;
        }
    }
}

