/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.lib.CLALibDecompress;
import org.apache.sysds.runtime.compress.lib.CLALibUtils;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;

public final class CLALibRightMultBy {
    private static final Log LOG = LogFactory.getLog((String)CLALibRightMultBy.class.getName());

    private CLALibRightMultBy() {
    }

    public static MatrixBlock rightMultByMatrix(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) {
        boolean allowOverlap = ConfigurationManager.getDMLConfig().getBooleanValue("sysds.compressed.overlapping");
        return CLALibRightMultBy.rightMultByMatrix(m1, m2, ret, k, allowOverlap);
    }

    public static MatrixBlock rightMultByMatrix(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k, boolean allowOverlap) {
        try {
            int rr = m1.getNumRows();
            int rc = m2.getNumColumns();
            if (m1.isEmpty() || m2.isEmpty()) {
                LOG.trace((Object)"Empty right multiply");
                if (ret == null) {
                    ret = new MatrixBlock(rr, rc, 0L);
                } else {
                    ret.reset(rr, rc, 0L);
                }
                return ret;
            }
            if (m2 instanceof CompressedMatrixBlock) {
                m2 = ((CompressedMatrixBlock)m2).getUncompressed("Uncompressed right side of right MM", k);
            }
            if (CLALibRightMultBy.betterIfDecompressed(m1)) {
                return CLALibRightMultBy.decompressingMatrixMult(m1, m2, k);
            }
            if (!allowOverlap) {
                LOG.trace((Object)"Overlapping output not allowed in call to Right MM");
                return CLALibRightMultBy.RMM(m1, m2, k);
            }
            CompressedMatrixBlock retC = CLALibRightMultBy.RMMOverlapping(m1, m2, k);
            if (retC.isEmpty()) {
                return retC;
            }
            if (retC.isOverlapping()) {
                retC.setNonZeros((long)rr * (long)rc);
            } else {
                retC.recomputeNonZeros(k);
            }
            return retC;
        }
        catch (Exception e) {
            throw new RuntimeException("Failed Right MM", e);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static MatrixBlock decompressingMatrixMult(CompressedMatrixBlock m1, MatrixBlock m2, int k) throws Exception {
        ExecutorService pool = CommonThreadPool.get(k);
        try {
            int rl = m1.getNumRows();
            int cr = m2.getNumColumns();
            MatrixBlock ret = new MatrixBlock(rl, cr, false);
            ret.allocateBlock();
            ArrayList<Future<Long>> tasks = new ArrayList<Future<Long>>();
            List<AColGroup> groups = m1.getColGroups();
            int blkI = Math.max((int)Math.ceil((double)rl / (double)k), 16);
            int blkJ = blkI > 16 ? cr : Math.max(cr / k, 512);
            for (int i = 0; i < rl; i += blkI) {
                int startI = i;
                int endI = Math.min(i + blkI, rl);
                for (int j = 0; j < cr; j += blkJ) {
                    int startJ = j;
                    int endJ = Math.min(j + blkJ, cr);
                    tasks.add(pool.submit(() -> {
                        for (AColGroup g : groups) {
                            g.rightDecompressingMult(m2, ret, startI, endI, rl, startJ, endJ);
                        }
                        return ret.recomputeNonZeros(startI, endI - 1, startJ, endJ - 1);
                    }));
                }
            }
            long nnz = 0L;
            for (Future future : tasks) {
                nnz += ((Long)future.get()).longValue();
            }
            ret.setNonZeros(nnz);
            ret.examSparsity();
            MatrixBlock matrixBlock = ret;
            return matrixBlock;
        }
        finally {
            pool.shutdown();
        }
    }

    private static boolean betterIfDecompressed(CompressedMatrixBlock m) {
        for (AColGroup g : m.getColGroups()) {
            if (g instanceof ColGroupUncompressed || g.getNumValues() * 2 < m.getNumRows()) continue;
            return true;
        }
        return false;
    }

    private static CompressedMatrixBlock RMMOverlapping(CompressedMatrixBlock m1, MatrixBlock that, int k) throws Exception {
        List<AColGroup> filteredGroups;
        double[] constV;
        int rl = m1.getNumRows();
        int cr = that.getNumColumns();
        int rr = that.getNumRows();
        List<AColGroup> colGroups = m1.getColGroups();
        ArrayList<AColGroup> retCg = new ArrayList<AColGroup>();
        CompressedMatrixBlock ret = new CompressedMatrixBlock(rl, cr);
        boolean shouldFilter = CLALibUtils.shouldPreFilter(colGroups);
        if (shouldFilter) {
            constV = new double[rr];
            filteredGroups = CLALibUtils.filterGroups(colGroups, constV);
        } else {
            filteredGroups = colGroups;
            constV = null;
        }
        if (k == 1 || filteredGroups.size() == 1) {
            CLALibRightMultBy.RMMSingle(filteredGroups, that, retCg);
        } else {
            CLALibRightMultBy.RMMParallel(filteredGroups, that, retCg, k);
        }
        if (constV != null) {
            MatrixBlock cb = new MatrixBlock(1, constV.length, constV);
            MatrixBlock cbRet = new MatrixBlock(1, that.getNumColumns(), false);
            LibMatrixMult.matrixMult(cb, that, cbRet);
            if (!cbRet.isEmpty()) {
                CLALibRightMultBy.addConstant(cbRet, retCg);
            }
        }
        ret.allocateColGroupList(retCg);
        if (retCg.size() > 1) {
            ret.setOverlapping(true);
        }
        CLALibUtils.addEmptyColumn(retCg, cr);
        return ret;
    }

    private static void addConstant(MatrixBlock constantRow, List<AColGroup> out) {
        constantRow.sparseToDense();
        out.add(ColGroupConst.create(constantRow.getDenseBlockValues()));
    }

    private static MatrixBlock RMM(CompressedMatrixBlock m1, MatrixBlock that, int k) throws Exception {
        double[] constV;
        List<AColGroup> filteredGroups;
        int rl = m1.getNumRows();
        int cr = that.getNumColumns();
        int rr = that.getNumRows();
        List<AColGroup> colGroups = m1.getColGroups();
        boolean shouldFilter = CLALibUtils.shouldPreFilter(colGroups);
        MatrixBlock ret = new MatrixBlock(rl, cr, false);
        Future<MatrixBlock> f = ret.allocateBlockAsync();
        if (shouldFilter) {
            if (CLALibUtils.alreadyPreFiltered(colGroups, cr)) {
                filteredGroups = new ArrayList<AColGroup>(colGroups.size() - 1);
                constV = CLALibUtils.filterGroupsAndSplitPreAggOneConst(colGroups, filteredGroups);
            } else {
                constV = new double[rr];
                filteredGroups = CLALibUtils.filterGroups(colGroups, constV);
            }
        } else {
            filteredGroups = colGroups;
            constV = null;
        }
        ArrayList<AColGroup> retCg = new ArrayList<AColGroup>(filteredGroups.size());
        if (k == 1) {
            CLALibRightMultBy.RMMSingle(filteredGroups, that, retCg);
        } else {
            CLALibRightMultBy.RMMParallel(filteredGroups, that, retCg, k);
        }
        if (constV != null) {
            MatrixBlock constVMB = new MatrixBlock(1, constV.length, constV);
            MatrixBlock mmTemp = new MatrixBlock(1, cr, false);
            LibMatrixMult.matrixMult(constVMB, that, mmTemp);
            constV = mmTemp.isEmpty() ? null : mmTemp.getDenseBlockValues();
        }
        ret = f.get();
        CLALibDecompress.decompressDense(ret, retCg, constV, 0.0, k, true);
        return ret;
    }

    private static boolean RMMSingle(List<AColGroup> filteredGroups, MatrixBlock that, List<AColGroup> retCg) {
        boolean containsNull = false;
        IColIndex allCols = ColIndexFactory.create(that.getNumColumns());
        for (AColGroup g : filteredGroups) {
            AColGroup retG = g.rightMultByMatrix(that, allCols, 1);
            if (retG != null) {
                retCg.add(retG);
                continue;
            }
            containsNull = true;
        }
        return containsNull;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static boolean RMMParallel(List<AColGroup> filteredGroups, MatrixBlock that, List<AColGroup> retCg, int k) throws Exception {
        ExecutorService pool = CommonThreadPool.get(k);
        boolean containsNull = false;
        try {
            IColIndex allCols = ColIndexFactory.create(that.getNumColumns());
            ArrayList<RightMatrixMultTask> tasks = new ArrayList<RightMatrixMultTask>(filteredGroups.size());
            for (AColGroup aColGroup : filteredGroups) {
                tasks.add(new RightMatrixMultTask(aColGroup, that, allCols, k));
            }
            for (Future future : pool.invokeAll(tasks)) {
                AColGroup g = (AColGroup)future.get();
                if (g != null) {
                    retCg.add(g);
                    continue;
                }
                containsNull = true;
            }
        }
        finally {
            pool.shutdown();
        }
        return containsNull;
    }

    private static class RightMatrixMultTask
    implements Callable<AColGroup> {
        private final AColGroup _colGroup;
        private final MatrixBlock _b;
        private final IColIndex _allCols;
        private final int _k;

        protected RightMatrixMultTask(AColGroup colGroup, MatrixBlock b, IColIndex allCols, int k) {
            this._colGroup = colGroup;
            this._b = b;
            this._allCols = allCols;
            this._k = k;
        }

        @Override
        public AColGroup call() throws Exception {
            return this._colGroup.rightMultByMatrix(this._b, this._allCols, this._k);
        }
    }
}

