/*
 * Decompiled with CFR 0.152.
 */
package water.rapids.ast.prims.advmath;

import java.util.Random;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.vals.ValFrame;
import water.util.RandomUtils;
import water.util.VecUtils;

public class AstKFold
extends AstPrimitive {
    @Override
    public String[] args() {
        return new String[]{"ary", "nfolds", "seed"};
    }

    @Override
    public int nargs() {
        return 4;
    }

    @Override
    public String str() {
        return "kfold_column";
    }

    public static Vec kfoldColumn(Vec v2, final int nfolds, final long seed) {
        new MRTask(){

            @Override
            public void map(Chunk c2) {
                long start = c2.start();
                for (int i2 = 0; i2 < c2._len; ++i2) {
                    int fold = Math.abs(RandomUtils.getRNG(start + seed + (long)i2).nextInt()) % nfolds;
                    c2.set(i2, fold);
                }
            }
        }.doAll(v2);
        return v2;
    }

    public static Vec moduloKfoldColumn(Vec v2, final int nfolds) {
        new MRTask(){

            @Override
            public void map(Chunk c2) {
                long start = c2.start();
                for (int i2 = 0; i2 < c2._len; ++i2) {
                    c2.set(i2, (int)((start + (long)i2) % (long)nfolds));
                }
            }
        }.doAll(v2);
        return v2;
    }

    public static Vec stratifiedKFoldColumn(Vec y2, final int nfolds, long seed) {
        if (!(y2.isCategorical() || y2.isNumeric() && y2.isInt())) {
            throw new IllegalArgumentException("stratification only applies to integer and categorical columns. Got: " + y2.get_type_str());
        }
        final long[] classes = ((VecUtils.CollectIntegerDomain)new VecUtils.CollectIntegerDomain().doAll(y2)).domain();
        final int nClass = y2.isNumeric() ? classes.length : y2.domain().length;
        final long[] seeds = new long[nClass];
        for (int i2 = 0; i2 < nClass; ++i2) {
            seeds[i2] = RandomUtils.getRNG(seed + (long)i2).nextLong();
        }
        return ((MRTask)new MRTask(){

            private int getFoldId(long absoluteRow, long seed) {
                return Math.abs(RandomUtils.getRNG(absoluteRow + seed).nextInt()) % nfolds;
            }

            @Override
            public void map(Chunk[] y2) {
                long start = y2[0].start();
                for (int testFold = 0; testFold < nfolds; ++testFold) {
                    for (int classLabel = 0; classLabel < nClass; ++classLabel) {
                        for (int row = 0; row < y2[0]._len; ++row) {
                            if (y2[0].isNA(row)) {
                                if ((start + (long)row) % (long)nfolds != (long)testFold) continue;
                                y2[1].set(row, testFold);
                                continue;
                            }
                            if (y2[0].at8(row) != (classes == null ? (long)classLabel : classes[classLabel]) || testFold != this.getFoldId(start + (long)row, seeds[classLabel])) continue;
                            y2[1].set(row, testFold);
                        }
                    }
                }
            }
        }.doAll((Frame)new Frame((Vec[])new Vec[]{y2, y2.makeZero()})))._fr.vec(1);
    }

    @Override
    public ValFrame apply(Env env, Env.StackHelp stk, AstRoot[] asts) {
        Vec foldVec = stk.track(asts[1].exec(env)).getFrame().anyVec().makeZero();
        int nfolds = (int)asts[2].exec(env).getNum();
        long seed = (long)asts[3].exec(env).getNum();
        return new ValFrame(new Frame(AstKFold.kfoldColumn(foldVec, nfolds, seed == -1L ? new Random().nextLong() : seed)));
    }
}

