/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.error.hmlc;

import java.io.PrintWriter;
import java.util.Arrays;
import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.error.common.ClusError;
import si.ijs.kt.clus.error.common.ClusErrorList;
import si.ijs.kt.clus.ext.hierarchical.ClassHierarchy;
import si.ijs.kt.clus.ext.hierarchical.ClassTerm;
import si.ijs.kt.clus.statistic.ClusStatistic;
import si.ijs.kt.clus.util.format.ClusNumberFormat;

public class HierLevelAccuracy
extends ClusError {
    public static final long serialVersionUID = 1L;
    protected ClassHierarchy m_Hier;
    protected double[] m_CorrectLevel;
    protected double[] m_CountLevel;
    protected double m_Correct;
    protected double m_Predicted;
    protected boolean[] m_ActualArr;
    protected boolean[] m_PredLevelErr;
    protected int m_MaxDepth;

    public HierLevelAccuracy(ClusErrorList par, ClassHierarchy hier) {
        super(par, hier.getMaxDepth());
        this.m_Hier = hier;
        this.m_CorrectLevel = new double[this.m_Dim];
        this.m_CountLevel = new double[this.m_Dim];
        this.m_ActualArr = new boolean[hier.getTotal()];
        this.m_PredLevelErr = new boolean[this.m_Dim];
    }

    public void update(ClassTerm node, int depth, double[] predarr) {
        boolean has_pred = predarr[node.getIndex()] >= 0.5;
        boolean has_actual = this.m_ActualArr[node.getIndex()];
        if ((has_pred || has_actual) && depth > this.m_MaxDepth) {
            this.m_MaxDepth = depth;
        }
        if (has_pred != has_actual) {
            this.m_PredLevelErr[depth] = true;
        }
        for (int i = 0; i < node.getNbChildren(); ++i) {
            this.update((ClassTerm)node.getChild(i), depth + 1, predarr);
        }
    }

    @Override
    public void addExample(DataTuple tuple, ClusStatistic pred) {
    }

    @Override
    public void addInvalid(DataTuple tuple) {
    }

    @Override
    public double getModelError() {
        int nb = this.getNbExamples();
        return nb == 0 ? 0.0 : 1.0 - this.m_Correct / (double)nb;
    }

    public double getErrorComp(int i) {
        double nb = this.m_CountLevel[i];
        return nb == 0.0 ? 0.0 : this.m_CorrectLevel[i] / nb;
    }

    public double getAccuracy() {
        return this.m_Predicted == 0.0 ? 0.0 : this.m_Correct / this.m_Predicted;
    }

    public double getRecall() {
        int nb = this.getNbExamples();
        return nb == 0 ? 0.0 : this.m_Predicted / (double)nb;
    }

    public double getOverallAccuracy() {
        int nb = this.getNbExamples();
        return nb == 0 ? 0.0 : this.m_Correct / (double)nb;
    }

    @Override
    public String getName() {
        return "Hierarchical accuracy by level";
    }

    @Override
    public ClusError getErrorClone(ClusErrorList par) {
        return new HierLevelAccuracy(par, this.m_Hier);
    }

    @Override
    public void reset() {
        Arrays.fill(this.m_CorrectLevel, 0.0);
        Arrays.fill(this.m_CountLevel, 0.0);
        this.m_Correct = 0.0;
        this.m_Predicted = 0.0;
    }

    @Override
    public void add(ClusError other) {
        HierLevelAccuracy acc = (HierLevelAccuracy)other;
        this.m_Correct += acc.m_Correct;
        this.m_Predicted += acc.m_Predicted;
        for (int i = 0; i < this.m_Dim; ++i) {
            int n = i;
            this.m_CorrectLevel[n] = this.m_CorrectLevel[n] + acc.m_CorrectLevel[i];
            int n2 = i;
            this.m_CountLevel[n2] = this.m_CountLevel[n2] + acc.m_CountLevel[i];
        }
    }

    @Override
    public void showModelError(PrintWriter out, int detail) {
        ClusNumberFormat fr = this.getFormat();
        StringBuffer buf = new StringBuffer();
        buf.append("[");
        for (int i = 0; i < this.m_Dim; ++i) {
            if (i != 0) {
                buf.append(",");
            }
            buf.append(fr.format(this.getErrorComp(i)));
        }
        buf.append("]");
        buf.append(", Acc: ");
        buf.append(fr.format(this.getAccuracy()));
        buf.append(", Rec: ");
        buf.append(fr.format(this.getRecall()));
        buf.append(", AccAll: ");
        buf.append(fr.format(this.getOverallAccuracy()));
        out.println(buf.toString());
    }

    @Override
    public boolean shouldBeLow() {
        return false;
    }
}

