/*
 * Decompiled with CFR 0.152.
 */
package hex;

import hex.DataInfo;
import java.util.Arrays;
import water.DKV;
import water.H2O;
import water.Job;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.NewChunk;
import water.util.ArrayUtils;
import water.util.RandomBase;
import water.util.RandomUtils;

public abstract class FrameTask<T extends FrameTask<T>>
extends MRTask<T> {
    protected boolean _sparse;
    protected transient DataInfo _dinfo;
    final Key _dinfoKey;
    final int[] _activeCols;
    protected final Key<Job> _jobKey;
    protected float _useFraction = 1.0f;
    private final long _seed;
    protected boolean _shuffle = false;
    private final int _iteration;

    public DataInfo dinfo() {
        return this._dinfo;
    }

    public FrameTask(Key<Job> jobKey, DataInfo dinfo) {
        this(jobKey, dinfo, -557122578L, -1, false);
    }

    public FrameTask(Key<Job> jobKey, DataInfo dinfo, long seed, int iteration, boolean sparse) {
        this(jobKey, dinfo == null ? null : dinfo._key, dinfo == null ? null : dinfo._activeCols, seed, iteration, sparse, null);
    }

    public FrameTask(Key<Job> jobKey, DataInfo dinfo, long seed, int iteration, boolean sparse, H2O.H2OCountedCompleter cmp) {
        this(jobKey, dinfo == null ? null : dinfo._key, dinfo == null ? null : dinfo._activeCols, seed, iteration, sparse, cmp);
    }

    private FrameTask(Key<Job> jobKey, Key dinfoKey, int[] activeCols, long seed, int iteration, boolean sparse, H2O.H2OCountedCompleter cmp) {
        super(cmp);
        this._jobKey = jobKey;
        this._dinfoKey = dinfoKey;
        this._activeCols = activeCols;
        this._seed = seed;
        this._iteration = iteration;
        this._sparse = sparse;
    }

    @Override
    protected void setupLocal() {
        DataInfo dinfo = (DataInfo)DKV.get(this._dinfoKey).get();
        this._dinfo = this._activeCols == null ? dinfo : dinfo.filterExpandedColumns(this._activeCols);
    }

    @Override
    protected void closeLocal() {
        this._dinfo = null;
    }

    protected void processRow(long gid, DataInfo.Row r2) {
        throw new RuntimeException("should've been overridden!");
    }

    protected void processRow(long gid, DataInfo.Row r2, NewChunk[] outputs) {
        throw new RuntimeException("should've been overridden!");
    }

    protected void processRow(long gid, DataInfo.Row r2, int mb) {
        throw new RuntimeException("should've been overridden!");
    }

    protected boolean skipRow(long gid) {
        return false;
    }

    protected void processMiniBatch(long seed, double[] responses, double[] offsets, int n2) {
    }

    protected int getMiniBatchSize() {
        return 0;
    }

    protected boolean chunkInit() {
        return true;
    }

    protected void chunkDone(long n2) {
    }

    @Override
    public void map(Chunk[] chunks, NewChunk[] outputs) {
        int[] shufIdx;
        if (this._jobKey != null && this._jobKey.get() != null && this._jobKey.get().stop_requested()) {
            throw new Job.JobCancelledException();
        }
        int nrows = chunks[0]._len;
        long offset = chunks[0].start();
        boolean doWork = this.chunkInit();
        if (!doWork) {
            return;
        }
        boolean obs_weights = this._dinfo._weights && !this._fr.vecs()[this._dinfo.weightChunkId()].isConst() && !this._fr.vecs()[this._dinfo.weightChunkId()].isBinary();
        double global_weight_sum = obs_weights ? (double)Math.round(this._fr.vecs()[this._dinfo.weightChunkId()].mean() * (double)this._fr.numRows()) : 0.0;
        DataInfo.Row row = null;
        DataInfo.Row[] rows = null;
        if (this._sparse) {
            rows = this._dinfo.extractSparseRows(chunks);
        } else {
            row = this._dinfo.newDenseRow();
        }
        double[] weight_map = null;
        double relative_chunk_weight = 1.0;
        if (obs_weights) {
            weight_map = new double[nrows];
            double weight_sum = 0.0;
            for (int i2 = 0; i2 < nrows; ++i2) {
                row = this._sparse ? rows[i2] : this._dinfo.extractDenseRow(chunks, i2, row);
                weight_map[i2] = weight_sum += row.weight;
                assert (i2 == 0 || row.weight == 0.0 || weight_map[i2] > weight_map[i2 - 1]);
            }
            if (weight_sum > 0.0) {
                ArrayUtils.div(weight_map, weight_sum);
                relative_chunk_weight = global_weight_sum * (double)nrows / (double)this._fr.numRows() / weight_sum;
            } else {
                return;
            }
        }
        int repeats = (int)Math.ceil((double)this._useFraction * relative_chunk_weight);
        float fraction = (float)((double)this._useFraction * relative_chunk_weight) / (float)repeats;
        assert ((double)fraction <= 1.0);
        boolean sample = (double)fraction < 0.999 || obs_weights || this._shuffle;
        long chunkSeed = (-8704322056524490956L + this._seed + offset) * ((long)this._iteration + -7484065362112007133L);
        RandomBase skip_rng = sample ? RandomUtils.getRNG(chunkSeed) : null;
        int[] nArray = shufIdx = skip_rng == null ? null : new int[nrows];
        if (skip_rng != null) {
            for (int i3 = 0; i3 < nrows; ++i3) {
                shufIdx[i3] = i3;
            }
            ArrayUtils.shuffleArray(shufIdx, skip_rng);
        }
        double[] responses = new double[this.getMiniBatchSize()];
        double[] offsets = new double[this.getMiniBatchSize()];
        long seed = 0L;
        int miniBatchSize = this.getMiniBatchSize();
        long num_processed_rows = 0L;
        long num_skipped_rows = 0L;
        int miniBatchCounter = 0;
        for (int rep = 0; rep < repeats; ++rep) {
            for (int row_idx = 0; row_idx < nrows; ++row_idx) {
                int r2;
                int n2 = r2 = sample ? -1 : 0;
                if (sample && !obs_weights && skip_rng.nextDouble() > (double)fraction) continue;
                if (obs_weights && num_processed_rows % 2L == 0L) {
                    double key = skip_rng.nextDouble();
                    r2 = Arrays.binarySearch(weight_map, 0, nrows, key);
                    if (r2 < 0) {
                        r2 = -r2 - 1;
                    }
                    assert (r2 == 0 || weight_map[r2] > weight_map[r2 - 1]);
                } else if (r2 == -1) {
                    r2 = shufIdx[row_idx];
                    while (obs_weights && (r2 == 0 && weight_map[r2] == 0.0 || r2 > 0 && weight_map[r2] == weight_map[r2 - 1])) {
                        r2 = skip_rng.nextInt(nrows);
                    }
                } else {
                    assert (!obs_weights);
                    r2 = row_idx;
                }
                assert (r2 >= 0 && r2 <= nrows);
                seed = offset + (long)(rep * nrows) + (long)r2;
                if (this.skipRow(seed)) {
                    ++num_skipped_rows;
                    continue;
                }
                DataInfo.Row row2 = row = this._sparse ? rows[r2] : this._dinfo.extractDenseRow(chunks, r2, row);
                if (row.isBad() || row.weight == 0.0) {
                    ++num_skipped_rows;
                    continue;
                }
                assert (row.weight > 0.0);
                if (outputs != null && outputs.length > 0) {
                    assert (miniBatchSize == 0);
                    this.processRow(seed, row, outputs);
                } else if (miniBatchSize > 0) {
                    this.processRow(seed, row, miniBatchCounter);
                    responses[miniBatchCounter] = row.response != null && row.response.length > 0 ? row.response(0) : 0.0;
                    offsets[miniBatchCounter] = row.offset;
                    ++miniBatchCounter;
                } else {
                    this.processRow(seed, row);
                }
                ++num_processed_rows;
                if (miniBatchCounter <= 0 || miniBatchCounter % miniBatchSize != 0) continue;
                this.processMiniBatch(seed, responses, offsets, miniBatchCounter);
                miniBatchCounter = 0;
            }
        }
        if (miniBatchCounter > 0) {
            this.processMiniBatch(seed, responses, offsets, miniBatchCounter);
        }
        assert (fraction != 1.0f || num_processed_rows + num_skipped_rows == (long)(repeats * nrows));
        this.chunkDone(num_processed_rows);
    }

    public static class ExtractDenseRow
    extends MRTask<ExtractDenseRow> {
        private final DataInfo _di;
        private final long _gid;
        public DataInfo.Row _row;

        public ExtractDenseRow(DataInfo di, long globalRowId) {
            this._di = di;
            this._gid = globalRowId;
        }

        @Override
        public void map(Chunk[] cs) {
            long start = cs[0].start();
            if (start <= this._gid && cs[0].start() + (long)cs[0].len() > this._gid) {
                this._row = this._di.newDenseRow();
                this._di.extractDenseRow(cs, (int)(this._gid - cs[0].start()), this._row);
            }
        }

        @Override
        public void reduce(ExtractDenseRow mrt) {
            if (mrt._row != null) {
                assert (this._row == null);
                this._row = mrt._row;
            }
        }
    }
}

