/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.addon.hmc.HMCAverageSingleClass;

import java.io.IOException;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.Date;
import si.ijs.kt.clus.Clus;
import si.ijs.kt.clus.addon.hmc.HMCAverageSingleClass.HMCAverageNodeWiseModels;
import si.ijs.kt.clus.addon.hmc.HMCAverageSingleClass.HMCAverageTreeModel;
import si.ijs.kt.clus.algo.tdidt.ClusDecisionTree;
import si.ijs.kt.clus.data.ClusSchema;
import si.ijs.kt.clus.data.io.ARFFFile;
import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.data.rows.RowData;
import si.ijs.kt.clus.data.type.ClusAttrType;
import si.ijs.kt.clus.data.type.primitive.NumericAttrType;
import si.ijs.kt.clus.data.type.primitive.StringAttrType;
import si.ijs.kt.clus.error.common.ClusErrorList;
import si.ijs.kt.clus.error.hmlc.HierClassWiseAccuracy;
import si.ijs.kt.clus.ext.hierarchical.ClassHierarchy;
import si.ijs.kt.clus.ext.hierarchical.ClassTerm;
import si.ijs.kt.clus.ext.hierarchical.ClassesTuple;
import si.ijs.kt.clus.ext.hierarchical.ClassesValue;
import si.ijs.kt.clus.ext.hierarchical.HierClassTresholdPruner;
import si.ijs.kt.clus.main.ClusOutput;
import si.ijs.kt.clus.main.ClusRun;
import si.ijs.kt.clus.main.ClusStatManager;
import si.ijs.kt.clus.main.settings.Settings;
import si.ijs.kt.clus.model.ClusModel;
import si.ijs.kt.clus.model.ClusModelInfo;
import si.ijs.kt.clus.model.io.ClusModelCollectionIO;
import si.ijs.kt.clus.statistic.ClusStatistic;
import si.ijs.kt.clus.statistic.RegressionStat;
import si.ijs.kt.clus.statistic.WHTDStatistic;
import si.ijs.kt.clus.util.ClusLogger;
import si.ijs.kt.clus.util.exception.ClusException;
import si.ijs.kt.clus.util.jeans.io.ini.INIFileNominalOrDoubleOrVector;
import si.ijs.kt.clus.util.jeans.util.FileUtil;
import si.ijs.kt.clus.util.jeans.util.array.StringTable;
import si.ijs.kt.clus.util.jeans.util.cmdline.CMDLineArgs;
import si.ijs.kt.clus.util.jeans.util.cmdline.CMDLineArgsProvider;

public class HMCAverageSingleClass
implements CMDLineArgsProvider {
    private static String[] g_Options = new String[]{"models", "hsc", "stats", "loadPredictions"};
    private static int[] g_OptionArities = new int[]{1, 0, 0, 1};
    protected Clus m_Clus;
    protected StringTable m_Table = new StringTable();
    protected ClusErrorList[][] m_EvalArray;
    protected double[][][] m_PredProb;
    protected int m_NbModels;
    protected int m_TotSize;

    public void run(String[] args) throws IOException, ClusException, ClassNotFoundException, InterruptedException {
        this.m_Clus = new Clus();
        Settings sett = this.m_Clus.getSettings();
        CMDLineArgs cargs = new CMDLineArgs(this);
        cargs.process(args);
        if (cargs.allOK()) {
            sett.getGeneric().setDate(new Date());
            sett.getGeneric().setAppName(cargs.getMainArg(0));
            this.m_Clus.initSettings(cargs);
            ClusDecisionTree clss = new ClusDecisionTree(this.m_Clus);
            this.m_Clus.initialize(cargs, clss);
            WHTDStatistic target = this.createTargetStat();
            ((ClusStatistic)target).calcMean();
            if (cargs.hasOption("stats")) {
                this.computeStats();
                System.exit(0);
            }
            if (cargs.hasOption("models") || cargs.hasOption("hsc")) {
                if (cargs.hasOption("hsc")) {
                    this.m_Clus.getSettings().getGeneric().setSuffix(".hsc.combined");
                } else {
                    this.m_Clus.getSettings().getGeneric().setSuffix(".sc.combined");
                }
                ClusRun cr = this.m_Clus.partitionData();
                cr.combineTrainAndValidSets();
                ClassHierarchy hier = this.getStatManager().getHier();
                this.m_PredProb = new double[2][][];
                for (int i = 0; i <= 1; ++i) {
                    int size = cr.getDataSet(i).getNbRows();
                    this.m_PredProb[i] = new double[size][hier.getTotal()];
                    for (int k = 0; k < size; ++k) {
                        Arrays.fill(this.m_PredProb[i][k], Double.MAX_VALUE);
                    }
                }
                INIFileNominalOrDoubleOrVector class_thr = this.getSettings().getHMLC().getClassificationThresholds();
                if (class_thr.isVector()) {
                    HierClassTresholdPruner pruner = (HierClassTresholdPruner)this.getStatManager().getTreePruner(null);
                    this.m_EvalArray = new ClusErrorList[2][pruner.getNbResults()];
                    for (int i = 0; i < pruner.getNbResults(); ++i) {
                        for (int j = 0; j <= 1; ++j) {
                            this.m_EvalArray[j][i] = new ClusErrorList();
                            this.m_EvalArray[j][i].addError(new HierClassWiseAccuracy(this.m_EvalArray[j][i], hier));
                            this.m_EvalArray[j][i].addError(null);
                        }
                    }
                }
                if (cargs.hasOption("hsc")) {
                    HMCAverageNodeWiseModels avg = new HMCAverageNodeWiseModels(this, this.m_PredProb);
                    avg.processModels(cr);
                    this.m_NbModels = avg.getNbModels();
                    this.m_TotSize = avg.getTotalSize();
                    if (this.m_EvalArray != null) {
                        avg.updateErrorMeasures(cr);
                    }
                } else {
                    this.loadModelPerModel(cargs.getOptionValue("models"), cr);
                }
                ClusOutput output = new ClusOutput(sett.getGeneric().getAppNameWithSuffix() + ".out", this.m_Clus.getSchema(), sett);
                ClusModelInfo def_model = cr.addModelInfo(0);
                def_model.setModel(ClusDecisionTree.induceDefault(cr));
                ClusModelInfo orig_model_inf = cr.addModelInfo(1);
                HMCAverageTreeModel orig_model = new HMCAverageTreeModel(target, this.m_PredProb, this.m_NbModels, this.m_TotSize);
                orig_model_inf.setModel(orig_model);
                cr.copyAllModelsMIs();
                RowData train = (RowData)cr.getTrainingSet();
                train.addIndices();
                orig_model.setDataSet(0);
                this.m_Clus.calcError(train.getIterator(), 0, cr);
                RowData test = cr.getTestSet();
                if (test != null) {
                    test.addIndices();
                    orig_model.setDataSet(1);
                    this.m_Clus.calcError(test.getIterator(), 1, cr);
                }
                if (class_thr.isVector()) {
                    HierClassTresholdPruner pruner = (HierClassTresholdPruner)this.getStatManager().getTreePruner(null);
                    for (int i = 0; i < pruner.getNbResults(); ++i) {
                        ClusModelInfo pruned_info = cr.addModelInfo(pruner.getPrunedName(i));
                        pruned_info.setShouldWritePredictions(false);
                        pruned_info.setTrainError(this.m_EvalArray[0][i]);
                        pruned_info.setTestError(this.m_EvalArray[1][i]);
                    }
                }
                output.writeHeader();
                output.writeOutput(cr, true, this.getSettings().getOutput().isOutTrainError());
                output.close();
            } else if (cargs.hasOption("loadPredictions")) {
                this.m_Clus.getSettings().getGeneric().setSuffix(".evaluatedPredictions");
                ClusRun cr = this.m_Clus.partitionData();
                ClassHierarchy hier = this.getStatManager().getHier();
                int size = cr.getDataSet(1).getNbRows();
                this.m_PredProb = new double[1][size][hier.getTotal()];
                String file = cargs.getOptionValue("loadPredictions");
                RowData rw = ARFFFile.readArff(file);
                ClusSchema schema = rw.getSchema();
                NumericAttrType[] na = schema.getNumericAttrUse(ClusAttrType.AttributeUseType.All);
                int[] mapping_classes = new int[schema.getNbAttributes()];
                for (int y = 0; y < na.length; ++y) {
                    String label = na[y].getName();
                    boolean found = false;
                    for (int a = 0; a < hier.getTotal(); ++a) {
                        if (!hier.getTermAt(a).toStringHuman(hier).equals(label)) continue;
                        mapping_classes[y] = a;
                        found = true;
                    }
                    if (found) continue;
                    throw new ClusException("Error: class " + label + " not found.");
                }
                RowData testset = cr.getDataSet(1);
                ClusLogger.info("Number of rows in predictions-file: " + rw.getNbRows());
                ClusLogger.info("Number of rows in test-file: " + testset.getNbRows());
                ClusLogger.info("Number of classes: " + hier.getTotal());
                for (int x = 0; x < rw.getNbRows(); ++x) {
                    DataTuple tuple = rw.getTuple(x);
                    for (int y = 0; y < na.length; ++y) {
                        int a = mapping_classes[y];
                        this.m_PredProb[0][x][a] = na[y].getNumeric(tuple);
                    }
                }
                INIFileNominalOrDoubleOrVector class_thr = this.getSettings().getHMLC().getClassificationThresholds();
                if (class_thr.isVector()) {
                    HierClassTresholdPruner pruner = (HierClassTresholdPruner)this.getStatManager().getTreePruner(null);
                    this.m_EvalArray = new ClusErrorList[2][pruner.getNbResults()];
                    for (int i = 0; i < pruner.getNbResults(); ++i) {
                        this.m_EvalArray[1][i] = new ClusErrorList();
                        this.m_EvalArray[1][i].addError(new HierClassWiseAccuracy(this.m_EvalArray[1][i], hier));
                        this.m_EvalArray[1][i].addError(null);
                    }
                }
                ClusOutput output = new ClusOutput(sett.getGeneric().getAppNameWithSuffix() + ".out", this.m_Clus.getSchema(), sett);
                ClusModelInfo def_model = cr.addModelInfo(0);
                def_model.setModel(ClusDecisionTree.induceDefault(cr));
                ClusModelInfo orig_model_inf = cr.addModelInfo(1);
                HMCAverageTreeModel orig_model = new HMCAverageTreeModel(target, this.m_PredProb, this.m_NbModels, this.m_TotSize);
                orig_model_inf.setModel(orig_model);
                cr.copyAllModelsMIs();
                RowData test = cr.getTestSet();
                if (test != null) {
                    test.addIndices();
                    this.m_Clus.calcError(test.getIterator(), 1, cr);
                }
                if (class_thr.isVector()) {
                    HierClassTresholdPruner pruner = (HierClassTresholdPruner)this.getStatManager().getTreePruner(null);
                    for (int i = 0; i < pruner.getNbResults(); ++i) {
                        ClusModelInfo pruned_info = cr.addModelInfo(pruner.getPrunedName(i));
                        pruned_info.setShouldWritePredictions(false);
                        pruned_info.setTestError(this.m_EvalArray[1][i]);
                    }
                }
                output.writeHeader();
                output.writeOutput(cr, true, this.getSettings().getOutput().isOutTrainError());
                output.close();
            } else {
                throw new ClusException("Must specify e.g., -models dirname");
            }
        }
    }

    public ClusStatManager getStatManager() {
        return this.m_Clus.getStatManager();
    }

    public Settings getSettings() {
        return this.m_Clus.getSettings();
    }

    public Clus getClus() {
        return this.m_Clus;
    }

    public ClusErrorList getEvalArray(int traintest, int j) {
        return this.m_EvalArray[traintest][j];
    }

    public WHTDStatistic createTargetStat() {
        return (WHTDStatistic)this.m_Clus.getStatManager().createStatistic(ClusAttrType.AttributeUseType.Target);
    }

    public String getClassStr(String file) {
        StringBuffer result = new StringBuffer();
        String value = FileUtil.getName(FileUtil.removePath(file));
        String[] cmps = value.split("_");
        String[] elems = cmps[cmps.length - 1].split("-");
        for (int i = 0; i < elems.length; ++i) {
            if (i != 0) {
                result.append("/");
            }
            result.append(elems[i]);
        }
        return result.toString();
    }

    public int getClassIndex(String file) throws ClusException {
        String class_str = this.getClassStr(file);
        ClassHierarchy hier = this.getStatManager().getHier();
        ClassesValue val = new ClassesValue(class_str, hier.getType().getTable());
        return hier.getClassTerm(val).getIndex();
    }

    public ClusModel loadModel(String file) throws IOException, ClusException, ClassNotFoundException {
        String class_str = this.getClassStr(file);
        ClusLogger.info("Loading: " + file + " class: " + class_str);
        ClusModelCollectionIO io = ClusModelCollectionIO.load(file);
        ClusModel sub_model = io.getModel("Original");
        if (sub_model == null) {
            throw new ClusException("Error: .model file does not contain model named 'Original'");
        }
        ++this.m_NbModels;
        this.m_TotSize += sub_model.getModelSize();
        return sub_model;
    }

    public void loadModelPerModel(String dir, ClusRun cr) throws IOException, ClusException, ClassNotFoundException, InterruptedException {
        String[] files = FileUtil.dirList(dir, "model");
        for (int i = 0; i < files.length; ++i) {
            ClusModel model = this.loadModel(FileUtil.cmbPath(dir, files[i]));
            int class_idx = this.getClassIndex(files[i]);
            for (int j = 0; j <= 1; ++j) {
                this.evaluateModelAndUpdateErrors(j, class_idx, model, cr);
            }
        }
        INIFileNominalOrDoubleOrVector class_thr = this.getSettings().getHMLC().getClassificationThresholds();
        if (class_thr.isVector()) {
            HierClassTresholdPruner pruner = (HierClassTresholdPruner)this.getStatManager().getTreePruner(null);
            for (int j = 0; j < pruner.getNbResults(); ++j) {
                for (int traintest = 0; traintest <= 1; ++traintest) {
                    RowData data = cr.getDataSet(traintest);
                    ClusErrorList error = this.getEvalArray(traintest, j);
                    error.setNbExamples(data.getNbRows(), data.getNbRows());
                }
            }
        }
    }

    public void evaluateModelAndUpdateErrors(int train_or_test, int class_idx, ClusModel model, ClusRun cr) throws ClusException, IOException, InterruptedException {
        RowData data = cr.getDataSet(train_or_test);
        this.m_Clus.getSchema().attachModel(model);
        INIFileNominalOrDoubleOrVector class_thr = this.getSettings().getHMLC().getClassificationThresholds();
        if (class_thr.isVector()) {
            HierClassTresholdPruner pruner = (HierClassTresholdPruner)this.getStatManager().getTreePruner(null);
            for (int i = 0; i < data.getNbRows(); ++i) {
                DataTuple tuple = data.getTuple(i);
                ClusStatistic prediction = model.predictWeighted(tuple);
                double[] predicted_distr = prediction.getNumericPred();
                ClassesTuple tp = (ClassesTuple)tuple.getObjVal(0);
                boolean actually_has_class = tp.hasClass(class_idx);
                for (int j = 0; j < pruner.getNbResults(); ++j) {
                    boolean predicted_class = predicted_distr[0] >= pruner.getThreshold(j) / 100.0;
                    HierClassWiseAccuracy acc = (HierClassWiseAccuracy)this.m_EvalArray[train_or_test][j].getError(0);
                    acc.nextPrediction(class_idx, predicted_class, actually_has_class);
                }
            }
        }
        for (int i = 0; i < data.getNbRows(); ++i) {
            DataTuple tuple = data.getTuple(i);
            ClusStatistic prediction = model.predictWeighted(tuple);
            double[] predicted_distr = prediction.getNumericPred();
            this.m_PredProb[train_or_test][i][class_idx] = predicted_distr[0];
        }
    }

    @Override
    public String[] getOptionArgs() {
        return g_Options;
    }

    @Override
    public int[] getOptionArgArities() {
        return g_OptionArities;
    }

    @Override
    public int getNbMainArgs() {
        return 1;
    }

    @Override
    public void showHelp() {
    }

    public void countClasses(RowData data, double[] counts) {
        ClassHierarchy hier = this.getStatManager().getHier();
        int sidx = hier.getType().getArrayIndex();
        boolean[] arr = new boolean[hier.getTotal()];
        for (int i = 0; i < data.getNbRows(); ++i) {
            int j;
            DataTuple tuple = data.getTuple(i);
            ClassesTuple tp = (ClassesTuple)tuple.getObjVal(sidx);
            Arrays.fill(arr, false);
            tp.fillBoolArrayNodeAndAncestors(arr);
            for (j = 0; j < arr.length; ++j) {
                if (!arr[j]) continue;
                counts[0] = counts[0] + 1.0;
            }
            hier.removeParentNodes(arr);
            for (j = 0; j < arr.length; ++j) {
                if (!arr[j]) continue;
                counts[1] = counts[1] + 1.0;
            }
        }
    }

    public void computeStats() throws ClusException, IOException, InterruptedException {
        ClusRun cr = this.m_Clus.partitionData();
        RegressionStat stat = (RegressionStat)this.getStatManager().createStatistic(ClusAttrType.AttributeUseType.Target);
        RowData train = (RowData)cr.getTrainingSet();
        RowData valid = (RowData)cr.getPruneSet();
        RowData test = cr.getTestSet();
        train.calcTotalStat(stat);
        if (valid != null) {
            valid.calcTotalStat(stat);
        }
        if (test != null) {
            test.calcTotalStat(stat);
        }
        stat.calcMean();
        ClassHierarchy hier = this.getStatManager().getHier();
        PrintWriter wrt = this.getSettings().getGeneric().getFileAbsoluteWriter(this.getSettings().getGeneric().getAppName() + "-hmcstat.arff");
        ClusSchema schema = new ClusSchema("HMC-Statistics");
        schema.addAttrType(new StringAttrType("Class"));
        schema.addAttrType(new NumericAttrType("Weight"));
        schema.addAttrType(new NumericAttrType("MinDepth"));
        schema.addAttrType(new NumericAttrType("MaxDepth"));
        schema.addAttrType(new NumericAttrType("Frequency"));
        double total = stat.getTotalWeight();
        double[] classCounts = new double[2];
        this.countClasses(train, classCounts);
        if (valid != null) {
            this.countClasses(valid, classCounts);
        }
        if (test != null) {
            this.countClasses(test, classCounts);
        }
        int nbDescriptiveAttrs = this.m_Clus.getSchema().getNbDescriptiveAttributes();
        wrt.println();
        wrt.println("% Number of examples: " + total);
        wrt.println("% Number of descriptive attributes: " + nbDescriptiveAttrs);
        wrt.println("% Number of classes: " + hier.getTotal());
        wrt.println("% Avg number of labels/example: " + classCounts[0] / total + " (most specific: " + classCounts[1] / total + ")");
        wrt.println("% Hierarchy depth: " + hier.getDepth());
        wrt.println();
        ARFFFile.writeArffHeader(wrt, schema);
        wrt.println("@DATA");
        for (int i = 0; i < hier.getTotal(); ++i) {
            ClassTerm term = hier.getTermAt(i);
            int index = term.getIndex();
            wrt.print(term.toStringHuman(hier));
            wrt.print("," + hier.getWeight(index));
            wrt.print("," + term.getMinDepth());
            wrt.print("," + term.getMaxDepth());
            wrt.print("," + stat.getSumValues(index));
            wrt.println();
        }
        wrt.close();
    }

    public static void main(String[] args) {
        try {
            HMCAverageSingleClass avg = new HMCAverageSingleClass();
            avg.run(args);
        }
        catch (IOException io) {
            ClusLogger.info("IO Error: " + io.getMessage());
        }
        catch (ClusException cl) {
            ClusLogger.info("Error: " + cl.getMessage());
        }
        catch (ClassNotFoundException cn) {
            ClusLogger.info("Error: " + cn.getMessage());
        }
        catch (Exception e) {
            ClusLogger.info(e.toString());
        }
    }
}

