/*
 * Decompiled with CFR 0.152.
 */
package water.rapids.ast.prims.advmath;

import java.math.BigDecimal;
import java.math.MathContext;
import java.util.Objects;
import water.Key;
import water.MRTask;
import water.Scope;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Merge;
import water.rapids.ast.prims.advmath.AstCorrelation;
import water.util.FrameUtils;

public class SpearmanCorrelation {
    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static Frame calculate(Frame frameX, Frame frameY, AstCorrelation.Mode mode) {
        Objects.requireNonNull(frameX);
        Objects.requireNonNull(frameY);
        SpearmanCorrelation.checkCorrelationDoable(frameX, frameY, mode);
        Frame correlationMatrix = SpearmanCorrelation.createCorrelationMatrix(frameX, frameY);
        boolean framesAreEqual = !AstCorrelation.Mode.Everything.equals((Object)mode) && SpearmanCorrelation.framesContainSameVecs(frameX, frameY);
        for (int vecIdX = 0; vecIdX < frameX.numCols(); ++vecIdX) {
            for (int vecIdY = 0; vecIdY < frameY.numCols(); ++vecIdY) {
                Scope.enter();
                try {
                    if (framesAreEqual && vecIdX == vecIdY) {
                        correlationMatrix.vec(vecIdX).set((long)vecIdY, 1.0);
                        continue;
                    }
                    if (SpearmanCorrelation.isNaNCorrelation(frameX.vec(vecIdX), frameY.vec(vecIdY), mode)) {
                        correlationMatrix.vec(vecIdX).set((long)vecIdY, Double.NaN);
                        continue;
                    }
                    SpearmanRankedVectors rankedVectors = SpearmanCorrelation.rankedVectors(frameX, frameY, vecIdX, vecIdY, mode);
                    double[] means = SpearmanCorrelation.calculateMeans(rankedVectors._x, rankedVectors._y);
                    SpearmanCorrelationCoefficientTask spearman = (SpearmanCorrelationCoefficientTask)new SpearmanCorrelationCoefficientTask(means[0], means[1]).doAll(rankedVectors._x, rankedVectors._y);
                    correlationMatrix.vec(vecIdX).set((long)vecIdY, spearman.getSpearmanCorrelationCoefficient());
                    continue;
                }
                finally {
                    Scope.exit(new Key[0]);
                }
            }
        }
        return correlationMatrix;
    }

    private static boolean framesContainSameVecs(Frame frameX, Frame frameY) {
        Vec[] vecsY;
        Vec[] vecsX = frameX.vecs();
        if (vecsX.length != (vecsY = frameY.vecs()).length) {
            return false;
        }
        for (int i2 = 0; i2 < vecsX.length; ++i2) {
            if (vecsX[i2]._key.equals(vecsY[i2]._key)) continue;
            return false;
        }
        return true;
    }

    private static void checkCorrelationDoable(Frame frameX, Frame frameY, AstCorrelation.Mode mode) throws IllegalArgumentException {
        if (!AstCorrelation.Mode.AllObs.equals((Object)mode)) {
            return;
        }
        if (frameX.numCols() == 0) {
            throw new IllegalArgumentException("First given frame for Spearman calculation has no columnns.");
        }
        if (frameY.numCols() == 0) {
            throw new IllegalArgumentException("Second given frame for Spearman calculation has no columnns.");
        }
        Vec[] vecsX = frameX.vecs();
        Vec[] vecsY = frameY.vecs();
        for (int i2 = 0; i2 < vecsX.length; ++i2) {
            if (vecsX[i2].naCnt() == 0L && vecsY[i2].naCnt() == 0L) continue;
            throw new IllegalArgumentException("Mode is 'AllObs' but NAs are present");
        }
    }

    private static boolean isNaNCorrelation(Vec vecX, Vec vecY, AstCorrelation.Mode mode) {
        return AstCorrelation.Mode.Everything.equals((Object)mode) && (vecX.naCnt() > 0L || vecY.naCnt() > 0L);
    }

    private static Frame createCorrelationMatrix(Frame frameX, Frame frameY) {
        Vec[] correlationVecs = new Vec[frameX.numCols()];
        int height = frameY.numCols();
        for (int width = 0; width < frameX.numCols(); ++width) {
            correlationVecs[width] = Vec.makeCon(Double.NaN, (long)height);
        }
        return new Frame(Key.make(), correlationVecs);
    }

    private static SpearmanRankedVectors rankedVectors(Frame frameX, Frame frameY, int vecIdX, int vecIdY, AstCorrelation.Mode mode) {
        Frame comparedVecsWithNas = new Frame(frameX.vec(vecIdX).makeCopy(), frameY.vec(vecIdY).makeCopy());
        Frame unsortedVecs = AstCorrelation.Mode.CompleteObs.equals((Object)mode) ? comparedVecsWithNas : ((Merge.RemoveNAsTask)new Merge.RemoveNAsTask(0, 1).doAll(comparedVecsWithNas.types(), comparedVecsWithNas)).outputFrame(comparedVecsWithNas.names(), comparedVecsWithNas.domains());
        Frame sortedX = new Frame(unsortedVecs.vec(0).makeCopy());
        Scope.track(sortedX);
        Frame sortedY = new Frame(unsortedVecs.vec(1).makeCopy());
        Scope.track(sortedY);
        boolean xIsOrdered = SpearmanCorrelation.needsOrdering(sortedX.vec(0));
        boolean yIsOrdered = SpearmanCorrelation.needsOrdering(sortedY.vec(0));
        if (xIsOrdered) {
            FrameUtils.labelRows(sortedX, "label");
            sortedX = sortedX.sort(new int[]{0});
            Scope.track(sortedX);
        }
        if (yIsOrdered) {
            FrameUtils.labelRows(sortedY, "label");
            sortedY = sortedY.sort(new int[]{0});
            Scope.track(sortedY);
        }
        assert (sortedX.numRows() == sortedY.numRows());
        Vec orderX = SpearmanCorrelation.needsOrdering(sortedX.vec(0)) ? Vec.makeZero(sortedX.numRows()) : frameX.vec(vecIdX);
        Vec orderY = SpearmanCorrelation.needsOrdering(sortedY.vec(0)) ? Vec.makeZero(sortedY.numRows()) : frameY.vec(vecIdY);
        Vec xLabel = sortedX.vec("label") == null ? sortedX.vec(0) : sortedX.vec("label");
        Vec xValue = sortedX.vec(0);
        Vec yLabel = sortedY.vec("label") == null ? sortedY.vec(0) : sortedY.vec("label");
        Vec yValue = sortedY.vec(0);
        Scope.track(xLabel);
        Scope.track(yLabel);
        Vec.Writer orderXWriter = orderX.open();
        Vec.Writer orderYWriter = orderY.open();
        Vec.Reader xValueReader = xValue.new Vec.Reader();
        Vec.Reader yValueReader = yValue.new Vec.Reader();
        Vec.Reader xLabelReader = xLabel.new Vec.Reader();
        Vec.Reader yLabelReader = yLabel.new Vec.Reader();
        double lastX = Double.NaN;
        double lastY = Double.NaN;
        long skippedX = 0L;
        long skippedY = 0L;
        int i2 = 0;
        while ((long)i2 < orderX.length()) {
            if (xIsOrdered) {
                skippedX = lastX == xValueReader.at(i2) ? ++skippedX : 0L;
                lastX = xValueReader.at(i2);
                orderXWriter.set(xLabelReader.at8(i2) - 1L, (long)i2 - skippedX);
            }
            if (yIsOrdered) {
                skippedY = lastY == yValueReader.at(i2) ? ++skippedY : 0L;
                lastY = yValueReader.at(i2);
                orderYWriter.set(yLabelReader.at8(i2) - 1L, (long)i2 - skippedY);
            }
            ++i2;
        }
        orderXWriter.close();
        orderYWriter.close();
        Frame sameChunkLayoutFrame = new Frame(new String[]{"X"}, new Vec[]{orderX});
        sameChunkLayoutFrame.add("Y", orderY);
        return new SpearmanRankedVectors(sameChunkLayoutFrame.vec("X"), sameChunkLayoutFrame.vec("Y"));
    }

    private static boolean needsOrdering(Vec vec) {
        return !vec.isCategorical();
    }

    private static double[] calculateMeans(Vec ... vecs) throws IllegalArgumentException {
        if (vecs.length < 1) {
            throw new IllegalArgumentException("There are no vectors to calculate means from.");
        }
        long referenceVectorLength = vecs[0].length();
        for (int i2 = 0; i2 < vecs.length; ++i2) {
            if (!vecs[i2].isCategorical() && !vecs[i2].isNumeric()) {
                throw new IllegalArgumentException(String.format("Given vector '%s' is not numerical or categorical.", vecs[i2]._key.toString()));
            }
            if (referenceVectorLength == vecs[i2].length()) continue;
            throw new IllegalArgumentException("Vectors to calculate means from do not have the same length." + String.format(" Vector '%s' is of length '%d'", vecs[i2]._key.toString(), vecs[i2].length()));
        }
        return ((MeanTask)new MeanTask().doAll(vecs))._means;
    }

    private static class MeanTask
    extends MRTask<MeanTask> {
        private double[] _means;
        private long _linesVisited = 0L;

        private MeanTask() {
        }

        @Override
        public void map(Chunk[] cs) {
            int i2;
            BigDecimal[] averages = new BigDecimal[cs.length];
            for (i2 = 0; i2 < averages.length; ++i2) {
                averages[i2] = new BigDecimal(0, MathContext.DECIMAL128);
            }
            block1: for (int row = 0; row < cs[0].len(); ++row) {
                double[] values = new double[cs.length];
                for (int col = 0; col < cs.length; ++col) {
                    values[col] = cs[col].atd(row);
                    if (Double.isNaN(values[col])) break block1;
                }
                ++this._linesVisited;
                for (int i3 = 0; i3 < values.length; ++i3) {
                    averages[i3] = averages[i3].add(new BigDecimal(values[i3], MathContext.DECIMAL128), MathContext.DECIMAL128);
                }
            }
            this._means = new double[cs.length];
            for (i2 = 0; i2 < averages.length; ++i2) {
                this._means[i2] = averages[i2].divide(new BigDecimal(this._linesVisited), MathContext.DECIMAL64).doubleValue();
            }
        }

        @Override
        public void reduce(MeanTask mrt) {
            int numChunks = this._means.length;
            for (int i2 = 0; i2 < numChunks; ++i2) {
                this._means[i2] = (this._means[i2] * (double)this._linesVisited + mrt._means[i2] * (double)mrt._linesVisited) / (double)(this._linesVisited + mrt._linesVisited);
            }
            this._linesVisited += mrt._linesVisited;
        }
    }

    private static class SpearmanCorrelationCoefficientTask
    extends MRTask<SpearmanCorrelationCoefficientTask> {
        private final double _xMean;
        private final double _yMean;
        private double spearmanCorrelationCoefficient;
        private double _xDiffSquared = 0.0;
        private double _yDiffSquared = 0.0;
        private double _xyMul = 0.0;
        private long _linesVisited;

        private SpearmanCorrelationCoefficientTask(double xMean, double yMean) {
            this._xMean = xMean;
            this._yMean = yMean;
        }

        @Override
        public void map(Chunk[] chunks) {
            assert (chunks.length == 2);
            Chunk xChunk = chunks[0];
            Chunk yChunk = chunks[1];
            for (int row = 0; row < chunks[0].len(); ++row) {
                double x2 = xChunk.atd(row);
                double y2 = yChunk.atd(row);
                ++this._linesVisited;
                this._xyMul += x2 * y2;
                double xDiffFromMean = x2 - this._xMean;
                double yDiffFromMean = y2 - this._yMean;
                this._xDiffSquared += Math.pow(xDiffFromMean, 2.0);
                this._yDiffSquared += Math.pow(yDiffFromMean, 2.0);
            }
        }

        @Override
        public void reduce(SpearmanCorrelationCoefficientTask mrt) {
            this._xDiffSquared += mrt._xDiffSquared;
            this._yDiffSquared += mrt._yDiffSquared;
            this._linesVisited += mrt._linesVisited;
            this._xyMul += mrt._xyMul;
        }

        @Override
        protected void postGlobal() {
            double xStdDev = Math.sqrt(this._xDiffSquared / (double)this._linesVisited);
            double yStdDev = Math.sqrt(this._yDiffSquared / (double)this._linesVisited);
            this.spearmanCorrelationCoefficient = (this._xyMul - (double)this._linesVisited * this._xMean * this._yMean) / ((double)this._linesVisited * xStdDev * yStdDev);
        }

        public double getSpearmanCorrelationCoefficient() {
            return this.spearmanCorrelationCoefficient;
        }
    }

    private static class SpearmanRankedVectors {
        private final Vec _x;
        private final Vec _y;

        public SpearmanRankedVectors(Vec x2, Vec y2) {
            this._x = x2;
            this._y = y2;
        }
    }
}

