/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.drf;

import hex.Model;
import hex.tree.CompressedForest;
import hex.tree.SharedTreeModel;
import hex.tree.SharedTreeModelWithContributions;
import hex.tree.SharedTreePojoWriter;
import hex.tree.drf.DRF;
import hex.tree.drf.DrfMojoWriter;
import hex.tree.drf.DrfPojoWriter;
import hex.util.EffectiveParametersUtils;
import water.Key;
import water.fvec.NewChunk;
import water.util.MathUtils;

public class DRFModel
extends SharedTreeModelWithContributions<DRFModel, DRFParameters, DRFOutput> {
    public DRFModel(Key<DRFModel> selfKey, DRFParameters parms, DRFOutput output) {
        super(selfKey, parms, output);
    }

    @Override
    public void initActualParamValues() {
        super.initActualParamValues();
        EffectiveParametersUtils.initFoldAssignment(this._parms);
        EffectiveParametersUtils.initHistogramType((SharedTreeModel.SharedTreeParameters)this._parms);
        EffectiveParametersUtils.initCategoricalEncoding(this._parms, Model.Parameters.CategoricalEncodingScheme.Enum);
    }

    public void initActualParamValuesAfterOutputSetup(boolean isClassifier) {
        EffectiveParametersUtils.initStoppingMetric(this._parms, isClassifier);
    }

    @Override
    protected SharedTreeModelWithContributions.ScoreContributionsTask getScoreContributionsTask(SharedTreeModel model) {
        return new ScoreContributionsTaskDRF((SharedTreeModel)this);
    }

    @Override
    protected SharedTreeModelWithContributions.ScoreContributionsTask getScoreContributionsSoringTask(SharedTreeModel model, Model.Contributions.ContributionsOptions options) {
        return new ScoreContributionsSoringTaskDRF((SharedTreeModel)this, options);
    }

    @Override
    public boolean binomialOpt() {
        return !((DRFParameters)this._parms)._binomial_double_trees;
    }

    @Override
    protected double[] score0(double[] data, double[] preds, double offset, int ntrees) {
        super.score0(data, preds, offset, ntrees);
        int N2 = ((DRFOutput)this._output)._ntrees;
        if (((DRFOutput)this._output).nclasses() == 1) {
            if (N2 >= 1) {
                preds[0] = preds[0] / (double)N2;
            }
        } else if (((DRFOutput)this._output).nclasses() == 2 && this.binomialOpt()) {
            if (N2 >= 1) {
                preds[1] = preds[1] / (double)N2;
            }
            preds[2] = 1.0 - preds[1];
        } else {
            double sum = MathUtils.sum(preds);
            if (sum > 0.0) {
                MathUtils.div(preds, sum);
            }
        }
        return preds;
    }

    @Override
    protected SharedTreePojoWriter makeTreePojoWriter() {
        CompressedForest compressedForest = new CompressedForest(((DRFOutput)this._output)._treeKeys, ((DRFOutput)this._output)._domains);
        CompressedForest.LocalCompressedForest localCompressedForest = compressedForest.fetch();
        return new DrfPojoWriter(this, localCompressedForest._trees);
    }

    @Override
    public DrfMojoWriter getMojo() {
        return new DrfMojoWriter(this);
    }

    public class ScoreContributionsSoringTaskDRF
    extends SharedTreeModelWithContributions.ScoreContributionsSortingTask {
        public ScoreContributionsSoringTaskDRF(SharedTreeModel model, Model.Contributions.ContributionsOptions options) {
            super(DRFModel.this, model, options);
        }

        @Override
        public void doModelSpecificComputation(float[] contribs) {
            for (int i2 = 0; i2 < contribs.length; ++i2) {
                if (this._output.nclasses() == 1) {
                    contribs[i2] = contribs[i2] / (float)this._output._ntrees;
                    continue;
                }
                float featurePlusBiasRatio = 1.0f / (float)(this._output.nfeatures() + 1);
                contribs[i2] = featurePlusBiasRatio - contribs[i2] / (float)this._output._ntrees;
            }
        }
    }

    public class ScoreContributionsTaskDRF
    extends SharedTreeModelWithContributions.ScoreContributionsTask {
        public ScoreContributionsTaskDRF(SharedTreeModel model) {
            super(DRFModel.this, model);
        }

        @Override
        public void addContribToNewChunk(float[] contribs, NewChunk[] nc) {
            for (int i2 = 0; i2 < nc.length; ++i2) {
                if (this._output.nclasses() == 1) {
                    nc[i2].addNum(contribs[i2] / (float)this._output._ntrees);
                    continue;
                }
                float featurePlusBiasRatio = 1.0f / (float)(this._output.nfeatures() + 1);
                nc[i2].addNum(featurePlusBiasRatio - contribs[i2] / (float)this._output._ntrees);
            }
        }
    }

    public static class DRFOutput
    extends SharedTreeModel.SharedTreeOutput {
        public DRFOutput(DRF b2) {
            super(b2);
        }
    }

    public static class DRFParameters
    extends SharedTreeModel.SharedTreeParameters {
        public boolean _binomial_double_trees = false;
        public int _mtries = -1;

        @Override
        public String algoName() {
            return "DRF";
        }

        @Override
        public String fullName() {
            return "Distributed Random Forest";
        }

        @Override
        public String javaName() {
            return DRFModel.class.getName();
        }

        public DRFParameters() {
            this._max_depth = 20;
            this._min_rows = 1.0;
        }
    }
}

