/*
 * Decompiled with CFR 0.152.
 */
package water.util;

import java.util.Arrays;
import java.util.Random;
import water.Key;
import water.MRTask;
import water.fvec.C16Chunk;
import water.fvec.CStrChunk;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.parser.BufferedString;
import water.util.ArrayUtils;
import water.util.IcedAtomicInt;
import water.util.IcedDouble;
import water.util.IcedHashMap;
import water.util.Log;
import water.util.PrettyPrint;
import water.util.RandomBase;
import water.util.RandomUtils;

public class MRUtils {
    public static Frame sampleFrame(Frame fr, long rows, long seed) {
        return MRUtils.sampleFrame(fr, rows, null, seed);
    }

    public static Frame shuffleFramePerChunk(Frame fr, final long seed) {
        return ((MRTask)new MRTask(){

            @Override
            public void map(Chunk[] cs, NewChunk[] ncs) {
                int[] idx = new int[cs[0]._len];
                for (int r2 = 0; r2 < idx.length; ++r2) {
                    idx[r2] = r2;
                }
                ArrayUtils.shuffleArray(idx, RandomUtils.getRNG(seed));
                int[] nArray = idx;
                int n2 = nArray.length;
                for (int i2 = 0; i2 < n2; ++i2) {
                    long anIdx = nArray[i2];
                    for (int i3 = 0; i3 < ncs.length; ++i3) {
                        if (cs[i3] instanceof CStrChunk) {
                            ncs[i3].addStr(cs[i3], cs[i3].start() + anIdx);
                            continue;
                        }
                        ncs[i3].addNum(cs[i3].atd((int)anIdx));
                    }
                }
            }
        }.doAll(fr.types(), fr)).outputFrame(fr.names(), fr.domains());
    }

    public static Frame sampleFrameStratified(Frame fr, Vec label, Vec weights, float[] sampling_ratios, long maxrows, long seed, boolean allowOversampling, boolean verbose) {
        return MRUtils.sampleFrameStratified(fr, label, weights, sampling_ratios, maxrows, seed, allowOversampling, verbose, null);
    }

    public static Frame sampleFrameStratified(Frame fr, Vec label, Vec weights, float[] sampling_ratios, long maxrows, long seed, boolean allowOversampling, boolean verbose, String[] quasibinomialDomain) {
        int i2;
        double[] dist;
        if (fr == null) {
            return null;
        }
        assert (label.isCategorical());
        if (maxrows < (long)label.domain().length) {
            Log.warn("Attempting to do stratified sampling to fewer samples than there are class labels - automatically increasing to #rows == #labels (" + label.domain().length + ").");
            maxrows = label.domain().length;
        }
        if (quasibinomialDomain != null) {
            dist = weights != null ? ((ClassDistQuasibinomial)new ClassDistQuasibinomial(quasibinomialDomain).doAll(label, weights)).dist() : ((ClassDistQuasibinomial)new ClassDistQuasibinomial(quasibinomialDomain).doAll(label)).dist();
        } else {
            double[] dArray = dist = weights != null ? ((ClassDist)new ClassDist(label).doAll(label, weights)).dist() : ((ClassDist)new ClassDist(label).doAll(label)).dist();
        }
        assert (dist.length > 0);
        Log.info("Doing stratified sampling for data set containing " + fr.numRows() + " rows from " + dist.length + " classes. Oversampling: " + (allowOversampling ? "on" : "off"));
        if (verbose) {
            for (i2 = 0; i2 < dist.length; ++i2) {
                Log.info("Class " + label.factor(i2) + ": count: " + dist[i2] + " prior: " + (float)dist[i2] / (float)fr.numRows());
            }
        }
        float[] fArray = sampling_ratios = sampling_ratios == null ? new float[dist.length] : (float[])sampling_ratios.clone();
        assert (sampling_ratios.length == dist.length);
        if (ArrayUtils.minValue(sampling_ratios) == 0.0f && ArrayUtils.maxValue(sampling_ratios) == 0.0f) {
            for (i2 = 0; i2 < dist.length; ++i2) {
                sampling_ratios[i2] = (float)fr.numRows() / (float)label.domain().length / (float)dist[i2];
            }
            float inv_scale = ArrayUtils.minValue(sampling_ratios);
            if (!Float.isNaN(inv_scale) && !Float.isInfinite(inv_scale)) {
                ArrayUtils.div(sampling_ratios, inv_scale);
            }
        }
        if (!allowOversampling) {
            for (i2 = 0; i2 < sampling_ratios.length; ++i2) {
                sampling_ratios[i2] = Math.min(1.0f, sampling_ratios[i2]);
            }
        }
        float numrows = 0.0f;
        for (int i3 = 0; i3 < sampling_ratios.length; ++i3) {
            numrows = (float)((double)numrows + (double)sampling_ratios[i3] * dist[i3]);
        }
        if (Float.isNaN(numrows)) {
            Log.err("Total number of sampled rows was NaN. Sampling ratios: " + Arrays.toString(sampling_ratios) + "; Dist: " + Arrays.toString(dist));
            throw new IllegalArgumentException("Error during sampling - too few points?");
        }
        long actualnumrows = Math.min(maxrows, (long)Math.round(numrows));
        assert (actualnumrows >= 0L);
        Log.info("Stratified sampling to a total of " + String.format("%,d", actualnumrows) + " rows" + ((float)actualnumrows < numrows ? " (limited by max_after_balance_size)." : "."));
        if ((float)actualnumrows != numrows) {
            ArrayUtils.mult(sampling_ratios, (float)actualnumrows / numrows);
            if (verbose) {
                Log.info("Downsampling majority class by " + (float)actualnumrows / numrows + " to limit number of rows to " + String.format("%,d", maxrows));
            }
        }
        for (int i4 = 0; i4 < label.domain().length; ++i4) {
            Log.info("Class '" + label.domain()[i4] + "' sampling ratio: " + sampling_ratios[i4]);
        }
        return MRUtils.sampleFrameStratified(fr, label, weights, sampling_ratios, seed, verbose, quasibinomialDomain);
    }

    public static Frame sampleFrameStratified(Frame fr, Vec label, Vec weights, float[] sampling_ratios, long seed, boolean debug, String[] quasibinomialDomain) {
        return MRUtils.sampleFrameStratified(fr, label, weights, sampling_ratios, seed, debug, 0, quasibinomialDomain);
    }

    private static Frame sampleFrameStratified(Frame fr, Vec label, Vec weights, final float[] sampling_ratios, final long seed, boolean debug, int count2, String[] quasibinomialDomain) {
        double[] dist;
        Vec wei;
        if (fr == null) {
            return null;
        }
        assert (label.isCategorical());
        assert (sampling_ratios != null && sampling_ratios.length == label.domain().length);
        final int labelidx = fr.find(label);
        assert (labelidx >= 0);
        int weightsidx = fr.find(weights);
        boolean poisson = false;
        Frame r2 = ((MRTask)new MRTask(){

            @Override
            public void map(Chunk[] cs, NewChunk[] ncs) {
                RandomBase rng = RandomUtils.getRNG(seed);
                for (int r2 = 0; r2 < cs[0]._len; ++r2) {
                    if (cs[labelidx].isNA(r2)) continue;
                    rng.setSeed(cs[0].start() + (long)r2 + seed);
                    int label = (int)cs[labelidx].at8(r2);
                    assert (sampling_ratios.length > label && label >= 0);
                    float remainder = sampling_ratios[label] - (float)((int)sampling_ratios[label]);
                    int sampling_reps = (int)sampling_ratios[label] + (rng.nextFloat() < remainder ? 1 : 0);
                    for (int i2 = 0; i2 < ncs.length; ++i2) {
                        int j2;
                        if (cs[i2] instanceof CStrChunk) {
                            for (j2 = 0; j2 < sampling_reps; ++j2) {
                                ncs[i2].addStr(cs[i2], cs[0].start() + (long)r2);
                            }
                            continue;
                        }
                        for (j2 = 0; j2 < sampling_reps; ++j2) {
                            ncs[i2].addNum(cs[i2].atd(r2));
                        }
                    }
                }
            }
        }.doAll(fr.types(), fr)).outputFrame(fr.names(), fr.domains());
        Vec lab = r2.vecs()[labelidx];
        Vec vec = wei = weightsidx != -1 ? r2.vecs()[weightsidx] : null;
        if (quasibinomialDomain != null) {
            dist = wei != null ? ((ClassDistQuasibinomial)new ClassDistQuasibinomial(quasibinomialDomain).doAll(lab, wei)).dist() : ((ClassDistQuasibinomial)new ClassDistQuasibinomial(quasibinomialDomain).doAll(lab)).dist();
        } else {
            double[] dArray = dist = wei != null ? ((ClassDist)new ClassDist(lab).doAll(lab, wei)).dist() : ((ClassDist)new ClassDist(lab).doAll(lab)).dist();
        }
        if (dist == null) {
            return fr;
        }
        if (debug) {
            double sumdist = ArrayUtils.sum(dist);
            Log.info("After stratified sampling: " + sumdist + " rows.");
            for (int i2 = 0; i2 < dist.length; ++i2) {
                Log.info("Class " + r2.vecs()[labelidx].factor(i2) + ": count: " + dist[i2] + " sampling ratio: " + sampling_ratios[i2] + " actual relative frequency: " + (double)((float)dist[i2]) / sumdist * (double)dist.length);
            }
        }
        if (ArrayUtils.minValue(dist) == 0.0 && count2 < 10) {
            Log.info("Re-doing stratified sampling because not all classes were represented (unlucky draw).");
            r2.remove();
            return MRUtils.sampleFrameStratified(fr, label, weights, sampling_ratios, seed + 1L, debug, ++count2, quasibinomialDomain);
        }
        Frame shuffled = MRUtils.shuffleFramePerChunk(r2, seed + 92339987L);
        r2.remove();
        return shuffled;
    }

    public static Frame sampleFrame(Frame fr, long rows, String weightColumn, final long seed) {
        double fractionOfWeights;
        if (fr == null) {
            return null;
        }
        final int weightIdx = fr.find(weightColumn);
        if (weightIdx < 0) {
            fractionOfWeights = rows > 0L ? (double)rows / (double)fr.numRows() : 1.0;
        } else {
            double meanWeight = fr.vec(weightIdx).mean();
            double d2 = fractionOfWeights = rows > 0L ? (double)rows / ((double)fr.numRows() * meanWeight) : 1.0;
        }
        if (fractionOfWeights >= 1.0) {
            return fr;
        }
        Key newKey = fr._key != null ? Key.make(fr._key.toString() + (fr._key.toString().contains("temporary") ? ".sample." : ".temporary.sample.") + PrettyPrint.formatPct(fractionOfWeights).replace(" ", "")) : null;
        Frame r2 = ((MRTask)new MRTask(){

            @Override
            public void map(Chunk[] cs, NewChunk[] ncs) {
                RandomBase rng = RandomUtils.getRNG(0L);
                BufferedString bStr = new BufferedString();
                int count2 = 0;
                for (int r2 = 0; r2 < cs[0]._len; ++r2) {
                    double threshold;
                    rng.setSeed(seed + (long)r2 + cs[0].start());
                    double d2 = threshold = weightIdx < 0 ? fractionOfWeights : fractionOfWeights * cs[weightIdx].atd(r2);
                    if (!((double)rng.nextFloat() < threshold) && (count2 != 0 || r2 != cs[0]._len - 1)) continue;
                    ++count2;
                    for (int i2 = 0; i2 < ncs.length; ++i2) {
                        if (cs[i2].isNA(r2)) {
                            ncs[i2].addNA();
                            continue;
                        }
                        if (cs[i2] instanceof CStrChunk) {
                            ncs[i2].addStr(cs[i2].atStr(bStr, r2));
                            continue;
                        }
                        if (cs[i2] instanceof C16Chunk) {
                            ncs[i2].addUUID(cs[i2].at16l(r2), cs[i2].at16h(r2));
                            continue;
                        }
                        ncs[i2].addNum(cs[i2].atd(r2));
                    }
                }
            }
        }.doAll(fr.types(), fr)).outputFrame(newKey, fr.names(), fr.domains());
        if (r2.numRows() == 0L) {
            Log.warn("You asked for " + rows + " rows (out of " + fr.numRows() + "), but you got none (seed=" + seed + ").");
            Log.warn("Let's try again. You've gotta ask yourself a question: \"Do I feel lucky?\"");
            return MRUtils.sampleFrame(fr, rows, seed + 1L);
        }
        return r2;
    }

    public static Frame sampleFrameSmall(Frame fr, int rows, long seed) {
        return MRUtils.sampleFrameSmall(fr, rows, RandomUtils.getRNG(seed));
    }

    public static Frame sampleFrameSmall(Frame fr, int rows, Random rand) {
        if ((long)rows >= fr.numRows()) {
            return fr;
        }
        return fr.deepSlice(ArrayUtils.distinctLongs(rows, fr.numRows(), rand), null);
    }

    public static class Dist
    extends MRTask<Dist> {
        private IcedHashMap<IcedDouble, IcedAtomicInt> _dist;

        @Override
        public void map(Chunk ys) {
            this._dist = new IcedHashMap();
            IcedDouble d2 = new IcedDouble(0.0);
            for (int row = 0; row < ys._len; ++row) {
                if (ys.isNA(row)) continue;
                d2._val = ys.atd(row);
                IcedAtomicInt oldV = (IcedAtomicInt)this._dist.get(d2);
                if (oldV == null) {
                    oldV = this._dist.putIfAbsent(new IcedDouble(d2._val), new IcedAtomicInt(1));
                }
                if (oldV == null) continue;
                oldV.incrementAndGet();
            }
        }

        @Override
        public void reduce(Dist mrt) {
            if (this._dist != mrt._dist) {
                IcedHashMap<IcedDouble, IcedAtomicInt> l2 = this._dist;
                IcedHashMap<IcedDouble, IcedAtomicInt> r2 = mrt._dist;
                if (l2.size() < r2.size()) {
                    l2 = r2;
                    r2 = this._dist;
                }
                for (IcedDouble v2 : r2.keySet()) {
                    IcedAtomicInt oldVal = l2.putIfAbsent(v2, (IcedAtomicInt)r2.get(v2));
                    if (oldVal == null) continue;
                    oldVal.addAndGet(((IcedAtomicInt)r2.get(v2)).get());
                }
                this._dist = l2;
                mrt._dist = null;
            }
        }

        public double[] dist() {
            int i2 = 0;
            double[] dist = new double[this._dist.size()];
            for (IcedAtomicInt v2 : this._dist.values()) {
                dist[i2++] = v2.get();
            }
            return dist;
        }

        public double[] keys() {
            int i2 = 0;
            double[] keys = new double[this._dist.size()];
            for (IcedDouble k2 : this._dist.keySet()) {
                keys[i2++] = k2._val;
            }
            return keys;
        }
    }

    public static class ClassDistQuasibinomial
    extends MRTask<ClassDistQuasibinomial> {
        final int _nclass;
        private double[] _ys;
        private String[] _domain;
        private double _firstDoubleDomain;

        public ClassDistQuasibinomial(String[] domain) {
            this._nclass = 2;
            this._domain = domain;
            this._firstDoubleDomain = Double.valueOf(domain[0]);
        }

        public final double[] dist() {
            return this._ys;
        }

        public final double[] relDist() {
            double sum = ArrayUtils.sum(this._ys);
            return sum == 0.0 ? this._ys : ArrayUtils.div(Arrays.copyOf(this._ys, this._ys.length), sum);
        }

        public final String[] domains() {
            return this._domain;
        }

        @Override
        public void map(Chunk ys) {
            this._ys = new double[this._nclass];
            for (int i2 = 0; i2 < ys._len; ++i2) {
                int index;
                if (ys.isNA(i2)) continue;
                int n2 = index = ys.atd(i2) == this._firstDoubleDomain ? 0 : 1;
                this._ys[n2] = this._ys[n2] + 1.0;
            }
        }

        @Override
        public void map(Chunk ys, Chunk ws) {
            this._ys = new double[this._nclass];
            for (int i2 = 0; i2 < ys._len; ++i2) {
                int index;
                if (ys.isNA(i2)) continue;
                int n2 = index = ys.atd(i2) == this._firstDoubleDomain ? 0 : 1;
                this._ys[n2] = this._ys[n2] + ws.atd(i2);
            }
        }

        @Override
        public void reduce(ClassDistQuasibinomial that) {
            ArrayUtils.add(this._ys, that._ys);
        }
    }

    public static class ClassDist
    extends MRTask<ClassDist> {
        final int _nclass;
        protected double[] _ys;

        public ClassDist(Vec label) {
            this._nclass = label.domain().length;
        }

        public ClassDist(int n2) {
            this._nclass = n2;
        }

        public final double[] dist() {
            return this._ys;
        }

        public final double[] relDist() {
            double sum = ArrayUtils.sum(this._ys);
            return sum == 0.0 ? this._ys : ArrayUtils.div(Arrays.copyOf(this._ys, this._ys.length), sum);
        }

        @Override
        public void map(Chunk ys) {
            this._ys = new double[this._nclass];
            for (int i2 = 0; i2 < ys._len; ++i2) {
                if (ys.isNA(i2)) continue;
                int n2 = (int)ys.at8(i2);
                this._ys[n2] = this._ys[n2] + 1.0;
            }
        }

        @Override
        public void map(Chunk ys, Chunk ws) {
            this._ys = new double[this._nclass];
            for (int i2 = 0; i2 < ys._len; ++i2) {
                if (ys.isNA(i2)) continue;
                int n2 = (int)ys.at8(i2);
                this._ys[n2] = this._ys[n2] + ws.atd(i2);
            }
        }

        @Override
        public void reduce(ClassDist that) {
            ArrayUtils.add(this._ys, that._ys);
        }
    }
}

