/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.algo.rules.tune;

import java.io.IOException;
import java.util.Random;
import si.ijs.kt.clus.algo.ClusInductionAlgorithm;
import si.ijs.kt.clus.algo.ClusInductionAlgorithmType;
import si.ijs.kt.clus.algo.rules.ClusRuleClassifier;
import si.ijs.kt.clus.algo.tdidt.ClusDecisionTree;
import si.ijs.kt.clus.data.ClusData;
import si.ijs.kt.clus.data.ClusSchema;
import si.ijs.kt.clus.data.rows.RowData;
import si.ijs.kt.clus.data.type.ClusAttrType;
import si.ijs.kt.clus.data.type.primitive.NominalAttrType;
import si.ijs.kt.clus.data.type.primitive.NumericAttrType;
import si.ijs.kt.clus.error.Accuracy;
import si.ijs.kt.clus.error.RMSError;
import si.ijs.kt.clus.error.common.ClusError;
import si.ijs.kt.clus.error.common.ClusErrorList;
import si.ijs.kt.clus.main.ClusRun;
import si.ijs.kt.clus.main.ClusStatManager;
import si.ijs.kt.clus.main.ClusSummary;
import si.ijs.kt.clus.main.settings.Settings;
import si.ijs.kt.clus.main.settings.section.SettingsEnsemble;
import si.ijs.kt.clus.model.ClusModel;
import si.ijs.kt.clus.model.ClusModelInfo;
import si.ijs.kt.clus.selection.ClusSelection;
import si.ijs.kt.clus.selection.XValRandomSelection;
import si.ijs.kt.clus.selection.XValSelection;
import si.ijs.kt.clus.util.ClusLogger;
import si.ijs.kt.clus.util.ClusRandom;
import si.ijs.kt.clus.util.exception.ClusException;
import si.ijs.kt.clus.util.jeans.util.cmdline.CMDLineArgs;

public class FIREROSTune
extends ClusRuleClassifier {
    protected ClusInductionAlgorithmType m_Class;
    protected String[] m_ROSSubspaceSizes;

    public FIREROSTune(ClusInductionAlgorithmType clss) {
        super(clss.getClus());
        this.m_Class = clss;
    }

    public FIREROSTune(ClusInductionAlgorithmType clss, String[] ROSSubspaceSizes) {
        this(clss);
        this.m_ROSSubspaceSizes = ROSSubspaceSizes;
    }

    @Override
    public ClusInductionAlgorithm createInduce(ClusSchema schema, Settings sett, CMDLineArgs cargs) throws ClusException, IOException {
        return this.m_Class.createInduce(schema, sett, cargs);
    }

    @Override
    public void printInfo() {
        ClusLogger.info("Fitted Rule Ensemble with Random Output Selections (FIRE-ROS) (Tuning)");
        ClusLogger.info("Heuristic: " + this.getStatManager().getHeuristicName());
    }

    private final void showFold(int i) {
        if (i != 0) {
            ClusLogger.fine(" ");
        }
        ClusLogger.fine(String.valueOf(i + 1));
    }

    public ClusErrorList createTuneError(ClusStatManager mgr) {
        ClusErrorList parent = new ClusErrorList();
        NumericAttrType[] num = mgr.getSchema().getNumericAttrUse(ClusAttrType.AttributeUseType.Target);
        NominalAttrType[] nom = mgr.getSchema().getNominalAttrUse(ClusAttrType.AttributeUseType.Target);
        if (nom.length != 0) {
            parent.addError(new Accuracy(parent, nom));
        }
        if (num.length != 0) {
            parent.addError(new RMSError(parent, num));
        }
        return parent;
    }

    public final ClusRun partitionDataBasic(ClusData data, ClusSelection sel, ClusSummary summary, int idx) throws IOException, ClusException, InterruptedException {
        ClusRun cr = new ClusRun(data.cloneData(), summary);
        if (sel != null) {
            if (sel.changesDistribution()) {
                ((RowData)cr.getTrainingSet()).update(sel);
            } else {
                ClusData val = cr.getTrainingSet().select(sel);
                cr.setTestSet(((RowData)val).getIterator());
            }
        }
        cr.setIndex(idx);
        cr.copyTrainingData();
        return cr;
    }

    public double doParamXVal(RowData trset, RowData pruneset) throws Exception {
        int prevVerb = this.getSettings().getGeneral().enableVerbose(0);
        ClusStatManager mgr = this.getStatManager();
        ClusSummary summ = new ClusSummary();
        summ.setStatManager(this.getStatManager());
        summ.addModelInfo(1).setTestError(this.createTuneError(mgr));
        ClusRandom.initialize(this.getSettings());
        if (pruneset != null) {
            ClusRun cr = new ClusRun(trset.cloneData(), summ);
            ClusModel model = this.m_Class.induceSingleUnpruned(cr);
            cr.addModelInfo(1).setModel(model);
            cr.addModelInfo(1).setTestError(this.createTuneError(mgr));
            this.m_Clus.calcError(pruneset.getIterator(), 1, cr, null);
            summ.addSummary(cr);
        } else {
            Random random = new Random(0L);
            int nbfolds = Integer.parseInt(this.getSettings().getModel().getTuneFolds());
            XValRandomSelection sel = new XValRandomSelection(trset.getNbRows(), nbfolds, random);
            ClusModel dummy = null;
            for (int i = 0; i < nbfolds; ++i) {
                this.showFold(i);
                XValSelection msel = new XValSelection(sel, i);
                ClusRun cr = this.partitionDataBasic(trset, msel, summ, i + 1);
                ClusModelInfo def_info = cr.addModelInfo(0);
                if (dummy == null) {
                    dummy = ClusDecisionTree.induceDefault(cr);
                }
                def_info.setModel(dummy);
                ClusModel model = this.m_Class.induceSingleUnpruned(cr);
                cr.addModelInfo(1).setModel(model);
                cr.addModelInfo(1).setTestError(this.createTuneError(mgr));
                this.m_Clus.calcError(cr.getTestIter(), 1, cr, null);
                summ.addSummary(cr);
            }
            ClusLogger.fine();
        }
        ClusModelInfo mi = summ.getModelInfo(1);
        this.getSettings().getGeneral().enableVerbose(prevVerb);
        ClusError err = mi.getTestError().getFirstError();
        return err.getModelError();
    }

    public void findBestROSParameters(RowData trset, RowData pruneset) throws Exception {
        Integer idxBestSubspace = null;
        boolean errorShouldBeLow = this.createTuneError(this.getStatManager()).getFirstError().shouldBeLow();
        double bestError = errorShouldBeLow ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
        SettingsEnsemble.EnsembleROSVotingType bestVotingType = null;
        int originalMaxIter = this.getSettings().getRules().getOptGDMaxIter();
        int originalEnsembleSize = this.getSettings().getEnsemble().getNbBaggingSets().getInt();
        this.getSettings().getEnsemble().setNbBags(10);
        this.getSettings().getRules().setOptGDMaxIter(4000);
        for (SettingsEnsemble.EnsembleROSVotingType vt : SettingsEnsemble.EnsembleROSVotingType.values()) {
            for (int i = 0; i < this.m_ROSSubspaceSizes.length; ++i) {
                this.getSettings().getEnsemble().setEnsembleROSVotingType(vt);
                this.getSettings().getEnsemble().setNbRandomTargetAttrString(this.m_ROSSubspaceSizes[i]);
                ClusLogger.fine("Try for ROS subspace size = " + this.m_ROSSubspaceSizes[i] + " with voting: " + (Object)((Object)vt));
                double err = this.doParamXVal(trset, pruneset);
                ClusLogger.fine("-> " + err);
                if (!(errorShouldBeLow && err < bestError - 1.0E-16) && (errorShouldBeLow || !(err > bestError + 1.0E-16))) continue;
                bestError = err;
                idxBestSubspace = i;
                bestVotingType = vt;
                ClusLogger.fine(" *");
            }
        }
        this.getSettings().getEnsemble().setNbRandomTargetAttrString(this.m_ROSSubspaceSizes[idxBestSubspace]);
        this.getSettings().getEnsemble().setEnsembleROSVotingType(bestVotingType);
        ClusLogger.fine("Best FIRE-ROS setting is: " + this.m_ROSSubspaceSizes[idxBestSubspace] + " with " + bestVotingType);
        this.getSettings().getEnsemble().setNbBags(originalEnsembleSize);
        this.getSettings().getRules().setOptGDMaxIter(originalMaxIter);
    }

    @Override
    public void induceAll(ClusRun cr) throws ClusException, IOException {
        try {
            RowData valid = (RowData)cr.getPruneSet();
            RowData train = (RowData)cr.getTrainingSet();
            this.findBestROSParameters(train, valid);
            ClusLogger.info();
            cr.combineTrainAndValidSets();
            ClusRandom.initialize(this.getSettings());
            this.m_Class.induceAll(cr);
        }
        catch (Exception e) {
            ClusLogger.severe(e.toString());
        }
    }
}

