/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.error.hmlc;

import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.zip.GZIPOutputStream;
import si.ijs.kt.clus.algo.kNN.KnnClassifier;
import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.error.BinaryPredictionList;
import si.ijs.kt.clus.error.ROCAndPRCurve;
import si.ijs.kt.clus.error.common.ClusError;
import si.ijs.kt.clus.error.common.ClusErrorList;
import si.ijs.kt.clus.ext.hierarchical.ClassHierarchy;
import si.ijs.kt.clus.ext.hierarchical.ClassTerm;
import si.ijs.kt.clus.ext.hierarchical.ClassesTuple;
import si.ijs.kt.clus.ext.hierarchical.ClassesValue;
import si.ijs.kt.clus.main.settings.section.SettingsHMLC;
import si.ijs.kt.clus.statistic.ClusStatistic;
import si.ijs.kt.clus.statistic.WHTDStatistic;
import si.ijs.kt.clus.util.format.ClusFormat;
import si.ijs.kt.clus.util.format.ClusNumberFormat;

public class HierErrorMeasures
extends ClusError {
    public static final long serialVersionUID = 1L;
    protected ClassHierarchy m_Hier;
    protected boolean[] m_EvalClass;
    protected BinaryPredictionList[] m_ClassWisePredictions;
    protected ROCAndPRCurve[] m_ROCAndPRCurves;
    protected SettingsHMLC.HierarchyMeasures m_OptimizeMeasure;
    protected boolean m_WriteCurves;
    protected double[] m_RecallValues;
    protected double[] m_AvgPrecisionAtRecall;
    protected double m_AverageAUROC;
    protected double m_AverageAUPRC;
    protected double m_WAvgAUPRC;
    protected double m_PooledAUPRC;
    protected transient PrintWriter m_PRCurves;
    protected transient PrintWriter m_ROCCurves;
    private boolean m_IsGzipOutput;

    public HierErrorMeasures(ClusErrorList par, ClassHierarchy hier, double[] recalls, SettingsHMLC.HierarchyMeasures optimize, boolean wrCurves, boolean isGzipOutput) {
        super(par, hier.getTotal());
        this.m_Hier = hier;
        this.m_OptimizeMeasure = optimize;
        this.m_WriteCurves = wrCurves;
        this.m_RecallValues = recalls;
        this.m_EvalClass = hier.getEvalClassesVector();
        this.m_ClassWisePredictions = new BinaryPredictionList[hier.getTotal()];
        this.m_ROCAndPRCurves = new ROCAndPRCurve[hier.getTotal()];
        for (int i = 0; i < hier.getTotal(); ++i) {
            BinaryPredictionList predlist;
            this.m_ClassWisePredictions[i] = predlist = new BinaryPredictionList();
            this.m_ROCAndPRCurves[i] = new ROCAndPRCurve(predlist);
        }
        this.m_IsGzipOutput = isGzipOutput;
    }

    @Override
    public void addExample(DataTuple tuple, ClusStatistic pred) {
        ClassesTuple tp = (ClassesTuple)tuple.getObjVal(this.m_Hier.getType().getArrayIndex());
        double[] predarr = ((WHTDStatistic)pred).getNumericPred();
        boolean[] actual = tp.getVectorBooleanNodeAndAncestors(this.m_Hier);
        for (int i = 0; i < this.m_Dim; ++i) {
            this.m_ClassWisePredictions[i].addExample(actual[i], predarr[i]);
        }
    }

    @Override
    public void addInvalid(DataTuple tuple) {
        ClassesTuple tp = (ClassesTuple)tuple.getObjVal(this.m_Hier.getType().getArrayIndex());
        boolean[] actual = tp.getVectorBooleanNodeAndAncestors(this.m_Hier);
        for (int i = 0; i < this.m_Dim; ++i) {
            this.m_ClassWisePredictions[i].addInvalid(actual[i]);
        }
    }

    @Override
    public boolean isComputeForModel(String name) {
        if (name.equals("Original") || name.equals("Pruned") || name.equals("Default")) {
            return true;
        }
        if (name.startsWith("Original ") && name.contains("-nn model with ") || name.equals(KnnClassifier.DEFAULT_MODEL_NAME_WITH_CONSTANT_WEIGHTS)) {
            return true;
        }
        return name.startsWith("Forest with ") && !name.contains("T = ");
    }

    @Override
    public boolean shouldBeLow() {
        return false;
    }

    @Override
    public double getModelError() {
        this.computeAll();
        switch (this.m_OptimizeMeasure) {
            case AverageAUROC: {
                return this.m_AverageAUROC;
            }
            case AverageAUPRC: {
                return this.m_AverageAUPRC;
            }
            case WeightedAverageAUPRC: {
                return this.m_WAvgAUPRC;
            }
            case PooledAUPRC: {
                return this.m_PooledAUPRC;
            }
        }
        return 0.0;
    }

    public boolean isEvalClass(int idx) {
        return this.m_EvalClass[idx] && this.includeZeroFreqClasses(idx);
    }

    @Override
    public void reset() {
        for (int i = 0; i < this.m_Dim; ++i) {
            this.m_ClassWisePredictions[i].clear();
        }
    }

    @Override
    public void add(ClusError other) {
        BinaryPredictionList[] olist = ((HierErrorMeasures)other).m_ClassWisePredictions;
        for (int i = 0; i < this.m_Dim; ++i) {
            this.m_ClassWisePredictions[i].add(olist[i]);
        }
    }

    @Override
    public void updateFromGlobalMeasure(ClusError global) {
        BinaryPredictionList[] olist = ((HierErrorMeasures)global).m_ClassWisePredictions;
        for (int i = 0; i < this.m_Dim; ++i) {
            this.m_ClassWisePredictions[i].copyActual(olist[i]);
        }
    }

    public void printResultsRec(ClusNumberFormat fr, PrintWriter out, ClassTerm node, boolean[] printed) {
        int idx = node.getIndex();
        if (printed[idx]) {
            return;
        }
        printed[idx] = true;
        if (this.isEvalClass(idx)) {
            ClassesValue val = new ClassesValue(node);
            out.print("      " + idx + ": " + val.toStringWithDepths(this.m_Hier));
            out.print(", AUROC: " + fr.format(this.m_ROCAndPRCurves[idx].getAreaROC()));
            out.print(", AUPRC: " + fr.format(this.m_ROCAndPRCurves[idx].getAreaPR()));
            out.print(", Freq: " + fr.format(this.m_ClassWisePredictions[idx].getFrequency()));
            if (this.m_RecallValues != null) {
                int nbRecalls = this.m_RecallValues.length;
                for (int i = 0; i < nbRecalls; ++i) {
                    int rec = (int)Math.floor(100.0 * this.m_RecallValues[i] + 0.5);
                    out.print(", P" + rec + "R: " + fr.format(100.0 * this.m_ROCAndPRCurves[idx].getPrecisionAtRecall(i)));
                }
            }
            out.println();
        }
        for (int i = 0; i < node.getNbChildren(); ++i) {
            this.printResultsRec(fr, out, (ClassTerm)node.getChild(i), printed);
        }
    }

    public void printResults(ClusNumberFormat fr, PrintWriter out, ClassHierarchy hier) {
        ClassTerm node = hier.getRoot();
        boolean[] printed = new boolean[hier.getTotal()];
        for (int i = 0; i < node.getNbChildren(); ++i) {
            this.printResultsRec(fr, out, (ClassTerm)node.getChild(i), printed);
        }
    }

    @Override
    public boolean isMultiLine() {
        return true;
    }

    public void compatibility(ROCAndPRCurve[] curves, ROCAndPRCurve pooled) {
        double[] thr = null;
        for (int i = 0; i < curves.length; ++i) {
            curves[i].setThresholds(thr);
        }
        pooled.setThresholds(thr);
    }

    public boolean includeZeroFreqClasses(int idx) {
        return this.m_ClassWisePredictions[idx].getNbPos() > 0;
    }

    public void computeAll() {
        BinaryPredictionList pooled = new BinaryPredictionList();
        ROCAndPRCurve pooledCurve = new ROCAndPRCurve(pooled);
        this.compatibility(this.m_ROCAndPRCurves, pooledCurve);
        for (int i = 0; i < this.m_Dim; ++i) {
            if (!this.isEvalClass(i)) continue;
            this.m_ClassWisePredictions[i].sort();
            this.m_ROCAndPRCurves[i].computeCurves();
            this.m_ROCAndPRCurves[i].computePrecisions(this.m_RecallValues);
            this.outputPRCurve(i, this.m_ROCAndPRCurves[i]);
            this.outputROCCurve(i, this.m_ROCAndPRCurves[i]);
            this.m_ROCAndPRCurves[i].clear();
            pooled.add(this.m_ClassWisePredictions[i]);
            this.m_ClassWisePredictions[i].clearData();
        }
        pooled.sort();
        pooledCurve.computeCurves();
        this.outputPRCurve(-1, pooledCurve);
        this.outputROCCurve(-1, pooledCurve);
        pooledCurve.clear();
        int cnt = 0;
        double sumAUROC = 0.0;
        double sumAUPRC = 0.0;
        double sumAUPRCw = 0.0;
        double sumFrequency = 0.0;
        for (int i = 0; i < this.m_Dim; ++i) {
            if (!this.isEvalClass(i)) continue;
            double freq = this.m_ClassWisePredictions[i].getFrequency();
            sumAUROC += this.m_ROCAndPRCurves[i].getAreaROC();
            sumAUPRC += this.m_ROCAndPRCurves[i].getAreaPR();
            sumAUPRCw += freq * this.m_ROCAndPRCurves[i].getAreaPR();
            sumFrequency += freq;
            ++cnt;
        }
        this.m_AverageAUROC = sumAUROC / (double)cnt;
        this.m_AverageAUPRC = sumAUPRC / (double)cnt;
        this.m_WAvgAUPRC = sumAUPRCw / sumFrequency;
        this.m_PooledAUPRC = pooledCurve.getAreaPR();
        if (this.m_RecallValues != null) {
            int nbRecalls = this.m_RecallValues.length;
            this.m_AvgPrecisionAtRecall = new double[nbRecalls];
            int j = 0;
            while (j < nbRecalls) {
                int nbClass = 0;
                for (int i = 0; i < this.m_Dim; ++i) {
                    if (!this.isEvalClass(i)) continue;
                    double prec = this.m_ROCAndPRCurves[i].getPrecisionAtRecall(j);
                    int n = j;
                    this.m_AvgPrecisionAtRecall[n] = this.m_AvgPrecisionAtRecall[n] + prec;
                    ++nbClass;
                }
                int n = j++;
                this.m_AvgPrecisionAtRecall[n] = this.m_AvgPrecisionAtRecall[n] / (double)nbClass;
            }
        }
    }

    public void ouputCurve(int ci, ArrayList points, PrintWriter curves) {
        String clName = "ALL";
        if (ci != -1) {
            ClassTerm cl = this.m_Hier.getTermAt(ci);
            clName = "\"" + cl.toStringHuman(this.m_Hier) + "\"";
        }
        for (int i = 0; i < points.size(); ++i) {
            double[] pt = (double[])points.get(i);
            curves.println(clName + "," + pt[0] + "," + pt[1]);
        }
    }

    public void outputPRCurve(int i, ROCAndPRCurve curve) {
        if (this.m_PRCurves != null) {
            ArrayList<double[]> points = curve.getPRCurve();
            this.ouputCurve(i, points, this.m_PRCurves);
        }
    }

    public void outputROCCurve(int i, ROCAndPRCurve curve) {
        if (this.m_ROCCurves != null) {
            ArrayList<double[]> points = curve.getROCCurve();
            this.ouputCurve(i, points, this.m_ROCCurves);
        }
    }

    public void writeCSVFilesPR(String fname) throws IOException {
        if (this.m_IsGzipOutput) {
            fname = fname + ".gz";
            this.m_PRCurves = new PrintWriter(new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(fname))));
        } else {
            this.m_PRCurves = new PrintWriter(fname);
        }
        this.m_PRCurves.println("Class,Recall,Precision");
    }

    public void writeCSVFilesROC(String fname) throws IOException {
        if (this.m_IsGzipOutput) {
            fname = fname + ".gz";
            this.m_ROCCurves = new PrintWriter(new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(fname))));
        } else {
            this.m_ROCCurves = new PrintWriter(fname);
        }
        this.m_ROCCurves.println("Class,FP,TP");
    }

    @Override
    public void showModelError(PrintWriter out, String bName, int detail) throws IOException {
        if (this.m_WriteCurves && bName != null) {
            this.writeCSVFilesPR(bName + ".pr.csv");
            this.writeCSVFilesROC(bName + ".roc.csv");
        }
        ClusNumberFormat fr1 = ClusFormat.SIX_AFTER_DOT;
        this.computeAll();
        out.println();
        out.println("      Average AUROC:            " + this.m_AverageAUROC);
        out.println("      Average AUPRC:            " + this.m_AverageAUPRC);
        out.println("      Average AUPRC (weighted): " + this.m_WAvgAUPRC);
        out.println("      Pooled AUPRC:             " + this.m_PooledAUPRC);
        if (this.m_RecallValues != null) {
            int nbRecalls = this.m_RecallValues.length;
            for (int i = 0; i < nbRecalls; ++i) {
                int rec = (int)Math.floor(100.0 * this.m_RecallValues[i] + 0.5);
                out.println("      P" + rec + "R: " + 100.0 * this.m_AvgPrecisionAtRecall[i]);
            }
        }
        if (detail != 1) {
            this.printResults(fr1, out, this.m_Hier);
        }
        if (this.m_PRCurves != null) {
            this.m_PRCurves.close();
            this.m_PRCurves = null;
        }
        if (this.m_ROCCurves != null) {
            this.m_ROCCurves.close();
            this.m_ROCCurves = null;
        }
    }

    @Override
    public String getName() {
        return "Hierarchical error measures";
    }

    @Override
    public ClusError getErrorClone(ClusErrorList par) {
        return new HierErrorMeasures(par, this.m_Hier, this.m_RecallValues, this.m_OptimizeMeasure, this.m_WriteCurves, this.m_IsGzipOutput);
    }

    @Override
    public void showModelError(PrintWriter out, int detail) {
        try {
            this.showModelError(out, null, detail);
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }
}

