/*
 * Decompiled with CFR 0.152.
 */
package hex;

import hex.AUC2;
import hex.ConfusionMatrix;
import hex.CustomMetric;
import hex.GainsLift;
import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsSupervised;
import hex.genmodel.GenModel;
import hex.genmodel.utils.DistributionFamily;
import java.util.Arrays;
import java.util.Optional;
import water.Key;
import water.MRTask;
import water.Scope;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.C8DVolatileChunk;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.MathUtils;

public class ModelMetricsBinomial
extends ModelMetricsSupervised {
    public final AUC2 _auc;
    public final double _logloss;
    public double _mean_per_class_error;
    public final GainsLift _gainsLift;

    public ModelMetricsBinomial(Model model, Frame frame, long nobs, double mse, String[] domain, double sigma, AUC2 auc, double logloss, GainsLift gainsLift, CustomMetric customMetric) {
        super(model, frame, nobs, mse, domain, sigma, customMetric);
        this._auc = auc;
        this._logloss = logloss;
        this._gainsLift = gainsLift;
        this._mean_per_class_error = this.cm() == null ? Double.NaN : this.cm().mean_per_class_error();
    }

    public static ModelMetricsBinomial getFromDKV(Model model, Frame frame) {
        ModelMetrics mm4 = ModelMetrics.getFromDKV(model, frame);
        if (!(mm4 instanceof ModelMetricsBinomial)) {
            throw new H2OIllegalArgumentException("Expected to find a Binomial ModelMetrics for model: " + model._key.toString() + " and frame: " + frame._key.toString(), "Expected to find a ModelMetricsBinomial for model: " + model._key.toString() + " and frame: " + frame._key.toString() + " but found a: " + (mm4 == null ? null : mm4.getClass()));
        }
        return (ModelMetricsBinomial)mm4;
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        if (this._auc != null) {
            sb.append(" AUC: " + (float)this._auc._auc + "\n");
            sb.append(" pr_auc: " + (float)this._auc.pr_auc() + "\n");
        }
        sb.append(" logloss: " + (float)this._logloss + "\n");
        sb.append(" mean_per_class_error: " + (float)this._mean_per_class_error + "\n");
        sb.append(" default threshold: " + (this._auc == null ? 0.5 : (double)((float)this._auc.defaultThreshold())) + "\n");
        if (this.cm() != null) {
            sb.append(" CM: " + this.cm().toASCII());
        }
        if (this._gainsLift != null) {
            sb.append(this._gainsLift);
        }
        return sb.toString();
    }

    public double logloss() {
        return this._logloss;
    }

    public double mean_per_class_error() {
        return this._mean_per_class_error;
    }

    @Override
    public AUC2 auc_obj() {
        return this._auc;
    }

    @Override
    public ConfusionMatrix cm() {
        if (this._auc == null) {
            return null;
        }
        double[][] cm = this._auc.defaultCM();
        return cm == null ? null : new ConfusionMatrix(cm, this._domain);
    }

    public ConfusionMatrix cm(AUC2.ThresholdCriterion criterion) {
        if (this._auc == null) {
            return null;
        }
        double[][] cm = this._auc.cmByCriterion(criterion);
        return cm == null ? null : new ConfusionMatrix(cm, this._domain);
    }

    public GainsLift gainsLift() {
        return this._gainsLift;
    }

    public double auc() {
        return this.auc_obj()._auc;
    }

    public double pr_auc() {
        return this.auc_obj()._pr_auc;
    }

    public double aucpr() {
        return this.auc_obj()._pr_auc;
    }

    public double lift_top_group() {
        return this.gainsLift().response_rates[0] / this.gainsLift().avg_response_rate;
    }

    public static ModelMetricsBinomial make(Vec targetClassProbs, Vec actualLabels) {
        return ModelMetricsBinomial.make(targetClassProbs, actualLabels, actualLabels.domain());
    }

    public static ModelMetricsBinomial make(Vec targetClassProbs, Vec actualLabels, String[] domain) {
        return ModelMetricsBinomial.make(targetClassProbs, actualLabels, null, domain);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static ModelMetricsBinomial make(Vec targetClassProbs, Vec actualLabels, Vec weights, String[] domain) {
        Scope.enter();
        try {
            Vec labels = actualLabels.toCategoricalVec();
            if (domain == null) {
                domain = labels.domain();
            }
            if (labels == null || targetClassProbs == null) {
                throw new IllegalArgumentException("Missing actualLabels or predictedProbs for binomial metrics!");
            }
            if (!targetClassProbs.isNumeric()) {
                throw new IllegalArgumentException("Predicted probabilities must be numeric per-class probabilities for binomial metrics.");
            }
            if (targetClassProbs.min() < 0.0 || targetClassProbs.max() > 1.0) {
                throw new IllegalArgumentException("Predicted probabilities must be between 0 and 1 for binomial metrics.");
            }
            if (domain.length != 2) {
                throw new IllegalArgumentException("Domain must have 2 class labels, but is " + Arrays.toString(domain) + " for binomial metrics.");
            }
            if ((labels = labels.adaptTo(domain)).cardinality() != 2) {
                throw new IllegalArgumentException("Adapted domain must have 2 class labels, but is " + Arrays.toString(labels.domain()) + " for binomial metrics.");
            }
            Frame fr = new Frame(targetClassProbs);
            fr.add("labels", labels);
            if (weights != null) {
                fr.add("weights", weights);
            }
            MetricBuilderBinomial mb = ((BinomialMetrics)new BinomialMetrics((String[])labels.domain()).doAll((Frame)fr))._mb;
            labels.remove();
            Frame preds = new Frame(targetClassProbs);
            ModelMetricsBinomial mm4 = (ModelMetricsBinomial)mb.makeModelMetrics(null, fr, preds, fr.vec("labels"), fr.vec("weights"));
            mm4._description = "Computed on user-given predictions and labels, using F1-optimal threshold: " + mm4.auc_obj().defaultThreshold() + ".";
            ModelMetricsBinomial modelMetricsBinomial = mm4;
            return modelMetricsBinomial;
        }
        finally {
            Scope.exit(new Key[0]);
        }
    }

    public static class MetricBuilderBinomial<T extends MetricBuilderBinomial<T>>
    extends ModelMetricsSupervised.MetricBuilderSupervised<T> {
        protected double _logloss;
        protected AUC2.AUCBuilder _auc;

        public MetricBuilderBinomial() {
        }

        public MetricBuilderBinomial(String[] domain) {
            super(2, domain);
            this._auc = new AUC2.AUCBuilder(400);
        }

        public double auc() {
            return new AUC2((AUC2.AUCBuilder)this._auc)._auc;
        }

        public double pr_auc() {
            return new AUC2((AUC2.AUCBuilder)this._auc)._pr_auc;
        }

        @Override
        public double[] perRow(double[] ds, float[] yact, Model m4) {
            return this.perRow(ds, yact, 1.0, 0.0, m4);
        }

        @Override
        public double[] perRow(double[] ds, float[] yact, double w2, double o2, Model m4) {
            boolean quasibinomial;
            if (Float.isNaN(yact[0])) {
                return ds;
            }
            if (ArrayUtils.hasNaNs(ds)) {
                return ds;
            }
            if (w2 == 0.0 || Double.isNaN(w2)) {
                return ds;
            }
            int iact = (int)yact[0];
            boolean bl = quasibinomial = m4 != null && ((Model.Parameters)m4._parms)._distribution == DistributionFamily.quasibinomial;
            if (quasibinomial) {
                if (yact[0] != 0.0f) {
                    iact = this._domain[0].equals(String.valueOf((int)yact[0])) ? 0 : 1;
                }
                this._wY += w2 * (double)yact[0];
                this._wYY += w2 * (double)yact[0] * (double)yact[0];
                double err = (double)yact[0] - ds[iact + 1];
                this._sumsqe += w2 * err * err;
                this._logloss += -w2 * ((double)yact[0] * Math.log(Math.max(1.0E-15, ds[2])) + (double)(1.0f - yact[0]) * Math.log(Math.max(1.0E-15, ds[1])));
            } else {
                if (iact != 0 && iact != 1) {
                    return ds;
                }
                this._wY += w2 * (double)iact;
                this._wYY += w2 * (double)iact * (double)iact;
                double err = iact + 1 < ds.length ? 1.0 - ds[iact + 1] : 1.0;
                this._sumsqe += w2 * err * err;
                this._logloss += w2 * MathUtils.logloss(err);
            }
            ++this._count;
            this._wcount += w2;
            assert (!Double.isNaN(this._sumsqe));
            this._auc.perRow(ds[2], iact, w2);
            return ds;
        }

        @Override
        public void reduce(T mb) {
            super.reduce(mb);
            this._logloss += ((MetricBuilderBinomial)mb)._logloss;
            this._auc.reduce(((MetricBuilderBinomial)mb)._auc);
        }

        @Override
        public ModelMetrics makeModelMetrics(Model m4, Frame f2, Frame frameWithWeights, Frame preds) {
            Vec resp = null;
            Vec weight = null;
            if (this._wcount > 0.0 && preds != null) {
                if (frameWithWeights == null) {
                    frameWithWeights = f2;
                }
                Vec vec = resp = m4 == null && frameWithWeights.vec(f2.numCols() - 1).isCategorical() ? frameWithWeights.vec(f2.numCols() - 1) : frameWithWeights.vec(((Model.Parameters)m4._parms)._response_column);
                if (resp != null) {
                    weight = m4 == null ? null : frameWithWeights.vec(((Model.Parameters)m4._parms)._weights_column);
                }
            }
            return this.makeModelMetrics(m4, f2, preds, resp, weight);
        }

        private ModelMetrics makeModelMetrics(Model m4, Frame f2, Frame preds, Vec resp, Vec weight) {
            Optional<GainsLift> optionalGainsLift;
            GainsLift gl = null;
            if (this._wcount > 0.0 && preds != null && resp != null && (optionalGainsLift = this.calculateGainsLift(m4, preds, resp, weight)).isPresent()) {
                gl = optionalGainsLift.get();
            }
            return this.makeModelMetrics(m4, f2, gl);
        }

        private ModelMetrics makeModelMetrics(Model m4, Frame f2, GainsLift gl) {
            AUC2 auc;
            double mse = Double.NaN;
            double logloss = Double.NaN;
            double sigma = Double.NaN;
            if (this._wcount > 0.0) {
                sigma = this.weightedSigma();
                mse = this._sumsqe / this._wcount;
                logloss = this._logloss / this._wcount;
                auc = new AUC2(this._auc);
            } else {
                auc = new AUC2();
            }
            ModelMetricsBinomial mm4 = new ModelMetricsBinomial(m4, f2, this._count, mse, this._domain, sigma, auc, logloss, gl, this._customMetric);
            if (m4 != null) {
                m4.addModelMetrics(mm4);
            }
            return mm4;
        }

        private Optional<GainsLift> calculateGainsLift(Model m4, Frame preds, Vec resp, Vec weights) {
            GainsLift gl = new GainsLift(preds.lastVec(), resp, weights);
            if (m4 != null && ((Model.Parameters)m4._parms)._gainslift_bins < -1) {
                throw new IllegalArgumentException("Number of G/L bins must be greater or equal than -1.");
            }
            if (m4 != null && (((Model.Parameters)m4._parms)._gainslift_bins > 0 || ((Model.Parameters)m4._parms)._gainslift_bins == -1)) {
                gl._groups = ((Model.Parameters)m4._parms)._gainslift_bins;
            } else if (m4 != null && ((Model.Parameters)m4._parms)._gainslift_bins == 0) {
                return Optional.empty();
            }
            gl.exec(m4 != null ? ((Model.Output)m4._output)._job : null);
            return Optional.of(gl);
        }

        @Override
        public Frame makePredictionCache(Model m4, Vec response) {
            return new Frame(response.makeVolatileDoubles(1));
        }

        @Override
        public void cachePrediction(double[] cdist, Chunk[] chks, int row, int cacheChunkIdx, Model m4) {
            assert (cdist.length == 3);
            ((C8DVolatileChunk)chks[cacheChunkIdx]).getValues()[row] = cdist[cdist.length - 1];
        }

        public String toString() {
            if (this._wcount == 0.0) {
                return "empty, no rows";
            }
            return "auc = " + MathUtils.roundToNDigits(this.auc(), 3) + ", logloss = " + this._logloss / this._wcount;
        }
    }

    private static class BinomialMetrics
    extends MRTask<BinomialMetrics> {
        String[] domain;
        public MetricBuilderBinomial _mb;

        public BinomialMetrics(String[] domain) {
            this.domain = domain;
        }

        @Override
        public void map(Chunk[] chks) {
            this._mb = new MetricBuilderBinomial(this.domain);
            Chunk actuals = chks[1];
            Chunk weights = chks.length == 3 ? chks[2] : null;
            double[] ds = new double[3];
            float[] acts = new float[1];
            for (int i2 = 0; i2 < chks[0]._len; ++i2) {
                ds[2] = chks[0].atd(i2);
                ds[1] = 1.0 - ds[2];
                ds[0] = GenModel.getPrediction(ds, null, ds, Double.NaN);
                acts[0] = (float)actuals.atd(i2);
                double weight = weights != null ? weights.atd(i2) : 1.0;
                this._mb.perRow(ds, acts, weight, 0.0, null);
            }
        }

        @Override
        public void reduce(BinomialMetrics mrt) {
            this._mb.reduce(mrt._mb);
        }
    }
}

