/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.error.mlc;

import java.io.PrintWriter;
import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.data.type.primitive.NominalAttrType;
import si.ijs.kt.clus.error.BinaryPredictionList;
import si.ijs.kt.clus.error.ROCAndPRCurve;
import si.ijs.kt.clus.error.common.ClusError;
import si.ijs.kt.clus.error.common.ClusErrorList;
import si.ijs.kt.clus.error.common.ClusNominalError;
import si.ijs.kt.clus.statistic.ClassificationStat;
import si.ijs.kt.clus.statistic.ClusStatistic;

public abstract class MLROCAndPRCurve
extends ClusNominalError {
    protected static final int averageAUROC = 0;
    protected static final int averageAUPRC = 1;
    protected static final int weightedAUPRC = 2;
    protected static final int pooledAUPRC = 3;
    public static final long serialVersionUID = 1L;
    protected double m_AreaROC;
    protected double m_AreaPR;
    protected double[] m_Thresholds;
    protected transient boolean m_ExtendPR;
    protected transient BinaryPredictionList m_Values;
    protected BinaryPredictionList[] m_ClassWisePredictions = new BinaryPredictionList[this.m_Dim];
    protected ROCAndPRCurve[] m_ROCAndPRCurves = new ROCAndPRCurve[this.m_Dim];
    protected double m_AverageAUROC = -1.0;
    protected double m_AverageAUPRC = -1.0;
    protected double m_WAvgAUPRC = -1.0;
    protected double m_PooledAUPRC = -1.0;

    public MLROCAndPRCurve(ClusErrorList par, NominalAttrType[] nom) {
        super(par, nom);
        for (int i = 0; i < this.m_Dim; ++i) {
            BinaryPredictionList predlist;
            this.m_ClassWisePredictions[i] = predlist = new BinaryPredictionList();
            this.m_ROCAndPRCurves[i] = new ROCAndPRCurve(predlist);
        }
    }

    public BinaryPredictionList[] getClassWisePredictions() {
        return this.m_ClassWisePredictions;
    }

    @Override
    public void add(ClusError other) {
        MLROCAndPRCurve castedOther = (MLROCAndPRCurve)other;
        BinaryPredictionList[] otherCWP = castedOther.getClassWisePredictions();
        for (int i = 0; i < this.m_Dim; ++i) {
            this.m_ClassWisePredictions[i].add(otherCWP[i]);
        }
    }

    @Override
    public boolean shouldBeLow() {
        return false;
    }

    @Override
    public void reset() {
        this.m_AreaROC = -1.0;
        this.m_AreaPR = -1.0;
        this.m_Values.clear();
        this.m_ClassWisePredictions = new BinaryPredictionList[this.m_Dim];
        this.m_ROCAndPRCurves = new ROCAndPRCurve[this.m_Dim];
        for (int i = 0; i < this.m_Dim; ++i) {
            BinaryPredictionList predlist;
            this.m_ClassWisePredictions[i] = predlist = new BinaryPredictionList();
            this.m_ROCAndPRCurves[i] = new ROCAndPRCurve(predlist);
        }
        this.m_AverageAUROC = -1.0;
        this.m_AverageAUPRC = -1.0;
        this.m_WAvgAUPRC = -1.0;
        this.m_PooledAUPRC = -1.0;
    }

    public void showSummaryError(PrintWriter out, boolean detail) {
        this.showModelError(out, detail ? 1 : 0);
    }

    @Override
    public double getModelError() {
        throw new RuntimeException("This must be implemented by a subclas.");
    }

    public double getModelError(int typeOfCurve) {
        this.computeAll();
        switch (typeOfCurve) {
            case 0: {
                return this.m_AverageAUROC;
            }
            case 1: {
                return this.m_AverageAUPRC;
            }
            case 2: {
                return this.m_WAvgAUPRC;
            }
            case 3: {
                return this.m_PooledAUPRC;
            }
        }
        throw new RuntimeException("Unknown type of curve: typeOfCurve" + typeOfCurve);
    }

    public void computeAll() {
        BinaryPredictionList pooled = new BinaryPredictionList();
        ROCAndPRCurve pooledCurve = new ROCAndPRCurve(pooled);
        for (int i = 0; i < this.m_Dim; ++i) {
            this.m_ClassWisePredictions[i].sort();
            this.m_ROCAndPRCurves[i].computeCurves();
            this.m_ROCAndPRCurves[i].clear();
            pooled.add(this.m_ClassWisePredictions[i]);
            this.m_ClassWisePredictions[i].clearData();
        }
        pooled.sort();
        pooledCurve.computeCurves();
        pooledCurve.clear();
        int cnt = 0;
        double sumAUROC = 0.0;
        double sumAUPRC = 0.0;
        double sumAUPRCw = 0.0;
        double sumFrequency = 0.0;
        for (int i = 0; i < this.m_Dim; ++i) {
            double freq = this.m_ClassWisePredictions[i].getFrequency();
            sumAUROC += this.m_ROCAndPRCurves[i].getAreaROC();
            sumAUPRC += this.m_ROCAndPRCurves[i].getAreaPR();
            sumAUPRCw += freq * this.m_ROCAndPRCurves[i].getAreaPR();
            sumFrequency += freq;
            ++cnt;
        }
        this.m_AverageAUROC = sumAUROC / (double)cnt;
        this.m_AverageAUPRC = sumAUPRC / (double)cnt;
        this.m_WAvgAUPRC = sumAUPRCw / sumFrequency;
        this.m_PooledAUPRC = pooledCurve.getAreaPR();
    }

    @Override
    public abstract void showModelError(PrintWriter var1, int var2);

    @Override
    public String getName() {
        return "MLROCAndPRCurve";
    }

    @Override
    public void addExample(DataTuple tuple, ClusStatistic pred) {
        double[][] probabilities = ((ClassificationStat)pred).getProbabilityPrediction();
        for (int i = 0; i < this.m_Dim; ++i) {
            NominalAttrType attr = this.getAttr(i);
            if (attr.isMissing(tuple)) continue;
            boolean groundTruth = attr.getNominal(tuple) == 0;
            this.m_ClassWisePredictions[i].addExample(groundTruth, probabilities[i][0]);
        }
    }

    @Override
    public void addExample(DataTuple tuple, DataTuple pred) {
        throw new RuntimeException("Not implemented!");
    }

    @Override
    public void addInvalid(DataTuple tuple) {
    }

    @Override
    public abstract ClusError getErrorClone(ClusErrorList var1);
}

