/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.xgboost;

import biz.k11i.xgboost.gbm.GBTree;
import biz.k11i.xgboost.gbm.GradBooster;
import biz.k11i.xgboost.tree.RegTree;
import biz.k11i.xgboost.tree.RegTreeNode;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.tree.ConvertTreeOptions;
import hex.genmodel.algos.tree.PlattScalingMojoHelper;
import hex.genmodel.algos.tree.SharedTreeGraph;
import hex.genmodel.algos.tree.SharedTreeGraphConverter;
import hex.genmodel.algos.tree.SharedTreeMojoModel;
import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.genmodel.algos.tree.TreeBackedMojoModel;
import java.io.Closeable;
import java.util.Arrays;

public abstract class XGBoostMojoModel
extends MojoModel
implements TreeBackedMojoModel,
SharedTreeGraphConverter,
PlattScalingMojoHelper.MojoModelWithCalibration,
Closeable {
    private static final String SPACE = " ";
    public String _boosterType;
    public int _ntrees;
    public int _nums;
    public int _cats;
    public int[] _catOffsets;
    public boolean _useAllFactorLevels;
    public boolean _sparse;
    public String _featureMap;
    public boolean _hasOffset;
    protected double[] _calib_glm_beta;

    public XGBoostMojoModel(String[] columns, String[][] domains, String responseColumn) {
        super(columns, domains, responseColumn);
    }

    public void postReadInit() {
    }

    @Override
    public boolean requiresOffset() {
        return this._hasOffset;
    }

    @Override
    public final double[] score0(double[] row, double[] preds) {
        if (this._hasOffset) {
            throw new IllegalStateException("Model was trained with offset, use score0 with offset");
        }
        return this.score0(row, 0.0, preds);
    }

    public static double[] toPreds(double[] in, float[] out, double[] preds, int nclasses, double[] priorClassDistrib, double defaultThreshold) {
        if (nclasses > 2) {
            for (int i2 = 0; i2 < out.length; ++i2) {
                preds[1 + i2] = out[i2];
            }
            preds[0] = GenModel.getPrediction(preds, priorClassDistrib, in, defaultThreshold);
        } else if (nclasses == 2) {
            preds[1] = 1.0f - out[0];
            preds[2] = out[0];
            preds[0] = GenModel.getPrediction(preds, priorClassDistrib, in, defaultThreshold);
        } else {
            preds[0] = out[0];
        }
        return preds;
    }

    @Override
    public int getNTreeGroups() {
        return this._ntrees;
    }

    @Override
    public int getNTreesPerGroup() {
        return this._nclasses > 2 ? this._nclasses : 1;
    }

    @Override
    public double[] getCalibGlmBeta() {
        return this._calib_glm_beta;
    }

    @Override
    public boolean calibrateClassProbabilities(double[] preds) {
        return PlattScalingMojoHelper.calibrateClassProbabilities(this, preds);
    }

    protected void constructSubgraph(RegTreeNode[] xgBoostNodes, SharedTreeNode sharedTreeNode, int nodeIndex, SharedTreeSubgraph sharedTreeSubgraph, boolean[] oneHotEncodedMap, boolean inclusiveNA, String[] features) {
        RegTreeNode xgBoostNode = xgBoostNodes[nodeIndex];
        if (oneHotEncodedMap[xgBoostNode.getSplitIndex()]) {
            sharedTreeNode.setSplitValue(1.0f);
        } else {
            sharedTreeNode.setSplitValue(xgBoostNode.getSplitCondition());
        }
        sharedTreeNode.setPredValue(xgBoostNode.getLeafValue());
        sharedTreeNode.setCol(xgBoostNode.getSplitIndex(), features[xgBoostNode.getSplitIndex()].split(SPACE)[1]);
        sharedTreeNode.setInclusiveNa(inclusiveNA);
        sharedTreeNode.setNodeNumber(nodeIndex);
        if (xgBoostNode.getLeftChildIndex() != -1) {
            this.constructSubgraph(xgBoostNodes, sharedTreeSubgraph.makeLeftChildNode(sharedTreeNode), xgBoostNode.getLeftChildIndex(), sharedTreeSubgraph, oneHotEncodedMap, xgBoostNode.default_left(), features);
        }
        if (xgBoostNode.getRightChildIndex() != -1) {
            this.constructSubgraph(xgBoostNodes, sharedTreeSubgraph.makeRightChildNode(sharedTreeNode), xgBoostNode.getRightChildIndex(), sharedTreeSubgraph, oneHotEncodedMap, !xgBoostNode.default_left(), features);
        }
    }

    private String[] constructFeatureMap() {
        String[] featureMapTokens = this._featureMap.split("\n");
        int nonEmptyTokenRange = featureMapTokens.length;
        for (int i2 = 0; i2 < featureMapTokens.length; ++i2) {
            if (!featureMapTokens[i2].trim().isEmpty()) continue;
            nonEmptyTokenRange = i2 + 1;
            break;
        }
        return Arrays.copyOfRange(featureMapTokens, 0, nonEmptyTokenRange);
    }

    protected boolean[] markOneHotEncodedCategoricals(String[] featureMap) {
        int numColumns = featureMap.length;
        int numCatCols = -1;
        for (int i2 = 0; i2 < featureMap.length; ++i2) {
            String[] s2 = featureMap[i2].split(SPACE);
            assert (s2.length >= 3);
            if (s2[2].equals("i")) continue;
            numCatCols = i2;
            break;
        }
        if (numCatCols == -1) {
            numCatCols = featureMap.length;
        }
        boolean[] categorical = new boolean[numColumns];
        for (int i3 = 0; i3 < numColumns; ++i3) {
            if (i3 >= numCatCols) continue;
            categorical[i3] = true;
        }
        return categorical;
    }

    SharedTreeGraph computeGraph(GradBooster booster, int treeToPrint) {
        if (!(booster instanceof GBTree)) {
            throw new IllegalArgumentException(String.format("Given XGBoost model is not backed by a tree-based booster. Booster class is %s", booster.getClass().getCanonicalName()));
        }
        int ntreeGroups = this.getNTreeGroups();
        int ntreePerGroup = this.getNTreesPerGroup();
        if (treeToPrint >= ntreeGroups) {
            throw new IllegalArgumentException("Tree " + treeToPrint + " does not exist (max " + ntreeGroups + ")");
        }
        String[] features = this.constructFeatureMap();
        boolean[] oneHotEncodedMap = this.markOneHotEncodedCategoricals(features);
        RegTree[][] treesAndClasses = ((GBTree)booster).getGroupedTrees();
        SharedTreeGraph g2 = new SharedTreeGraph();
        for (int j2 = Math.max(treeToPrint, 0); j2 < ntreeGroups; ++j2) {
            for (int i2 = 0; i2 < ntreePerGroup; ++i2) {
                if (j2 >= treesAndClasses[i2].length || treesAndClasses[i2][j2] == null) continue;
                RegTreeNode[] treeNodes = treesAndClasses[i2][j2].getNodes();
                assert (treeNodes.length >= 1);
                String[] domainValues = this.isSupervised() ? this.getDomainValues(this.getResponseIdx()) : null;
                String treeName = SharedTreeMojoModel.treeName(j2, i2, domainValues);
                SharedTreeSubgraph sg = g2.makeSubgraph(treeName);
                this.constructSubgraph(treeNodes, sg.makeRootNode(), 0, sg, oneHotEncodedMap, true, features);
            }
            if (treeToPrint >= 0) break;
        }
        return g2;
    }

    @Override
    public SharedTreeGraph convert(int treeNumber, String treeClass, ConvertTreeOptions options) {
        return this.convert(treeNumber, treeClass);
    }

    public static enum ObjectiveType {
        BINARY_LOGISTIC("binary:logistic"),
        REG_GAMMA("reg:gamma"),
        REG_TWEEDIE("reg:tweedie"),
        COUNT_POISSON("count:poisson"),
        REG_SQUAREDERROR("reg:squarederror"),
        REG_LINEAR("reg:linear"),
        MULTI_SOFTPROB("multi:softprob"),
        RANK_PAIRWISE("rank:pairwise");

        private String _id;

        private ObjectiveType(String id) {
            this._id = id;
        }

        public String getId() {
            return this._id;
        }

        public static ObjectiveType fromXGBoost(String type) {
            for (ObjectiveType t2 : ObjectiveType.values()) {
                if (!t2.getId().equals(type)) continue;
                return t2;
            }
            return null;
        }
    }
}

