/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.error;

import java.io.PrintWriter;
import si.ijs.kt.clus.data.attweights.ClusAttributeWeights;
import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.data.type.primitive.NumericAttrType;
import si.ijs.kt.clus.error.common.ClusError;
import si.ijs.kt.clus.error.common.ClusErrorList;
import si.ijs.kt.clus.error.common.ClusNumericError;
import si.ijs.kt.clus.error.common.ComponentError;
import si.ijs.kt.clus.statistic.ClusStatistic;
import si.ijs.kt.clus.statistic.RegressionStat;
import si.ijs.kt.clus.util.ClusLogger;
import si.ijs.kt.clus.util.format.ClusNumberFormat;

public class MSError
extends ClusNumericError
implements ComponentError {
    public static final long serialVersionUID = 1L;
    protected int[] m_nbEx;
    protected double[] m_SumErr;
    protected double[] m_SumSqErr;
    protected ClusAttributeWeights m_Weights;
    protected boolean m_PrintAllComps;
    protected double[] m_SumTrueValues;
    protected double[] m_SumSquaredTrueValues;

    public MSError(ClusErrorList par, NumericAttrType[] num) {
        this(par, num, null, true, "");
    }

    public MSError(ClusErrorList par, NumericAttrType[] num, String info) {
        this(par, num, null, true, info);
    }

    public MSError(ClusErrorList par, NumericAttrType[] num, ClusAttributeWeights weights) {
        this(par, num, weights, true, "");
    }

    public MSError(ClusErrorList par, NumericAttrType[] num, ClusAttributeWeights weights, String info) {
        this(par, num, weights, true, info);
    }

    public MSError(ClusErrorList par, NumericAttrType[] num, ClusAttributeWeights weights, boolean printall) {
        this(par, num, weights, printall, "");
    }

    public MSError(ClusErrorList par, NumericAttrType[] num, ClusAttributeWeights weights, boolean printall, String info) {
        super(par, num);
        this.m_nbEx = new int[this.m_Dim];
        this.m_SumErr = new double[this.m_Dim];
        this.m_SumSqErr = new double[this.m_Dim];
        this.m_Weights = weights;
        this.m_PrintAllComps = printall;
        this.m_SumTrueValues = new double[this.m_Dim];
        this.m_SumSquaredTrueValues = new double[this.m_Dim];
        this.setAdditionalInfo(info);
    }

    @Override
    public void reset() {
        for (int i = 0; i < this.m_Dim; ++i) {
            this.m_SumErr[i] = 0.0;
            this.m_SumSqErr[i] = 0.0;
            this.m_nbEx[i] = 0;
            this.m_SumTrueValues[i] = 0.0;
            this.m_SumSquaredTrueValues[i] = 0.0;
        }
    }

    @Override
    public void setWeights(ClusAttributeWeights weights) {
        this.m_Weights = weights;
    }

    @Override
    public double getModelErrorComponent(int i) {
        double err;
        int nb = this.m_nbEx[i];
        double d = err = (double)nb != 0.0 ? this.m_SumErr[i] / (double)nb : 0.0;
        if (this.m_Weights != null) {
            err *= this.m_Weights.getWeight(this.getAttr(i));
        }
        return err;
    }

    @Override
    public double getModelError() {
        int i;
        double ss_tree = 0.0;
        int nb = 0;
        for (i = 0; i < this.m_Dim; ++i) {
            nb += this.m_nbEx[i];
        }
        if (this.m_Weights != null) {
            for (i = 0; i < this.m_Dim; ++i) {
                ss_tree += this.m_SumErr[i] * this.m_Weights.getWeight(this.getAttr(i));
            }
            return (double)nb != 0.0 ? ss_tree / (double)nb : 0.0;
        }
        for (i = 0; i < this.m_Dim; ++i) {
            ss_tree += this.m_SumErr[i];
        }
        return (double)nb != 0.0 ? ss_tree / (double)nb : 0.0;
    }

    @Override
    public double getModelErrorStandardError() {
        double sum_err = 0.0;
        double sum_sq_err = 0.0;
        for (int i = 0; i < this.m_Dim; ++i) {
            if (this.m_Weights != null) {
                sum_err += this.m_SumErr[i];
                sum_sq_err += this.m_SumSqErr[i];
                continue;
            }
            sum_err += this.m_SumErr[i] * this.m_Weights.getWeight(this.getAttr(i));
            sum_sq_err += this.m_SumSqErr[i] * MSError.sqr(this.m_Weights.getWeight(this.getAttr(i)));
        }
        double n = 0.0;
        for (int i = 0; i < this.m_Dim; ++i) {
            n += (double)this.m_nbEx[i];
        }
        if (n <= 1.0) {
            return Double.POSITIVE_INFINITY;
        }
        double ss_x = (n * sum_sq_err - MSError.sqr(sum_err)) / (n * (n - 1.0));
        return Math.sqrt(ss_x / n);
    }

    public static final double sqr(double value) {
        return value * value;
    }

    public void addExample(double[] real, double[] predicted, boolean isRelative) {
        for (int i = 0; i < this.m_Dim; ++i) {
            double err = MSError.sqr(real[i] - predicted[i]);
            ClusLogger.info(err);
            if (Double.isInfinite(err) || Double.isNaN(err)) continue;
            int n = i;
            this.m_SumErr[n] = this.m_SumErr[n] + err;
            int n2 = i;
            this.m_SumSqErr[n2] = this.m_SumSqErr[n2] + MSError.sqr(err);
            int n3 = i;
            this.m_nbEx[n3] = this.m_nbEx[n3] + 1;
            if (!isRelative) continue;
            int n4 = i;
            this.m_SumTrueValues[n4] = this.m_SumTrueValues[n4] + real[i];
            int n5 = i;
            this.m_SumSquaredTrueValues[n5] = this.m_SumSquaredTrueValues[n5] + real[i] * real[i];
        }
    }

    @Override
    public void addExample(double[] real, double[] predicted) {
        this.addExample(real, predicted, false);
    }

    public void addExample(double[] real, boolean[] predicted, boolean isRelative) {
        for (int i = 0; i < this.m_Dim; ++i) {
            double predicted_i = predicted[i] ? 1.0 : 0.0;
            double err = MSError.sqr(real[i] - predicted_i);
            ClusLogger.info(err);
            if (Double.isInfinite(err) || Double.isNaN(err)) continue;
            int n = i;
            this.m_SumErr[n] = this.m_SumErr[n] + err;
            int n2 = i;
            this.m_SumSqErr[n2] = this.m_SumSqErr[n2] + MSError.sqr(err);
            int n3 = i;
            this.m_nbEx[n3] = this.m_nbEx[n3] + 1;
            if (!isRelative) continue;
            int n4 = i;
            this.m_SumTrueValues[n4] = this.m_SumTrueValues[n4] + real[i];
            int n5 = i;
            this.m_SumSquaredTrueValues[n5] = this.m_SumSquaredTrueValues[n5] + real[i] * real[i];
        }
    }

    public void addExample(double[] real, boolean[] predicted) {
        this.addExample(real, predicted, false);
    }

    public void addExample(DataTuple tuple, ClusStatistic pred, boolean isRelative) {
        double[] predicted = pred.getNumericPred();
        for (int i = 0; i < this.m_Dim; ++i) {
            double real_i = this.getAttr(i).getNumeric(tuple);
            double err = MSError.sqr(real_i - predicted[i]);
            if (Double.isInfinite(err) || Double.isNaN(err)) continue;
            int n = i;
            this.m_SumErr[n] = this.m_SumErr[n] + err;
            int n2 = i;
            this.m_SumSqErr[n2] = this.m_SumSqErr[n2] + MSError.sqr(err);
            int n3 = i;
            this.m_nbEx[n3] = this.m_nbEx[n3] + 1;
            if (!isRelative) continue;
            int n4 = i;
            this.m_SumTrueValues[n4] = this.m_SumTrueValues[n4] + real_i;
            int n5 = i;
            this.m_SumSquaredTrueValues[n5] = this.m_SumSquaredTrueValues[n5] + real_i * real_i;
        }
    }

    @Override
    public void addExample(DataTuple tuple, ClusStatistic pred) {
        this.addExample(tuple, pred, false);
    }

    public void addExample(DataTuple real, DataTuple pred, boolean isRelative) {
        for (int i = 0; i < this.m_Dim; ++i) {
            double predicted_i;
            double real_i = this.getAttr(i).getNumeric(real);
            double err = MSError.sqr(real_i - (predicted_i = this.getAttr(i).getNumeric(pred)));
            if (Double.isInfinite(err) || Double.isNaN(err)) continue;
            int n = i;
            this.m_SumErr[n] = this.m_SumErr[n] + err;
            int n2 = i;
            this.m_SumSqErr[n2] = this.m_SumSqErr[n2] + MSError.sqr(err);
            int n3 = i;
            this.m_nbEx[n3] = this.m_nbEx[n3] + 1;
            if (!isRelative) continue;
            int n4 = i;
            this.m_SumTrueValues[n4] = this.m_SumTrueValues[n4] + real_i;
            int n5 = i;
            this.m_SumSquaredTrueValues[n5] = this.m_SumSquaredTrueValues[n5] + real_i * real_i;
        }
    }

    @Override
    public void addExample(DataTuple real, DataTuple pred) {
        this.addExample(real, pred, false);
    }

    @Override
    public void addInvalid(DataTuple tuple) {
    }

    @Override
    public void add(ClusError other) {
        MSError oe = (MSError)other;
        for (int i = 0; i < this.m_Dim; ++i) {
            int n = i;
            this.m_SumErr[n] = this.m_SumErr[n] + oe.m_SumErr[i];
            int n2 = i;
            this.m_SumSqErr[n2] = this.m_SumSqErr[n2] + oe.m_SumSqErr[i];
            int n3 = i;
            this.m_nbEx[n3] = this.m_nbEx[n3] + oe.m_nbEx[i];
        }
    }

    @Override
    public void showModelError(PrintWriter out, int detail) {
        ClusNumberFormat fr = this.getFormat();
        StringBuffer buf = new StringBuffer();
        if (this.m_PrintAllComps) {
            buf.append("[");
            for (int i = 0; i < this.m_Dim; ++i) {
                if (i != 0) {
                    buf.append(",");
                }
                buf.append(fr.format(this.getModelErrorComponent(i)));
            }
            if (this.m_Dim > 1) {
                buf.append("]: ");
            } else {
                buf.append("]");
            }
        }
        if (this.m_Dim > 1 || !this.m_PrintAllComps) {
            buf.append(fr.format(this.getModelError()));
        }
        out.println(buf.toString());
    }

    public void showSummaryError(PrintWriter out, boolean detail) {
        ClusNumberFormat fr = this.getFormat();
        out.println(this.getPrefix() + "Mean over components MSE: " + fr.format(this.getModelError()));
    }

    @Override
    public String getName() {
        if (this.m_Weights == null) {
            return "Mean squared error (MSE)" + this.getAdditionalInfoFormatted();
        }
        return "Weighted mean squared error (MSE) (" + this.m_Weights.getName(this.m_Attrs) + ")" + this.getAdditionalInfoFormatted();
    }

    @Override
    public ClusError getErrorClone(ClusErrorList par) {
        return new MSError(par, this.m_Attrs, this.m_Weights, this.m_PrintAllComps, this.getAdditionalInfo());
    }

    @Override
    public double computeLeafError(ClusStatistic stat) {
        RegressionStat rstat = (RegressionStat)stat;
        return rstat.getSVarS(this.m_Weights) * (double)rstat.getNbAttributes();
    }

    @Override
    public boolean shouldBeLow() {
        return true;
    }
}

