/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.lib.CLALibDecompress;
import org.apache.sysds.runtime.compress.lib.CLALibUtils;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.utils.DMLCompressionStatistics;

public final class CLALibRightMultBy {
    private static final Log LOG = LogFactory.getLog((String)CLALibRightMultBy.class.getName());

    private CLALibRightMultBy() {
    }

    public static MatrixBlock rightMultByMatrix(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) {
        boolean allowOverlap = ConfigurationManager.getDMLConfig().getBooleanValue("sysds.compressed.overlapping");
        return CLALibRightMultBy.rightMultByMatrix(m1, m2, ret, k, allowOverlap);
    }

    public static MatrixBlock rightMultByMatrix(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k, boolean allowOverlap) {
        int rr = m1.getNumRows();
        int rc = m2.getNumColumns();
        if (m1.isEmpty() || m2.isEmpty()) {
            LOG.trace((Object)"Empty right multiply");
            if (ret == null) {
                ret = new MatrixBlock(rr, rc, 0L);
            } else {
                ret.reset(rr, rc, 0L);
            }
            return ret;
        }
        if (m2 instanceof CompressedMatrixBlock) {
            m2 = ((CompressedMatrixBlock)m2).getUncompressed("Uncompressed right side of right MM", k);
        }
        if (!allowOverlap) {
            LOG.trace((Object)"Overlapping output not allowed in call to Right MM");
            return CLALibRightMultBy.RMM(m1, m2, k);
        }
        CompressedMatrixBlock retC = CLALibRightMultBy.RMMOverlapping(m1, m2, k);
        if (retC.isEmpty()) {
            return retC;
        }
        if (retC.isOverlapping()) {
            retC.setNonZeros((long)rr * (long)rc);
        } else {
            retC.recomputeNonZeros();
        }
        return retC;
    }

    private static CompressedMatrixBlock RMMOverlapping(CompressedMatrixBlock m1, MatrixBlock that, int k) {
        int rl = m1.getNumRows();
        int cr = that.getNumColumns();
        int rr = that.getNumRows();
        List<AColGroup> colGroups = m1.getColGroups();
        ArrayList<AColGroup> retCg = new ArrayList<AColGroup>();
        CompressedMatrixBlock ret = new CompressedMatrixBlock(rl, cr);
        boolean shouldFilter = CLALibUtils.shouldPreFilter(colGroups);
        double[] constV = shouldFilter ? new double[rr] : null;
        List<AColGroup> filteredGroups = CLALibUtils.filterGroups(colGroups, constV);
        if (colGroups == filteredGroups) {
            constV = null;
        }
        if (k == 1) {
            CLALibRightMultBy.RMMSingle(filteredGroups, that, retCg);
        } else {
            CLALibRightMultBy.RMMParallel(filteredGroups, that, retCg, k);
        }
        if (constV != null) {
            MatrixBlock cb = new MatrixBlock(1, constV.length, constV);
            MatrixBlock cbRet = new MatrixBlock(1, that.getNumColumns(), false);
            LibMatrixMult.matrixMult(cb, that, cbRet);
            if (!cbRet.isEmpty()) {
                CLALibRightMultBy.addConstant(cbRet, retCg);
            }
        }
        ret.allocateColGroupList(retCg);
        if (retCg.size() > 1) {
            ret.setOverlapping(true);
        }
        CLALibUtils.addEmptyColumn(retCg, cr);
        return ret;
    }

    private static void addConstant(MatrixBlock constantRow, List<AColGroup> out) {
        int nCol = constantRow.getNumColumns();
        int bestCandidate = -1;
        int bestCandidateValuesSize = Integer.MAX_VALUE;
        for (int i = 0; i < out.size(); ++i) {
            AColGroup g = out.get(i);
            if (!(g instanceof ColGroupDDC) || g.getNumCols() != nCol || g.getNumValues() >= bestCandidateValuesSize) continue;
            bestCandidate = i;
        }
        constantRow.sparseToDense();
        if (bestCandidate != -1) {
            AColGroup bc = out.get(bestCandidate);
            out.remove(bestCandidate);
            AColGroup ng = bc.binaryRowOpRight(new BinaryOperator(Plus.getPlusFnObject(), 1), constantRow.getDenseBlockValues(), true);
            out.add(ng);
        } else {
            out.add(ColGroupConst.create(constantRow.getDenseBlockValues()));
        }
    }

    private static MatrixBlock RMM(CompressedMatrixBlock m1, MatrixBlock that, int k) {
        int rl = m1.getNumRows();
        int cr = that.getNumColumns();
        int rr = that.getNumRows();
        List<AColGroup> colGroups = m1.getColGroups();
        ArrayList<AColGroup> retCg = new ArrayList<AColGroup>();
        boolean shouldFilter = CLALibUtils.shouldPreFilter(colGroups);
        MatrixBlock ret = new MatrixBlock(rl, cr, false);
        Future<MatrixBlock> f = ret.allocateBlockAsync();
        double[] constV = shouldFilter ? new double[rr] : null;
        List<AColGroup> filteredGroups = CLALibUtils.filterGroups(colGroups, constV);
        if (colGroups == filteredGroups) {
            constV = null;
        }
        if (k == 1) {
            CLALibRightMultBy.RMMSingle(filteredGroups, that, retCg);
        } else {
            CLALibRightMultBy.RMMParallel(filteredGroups, that, retCg, k);
        }
        if (constV != null) {
            MatrixBlock constVMB = new MatrixBlock(1, constV.length, constV);
            MatrixBlock mmTemp = new MatrixBlock(1, cr, false);
            LibMatrixMult.matrixMult(constVMB, that, mmTemp);
            constV = mmTemp.isEmpty() ? null : mmTemp.getDenseBlockValues();
        }
        Timing time = new Timing(true);
        ret = CLALibRightMultBy.asyncRet(f);
        CLALibDecompress.decompressDenseMultiThread(ret, retCg, constV, 0.0, k, true);
        if (DMLScript.STATISTICS) {
            double t = time.stop();
            DMLCompressionStatistics.addDecompressTime(t, k);
        }
        return ret;
    }

    private static <T> T asyncRet(Future<T> in) {
        try {
            return in.get();
        }
        catch (Exception e) {
            throw new DMLRuntimeException(e);
        }
    }

    private static boolean RMMSingle(List<AColGroup> filteredGroups, MatrixBlock that, List<AColGroup> retCg) {
        boolean containsNull = false;
        IColIndex allCols = ColIndexFactory.create(that.getNumColumns());
        for (AColGroup g : filteredGroups) {
            AColGroup retG = g.rightMultByMatrix(that, allCols);
            if (retG != null) {
                retCg.add(retG);
                continue;
            }
            containsNull = true;
        }
        return containsNull;
    }

    private static boolean RMMParallel(List<AColGroup> filteredGroups, MatrixBlock that, List<AColGroup> retCg, int k) {
        ExecutorService pool = CommonThreadPool.get(k);
        boolean containsNull = false;
        try {
            IColIndex allCols = ColIndexFactory.create(that.getNumColumns());
            ArrayList<RightMatrixMultTask> tasks = new ArrayList<RightMatrixMultTask>(filteredGroups.size());
            for (AColGroup aColGroup : filteredGroups) {
                tasks.add(new RightMatrixMultTask(aColGroup, that, allCols));
            }
            for (Future future : pool.invokeAll(tasks)) {
                AColGroup g = (AColGroup)future.get();
                if (g != null) {
                    retCg.add(g);
                    continue;
                }
                containsNull = true;
            }
        }
        catch (InterruptedException | ExecutionException e) {
            throw new DMLRuntimeException(e);
        }
        finally {
            pool.shutdown();
        }
        return containsNull;
    }

    private static class RightMatrixMultTask
    implements Callable<AColGroup> {
        private final AColGroup _colGroup;
        private final MatrixBlock _b;
        private final IColIndex _allCols;

        protected RightMatrixMultTask(AColGroup colGroup, MatrixBlock b, IColIndex allCols) {
            this._colGroup = colGroup;
            this._b = b;
            this._allCols = allCols;
        }

        @Override
        public AColGroup call() {
            try {
                return this._colGroup.rightMultByMatrix(this._b, this._allCols);
            }
            catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
        }
    }
}

