/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.xgboost;

import biz.k11i.xgboost.Predictor;
import biz.k11i.xgboost.gbm.GBTree;
import biz.k11i.xgboost.gbm.GradBooster;
import biz.k11i.xgboost.tree.RegTree;
import biz.k11i.xgboost.tree.RegTreeNode;
import biz.k11i.xgboost.tree.RegTreeNodeStat;
import biz.k11i.xgboost.util.FVec;
import hex.DataInfo;
import hex.FeatureInteractions;
import hex.FeatureInteractionsCollector;
import hex.FriedmanPopescusHCollector;
import hex.KeyValue;
import hex.Model;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.ModelMetricsRegression;
import hex.genmodel.algos.tree.ContributionComposer;
import hex.genmodel.algos.tree.ConvertTreeOptions;
import hex.genmodel.algos.tree.SharedTreeGraph;
import hex.genmodel.algos.tree.SharedTreeGraphConverter;
import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.genmodel.algos.xgboost.XGBoostJavaMojoModel;
import hex.genmodel.algos.xgboost.XGBoostMojoModel;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.FriedmanPopescusH;
import hex.tree.PlattScalingHelper;
import hex.tree.xgboost.BoosterParms;
import hex.tree.xgboost.XGBoost;
import hex.tree.xgboost.XGBoostModelInfo;
import hex.tree.xgboost.XGBoostMojoWriter;
import hex.tree.xgboost.XGBoostOutput;
import hex.tree.xgboost.XGBoostPojoWriter;
import hex.tree.xgboost.XGBoostUtils;
import hex.tree.xgboost.predict.AssignLeafNodeTask;
import hex.tree.xgboost.predict.AuxNodeWeights;
import hex.tree.xgboost.predict.MutableOneHotEncoderFVec;
import hex.tree.xgboost.predict.PredictTreeSHAPSortingTask;
import hex.tree.xgboost.predict.PredictTreeSHAPTask;
import hex.tree.xgboost.predict.PredictorFactory;
import hex.tree.xgboost.predict.UpdateAuxTreeWeightsTask;
import hex.tree.xgboost.predict.XGBoostBigScorePredict;
import hex.tree.xgboost.predict.XGBoostJavaBigScorePredict;
import hex.tree.xgboost.predict.XGBoostJavaVariableImportance;
import hex.tree.xgboost.predict.XGBoostModelMetrics;
import hex.tree.xgboost.predict.XGBoostNativeBigScorePredict;
import hex.tree.xgboost.predict.XGBoostNativeVariableImportance;
import hex.tree.xgboost.predict.XGBoostVariableImportance;
import hex.tree.xgboost.util.GpuUtils;
import hex.tree.xgboost.util.PredictConfiguration;
import hex.util.EffectiveParametersUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.function.Consumer;
import java.util.stream.Stream;
import org.apache.log4j.Logger;
import water.AutoBuffer;
import water.DKV;
import water.Futures;
import water.H2O;
import water.H2ONode;
import water.IcedUtils;
import water.Job;
import water.JobUpdatePostMap;
import water.Key;
import water.Keyed;
import water.codegen.CodeGeneratorPipeline;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.JCodeGen;
import water.util.SBPrintStream;
import water.util.TwoDimTable;

public class XGBoostModel
extends Model<XGBoostModel, XGBoostParameters, XGBoostOutput>
implements SharedTreeGraphConverter,
Model.LeafNodeAssignment,
Model.Contributions,
FeatureInteractionsCollector,
Model.UpdateAuxTreeWeights,
FriedmanPopescusHCollector {
    private static final Logger LOG = Logger.getLogger(XGBoostModel.class);
    private static final String PROP_VERBOSITY = "sys.ai.h2o..xgboost.verbosity";
    private static final String PROP_NTHREAD = "sys.ai.h2o.xgboost.nthreadMax";
    private XGBoostModelInfo model_info;

    public XGBoostModelInfo model_info() {
        return this.model_info;
    }

    @Override
    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
        switch (((XGBoostOutput)this._output).getModelCategory()) {
            case Binomial: {
                return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
            }
            case Multinomial: {
                return new ModelMetricsMultinomial.MetricBuilderMultinomial(((XGBoostOutput)this._output).nclasses(), domain, ((XGBoostParameters)this._parms)._auc_type);
            }
            case Regression: {
                return new ModelMetricsRegression.MetricBuilderRegression();
            }
        }
        throw H2O.unimpl();
    }

    public XGBoostModel(Key<XGBoostModel> selfKey, XGBoostParameters parms, XGBoostOutput output, Frame train, Frame valid) {
        super(selfKey, parms, output);
        DataInfo dinfo = XGBoost.makeDataInfo(train, valid, (XGBoostParameters)this._parms, output.nclasses());
        DKV.put(dinfo);
        this.setDataInfoToOutput(dinfo);
        this.model_info = new XGBoostModelInfo(parms, dinfo);
    }

    @Override
    public void initActualParamValues() {
        super.initActualParamValues();
        EffectiveParametersUtils.initFoldAssignment(this._parms);
        ((XGBoostParameters)this._parms)._backend = XGBoostModel.getActualBackend((XGBoostParameters)this._parms, true);
        ((XGBoostParameters)this._parms)._tree_method = XGBoostModel.getActualTreeMethod((XGBoostParameters)this._parms);
    }

    public static XGBoostParameters.TreeMethod getActualTreeMethod(XGBoostParameters p2) {
        if (p2._tree_method == XGBoostParameters.TreeMethod.auto) {
            if (p2._backend == XGBoostParameters.Backend.gpu) {
                return XGBoostParameters.TreeMethod.hist;
            }
            if (H2O.getCloudSize() > 1) {
                if (p2._monotone_constraints != null && p2._booster != XGBoostParameters.Booster.gblinear && p2._backend != XGBoostParameters.Backend.gpu) {
                    return XGBoostParameters.TreeMethod.hist;
                }
                return XGBoostParameters.TreeMethod.approx;
            }
            if (p2.train() != null && p2.train().numRows() >= 0x400000L) {
                return XGBoostParameters.TreeMethod.approx;
            }
            return XGBoostParameters.TreeMethod.exact;
        }
        return p2._tree_method;
    }

    public void initActualParamValuesAfterOutputSetup(boolean isClassifier, int nclasses) {
        EffectiveParametersUtils.initStoppingMetric(this._parms, isClassifier);
        EffectiveParametersUtils.initCategoricalEncoding(this._parms, Model.Parameters.CategoricalEncodingScheme.OneHotInternal);
        EffectiveParametersUtils.initDistribution(this._parms, nclasses);
        ((XGBoostParameters)this._parms)._dmatrix_type = ((XGBoostOutput)this._output)._sparse ? XGBoostParameters.DMatrixType.sparse : XGBoostParameters.DMatrixType.dense;
    }

    public static XGBoostParameters.Backend getActualBackend(XGBoostParameters p2, boolean verbose) {
        Consumer<String> log;
        Consumer<String> consumer = verbose ? arg_0 -> ((Logger)LOG).info(arg_0) : (log = arg_0 -> ((Logger)LOG).debug(arg_0));
        if (p2._backend == XGBoostParameters.Backend.auto || p2._backend == XGBoostParameters.Backend.gpu) {
            if (H2O.getCloudSize() > 1 && !p2._build_tree_one_node && !XGBoost.allowMultiGPU()) {
                log.accept("GPU backend not supported in distributed mode. Using CPU backend.");
                return XGBoostParameters.Backend.cpu;
            }
            if (!p2.gpuIncompatibleParams().isEmpty()) {
                log.accept("GPU backend not supported for the choice of parameters (" + p2.gpuIncompatibleParams() + "). Using CPU backend.");
                return XGBoostParameters.Backend.cpu;
            }
            if (GpuUtils.hasGPU(H2O.CLOUD.members()[0], p2._gpu_id)) {
                log.accept("Using GPU backend (gpu_id: " + Arrays.toString(p2._gpu_id) + ").");
                return XGBoostParameters.Backend.gpu;
            }
            log.accept("No GPU (gpu_id: " + Arrays.toString(p2._gpu_id) + ") found. Using CPU backend.");
            return XGBoostParameters.Backend.cpu;
        }
        log.accept("Using CPU backend.");
        return XGBoostParameters.Backend.cpu;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public static Map<String, Object> createParamsMap(XGBoostParameters p2, int nClasses, String[] coefNames) {
        String[][] interactionConstraints;
        int nthread;
        HashMap<String, Object> params = new HashMap<String, Object>();
        if (p2._n_estimators != 0) {
            LOG.info((Object)"Using user-provided parameter n_estimators instead of ntrees.");
            params.put("nround", p2._n_estimators);
            p2._ntrees = p2._n_estimators;
        } else {
            params.put("nround", p2._ntrees);
            p2._n_estimators = p2._ntrees;
        }
        if (p2._eta != 0.3) {
            params.put("eta", p2._eta);
            p2._learn_rate = p2._eta;
        } else {
            params.put("eta", p2._learn_rate);
            p2._eta = p2._learn_rate;
        }
        params.put("max_depth", p2._max_depth);
        if (System.getProperty(PROP_VERBOSITY) != null) {
            params.put("verbosity", System.getProperty(PROP_VERBOSITY));
        } else {
            params.put("silent", p2._quiet_mode);
        }
        if (p2._subsample != 1.0) {
            params.put("subsample", p2._subsample);
            p2._sample_rate = p2._subsample;
        } else {
            params.put("subsample", p2._sample_rate);
            p2._subsample = p2._sample_rate;
        }
        if (p2._colsample_bytree != 1.0) {
            params.put("colsample_bytree", p2._colsample_bytree);
            p2._col_sample_rate_per_tree = p2._colsample_bytree;
        } else {
            params.put("colsample_bytree", p2._col_sample_rate_per_tree);
            p2._colsample_bytree = p2._col_sample_rate_per_tree;
        }
        if (p2._colsample_bylevel != 1.0) {
            params.put("colsample_bylevel", p2._colsample_bylevel);
            p2._col_sample_rate = p2._colsample_bylevel;
        } else {
            params.put("colsample_bylevel", p2._col_sample_rate);
            p2._colsample_bylevel = p2._col_sample_rate;
        }
        if (p2._colsample_bynode != 1.0) {
            params.put("colsample_bynode", p2._colsample_bynode);
        }
        if (p2._max_delta_step != 0.0f) {
            params.put("max_delta_step", Float.valueOf(p2._max_delta_step));
            p2._max_abs_leafnode_pred = p2._max_delta_step;
        } else {
            params.put("max_delta_step", Float.valueOf(p2._max_abs_leafnode_pred));
            p2._max_delta_step = p2._max_abs_leafnode_pred;
        }
        params.put("seed", (int)(p2._seed % Integer.MAX_VALUE));
        params.put("grow_policy", p2._grow_policy.toString());
        if (p2._grow_policy == XGBoostParameters.GrowPolicy.lossguide) {
            params.put("max_bin", p2._max_bins);
            params.put("max_leaves", p2._max_leaves);
        }
        params.put("booster", p2._booster.toString());
        if (p2._booster == XGBoostParameters.Booster.dart) {
            params.put("sample_type", p2._sample_type.toString());
            params.put("normalize_type", p2._normalize_type.toString());
            params.put("rate_drop", Float.valueOf(p2._rate_drop));
            params.put("one_drop", p2._one_drop ? "1" : "0");
            params.put("skip_drop", Float.valueOf(p2._skip_drop));
        }
        XGBoostParameters.Backend actualBackend = XGBoostModel.getActualBackend(p2, true);
        XGBoostParameters.TreeMethod actualTreeMethod = XGBoostModel.getActualTreeMethod(p2);
        if (actualBackend == XGBoostParameters.Backend.gpu) {
            if (p2._gpu_id != null && p2._gpu_id.length > 0) {
                params.put("gpu_id", p2._gpu_id[0]);
            } else {
                params.put("gpu_id", 0);
            }
            if (p2._booster == XGBoostParameters.Booster.gblinear) {
                LOG.info((Object)"Using gpu_coord_descent updater.");
                params.put("updater", "gpu_coord_descent");
            } else {
                LOG.info((Object)"Using gpu_hist tree method.");
                params.put("max_bin", p2._max_bins);
                params.put("updater", "grow_gpu_hist");
            }
        } else if (p2._booster == XGBoostParameters.Booster.gblinear) {
            LOG.info((Object)"Using coord_descent updater.");
            params.put("updater", "coord_descent");
        } else if (H2O.CLOUD.size() > 1 && p2._tree_method == XGBoostParameters.TreeMethod.auto && p2._monotone_constraints != null) {
            LOG.info((Object)"Using hist tree method for distributed computation with monotone_constraints.");
            params.put("tree_method", actualTreeMethod.toString());
            params.put("max_bin", p2._max_bins);
        } else {
            LOG.info((Object)("Using " + p2._tree_method.toString() + " tree method."));
            params.put("tree_method", actualTreeMethod.toString());
            if (p2._tree_method == XGBoostParameters.TreeMethod.hist) {
                params.put("max_bin", p2._max_bins);
            }
        }
        if (p2._min_child_weight != 1.0) {
            LOG.info((Object)"Using user-provided parameter min_child_weight instead of min_rows.");
            params.put("min_child_weight", p2._min_child_weight);
            p2._min_rows = p2._min_child_weight;
        } else {
            params.put("min_child_weight", p2._min_rows);
            p2._min_child_weight = p2._min_rows;
        }
        if (p2._gamma != 0.0f) {
            LOG.info((Object)"Using user-provided parameter gamma instead of min_split_improvement.");
            params.put("gamma", Float.valueOf(p2._gamma));
            p2._min_split_improvement = p2._gamma;
        } else {
            params.put("gamma", Float.valueOf(p2._min_split_improvement));
            p2._gamma = p2._min_split_improvement;
        }
        params.put("lambda", Float.valueOf(p2._reg_lambda));
        params.put("alpha", Float.valueOf(p2._reg_alpha));
        if (p2._scale_pos_weight != 1.0f) {
            params.put("scale_pos_weight", Float.valueOf(p2._scale_pos_weight));
        }
        if (nClasses == 2) {
            params.put("objective", XGBoostMojoModel.ObjectiveType.BINARY_LOGISTIC.getId());
        } else if (nClasses == 1) {
            if (p2._distribution == DistributionFamily.gamma) {
                params.put("objective", XGBoostMojoModel.ObjectiveType.REG_GAMMA.getId());
            } else if (p2._distribution == DistributionFamily.tweedie) {
                params.put("objective", XGBoostMojoModel.ObjectiveType.REG_TWEEDIE.getId());
                params.put("tweedie_variance_power", p2._tweedie_power);
            } else if (p2._distribution == DistributionFamily.poisson) {
                params.put("objective", XGBoostMojoModel.ObjectiveType.COUNT_POISSON.getId());
            } else {
                if (p2._distribution != DistributionFamily.gaussian && p2._distribution != DistributionFamily.AUTO) throw new UnsupportedOperationException("No support for distribution=" + p2._distribution.toString());
                params.put("objective", XGBoostMojoModel.ObjectiveType.REG_SQUAREDERROR.getId());
            }
        } else {
            params.put("objective", XGBoostMojoModel.ObjectiveType.MULTI_SOFTPROB.getId());
            params.put("num_class", nClasses);
        }
        assert (XGBoostMojoModel.ObjectiveType.fromXGBoost((String)params.get("objective")) != null);
        int nthreadMax = XGBoostModel.getMaxNThread();
        int n2 = nthread = p2._nthread != -1 ? Math.min(p2._nthread, nthreadMax) : nthreadMax;
        if (nthread < p2._nthread) {
            LOG.warn((Object)("Requested nthread=" + p2._nthread + " but the cluster has only " + nthreadMax + " available.Training will use nthread=" + nthread + " instead of the user specified value."));
        }
        params.put("nthread", nthread);
        Map<String, Integer> monotoneConstraints = p2.monotoneConstraints();
        if (!monotoneConstraints.isEmpty()) {
            int constraintsUsed = 0;
            StringBuilder sb = new StringBuilder();
            sb.append("(");
            for (String coef : coefNames) {
                String direction;
                if (monotoneConstraints.containsKey(coef)) {
                    direction = monotoneConstraints.get(coef).toString();
                    ++constraintsUsed;
                } else {
                    direction = "0";
                }
                sb.append(direction);
                sb.append(",");
            }
            sb.replace(sb.length() - 1, sb.length(), ")");
            params.put("monotone_constraints", sb.toString());
            assert (constraintsUsed == monotoneConstraints.size());
        }
        if ((interactionConstraints = p2._interaction_constraints) != null && interactionConstraints.length > 0) {
            if (!p2._categorical_encoding.equals((Object)Model.Parameters.CategoricalEncodingScheme.OneHotInternal)) {
                throw new IllegalArgumentException("No support interaction constraint for categorical encoding = " + p2._categorical_encoding.toString() + ". Constraint interactions are available only for ``AUTO`` (``one_hot_internal`` or ``OneHotInternal``) categorical encoding.");
            }
            params.put("interaction_constraints", XGBoostModel.createInteractions(interactionConstraints, coefNames, p2));
        }
        LOG.info((Object)"XGBoost Parameters:");
        for (Map.Entry s2 : params.entrySet()) {
            LOG.info((Object)(" " + (String)s2.getKey() + " = " + s2.getValue()));
        }
        LOG.info((Object)"");
        return Collections.unmodifiableMap(params);
    }

    private static String createInteractions(String[][] interaction_constraints, String[] coefNames, XGBoostParameters params) {
        StringBuilder sb = new StringBuilder();
        sb.append("[");
        for (String[] list : interaction_constraints) {
            sb.append("[");
            for (String item : list) {
                if (item.equals(params._response_column)) {
                    throw new IllegalArgumentException("'interaction_constraints': Column with the name '" + item + "'is used as response column and cannot be used in interaction.");
                }
                if (item.equals(params._weights_column)) {
                    throw new IllegalArgumentException("'interaction_constraints': Column with the name '" + item + "'is used as weights column and cannot be used in interaction.");
                }
                if (item.equals(params._fold_column)) {
                    throw new IllegalArgumentException("'interaction_constraints': Column with the name '" + item + "'is used as fold column and cannot be used in interaction.");
                }
                if (params._ignored_columns != null && ArrayUtils.find(params._ignored_columns, item) != -1) {
                    throw new IllegalArgumentException("'interaction_constraints': Column with the name '" + item + "'is set in ignored columns and cannot be used in interaction.");
                }
                int start = ArrayUtils.findWithPrefix(coefNames, item);
                if (start == -1) {
                    throw new IllegalArgumentException("'interaction_constraints': Column with name '" + item + "' is not in the frame.");
                }
                if (start > -1) {
                    sb.append(start).append(",");
                    continue;
                }
                start = -start - 2;
                assert (coefNames[start].startsWith(item)) : "The column name should be find correctly.";
                for (int end = start; end < coefNames.length && coefNames[end].startsWith(item); ++end) {
                    sb.append(end).append(",");
                }
            }
            sb.replace(sb.length() - 1, sb.length(), "],");
        }
        sb.replace(sb.length() - 1, sb.length(), "]");
        return sb.toString();
    }

    public static BoosterParms createParams(XGBoostParameters p2, int nClasses, String[] coefNames) {
        return BoosterParms.fromMap(XGBoostModel.createParamsMap(p2, nClasses, coefNames));
    }

    protected XGBoostModel deepClone(Key<XGBoostModel> result) {
        XGBoostModel newModel = IcedUtils.deepCopy(this);
        newModel._key = result;
        ((XGBoostOutput)newModel._output).clearModelMetrics(false);
        ((XGBoostOutput)newModel._output)._training_metrics = null;
        ((XGBoostOutput)newModel._output)._validation_metrics = null;
        return newModel;
    }

    static int getMaxNThread() {
        if (System.getProperty(PROP_NTHREAD) != null) {
            return Integer.getInteger(PROP_NTHREAD);
        }
        int maxNodesPerHost = 1;
        HashSet<String> checkedNodes = new HashSet<String>();
        for (H2ONode node : H2O.CLOUD.members()) {
            String nodeHost = node.getIp();
            if (checkedNodes.contains(nodeHost)) continue;
            checkedNodes.add(nodeHost);
            long cnt = Stream.of(H2O.CLOUD.members()).filter(h2 -> h2.getIp().equals(nodeHost)).count();
            if (cnt <= (long)maxNodesPerHost) continue;
            maxNodesPerHost = (int)cnt;
        }
        return Math.max(1, H2O.ARGS.nthreads / maxNodesPerHost);
    }

    @Override
    protected AutoBuffer writeAll_impl(AutoBuffer ab) {
        ab.putKey(this.model_info.getDataInfoKey());
        ab.putKey(this.model_info.getAuxNodeWeightsKey());
        return super.writeAll_impl(ab);
    }

    @Override
    protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
        ab.getKey(this.model_info.getDataInfoKey(), fs);
        ab.getKey(this.model_info.getAuxNodeWeightsKey(), fs);
        return super.readAll_impl(ab, fs);
    }

    @Override
    public XGBoostMojoWriter getMojo() {
        return new XGBoostMojoWriter(this);
    }

    private ModelMetrics makeMetrics(Frame data, Frame originalData, boolean isTrain, String description) {
        LOG.debug((Object)("Making metrics: " + description));
        return new XGBoostModelMetrics((XGBoostOutput)this._output, data, originalData, isTrain, this).compute();
    }

    final void doScoring(Frame _train, Frame _trainOrig, Frame _valid, Frame _validOrig) {
        ModelMetrics mm4;
        ((XGBoostOutput)this._output)._training_metrics = mm4 = this.makeMetrics(_train, _trainOrig, true, "Metrics reported on training frame");
        ((XGBoostOutput)this._output)._scored_train[((XGBoostOutput)this._output)._ntrees].fillFrom(mm4);
        this.addModelMetrics(mm4);
        if (_valid != null) {
            ((XGBoostOutput)this._output)._validation_metrics = mm4 = this.makeMetrics(_valid, _validOrig, false, "Metrics reported on validation frame");
            ((XGBoostOutput)this._output)._scored_valid[((XGBoostOutput)this._output)._ntrees].fillFrom(mm4);
            this.addModelMetrics(mm4);
        }
    }

    @Override
    protected Frame postProcessPredictions(Frame adaptedFrame, Frame predictFr, Job j2) {
        return PlattScalingHelper.postProcessPredictions(predictFr, j2, (PlattScalingHelper.OutputWithCalibration)((Object)this._output));
    }

    @Override
    protected double[] score0(double[] data, double[] preds) {
        return this.score0(data, preds, 0.0);
    }

    @Override
    public double[] score0(double[] data, double[] preds, double offset) {
        float[] out;
        DataInfo di = this.model_info.dataInfo();
        assert (di != null);
        MutableOneHotEncoderFVec row = new MutableOneHotEncoderFVec(di, ((XGBoostOutput)this._output)._sparse);
        row.setInput(data);
        Predictor predictor = this.makePredictor(true);
        if (((XGBoostOutput)this._output).hasOffset()) {
            out = predictor.predict((FVec)row, (float)offset);
        } else {
            if (offset != 0.0) {
                throw new UnsupportedOperationException("Unsupported: offset != 0");
            }
            out = predictor.predict(row);
        }
        return XGBoostMojoModel.toPreds(data, out, preds, ((XGBoostOutput)this._output).nclasses(), ((XGBoostOutput)this._output)._priorClassDist, this.defaultThreshold());
    }

    @Override
    protected XGBoostBigScorePredict setupBigScorePredict(Model.BigScore bs) {
        return this.setupBigScorePredict(false);
    }

    public XGBoostBigScorePredict setupBigScorePredict(boolean isTrain) {
        DataInfo di = this.model_info().scoringInfo(isTrain);
        return PredictConfiguration.useJavaScoring() ? this.setupBigScorePredictJava(di) : this.setupBigScorePredictNative(di);
    }

    private XGBoostBigScorePredict setupBigScorePredictNative(DataInfo di) {
        BoosterParms boosterParms = XGBoostModel.createParams((XGBoostParameters)this._parms, ((XGBoostOutput)this._output).nclasses(), di.coefNames());
        return new XGBoostNativeBigScorePredict(this.model_info, (XGBoostParameters)this._parms, (XGBoostOutput)this._output, di, boosterParms, this.defaultThreshold());
    }

    private XGBoostBigScorePredict setupBigScorePredictJava(DataInfo di) {
        return new XGBoostJavaBigScorePredict(this.model_info, (XGBoostOutput)this._output, di, (XGBoostParameters)this._parms, this.defaultThreshold());
    }

    public XGBoostVariableImportance setupVarImp() {
        if (PredictConfiguration.useJavaScoring()) {
            return new XGBoostJavaVariableImportance(this.model_info);
        }
        return new XGBoostNativeVariableImportance(this._key, this.model_info.getFeatureMap());
    }

    @Override
    public Frame scoreContributions(Frame frame, Key<Frame> destination_key) {
        return this.scoreContributions(frame, destination_key, null, new Model.Contributions.ContributionsOptions());
    }

    @Override
    public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Frame> j2, Model.Contributions.ContributionsOptions options) {
        Frame adaptFrm = new Frame(frame);
        this.adaptTestForTrain(adaptFrm, true, false);
        DataInfo di = this.model_info().dataInfo();
        assert (di != null);
        String[] featureContribNames = Model.Contributions.ContributionsOutputFormat.Compact.equals((Object)options._outputFormat) ? ((XGBoostOutput)this._output).features() : di.coefNames();
        String[] outputNames = ArrayUtils.append(featureContribNames, "BiasTerm");
        if (options.isSortingRequired()) {
            ContributionComposer contributionComposer = new ContributionComposer();
            int topNAdjusted = contributionComposer.checkAndAdjustInput(options._topN, featureContribNames.length);
            int bottomNAdjusted = contributionComposer.checkAndAdjustInput(options._bottomN, featureContribNames.length);
            int outputSize = Math.min((topNAdjusted + bottomNAdjusted) * 2, featureContribNames.length * 2);
            String[] names = new String[outputSize + 1];
            byte[] types = new byte[outputSize + 1];
            String[][] domains = new String[outputSize + 1][outputNames.length];
            this.composeScoreContributionTaskMetadata(names, types, domains, featureContribNames, options);
            return ((PredictTreeSHAPTask)new PredictTreeSHAPSortingTask(di, this.model_info(), (XGBoostOutput)this._output, options).withPostMapAction(JobUpdatePostMap.forJob(j2)).doAll(types, adaptFrm)).outputFrame(destination_key, names, domains);
        }
        return ((PredictTreeSHAPTask)new PredictTreeSHAPTask(di, this.model_info(), (XGBoostOutput)this._output, options).withPostMapAction(JobUpdatePostMap.forJob(j2)).doAll(outputNames.length, (byte)3, adaptFrm)).outputFrame(destination_key, outputNames, null);
    }

    @Override
    public Model.UpdateAuxTreeWeights.UpdateAuxTreeWeightsReport updateAuxTreeWeights(Frame frame, String weightsColumn) {
        if (weightsColumn == null) {
            throw new IllegalArgumentException("Weights column name is not defined");
        }
        Frame adaptFrm = new Frame(frame);
        Vec weights = adaptFrm.remove(weightsColumn);
        if (weights == null) {
            throw new IllegalArgumentException("Input frame doesn't contain weights column `" + weightsColumn + "`");
        }
        this.adaptTestForTrain(adaptFrm, true, false);
        Frame featureFrm = new Frame(((XGBoostOutput)this._output).features(), frame.vecs(((XGBoostOutput)this._output).features()));
        featureFrm.add(weightsColumn, weights);
        DataInfo di = this.model_info().dataInfo();
        assert (di != null);
        double[][] nodeWeights = ((UpdateAuxTreeWeightsTask)new UpdateAuxTreeWeightsTask(((XGBoostParameters)this._parms)._distribution, di, this.model_info(), (XGBoostOutput)this._output).doAll(featureFrm)).getNodeWeights();
        AuxNodeWeights auxNodeWeights = new AuxNodeWeights(this.model_info().getAuxNodeWeightsKey(), nodeWeights);
        DKV.put(auxNodeWeights);
        Model.UpdateAuxTreeWeights.UpdateAuxTreeWeightsReport report = new Model.UpdateAuxTreeWeights.UpdateAuxTreeWeightsReport();
        report._warn_classes = new int[0];
        report._warn_trees = new int[0];
        block0: for (int treeId = 0; treeId < nodeWeights.length; ++treeId) {
            if (nodeWeights[treeId] == null) continue;
            for (double w2 : nodeWeights[treeId]) {
                if (w2 != 0.0) continue;
                report._warn_trees = ArrayUtils.append(report._warn_trees, treeId);
                report._warn_classes = ArrayUtils.append(report._warn_classes, 0);
                continue block0;
            }
        }
        return report;
    }

    @Override
    public Frame scoreLeafNodeAssignment(Frame frame, Model.LeafNodeAssignment.LeafNodeAssignmentType type, Key<Frame> destination_key) {
        AssignLeafNodeTask task = AssignLeafNodeTask.make(this.model_info.scoringInfo(false), (XGBoostOutput)this._output, this.model_info._boosterBytes, type);
        Frame adaptFrm = new Frame(frame);
        this.adaptTestForTrain(adaptFrm, true, false);
        return task.execute(adaptFrm, destination_key);
    }

    private void setDataInfoToOutput(DataInfo dinfo) {
        ((XGBoostOutput)this._output).setNames(dinfo._adaptedFrame.names(), dinfo._adaptedFrame.typesStr());
        ((XGBoostOutput)this._output)._domains = dinfo._adaptedFrame.domains();
        ((XGBoostOutput)this._output)._nums = dinfo._nums;
        ((XGBoostOutput)this._output)._cats = dinfo._cats;
        ((XGBoostOutput)this._output)._catOffsets = dinfo._catOffsets;
        ((XGBoostOutput)this._output)._useAllFactorLevels = dinfo._useAllFactorLevels;
    }

    @Override
    protected Futures remove_impl(Futures fs, boolean cascade) {
        AuxNodeWeights anw;
        DataInfo di = this.model_info().dataInfo();
        if (di != null) {
            di.remove(fs);
        }
        if ((anw = this.model_info().auxNodeWeights()) != null) {
            anw.remove(fs);
        }
        if (((XGBoostOutput)this._output)._calib_model != null) {
            ((XGBoostOutput)this._output)._calib_model.remove(fs);
        }
        return super.remove_impl(fs, cascade);
    }

    @Override
    public SharedTreeGraph convert(int treeNumber, String treeClassName) {
        GradBooster booster = XGBoostJavaMojoModel.makePredictor(this.model_info._boosterBytes, this.model_info.auxNodeWeightBytes()).getBooster();
        if (!(booster instanceof GBTree)) {
            throw new IllegalArgumentException("XGBoost model is not backed by a tree-based booster. Booster class is " + booster.getClass().getCanonicalName());
        }
        RegTree[][] groupedTrees = ((GBTree)booster).getGroupedTrees();
        int treeClass = this.getXGBoostClassIndex(treeClassName);
        if (treeClass >= groupedTrees.length) {
            throw new IllegalArgumentException(String.format("Given XGBoost model does not have given class '%s'.", treeClassName));
        }
        RegTree[] treesInGroup = groupedTrees[treeClass];
        if (treeNumber >= treesInGroup.length || treeNumber < 0) {
            throw new IllegalArgumentException(String.format("There is no such tree number for given class. Total number of trees is %d.", treesInGroup.length));
        }
        RegTreeNode[] treeNodes = treesInGroup[treeNumber].getNodes();
        RegTreeNodeStat[] treeNodeStats = treesInGroup[treeNumber].getStats();
        assert (treeNodes.length >= 1);
        SharedTreeGraph sharedTreeGraph = new SharedTreeGraph();
        SharedTreeSubgraph sharedTreeSubgraph = sharedTreeGraph.makeSubgraph(((XGBoostOutput)this._output)._training_metrics._description);
        XGBoostUtils.FeatureProperties featureProperties = XGBoostUtils.assembleFeatureNames(this.model_info.dataInfo());
        XGBoostModel.constructSubgraph(treeNodes, treeNodeStats, sharedTreeSubgraph.makeRootNode(), 0, sharedTreeSubgraph, featureProperties, true);
        return sharedTreeGraph;
    }

    private static void constructSubgraph(RegTreeNode[] xgBoostNodes, RegTreeNodeStat[] xgBoostNodeStats, SharedTreeNode sharedTreeNode, int nodeIndex, SharedTreeSubgraph sharedTreeSubgraph, XGBoostUtils.FeatureProperties featureProperties, boolean inclusiveNA) {
        RegTreeNode xgBoostNode = xgBoostNodes[nodeIndex];
        RegTreeNodeStat xgBoostNodeStat = xgBoostNodeStats[nodeIndex];
        if (featureProperties._oneHotEncoded[xgBoostNode.getSplitIndex()]) {
            sharedTreeNode.setSplitValue(1.0f);
        } else {
            sharedTreeNode.setSplitValue(xgBoostNode.getSplitCondition());
        }
        sharedTreeNode.setPredValue(xgBoostNode.getLeafValue());
        sharedTreeNode.setInclusiveNa(inclusiveNA);
        sharedTreeNode.setNodeNumber(nodeIndex);
        sharedTreeNode.setGain(xgBoostNodeStat.getGain());
        sharedTreeNode.setWeight(xgBoostNodeStat.getCover());
        if (!xgBoostNode.isLeaf()) {
            sharedTreeNode.setCol(xgBoostNode.getSplitIndex(), featureProperties._names[xgBoostNode.getSplitIndex()]);
            XGBoostModel.constructSubgraph(xgBoostNodes, xgBoostNodeStats, sharedTreeSubgraph.makeLeftChildNode(sharedTreeNode), xgBoostNode.getLeftChildIndex(), sharedTreeSubgraph, featureProperties, xgBoostNode.default_left());
            XGBoostModel.constructSubgraph(xgBoostNodes, xgBoostNodeStats, sharedTreeSubgraph.makeRightChildNode(sharedTreeNode), xgBoostNode.getRightChildIndex(), sharedTreeSubgraph, featureProperties, !xgBoostNode.default_left());
        }
    }

    @Override
    public SharedTreeGraph convert(int treeNumber, String treeClass, ConvertTreeOptions options) {
        return this.convert(treeNumber, treeClass);
    }

    private int getXGBoostClassIndex(String treeClass) {
        ModelCategory modelCategory = ((XGBoostOutput)this._output).getModelCategory();
        if (ModelCategory.Regression.equals((Object)modelCategory) && treeClass != null && !treeClass.isEmpty()) {
            throw new IllegalArgumentException("There should be no tree class specified for regression.");
        }
        if (treeClass == null || treeClass.isEmpty()) {
            switch (modelCategory) {
                case Binomial: 
                case Regression: {
                    return 0;
                }
            }
            throw new IllegalArgumentException(String.format("Model category '%s' requires tree class to be specified.", new Object[]{modelCategory}));
        }
        String[] domain = ((XGBoostOutput)this._output)._domains[((XGBoostOutput)this._output)._domains.length - 1];
        int treeClassIndex = ArrayUtils.find(domain, treeClass);
        if (ModelCategory.Binomial.equals((Object)modelCategory) && treeClassIndex != 0) {
            throw new IllegalArgumentException(String.format("For binomial XGBoost model, only one tree for class %s has been built.", domain[0]));
        }
        if (treeClassIndex < 0) {
            throw new IllegalArgumentException(String.format("No such class '%s' in tree.", treeClass));
        }
        return treeClassIndex;
    }

    @Override
    public boolean isFeatureUsedInPredict(String featureName) {
        int featureIdx = ArrayUtils.find(((XGBoostOutput)this._output)._varimp._names, featureName);
        if (featureIdx == -1 && ((XGBoostOutput)this._output)._catOffsets.length > 1) {
            featureIdx = ArrayUtils.find(((XGBoostOutput)this._output)._names, featureName);
            if (featureIdx == -1 || !((XGBoostOutput)this._output)._column_types[featureIdx].equals("Enum")) {
                return false;
            }
            for (int i2 = 0; i2 < ((XGBoostOutput)this._output)._varimp._names.length; ++i2) {
                if (!((XGBoostOutput)this._output)._varimp._names[i2].startsWith(featureName.concat(".")) || ((XGBoostOutput)this._output)._varimp._varimp[i2] == 0.0f) continue;
                return true;
            }
            return false;
        }
        return featureIdx != -1 && (double)((XGBoostOutput)this._output)._varimp._varimp[featureIdx] != 0.0;
    }

    @Override
    protected boolean toJavaCheckTooBig() {
        return this._output == null || ((XGBoostOutput)this._output)._ntrees * ((XGBoostParameters)this._parms)._max_depth > 1000;
    }

    @Override
    protected SBPrintStream toJavaInit(SBPrintStream sb, CodeGeneratorPipeline fileCtx) {
        sb.nl();
        sb.ip("public boolean isSupervised() { return true; }").nl();
        sb.ip("public int nclasses() { return ").p(((XGBoostOutput)this._output).nclasses()).p("; }").nl();
        return sb;
    }

    @Override
    protected void toJavaPredictBody(SBPrintStream sb, CodeGeneratorPipeline classCtx, CodeGeneratorPipeline fileCtx, boolean verboseCode) {
        String namePrefix = JCodeGen.toJavaId(this._key.toString());
        Predictor p2 = this.makePredictor(false);
        XGBoostPojoWriter.make(p2, namePrefix, (XGBoostOutput)this._output, this.defaultThreshold()).renderJavaPredictBody(sb, fileCtx);
    }

    public FeatureInteractions getFeatureInteractions(int maxInteractionDepth, int maxTreeDepth, int maxDeepening) {
        FeatureInteractions featureInteractions = new FeatureInteractions();
        for (int i2 = 0; i2 < ((XGBoostParameters)this._parms)._ntrees; ++i2) {
            FeatureInteractions currentTreeFeatureInteractions = new FeatureInteractions();
            SharedTreeGraph sharedTreeGraph = this.convert(i2, null);
            assert (sharedTreeGraph.subgraphArray.size() == 1);
            SharedTreeSubgraph tree = sharedTreeGraph.subgraphArray.get(0);
            ArrayList<SharedTreeNode> interactionPath = new ArrayList<SharedTreeNode>();
            HashSet<String> memo = new HashSet<String>();
            FeatureInteractions.collectFeatureInteractions(tree.rootNode, interactionPath, 0.0, 0.0, 1.0, 0, 0, currentTreeFeatureInteractions, memo, maxInteractionDepth, maxTreeDepth, maxDeepening, i2, false);
            featureInteractions.mergeWith(currentTreeFeatureInteractions);
        }
        return featureInteractions;
    }

    @Override
    public TwoDimTable[][] getFeatureInteractionsTable(int maxInteractionDepth, int maxTreeDepth, int maxDeepening) {
        return FeatureInteractions.getFeatureInteractionsTable(this.getFeatureInteractions(maxInteractionDepth, maxTreeDepth, maxDeepening));
    }

    Predictor makePredictor(boolean scoringOnly) {
        return PredictorFactory.makePredictor(this.model_info._boosterBytes, this.model_info.auxNodeWeightBytes(), scoringOnly);
    }

    @Override
    public double getFriedmanPopescusH(Frame frame, String[] vars) {
        int nclasses = ((XGBoostOutput)this._output).nclasses() > 2 ? ((XGBoostOutput)this._output).nclasses() : 1;
        SharedTreeSubgraph[][] sharedTreeSubgraphs = new SharedTreeSubgraph[((XGBoostParameters)this._parms)._ntrees][nclasses];
        for (int i2 = 0; i2 < ((XGBoostParameters)this._parms)._ntrees; ++i2) {
            for (int j2 = 0; j2 < nclasses; ++j2) {
                SharedTreeGraph graph = this.convert(i2, ((XGBoostOutput)this._output).classNames()[j2]);
                assert (graph.subgraphArray.size() == 1);
                sharedTreeSubgraphs[i2][j2] = graph.subgraphArray.get(0);
            }
        }
        return FriedmanPopescusH.h(frame, vars, ((XGBoostParameters)this._parms)._learn_rate, sharedTreeSubgraphs);
    }

    public static class XGBoostParameters
    extends Model.Parameters
    implements Model.GetNTrees,
    PlattScalingHelper.ParamsWithCalibration {
        public boolean _quiet_mode = true;
        public int _ntrees = 50;
        public int _n_estimators;
        public int _max_depth = 6;
        public double _min_rows = 1.0;
        public double _min_child_weight = 1.0;
        public double _learn_rate = 0.3;
        public double _eta = 0.3;
        public double _learn_rate_annealing = 1.0;
        public double _sample_rate = 1.0;
        public double _subsample = 1.0;
        public double _col_sample_rate = 1.0;
        public double _colsample_bylevel = 1.0;
        public double _colsample_bynode = 1.0;
        public double _col_sample_rate_per_tree = 1.0;
        public double _colsample_bytree = 1.0;
        public KeyValue[] _monotone_constraints;
        public String[][] _interaction_constraints;
        public float _max_abs_leafnode_pred = 0.0f;
        public float _max_delta_step = 0.0f;
        public int _score_tree_interval = 0;
        public int _initial_score_interval = 4000;
        public int _score_interval = 4000;
        public float _min_split_improvement = 0.0f;
        public float _gamma;
        public int _nthread = -1;
        public String _save_matrix_directory;
        public boolean _build_tree_one_node = false;
        public int _max_bins = 256;
        public int _max_leaves = 0;
        public TreeMethod _tree_method = TreeMethod.auto;
        public GrowPolicy _grow_policy = GrowPolicy.depthwise;
        public Booster _booster = Booster.gbtree;
        public DMatrixType _dmatrix_type = DMatrixType.auto;
        public float _reg_lambda = 1.0f;
        public float _reg_alpha = 0.0f;
        public float _scale_pos_weight = 1.0f;
        public boolean _calibrate_model;
        public Key<Frame> _calibration_frame;
        public DartSampleType _sample_type = DartSampleType.uniform;
        public DartNormalizeType _normalize_type = DartNormalizeType.tree;
        public float _rate_drop = 0.0f;
        public boolean _one_drop = false;
        public float _skip_drop = 0.0f;
        public int[] _gpu_id;
        public Backend _backend = Backend.auto;
        static String[] CHECKPOINT_NON_MODIFIABLE_FIELDS = new String[]{"_tree_method", "_grow_policy", "_booster", "_sample_rate", "_max_depth", "_min_rows"};

        @Override
        public String algoName() {
            return "XGBoost";
        }

        @Override
        public String fullName() {
            return "XGBoost";
        }

        @Override
        public String javaName() {
            return XGBoostModel.class.getName();
        }

        @Override
        public long progressUnits() {
            return this._ntrees;
        }

        Map<String, Object> gpuIncompatibleParams() {
            HashMap<String, Object> incompat = new HashMap<String, Object>();
            if (TreeMethod.auto != this._tree_method && TreeMethod.hist != this._tree_method && Booster.gblinear != this._booster) {
                incompat.put("tree_method", "Only auto and hist are supported tree_method on GPU backend.");
            }
            if (this._max_depth > 15 || this._max_depth < 1) {
                incompat.put("max_depth", this._max_depth + " . Max depth must be greater than 0 and lower than 16 for GPU backend.");
            }
            if (this._grow_policy == GrowPolicy.lossguide) {
                incompat.put("grow_policy", (Object)GrowPolicy.lossguide);
            }
            return incompat;
        }

        Map<String, Integer> monotoneConstraints() {
            if (this._monotone_constraints == null || this._monotone_constraints.length == 0) {
                return Collections.emptyMap();
            }
            HashMap<String, Integer> constraints = new HashMap<String, Integer>(this._monotone_constraints.length);
            for (KeyValue constraint : this._monotone_constraints) {
                double val = constraint.getValue();
                if (val == 0.0) continue;
                if (constraints.containsKey(constraint.getKey())) {
                    throw new IllegalStateException("Duplicate definition of constraint for feature '" + constraint.getKey() + "'.");
                }
                int direction = val < 0.0 ? -1 : 1;
                constraints.put(constraint.getKey(), direction);
            }
            return constraints;
        }

        @Override
        public int getNTrees() {
            return this._ntrees;
        }

        @Override
        public Frame getCalibrationFrame() {
            return this._calibration_frame != null ? this._calibration_frame.get() : null;
        }

        @Override
        public boolean calibrateModel() {
            return this._calibrate_model;
        }

        @Override
        public Model.Parameters getParams() {
            return this;
        }

        public static enum Backend {
            auto,
            gpu,
            cpu;

        }

        public static enum DMatrixType {
            auto,
            dense,
            sparse;

        }

        public static enum DartNormalizeType {
            tree,
            forest;

        }

        public static enum DartSampleType {
            uniform,
            weighted;

        }

        public static enum Booster {
            gbtree,
            gblinear,
            dart;

        }

        public static enum GrowPolicy {
            depthwise,
            lossguide;

        }

        public static enum TreeMethod {
            auto,
            exact,
            approx,
            hist;

        }
    }
}

