/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.xgboost.task;

import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.ModelMetricsRegression;
import hex.genmodel.GenModel;
import hex.tree.xgboost.XGBoostModel;
import hex.tree.xgboost.XGBoostOutput;
import hex.tree.xgboost.predict.XGBoostBigScorePredict;
import hex.tree.xgboost.predict.XGBoostPredict;
import water.MRTask;
import water.MemoryManager;
import water.fvec.Chunk;
import water.fvec.NewChunk;

public class XGBoostScoreTask
extends MRTask<XGBoostScoreTask> {
    private final XGBoostOutput _output;
    private final int _weightsChunkId;
    private final XGBoostModel _model;
    private final boolean _isTrain;
    private final double _threshold;
    public ModelMetrics.MetricBuilder _metricBuilder;
    private transient XGBoostBigScorePredict _predict;

    public XGBoostScoreTask(XGBoostOutput output, int weightsChunkId, boolean isTrain, XGBoostModel model) {
        this._output = output;
        this._weightsChunkId = weightsChunkId;
        this._model = model;
        this._isTrain = isTrain;
        this._threshold = model.defaultThreshold();
    }

    private ModelMetrics.MetricBuilder createMetricsBuilder(int responseClassesNum, String[] responseDomain) {
        switch (responseClassesNum) {
            case 1: {
                return new ModelMetricsRegression.MetricBuilderRegression();
            }
            case 2: {
                return new ModelMetricsBinomial.MetricBuilderBinomial(responseDomain);
            }
        }
        return new ModelMetricsMultinomial.MetricBuilderMultinomial(responseClassesNum, responseDomain, ((XGBoostModel.XGBoostParameters)this._model._parms)._auc_type);
    }

    @Override
    protected void setupLocal() {
        this._predict = this._model.setupBigScorePredict(this._isTrain);
    }

    @Override
    public void map(Chunk[] cs, NewChunk[] ncs) {
        this._metricBuilder = this.createMetricsBuilder(this._output.nclasses(), this._output.classNames());
        XGBoostPredict predictor = this._predict.initMap(this._fr, cs);
        float[][] preds = predictor.predict(cs);
        if (preds.length == 0) {
            return;
        }
        assert (preds.length == cs[0]._len);
        Chunk responseChunk = cs[this._output.responseIdx()];
        if (this._output.nclasses() == 1) {
            double[] currentPred = new double[1];
            float[] yact = new float[1];
            for (int j2 = 0; j2 < preds.length; ++j2) {
                currentPred[0] = preds[j2][0];
                yact[0] = (float)responseChunk.atd(j2);
                double weight = this._weightsChunkId != -1 ? cs[this._weightsChunkId].atd(j2) : 1.0;
                this._metricBuilder.perRow(currentPred, yact, weight, 0.0, this._model);
            }
            for (int i2 = 0; i2 < cs[0]._len; ++i2) {
                ncs[0].addNum(preds[i2][0]);
            }
        } else if (this._output.nclasses() == 2) {
            double[] row = new double[3];
            float[] yact = new float[1];
            for (int i3 = 0; i3 < cs[0]._len; ++i3) {
                double p2 = preds[i3][0];
                row[1] = 1.0 - p2;
                row[2] = p2;
                row[0] = GenModel.getPrediction(row, this._output._priorClassDist, null, this._threshold);
                ncs[0].addNum(row[0]);
                ncs[1].addNum(row[1]);
                ncs[2].addNum(row[2]);
                double weight = this._weightsChunkId != -1 ? cs[this._weightsChunkId].atd(i3) : 1.0;
                yact[0] = (float)responseChunk.atd(i3);
                this._metricBuilder.perRow(row, yact, weight, 0.0, this._model);
            }
        } else {
            float[] yact = new float[1];
            double[] row = MemoryManager.malloc8d(ncs.length);
            for (int i4 = 0; i4 < cs[0]._len; ++i4) {
                for (int j3 = 1; j3 < row.length; ++j3) {
                    double val = preds[i4][j3 - 1];
                    ncs[j3].addNum(val);
                    row[j3] = val;
                }
                row[0] = GenModel.getPrediction(row, this._output._priorClassDist, null, this._threshold);
                ncs[0].addNum(row[0]);
                yact[0] = (float)responseChunk.atd(i4);
                double weight = this._weightsChunkId != -1 ? cs[this._weightsChunkId].atd(i4) : 1.0;
                this._metricBuilder.perRow(row, yact, weight, 0.0, this._model);
            }
        }
    }

    @Override
    public void reduce(XGBoostScoreTask mrt) {
        super.reduce(mrt);
        this._metricBuilder.reduce(mrt._metricBuilder);
    }
}

