/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.drf;

import hex.Model;
import hex.genmodel.CategoricalEncoding;
import hex.tree.CompressedTree;
import hex.tree.SharedTreePojoWriter;
import hex.tree.drf.DRFModel;
import water.util.SBPrintStream;

class DrfPojoWriter
extends SharedTreePojoWriter {
    private final boolean _balance_classes;

    DrfPojoWriter(DRFModel model, CompressedTree[][] trees) {
        super(model._key, model._output, model.getGenModelEncoding(), model.binomialOpt(), trees, ((DRFModel.DRFOutput)model._output)._treeStats);
        this._balance_classes = ((DRFModel.DRFParameters)model._parms)._balance_classes;
    }

    DrfPojoWriter(Model<?, ?, ?> model, CategoricalEncoding encoding, boolean binomialOpt, CompressedTree[][] trees, boolean balanceClasses) {
        super(model._key, (Model.Output)model._output, encoding, binomialOpt, trees, null);
        this._balance_classes = balanceClasses;
    }

    @Override
    protected void toJavaUnifyPreds(SBPrintStream body) {
        if (this._output.nclasses() == 1) {
            body.ip("preds[0] /= " + this._trees.length + ";").nl();
        } else {
            if (this._output.nclasses() == 2 && this._binomialOpt) {
                body.ip("preds[1] /= " + this._trees.length + ";").nl();
                body.ip("preds[2] = 1.0 - preds[1];").nl();
            } else {
                body.ip("double sum = 0;").nl();
                body.ip("for(int i=1; i<preds.length; i++) { sum += preds[i]; }").nl();
                body.ip("if (sum>0) for(int i=1; i<preds.length; i++) { preds[i] /= sum; }").nl();
            }
            if (this._balance_classes) {
                body.ip("hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);").nl();
            }
            body.ip("preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, " + this._output.defaultThreshold() + ");").nl();
        }
    }
}

