/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.lops.rewrite;

import java.util.ArrayList;
import java.util.List;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.lops.Data;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.MatMultCP;
import org.apache.sysds.lops.OperatorOrderingUtils;
import org.apache.sysds.lops.rewrite.LopRewriteRule;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.runtime.matrix.data.LibMatrixNative;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;

public class RewriteUpdateGPUPlacements
extends LopRewriteRule {
    @Override
    public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock sb) {
        if (!ConfigurationManager.isRuleBasedGPUPlacement()) {
            return List.of(sb);
        }
        ArrayList<Lop> lops = OperatorOrderingUtils.getLopList(sb);
        if (lops == null || lops.stream().noneMatch(Lop::isExecGPU)) {
            return List.of(sb);
        }
        ArrayList<Lop> roots = sb.getLops();
        roots.forEach(this::rUpdateExecType);
        roots.forEach(Lop::resetVisitStatus);
        return List.of(sb);
    }

    @Override
    public List<StatementBlock> rewriteLOPinStatementBlocks(List<StatementBlock> sbs) {
        return sbs;
    }

    private void updateExecTypeGPU2CP(Lop lop) {
        boolean memBound;
        if (!lop.isExecGPU()) {
            return;
        }
        for (Lop in : lop.getInputs()) {
            if (in.getNnz() < 0L || !MatrixBlock.evalSparseFormatInMemory(in.getNumRows(), in.getNumCols(), in.getNnz())) continue;
            lop.setExecType(Types.ExecType.CP);
            return;
        }
        if (lop instanceof MatMultCP && !(memBound = LibMatrixNative.isMatMultMemoryBound((int)lop.getInput(0).getNumRows(), (int)lop.getInput(0).getNumCols(), (int)lop.getInput(1).getNumCols()))) {
            return;
        }
        if (lop.getInputs().size() == 2) {
            long size2;
            long size1 = MatrixBlock.estimateSizeInMemory(lop.getInput(0).getNumRows(), lop.getInput(0).getNumCols(), lop.getInput(0).getNnz());
            if (!(size1 <= (size2 = MatrixBlock.estimateSizeInMemory(lop.getInput(1).getNumRows(), lop.getInput(1).getNumCols(), lop.getInput(1).getNnz())) || lop.getInput(0) instanceof Data || lop.getInput(0).isExecGPU() || lop.isAllOutputsGPU())) {
                lop.setExecType(Types.ExecType.CP);
            }
            if (!(size2 <= size1 || lop.getInput(1) instanceof Data || lop.getInput(1).isExecGPU() || lop.isAllOutputsGPU())) {
                lop.setExecType(Types.ExecType.CP);
            }
            if (!(size1 != size2 || lop.getInput(0) instanceof Data || lop.getInput(1) instanceof Data || lop.getInput(0).isExecGPU() || lop.getInput(1).isExecGPU() || lop.isAllOutputsGPU())) {
                lop.setExecType(Types.ExecType.CP);
            }
        }
        if (!(lop.getInputs().size() != 1 || lop.getInput(0) instanceof Data || lop.getInput(0).isExecGPU() || lop.isAllOutputsGPU())) {
            lop.setExecType(Types.ExecType.CP);
        }
        if (lop.getInputs().size() > 2) {
            int numGPUInputs = 0;
            int numCPInputs = 0;
            for (Lop in : lop.getInputs()) {
                if (!(in instanceof Data) && in.isExecGPU()) {
                    ++numGPUInputs;
                }
                if (in instanceof Data || !in.isExecCP()) continue;
                ++numCPInputs;
            }
            if (numCPInputs > numGPUInputs && !lop.isAllOutputsGPU()) {
                lop.setExecType(Types.ExecType.CP);
            }
        }
    }

    private void rUpdateExecType(Lop root) {
        if (root.isVisited()) {
            return;
        }
        for (Lop input : root.getInputs()) {
            if (input instanceof Data) continue;
            this.rUpdateExecType(input);
        }
        this.updateExecTypeGPU2CP(root);
        root.setVisited();
    }
}

