/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.ext.semisupervised.confidence.regression;

import java.util.ArrayList;
import java.util.Map;
import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.ext.ensemble.ClusForest;
import si.ijs.kt.clus.ext.semisupervised.Helper;
import si.ijs.kt.clus.ext.semisupervised.confidence.PredictionConfidence;
import si.ijs.kt.clus.main.ClusStatManager;
import si.ijs.kt.clus.main.settings.section.SettingsSSL;
import si.ijs.kt.clus.model.ClusModel;
import si.ijs.kt.clus.statistic.RegressionStat;
import si.ijs.kt.clus.statistic.WHTDStatistic;
import si.ijs.kt.clus.util.exception.ClusException;

public class VarianceScore
extends PredictionConfidence {
    Map<Integer, double[]> stdDevs;
    int nb_unlabeled;

    public VarianceScore(ClusStatManager statManager, SettingsSSL.SSLNormalization normalizationType, SettingsSSL.SSLAggregation aggregationType) throws ClusException {
        super(statManager, normalizationType, aggregationType);
        if (!statManager.getSettings().getEnsemble().isEnsembleMode()) {
            throw new ClusException("If prediction confidence uses Standard Deviation then ensembles have to be used. Please use -forest option");
        }
    }

    @Override
    public double[] calculatePerTargetScores(ClusModel model, DataTuple tuple) throws ClusException, InterruptedException {
        model.predictWeighted(tuple);
        return this.processVotes(((ClusForest)model).getVotes());
    }

    @Override
    public double[] calculatePerTargetOOBScores(ClusForest model, DataTuple tuple) throws InterruptedException {
        return this.processVotes(model.getOOBVotes(tuple));
    }

    private double[] processVotes(ArrayList votes) {
        switch (this.m_StatManager.getTargetMode()) {
            case REGRESSION: {
                return this.processVotesRegression(votes);
            }
            case HIERARCHICAL: {
                return this.processVotesHMC(votes);
            }
        }
        return this.processVotesRegression(votes);
    }

    private double[] processVotesRegression(ArrayList votes) {
        double[][] predicts = new double[this.getNbTargetAttributes()][votes.size()];
        for (int i = 0; i < votes.size(); ++i) {
            RegressionStat stat = (RegressionStat)votes.get(i);
            for (int j = 0; j < stat.getNbAttributes(); ++j) {
                predicts[j][i] = stat.getMean(j);
            }
        }
        return this.calcStDev(predicts);
    }

    private double[] processVotesHMC(ArrayList votes) {
        double[] m_Means = new double[this.getNbTargetAttributes()];
        int nb_votes = votes.size();
        for (int j = 0; j < nb_votes; ++j) {
            WHTDStatistic vote = (WHTDStatistic)votes.get(j);
            for (int i = 0; i < this.getNbTargetAttributes(); ++i) {
                int n = i;
                m_Means[n] = m_Means[n] + vote.m_Means[i] / (double)nb_votes;
            }
        }
        return m_Means;
    }

    private double[] calcStDev(double[][] values) {
        double[] result = new double[values.length];
        for (int i = 0; i < result.length; ++i) {
            result[i] = values[i].length <= 1 ? 0.0 : Helper.getStdDev(values[i]);
        }
        return result;
    }
}

