/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.rewrite;

import java.util.ArrayList;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.rewrite.HopRewriteRule;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.hops.rewrite.ProgramRewriteStatus;
import org.apache.sysml.parser.Expression;

public class RewriteRemoveUnnecessaryCasts
extends HopRewriteRule {
    @Override
    public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
        if (roots == null) {
            return null;
        }
        for (Hop h : roots) {
            this.rule_RemoveUnnecessaryCasts(h);
        }
        return roots;
    }

    @Override
    public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
        if (root == null) {
            return root;
        }
        this.rule_RemoveUnnecessaryCasts(root);
        return root;
    }

    private void rule_RemoveUnnecessaryCasts(Hop hop) {
        ArrayList parents;
        Expression.ValueType vtOut;
        Hop in;
        Expression.ValueType vtIn;
        if (hop.isVisited()) {
            return;
        }
        ArrayList<Hop> inputs = hop.getInput();
        for (int i = 0; i < inputs.size(); ++i) {
            this.rule_RemoveUnnecessaryCasts(inputs.get(i));
        }
        if (hop instanceof UnaryOp && HopRewriteUtils.isValueTypeCast(((UnaryOp)hop).getOp()) && (vtIn = (in = hop.getInput().get(0)).getValueType()) == (vtOut = hop.getValueType()) && vtIn != Expression.ValueType.UNKNOWN) {
            parents = hop.getParent();
            for (int i = 0; i < parents.size(); ++i) {
                Hop p = parents.get(i);
                ArrayList<Hop> pin = p.getInput();
                for (int j = 0; j < pin.size(); ++j) {
                    Hop pinj = pin.get(j);
                    if (pinj != hop) continue;
                    pin.remove(j);
                    pin.add(j, in);
                    in.getParent().remove(hop);
                    in.getParent().add(p);
                }
            }
            parents.clear();
        }
        if (hop instanceof UnaryOp && hop.getInput().get(0) instanceof UnaryOp) {
            UnaryOp uop1 = (UnaryOp)hop;
            UnaryOp uop2 = (UnaryOp)hop.getInput().get(0);
            if (uop1.getOp() == Hop.OpOp1.CAST_AS_MATRIX && uop2.getOp() == Hop.OpOp1.CAST_AS_SCALAR || uop1.getOp() == Hop.OpOp1.CAST_AS_SCALAR && uop2.getOp() == Hop.OpOp1.CAST_AS_MATRIX) {
                Hop input = uop2.getInput().get(0);
                parents = (ArrayList)hop.getParent().clone();
                for (Hop p : parents) {
                    HopRewriteUtils.replaceChildReference(p, hop, input);
                }
            }
        }
        hop.setVisited();
    }
}

