/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.codegen.template;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.hops.codegen.cplan.CNode;
import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
import org.apache.sysds.hops.codegen.cplan.CNodeData;
import org.apache.sysds.hops.codegen.cplan.CNodeOuterProduct;
import org.apache.sysds.hops.codegen.cplan.CNodeTpl;
import org.apache.sysds.hops.codegen.cplan.CNodeUnary;
import org.apache.sysds.hops.codegen.template.CPlanMemoTable;
import org.apache.sysds.hops.codegen.template.TemplateBase;
import org.apache.sysds.hops.codegen.template.TemplateUtils;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.MMTSJ;
import org.apache.sysds.runtime.codegen.SpoofOuterProduct;
import org.apache.sysds.runtime.matrix.data.Pair;

public class TemplateOuterProduct
extends TemplateBase {
    MMTSJ.MMTSJType mmtsj = MMTSJ.MMTSJType.NONE;

    public TemplateOuterProduct() {
        super(TemplateBase.TemplateType.OUTER);
    }

    public TemplateOuterProduct(TemplateBase.CloseType ctype) {
        super(TemplateBase.TemplateType.OUTER, ctype);
    }

    @Override
    public boolean open(Hop hop) {
        return (HopRewriteUtils.isOuterProductLikeMM(hop) || HopRewriteUtils.isOuterBinary(hop)) && hop.getDim1() > 256L && hop.getDim2() > 256L;
    }

    @Override
    public boolean fuse(Hop hop, Hop input) {
        return !this.isClosed() && (hop instanceof UnaryOp && TemplateUtils.isOperationSupported(hop) || hop instanceof BinaryOp && TemplateUtils.isOperationSupported(hop) && (TemplateUtils.isBinaryMatrixColVector(hop) || HopRewriteUtils.isBinaryMatrixScalarOperation(hop) || HopRewriteUtils.isBinaryMatrixMatrixOperation(hop) || TemplateUtils.isBinaryMatrixRowVector(hop)) || HopRewriteUtils.isTransposeOperation(hop) && input instanceof AggBinaryOp && !HopRewriteUtils.isOuterProductLikeMM(input) || hop instanceof AggBinaryOp && !HopRewriteUtils.isOuterProductLikeMM(hop) && TemplateUtils.containsOuterProduct(input, HopRewriteUtils.getOtherInput(hop, input)) || hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection() == Types.Direction.RowCol);
    }

    @Override
    public boolean merge(Hop hop, Hop input) {
        return !this.isClosed() && (TemplateUtils.isBinaryMatrixRowVector(hop) || HopRewriteUtils.isBinaryMatrixScalarOperation(hop) || HopRewriteUtils.isBinary(hop, Types.OpOp2.MULT) && HopRewriteUtils.isBinarySparseSafe(input) && !TemplateUtils.containsOuterProduct(input));
    }

    @Override
    public TemplateBase.CloseType close(Hop hop) {
        if (hop instanceof AggUnaryOp && (HopRewriteUtils.isOuterProductLikeMM(hop.getInput().get(0)) || !HopRewriteUtils.isBinarySparseSafe(hop.getInput().get(0))) || hop instanceof AggBinaryOp && (HopRewriteUtils.isOuterProductLikeMM(hop.getInput().get(0)) || HopRewriteUtils.isOuterProductLikeMM(hop.getInput().get(1)) || !HopRewriteUtils.isOuterProductLikeMM(hop) && !HopRewriteUtils.isBinarySparseSafe(HopRewriteUtils.getLargestInput(hop)))) {
            return TemplateBase.CloseType.CLOSED_INVALID;
        }
        if (hop instanceof AggUnaryOp || hop instanceof AggBinaryOp && !HopRewriteUtils.isOuterProductLikeMM(hop) && !HopRewriteUtils.isTransposeOperation(hop.getParent().get(0)) || HopRewriteUtils.isTransposeOperation(hop) && hop.getInput().get(0) instanceof AggBinaryOp && !HopRewriteUtils.isOuterProductLikeMM(hop.getInput().get(0))) {
            return TemplateBase.CloseType.CLOSED_VALID;
        }
        if (HopRewriteUtils.isBinaryMatrixMatrixOperation(hop) && HopRewriteUtils.isBinary(hop, Types.OpOp2.MULT, Types.OpOp2.DIV)) {
            return TemplateBase.CloseType.OPEN_VALID;
        }
        return TemplateBase.CloseType.OPEN_INVALID;
    }

    @Override
    public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) {
        HashSet<Hop> inHops = new HashSet<Hop>();
        HashMap<String, Hop> inHops2 = new HashMap<String, Hop>();
        HashMap<Long, CNode> tmp = new HashMap<Long, CNode>();
        hop.resetVisitStatus();
        this.rConstructCplan(hop, memo, tmp, inHops, inHops2, compileLiterals);
        hop.resetVisitStatus();
        long outputHopID = hop.getHopID();
        if (hop instanceof BinaryOp) {
            outputHopID = TemplateUtils.skipConditionalInOuterProduct(hop, tmp, inHops);
        }
        Hop X = inHops2.get("_X");
        Hop U = inHops2.get("_U");
        Hop V = inHops2.get("_V");
        LinkedList<Hop> sinHops = new LinkedList<Hop>(inHops);
        sinHops.remove(V);
        sinHops.remove(U);
        sinHops.remove(X);
        sinHops.addFirst(V);
        sinHops.addFirst(U);
        sinHops.addFirst(X);
        ArrayList<CNode> inputs = new ArrayList<CNode>();
        for (Hop in : sinHops) {
            if (in == null) continue;
            inputs.add(tmp.get(in.getHopID()));
        }
        CNode output = tmp.get(outputHopID);
        CNodeOuterProduct tpl = new CNodeOuterProduct(inputs, output, this.mmtsj);
        tpl.setOutProdType(TemplateUtils.getOuterProductType(X, U, V, hop));
        tpl.setTransposeOutput(!HopRewriteUtils.isTransposeOperation(hop) && tpl.getOutProdType() == SpoofOuterProduct.OutProdType.LEFT_OUTER_PRODUCT);
        tpl.setBeginLine(hop.getBeginLine());
        return new Pair<Hop[], CNodeTpl>(sinHops.toArray(new Hop[0]), tpl);
    }

    private void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, HashMap<String, Hop> inHops2, boolean compileLiterals) {
        CNode cdata2;
        CNode cdata1;
        if (tmp.containsKey(hop.getHopID())) {
            return;
        }
        CPlanMemoTable.MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateBase.TemplateType.OUTER, TemplateBase.TemplateType.CELL);
        for (int i = 0; i < hop.getInput().size(); ++i) {
            Hop c = hop.getInput().get(i);
            if (me.isPlanRef(i)) {
                this.rConstructCplan(c, memo, tmp, inHops, inHops2, compileLiterals);
                continue;
            }
            CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
            tmp.put(c.getHopID(), cdata);
            inHops.add(c);
        }
        CNode out = null;
        if (hop instanceof UnaryOp) {
            cdata1 = tmp.get(hop.getInput().get(0).getHopID());
            String primitiveOpName = ((UnaryOp)hop).getOp().name();
            out = new CNodeUnary(cdata1, CNodeUnary.UnaryType.valueOf(primitiveOpName));
        } else if (hop instanceof BinaryOp) {
            cdata1 = tmp.get(hop.getInput().get(0).getHopID());
            cdata2 = tmp.get(hop.getInput().get(1).getHopID());
            String primitiveOpName = ((BinaryOp)hop).getOp().name();
            if (HopRewriteUtils.isBinarySparseSafe(hop)) {
                if (TemplateUtils.isMatrix(hop.getInput().get(0)) && cdata1 instanceof CNodeData) {
                    inHops2.put("_X", hop.getInput().get(0));
                }
                if (TemplateUtils.isMatrix(hop.getInput().get(1)) && cdata2 instanceof CNodeData) {
                    inHops2.put("_X", hop.getInput().get(1));
                }
            }
            cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
            cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
            out = new CNodeBinary(cdata1, cdata2, CNodeBinary.BinType.valueOf(primitiveOpName));
        } else if (hop instanceof AggBinaryOp) {
            cdata1 = tmp.get(hop.getInput().get(0).getHopID());
            cdata2 = tmp.get(hop.getInput().get(1).getHopID());
            cdata1 = TemplateUtils.skipTranspose(cdata1, hop.getInput().get(0), tmp, compileLiterals);
            cdata2 = TemplateUtils.skipTranspose(cdata2, hop.getInput().get(1), tmp, compileLiterals);
            if (HopRewriteUtils.isOuterProductLikeMM(hop)) {
                if (HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))) {
                    inHops2.put("_U", hop.getInput().get(0).getInput().get(0));
                } else {
                    inHops2.put("_U", hop.getInput().get(0));
                }
                if (HopRewriteUtils.isTransposeOperation(hop.getInput().get(1))) {
                    inHops2.put("_V", hop.getInput().get(1).getInput().get(0));
                } else {
                    inHops2.put("_V", hop.getInput().get(1));
                }
                this.mmtsj = ((AggBinaryOp)hop).checkTransposeSelf();
                out = new CNodeBinary(cdata1, cdata2, CNodeBinary.BinType.DOT_PRODUCT);
            } else {
                out = cdata1.getDataType().isScalar() ? new CNodeBinary(cdata2, cdata1, CNodeBinary.BinType.VECT_MULT_ADD) : new CNodeBinary(cdata1, cdata2, CNodeBinary.BinType.VECT_MULT_ADD);
            }
        } else if (HopRewriteUtils.isTransposeOperation(hop)) {
            out = tmp.get(hop.getInput().get(0).getHopID());
        } else if (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getOp() == Types.AggOp.SUM && ((AggUnaryOp)hop).getDirection() == Types.Direction.RowCol) {
            out = tmp.get(hop.getInput().get(0).getHopID());
        }
        tmp.put(hop.getHopID(), out);
    }

    public static CPlanMemoTable.MemoTableEntry dropAlternativePlan(CPlanMemoTable memo, CPlanMemoTable.MemoTableEntry me1, CPlanMemoTable.MemoTableEntry me2) {
        if (me1.countPlanRefs() == 1 && me2.countPlanRefs() == 1 && me1.getPlanRefIndex() != me2.getPlanRefIndex()) {
            Hop c1 = memo._hopRefs.get(me1.input(me1.getPlanRefIndex()));
            Hop c2 = memo._hopRefs.get(me2.input(me2.getPlanRefIndex()));
            if (memo.contains(c1.getHopID(), TemplateBase.TemplateType.OUTER) && memo.contains(c2.getHopID(), TemplateBase.TemplateType.OUTER)) {
                if (HopRewriteUtils.isBinaryMatrixMatrixOperation(c1) && HopRewriteUtils.isBinary(c1, Types.OpOp2.MULT, Types.OpOp2.DIV)) {
                    return me1;
                }
                if (HopRewriteUtils.isBinaryMatrixMatrixOperation(c2) && HopRewriteUtils.isBinary(c2, Types.OpOp2.MULT, Types.OpOp2.DIV)) {
                    return me2;
                }
            }
        }
        return null;
    }
}

