/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.addon.hmc.HMCAverageSingleClass;

import java.io.IOException;
import java.util.ArrayList;
import si.ijs.kt.clus.Clus;
import si.ijs.kt.clus.addon.hmc.HMCAverageSingleClass.HMCAverageSingleClass;
import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.data.rows.RowData;
import si.ijs.kt.clus.error.common.ClusErrorList;
import si.ijs.kt.clus.error.hmlc.HierClassWiseAccuracy;
import si.ijs.kt.clus.ext.hierarchical.ClassHierarchy;
import si.ijs.kt.clus.ext.hierarchical.ClassTerm;
import si.ijs.kt.clus.ext.hierarchical.ClassesTuple;
import si.ijs.kt.clus.ext.hierarchical.HierClassTresholdPruner;
import si.ijs.kt.clus.main.ClusRun;
import si.ijs.kt.clus.main.ClusStatManager;
import si.ijs.kt.clus.main.settings.Settings;
import si.ijs.kt.clus.model.ClusModel;
import si.ijs.kt.clus.model.io.ClusModelCollectionIO;
import si.ijs.kt.clus.statistic.ClusStatistic;
import si.ijs.kt.clus.util.ClusLogger;
import si.ijs.kt.clus.util.exception.ClusException;

public class HMCAverageNodeWiseModels {
    protected int m_NbModels;
    protected int m_TotSize;
    protected HMCAverageSingleClass m_Cls;
    protected double[][][] m_PredProb;

    public HMCAverageNodeWiseModels(HMCAverageSingleClass cls, double[][][] predprop) {
        this.m_Cls = cls;
        this.m_PredProb = predprop;
    }

    public int getNbModels() {
        return this.m_NbModels;
    }

    public int getTotalSize() {
        return this.m_TotSize;
    }

    public ClusStatManager getStatManager() {
        return this.m_Cls.getStatManager();
    }

    public Settings getSettings() {
        return this.m_Cls.getSettings();
    }

    public Clus getClus() {
        return this.m_Cls.getClus();
    }

    public boolean allParentsOk(ClassTerm term, boolean[] computed) {
        for (int j = 0; j < term.getNbParents(); ++j) {
            ClassTerm parent = term.getParent(j);
            if (parent.getIndex() == -1 || computed[parent.getIndex()]) continue;
            return false;
        }
        return true;
    }

    public void processModels(ClusRun cr) throws ClusException, IOException, ClassNotFoundException, InterruptedException {
        ClassHierarchy hier = this.getStatManager().getHier();
        boolean[] prob_computed = new boolean[hier.getTotal()];
        ArrayList<ClassTerm> todo = new ArrayList<ClassTerm>();
        for (int i = 0; i < hier.getTotal(); ++i) {
            ClassTerm term = hier.getTermAt(i);
            todo.add(term);
        }
        int nb_done = 0;
        while (nb_done < hier.getTotal()) {
            for (int i = todo.size() - 1; i >= 0; --i) {
                ClassTerm term = (ClassTerm)todo.get(i);
                if (!this.allParentsOk(term, prob_computed)) continue;
                this.doOneClass(term, cr);
                prob_computed[term.getIndex()] = true;
                todo.remove(i);
                ++nb_done;
            }
        }
    }

    public void updateErrorMeasures(ClusRun cr) throws ClusException, IOException, InterruptedException {
        ClassHierarchy hier = this.getStatManager().getHier();
        HierClassTresholdPruner pruner = (HierClassTresholdPruner)this.getStatManager().getTreePruner(null);
        for (int traintest = 0; traintest <= 1; ++traintest) {
            RowData data = cr.getDataSet(traintest);
            for (int exid = 0; exid < data.getNbRows(); ++exid) {
                DataTuple tuple = data.getTuple(exid);
                ClassesTuple tp = (ClassesTuple)tuple.getObjVal(0);
                for (int clidx = 0; clidx < hier.getTotal(); ++clidx) {
                    double predicted_weight = this.m_PredProb[traintest][exid][clidx];
                    boolean actually_has_class = tp.hasClass(clidx);
                    for (int j = 0; j < pruner.getNbResults(); ++j) {
                        boolean predicted_class = predicted_weight >= pruner.getThreshold(j) / 100.0;
                        HierClassWiseAccuracy acc = (HierClassWiseAccuracy)this.m_Cls.getEvalArray(traintest, j).getError(0);
                        acc.nextPrediction(clidx, predicted_class, actually_has_class);
                    }
                }
            }
            for (int j = 0; j < pruner.getNbResults(); ++j) {
                ClusErrorList error = this.m_Cls.getEvalArray(traintest, j);
                error.setNbExamples(data.getNbRows(), data.getNbRows());
            }
        }
    }

    public void doOneClass(ClassTerm term, ClusRun cr) throws IOException, ClassNotFoundException, ClusException, InterruptedException {
        String childName = term.toPathString("=");
        for (int j = 0; j < term.getNbParents(); ++j) {
            ClassTerm parent = term.getParent(j);
            String nodeName = parent.toPathString("=");
            String name = this.getSettings().getGeneric().getAppName() + "-" + nodeName + "-" + childName;
            String toload = "hsc/model/" + name + ".model";
            ClusLogger.info("Loading: " + toload);
            ClusModelCollectionIO io = ClusModelCollectionIO.load(toload);
            ClusModel model = io.getModel("Original");
            if (model == null) {
                throw new ClusException("Error: .model file does not contain model named 'Original'");
            }
            ++this.m_NbModels;
            this.m_TotSize += model.getModelSize();
            this.getClus().getSchema().attachModel(model);
            for (int traintest = 0; traintest <= 1; ++traintest) {
                RowData data = cr.getDataSet(traintest);
                for (int exid = 0; exid < data.getNbRows(); ++exid) {
                    this.updatePrediction(data, exid, traintest, model, parent, term);
                }
            }
        }
        int child_idx = term.getIndex();
        for (int traintest = 0; traintest <= 1; ++traintest) {
            RowData data = cr.getDataSet(traintest);
            for (int exid = 0; exid < data.getNbRows(); ++exid) {
                double[] dArray = this.m_PredProb[traintest][exid];
                int n = child_idx;
                dArray[n] = dArray[n] / (double)term.getNbParents();
            }
        }
    }

    public void updatePrediction(RowData data, int exid, int traintest, ClusModel model, ClassTerm parent, ClassTerm term) throws ClusException, InterruptedException {
        int child_idx;
        DataTuple tuple = data.getTuple(exid);
        ClusStatistic prediction = model.predictWeighted(tuple);
        double[] predicted_distr = prediction.getNumericPred();
        double predicted_prob = predicted_distr[0];
        int parent_idx = parent.getIndex();
        double parent_prob = parent_idx == -1 ? 1.0 : this.m_PredProb[traintest][exid][parent_idx];
        double child_prob = parent_prob * predicted_prob;
        if (child_prob < this.m_PredProb[traintest][exid][child_idx = term.getIndex()]) {
            this.m_PredProb[traintest][exid][child_idx] = child_prob;
        }
    }
}

