/*
 * Decompiled with CFR 0.152.
 */
package eval;

import data.instance.Instance;
import data.value.Value;
import eval.EvaluationStats;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.stream.Collectors;

public class ClassificationStats
extends EvaluationStats {
    private HashMap<Value, HashMap<Value, Double>> confusionMatrix;

    @Override
    public void computeStats(HashMap<Instance, Value> predictions) {
        this.confusionMatrix = new HashMap();
        HashSet<Value> classes = new HashSet<Value>();
        classes.addAll(predictions.keySet().stream().map(inst -> inst.getLabel()).collect(Collectors.toSet()));
        classes.addAll(predictions.values());
        HashMap<Value, Double> actualCount = new HashMap<Value, Double>();
        HashMap<Value, Double> predictedCount = new HashMap<Value, Double>();
        HashMap<Value, Double> tnr = new HashMap<Value, Double>();
        HashMap<Value, Double> tpr = new HashMap<Value, Double>();
        HashMap<Value, Double> fpr = new HashMap<Value, Double>();
        HashMap<Value, Double> fnr = new HashMap<Value, Double>();
        HashMap<Value, Double> npv = new HashMap<Value, Double>();
        HashMap<Value, Double> ppv = new HashMap<Value, Double>();
        HashMap<Value, Double> fomr = new HashMap<Value, Double>();
        HashMap<Value, Double> fdr = new HashMap<Value, Double>();
        HashMap<Value, Double> fmeasure = new HashMap<Value, Double>();
        HashMap<Value, Double> plr = new HashMap<Value, Double>();
        HashMap<Value, Double> nlr = new HashMap<Value, Double>();
        HashMap<Value, Double> dor = new HashMap<Value, Double>();
        HashMap<Value, Double> diagonal = new HashMap<Value, Double>();
        for (Value value : classes) {
            HashMap<Value, Double> toAdd = new HashMap<Value, Double>();
            for (Value val2 : classes) {
                toAdd.put(val2, 0.0);
            }
            this.confusionMatrix.put(value, toAdd);
            actualCount.put(value, 0.0);
            predictedCount.put(value, 0.0);
            tnr.put(value, 0.0);
            tpr.put(value, 0.0);
            fpr.put(value, 0.0);
            fnr.put(value, 0.0);
            npv.put(value, 0.0);
            ppv.put(value, 0.0);
            fomr.put(value, 0.0);
            fdr.put(value, 0.0);
            diagonal.put(value, 0.0);
        }
        for (Map.Entry entry : predictions.entrySet()) {
            this.confusionMatrix.get(entry.getValue()).put(((Instance)entry.getKey()).getLabel(), this.confusionMatrix.get(entry.getValue()).get(((Instance)entry.getKey()).getLabel()) + ((Instance)entry.getKey()).getWeight());
        }
        double d = 0.0;
        double sumDiag = 0.0;
        for (Value pred : this.confusionMatrix.keySet()) {
            for (Value actual : this.confusionMatrix.get(pred).keySet()) {
                double elem = this.confusionMatrix.get(pred).get(actual);
                actualCount.put(actual, (Double)actualCount.get(actual) + elem);
                predictedCount.put(pred, (Double)predictedCount.get(pred) + elem);
                d += elem;
                if (!pred.equals(actual)) continue;
                diagonal.put(pred, elem);
                sumDiag += elem;
            }
        }
        this.statistics.put("accuracy", sumDiag / d);
        this.statistics.put("stderr", sumDiag * (d - sumDiag) / (d * d * d));
        this.statistics.put("error-rate", (d - sumDiag) / d);
        this.statistics.put("correct", sumDiag);
        this.statistics.put("incorrect", d - sumDiag);
        this.statistics.put("total", d);
        double chanceAgreement = 0.0;
        for (Value val : this.confusionMatrix.keySet()) {
            chanceAgreement += (Double)actualCount.get(val) * (Double)predictedCount.get(val);
            tpr.put(val, (Double)diagonal.get(val) / (Double)actualCount.get(val));
            fpr.put(val, ((Double)predictedCount.get(val) - (Double)diagonal.get(val)) / (d - (Double)actualCount.get(val)));
            fnr.put(val, ((Double)actualCount.get(val) - (Double)diagonal.get(val)) / (Double)actualCount.get(val));
            tnr.put(val, (d + (Double)diagonal.get(val) - (Double)actualCount.get(val) - (Double)predictedCount.get(val)) / (d - (Double)actualCount.get(val)));
            ppv.put(val, (Double)diagonal.get(val) / (Double)predictedCount.get(val));
            fdr.put(val, ((Double)predictedCount.get(val) - (Double)diagonal.get(val)) / (Double)predictedCount.get(val));
            fomr.put(val, ((Double)actualCount.get(val) - (Double)diagonal.get(val)) / (d - (Double)predictedCount.get(val)));
            npv.put(val, (d + (Double)diagonal.get(val) - (Double)actualCount.get(val) - (Double)predictedCount.get(val)) / (d - (Double)predictedCount.get(val)));
            fmeasure.put(val, 2.0 * (Double)ppv.get(val) * (Double)tpr.get(val) / ((Double)ppv.get(val) + (Double)tpr.get(val)));
            plr.put(val, (Double)tpr.get(val) / (Double)fpr.get(val));
            nlr.put(val, (Double)fnr.get(val) / (Double)tnr.get(val));
            dor.put(val, (Double)plr.get(val) / (Double)nlr.get(val));
        }
        this.statistics.put("tpr", tpr);
        this.statistics.put("fpr", fpr);
        this.statistics.put("fnr", fnr);
        this.statistics.put("tnr", tnr);
        this.statistics.put("ppv", ppv);
        this.statistics.put("fdr", fdr);
        this.statistics.put("for", fomr);
        this.statistics.put("npv", npv);
        this.statistics.put("f-measure", fmeasure);
        this.statistics.put("plr", plr);
        this.statistics.put("nlr", nlr);
        this.statistics.put("dor", dor);
        this.statistics.put("kappa", (sumDiag / d - (chanceAgreement /= d * d)) / (1.0 - chanceAgreement));
    }

    public String printConfusionMatrix() {
        StringBuilder res = new StringBuilder();
        int maxDigits = 3;
        ArrayList<Value> classes = new ArrayList<Value>(this.confusionMatrix.keySet());
        for (Value val1 : this.confusionMatrix.keySet()) {
            if (val1.getStringValue().length() > maxDigits) {
                maxDigits = val1.getStringValue().length();
            }
            for (Value val2 : this.confusionMatrix.get(val1).keySet()) {
                if (Double.toString(this.confusionMatrix.get(val1).get(val2)).length() <= maxDigits) continue;
                maxDigits = Double.toString(this.confusionMatrix.get(val1).get(val2)).length();
            }
        }
        int i = 0;
        while (i < maxDigits - 3) {
            res.append(" ");
            ++i;
        }
        res.append("A\\P|");
        for (Value val : classes) {
            res.append(val.getStringValue());
            int i2 = val.getStringValue().length();
            while (i2 < maxDigits) {
                res.append(" ");
                ++i2;
            }
            res.append("|");
        }
        res.append("\n");
        int width = (maxDigits + 1) * (this.confusionMatrix.size() + 1);
        int i3 = 0;
        while (i3 < width) {
            res.append("-");
            ++i3;
        }
        res.append("\n");
        for (Value val2 : classes) {
            res.append(val2.getStringValue());
            int i4 = val2.getStringValue().length();
            while (i4 < maxDigits) {
                res.append(" ");
                ++i4;
            }
            res.append("|");
            for (Value val1 : classes) {
                Double elem = this.confusionMatrix.get(val1).get(val2);
                res.append(elem.toString());
                int i5 = elem.toString().length();
                while (i5 < maxDigits) {
                    res.append(" ");
                    ++i5;
                }
                res.append("|");
            }
            res.append("\n");
            i = 0;
            while (i < width) {
                res.append("-");
                ++i;
            }
            res.append("\n");
        }
        return res.toString();
    }

    public String toString() {
        StringBuilder res = new StringBuilder();
        res.append(this.printConfusionMatrix());
        res.append("\n");
        for (Map.Entry ent : this.statistics.entrySet()) {
            res.append(String.valueOf((String)ent.getKey()) + " : " + ent.getValue().toString() + "\n");
        }
        return res.toString();
    }
}

