/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.Stack;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.FunctionRegistry;
import org.apache.hadoop.hive.ql.exec.GroupByOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.TableScanOperator;
import org.apache.hadoop.hive.ql.exec.Utilities;
import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker;
import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.lib.Rule;
import org.apache.hadoop.hive.ql.lib.RuleRegExp;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.metadata.Partition;
import org.apache.hadoop.hive.ql.metadata.Table;
import org.apache.hadoop.hive.ql.optimizer.Transform;
import org.apache.hadoop.hive.ql.optimizer.ppr.PartitionPruner;
import org.apache.hadoop.hive.ql.parse.ParseContext;
import org.apache.hadoop.hive.ql.parse.PrunedPartitionList;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeFieldDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeNullDesc;
import org.apache.hadoop.hive.ql.plan.GroupByDesc;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.util.StringUtils;

public class GroupByOptimizer
implements Transform {
    private static final Log LOG = LogFactory.getLog((String)GroupByOptimizer.class.getName());

    @Override
    public ParseContext transform(ParseContext pctx) throws SemanticException {
        LinkedHashMap<Rule, NodeProcessor> opRules = new LinkedHashMap<Rule, NodeProcessor>();
        GroupByOptProcCtx groupByOptimizeCtx = new GroupByOptProcCtx();
        opRules.put(new RuleRegExp("R1", "GBY%RS%GBY%"), this.getMapAggreSortedGroupbyProc(pctx));
        DefaultRuleDispatcher disp = new DefaultRuleDispatcher(this.getDefaultProc(), opRules, groupByOptimizeCtx);
        DefaultGraphWalker ogw = new DefaultGraphWalker(disp);
        ArrayList<Node> topNodes = new ArrayList<Node>();
        topNodes.addAll(pctx.getTopOps().values());
        ogw.startWalking(topNodes, null);
        return pctx;
    }

    private NodeProcessor getDefaultProc() {
        return new NodeProcessor(){

            @Override
            public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
                return null;
            }
        };
    }

    private NodeProcessor getMapAggreSortedGroupbyProc(ParseContext pctx) {
        return new BucketGroupByProcessor(pctx);
    }

    public class GroupByOptProcCtx
    implements NodeProcessorCtx {
    }

    public class BucketGroupByProcessor
    implements NodeProcessor {
        protected ParseContext pGraphContext;

        public BucketGroupByProcessor(ParseContext pGraphContext) {
            this.pGraphContext = pGraphContext;
        }

        @Override
        public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
            GroupByOperator op = (GroupByOperator)stack.get(stack.size() - 3);
            this.checkBucketGroupBy(op);
            return null;
        }

        private void checkBucketGroupBy(GroupByOperator curr) throws SemanticException {
            if (((GroupByDesc)curr.getConf()).getMode() != GroupByDesc.Mode.HASH) {
                return;
            }
            Set<String> tblNames = this.pGraphContext.getGroupOpToInputTables().get(curr);
            if (tblNames == null || tblNames.size() == 0) {
                return;
            }
            boolean bucketGroupBy = true;
            GroupByDesc desc = (GroupByDesc)curr.getConf();
            LinkedList<ExprNodeDesc> groupByKeys = new LinkedList<ExprNodeDesc>();
            groupByKeys.addAll(desc.getKeys());
            ArrayList<String> groupByCols = new ArrayList<String>();
            while (groupByKeys.size() > 0) {
                ExprNodeDesc node = (ExprNodeDesc)groupByKeys.remove(0);
                if (node instanceof ExprNodeColumnDesc) {
                    groupByCols.addAll(node.getCols());
                    continue;
                }
                if (node instanceof ExprNodeConstantDesc || node instanceof ExprNodeNullDesc) continue;
                if (node instanceof ExprNodeFieldDesc) {
                    groupByKeys.add(0, ((ExprNodeFieldDesc)node).getDesc());
                    continue;
                }
                if (node instanceof ExprNodeGenericFuncDesc) {
                    ExprNodeGenericFuncDesc udfNode = (ExprNodeGenericFuncDesc)node;
                    GenericUDF udf = udfNode.getGenericUDF();
                    if (!FunctionRegistry.isDeterministic(udf)) {
                        return;
                    }
                    groupByKeys.addAll(0, udfNode.getChildExprs());
                    continue;
                }
                return;
            }
            if (groupByCols.size() == 0) {
                return;
            }
            for (String table : tblNames) {
                Operator<? extends Serializable> topOp = this.pGraphContext.getTopOps().get(table);
                if (topOp == null || !(topOp instanceof TableScanOperator)) {
                    return;
                }
                TableScanOperator ts = (TableScanOperator)topOp;
                Table destTable = this.pGraphContext.getTopToTable().get(ts);
                if (destTable == null) {
                    return;
                }
                if (!destTable.isPartitioned()) {
                    List<String> sortCols;
                    List<String> bucketCols = destTable.getBucketCols();
                    bucketGroupBy = this.matchBucketOrSortedColumns(groupByCols, bucketCols, sortCols = Utilities.getColumnNamesFromSortCols(destTable.getSortCols()));
                    if (bucketGroupBy) continue;
                    return;
                }
                PrunedPartitionList partsList = null;
                try {
                    partsList = this.pGraphContext.getOpToPartList().get(ts);
                    if (partsList == null) {
                        partsList = PartitionPruner.prune(destTable, this.pGraphContext.getOpToPartPruner().get(ts), this.pGraphContext.getConf(), table, this.pGraphContext.getPrunedPartitions());
                        this.pGraphContext.getOpToPartList().put(ts, partsList);
                    }
                }
                catch (HiveException e) {
                    LOG.error((Object)StringUtils.stringifyException((Throwable)e));
                    throw new SemanticException(e.getMessage(), e);
                }
                ArrayList<Partition> parts = new ArrayList<Partition>();
                parts.addAll(partsList.getConfirmedPartns());
                parts.addAll(partsList.getUnknownPartns());
                for (Partition part : parts) {
                    List<String> sortCols;
                    List<String> bucketCols = part.getBucketCols();
                    bucketGroupBy = this.matchBucketOrSortedColumns(groupByCols, bucketCols, sortCols = part.getSortColNames());
                    if (bucketGroupBy) continue;
                    return;
                }
            }
            ((GroupByDesc)curr.getConf()).setBucketGroup(bucketGroupBy);
        }

        private boolean matchBucketOrSortedColumns(List<String> groupByCols, List<String> bucketCols, List<String> sortCols) throws SemanticException {
            boolean ret = false;
            if (sortCols == null || sortCols.size() == 0) {
                ret = this.matchBucketColumns(groupByCols, bucketCols);
            }
            if (!ret && sortCols != null && sortCols.size() >= groupByCols.size()) {
                int num = groupByCols.size();
                for (int i = 0; i < num; ++i) {
                    if (sortCols.indexOf(groupByCols.get(i)) <= num - 1) continue;
                    return false;
                }
                return true;
            }
            return ret;
        }

        private boolean matchBucketColumns(List<String> grpCols, List<String> tblBucketCols) throws SemanticException {
            if (tblBucketCols == null || tblBucketCols.size() == 0 || grpCols.size() == 0 || grpCols.size() != tblBucketCols.size()) {
                return false;
            }
            for (int i = 0; i < grpCols.size(); ++i) {
                String tblCol = grpCols.get(i);
                if (tblBucketCols.contains(tblCol)) continue;
                return false;
            }
            return true;
        }
    }
}

