/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.xgboost;

import biz.k11i.xgboost.Predictor;
import biz.k11i.xgboost.gbm.GBTree;
import biz.k11i.xgboost.gbm.GradBooster;
import biz.k11i.xgboost.learner.ObjFunction;
import biz.k11i.xgboost.tree.RegTree;
import biz.k11i.xgboost.tree.TreeSHAPHelper;
import biz.k11i.xgboost.util.FVec;
import hex.genmodel.PredictContributions;
import hex.genmodel.PredictContributionsFactory;
import hex.genmodel.algos.tree.ContributionsPredictor;
import hex.genmodel.algos.tree.ConvertTreeOptions;
import hex.genmodel.algos.tree.SharedTreeGraph;
import hex.genmodel.algos.tree.SharedTreeMojoModel;
import hex.genmodel.algos.tree.TreeSHAPEnsemble;
import hex.genmodel.algos.tree.TreeSHAPPredictor;
import hex.genmodel.algos.xgboost.OneHotEncoderFactory;
import hex.genmodel.algos.xgboost.XGBoostMojoModel;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.ArrayList;

public final class XGBoostJavaMojoModel
extends XGBoostMojoModel
implements PredictContributionsFactory {
    private Predictor _predictor;
    private TreeSHAPPredictor<FVec> _treeSHAPPredictor;
    private OneHotEncoderFactory _1hotFactory;

    public XGBoostJavaMojoModel(byte[] boosterBytes, String[] columns, String[][] domains, String responseColumn) {
        this(boosterBytes, columns, domains, responseColumn, false);
    }

    public XGBoostJavaMojoModel(byte[] boosterBytes, String[] columns, String[][] domains, String responseColumn, boolean enableTreeSHAP) {
        super(columns, domains, responseColumn);
        this._predictor = XGBoostJavaMojoModel.makePredictor(boosterBytes);
        this._treeSHAPPredictor = enableTreeSHAP ? XGBoostJavaMojoModel.makeTreeSHAPPredictor(this._predictor) : null;
    }

    @Override
    public void postReadInit() {
        this._1hotFactory = new OneHotEncoderFactory(this.backwardsCompatibility10(), this._sparse, this._catOffsets, this._cats, this._nums, this._useAllFactorLevels);
    }

    private boolean backwardsCompatibility10() {
        return this._mojo_version == 1.0 && !"gbtree".equals(this._boosterType);
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public static Predictor makePredictor(byte[] boosterBytes) {
        try (ByteArrayInputStream is = new ByteArrayInputStream(boosterBytes);){
            Predictor predictor = new Predictor(is);
            return predictor;
        }
        catch (IOException e2) {
            throw new IllegalStateException("Failed to load predictor.", e2);
        }
    }

    private static TreeSHAPPredictor<FVec> makeTreeSHAPPredictor(Predictor predictor) {
        if (predictor.getNumClass() > 2) {
            throw new UnsupportedOperationException("Calculating contributions is currently not supported for multinomial models.");
        }
        GBTree gbTree = (GBTree)predictor.getBooster();
        RegTree[] trees = gbTree.getGroupedTrees()[0];
        ArrayList predictors = new ArrayList(trees.length);
        for (RegTree tree : trees) {
            predictors.add(TreeSHAPHelper.makePredictor(tree));
        }
        float initPred = predictor.getBaseScore();
        return new TreeSHAPEnsemble<FVec>(predictors, initPred);
    }

    @Override
    public final double[] score0(double[] doubles, double offset, double[] preds) {
        float[] out;
        if (this.backwardsCompatibility10()) {
            if (doubles.length > this._cats + this._nums) {
                throw new ArrayIndexOutOfBoundsException("Too many input values.");
            }
            if (doubles.length < this._cats + this._nums) {
                double[] tmp = new double[this._cats + this._nums];
                System.arraycopy(doubles, 0, tmp, 0, doubles.length);
                doubles = tmp;
            }
        }
        FVec row = this._1hotFactory.fromArray(doubles);
        if (this._hasOffset) {
            out = this._predictor.predict(row, (float)offset);
        } else {
            if (offset != 0.0) {
                throw new UnsupportedOperationException("Unsupported: offset != 0");
            }
            out = this._predictor.predict(row);
        }
        return XGBoostJavaMojoModel.toPreds(doubles, out, preds, this._nclasses, this._priorClassDistrib, this._defaultThreshold);
    }

    public final Object makeContributionsWorkspace() {
        return this._treeSHAPPredictor.makeWorkspace();
    }

    public final float[] calculateContributions(FVec row, float[] out_contribs, Object workspace) {
        this._treeSHAPPredictor.calculateContributions(row, out_contribs, 0, -1, workspace);
        return out_contribs;
    }

    @Override
    public final PredictContributions makeContributionsPredictor() {
        TreeSHAPPredictor<FVec> treeSHAPPredictor = this._treeSHAPPredictor != null ? this._treeSHAPPredictor : XGBoostJavaMojoModel.makeTreeSHAPPredictor(this._predictor);
        return new XGBoostContributionsPredictor(this, treeSHAPPredictor);
    }

    static ObjFunction getObjFunction(String name) {
        return ObjFunction.fromName(name);
    }

    @Override
    public void close() {
        this._predictor = null;
        this._treeSHAPPredictor = null;
        this._1hotFactory = null;
    }

    @Override
    public SharedTreeGraph convert(int treeNumber, String treeClass) {
        GradBooster booster = this._predictor.getBooster();
        return this.computeGraph(booster, treeNumber);
    }

    @Override
    public SharedTreeGraph convert(int treeNumber, String treeClass, ConvertTreeOptions options) {
        return this.convert(treeNumber, treeClass);
    }

    @Override
    public double getInitF() {
        return this._predictor.getBaseScore();
    }

    @Override
    public SharedTreeMojoModel.LeafNodeAssignments getLeafNodeAssignments(double[] doubles) {
        FVec row = this._1hotFactory.fromArray(doubles);
        SharedTreeMojoModel.LeafNodeAssignments result = new SharedTreeMojoModel.LeafNodeAssignments();
        result._paths = this._predictor.predictLeafPath(row);
        result._nodeIds = this._predictor.predictLeaf(row);
        return result;
    }

    @Override
    public String[] getDecisionPath(double[] doubles) {
        FVec row = this._1hotFactory.fromArray(doubles);
        return this._predictor.predictLeafPath(row);
    }

    private static String[] makeFeatureContributionNames(XGBoostMojoModel m2) {
        String[] names = new String[m2._nums + m2._catOffsets[m2._cats]];
        String[] features = m2.features();
        int i2 = 0;
        for (int c2 = 0; c2 < features.length; ++c2) {
            if (m2._domains[c2] == null) {
                names[i2++] = features[c2];
                continue;
            }
            for (String d2 : m2._domains[c2]) {
                names[i2++] = features[c2] + "." + d2;
            }
            names[i2++] = features[c2] + ".missing(NA)";
        }
        assert (names.length == i2);
        return names;
    }

    private final class XGBoostContributionsPredictor
    extends ContributionsPredictor<FVec> {
        private XGBoostContributionsPredictor(XGBoostMojoModel model, TreeSHAPPredictor<FVec> treeSHAPPredictor) {
            super(XGBoostJavaMojoModel.this._nums + XGBoostJavaMojoModel.this._catOffsets[XGBoostJavaMojoModel.this._cats] + 1, XGBoostJavaMojoModel.makeFeatureContributionNames(model), treeSHAPPredictor);
        }

        @Override
        protected FVec toInputRow(double[] input) {
            return XGBoostJavaMojoModel.this._1hotFactory.fromArray(input);
        }
    }
}

