/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.algo.tdidt.tune;

import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.util.Random;
import si.ijs.kt.clus.algo.ClusInductionAlgorithm;
import si.ijs.kt.clus.algo.ClusInductionAlgorithmType;
import si.ijs.kt.clus.algo.tdidt.ClusDecisionTree;
import si.ijs.kt.clus.data.ClusData;
import si.ijs.kt.clus.data.ClusSchema;
import si.ijs.kt.clus.data.rows.RowData;
import si.ijs.kt.clus.data.type.ClusAttrType;
import si.ijs.kt.clus.data.type.primitive.NominalAttrType;
import si.ijs.kt.clus.data.type.primitive.NumericAttrType;
import si.ijs.kt.clus.error.Accuracy;
import si.ijs.kt.clus.error.RMSError;
import si.ijs.kt.clus.error.common.ClusError;
import si.ijs.kt.clus.error.common.ClusErrorList;
import si.ijs.kt.clus.error.hmlc.HierErrorMeasures;
import si.ijs.kt.clus.main.ClusRun;
import si.ijs.kt.clus.main.ClusStatManager;
import si.ijs.kt.clus.main.ClusSummary;
import si.ijs.kt.clus.main.settings.Settings;
import si.ijs.kt.clus.main.settings.section.SettingsHMLC;
import si.ijs.kt.clus.model.ClusModel;
import si.ijs.kt.clus.model.ClusModelInfo;
import si.ijs.kt.clus.selection.ClusSelection;
import si.ijs.kt.clus.selection.XValRandomSelection;
import si.ijs.kt.clus.selection.XValSelection;
import si.ijs.kt.clus.util.ClusLogger;
import si.ijs.kt.clus.util.ClusRandom;
import si.ijs.kt.clus.util.exception.ClusException;
import si.ijs.kt.clus.util.jeans.util.cmdline.CMDLineArgs;

public class CDTTuneFTest
extends ClusDecisionTree {
    protected ClusInductionAlgorithmType m_Class;
    protected double[] m_FTests;

    public CDTTuneFTest(ClusInductionAlgorithmType clss) {
        super(clss.getClus());
        this.m_Class = clss;
    }

    public CDTTuneFTest(ClusInductionAlgorithmType clss, double[] ftests) {
        this(clss);
        this.m_FTests = ftests;
    }

    @Override
    public ClusInductionAlgorithm createInduce(ClusSchema schema, Settings sett, CMDLineArgs cargs) throws ClusException, IOException {
        return this.m_Class.createInduce(schema, sett, cargs);
    }

    @Override
    public void printInfo() {
        ClusLogger.info("TDIDT (Tuning F-Test)");
        ClusLogger.info("Heuristic: " + this.getStatManager().getHeuristicName());
    }

    private final void showFold(int i) {
        if (this.getSettings().getGeneral().getVerbose() > 1) {
            if (i != 0) {
                System.out.print(" ");
            }
            System.out.print(String.valueOf(i + 1));
            System.out.flush();
        }
    }

    public ClusErrorList createTuneError(ClusStatManager mgr) {
        ClusErrorList parent = new ClusErrorList();
        if (mgr.getTargetMode() == ClusStatManager.Mode.HIERARCHICAL) {
            SettingsHMLC.HierarchyMeasures optimize = this.getSettings().getHMLC().getHierOptimizeErrorMeasure();
            parent.addError(new HierErrorMeasures(parent, mgr.getHier(), null, optimize, false, this.getSettings().getOutput().isGzipOutput()));
            return parent;
        }
        NumericAttrType[] num = mgr.getSchema().getNumericAttrUse(ClusAttrType.AttributeUseType.Target);
        NominalAttrType[] nom = mgr.getSchema().getNominalAttrUse(ClusAttrType.AttributeUseType.Target);
        if (nom.length != 0) {
            parent.addError(new Accuracy(parent, nom));
        }
        if (num.length != 0) {
            parent.addError(new RMSError(parent, num));
        }
        return parent;
    }

    public final ClusRun partitionDataBasic(ClusData data, ClusSelection sel, ClusSummary summary, int idx) throws IOException, ClusException, InterruptedException {
        ClusRun cr = new ClusRun(data.cloneData(), summary);
        if (sel != null) {
            if (sel.changesDistribution()) {
                ((RowData)cr.getTrainingSet()).update(sel);
            } else {
                ClusData val = cr.getTrainingSet().select(sel);
                cr.setTestSet(((RowData)val).getIterator());
            }
        }
        cr.setIndex(idx);
        cr.copyTrainingData();
        return cr;
    }

    public double doParamXVal(RowData trset, RowData pruneset) throws Exception {
        int prevVerb = this.getSettings().getGeneral().enableVerbose(0);
        ClusStatManager mgr = this.getStatManager();
        ClusSummary summ = new ClusSummary();
        summ.setStatManager(this.getStatManager());
        summ.addModelInfo(1).setTestError(this.createTuneError(mgr));
        ClusRandom.initialize(this.getSettings());
        double avgSize = 0.0;
        if (pruneset != null) {
            ClusRun cr = new ClusRun(trset.cloneData(), summ);
            ClusModel model = this.m_Class.induceSingleUnpruned(cr);
            avgSize = model.getModelSize();
            cr.addModelInfo(1).setModel(model);
            cr.addModelInfo(1).setTestError(this.createTuneError(mgr));
            this.m_Clus.calcError(pruneset.getIterator(), 1, cr, null);
            summ.addSummary(cr);
        } else {
            Random random = new Random(0L);
            int nbfolds = Integer.parseInt(this.getSettings().getModel().getTuneFolds());
            XValRandomSelection sel = new XValRandomSelection(trset.getNbRows(), nbfolds, random);
            for (int i = 0; i < nbfolds; ++i) {
                this.showFold(i);
                XValSelection msel = new XValSelection(sel, i);
                ClusRun cr = this.partitionDataBasic(trset, msel, summ, i + 1);
                ClusModel model = this.m_Class.induceSingleUnpruned(cr);
                avgSize += (double)model.getModelSize();
                cr.addModelInfo(1).setModel(model);
                cr.addModelInfo(1).setTestError(this.createTuneError(mgr));
                this.m_Clus.calcError(cr.getTestIter(), 1, cr, null);
                summ.addSummary(cr);
            }
            avgSize /= (double)nbfolds;
            if (this.getSettings().getGeneral().getVerbose() > 1) {
                ClusLogger.info();
            }
        }
        ClusModelInfo mi = summ.getModelInfo(1);
        this.getSettings().getGeneral().enableVerbose(prevVerb);
        ClusError err = mi.getTestError().getFirstError();
        if (this.getSettings().getGeneral().getVerbose() > 1) {
            PrintWriter wrt = new PrintWriter(new OutputStreamWriter(System.out));
            wrt.print("Size: " + avgSize + ", ");
            wrt.print("Error: ");
            err.showModelError(wrt, 3);
            wrt.flush();
        }
        return err.getModelError();
    }

    public void findBestFTest(RowData trset, RowData pruneset) throws Exception {
        int best_value = 0;
        boolean low = this.createTuneError(this.getStatManager()).getFirstError().shouldBeLow();
        double best_error = low ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
        for (int i = 0; i < this.m_FTests.length; ++i) {
            this.getSettings().getTree().setFTest(this.m_FTests[i], this.getSettings().getGeneral().getVerbose());
            if (this.getSettings().getGeneral().getVerbose() > 0) {
                ClusLogger.info("Try for F-test value = " + this.m_FTests[i]);
            }
            double err = this.doParamXVal(trset, pruneset);
            if (this.getSettings().getGeneral().getVerbose() > 1) {
                System.out.print("-> " + err);
            }
            if (low) {
                if (err < best_error - 1.0E-16) {
                    best_error = err;
                    best_value = i;
                    if (this.getSettings().getGeneral().getVerbose() > 1) {
                        ClusLogger.info(" *");
                    }
                } else if (this.getSettings().getGeneral().getVerbose() > 1) {
                    ClusLogger.info();
                }
            } else if (err > best_error + 1.0E-16) {
                best_error = err;
                best_value = i;
                if (this.getSettings().getGeneral().getVerbose() > 1) {
                    ClusLogger.info(" *");
                }
            } else if (this.getSettings().getGeneral().getVerbose() > 1) {
                ClusLogger.info();
            }
            if (this.getSettings().getGeneral().getVerbose() <= 0) continue;
            ClusLogger.info();
        }
        this.getSettings().getTree().setFTest(this.m_FTests[best_value], this.getSettings().getGeneral().getVerbose());
        if (this.getSettings().getGeneral().getVerbose() > 0) {
            ClusLogger.info("Best F-test value is: " + this.m_FTests[best_value]);
        }
    }

    @Override
    public void induceAll(ClusRun cr) throws ClusException, IOException {
        try {
            RowData valid = (RowData)cr.getPruneSet();
            RowData train = (RowData)cr.getTrainingSet();
            this.findBestFTest(train, valid);
            ClusLogger.info();
            cr.combineTrainAndValidSets();
            ClusRandom.initialize(this.getSettings());
            this.m_Class.induceAll(cr);
        }
        catch (Exception e) {
            ClusLogger.severe(e.toString());
        }
    }
}

