/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.xgboost.predict;

import hex.DataInfo;
import hex.Model;
import hex.genmodel.algos.tree.ContributionComposer;
import hex.genmodel.algos.tree.TreeSHAPPredictor;
import hex.tree.xgboost.XGBoostModelInfo;
import hex.tree.xgboost.XGBoostOutput;
import hex.tree.xgboost.predict.MutableOneHotEncoderFVec;
import hex.tree.xgboost.predict.PredictTreeSHAPTask;
import water.MemoryManager;
import water.fvec.Chunk;
import water.fvec.NewChunk;

public class PredictTreeSHAPSortingTask
extends PredictTreeSHAPTask {
    private final boolean _outputAggregated;
    private final int _topN;
    private final int _bottomN;
    private final boolean _compareAbs;

    public PredictTreeSHAPSortingTask(DataInfo di, XGBoostModelInfo modelInfo, XGBoostOutput output, Model.Contributions.ContributionsOptions options) {
        super(di, modelInfo, output, options);
        this._outputAggregated = Model.Contributions.ContributionsOutputFormat.Compact.equals((Object)options._outputFormat);
        this._topN = options._topN;
        this._bottomN = options._bottomN;
        this._compareAbs = options._compareAbs;
    }

    protected void fillInput(Chunk[] chks, int row, double[] input, float[] contribs, int[] contribNameIds) {
        super.fillInput(chks, row, input, contribs);
        for (int i2 = 0; i2 < contribNameIds.length; ++i2) {
            contribNameIds[i2] = i2;
        }
    }

    @Override
    public void map(Chunk[] chks, NewChunk[] nc) {
        MutableOneHotEncoderFVec rowFVec = new MutableOneHotEncoderFVec(this._di, this._output._sparse);
        double[] input = MemoryManager.malloc8d(chks.length);
        float[] contribs = MemoryManager.malloc4f(this._di.fullN() + 1);
        float[] output = this._outputAggregated ? MemoryManager.malloc4f(chks.length) : contribs;
        int[] contribNameIds = MemoryManager.malloc4(output.length);
        TreeSHAPPredictor.Workspace workspace = this._mojo.makeContributionsWorkspace();
        for (int row = 0; row < chks[0]._len; ++row) {
            this.fillInput(chks, row, input, contribs, contribNameIds);
            rowFVec.setInput(input);
            this._mojo.calculateContributions(rowFVec, contribs, workspace);
            this.handleOutputFormat(rowFVec, contribs, output);
            ContributionComposer contributionComposer = new ContributionComposer();
            int[] contribNameIdsSorted = contributionComposer.composeContributions(contribNameIds, output, this._topN, this._bottomN, this._compareAbs);
            this.addContribToNewChunk(contribs, contribNameIdsSorted, nc);
        }
    }

    protected void addContribToNewChunk(float[] contribs, int[] contribNamesSorted, NewChunk[] nc) {
        int i2 = 0;
        int inputPointer = 0;
        while (i2 < nc.length - 1) {
            nc[i2].addNum(contribNamesSorted[inputPointer]);
            nc[i2 + 1].addNum(contribs[contribNamesSorted[inputPointer]]);
            i2 += 2;
            ++inputPointer;
        }
        nc[nc.length - 1].addNum(contribs[contribs.length - 1]);
    }
}

