/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.xgboost.predict;

import biz.k11i.xgboost.Predictor;
import biz.k11i.xgboost.gbm.GBTree;
import biz.k11i.xgboost.tree.RegTree;
import biz.k11i.xgboost.tree.RegTreeNode;
import hex.DataInfo;
import hex.LinkFunction;
import hex.LinkFunctionFactory;
import hex.genmodel.utils.DistributionFamily;
import hex.genmodel.utils.LinkFunctionType;
import hex.tree.xgboost.XGBoostModelInfo;
import hex.tree.xgboost.XGBoostOutput;
import hex.tree.xgboost.predict.MutableOneHotEncoderFVec;
import hex.tree.xgboost.predict.PredictorFactory;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.NewChunk;
import water.util.ArrayUtils;

public class UpdateAuxTreeWeightsTask
extends MRTask<UpdateAuxTreeWeightsTask> {
    private final DistributionFamily _dist;
    private final Predictor _p;
    private final DataInfo _di;
    private final boolean _sparse;
    private double[][] _nodeWeights;

    public UpdateAuxTreeWeightsTask(DistributionFamily dist, DataInfo di, XGBoostModelInfo modelInfo, XGBoostOutput output) {
        this._dist = dist;
        this._p = PredictorFactory.makePredictor(modelInfo._boosterBytes, null, false);
        this._di = di;
        this._sparse = output._sparse;
        if (this._p.getNumClass() > 2) {
            throw new UnsupportedOperationException("Updating tree weights is currently not supported for multinomial models.");
        }
        if (this._dist != DistributionFamily.gaussian && this._dist != DistributionFamily.bernoulli) {
            throw new UnsupportedOperationException("Updating tree weights is currently not supported for distribution " + (Object)((Object)this._dist) + ".");
        }
    }

    private double[][] initNodeWeights() {
        GBTree gbTree = (GBTree)this._p.getBooster();
        RegTree[] trees = gbTree.getGroupedTrees()[0];
        double[][] nodeWeights = new double[trees.length][];
        for (int i2 = 0; i2 < trees.length; ++i2) {
            nodeWeights[i2] = new double[trees[i2].getStats().length];
        }
        return nodeWeights;
    }

    @Override
    public void map(Chunk[] chks, NewChunk[] idx) {
        this._nodeWeights = this.initNodeWeights();
        LinkFunction logit = LinkFunctionFactory.getLinkFunction(LinkFunctionType.logit);
        RegTree[] trees = ((GBTree)this._p.getBooster()).getGroupedTrees()[0];
        MutableOneHotEncoderFVec inputVec = new MutableOneHotEncoderFVec(this._di, this._sparse);
        int inputLength = chks.length - 1;
        int weightIndex = chks.length - 1;
        double[] input = new double[inputLength];
        for (int row = 0; row < chks[0]._len; ++row) {
            double weight = chks[weightIndex].atd(row);
            if (weight == 0.0 || Double.isNaN(weight)) continue;
            for (int i2 = 0; i2 < input.length; ++i2) {
                input[i2] = chks[i2].atd(row);
            }
            inputVec.setInput(input);
            int ntrees = this._nodeWeights.length;
            int[] leafIdx = this._p.getBooster().predictLeaf(inputVec, ntrees);
            assert (leafIdx.length == ntrees) : "Leaf indices (#idx=" + leafIdx.length + ") were not returned for all trees (#trees=" + ntrees + ").";
            if (this._dist == DistributionFamily.gaussian) {
                for (int i3 = 0; i3 < leafIdx.length; ++i3) {
                    double[] dArray = this._nodeWeights[i3];
                    int n2 = leafIdx[i3];
                    dArray[n2] = dArray[n2] + weight;
                }
                continue;
            }
            assert (this._dist == DistributionFamily.bernoulli);
            double f2 = -this._p.getBaseScore();
            for (int i4 = 0; i4 < leafIdx.length; ++i4) {
                RegTreeNode[] nodes = trees[i4].getNodes();
                double p2 = logit.linkInv(f2);
                double hessian = p2 * (1.0 - p2);
                double[] dArray = this._nodeWeights[i4];
                int n3 = leafIdx[i4];
                dArray[n3] = dArray[n3] + weight * hessian;
                f2 += (double)nodes[leafIdx[i4]].getLeafValue();
            }
        }
    }

    @Override
    public void reduce(UpdateAuxTreeWeightsTask mrt) {
        ArrayUtils.add(this._nodeWeights, mrt._nodeWeights);
    }

    @Override
    protected void postGlobal() {
        GBTree gbTree = (GBTree)this._p.getBooster();
        RegTree[] trees = gbTree.getGroupedTrees()[0];
        for (int i2 = 0; i2 < trees.length; ++i2) {
            RegTreeNode[] nodes = trees[i2].getNodes();
            for (int j2 = nodes.length - 1; j2 >= 0; --j2) {
                RegTreeNode node = nodes[j2];
                int parentId = node.getParentIndex();
                if (parentId < 0) continue;
                assert (parentId < j2) : "Broken tree #" + i2 + ". Tree rollups assume parentId (=" + parentId + " < childId (=" + j2 + ").";
                RegTreeNode parent = nodes[parentId];
                this._nodeWeights[i2][parentId] = this._nodeWeights[i2][parent.getLeftChildIndex()] + this._nodeWeights[i2][parent.getRightChildIndex()];
            }
        }
    }

    public double[][] getNodeWeights() {
        return this._nodeWeights;
    }
}

