/*
 * Decompiled with CFR 0.152.
 */
package hex.tree;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.ModelMetricsRegression;
import hex.PojoWriter;
import hex.ScoreKeeper;
import hex.ToEigenVec;
import hex.VarImp;
import hex.genmodel.CategoricalEncoding;
import hex.genmodel.GenModel;
import hex.genmodel.algos.tree.SharedTreeMojoModel;
import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.glm.GLMModel;
import hex.tree.CompressedTree;
import hex.tree.DTree;
import hex.tree.PlattScalingHelper;
import hex.tree.Score;
import hex.tree.SharedTree;
import hex.tree.SharedTreePojoWriter;
import hex.tree.TreeStats;
import hex.util.LinearAlgebraUtils;
import java.util.ArrayList;
import java.util.Arrays;
import org.apache.log4j.Logger;
import water.AutoBuffer;
import water.DKV;
import water.Futures;
import water.H2O;
import water.IcedUtils;
import water.Job;
import water.Key;
import water.Keyed;
import water.LocalMR;
import water.MRTask;
import water.MrFun;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.parser.BufferedString;
import water.util.ArrayUtils;
import water.util.TwoDimTable;
import water.util.VecUtils;

public abstract class SharedTreeModel<M extends SharedTreeModel<M, P, O>, P extends SharedTreeParameters, O extends SharedTreeOutput>
extends Model<M, P, O>
implements Model.LeafNodeAssignment,
Model.GetMostImportantFeatures,
Model.FeatureFrequencies,
Model.UpdateAuxTreeWeights {
    private static final Logger LOG = Logger.getLogger(SharedTreeModel.class);

    @Override
    public String[] getMostImportantFeatures(int n2) {
        if (this._output == null) {
            return null;
        }
        TwoDimTable vi = ((SharedTreeOutput)this._output)._variable_importances;
        if (vi == null) {
            return null;
        }
        n2 = Math.min(n2, vi.getRowHeaders().length);
        String[] res = new String[n2];
        System.arraycopy(vi.getRowHeaders(), 0, res, 0, n2);
        return res;
    }

    @Override
    public ToEigenVec getToEigenVec() {
        return LinearAlgebraUtils.toEigen;
    }

    @Override
    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
        switch (((SharedTreeOutput)this._output).getModelCategory()) {
            case Binomial: {
                return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
            }
            case Multinomial: {
                return new ModelMetricsMultinomial.MetricBuilderMultinomial(((SharedTreeOutput)this._output).nclasses(), domain, ((SharedTreeParameters)this._parms)._auc_type);
            }
            case Regression: {
                return new ModelMetricsRegression.MetricBuilderRegression();
            }
        }
        throw H2O.unimpl();
    }

    public SharedTreeModel(Key<M> selfKey, P parms, O output) {
        super(selfKey, parms, output);
    }

    protected String[] makeAllTreeColumnNames() {
        int classTrees = 0;
        for (int i2 = 0; i2 < ((SharedTreeOutput)this._output)._treeKeys[0].length; ++i2) {
            if (((SharedTreeOutput)this._output)._treeKeys[0][i2] == null) continue;
            ++classTrees;
        }
        int outputcols = ((SharedTreeOutput)this._output)._treeKeys.length * classTrees;
        String[] names = new String[outputcols];
        int col = 0;
        for (int tidx = 0; tidx < ((SharedTreeOutput)this._output)._treeKeys.length; ++tidx) {
            Key<CompressedTree>[] keys = ((SharedTreeOutput)this._output)._treeKeys[tidx];
            for (int c2 = 0; c2 < keys.length; ++c2) {
                if (keys[c2] == null) continue;
                names[col++] = "T" + (tidx + 1) + (keys.length == 1 ? "" : ".C" + (c2 + 1));
            }
        }
        return names;
    }

    @Override
    public Frame scoreLeafNodeAssignment(Frame frame, Model.LeafNodeAssignment.LeafNodeAssignmentType type, Key<Frame> destination_key) {
        Frame adaptFrm = new Frame(frame);
        this.adaptTestForTrain(adaptFrm, true, false);
        String[] names = this.makeAllTreeColumnNames();
        AssignLeafNodeTaskBase task = AssignLeafNodeTaskBase.make((SharedTreeOutput)this._output, type);
        return task.execute(adaptFrm, names, destination_key);
    }

    @Override
    public Model.UpdateAuxTreeWeights.UpdateAuxTreeWeightsReport updateAuxTreeWeights(Frame frame, String weightsColumn) {
        if (weightsColumn == null) {
            throw new IllegalArgumentException("Weights column name is not defined");
        }
        Frame adaptFrm = new Frame(frame);
        Vec weights = adaptFrm.remove(weightsColumn);
        if (weights == null) {
            throw new IllegalArgumentException("Input frame doesn't contain weights column `" + weightsColumn + "`");
        }
        this.adaptTestForTrain(adaptFrm, true, false);
        Frame featureFrm = new Frame(((SharedTreeOutput)this._output).features(), frame.vecs(((SharedTreeOutput)this._output).features()));
        featureFrm.add(weightsColumn, weights);
        UpdateAuxTreeWeightsTask t2 = (UpdateAuxTreeWeightsTask)new UpdateAuxTreeWeightsTask((SharedTreeOutput)this._output).doAll(featureFrm);
        Model.UpdateAuxTreeWeights.UpdateAuxTreeWeightsReport report = new Model.UpdateAuxTreeWeights.UpdateAuxTreeWeightsReport();
        report._warn_trees = t2._warnTrees;
        report._warn_classes = t2._warnClasses;
        return report;
    }

    @Override
    public Frame scoreFeatureFrequencies(Frame frame, Key<Frame> destination_key) {
        Frame adaptFrm = new Frame(frame);
        this.adaptTestForTrain(adaptFrm, true, false);
        adaptFrm.remove(((SharedTreeParameters)this._parms)._response_column);
        adaptFrm.remove(((SharedTreeParameters)this._parms)._fold_column);
        adaptFrm.remove(((SharedTreeParameters)this._parms)._weights_column);
        adaptFrm.remove(((SharedTreeParameters)this._parms)._offset_column);
        if (((SharedTreeParameters)this._parms)._treatment_column != null) {
            adaptFrm.remove(((SharedTreeParameters)this._parms)._treatment_column);
        }
        assert (adaptFrm.numCols() == ((SharedTreeOutput)this._output).nfeatures());
        return ((ScoreFeatureFrequenciesTask)new ScoreFeatureFrequenciesTask((SharedTreeOutput)this._output).doAll(adaptFrm.numCols(), (byte)3, adaptFrm)).outputFrame(destination_key, adaptFrm.names(), null);
    }

    @Override
    protected Frame postProcessPredictions(Frame adaptedFrame, Frame predictFr, Job j2) {
        return PlattScalingHelper.postProcessPredictions(predictFr, j2, (PlattScalingHelper.OutputWithCalibration)((Object)this._output));
    }

    protected double[] score0Incremental(Score.ScoreIncInfo sii, Chunk[] chks, double offset, int row_in_chunk, double[] tmp, double[] preds) {
        return this.score0(chks, offset, row_in_chunk, tmp, preds);
    }

    @Override
    protected double[] score0(double[] data, double[] preds, double offset) {
        return this.score0(data, preds, offset, ((SharedTreeOutput)this._output)._treeKeys.length);
    }

    @Override
    protected double[] score0(double[] data, double[] preds) {
        return this.score0(data, preds, 0.0);
    }

    protected double[] score0(double[] data, double[] preds, double offset, int ntrees) {
        Arrays.fill(preds, 0.0);
        return this.score0(data, preds, offset, 0, ntrees);
    }

    protected double[] score0(double[] data, double[] preds, double offset, int startTree, int ntrees) {
        for (int tidx = startTree; tidx < ntrees; ++tidx) {
            this.score0(data, preds, tidx);
        }
        return preds;
    }

    private void score0(double[] data, double[] preds, int treeIdx) {
        Key<CompressedTree>[] keys = ((SharedTreeOutput)this._output)._treeKeys[treeIdx];
        for (int c2 = 0; c2 < keys.length; ++c2) {
            if (keys[c2] == null) continue;
            double pred = ((CompressedTree)DKV.get(keys[c2]).get()).score(data, ((SharedTreeOutput)this._output)._domains);
            assert (!Double.isInfinite(pred));
            int n2 = keys.length == 1 ? 0 : c2 + 1;
            preds[n2] = preds[n2] + pred;
        }
    }

    protected M deepClone(Key<M> result) {
        SharedTreeModel newModel = (SharedTreeModel)IcedUtils.deepCopy(this.self());
        newModel._key = result;
        ((SharedTreeOutput)newModel._output).clearModelMetrics(false);
        ((SharedTreeOutput)newModel._output)._training_metrics = null;
        ((SharedTreeOutput)newModel._output)._validation_metrics = null;
        Key<CompressedTree>[][] treeKeys = ((SharedTreeOutput)newModel._output)._treeKeys;
        for (int i2 = 0; i2 < treeKeys.length; ++i2) {
            for (int j2 = 0; j2 < treeKeys[i2].length; ++j2) {
                if (treeKeys[i2][j2] == null) continue;
                CompressedTree ct = (CompressedTree)DKV.get(treeKeys[i2][j2]).get();
                CompressedTree newCt = IcedUtils.deepCopy(ct);
                newCt._key = CompressedTree.makeTreeKey(i2, j2);
                Key key = newCt._key;
                treeKeys[i2][j2] = key;
                DKV.put(key, newCt);
            }
        }
        Key<CompressedTree>[][] treeKeysAux = ((SharedTreeOutput)newModel._output)._treeKeysAux;
        if (treeKeysAux != null) {
            for (int i3 = 0; i3 < treeKeysAux.length; ++i3) {
                for (int j3 = 0; j3 < treeKeysAux[i3].length; ++j3) {
                    if (treeKeysAux[i3][j3] == null) continue;
                    CompressedTree ct = (CompressedTree)DKV.get(treeKeysAux[i3][j3]).get();
                    CompressedTree newCt = IcedUtils.deepCopy(ct);
                    newCt._key = Key.make(GenModel.createAuxKey(treeKeys[i3][j3].toString()));
                    Key key = newCt._key;
                    treeKeysAux[i3][j3] = key;
                    DKV.put(key, newCt);
                }
            }
        }
        return (M)newModel;
    }

    @Override
    protected Futures remove_impl(Futures fs, boolean cascade) {
        Key<CompressedTree>[] ks;
        int n2;
        Key<CompressedTree>[][] keyArray = ((SharedTreeOutput)this._output)._treeKeys;
        int n3 = keyArray.length;
        for (n2 = 0; n2 < n3; ++n2) {
            for (Key<CompressedTree> k2 : ks = keyArray[n2]) {
                Keyed.remove(k2, fs, true);
            }
        }
        keyArray = ((SharedTreeOutput)this._output)._treeKeysAux;
        n3 = keyArray.length;
        for (n2 = 0; n2 < n3; ++n2) {
            for (Key<CompressedTree> k2 : ks = keyArray[n2]) {
                Keyed.remove(k2, fs, true);
            }
        }
        if (((SharedTreeOutput)this._output)._calib_model != null) {
            ((SharedTreeOutput)this._output)._calib_model.remove(fs);
        }
        return super.remove_impl(fs, cascade);
    }

    @Override
    protected AutoBuffer writeAll_impl(AutoBuffer ab) {
        Key<CompressedTree>[] ks;
        int n2;
        Key<CompressedTree>[][] keyArray = ((SharedTreeOutput)this._output)._treeKeys;
        int n3 = keyArray.length;
        for (n2 = 0; n2 < n3; ++n2) {
            for (Key<CompressedTree> k2 : ks = keyArray[n2]) {
                ab.putKey(k2);
            }
        }
        keyArray = ((SharedTreeOutput)this._output)._treeKeysAux;
        n3 = keyArray.length;
        for (n2 = 0; n2 < n3; ++n2) {
            for (Key<CompressedTree> k2 : ks = keyArray[n2]) {
                ab.putKey(k2);
            }
        }
        return super.writeAll_impl(ab);
    }

    @Override
    protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
        Key<CompressedTree>[] ks;
        int n2;
        Key<CompressedTree>[][] keyArray = ((SharedTreeOutput)this._output)._treeKeys;
        int n3 = keyArray.length;
        for (n2 = 0; n2 < n3; ++n2) {
            for (Key<CompressedTree> k2 : ks = keyArray[n2]) {
                ab.getKey(k2, fs);
            }
        }
        keyArray = ((SharedTreeOutput)this._output)._treeKeysAux;
        n3 = keyArray.length;
        for (n2 = 0; n2 < n3; ++n2) {
            for (Key<CompressedTree> k2 : ks = keyArray[n2]) {
                ab.getKey(k2, fs);
            }
        }
        return super.readAll_impl(ab, fs);
    }

    private M self() {
        return (M)this;
    }

    public SharedTreeSubgraph getSharedTreeSubgraph(int tidx, int cls) {
        if (tidx < 0 || tidx >= ((SharedTreeOutput)this._output)._ntrees) {
            throw new IllegalArgumentException("Invalid tree index: " + tidx + ". Tree index must be in range [0, " + (((SharedTreeOutput)this._output)._ntrees - 1) + "].");
        }
        Key<CompressedTree> treeKey = ((SharedTreeOutput)this._output)._treeKeysAux[tidx][cls];
        if (treeKey == null) {
            return null;
        }
        CompressedTree auxCompressedTree = treeKey.get();
        return ((SharedTreeOutput)this._output)._treeKeys[tidx][cls].get().toSharedTreeSubgraph(auxCompressedTree, ((SharedTreeOutput)this._output)._names, ((SharedTreeOutput)this._output)._domains);
    }

    @Override
    public boolean isFeatureUsedInPredict(String featureName) {
        if (featureName.equals(((SharedTreeOutput)this._output).responseName())) {
            return false;
        }
        int featureIdx = ArrayUtils.find(((SharedTreeOutput)this._output)._varimp._names, featureName);
        return featureIdx != -1 && (double)((SharedTreeOutput)this._output)._varimp._varimp[featureIdx] != 0.0;
    }

    public boolean binomialOpt() {
        return true;
    }

    @Override
    public CategoricalEncoding getGenModelEncoding() {
        switch (((SharedTreeParameters)this._parms)._categorical_encoding) {
            case AUTO: 
            case Enum: 
            case SortByResponse: {
                return CategoricalEncoding.AUTO;
            }
            case OneHotExplicit: {
                return CategoricalEncoding.OneHotExplicit;
            }
            case Binary: {
                return CategoricalEncoding.Binary;
            }
            case EnumLimited: {
                return CategoricalEncoding.EnumLimited;
            }
            case Eigen: {
                return CategoricalEncoding.Eigen;
            }
            case LabelEncoder: {
                return CategoricalEncoding.LabelEncoder;
            }
        }
        return null;
    }

    protected SharedTreePojoWriter makeTreePojoWriter() {
        throw new UnsupportedOperationException("POJO is not supported for model " + ((SharedTreeParameters)this._parms).algoName() + ".");
    }

    @Override
    protected final PojoWriter makePojoWriter() {
        CategoricalEncoding encoding = this.getGenModelEncoding();
        if (encoding == null) {
            throw new IllegalArgumentException("Only default, SortByResponse, EnumLimited and 1-hot explicit scheme is supported for POJO/MOJO");
        }
        return this.makeTreePojoWriter();
    }

    private static class ScoreFeatureFrequenciesTask
    extends MRTask<ScoreFeatureFrequenciesTask> {
        final Key<CompressedTree>[][] _treeKeys;
        final Key<CompressedTree>[][] _auxTreeKeys;
        final String[][] _domains;
        transient SharedTreeSubgraph[][] _trees;

        ScoreFeatureFrequenciesTask(SharedTreeOutput output) {
            this._treeKeys = output._treeKeys;
            this._auxTreeKeys = output._treeKeysAux;
            this._domains = output._domains;
        }

        @Override
        protected void setupLocal() {
            this._trees = new SharedTreeSubgraph[this._treeKeys.length][];
            for (int t2 = 0; t2 < this._treeKeys.length; ++t2) {
                this._trees[t2] = new SharedTreeSubgraph[this._treeKeys[t2].length];
            }
            ComputeSharedTreesFun getSharedTreesFun = new ComputeSharedTreesFun(this._trees, this._treeKeys, this._auxTreeKeys, this._fr.names(), this._domains);
            H2O.submitTask(new LocalMR((MrFun)getSharedTreesFun, this._trees.length)).join();
        }

        @Override
        public void map(Chunk[] cs, NewChunk[] ncs) {
            double[] input = new double[cs.length];
            int[] output = new int[ncs.length];
            for (int r2 = 0; r2 < cs[0]._len; ++r2) {
                int i2;
                for (i2 = 0; i2 < cs.length; ++i2) {
                    input[i2] = cs[i2].atd(r2);
                }
                Arrays.fill(output, 0);
                for (int t2 = 0; t2 < this._treeKeys.length; ++t2) {
                    for (int c2 = 0; c2 < this._treeKeys[t2].length; ++c2) {
                        if (this._treeKeys[t2][c2] == null) continue;
                        double d2 = SharedTreeMojoModel.scoreTree(this._treeKeys[t2][c2].get()._bits, input, true, this._domains);
                        String decisionPath = SharedTreeMojoModel.getDecisionPath(d2);
                        SharedTreeNode n2 = this._trees[t2][c2].walkNodes(decisionPath);
                        this.updateStats(n2, output);
                    }
                }
                for (i2 = 0; i2 < ncs.length; ++i2) {
                    ncs[i2].addNum(output[i2]);
                }
            }
        }

        private void updateStats(SharedTreeNode leaf, int[] stats) {
            for (SharedTreeNode n2 = leaf.getParent(); n2 != null; n2 = n2.getParent()) {
                int n3 = n2.getColId();
                stats[n3] = stats[n3] + 1;
            }
        }
    }

    private static class ComputeSharedTreesFun
    extends MrFun<ComputeSharedTreesFun> {
        final Key<CompressedTree>[][] _treeKeys;
        final Key<CompressedTree>[][] _auxTreeKeys;
        final String[] _names;
        final String[][] _domains;
        transient SharedTreeSubgraph[][] _trees;

        ComputeSharedTreesFun(SharedTreeSubgraph[][] trees, Key<CompressedTree>[][] treeKeys, Key<CompressedTree>[][] auxTreeKeys, String[] names, String[][] domains) {
            this._trees = trees;
            this._treeKeys = treeKeys;
            this._auxTreeKeys = auxTreeKeys;
            this._names = names;
            this._domains = domains;
        }

        @Override
        protected void map(int t2) {
            for (int c2 = 0; c2 < this._treeKeys[t2].length; ++c2) {
                if (this._treeKeys[t2][c2] == null) continue;
                this._trees[t2][c2] = SharedTreeMojoModel.computeTreeGraph(0, "T", this._treeKeys[t2][c2].get()._bits, this._auxTreeKeys[t2][c2].get()._bits, this._names, this._domains);
            }
        }
    }

    private static class UpdateAuxTreeWeightsTask
    extends MRTask<UpdateAuxTreeWeightsTask> {
        private final Key<CompressedTree>[][] _treeKeys;
        private final Key<CompressedTree>[][] _auxTreeKeys;
        private final String[][] _domains;
        private transient int[][] _maxNodeIds;
        private double[][][] _leafNodeWeights;
        private int[] _warnTrees;
        private int[] _warnClasses;

        private UpdateAuxTreeWeightsTask(SharedTreeOutput output) {
            this._treeKeys = output._treeKeys;
            this._auxTreeKeys = output._treeKeysAux;
            this._domains = output._domains;
        }

        @Override
        protected void setupLocal() {
            this._maxNodeIds = new int[this._auxTreeKeys.length][];
            for (int treeId = 0; treeId < this._auxTreeKeys.length; ++treeId) {
                Key<CompressedTree>[] classAuxTreeKeys = this._auxTreeKeys[treeId];
                this._maxNodeIds[treeId] = new int[classAuxTreeKeys.length];
                for (int classId = 0; classId < classAuxTreeKeys.length; ++classId) {
                    if (classAuxTreeKeys[classId] == null) {
                        this._maxNodeIds[treeId][classId] = -1;
                        continue;
                    }
                    CompressedTree tree = classAuxTreeKeys[classId].get();
                    assert (tree != null);
                    this._maxNodeIds[treeId][classId] = tree.findMaxNodeId();
                }
            }
        }

        protected void initMap() {
            this._leafNodeWeights = new double[this._maxNodeIds.length][][];
            for (int treeId = 0; treeId < this._maxNodeIds.length; ++treeId) {
                int[] classMaxNodeIds = this._maxNodeIds[treeId];
                this._leafNodeWeights[treeId] = new double[classMaxNodeIds.length][];
                for (int classId = 0; classId < classMaxNodeIds.length; ++classId) {
                    if (classMaxNodeIds[classId] < 0) continue;
                    this._leafNodeWeights[treeId][classId] = new double[classMaxNodeIds[classId] + 1];
                }
            }
        }

        @Override
        public void map(Chunk[] chks) {
            double[] input = new double[chks.length - 1];
            this.initMap();
            for (int row = 0; row < chks[0]._len; ++row) {
                double weight = chks[input.length].atd(row);
                if (weight == 0.0 || Double.isNaN(weight)) continue;
                for (int i2 = 0; i2 < input.length; ++i2) {
                    input[i2] = chks[i2].atd(row);
                }
                for (int tidx = 0; tidx < this._treeKeys.length; ++tidx) {
                    Key<CompressedTree>[] keys = this._treeKeys[tidx];
                    for (int cls = 0; cls < keys.length; ++cls) {
                        Key<CompressedTree> key = keys[cls];
                        if (key == null) continue;
                        CompressedTree tree = (CompressedTree)DKV.get(key).get();
                        CompressedTree auxTree = this._auxTreeKeys[tidx][cls].get();
                        assert (auxTree != null);
                        double d2 = SharedTreeMojoModel.scoreTree(tree._bits, input, true, this._domains);
                        int nodeId = SharedTreeMojoModel.getLeafNodeId(d2, auxTree._bits);
                        double[] dArray = this._leafNodeWeights[tidx][cls];
                        int n2 = nodeId;
                        dArray[n2] = dArray[n2] + weight;
                    }
                }
            }
        }

        @Override
        public void reduce(UpdateAuxTreeWeightsTask mrt) {
            ArrayUtils.add(this._leafNodeWeights, mrt._leafNodeWeights);
        }

        @Override
        protected void postGlobal() {
            this._warnTrees = new int[0];
            this._warnClasses = new int[0];
            Futures fs = new Futures();
            for (int treeId = 0; treeId < this._leafNodeWeights.length; ++treeId) {
                double[][] classWeights = this._leafNodeWeights[treeId];
                for (int classId = 0; classId < classWeights.length; ++classId) {
                    double[] nodeWeights = classWeights[classId];
                    if (nodeWeights == null) continue;
                    CompressedTree auxTree = this._auxTreeKeys[treeId][classId].get();
                    assert (auxTree != null);
                    CompressedTree updatedTree = auxTree.updateLeafNodeWeights(nodeWeights);
                    assert (auxTree._key.equals(updatedTree._key));
                    DKV.put(updatedTree, fs);
                    if (!updatedTree.hasZeroWeight()) continue;
                    this._warnTrees = ArrayUtils.append(this._warnTrees, treeId);
                    this._warnClasses = ArrayUtils.append(this._warnClasses, classId);
                }
            }
            fs.blockForPending();
            assert (this._warnTrees.length == this._warnClasses.length);
        }
    }

    private static class AssignLeafNodeIdTask
    extends AssignLeafNodeTaskBase {
        final Key<CompressedTree>[][] _auxTreeKeys;

        private AssignLeafNodeIdTask(SharedTreeOutput output) {
            super(output);
            this._auxTreeKeys = output._treeKeysAux;
        }

        @Override
        protected void initMap() {
        }

        @Override
        protected void assignNode(int tidx, int cls, CompressedTree tree, double[] input, NewChunk nc) {
            CompressedTree auxTree = this._auxTreeKeys[tidx][cls].get();
            assert (auxTree != null);
            double d2 = SharedTreeMojoModel.scoreTree(tree._bits, input, true, this._domains);
            int nodeId = SharedTreeMojoModel.getLeafNodeId(d2, auxTree._bits);
            nc.addNum(nodeId, 0);
        }

        @Override
        protected Frame execute(Frame adaptFrm, String[] names, Key<Frame> destKey) {
            Frame result = ((AssignLeafNodeTaskBase)this.doAll(names.length, (byte)3, adaptFrm)).outputFrame(destKey, names, null);
            if (result.vec(0).min() < 0.0) {
                LOG.warn((Object)"Some of the observations were not assigned a Leaf Node ID (-1), only tree-paths up to length 64 are supported.");
            }
            return result;
        }
    }

    private static class AssignTreePathTask
    extends AssignLeafNodeTaskBase {
        private transient BufStringDecisionPathTracker _tr;

        private AssignTreePathTask(SharedTreeOutput output) {
            super(output);
        }

        @Override
        protected void initMap() {
            this._tr = new BufStringDecisionPathTracker();
        }

        @Override
        protected void assignNode(int tidx, int cls, CompressedTree tree, double[] input, NewChunk nc) {
            BufferedString pred = tree.getDecisionPath(input, this._domains, this._tr);
            nc.addStr(pred);
        }

        @Override
        protected Frame execute(Frame adaptFrm, String[] names, Key<Frame> destKey) {
            Frame res = ((AssignLeafNodeTaskBase)this.doAll(names.length, (byte)2, adaptFrm)).outputFrame(destKey, names, null);
            Vec[] nvecs = new Vec[res.vecs().length];
            boolean hasInvalidPaths = false;
            for (int c2 = 0; c2 < res.vecs().length; ++c2) {
                Vec vv = res.vec(c2);
                try {
                    hasInvalidPaths = hasInvalidPaths || vv.naCnt() > 0L;
                    nvecs[c2] = vv.toCategoricalVec();
                    continue;
                }
                catch (Exception e2) {
                    VecUtils.deleteVecs(nvecs, c2);
                    throw e2;
                }
            }
            res.delete();
            res = new Frame(destKey, names, nvecs);
            if (destKey != null) {
                DKV.put(res);
            }
            if (hasInvalidPaths) {
                LOG.warn((Object)"Some of the leaf node assignments were skipped (NA), only tree-paths up to length 64 are supported.");
            }
            return res;
        }
    }

    private static abstract class AssignLeafNodeTaskBase
    extends MRTask<AssignLeafNodeTaskBase> {
        final Key<CompressedTree>[][] _treeKeys;
        final String[][] _domains;

        AssignLeafNodeTaskBase(SharedTreeOutput output) {
            this._treeKeys = output._treeKeys;
            this._domains = output._domains;
        }

        protected abstract void initMap();

        protected abstract void assignNode(int var1, int var2, CompressedTree var3, double[] var4, NewChunk var5);

        @Override
        public void map(Chunk[] chks, NewChunk[] ncs) {
            double[] input = new double[chks.length];
            this.initMap();
            for (int row = 0; row < chks[0]._len; ++row) {
                for (int i2 = 0; i2 < chks.length; ++i2) {
                    input[i2] = chks[i2].atd(row);
                }
                int col = 0;
                for (int tidx = 0; tidx < this._treeKeys.length; ++tidx) {
                    Key<CompressedTree>[] keys = this._treeKeys[tidx];
                    for (int cls = 0; cls < keys.length; ++cls) {
                        Key<CompressedTree> key = keys[cls];
                        if (key == null) continue;
                        CompressedTree tree = (CompressedTree)DKV.get(key).get();
                        this.assignNode(tidx, cls, tree, input, ncs[col++]);
                    }
                }
                assert (col == ncs.length);
            }
        }

        protected abstract Frame execute(Frame var1, String[] var2, Key<Frame> var3);

        private static AssignLeafNodeTaskBase make(SharedTreeOutput modelOutput, Model.LeafNodeAssignment.LeafNodeAssignmentType type) {
            switch (type) {
                case Path: {
                    return new AssignTreePathTask(modelOutput);
                }
                case Node_ID: {
                    return new AssignLeafNodeIdTask(modelOutput);
                }
            }
            throw new UnsupportedOperationException("Unknown leaf node assignment type: " + (Object)((Object)type));
        }
    }

    public static class BufStringDecisionPathTracker
    implements SharedTreeMojoModel.DecisionPathTracker<BufferedString> {
        private final byte[] _buf = new byte[64];
        private final BufferedString _bs = new BufferedString(this._buf, 0, 0);
        private int _pos = 0;

        @Override
        public boolean go(int depth, boolean right) {
            int n2 = this._buf[depth] = right ? 82 : 76;
            if (right) {
                this._pos = depth;
            }
            return true;
        }

        @Override
        public BufferedString terminate() {
            this._bs.setLen(this._pos);
            this._pos = 0;
            return this._bs;
        }

        @Override
        public BufferedString invalidPath() {
            return null;
        }
    }

    public static abstract class SharedTreeOutput
    extends Model.Output
    implements Model.GetNTrees,
    PlattScalingHelper.OutputWithCalibration {
        public double _init_f;
        public int _ntrees = 0;
        public final TreeStats _treeStats;
        public Key<CompressedTree>[][] _treeKeys;
        public Key<CompressedTree>[][] _treeKeysAux;
        public ScoreKeeper[] _scored_train;
        public ScoreKeeper[] _scored_valid;
        public long[] _training_time_ms = new long[]{System.currentTimeMillis()};
        public TwoDimTable _variable_importances;
        public VarImp _varimp;
        public GLMModel _calib_model;

        public ScoreKeeper[] scoreKeepers() {
            ScoreKeeper[] ska;
            ArrayList<ScoreKeeper> skl = new ArrayList<ScoreKeeper>();
            for (ScoreKeeper sk : ska = this._validation_metrics != null ? this._scored_valid : this._scored_train) {
                if (sk.isEmpty()) continue;
                skl.add(sk);
            }
            return skl.toArray(new ScoreKeeper[skl.size()]);
        }

        @Override
        public TwoDimTable getVariableImportances() {
            return this._variable_importances;
        }

        public SharedTreeOutput(SharedTree b2) {
            super(b2);
            this._treeKeys = new Key[this._ntrees][];
            this._treeKeysAux = new Key[this._ntrees][];
            this._treeStats = new TreeStats();
            this._scored_train = new ScoreKeeper[]{new ScoreKeeper(Double.NaN)};
            this._scored_valid = new ScoreKeeper[]{new ScoreKeeper(Double.NaN)};
            this._modelClassDist = this._priorClassDist;
        }

        @Override
        public TwoDimTable createInputFramesInformationTable(ModelBuilder modelBuilder) {
            SharedTreeParameters params = (SharedTreeParameters)modelBuilder._parms;
            TwoDimTable table = super.createInputFramesInformationTable(modelBuilder);
            table.set(2, 0, "calibration_frame");
            table.set(2, 1, params.getCalibrationFrame() != null ? params.getCalibrationFrame().checksum() : -1L);
            table.set(2, 2, params.getCalibrationFrame() != null ? Arrays.toString(params.getCalibrationFrame().anyVec().espc()) : Integer.valueOf(-1));
            return table;
        }

        @Override
        public int getInformationTableNumRows() {
            return super.getInformationTableNumRows() + 1;
        }

        public void addKTrees(DTree[] trees) {
            assert (this.nclasses() == trees.length);
            this._treeKeys = (Key[][])Arrays.copyOf(this._treeKeys, this._ntrees + 1);
            this._treeKeysAux = (Key[][])Arrays.copyOf(this._treeKeysAux, this._ntrees + 1);
            this._treeKeys[this._ntrees] = new Key[trees.length];
            Key[] keys = this._treeKeys[this._ntrees];
            this._treeKeysAux[this._ntrees] = new Key[trees.length];
            Key[] keysAux = this._treeKeysAux[this._ntrees];
            Futures fs = new Futures();
            for (int i2 = 0; i2 < this.nclasses(); ++i2) {
                if (trees[i2] == null) continue;
                CompressedTree ct = trees[i2].compress(this._ntrees, i2, this._domains);
                keys[i2] = ct._key;
                DKV.put(keys[i2], ct, fs);
                this._treeStats.updateBy(trees[i2]);
                CompressedTree ctAux = new CompressedTree(trees[i2]._abAux.buf(), -1L, -1, -1);
                keysAux[i2] = ctAux._key = Key.make(GenModel.createAuxKey(ct._key.toString()));
                DKV.put(ctAux, fs);
            }
            ++this._ntrees;
            this._scored_train = ArrayUtils.copyAndFillOf(this._scored_train, this._ntrees + 1, new ScoreKeeper());
            this._scored_valid = this._scored_valid != null ? ArrayUtils.copyAndFillOf(this._scored_valid, this._ntrees + 1, new ScoreKeeper()) : null;
            this._training_time_ms = ArrayUtils.copyAndFillOf(this._training_time_ms, this._ntrees + 1, System.currentTimeMillis());
            fs.blockForPending();
        }

        public void trimTo(int ntrees) {
            Futures fs = new Futures();
            for (int i2 = ntrees; i2 < this._treeKeys.length; ++i2) {
                for (int tc = 0; tc < this._treeKeys[i2].length; ++tc) {
                    if (this._treeKeys[i2][tc] == null) continue;
                    DKV.remove(this._treeKeys[i2][tc], fs);
                    DKV.remove(this._treeKeysAux[i2][tc], fs);
                }
            }
            this._ntrees = ntrees;
            this._treeKeys = (Key[][])Arrays.copyOf(this._treeKeys, this._ntrees);
            this._treeKeysAux = (Key[][])Arrays.copyOf(this._treeKeysAux, this._ntrees);
            this._scored_train = Arrays.copyOf(this._scored_train, this._ntrees + 1);
            this._scored_valid = this._scored_valid != null ? Arrays.copyOf(this._scored_valid, this._ntrees + 1) : null;
            this._training_time_ms = Arrays.copyOf(this._training_time_ms, this._ntrees + 1);
            this._model_summary = SharedTree.createModelSummaryTable(this._ntrees, this._treeStats);
            fs.blockForPending();
        }

        @Override
        public int getNTrees() {
            return this._ntrees;
        }

        @Override
        public GLMModel calibrationModel() {
            return this._calib_model;
        }

        public CompressedTree ctree(int tnum, int knum) {
            return this._treeKeys[tnum][knum].get();
        }

        public String toStringTree(int tnum, int knum) {
            return this.ctree(tnum, knum).toString(this);
        }
    }

    public static abstract class SharedTreeParameters
    extends Model.Parameters
    implements Model.GetNTrees,
    PlattScalingHelper.ParamsWithCalibration {
        public int _ntrees = 50;
        public int _max_depth = 5;
        public double _min_rows = 10.0;
        public int _nbins = 20;
        public int _nbins_cats = 1024;
        public double _min_split_improvement = 1.0E-5;
        public HistogramType _histogram_type = HistogramType.AUTO;
        public double _r2_stopping = Double.MAX_VALUE;
        public int _nbins_top_level = 1024;
        public boolean _build_tree_one_node = false;
        public int _score_tree_interval = 0;
        public int _initial_score_interval = 4000;
        public int _score_interval = 4000;
        public double _sample_rate = 0.632;
        public double[] _sample_rate_per_class;
        public boolean _calibrate_model = false;
        public Key<Frame> _calibration_frame;
        public double _col_sample_rate_change_per_level = 1.0;
        public double _col_sample_rate_per_tree = 1.0;
        public boolean _parallel_main_model_building = false;
        public boolean _use_best_cv_iteration = true;
        static final String[] CHECKPOINT_NON_MODIFIABLE_FIELDS = new String[]{"_build_tree_one_node", "_sample_rate", "_max_depth", "_min_rows", "_nbins", "_nbins_cats", "_nbins_top_level"};

        public boolean useRowSampling() {
            return this._sample_rate < 1.0 || this._sample_rate_per_class != null;
        }

        @Override
        public long progressUnits() {
            return this._ntrees + (this._histogram_type == HistogramType.QuantilesGlobal || this._histogram_type == HistogramType.RoundRobin ? 1 : 0);
        }

        public boolean useColSampling() {
            return this._col_sample_rate_change_per_level != 1.0 || this._col_sample_rate_per_tree != 1.0;
        }

        public boolean isStochastic() {
            return this.useRowSampling() || this.useColSampling();
        }

        @Override
        public int getNTrees() {
            return this._ntrees;
        }

        @Override
        public Frame getCalibrationFrame() {
            return this._calibration_frame == null ? null : this._calibration_frame.get();
        }

        @Override
        public boolean calibrateModel() {
            return this._calibrate_model;
        }

        @Override
        public Model.Parameters getParams() {
            return this;
        }

        public boolean forceStrictlyReproducibleHistograms() {
            return false;
        }

        public static enum HistogramType {
            AUTO,
            UniformAdaptive,
            Random,
            QuantilesGlobal,
            RoundRobin,
            UniformRobust;

            public static HistogramType[] ROUND_ROBIN_CANDIDATES;

            static {
                ROUND_ROBIN_CANDIDATES = new HistogramType[]{AUTO, UniformAdaptive, Random, QuantilesGlobal};
            }
        }
    }
}

