/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.distance.primitive.relief;

import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.data.type.ClusAttrType;
import si.ijs.kt.clus.data.type.primitive.NominalAttrType;
import si.ijs.kt.clus.main.settings.section.SettingsRelief;

public class MultiLabelDistance {
    private SettingsRelief.MultilabelDistance m_DistanceType;
    private NominalAttrType[] m_Labels;
    private double[] m_LabelProbabilities;
    public static int m_RelevantLabelIndex = 0;

    public MultiLabelDistance(SettingsRelief.MultilabelDistance type, ClusAttrType[] attrs, double[] labelProbabilities) {
        this.m_DistanceType = type;
        this.m_Labels = new NominalAttrType[attrs.length];
        for (int i = 0; i < attrs.length; ++i) {
            this.m_Labels[i] = (NominalAttrType)attrs[i];
        }
        this.m_LabelProbabilities = labelProbabilities;
    }

    public double calculateDist(DataTuple t1, DataTuple t2) {
        double dist;
        switch (this.m_DistanceType) {
            case HammingLoss: {
                dist = this.hamming_loss(t1, t2);
                break;
            }
            case MLAccuracy: {
                dist = this.mlAccuracy(t1, t2);
                break;
            }
            case MLFOne: {
                dist = this.F1(t1, t2);
                break;
            }
            case SubsetAccuracy: {
                dist = this.subsetAccuracy(t1, t2);
                break;
            }
            default: {
                throw new RuntimeException("Unknown distance type");
            }
        }
        return dist;
    }

    private double correctedAttributeValue(int attrInd, NominalAttrType attr, DataTuple t) {
        return attr.isMissing(t) ? this.m_LabelProbabilities[attrInd] : 1.0 - (double)attr.getNominal(t);
    }

    private double areValuesEqualProbability(int attrInd, NominalAttrType attr, DataTuple t1, DataTuple t2) {
        double p = this.m_LabelProbabilities[attrInd];
        double q = 1.0 - p;
        if (attr.isMissing(t1) && attr.isMissing(t2)) {
            return p * p + q * q;
        }
        if (attr.isMissing(t1)) {
            return attr.getNominal(t2) == m_RelevantLabelIndex ? p : q;
        }
        if (attr.isMissing(t2)) {
            return attr.getNominal(t1) == m_RelevantLabelIndex ? p : q;
        }
        return attr.getNominal(t1) == attr.getNominal(t2) ? 1.0 : 0.0;
    }

    private double hamming_loss(DataTuple t1, DataTuple t2) {
        double dist = 0.0;
        for (int i = 0; i < this.m_Labels.length; ++i) {
            NominalAttrType attr = this.m_Labels[i];
            dist += 1.0 - this.areValuesEqualProbability(i, attr, t1, t2);
        }
        return dist / (double)this.m_Labels.length;
    }

    private double mlAccuracy(DataTuple t1, DataTuple t2) {
        double cap = 0.0;
        double cup = 0.0;
        for (int i = 0; i < this.m_Labels.length; ++i) {
            NominalAttrType attr = this.m_Labels[i];
            double value1 = this.correctedAttributeValue(i, attr, t1);
            double value2 = this.correctedAttributeValue(i, attr, t2);
            cap += value1 * value2;
            cup += value1 + value2 - value1 * value2;
        }
        double similarity = cup > 1.0E-9 ? cap / cup : 1.0;
        return 1.0 - similarity;
    }

    private double F1(DataTuple t1, DataTuple t2) {
        double cap = 0.0;
        double first = 0.0;
        double second = 0.0;
        for (int i = 0; i < this.m_Labels.length; ++i) {
            NominalAttrType attr = this.m_Labels[i];
            double value1 = this.correctedAttributeValue(i, attr, t1);
            double value2 = this.correctedAttributeValue(i, attr, t2);
            cap += value1 * value2;
            first += value1;
            second += value2;
        }
        double similarity = first + second > 1.0E-9 ? 2.0 * cap / (first + second) : 1.0;
        return 1.0 - similarity;
    }

    public double subsetAccuracy(DataTuple t1, DataTuple t2) {
        double pEqualSets = 1.0;
        for (int i = 0; i < this.m_Labels.length; ++i) {
            NominalAttrType attr = this.m_Labels[i];
            double factor = this.areValuesEqualProbability(i, attr, t1, t2);
            if (factor < 1.0E-9) {
                pEqualSets = 0.0;
                break;
            }
            pEqualSets *= factor;
        }
        return 1.0 - pEqualSets;
    }

    public String toString() {
        String fullClassName = this.getClass().toString();
        int lastIndex = fullClassName.lastIndexOf(".") + 1;
        return String.format("%s(%s)", fullClassName.substring(lastIndex), this.m_DistanceType.toString());
    }

    public String distanceName() {
        return this.m_DistanceType.toString();
    }
}

