/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.rules.MultiJoin;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Pair;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;

public class HiveJoinToMultiJoinRule
extends RelOptRule {
    public static final HiveJoinToMultiJoinRule INSTANCE = new HiveJoinToMultiJoinRule(Join.class);

    public HiveJoinToMultiJoinRule(Class<? extends Join> clazz) {
        super(HiveJoinToMultiJoinRule.operand(clazz, (RelOptRuleOperand)HiveJoinToMultiJoinRule.operand(RelNode.class, (RelOptRuleOperandChildren)HiveJoinToMultiJoinRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[]{HiveJoinToMultiJoinRule.operand(RelNode.class, (RelOptRuleOperandChildren)HiveJoinToMultiJoinRule.any())}));
    }

    public void onMatch(RelOptRuleCall call) {
        boolean combinable;
        Join join = (Join)call.rel(0);
        RelNode left = call.rel(1);
        RelNode right = call.rel(2);
        RexBuilder rexBuilder = join.getCluster().getRexBuilder();
        if (join.getJoinType() != JoinRelType.INNER) {
            return;
        }
        ArrayList newInputs = Lists.newArrayList();
        ArrayList newJoinFilters = Lists.newArrayList();
        newJoinFilters.add(join.getCondition());
        ArrayList joinSpecs = Lists.newArrayList();
        ArrayList projFields = Lists.newArrayList();
        if (left instanceof Join || left instanceof MultiJoin) {
            RexNode leftCondition = left instanceof Join ? ((Join)left).getCondition() : ((MultiJoin)left).getJoinFilter();
            combinable = HiveJoinToMultiJoinRule.isCombinablePredicate(join, join.getCondition(), leftCondition);
            if (combinable) {
                newJoinFilters.add(leftCondition);
                for (RelNode input : left.getInputs()) {
                    projFields.add(null);
                    joinSpecs.add(Pair.of((Object)JoinRelType.INNER, (Object)null));
                    newInputs.add(input);
                }
            } else {
                projFields.add(null);
                joinSpecs.add(Pair.of((Object)JoinRelType.INNER, (Object)null));
                newInputs.add(left);
            }
        } else {
            projFields.add(null);
            joinSpecs.add(Pair.of((Object)JoinRelType.INNER, (Object)null));
            newInputs.add(left);
        }
        if (right instanceof Join || right instanceof MultiJoin) {
            RexNode rightCondition = right instanceof Join ? this.shiftRightFilter(join, left, right, ((Join)right).getCondition()) : this.shiftRightFilter(join, left, right, ((MultiJoin)right).getJoinFilter());
            combinable = HiveJoinToMultiJoinRule.isCombinablePredicate(join, join.getCondition(), rightCondition);
            if (combinable) {
                newJoinFilters.add(rightCondition);
                for (RelNode input : right.getInputs()) {
                    projFields.add(null);
                    joinSpecs.add(Pair.of((Object)JoinRelType.INNER, (Object)null));
                    newInputs.add(input);
                }
            } else {
                projFields.add(null);
                joinSpecs.add(Pair.of((Object)JoinRelType.INNER, (Object)null));
                newInputs.add(right);
            }
        } else {
            projFields.add(null);
            joinSpecs.add(Pair.of((Object)JoinRelType.INNER, (Object)null));
            newInputs.add(right);
        }
        if (newJoinFilters.size() == 1) {
            return;
        }
        RexNode newCondition = RexUtil.flatten((RexBuilder)rexBuilder, (RexNode)RexUtil.composeConjunction((RexBuilder)rexBuilder, (Iterable)newJoinFilters, (boolean)false));
        ImmutableMap<Integer, ImmutableIntList> newJoinFieldRefCountsMap = this.addOnJoinFieldRefCounts(newInputs, join.getRowType().getFieldCount(), newCondition);
        List<RexNode> newPostJoinFilters = this.combinePostJoinFilters(join, left, right);
        MultiJoin multiJoin = new MultiJoin(join.getCluster(), (List)newInputs, newCondition, join.getRowType(), false, Pair.right((List)joinSpecs), Pair.left((List)joinSpecs), (List)projFields, newJoinFieldRefCountsMap, RexUtil.composeConjunction((RexBuilder)rexBuilder, newPostJoinFilters, (boolean)true));
        call.transformTo((RelNode)multiJoin);
    }

    private static boolean isCombinablePredicate(Join join, RexNode condition, RexNode otherCondition) {
        HiveCalciteUtil.JoinPredicateInfo joinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(join, condition);
        HiveCalciteUtil.JoinPredicateInfo otherJoinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(join, otherCondition);
        if (joinPredInfo.getProjsFromLeftPartOfJoinKeysInJoinSchema().equals(otherJoinPredInfo.getProjsFromLeftPartOfJoinKeysInJoinSchema())) {
            return false;
        }
        return !joinPredInfo.getProjsFromRightPartOfJoinKeysInJoinSchema().equals(otherJoinPredInfo.getProjsFromRightPartOfJoinKeysInJoinSchema());
    }

    private RexNode shiftRightFilter(Join joinRel, RelNode left, RelNode right, RexNode rightFilter) {
        if (rightFilter == null) {
            return null;
        }
        int nFieldsOnLeft = left.getRowType().getFieldList().size();
        int nFieldsOnRight = right.getRowType().getFieldList().size();
        int[] adjustments = new int[nFieldsOnRight];
        for (int i = 0; i < nFieldsOnRight; ++i) {
            adjustments[i] = nFieldsOnLeft;
        }
        rightFilter = (RexNode)rightFilter.accept((RexVisitor)new RelOptUtil.RexInputConverter(joinRel.getCluster().getRexBuilder(), right.getRowType().getFieldList(), joinRel.getRowType().getFieldList(), adjustments));
        return rightFilter;
    }

    private ImmutableMap<Integer, ImmutableIntList> addOnJoinFieldRefCounts(List<RelNode> multiJoinInputs, int nTotalFields, RexNode joinCondition) {
        int[] joinCondRefCounts = new int[nTotalFields];
        joinCondition.accept((RexVisitor)new InputReferenceCounter(joinCondRefCounts));
        HashMap refCountsMap = Maps.newHashMap();
        int nInputs = multiJoinInputs.size();
        int currInput = -1;
        int startField = 0;
        int nFields = 0;
        for (int i = 0; i < nTotalFields; ++i) {
            if (joinCondRefCounts[i] == 0) continue;
            while (i >= startField + nFields) {
                startField += nFields;
                assert (++currInput < nInputs);
                nFields = multiJoinInputs.get(currInput).getRowType().getFieldCount();
            }
            int[] refCounts = (int[])refCountsMap.get(currInput);
            if (refCounts == null) {
                refCounts = new int[nFields];
                refCountsMap.put(currInput, refCounts);
            }
            int n = i - startField;
            refCounts[n] = refCounts[n] + joinCondRefCounts[i];
        }
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (Map.Entry entry : refCountsMap.entrySet()) {
            builder.put(entry.getKey(), (Object)ImmutableIntList.of((int[])((int[])entry.getValue())));
        }
        return builder.build();
    }

    private List<RexNode> combinePostJoinFilters(Join joinRel, RelNode left, RelNode right) {
        ArrayList filters = Lists.newArrayList();
        if (right instanceof MultiJoin) {
            MultiJoin multiRight = (MultiJoin)right;
            filters.add(this.shiftRightFilter(joinRel, left, (RelNode)multiRight, multiRight.getPostJoinFilter()));
        }
        if (left instanceof MultiJoin) {
            filters.add(((MultiJoin)left).getPostJoinFilter());
        }
        return filters;
    }

    private class InputReferenceCounter
    extends RexVisitorImpl<Void> {
        private final int[] refCounts;

        public InputReferenceCounter(int[] refCounts) {
            super(true);
            this.refCounts = refCounts;
        }

        public Void visitInputRef(RexInputRef inputRef) {
            int n = inputRef.getIndex();
            this.refCounts[n] = this.refCounts[n] + 1;
            return null;
        }
    }
}

