/*
 * Decompiled with CFR 0.152.
 */
package hex.util;

import Jama.EigenvalueDecomposition;
import Jama.Matrix;
import Jama.QRDecomposition;
import hex.DataInfo;
import hex.FrameTask;
import hex.Interaction;
import hex.ToEigenVec;
import hex.gam.MatrixFrameUtils.TriDiagonalMatrix;
import hex.gram.Gram;
import hex.util.EigenPair;
import java.util.ArrayList;
import java.util.Arrays;
import jsr166y.ForkJoinTask;
import jsr166y.RecursiveAction;
import org.apache.commons.lang.ArrayUtils;
import water.DKV;
import water.Job;
import water.Key;
import water.MRTask;
import water.MemoryManager;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.util.Log;

public class LinearAlgebraUtils {
    public static ToEigenVec toEigen = new ToEigenVec(){

        @Override
        public Vec toEigenVec(Vec src) {
            return LinearAlgebraUtils.toEigen(src);
        }
    };

    public static double[] forwardSolve(double[][] L2, double[] b2) {
        assert (L2 != null && L2.length == b2.length);
        double[] res = new double[b2.length];
        for (int i2 = 0; i2 < b2.length; ++i2) {
            res[i2] = b2[i2];
            for (int j2 = 0; j2 < i2; ++j2) {
                int n2 = i2;
                res[n2] = res[n2] - L2[i2][j2] * res[j2];
            }
            int n3 = i2;
            res[n3] = res[n3] / L2[i2][i2];
        }
        return res;
    }

    public static double[] sqrtDiag(double[][] aMat) {
        int matrixSize = aMat.length;
        double[] answer = new double[matrixSize];
        for (int index = 0; index < matrixSize; ++index) {
            answer[index] = Math.sqrt(aMat[index][index]);
        }
        return answer;
    }

    public static double[][] chol2Inv(double[][] cholR, boolean upperTriag) {
        final int matrixSize = cholR.length;
        final double[][] cholL = upperTriag ? water.util.ArrayUtils.transpose(cholR) : cholR;
        final double[][] inverted = new double[matrixSize][];
        RecursiveAction[] ras = new RecursiveAction[matrixSize];
        int index = 0;
        while (index < matrixSize) {
            final double[] oneColumn = new double[matrixSize];
            oneColumn[index] = 1.0;
            final int i2 = index++;
            ras[i2] = new RecursiveAction(){

                @Override
                protected void compute() {
                    double[] upperColumn = LinearAlgebraUtils.forwardSolve(cholL, oneColumn);
                    inverted[i2] = Arrays.copyOf(upperColumn, matrixSize);
                }
            };
        }
        ForkJoinTask.invokeAll(ras);
        final double[][] cholRNew = upperTriag ? cholR : water.util.ArrayUtils.transpose(cholR);
        int index2 = 0;
        while (index2 < matrixSize) {
            double[] oneColumn = new double[matrixSize];
            oneColumn[index2] = 1.0;
            final int i3 = index2++;
            ras[i3] = new RecursiveAction(){

                @Override
                protected void compute() {
                    double[] lowerColumn = new double[matrixSize];
                    LinearAlgebraUtils.backwardSolve(cholRNew, inverted[i3], lowerColumn);
                    inverted[i3] = Arrays.copyOf(lowerColumn, matrixSize);
                }
            };
        }
        ForkJoinTask.invokeAll(ras);
        return inverted;
    }

    public static double[][] chol2Inv(double[][] cholR) {
        return LinearAlgebraUtils.chol2Inv(cholR, true);
    }

    public static double[][] generateTriDiagMatrix(double[] hj) {
        int matrixSize = hj.length - 1;
        final double[][] lowDiag = new double[matrixSize][];
        RecursiveAction[] ras = new RecursiveAction[matrixSize];
        for (int index = 0; index < matrixSize; ++index) {
            final int rowSize = index + 1;
            final int i2 = index;
            final double hjIndex = hj[index];
            final double hjIndexP1 = hj[index + 1];
            double oneO3 = 0.3333333333333333;
            double oneO6 = 0.16666666666666666;
            final double[] tempDiag = MemoryManager.malloc8d(rowSize);
            ras[i2] = new RecursiveAction(){

                @Override
                protected void compute() {
                    tempDiag[i2] = (hjIndex + hjIndexP1) * 0.3333333333333333;
                    if (i2 > 0) {
                        tempDiag[i2 - 1] = hjIndex * 0.16666666666666666;
                    }
                    lowDiag[i2] = Arrays.copyOf(tempDiag, rowSize);
                }
            };
        }
        ForkJoinTask.invokeAll(ras);
        return lowDiag;
    }

    public static double[][] generateOrthogonalComplement(double[][] orthMat, double[][] starT, int numBasis, long seed) {
        int index;
        int numOrthVec = orthMat[0].length;
        int vecSize = orthMat.length;
        double[][] orthMatT = water.util.ArrayUtils.transpose(orthMat);
        double[][] orthMatCompT = new double[numBasis][vecSize];
        double[][] orthMatCompT2 = new double[numBasis][vecSize];
        double[] innerProd = new double[numOrthVec];
        double[] scaleProd = new double[vecSize];
        double[][] orthMatCompT3 = water.util.ArrayUtils.subtract(LinearAlgebraUtils.generateIdentityMat(vecSize), water.util.ArrayUtils.multArrArr(orthMat, orthMatT));
        for (index = 0; index < numBasis; ++index) {
            System.arraycopy(orthMatCompT3[index], 0, orthMatCompT2[index], 0, vecSize);
        }
        LinearAlgebraUtils.applyGramSchmit(orthMatCompT2);
        for (index = 0; index < numBasis; ++index) {
            orthMatCompT[index] = water.util.ArrayUtils.gaussianVector(seed + (long)index, orthMatCompT[index]);
            LinearAlgebraUtils.genInnerProduct(orthMatT, orthMatCompT[index], innerProd);
            for (int basisInd = 0; basisInd < numOrthVec; ++basisInd) {
                System.arraycopy(orthMatT[basisInd], 0, scaleProd, 0, vecSize);
                water.util.ArrayUtils.mult(scaleProd, innerProd[basisInd]);
                water.util.ArrayUtils.subtract(orthMatCompT[index], scaleProd, orthMatCompT[index]);
            }
        }
        LinearAlgebraUtils.applyGramSchmit(orthMatCompT);
        return orthMatCompT;
    }

    public static double[][] generateIdentityMat(int size) {
        double[][] identity = new double[size][size];
        for (int index = 0; index < size; ++index) {
            identity[index][index] = 1.0;
        }
        return identity;
    }

    public static double[][] generateQR(double[][] starT) {
        Matrix starTMat = new Matrix(starT);
        QRDecomposition starTMat_qr = new QRDecomposition(starTMat);
        return starTMat_qr.getQ().getArray();
    }

    public static void genInnerProduct(double[][] mat, double[] vector, double[] innerProd) {
        int numVec = mat.length;
        for (int index = 0; index < numVec; ++index) {
            innerProd[index] = water.util.ArrayUtils.innerProduct(mat[index], vector);
        }
    }

    public static void applyGramSchmit(double[][] matT) {
        int numVec = matT.length;
        int vecSize = matT[0].length;
        double[] innerProd = new double[numVec];
        double[] scaleVec = new double[vecSize];
        for (int index = 0; index < numVec; ++index) {
            LinearAlgebraUtils.genInnerProduct(matT, matT[index], innerProd);
            for (int indexJ = 0; indexJ < index; ++indexJ) {
                System.arraycopy(matT[indexJ], 0, scaleVec, 0, vecSize);
                water.util.ArrayUtils.mult(scaleVec, innerProd[indexJ]);
                water.util.ArrayUtils.subtract(matT[index], scaleVec, matT[index]);
            }
            double mag = 1.0 / water.util.ArrayUtils.l2norm(matT[index]);
            water.util.ArrayUtils.mult(matT[index], mag);
        }
    }

    public static double[][] expandLowTrian2Ful(final double[][] cholL) {
        final int numRows = cholL.length;
        final double[][] result = new double[numRows][];
        RecursiveAction[] ras = new RecursiveAction[numRows];
        int index = 0;
        while (index < numRows) {
            final int i2 = index++;
            final double[] tempResult = MemoryManager.malloc8d(numRows);
            ras[i2] = new RecursiveAction(){

                @Override
                protected void compute() {
                    for (int colIndex = 0; colIndex <= i2; ++colIndex) {
                        tempResult[colIndex] = cholL[i2][colIndex];
                    }
                    result[i2] = Arrays.copyOf(tempResult, numRows);
                }
            };
        }
        ForkJoinTask.invokeAll(ras);
        return result;
    }

    public static double[][] matrixMultiply(final double[][] A2, final double[][] B2) {
        final int arow = A2[0].length;
        int acol = A2.length;
        int bcol = B2.length;
        final double[][] result = new double[bcol][];
        RecursiveAction[] ras = new RecursiveAction[acol];
        int index = 0;
        while (index < acol) {
            final int i2 = index++;
            final double[] tempResult = new double[arow];
            ras[i2] = new RecursiveAction(){

                @Override
                protected void compute() {
                    water.util.ArrayUtils.multArrVec(A2, B2[i2], tempResult);
                    result[i2] = Arrays.copyOf(tempResult, arow);
                }
            };
        }
        ForkJoinTask.invokeAll(ras);
        return result;
    }

    public static double[][] matrixMultiplyTriagonal(final double[][] A2, final TriDiagonalMatrix B2, boolean transposeResult) {
        final int arow = A2.length;
        int bcol = B2._size + 2;
        final int lastCol = bcol - 1;
        final int secondLastCol = bcol - 2;
        final int kMinus1 = bcol - 3;
        final int kMinus2 = bcol - 4;
        final double[][] result = new double[bcol][];
        RecursiveAction[] ras = new RecursiveAction[bcol];
        int index = 0;
        while (index < bcol) {
            final int i2 = index++;
            final double[] tempResult = new double[arow];
            final double[] bColVec = new double[B2._size];
            ras[i2] = new RecursiveAction(){

                @Override
                protected void compute() {
                    if (i2 == 0) {
                        bColVec[0] = B2._first_diag[0];
                    } else if (i2 == 1) {
                        bColVec[0] = B2._second_diag[0];
                        if (B2._first_diag.length > 1) {
                            bColVec[1] = B2._first_diag[1];
                        }
                    } else if (i2 == lastCol) {
                        bColVec[kMinus1] = B2._third_diag[kMinus1];
                    } else if (i2 == secondLastCol) {
                        bColVec[kMinus2] = B2._third_diag[kMinus2];
                        bColVec[kMinus1] = B2._second_diag[kMinus1];
                    } else {
                        bColVec[i2 - 2] = B2._third_diag[i2 - 2];
                        bColVec[i2 - 1] = B2._second_diag[i2 - 1];
                        bColVec[i2] = B2._first_diag[i2];
                    }
                    water.util.ArrayUtils.multArrVec(A2, bColVec, tempResult);
                    result[i2] = Arrays.copyOf(tempResult, arow);
                }
            };
        }
        ForkJoinTask.invokeAll(ras);
        return transposeResult ? (Object)water.util.ArrayUtils.transpose(result) : result;
    }

    public static double[] backwardSolve(double[][] L2, double[] b2, double[] res) {
        int lastIndex;
        assert (L2 != null && L2.length == L2[0].length && L2.length == b2.length);
        if (res == null) {
            res = new double[b2.length];
        }
        for (int rowIndex = lastIndex = b2.length - 1; rowIndex >= 0; --rowIndex) {
            res[rowIndex] = b2[rowIndex];
            for (int colIndex = lastIndex; colIndex > rowIndex; --colIndex) {
                int n2 = rowIndex;
                res[n2] = res[n2] - L2[rowIndex][colIndex] * res[colIndex];
            }
            int n3 = rowIndex;
            res[n3] = res[n3] / L2[rowIndex][rowIndex];
        }
        return res;
    }

    private static double modifyNumeric(double x2, int col, DataInfo dinfo) {
        double y2;
        double d2 = y2 = Double.isNaN(x2) && dinfo._imputeMissing ? dinfo._numNAFill[col] : x2;
        if (dinfo._normSub != null && dinfo._normMul != null) {
            y2 = (y2 - dinfo._normSub[col]) * dinfo._normMul[col];
        }
        return y2;
    }

    public static double[] expandRow(double[] row, DataInfo dinfo, double[] tmp, boolean modify_numeric) {
        for (int col = 0; col < dinfo._cats; ++col) {
            int cidx;
            if (Double.isNaN(row[col])) {
                if (dinfo._imputeMissing) {
                    cidx = dinfo.catNAFill()[col];
                } else {
                    if (!dinfo._catMissing[col]) continue;
                    cidx = dinfo._catOffsets[col + 1] - 1;
                }
            } else {
                cidx = dinfo._catOffsets[col + 1] - dinfo._catOffsets[col] == 1 ? dinfo.getCategoricalId(col, 0) : dinfo.getCategoricalId(col, (int)row[col]);
            }
            if (dinfo._catOffsets[col + 1] - dinfo._catOffsets[col] == 1 && cidx >= 0) {
                tmp[cidx] = row[col];
                continue;
            }
            if (cidx < 0) continue;
            tmp[cidx] = 1.0;
        }
        int chk_cnt = dinfo._cats;
        int exp_cnt = dinfo.numStart();
        for (int col = 0; col < dinfo._nums; ++col) {
            tmp[exp_cnt] = modify_numeric ? LinearAlgebraUtils.modifyNumeric(row[chk_cnt], col, dinfo) : row[chk_cnt];
            ++exp_cnt;
            ++chk_cnt;
        }
        return tmp;
    }

    public static double[][] reshape1DArray(double[] arr, int m4, int n2) {
        double[][] arr2D = new double[m4][n2];
        for (int i2 = 0; i2 < m4; ++i2) {
            System.arraycopy(arr, i2 * n2, arr2D[i2], 0, n2);
        }
        return arr2D;
    }

    public static EigenPair[] createSortedEigenpairs(double[] eigenvalues, double[][] eigenvectors) {
        int count2 = eigenvalues.length;
        Object[] eigenPairs = new EigenPair[count2];
        for (int i2 = 0; i2 < count2; ++i2) {
            eigenPairs[i2] = new EigenPair(eigenvalues[i2], eigenvectors[i2]);
        }
        Arrays.sort(eigenPairs);
        return eigenPairs;
    }

    public static EigenPair[] createReverseSortedEigenpairs(double[] eigenvalues, double[][] eigenvectors) {
        Object[] eigenPairs = LinearAlgebraUtils.createSortedEigenpairs(eigenvalues, eigenvectors);
        ArrayUtils.reverse((Object[])eigenPairs);
        return eigenPairs;
    }

    public static double[] extractEigenvaluesFromEigenpairs(EigenPair[] eigenPairs) {
        int count2 = eigenPairs.length;
        double[] eigenvalues = new double[count2];
        for (int i2 = 0; i2 < count2; ++i2) {
            eigenvalues[i2] = eigenPairs[i2].eigenvalue;
        }
        return eigenvalues;
    }

    public static double[][] extractEigenvectorsFromEigenpairs(EigenPair[] eigenPairs) {
        int count2 = eigenPairs.length;
        double[][] eigenvectors = new double[count2][];
        for (int i2 = 0; i2 < count2; ++i2) {
            eigenvectors[i2] = eigenPairs[i2].eigenvector;
        }
        return eigenvectors;
    }

    public static void choleskySymDiagMat(double[][] xx) {
        xx[0][0] = Math.sqrt(xx[0][0]);
        int rowNumber = xx.length;
        for (int row = 1; row < rowNumber; ++row) {
            int lowerDiag = row - 1;
            if (lowerDiag > 0) {
                int kMinus2 = lowerDiag - 1;
                xx[row][lowerDiag] = (xx[row][lowerDiag] - xx[row][kMinus2]) / xx[lowerDiag][lowerDiag];
            } else {
                xx[row][lowerDiag] = xx[row][lowerDiag] / xx[lowerDiag][lowerDiag];
            }
            xx[row][row] = Math.sqrt(xx[row][row] - xx[row][lowerDiag] * xx[row][lowerDiag]);
        }
    }

    public static double[][] computeR(Key<Job> jobKey, DataInfo yinfo, boolean transpose) {
        Gram.GramTask gtsk = new Gram.GramTask(jobKey, yinfo);
        gtsk.doAll(yinfo._adaptedFrame);
        Gram.Cholesky chol = gtsk._gram.cholesky(null);
        double[][] L2 = chol.getL();
        water.util.ArrayUtils.mult(L2, Math.sqrt(gtsk._nobs));
        return transpose ? L2 : water.util.ArrayUtils.transpose(L2);
    }

    public static double computeQ(Key<Job> jobKey, DataInfo yinfo, Frame ywfrm, double[][] xx) {
        xx = LinearAlgebraUtils.computeR(jobKey, yinfo, true);
        ForwardSolve qrtsk = new ForwardSolve(yinfo, xx);
        qrtsk.doAll(ywfrm);
        return qrtsk._sse;
    }

    public static double[][] computeQ(Key<Job> jobKey, DataInfo yinfo, Frame ywfrm) {
        double[][] xx = LinearAlgebraUtils.computeR(jobKey, yinfo, true);
        ForwardSolve qrtsk = new ForwardSolve(yinfo, xx);
        qrtsk.doAll(ywfrm);
        return xx;
    }

    public static double[][] computeQInPlace(Key<Job> jobKey, DataInfo yinfo) {
        double[][] cholL = LinearAlgebraUtils.computeR(jobKey, yinfo, true);
        ForwardSolveInPlace qrtsk = new ForwardSolveInPlace(yinfo, cholL);
        qrtsk.doAll(yinfo._adaptedFrame);
        return cholL;
    }

    public static int numColsExp(Frame fr, boolean useAllFactorLevels) {
        int uAFL = useAllFactorLevels ? 0 : 1;
        int cols = 0;
        for (Vec vec : fr.vecs()) {
            cols += vec.isCategorical() && vec.domain() != null ? vec.domain().length - uAFL : 1;
        }
        return cols;
    }

    static double[] multiple(double[] diagYY, int nTot, int nVars) {
        int ny = diagYY.length;
        int i2 = 0;
        while (i2 < ny) {
            int n2 = i2++;
            diagYY[n2] = diagYY[n2] * (double)nTot;
        }
        double[][] uu = new double[ny][ny];
        for (int i3 = 0; i3 < ny; ++i3) {
            for (int j2 = 0; j2 < ny; ++j2) {
                double yyij = i3 == j2 ? diagYY[i3] : 0.0;
                uu[i3][j2] = (yyij - diagYY[i3] * diagYY[j2] / (double)nTot) / ((double)nVars * Math.sqrt(diagYY[i3] * diagYY[j2]));
                if (!Double.isNaN(uu[i3][j2])) continue;
                uu[i3][j2] = 0.0;
            }
        }
        EigenvalueDecomposition eigen = new EigenvalueDecomposition(new Matrix(uu));
        double[] eigenvalues = eigen.getRealEigenvalues();
        double[][] eigenvectors = eigen.getV().getArray();
        int maxIndex = water.util.ArrayUtils.maxIndex(eigenvalues);
        return eigenvectors[maxIndex];
    }

    public static double[] toEigenArray(Vec src) {
        Key<Frame> source = Key.make();
        Key<Frame> dest = Key.make();
        Frame train = new Frame(source, new String[]{"enum"}, new Vec[]{src});
        int maxLevels = 1024;
        boolean created = false;
        if (src.cardinality() > maxLevels) {
            DKV.put(train);
            created = true;
            Log.info("Reducing the cardinality of a categorical column with " + src.cardinality() + " levels to " + maxLevels);
            train = Interaction.getInteraction(train._key, train.names(), maxLevels).execImpl(dest).get();
        }
        DataInfo dinfo = new DataInfo(train, null, 0, true, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, false, true, false, false, false, false, false);
        DKV.put(dinfo);
        Gram.GramTask gtsk = (Gram.GramTask)new Gram.GramTask(null, dinfo).doAll(dinfo._adaptedFrame);
        double[] rounded = new double[gtsk._gram._diag.length];
        for (int i2 = 0; i2 < rounded.length; ++i2) {
            rounded[i2] = (float)gtsk._gram._diag[i2];
        }
        dinfo.remove();
        double[] array = LinearAlgebraUtils.multiple(rounded, (int)gtsk._nobs, 1);
        if (created) {
            train.remove();
            DKV.remove(source);
        }
        return array;
    }

    public static Vec toEigen(Vec src) {
        Key<Frame> source = Key.make();
        Key<Frame> dest = Key.make();
        Frame train = new Frame(source, new String[]{"enum"}, new Vec[]{src});
        int maxLevels = 1024;
        boolean created = false;
        if (src.cardinality() > maxLevels) {
            DKV.put(train);
            created = true;
            Log.info("Reducing the cardinality of a categorical column with " + src.cardinality() + " levels to " + maxLevels);
            train = Interaction.getInteraction(train._key, train.names(), maxLevels).execImpl(dest).get();
        }
        Vec v2 = ((ProjectOntoEigenVector)new ProjectOntoEigenVector(LinearAlgebraUtils.toEigenArray(src)).doAll(1, (byte)3, train)).outputFrame().anyVec();
        if (created) {
            train.remove();
            DKV.remove(source);
        }
        return v2;
    }

    public static double[] toEigenProjectionArray(Frame _origTrain, Frame _train, boolean expensive) {
        if (expensive && _origTrain != null && _origTrain != _train) {
            ArrayList<Double> projections = new ArrayList<Double>();
            for (int i2 = 0; i2 < _origTrain.numCols(); ++i2) {
                double[] actProjection;
                Vec currentCol = _origTrain.vec(i2);
                if (!currentCol.isCategorical()) continue;
                for (double v2 : actProjection = LinearAlgebraUtils.toEigenArray(currentCol)) {
                    projections.add(v2);
                }
            }
            double[] primitive_projections = new double[projections.size()];
            for (int i3 = 0; i3 < projections.size(); ++i3) {
                primitive_projections[i3] = (Double)projections.get(i3);
            }
            return primitive_projections;
        }
        return null;
    }

    public static String getMatrixInString(double[][] matrix) {
        int dimX = matrix.length;
        if (dimX <= 0) {
            return "";
        }
        int dimY = matrix[0].length;
        for (int x2 = 1; x2 < dimX; ++x2) {
            if (matrix[x2].length == dimY) continue;
            return "Stacked matrix!";
        }
        StringBuilder stringOfMatrix = new StringBuilder();
        for (int x3 = 0; x3 < dimX; ++x3) {
            for (int y2 = 0; y2 < dimY; ++y2) {
                if (matrix[x3][y2] > 0.0) {
                    stringOfMatrix.append(' ');
                }
                stringOfMatrix.append(String.format("%.4f\t", matrix[x3][y2]));
            }
            stringOfMatrix.append('\n');
        }
        return stringOfMatrix.toString();
    }

    static class ProjectOntoEigenVector
    extends MRTask<ProjectOntoEigenVector> {
        final double[] _yCoord;

        ProjectOntoEigenVector(double[] yCoord) {
            this._yCoord = yCoord;
        }

        @Override
        public void map(Chunk[] cs, NewChunk[] nc) {
            for (int i2 = 0; i2 < cs[0]._len; ++i2) {
                if (cs[0].isNA(i2)) {
                    nc[0].addNA();
                    continue;
                }
                int which = (int)cs[0].at8(i2);
                nc[0].addNum((float)this._yCoord[which]);
            }
        }
    }

    public static class ForwardSolveInPlace
    extends MRTask<ForwardSolveInPlace> {
        final DataInfo _ainfo;
        final int _ncols;
        final double[][] _L;

        public ForwardSolveInPlace(DataInfo ainfo, double[][] L2) {
            assert (L2 != null && L2.length == L2[0].length && L2.length == ainfo._adaptedFrame.numCols());
            this._ainfo = ainfo;
            this._ncols = ainfo._adaptedFrame.numCols();
            this._L = L2;
        }

        @Override
        public void map(Chunk[] cs) {
            assert (this._ncols == cs.length);
            Chunk[] achks = new Chunk[this._ncols];
            System.arraycopy(cs, 0, achks, 0, this._ncols);
            for (int row = 0; row < cs[0]._len; ++row) {
                DataInfo.Row arow = this._ainfo.newDenseRow();
                this._ainfo.extractDenseRow(achks, row, arow);
                if (arow.isBad()) continue;
                double[] aexp = arow.expandCats();
                double[] qrow = LinearAlgebraUtils.forwardSolve(this._L, aexp);
                assert (qrow.length == this._ncols);
                for (int d2 = 0; d2 < this._ncols; ++d2) {
                    cs[d2].set(row, qrow[d2]);
                }
            }
        }
    }

    public static class ForwardSolve
    extends MRTask<ForwardSolve> {
        final DataInfo _ainfo;
        final int _ncols;
        final double[][] _L;
        public double _sse;

        public ForwardSolve(DataInfo ainfo, double[][] L2) {
            assert (L2 != null && L2.length == L2[0].length && L2.length == ainfo._adaptedFrame.numCols());
            this._ainfo = ainfo;
            this._ncols = ainfo._adaptedFrame.numCols();
            this._L = L2;
            this._sse = 0.0;
        }

        @Override
        public void map(Chunk[] cs) {
            assert (2 * this._ncols == cs.length);
            Chunk[] achks = new Chunk[this._ncols];
            System.arraycopy(cs, 0, achks, 0, this._ncols);
            for (int row = 0; row < cs[0]._len; ++row) {
                DataInfo.Row arow = this._ainfo.newDenseRow();
                this._ainfo.extractDenseRow(achks, row, arow);
                if (arow.isBad()) continue;
                double[] aexp = arow.expandCats();
                double[] qrow = LinearAlgebraUtils.forwardSolve(this._L, aexp);
                int i2 = 0;
                for (int d2 = this._ncols; d2 < 2 * this._ncols; ++d2) {
                    double qold = cs[d2].atd(row);
                    double diff = qrow[i2] - qold;
                    this._sse += diff * diff;
                    cs[d2].set(row, qrow[i2++]);
                }
                assert (i2 == qrow.length);
            }
        }
    }

    public static class SMulTask
    extends MRTask<SMulTask> {
        final DataInfo _ainfo;
        final int _ncolA;
        final int _ncolExp;
        final int _ncolQ;
        public double[][] _atq;

        public SMulTask(DataInfo ainfo, int ncolQ) {
            this._ainfo = ainfo;
            this._ncolA = ainfo._adaptedFrame.numCols();
            this._ncolExp = LinearAlgebraUtils.numColsExp(ainfo._adaptedFrame, true);
            this._ncolQ = ncolQ;
        }

        public SMulTask(DataInfo ainfo, int ncolQ, int ncolExp) {
            this._ainfo = ainfo;
            this._ncolA = ainfo._adaptedFrame.numCols();
            this._ncolExp = ncolExp;
            this._ncolQ = ncolQ;
        }

        @Override
        public void map(Chunk[] cs) {
            assert (this._ncolA + this._ncolQ == cs.length);
            this._atq = new double[this._ncolExp][this._ncolQ];
            for (int k2 = this._ncolA; k2 < this._ncolA + this._ncolQ; ++k2) {
                for (int p2 = 0; p2 < this._ainfo._cats; ++p2) {
                    for (int row = 0; row < cs[0]._len; ++row) {
                        int cidx;
                        if (cs[p2].isNA(row) && this._ainfo._skipMissing) continue;
                        double q2 = cs[k2].atd(row);
                        double a2 = cs[p2].atd(row);
                        if (Double.isNaN(a2)) {
                            if (this._ainfo._imputeMissing) {
                                cidx = this._ainfo.catNAFill()[p2];
                            } else {
                                if (!this._ainfo._catMissing[p2]) continue;
                                cidx = this._ainfo._catOffsets[p2 + 1] - 1;
                            }
                        } else {
                            cidx = this._ainfo.getCategoricalId(p2, (int)a2);
                        }
                        if (cidx < 0) continue;
                        double[] dArray = this._atq[cidx];
                        int n2 = k2 - this._ncolA;
                        dArray[n2] = dArray[n2] + q2;
                    }
                }
                int pnum = 0;
                int pexp = this._ainfo.numStart();
                for (int p3 = this._ainfo._cats; p3 < this._ncolA; ++p3) {
                    for (int row = 0; row < cs[0]._len; ++row) {
                        if (cs[p3].isNA(row) && this._ainfo._skipMissing) continue;
                        double q3 = cs[k2].atd(row);
                        double a3 = cs[p3].atd(row);
                        a3 = LinearAlgebraUtils.modifyNumeric(a3, pnum, this._ainfo);
                        double[] dArray = this._atq[pexp];
                        int n3 = k2 - this._ncolA;
                        dArray[n3] = dArray[n3] + q3 * a3;
                    }
                    ++pexp;
                    ++pnum;
                }
                assert (pexp == this._atq.length);
            }
        }

        @Override
        public void reduce(SMulTask other) {
            water.util.ArrayUtils.add(this._atq, other._atq);
        }
    }

    public static class BMulInPlaceTask
    extends MRTask<BMulInPlaceTask> {
        final DataInfo _xinfo;
        final double[][] _yt;
        final int _ncolX;
        public boolean _originalImplementation = true;

        public BMulInPlaceTask(DataInfo xinfo, double[][] yt, int nColsExp) {
            assert (yt != null && yt[0].length == nColsExp);
            this._xinfo = xinfo;
            this._ncolX = xinfo._adaptedFrame.numCols();
            this._yt = yt;
        }

        public BMulInPlaceTask(DataInfo xinfo, double[][] yt, int nColsExp, boolean originalWay) {
            assert (yt != null && yt[0].length == nColsExp);
            this._xinfo = xinfo;
            this._ncolX = xinfo._adaptedFrame.numCols();
            this._yt = yt;
            this._originalImplementation = originalWay;
        }

        @Override
        public void map(Chunk[] cs) {
            assert (cs.length == this._ncolX + this._yt.length);
            int lastColInd = this._ncolX - 1;
            Chunk[] xchk = new Chunk[this._ncolX];
            DataInfo.Row xrow = this._xinfo.newDenseRow();
            System.arraycopy(cs, 0, xchk, 0, this._ncolX);
            for (int row = 0; row < cs[0]._len; ++row) {
                this._xinfo.extractDenseRow(xchk, row, xrow);
                if (xrow.isBad()) continue;
                int bidx = this._ncolX;
                for (double[] ps : this._yt) {
                    double sum = this._originalImplementation ? xrow.innerProduct(ps) : xrow.innerProduct(ps) - ps[lastColInd];
                    cs[bidx].set(row, sum);
                    ++bidx;
                }
                assert (bidx == cs.length);
            }
        }
    }

    public static class BMulTaskMatrices
    extends MRTask<BMulTaskMatrices> {
        final Frame _y;
        final int _nyChunks;
        final int _yColNum;

        public BMulTaskMatrices(Frame y2) {
            this._y = y2;
            this._nyChunks = this._y.anyVec().nChunks();
            this._yColNum = this._y.numCols();
        }

        private void mulResultPerYChunk(Chunk[] xChunk, Chunk[] yChunk) {
            int xChunkLen = xChunk[0].len();
            int yColLen = yChunk.length;
            int yChunkLen = yChunk[0].len();
            int resultColOffset = xChunk.length - yColLen;
            int xChunkColOffset = (int)yChunk[0].start();
            for (int colIndex = 0; colIndex < yColLen; ++colIndex) {
                int resultColIndex = colIndex + resultColOffset;
                for (int rowIndex = 0; rowIndex < xChunkLen; ++rowIndex) {
                    double origResult = xChunk[resultColIndex].atd(rowIndex);
                    for (int interIndex = 0; interIndex < yChunkLen; ++interIndex) {
                        origResult += xChunk[interIndex + xChunkColOffset].atd(rowIndex) * yChunk[colIndex].atd(interIndex);
                    }
                    xChunk[resultColIndex].set(rowIndex, origResult);
                }
            }
        }

        @Override
        public void map(Chunk[] xChunk) {
            Chunk[] ychunk = new Chunk[this._y.numCols()];
            for (int ychunkInd = 0; ychunkInd < this._nyChunks; ++ychunkInd) {
                for (int chkIndex = 0; chkIndex < this._yColNum; ++chkIndex) {
                    ychunk[chkIndex] = this._y.vec(chkIndex).chunkForChunkIdx(ychunkInd);
                }
                this.mulResultPerYChunk(xChunk, ychunk);
            }
        }
    }

    public static class BMulTask
    extends FrameTask<BMulTask> {
        final double[][] _yt;

        public BMulTask(Key<Job> jobKey, DataInfo dinfo, double[][] yt) {
            super(jobKey, dinfo);
            this._yt = yt;
        }

        @Override
        protected void processRow(long gid, DataInfo.Row row, NewChunk[] outputs) {
            for (int p2 = 0; p2 < this._yt.length; ++p2) {
                double x2 = row.innerProduct(this._yt[p2]);
                outputs[p2].addNum(x2);
            }
        }
    }

    public static class CopyQtoQMatrix
    extends MRTask<CopyQtoQMatrix> {
        @Override
        public void map(Chunk[] cs) {
            int totColumn = cs.length;
            int halfColumn = totColumn / 2;
            int totRows = cs[0].len();
            for (int rowIndex = 0; rowIndex < totRows; ++rowIndex) {
                for (int colIndex = 0; colIndex < halfColumn; ++colIndex) {
                    cs[colIndex].set(rowIndex, cs[colIndex + halfColumn].atd(rowIndex));
                }
            }
        }
    }

    public static class FindMaxIndex
    extends MRTask<FindMaxIndex> {
        public long _maxIndex = -1L;
        int _colIndex;
        double _maxValue;

        public FindMaxIndex(int colOfInterest, double maxValue) {
            this._colIndex = colOfInterest;
            this._maxValue = maxValue;
        }

        @Override
        public void map(Chunk[] cs) {
            int rowLen = cs[0].len();
            long startRowIndex = cs[0].start();
            for (int rowIndex = 0; rowIndex < rowLen; ++rowIndex) {
                double rowVal = cs[this._colIndex].atd(rowIndex);
                if (rowVal != this._maxValue) continue;
                this._maxIndex = startRowIndex + (long)rowIndex;
            }
        }

        @Override
        public void reduce(FindMaxIndex other) {
            if (this._maxIndex < 0L) {
                this._maxIndex = other._maxIndex;
            } else if (this._maxIndex > other._maxIndex) {
                this._maxIndex = other._maxIndex;
            }
        }
    }
}

