/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.xgboost;

import biz.k11i.xgboost.gbm.GBTree;
import biz.k11i.xgboost.gbm.GradBooster;
import biz.k11i.xgboost.tree.RegTree;
import biz.k11i.xgboost.tree.RegTreeNode;
import hex.CVModelBuilder;
import hex.DataInfo;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ScoreKeeper;
import hex.genmodel.algos.xgboost.XGBoostJavaMojoModel;
import hex.genmodel.utils.DistributionFamily;
import hex.glm.GLMTask;
import hex.tree.PlattScalingHelper;
import hex.tree.SharedTree;
import hex.tree.TreeUtils;
import hex.tree.xgboost.XGBoostExtension;
import hex.tree.xgboost.XGBoostExtensionCheck;
import hex.tree.xgboost.XGBoostExternalCVModelBuilder;
import hex.tree.xgboost.XGBoostGPUCVModelBuilder;
import hex.tree.xgboost.XGBoostGPULock;
import hex.tree.xgboost.XGBoostModel;
import hex.tree.xgboost.XGBoostModelInfo;
import hex.tree.xgboost.XGBoostOutput;
import hex.tree.xgboost.XGBoostUtils;
import hex.tree.xgboost.XgbVarImp;
import hex.tree.xgboost.exec.LocalXGBoostExecutor;
import hex.tree.xgboost.exec.RemoteXGBoostExecutor;
import hex.tree.xgboost.exec.XGBoostExecutor;
import hex.tree.xgboost.predict.XGBoostVariableImportance;
import hex.tree.xgboost.remote.SteamExecutorStarter;
import hex.tree.xgboost.util.FeatureScore;
import hex.tree.xgboost.util.GpuUtils;
import hex.util.CheckpointUtils;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.apache.log4j.Logger;
import water.DKV;
import water.ExtensionManager;
import water.H2O;
import water.Job;
import water.Key;
import water.Scope;
import water.Value;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.RebalanceDataSet;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Timer;
import water.util.TwoDimTable;

public class XGBoost
extends ModelBuilder<XGBoostModel, XGBoostModel.XGBoostParameters, XGBoostOutput>
implements PlattScalingHelper.ModelBuilderWithCalibration<XGBoostModel, XGBoostModel.XGBoostParameters, XGBoostOutput> {
    private static final Logger LOG = Logger.getLogger(XGBoost.class);
    private static final double FILL_RATIO_THRESHOLD = 0.25;
    private int _ntrees;
    private transient Frame _calib;

    @Override
    public boolean haveMojo() {
        return true;
    }

    @Override
    public boolean havePojo() {
        return true;
    }

    @Override
    public ModelBuilder.BuilderVisibility builderVisibility() {
        if (ExtensionManager.getInstance().isCoreExtensionsEnabled(XGBoostExtension.NAME)) {
            return ModelBuilder.BuilderVisibility.Stable;
        }
        return ModelBuilder.BuilderVisibility.Experimental;
    }

    @Override
    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Regression, ModelCategory.Binomial, ModelCategory.Multinomial};
    }

    public XGBoost(XGBoostModel.XGBoostParameters parms) {
        super(parms);
        this.init(false);
    }

    public XGBoost(XGBoostModel.XGBoostParameters parms, Key<XGBoostModel> key) {
        super(parms, key);
        this.init(false);
    }

    public XGBoost(boolean startup_once) {
        super(new XGBoostModel.XGBoostParameters(), startup_once);
    }

    @Override
    public boolean isSupervised() {
        return true;
    }

    @Override
    protected int nModelsInParallel(int folds) {
        if (XGBoostModel.getActualBackend((XGBoostModel.XGBoostParameters)this._parms, false) == XGBoostModel.XGBoostParameters.Backend.gpu) {
            if (((XGBoostModel.XGBoostParameters)this._parms)._gpu_id != null && ((XGBoostModel.XGBoostParameters)this._parms)._gpu_id.length > 0) {
                return ((XGBoostModel.XGBoostParameters)this._parms)._gpu_id.length;
            }
            return GpuUtils.numGPUs(H2O.CLOUD.members()[0]);
        }
        return this.nModelsInParallel(folds, 2);
    }

    @Override
    protected XGBoostDriver trainModelImpl() {
        return new XGBoostDriver();
    }

    @Override
    public void init(boolean expensive) {
        DistributionFamily[] allowed_distributions;
        super.init(expensive);
        if (H2O.CLOUD.size() > 1 && H2O.SELF.getSecurityManager().securityEnabled) {
            if (H2O.ARGS.allow_insecure_xgboost) {
                LOG.info((Object)"Executing XGBoost on an secured cluster might compromise security.");
            } else {
                throw new H2OIllegalArgumentException("Cannot run XGBoost on an SSL enabled cluster larger than 1 node. XGBoost does not support SSL encryption.");
            }
        }
        if (H2O.ARGS.client && ((XGBoostModel.XGBoostParameters)this._parms)._build_tree_one_node) {
            this.error("_build_tree_one_node", "Cannot run on a single node in client mode.");
        }
        if (expensive) {
            if (this._response.naCnt() > 0L) {
                this.error("_response_column", "Response contains missing values (NAs) - not supported by XGBoost.");
            }
            if (!((XGBoostExtensionCheck)new XGBoostExtensionCheck().doAllNodes()).enabled) {
                this.error("XGBoost", "XGBoost is not available on all nodes!");
            }
        }
        if (((XGBoostModel.XGBoostParameters)this._parms).hasCheckpoint()) {
            Value cv = DKV.get(((XGBoostModel.XGBoostParameters)this._parms)._checkpoint);
            if (cv != null) {
                XGBoostModel checkpointModel = CheckpointUtils.getAndValidateCheckpointModel(this, XGBoostModel.XGBoostParameters.CHECKPOINT_NON_MODIFIABLE_FIELDS, cv);
                this._ntrees = ((XGBoostModel.XGBoostParameters)this._parms)._ntrees - ((XGBoostOutput)checkpointModel._output)._ntrees;
            }
        } else {
            this._ntrees = ((XGBoostModel.XGBoostParameters)this._parms)._ntrees;
        }
        if (((XGBoostModel.XGBoostParameters)this._parms)._max_depth < 0) {
            this.error("_max_depth", "_max_depth must be >= 0.");
        }
        if (((XGBoostModel.XGBoostParameters)this._parms)._max_depth == 0) {
            ((XGBoostModel.XGBoostParameters)this._parms)._max_depth = Integer.MAX_VALUE;
        }
        if (expensive && this.error_count() > 0) {
            throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(this);
        }
        if (((XGBoostModel.XGBoostParameters)this._parms)._backend == XGBoostModel.XGBoostParameters.Backend.gpu) {
            Map<String, Object> incompats;
            if (!GpuUtils.hasGPU(((XGBoostModel.XGBoostParameters)this._parms)._gpu_id)) {
                this.error("_backend", "GPU backend (gpu_id: " + Arrays.toString(((XGBoostModel.XGBoostParameters)this._parms)._gpu_id) + ") is not functional. Check CUDA_PATH and/or GPU installation.");
            }
            if (H2O.getCloudSize() > 1 && !((XGBoostModel.XGBoostParameters)this._parms)._build_tree_one_node && !XGBoost.allowMultiGPU()) {
                this.error("_backend", "GPU backend is not supported in distributed mode.");
            }
            if (!(incompats = ((XGBoostModel.XGBoostParameters)this._parms).gpuIncompatibleParams()).isEmpty()) {
                for (Map.Entry<String, Object> incompat : incompats.entrySet()) {
                    this.error("_backend", "GPU backend is not available for parameter setting '" + incompat.getKey() + " = " + incompat.getValue() + "'. Use CPU backend instead.");
                }
            }
        }
        if (!ArrayUtils.contains(allowed_distributions = new DistributionFamily[]{DistributionFamily.AUTO, DistributionFamily.bernoulli, DistributionFamily.multinomial, DistributionFamily.gaussian, DistributionFamily.poisson, DistributionFamily.gamma, DistributionFamily.tweedie}, ((XGBoostModel.XGBoostParameters)this._parms)._distribution)) {
            this.error("_distribution", ((XGBoostModel.XGBoostParameters)this._parms)._distribution.name() + " is not supported for XGBoost in current H2O.");
        }
        if (this.unsupportedCategoricalEncoding()) {
            this.error("_categorical_encoding", (Object)((Object)((XGBoostModel.XGBoostParameters)this._parms)._categorical_encoding) + " encoding is not supported for XGBoost in current H2O.");
        }
        switch (((XGBoostModel.XGBoostParameters)this._parms)._distribution) {
            case bernoulli: {
                if (this._nclass == 2) break;
                this.error("_distribution", H2O.technote(2, "Binomial requires the response to be a 2-class categorical"));
                break;
            }
            case modified_huber: {
                if (this._nclass == 2) break;
                this.error("_distribution", H2O.technote(2, "Modified Huber requires the response to be a 2-class categorical."));
                break;
            }
            case multinomial: {
                if (this.isClassifier()) break;
                this.error("_distribution", H2O.technote(2, "Multinomial requires an categorical response."));
                break;
            }
            case huber: {
                if (!this.isClassifier()) break;
                this.error("_distribution", H2O.technote(2, "Huber requires the response to be numeric."));
                break;
            }
            case poisson: {
                if (!this.isClassifier()) break;
                this.error("_distribution", H2O.technote(2, "Poisson requires the response to be numeric."));
                break;
            }
            case gamma: {
                if (!this.isClassifier()) break;
                this.error("_distribution", H2O.technote(2, "Gamma requires the response to be numeric."));
                break;
            }
            case tweedie: {
                if (!this.isClassifier()) break;
                this.error("_distribution", H2O.technote(2, "Tweedie requires the response to be numeric."));
                break;
            }
            case gaussian: {
                if (!this.isClassifier()) break;
                this.error("_distribution", H2O.technote(2, "Gaussian requires the response to be numeric."));
                break;
            }
            case laplace: {
                if (!this.isClassifier()) break;
                this.error("_distribution", H2O.technote(2, "Laplace requires the response to be numeric."));
                break;
            }
            case quantile: {
                if (!this.isClassifier()) break;
                this.error("_distribution", H2O.technote(2, "Quantile requires the response to be numeric."));
                break;
            }
            case AUTO: {
                break;
            }
            default: {
                this.error("_distribution", "Invalid distribution: " + (Object)((Object)((XGBoostModel.XGBoostParameters)this._parms)._distribution));
            }
        }
        this.checkPositiveRate("learn_rate", ((XGBoostModel.XGBoostParameters)this._parms)._learn_rate);
        this.checkPositiveRate("sample_rate", ((XGBoostModel.XGBoostParameters)this._parms)._sample_rate);
        this.checkPositiveRate("subsample", ((XGBoostModel.XGBoostParameters)this._parms)._subsample);
        this.checkPositiveRate("col_sample_rate", ((XGBoostModel.XGBoostParameters)this._parms)._col_sample_rate);
        this.checkPositiveRate("col_sample_rate_per_tree", ((XGBoostModel.XGBoostParameters)this._parms)._col_sample_rate_per_tree);
        this.checkPositiveRate("colsample_bylevel", ((XGBoostModel.XGBoostParameters)this._parms)._colsample_bylevel);
        this.checkPositiveRate("colsample_bynode", ((XGBoostModel.XGBoostParameters)this._parms)._colsample_bynode);
        this.checkPositiveRate("colsample_bytree", ((XGBoostModel.XGBoostParameters)this._parms)._colsample_bytree);
        this.checkColumnAlias("col_sample_rate", ((XGBoostModel.XGBoostParameters)this._parms)._col_sample_rate, "colsample_bylevel", ((XGBoostModel.XGBoostParameters)this._parms)._colsample_bylevel, 1.0);
        this.checkColumnAlias("col_sample_rate_per_tree", ((XGBoostModel.XGBoostParameters)this._parms)._col_sample_rate_per_tree, "colsample_bytree", ((XGBoostModel.XGBoostParameters)this._parms)._colsample_bytree, 1.0);
        this.checkColumnAlias("sample_rate", ((XGBoostModel.XGBoostParameters)this._parms)._sample_rate, "subsample", ((XGBoostModel.XGBoostParameters)this._parms)._subsample, 1.0);
        this.checkColumnAlias("learn_rate", ((XGBoostModel.XGBoostParameters)this._parms)._learn_rate, "eta", ((XGBoostModel.XGBoostParameters)this._parms)._eta, 0.3);
        this.checkColumnAlias("max_abs_leafnode_pred", ((XGBoostModel.XGBoostParameters)this._parms)._max_abs_leafnode_pred, "max_delta_step", ((XGBoostModel.XGBoostParameters)this._parms)._max_delta_step, 0.0);
        this.checkColumnAlias("ntrees", ((XGBoostModel.XGBoostParameters)this._parms)._ntrees, "n_estimators", ((XGBoostModel.XGBoostParameters)this._parms)._n_estimators, 0.0);
        if (((XGBoostModel.XGBoostParameters)this._parms)._tree_method.equals((Object)XGBoostModel.XGBoostParameters.TreeMethod.approx) && (((XGBoostModel.XGBoostParameters)this._parms)._col_sample_rate < 1.0 || ((XGBoostModel.XGBoostParameters)this._parms)._colsample_bylevel < 1.0)) {
            this.error("_tree_method", "approx is not supported with _col_sample_rate or _colsample_bylevel, use exact/hist instead or disable column sampling.");
        }
        if (((XGBoostModel.XGBoostParameters)this._parms)._scale_pos_weight != 1.0f) {
            if (this._nclass != 2) {
                this.error("_scale_pos_weight", "scale_pos_weight can only be used for binary classification");
            }
            if (((XGBoostModel.XGBoostParameters)this._parms)._scale_pos_weight <= 0.0f) {
                this.error("_scale_pos_weight", "scale_pos_weight must be a positive number");
            }
        }
        if (((XGBoostModel.XGBoostParameters)this._parms)._grow_policy == XGBoostModel.XGBoostParameters.GrowPolicy.lossguide && ((XGBoostModel.XGBoostParameters)this._parms)._tree_method != XGBoostModel.XGBoostParameters.TreeMethod.hist) {
            this.error("_grow_policy", "must use tree_method=hist for grow_policy=lossguide");
        }
        if (this._train != null && !((XGBoostModel.XGBoostParameters)this._parms).monotoneConstraints().isEmpty()) {
            if (((XGBoostModel.XGBoostParameters)this._parms)._tree_method == XGBoostModel.XGBoostParameters.TreeMethod.approx) {
                this.error("_tree_method", "approx is not supported with _monotone_constraints, use auto/exact/hist instead");
            } else assert (((XGBoostModel.XGBoostParameters)this._parms)._tree_method == XGBoostModel.XGBoostParameters.TreeMethod.auto || ((XGBoostModel.XGBoostParameters)this._parms)._tree_method == XGBoostModel.XGBoostParameters.TreeMethod.exact || ((XGBoostModel.XGBoostParameters)this._parms)._tree_method == XGBoostModel.XGBoostParameters.TreeMethod.hist) : "Unexpected tree method used " + (Object)((Object)((XGBoostModel.XGBoostParameters)this._parms)._tree_method);
            TreeUtils.checkMonotoneConstraints(this, this._train, ((XGBoostModel.XGBoostParameters)this._parms)._monotone_constraints);
        }
        if (this._train != null && H2O.CLOUD.size() > 1 && ((XGBoostModel.XGBoostParameters)this._parms)._tree_method == XGBoostModel.XGBoostParameters.TreeMethod.exact && !((XGBoostModel.XGBoostParameters)this._parms)._build_tree_one_node) {
            this.error("_tree_method", "exact is not supported in distributed environment, set build_tree_one_node to true to use exact");
        }
        PlattScalingHelper.initCalibration(this, (PlattScalingHelper.ParamsWithCalibration)((Object)this._parms), expensive);
    }

    private void checkPositiveRate(String paramName, double rateValue) {
        if (rateValue <= 0.0 || rateValue > 1.0) {
            this.error("_" + paramName, paramName + " must be between 0 (exclusive) and 1 (inclusive)");
        }
    }

    private void checkColumnAlias(String paramName, double paramValue, String aliasName, double aliasValue, double defaultValue) {
        if (paramValue != defaultValue && aliasValue != defaultValue && paramValue != aliasValue) {
            this.error("_" + paramName, paramName + " and its alias " + aliasName + " are both set to different value than default value. Set " + aliasName + " to default value (" + defaultValue + "), to use " + paramName + " actual value.");
        } else if (aliasValue != defaultValue) {
            this.warn("_" + paramName, "Using user-provided parameter " + aliasName + " instead of " + paramName + ".\"");
        }
    }

    @Override
    protected void checkEarlyStoppingReproducibility() {
        if (((XGBoostModel.XGBoostParameters)this._parms)._score_tree_interval == 0 && !((XGBoostModel.XGBoostParameters)this._parms)._score_each_iteration) {
            this.warn("_stopping_rounds", "early stopping is enabled but neither score_tree_interval or score_each_iteration are defined. Early stopping will not be reproducible!");
        }
    }

    static boolean allowMultiGPU() {
        return H2O.getSysBoolProperty("xgboost.multinode.gpu.enabled", false);
    }

    static boolean prestartExternalClusterForCV() {
        return H2O.getSysBoolProperty("xgboost.external.cv.prestart", false);
    }

    @Override
    public XGBoost getModelBuilder() {
        return this;
    }

    @Override
    public Frame getCalibrationFrame() {
        return this._calib;
    }

    @Override
    public void setCalibrationFrame(Frame f2) {
        this._calib = f2;
    }

    @Override
    protected boolean canLearnFromNAs() {
        return true;
    }

    static DataInfo makeDataInfo(Frame train, Frame valid, XGBoostModel.XGBoostParameters parms, int nClasses) {
        DataInfo dinfo = new DataInfo(train, valid, 1, true, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, false, false, true, parms._weights_column != null, parms._offset_column != null, parms._fold_column != null);
        GLMTask.YMUTask ymt = (GLMTask.YMUTask)new GLMTask.YMUTask(dinfo, nClasses, nClasses == 1, false, true, true).doAll(dinfo._adaptedFrame);
        if (parms._weights_column != null && parms._offset_column != null) {
            LOG.warn((Object)"Combination of offset and weights can lead to slight differences because Rollupstats aren't weighted - need to re-calculate weighted mean/sigma of the response including offset terms.");
        }
        if (parms._weights_column != null && parms._offset_column == null) {
            dinfo.updateWeightedSigmaAndMean(ymt.predictorSDs(), ymt.predictorMeans());
            if (nClasses == 1) {
                dinfo.updateWeightedSigmaAndMeanForResponse(ymt.responseSDs(), ymt.responseMeans());
            }
        }
        dinfo.coefNames();
        dinfo.coefOriginalColumnIndices();
        assert (dinfo._coefNames != null && dinfo._coefOriginalIndices != null);
        return dinfo;
    }

    @Override
    protected Frame rebalance(Frame original_fr, boolean local, String name) {
        if (original_fr == null) {
            return null;
        }
        if (((XGBoostModel.XGBoostParameters)this._parms)._build_tree_one_node) {
            int original_chunks = original_fr.anyVec().nChunks();
            if (original_chunks == 1) {
                return original_fr;
            }
            LOG.info((Object)("Rebalancing " + name.substring(name.length() - 5) + " dataset onto a single node."));
            Key newKey = Key.make(name + ".1chk");
            RebalanceDataSet rb = new RebalanceDataSet(original_fr, newKey, 1);
            H2O.submitTask(rb).join();
            Frame singleChunkFr = (Frame)DKV.get(newKey).get();
            Scope.track(singleChunkFr);
            return singleChunkFr;
        }
        return super.rebalance(original_fr, local, name);
    }

    private static TwoDimTable createVarImpTable(String name, double[] rel_imp, String[] coef_names) {
        return ModelMetrics.calcVarImp(rel_imp, coef_names, "Variable Importances" + (name != null ? " - " + name : ""), new String[]{"Relative Importance", "Scaled Importance", "Percentage"});
    }

    private static XgbVarImp computeVarImp(Map<String, FeatureScore> varimp) {
        if (varimp.isEmpty()) {
            return null;
        }
        float[] gains = new float[varimp.size()];
        float[] covers = new float[varimp.size()];
        int[] freqs = new int[varimp.size()];
        String[] names = new String[varimp.size()];
        int j2 = 0;
        for (Map.Entry<String, FeatureScore> it : varimp.entrySet()) {
            gains[j2] = it.getValue()._gain;
            covers[j2] = it.getValue()._cover;
            freqs[j2] = it.getValue()._frequency;
            names[j2] = it.getKey();
            ++j2;
        }
        return new XgbVarImp(names, gains, covers, freqs);
    }

    @Override
    protected CVModelBuilder makeCVModelBuilder(ModelBuilder<?, ?, ?>[] modelBuilders, int parallelization) {
        if (XGBoostModel.getActualBackend((XGBoostModel.XGBoostParameters)this._parms, false) == XGBoostModel.XGBoostParameters.Backend.gpu && parallelization > 1) {
            return new XGBoostGPUCVModelBuilder(this._job, modelBuilders, parallelization, ((XGBoostModel.XGBoostParameters)this._parms)._gpu_id);
        }
        if (H2O.ARGS.use_external_xgboost && XGBoost.prestartExternalClusterForCV()) {
            return new XGBoostExternalCVModelBuilder(this._job, modelBuilders, parallelization, SteamExecutorStarter.getInstance());
        }
        return super.makeCVModelBuilder(modelBuilders, parallelization);
    }

    @Override
    public void cv_computeAndSetOptimalParameters(ModelBuilder<XGBoostModel, XGBoostModel.XGBoostParameters, XGBoostOutput>[] cvModelBuilders) {
        if (((XGBoostModel.XGBoostParameters)this._parms)._stopping_rounds == 0 && ((XGBoostModel.XGBoostParameters)this._parms)._max_runtime_secs == 0.0) {
            return;
        }
        ((XGBoostModel.XGBoostParameters)this._parms)._stopping_rounds = 0;
        this.setMaxRuntimeSecsForMainModel();
        int sum = 0;
        for (ModelBuilder<XGBoostModel, XGBoostModel.XGBoostParameters, XGBoostOutput> mb : cvModelBuilders) {
            sum += ((XGBoostOutput)((Model)DKV.getGet(mb.dest()))._output)._ntrees;
        }
        ((XGBoostModel.XGBoostParameters)this._parms)._ntrees = (int)((double)sum / (double)cvModelBuilders.length);
        this.warn("_ntrees", "Setting optimal _ntrees to " + ((XGBoostModel.XGBoostParameters)this._parms)._ntrees + " for cross-validation main model based on early stopping of cross-validation models.");
        this.warn("_stopping_rounds", "Disabling convergence-based early stopping for cross-validation main model.");
        this.warn("_max_runtime_secs", "Disabling maximum allowed runtime for cross-validation main model.");
    }

    private boolean unsupportedCategoricalEncoding() {
        return ((XGBoostModel.XGBoostParameters)this._parms)._categorical_encoding == Model.Parameters.CategoricalEncodingScheme.Enum || ((XGBoostModel.XGBoostParameters)this._parms)._categorical_encoding == Model.Parameters.CategoricalEncodingScheme.Eigen;
    }

    class XGBoostDriver
    extends ModelBuilder.Driver {
        long _firstScore;
        long _timeLastScoreStart;
        long _timeLastScoreEnd;

        XGBoostDriver() {
            super(XGBoost.this);
            this._firstScore = 0L;
            this._timeLastScoreStart = 0L;
            this._timeLastScoreEnd = 0L;
        }

        @Override
        public void computeImpl() {
            XGBoost.this.init(true);
            if (XGBoost.this.error_count() > 0) {
                throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(XGBoost.this);
            }
            this.buildModel();
        }

        /*
         * Enabled force condition propagation
         * Lifted jumps to return sites
         */
        final void buildModel() {
            if ((XGBoostModel.XGBoostParameters.Backend.auto.equals((Object)((XGBoostModel.XGBoostParameters)XGBoost.this._parms)._backend) || XGBoostModel.XGBoostParameters.Backend.gpu.equals((Object)((XGBoostModel.XGBoostParameters)XGBoost.this._parms)._backend)) && GpuUtils.hasGPU(((XGBoostModel.XGBoostParameters)XGBoost.this._parms)._gpu_id) && (H2O.getCloudSize() == 1 || XGBoost.allowMultiGPU()) && ((XGBoostModel.XGBoostParameters)XGBoost.this._parms).gpuIncompatibleParams().isEmpty()) {
                int[] lockedGpus = null;
                try {
                    lockedGpus = XGBoostGPULock.lock(((XGBoostModel.XGBoostParameters)XGBoost.this._parms)._gpu_id);
                    this.buildModelImpl();
                    if (lockedGpus == null) return;
                }
                catch (Throwable throwable) {
                    if (lockedGpus == null) throw throwable;
                    XGBoostGPULock.unlock(lockedGpus);
                    throw throwable;
                }
                XGBoostGPULock.unlock(lockedGpus);
                return;
            }
            this.buildModelImpl();
        }

        private XGBoostExecutor makeExecutor(XGBoostModel model) throws IOException {
            if (H2O.ARGS.use_external_xgboost) {
                return SteamExecutorStarter.getInstance().getRemoteExecutor(model, XGBoost.this._train, XGBoost.this._job);
            }
            String remoteUriFromProp = H2O.getSysProperty("xgboost.external.address", null);
            if (remoteUriFromProp == null) {
                return new LocalXGBoostExecutor(model, XGBoost.this._train);
            }
            String userName = H2O.getSysProperty("xgboost.external.user", null);
            String password = H2O.getSysProperty("xgboost.external.password", null);
            return new RemoteXGBoostExecutor(model, XGBoost.this._train, remoteUriFromProp, userName, password);
        }

        final void buildModelImpl() {
            XGBoostModel model;
            if (((XGBoostModel.XGBoostParameters)XGBoost.this._parms).hasCheckpoint()) {
                XGBoostModel checkpoint = ((XGBoostModel)DKV.get(((XGBoostModel.XGBoostParameters)XGBoost.this._parms)._checkpoint).get()).deepClone(XGBoost.this._result);
                checkpoint._parms = XGBoost.this._parms;
                model = (XGBoostModel)checkpoint.delete_and_lock(XGBoost.this._job);
            } else {
                model = new XGBoostModel(XGBoost.this._result, (XGBoostModel.XGBoostParameters)XGBoost.this._parms, new XGBoostOutput(XGBoost.this), XGBoost.this._train, XGBoost.this._valid);
                model.write_lock(XGBoost.this._job);
            }
            ((XGBoostOutput)model._output)._sparse = ((XGBoostModel.XGBoostParameters)XGBoost.this._parms)._dmatrix_type == XGBoostModel.XGBoostParameters.DMatrixType.sparse ? true : (((XGBoostModel.XGBoostParameters)XGBoost.this._parms)._dmatrix_type == XGBoostModel.XGBoostParameters.DMatrixType.dense ? false : this.isTrainDatasetSparse());
            if (model.evalAutoParamsEnabled) {
                model.initActualParamValuesAfterOutputSetup(XGBoost.this.isClassifier(), XGBoost.this._nclass);
            }
            XGBoostUtils.createFeatureMap(model, XGBoost.this._train);
            XGBoostVariableImportance variableImportance = model.setupVarImp();
            try (XGBoostExecutor exec = this.makeExecutor(model);){
                model.model_info().updateBoosterBytes(exec.setup());
                this.scoreAndBuildTrees(model, exec, variableImportance);
            }
            catch (Exception e2) {
                throw new RuntimeException("Error while training XGBoost model", e2);
            }
            finally {
                variableImportance.cleanup();
                model.unlock(XGBoost.this._job);
            }
        }

        private boolean isTrainDatasetSparse() {
            long nonZeroCount = 0L;
            int nonCategoricalColumns = 0;
            long oneHotEncodedColumns = 0L;
            for (int i2 = 0; i2 < XGBoost.this._train.numCols(); ++i2) {
                if (XGBoost.this._train.name(i2).equals(((XGBoostModel.XGBoostParameters)XGBoost.this._parms)._response_column) || XGBoost.this._train.name(i2).equals(((XGBoostModel.XGBoostParameters)XGBoost.this._parms)._weights_column) || XGBoost.this._train.name(i2).equals(((XGBoostModel.XGBoostParameters)XGBoost.this._parms)._fold_column) || XGBoost.this._train.name(i2).equals(((XGBoostModel.XGBoostParameters)XGBoost.this._parms)._offset_column)) continue;
                Vec vector = XGBoost.this._train.vec(i2);
                nonZeroCount = vector.isCategorical() ? (nonZeroCount += XGBoost.this._train.numRows()) : (nonZeroCount += vector.nzCnt());
                if (vector.isCategorical()) {
                    oneHotEncodedColumns += (long)vector.cardinality();
                    continue;
                }
                ++nonCategoricalColumns;
            }
            long totalColumns = oneHotEncodedColumns + (long)nonCategoricalColumns;
            double denominator = (double)totalColumns * (double)XGBoost.this._train.numRows();
            double fillRatio = (double)nonZeroCount / denominator;
            LOG.info((Object)("fill ratio: " + fillRatio));
            return fillRatio < 0.25 || XGBoost.this._train.numRows() * totalColumns > Integer.MAX_VALUE;
        }

        private void scoreAndBuildTrees(XGBoostModel model, XGBoostExecutor exec, XGBoostVariableImportance varImp) {
            Map<String, Integer> monotoneConstraints;
            for (int tid = 0; !(tid >= XGBoost.this._ntrees || XGBoost.this._job.stop_requested() && tid > 0); ++tid) {
                boolean scored = this.doScoring(model, exec, varImp, false);
                if (scored && ScoreKeeper.stopEarly(((XGBoostOutput)model._output).scoreKeepers(), ((XGBoostModel.XGBoostParameters)XGBoost.this._parms)._stopping_rounds, ScoreKeeper.ProblemType.forSupervised(XGBoost.this._nclass > 1), ((XGBoostModel.XGBoostParameters)XGBoost.this._parms)._stopping_metric, ((XGBoostModel.XGBoostParameters)XGBoost.this._parms)._stopping_tolerance, "model's last", true)) {
                    LOG.info((Object)"Early stopping triggered - stopping XGBoost training");
                    break;
                }
                Timer kb_timer = new Timer();
                exec.update(tid);
                LOG.info((Object)(tid + 1 + ". tree was built in " + kb_timer.toString()));
                XGBoost.this._job.update(1L);
                ++((XGBoostOutput)model._output)._ntrees;
                ((XGBoostOutput)model._output)._scored_train = ArrayUtils.copyAndFillOf(((XGBoostOutput)model._output)._scored_train, ((XGBoostOutput)model._output)._ntrees + 1, new ScoreKeeper());
                ((XGBoostOutput)model._output)._scored_valid = ((XGBoostOutput)model._output)._scored_valid != null ? ArrayUtils.copyAndFillOf(((XGBoostOutput)model._output)._scored_valid, ((XGBoostOutput)model._output)._ntrees + 1, new ScoreKeeper()) : null;
                ((XGBoostOutput)model._output)._training_time_ms = ArrayUtils.copyAndFillOf(((XGBoostOutput)model._output)._training_time_ms, ((XGBoostOutput)model._output)._ntrees + 1, System.currentTimeMillis());
                if (XGBoost.this.stop_requested() && !XGBoost.this.timeout()) {
                    throw new Job.JobCancelledException();
                }
                if (!XGBoost.this.timeout()) continue;
                LOG.info((Object)"Stopping XGBoost training because of timeout");
                break;
            }
            if (!(monotoneConstraints = ((XGBoostModel.XGBoostParameters)XGBoost.this._parms).monotoneConstraints()).isEmpty() && ((XGBoostModel.XGBoostParameters)XGBoost.this._parms)._booster != XGBoostModel.XGBoostParameters.Booster.gblinear && this.monotonicityConstraintCheckEnabled()) {
                XGBoost.this._job.update(0L, "Checking monotonicity constraints on the final model");
                model.model_info().updateBoosterBytes(exec.updateBooster());
                this.checkMonotonicityConstraints(model.model_info(), monotoneConstraints);
            }
            if (((XGBoostModel.XGBoostParameters)XGBoost.this._parms)._interaction_constraints != null && this.interactionConstraintCheckEnabled()) {
                XGBoost.this._job.update(0L, "Checking interaction constraints on the final model");
                model.model_info().updateBoosterBytes(exec.updateBooster());
                this.checkInteractionConstraints(model.model_info(), ((XGBoostModel.XGBoostParameters)XGBoost.this._parms)._interaction_constraints);
            }
            XGBoost.this._job.update(0L, "Scoring the final model");
            this.doScoring(model, exec, varImp, true);
            XGBoost.this._job.update(((XGBoostModel.XGBoostParameters)XGBoost.this._parms)._ntrees - ((XGBoostOutput)model._output)._ntrees);
        }

        private boolean monotonicityConstraintCheckEnabled() {
            return Boolean.parseBoolean(XGBoost.this.getSysProperty("xgboost.monotonicity.checkEnabled", "true"));
        }

        private boolean interactionConstraintCheckEnabled() {
            return Boolean.parseBoolean(XGBoost.this.getSysProperty("xgboost.interactions.checkEnabled", "true"));
        }

        private void checkMonotonicityConstraints(XGBoostModelInfo model_info, Map<String, Integer> monotoneConstraints) {
            GradBooster booster = XGBoostJavaMojoModel.makePredictor(model_info._boosterBytes, null).getBooster();
            if (!(booster instanceof GBTree)) {
                throw new IllegalStateException("Expected booster object to be GBTree instead it is " + booster.getClass().getName());
            }
            RegTree[][] groupedTrees = ((GBTree)booster).getGroupedTrees();
            XGBoostUtils.FeatureProperties featureProperties = XGBoostUtils.assembleFeatureNames(model_info.dataInfo());
            RegTree[][] regTreeArray = groupedTrees;
            int n2 = regTreeArray.length;
            for (int i2 = 0; i2 < n2; ++i2) {
                RegTree[] classTrees;
                for (RegTree tree : classTrees = regTreeArray[i2]) {
                    if (tree == null) continue;
                    this.checkMonotonicityConstraints(tree.getNodes(), monotoneConstraints, featureProperties);
                }
            }
        }

        private void checkMonotonicityConstraints(RegTreeNode[] tree, Map<String, Integer> monotoneConstraints, XGBoostUtils.FeatureProperties featureProperties) {
            float[] mins = new float[tree.length];
            int[] min_ids = new int[tree.length];
            float[] maxs = new float[tree.length];
            int[] max_ids = new int[tree.length];
            this.rollupMinMaxPreds(tree, 0, mins, min_ids, maxs, max_ids);
            for (RegTreeNode node : tree) {
                String splitColumn;
                if (node.isLeaf() || !monotoneConstraints.containsKey(splitColumn = featureProperties._names[node.getSplitIndex()])) continue;
                int constraint = monotoneConstraints.get(splitColumn);
                int left = node.getLeftChildIndex();
                int right = node.getRightChildIndex();
                if (constraint > 0) {
                    if (!(maxs[left] > mins[right])) continue;
                    throw new IllegalStateException("Monotonicity constraint " + constraint + " violated on column '" + splitColumn + "' (max(left) > min(right)): " + maxs[left] + " > " + mins[right] + "\nNode: " + node + "\nLeft Node (max): " + tree[max_ids[left]] + "\nRight Node (min): " + tree[min_ids[right]]);
                }
                if (constraint >= 0 || !(mins[left] < maxs[right])) continue;
                throw new IllegalStateException("Monotonicity constraint " + constraint + " violated on column '" + splitColumn + "' (min(left) < max(right)): " + mins[left] + " < " + maxs[right] + "\nNode: " + node + "\nLeft Node (min): " + tree[min_ids[left]] + "\nRight Node (max): " + tree[max_ids[right]]);
            }
        }

        private void rollupMinMaxPreds(RegTreeNode[] tree, int nid, float[] mins, int[] min_ids, float[] maxs, int[] max_ids) {
            RegTreeNode node = tree[nid];
            if (node.isLeaf()) {
                mins[nid] = node.getLeafValue();
                min_ids[nid] = nid;
                maxs[nid] = node.getLeafValue();
                max_ids[nid] = nid;
                return;
            }
            int left = node.getLeftChildIndex();
            int right = node.getRightChildIndex();
            this.rollupMinMaxPreds(tree, left, mins, min_ids, maxs, max_ids);
            this.rollupMinMaxPreds(tree, right, mins, min_ids, maxs, max_ids);
            int min_id = mins[left] < mins[right] ? left : right;
            mins[nid] = mins[min_id];
            min_ids[nid] = min_ids[min_id];
            int max_id = maxs[left] > maxs[right] ? left : right;
            maxs[nid] = maxs[max_id];
            max_ids[nid] = max_ids[max_id];
        }

        private void checkInteractionConstraints(XGBoostModelInfo model_info, String[][] interactionConstraints) {
            GradBooster booster = XGBoostJavaMojoModel.makePredictor(model_info._boosterBytes, null).getBooster();
            if (!(booster instanceof GBTree)) {
                throw new IllegalStateException("Expected booster object to be GBTree instead it is " + booster.getClass().getName());
            }
            RegTree[][] groupedTrees = ((GBTree)booster).getGroupedTrees();
            XGBoostUtils.FeatureProperties featureProperties = XGBoostUtils.assembleFeatureNames(model_info.dataInfo());
            HashMap<Integer, Set<Integer>> interactionUnions = new HashMap<Integer, Set<Integer>>();
            for (String[] interaction : interactionConstraints) {
                Integer[] mapOfIndices;
                Integer[] integerArray = mapOfIndices = featureProperties.mapOriginalNamesToIndices(interaction);
                int n2 = integerArray.length;
                for (int i2 = 0; i2 < n2; ++i2) {
                    int index = integerArray[i2];
                    if (!interactionUnions.containsKey(index)) {
                        interactionUnions.put(index, new HashSet());
                    }
                    ((Set)interactionUnions.get(index)).addAll(Arrays.asList(mapOfIndices));
                }
            }
            RegTree[][] regTreeArray = groupedTrees;
            int n3 = regTreeArray.length;
            for (int i3 = 0; i3 < n3; ++i3) {
                RegTree[] classTrees;
                for (RegTree tree : classTrees = regTreeArray[i3]) {
                    if (tree == null) continue;
                    RegTreeNode[] treeNodes = tree.getNodes();
                    this.checkInteractionConstraints(treeNodes, treeNodes[0], interactionUnions, featureProperties);
                }
            }
        }

        private void checkInteractionConstraints(RegTreeNode[] tree, RegTreeNode node, Map<Integer, Set<Integer>> interactionUnions, XGBoostUtils.FeatureProperties featureProperties) {
            RegTreeNode rightChildNode;
            if (node.isLeaf()) {
                return;
            }
            int splitIndex = node.getSplitIndex();
            int splitIndexOriginal = featureProperties._originalColumnIndices[splitIndex];
            Set<Integer> interactionUnion = interactionUnions.get(splitIndexOriginal);
            RegTreeNode leftChildNode = tree[node.getLeftChildIndex()];
            if (!leftChildNode.isLeaf()) {
                int leftChildSplitIndex = leftChildNode.getSplitIndex();
                int leftChildSplitIndexOriginal = featureProperties._originalColumnIndices[leftChildSplitIndex];
                if (!(leftChildSplitIndex == splitIndex || interactionUnion != null && interactionUnion.contains(leftChildSplitIndexOriginal))) {
                    String parentOriginalName = featureProperties._originalNames[splitIndexOriginal];
                    String interactionString = this.generateInteractionConstraintUnionString(featureProperties._originalNames, splitIndexOriginal, interactionUnion);
                    String leftOriginalName = featureProperties._originalNames[leftChildSplitIndexOriginal];
                    throw new IllegalStateException("Interaction constraint violated on column '" + leftOriginalName + "': The parent column '" + parentOriginalName + "' can interact only with " + interactionString + " columns.");
                }
            }
            if (!(rightChildNode = tree[node.getRightChildIndex()]).isLeaf()) {
                int rightChildSplitIndex = rightChildNode.getSplitIndex();
                int rightChildSplitIndexOriginal = featureProperties._originalColumnIndices[rightChildSplitIndex];
                if (!(rightChildSplitIndex == splitIndex || interactionUnion != null && interactionUnion.contains(rightChildSplitIndexOriginal))) {
                    String parentOriginalName = featureProperties._originalNames[splitIndexOriginal];
                    String interactionString = this.generateInteractionConstraintUnionString(featureProperties._originalNames, splitIndexOriginal, interactionUnion);
                    String rightOriginalName = featureProperties._originalNames[rightChildSplitIndexOriginal];
                    throw new IllegalStateException("Interaction constraint violated on column '" + rightOriginalName + "': The parent column '" + parentOriginalName + "' can interact only with " + interactionString + " columns.");
                }
            }
            this.checkInteractionConstraints(tree, leftChildNode, interactionUnions, featureProperties);
            this.checkInteractionConstraints(tree, rightChildNode, interactionUnions, featureProperties);
        }

        private String generateInteractionConstraintUnionString(String[] originalNames, int splitIndexOriginal, Set<Integer> interactionUnion) {
            String parentOriginalName = originalNames[splitIndexOriginal];
            String interaction = "['" + parentOriginalName + "']";
            if (interactionUnion != null) {
                StringBuilder sb = new StringBuilder("[");
                for (Integer i2 : interactionUnion) {
                    sb.append(originalNames[i2]).append(",");
                }
                interaction = sb.replace(sb.length() - 1, sb.length(), "]").toString();
            }
            return interaction;
        }

        private boolean doScoring(XGBoostModel model, XGBoostExecutor exec, XGBoostVariableImportance varImp, boolean finalScoring) {
            boolean manualInterval;
            boolean scored = false;
            long now = System.currentTimeMillis();
            if (this._firstScore == 0L) {
                this._firstScore = now;
            }
            long sinceLastScore = now - this._timeLastScoreStart;
            XGBoost.this._job.update(0L, "Built " + ((XGBoostOutput)model._output)._ntrees + " trees so far (out of " + ((XGBoostModel.XGBoostParameters)XGBoost.this._parms)._ntrees + ").");
            boolean timeToScore = now - this._firstScore < (long)((XGBoostModel.XGBoostParameters)XGBoost.this._parms)._initial_score_interval || sinceLastScore > (long)((XGBoostModel.XGBoostParameters)XGBoost.this._parms)._score_interval && (double)(this._timeLastScoreEnd - this._timeLastScoreStart) / (double)sinceLastScore < 0.1;
            boolean bl = manualInterval = ((XGBoostModel.XGBoostParameters)XGBoost.this._parms)._score_tree_interval > 0 && ((XGBoostOutput)model._output)._ntrees % ((XGBoostModel.XGBoostParameters)XGBoost.this._parms)._score_tree_interval == 0;
            if (((XGBoostModel.XGBoostParameters)XGBoost.this._parms)._score_each_iteration || finalScoring || timeToScore && ((XGBoostModel.XGBoostParameters)XGBoost.this._parms)._score_tree_interval == 0 || manualInterval) {
                this._timeLastScoreStart = now;
                model.model_info().updateBoosterBytes(exec.updateBooster());
                model.doScoring(XGBoost.this._train, ((XGBoostModel.XGBoostParameters)XGBoost.this._parms).train(), XGBoost.this._valid, ((XGBoostModel.XGBoostParameters)XGBoost.this._parms).valid());
                this._timeLastScoreEnd = System.currentTimeMillis();
                XGBoostOutput out = (XGBoostOutput)model._output;
                Map<String, FeatureScore> varimp = varImp.getFeatureScores(model.model_info()._boosterBytes);
                out._varimp = XGBoost.computeVarImp(varimp);
                out._model_summary = SharedTree.createModelSummaryTable(out._ntrees, null);
                out._scoring_history = SharedTree.createScoringHistoryTable(out, ((XGBoostOutput)model._output)._scored_train, out._scored_valid, XGBoost.this._job, out._training_time_ms, ((XGBoostModel.XGBoostParameters)XGBoost.this._parms)._custom_metric_func != null, false);
                if (out._varimp != null) {
                    out._variable_importances = XGBoost.createVarImpTable(null, ArrayUtils.toDouble(out._varimp._varimp), out._varimp._names);
                    out._variable_importances_cover = XGBoost.createVarImpTable("Cover", ArrayUtils.toDouble(out._varimp._covers), out._varimp._names);
                    out._variable_importances_frequency = XGBoost.createVarImpTable("Frequency", ArrayUtils.toDouble(out._varimp._freqs), out._varimp._names);
                }
                model.update(XGBoost.this._job);
                LOG.info((Object)model);
                scored = true;
            }
            if (finalScoring && ((XGBoostModel.XGBoostParameters)XGBoost.this._parms).calibrateModel() && !((XGBoostModel.XGBoostParameters)XGBoost.this._parms)._is_cv_model) {
                ((XGBoostOutput)model._output)._calib_model = PlattScalingHelper.buildCalibrationModel(XGBoost.this, (PlattScalingHelper.ParamsWithCalibration)((Object)XGBoost.this._parms), XGBoost.this._job, model);
                model.update(XGBoost.this._job);
            }
            return scored;
        }
    }
}

