/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.ipa;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.DataGenOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.ipa.FunctionCallGraph;
import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
import org.apache.sysds.hops.ipa.IPAPass;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;

public class IPAPassRemoveConstantBinaryOps
extends IPAPass {
    @Override
    public boolean isApplicable(FunctionCallGraph fgraph) {
        return true;
    }

    @Override
    public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) {
        HashMap<String, Hop> mOnes = new HashMap<String, Hop>();
        for (StatementBlock sb : prog.getStatementBlocks()) {
            for (String var : sb.variablesUpdated().getVariableNames()) {
                if (!mOnes.containsKey(var)) continue;
                mOnes.remove(var);
            }
            if (!mOnes.isEmpty()) {
                IPAPassRemoveConstantBinaryOps.rRemoveConstantBinaryOp(sb, mOnes);
            }
            if (sb instanceof IfStatementBlock || sb instanceof WhileStatementBlock || sb instanceof ForStatementBlock) continue;
            IPAPassRemoveConstantBinaryOps.collectMatrixOfOnes(sb.getHops(), mOnes);
        }
        return false;
    }

    private static void collectMatrixOfOnes(List<Hop> roots, Map<String, Hop> mOnes) {
        if (roots == null) {
            return;
        }
        for (Hop root : roots) {
            if (!(root instanceof DataOp) || ((DataOp)root).getOp() != Types.OpOpData.TRANSIENTWRITE || !(root.getInput().get(0) instanceof DataGenOp) || ((DataGenOp)root.getInput().get(0)).getOp() != Types.OpOpDG.RAND || !((DataGenOp)root.getInput().get(0)).hasConstantValue(1.0)) continue;
            mOnes.put(root.getName(), root.getInput().get(0));
        }
    }

    private static void rRemoveConstantBinaryOp(StatementBlock sb, Map<String, Hop> mOnes) {
        block11: {
            block10: {
                if (!(sb instanceof IfStatementBlock)) break block10;
                IfStatementBlock isb = (IfStatementBlock)sb;
                IfStatement istmt = (IfStatement)isb.getStatement(0);
                for (StatementBlock c : istmt.getIfBody()) {
                    IPAPassRemoveConstantBinaryOps.rRemoveConstantBinaryOp(c, mOnes);
                }
                if (istmt.getElseBody() == null) break block11;
                for (StatementBlock c : istmt.getElseBody()) {
                    IPAPassRemoveConstantBinaryOps.rRemoveConstantBinaryOp(c, mOnes);
                }
                break block11;
            }
            if (sb instanceof WhileStatementBlock) {
                WhileStatementBlock wsb = (WhileStatementBlock)sb;
                WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
                for (StatementBlock c : wstmt.getBody()) {
                    IPAPassRemoveConstantBinaryOps.rRemoveConstantBinaryOp(c, mOnes);
                }
            } else if (sb instanceof ForStatementBlock) {
                ForStatementBlock fsb = (ForStatementBlock)sb;
                ForStatement fstmt = (ForStatement)fsb.getStatement(0);
                for (StatementBlock c : fstmt.getBody()) {
                    IPAPassRemoveConstantBinaryOps.rRemoveConstantBinaryOp(c, mOnes);
                }
            } else if (sb.getHops() != null) {
                Hop.resetVisitStatus(sb.getHops());
                for (Hop hop : sb.getHops()) {
                    IPAPassRemoveConstantBinaryOps.rRemoveConstantBinaryOp(hop, mOnes);
                }
            }
        }
    }

    private static void rRemoveConstantBinaryOp(Hop hop, Map<String, Hop> mOnes) {
        if (hop.isVisited()) {
            return;
        }
        if (hop instanceof BinaryOp && ((BinaryOp)hop).getOp() == Types.OpOp2.MULT && !((BinaryOp)hop).isOuter() && hop.getInput().get(0).getDataType() == Types.DataType.MATRIX && hop.getInput().get(1) instanceof DataOp && mOnes.containsKey(hop.getInput().get(1).getName())) {
            HopRewriteUtils.removeChildReferenceByPos(hop, hop.getInput().get(1), 1);
            HopRewriteUtils.addChildReference(hop, new LiteralOp(1L), 1);
        }
        for (Hop c : hop.getInput()) {
            IPAPassRemoveConstantBinaryOps.rRemoveConstantBinaryOp(c, mOnes);
        }
        hop.setVisited();
    }
}

