/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.error.hmlc;

import java.io.PrintWriter;
import java.util.Arrays;
import si.ijs.kt.clus.algo.kNN.KnnClassifier;
import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.error.common.ClusError;
import si.ijs.kt.clus.error.common.ClusErrorList;
import si.ijs.kt.clus.ext.hierarchical.ClassHierarchy;
import si.ijs.kt.clus.ext.hierarchical.ClassTerm;
import si.ijs.kt.clus.ext.hierarchical.ClassesTuple;
import si.ijs.kt.clus.ext.hierarchical.ClassesValue;
import si.ijs.kt.clus.statistic.ClusStatistic;
import si.ijs.kt.clus.statistic.WHTDStatistic;
import si.ijs.kt.clus.util.format.ClusFormat;
import si.ijs.kt.clus.util.format.ClusNumberFormat;

public class HierClassWiseAccuracy
extends ClusError {
    public static final long serialVersionUID = 1L;
    protected boolean m_NoTrivialClasses = false;
    protected ClassHierarchy m_Hier;
    protected double[] m_NbPosPredictions;
    protected double[] m_TP;
    protected double[] m_NbPosActual;
    protected boolean[] m_EvalClass;
    private int hierarchyIndex = 0;

    public HierClassWiseAccuracy(ClusErrorList par, ClassHierarchy hier) {
        super(par, hier.getTotal());
        this.m_Hier = hier;
        this.m_EvalClass = hier.getEvalClassesVector();
        this.m_NbPosPredictions = new double[this.m_Dim];
        this.m_TP = new double[this.m_Dim];
        this.m_NbPosActual = new double[this.m_Dim];
        this.hierarchyIndex = hier.getType().getArrayIndex();
    }

    @Override
    public void addExample(DataTuple tuple, ClusStatistic pred) {
        ClassesTuple tp = (ClassesTuple)tuple.getObjVal(this.hierarchyIndex);
        boolean[] predarr = ((WHTDStatistic)pred).getDiscretePred();
        for (int i = 0; i < this.m_Dim; ++i) {
            if (!predarr[i]) continue;
            int n = i;
            this.m_NbPosPredictions[n] = this.m_NbPosPredictions[n] + 1.0;
            if (!tp.hasClass(i)) continue;
            int n2 = i;
            this.m_TP[n2] = this.m_TP[n2] + 1.0;
        }
        tp.updateDistribution(this.m_NbPosActual, 1.0);
    }

    @Override
    public void addInvalid(DataTuple tuple) {
        ClassesTuple tp = (ClassesTuple)tuple.getObjVal(0);
        tp.updateDistribution(this.m_NbPosActual, 1.0);
    }

    @Override
    public boolean isComputeForModel(String name) {
        if (name.equals("Default")) {
            return false;
        }
        if (name.equals("Original")) {
            return false;
        }
        if (name.startsWith("Original") && name.contains("-nn model with ") || name.equals(KnnClassifier.DEFAULT_MODEL_NAME_WITH_CONSTANT_WEIGHTS)) {
            return true;
        }
        return true;
    }

    public boolean isNoTrivialClasses() {
        return this.m_NoTrivialClasses;
    }

    public boolean isEvalClass(int idx) {
        if (this.isNoTrivialClasses() && this.m_NbPosActual[idx] == (double)this.getNbTotal()) {
            return false;
        }
        return this.m_EvalClass[idx];
    }

    public double getPrecision() {
        double tot_corr = this.getTP();
        double tot_pred = this.getSumNbPosPredicted();
        return tot_pred == 0.0 ? 0.0 : tot_corr / tot_pred;
    }

    public double getRecall() {
        double tot_corr = this.getTP();
        double tot_def = this.getSumNbPosActual();
        return tot_def == 0.0 ? 0.0 : tot_corr / tot_def;
    }

    public int getTP() {
        int tot_corr = 0;
        for (int i = 0; i < this.m_Dim; ++i) {
            if (!this.isEvalClass(i)) continue;
            tot_corr = (int)((double)tot_corr + this.m_TP[i]);
        }
        return tot_corr;
    }

    public int getFP() {
        int tot_pred = this.getSumNbPosPredicted();
        int tot_corr = this.getTP();
        return tot_pred - tot_corr;
    }

    public int getFN() {
        int tot_def = this.getSumNbPosActual();
        int tot_corr = this.getTP();
        return tot_def - tot_corr;
    }

    public int getSumNbPosActual() {
        int tot_def = 0;
        for (int i = 0; i < this.m_Dim; ++i) {
            if (!this.isEvalClass(i)) continue;
            tot_def = (int)((double)tot_def + this.m_NbPosActual[i]);
        }
        return tot_def;
    }

    public int getSumNbPosPredicted() {
        int tot_pred = 0;
        for (int i = 0; i < this.m_Dim; ++i) {
            if (!this.isEvalClass(i)) continue;
            tot_pred = (int)((double)tot_pred + this.m_NbPosPredictions[i]);
        }
        return tot_pred;
    }

    public int getNbPosExamplesCheck() {
        return this.getTP() + this.getFN();
    }

    public double getMacroAvgPrecision() {
        int cnt = 0;
        double avg = 0.0;
        for (int i = 0; i < this.m_Dim; ++i) {
            if (this.m_NbPosPredictions[i] == 0.0 || !this.isEvalClass(i)) continue;
            ++cnt;
            avg += this.m_TP[i] / this.m_NbPosPredictions[i];
        }
        return cnt == 0 ? 0.0 : avg / (double)cnt;
    }

    @Override
    public void reset() {
        Arrays.fill(this.m_TP, 0.0);
        Arrays.fill(this.m_NbPosPredictions, 0.0);
        Arrays.fill(this.m_NbPosActual, 0.0);
    }

    @Override
    public void add(ClusError other) {
        HierClassWiseAccuracy acc = (HierClassWiseAccuracy)other;
        for (int i = 0; i < this.m_Dim; ++i) {
            int n = i;
            this.m_TP[n] = this.m_TP[n] + acc.m_TP[i];
            int n2 = i;
            this.m_NbPosPredictions[n2] = this.m_NbPosPredictions[n2] + acc.m_NbPosPredictions[i];
            int n3 = i;
            this.m_NbPosActual[n3] = this.m_NbPosActual[n3] + acc.m_NbPosActual[i];
        }
    }

    @Override
    public void updateFromGlobalMeasure(ClusError global) {
        HierClassWiseAccuracy other = (HierClassWiseAccuracy)global;
        System.arraycopy(other.m_NbPosActual, 0, this.m_NbPosActual, 0, this.m_NbPosActual.length);
    }

    public void printNonZeroAccuraciesRec(ClusNumberFormat fr, PrintWriter out, ClassTerm node, boolean[] printed) {
        int idx = node.getIndex();
        if (printed[idx]) {
            return;
        }
        printed[idx] = true;
        if (this.m_NbPosPredictions[idx] != 0.0 && this.isEvalClass(idx)) {
            int nb = this.getNbTotal();
            double def = nb == 0 ? 0.0 : this.m_NbPosActual[idx] / (double)nb;
            double prec = this.m_NbPosPredictions[idx] == 0.0 ? 0.0 : this.m_TP[idx] / this.m_NbPosPredictions[idx];
            double rec = this.m_NbPosActual[idx] == 0.0 ? 0.0 : this.m_TP[idx] / this.m_NbPosActual[idx];
            int TP = (int)this.m_TP[idx];
            int FP = (int)(this.m_NbPosPredictions[idx] - this.m_TP[idx]);
            int nbPos = (int)this.m_NbPosActual[idx];
            ClassesValue val = new ClassesValue(node);
            out.print("      " + val.toStringWithDepths(this.m_Hier));
            out.print(", def: " + fr.format(def));
            out.print(", prec: " + fr.format(prec));
            out.print(", rec: " + fr.format(rec));
            out.print(", TP: " + fr.format(TP) + ", FP: " + fr.format(FP) + ", nbPos: " + fr.format(nbPos));
            out.println();
        }
        for (int i = 0; i < node.getNbChildren(); ++i) {
            this.printNonZeroAccuraciesRec(fr, out, (ClassTerm)node.getChild(i), printed);
        }
    }

    public void printNonZeroAccuracies(ClusNumberFormat fr, PrintWriter out, ClassHierarchy hier) {
        boolean[] printed = new boolean[hier.getTotal()];
        ClassTerm node = hier.getRoot();
        for (int i = 0; i < node.getNbChildren(); ++i) {
            this.printNonZeroAccuraciesRec(fr, out, (ClassTerm)node.getChild(i), printed);
        }
    }

    @Override
    public void showModelError(PrintWriter out, int detail) {
        ClusNumberFormat fr1 = this.getFormat();
        ClusNumberFormat fr2 = ClusFormat.SIX_AFTER_DOT;
        out.print("precision: " + fr2.format(this.getPrecision()));
        out.print(", recall: " + fr2.format(this.getRecall()));
        out.print(", coverage: " + fr2.format(this.getCoverage()));
        out.print(", TP: " + this.getTP() + ", FP: " + this.getFP() + ", nbPos: " + this.getSumNbPosActual());
        out.println();
        this.printNonZeroAccuracies(fr1, out, this.m_Hier);
    }

    @Override
    public String getName() {
        return "Hierarchical accuracy by class";
    }

    @Override
    public ClusError getErrorClone(ClusErrorList par) {
        return new HierClassWiseAccuracy(par, this.m_Hier);
    }

    public void nextPrediction(int cls, boolean predicted_class, boolean actually_has_class) {
        if (predicted_class) {
            int n = cls;
            this.m_NbPosPredictions[n] = this.m_NbPosPredictions[n] + 1.0;
            if (actually_has_class) {
                int n2 = cls;
                this.m_TP[n2] = this.m_TP[n2] + 1.0;
            }
        }
        if (actually_has_class) {
            int n = cls;
            this.m_NbPosActual[n] = this.m_NbPosActual[n] + 1.0;
        }
    }

    @Override
    public boolean shouldBeLow() {
        return false;
    }
}

