/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.ext.beamsearch;

import java.io.IOException;
import java.util.ArrayList;
import si.ijs.kt.clus.Clus;
import si.ijs.kt.clus.algo.ClusInductionAlgorithm;
import si.ijs.kt.clus.algo.split.NominalSplit;
import si.ijs.kt.clus.algo.tdidt.ClusDecisionTree;
import si.ijs.kt.clus.algo.tdidt.ClusNode;
import si.ijs.kt.clus.data.ClusSchema;
import si.ijs.kt.clus.data.rows.RowData;
import si.ijs.kt.clus.ext.beamsearch.ClusBeamModel;
import si.ijs.kt.clus.ext.beamsearch.ClusBeamSearch;
import si.ijs.kt.clus.ext.beamsearch.ClusBeamSimilarityOutput;
import si.ijs.kt.clus.ext.ensemble.ClusForest;
import si.ijs.kt.clus.main.ClusRun;
import si.ijs.kt.clus.main.settings.Settings;
import si.ijs.kt.clus.model.ClusModel;
import si.ijs.kt.clus.model.ClusModelInfo;
import si.ijs.kt.clus.model.io.ClusModelCollectionIO;
import si.ijs.kt.clus.pruning.PruneTree;
import si.ijs.kt.clus.util.exception.ClusException;

public class ClusBeamInduce
extends ClusInductionAlgorithm {
    protected NominalSplit m_Split;
    protected ClusBeamSearch m_Search;

    public ClusBeamInduce(ClusSchema schema, Settings sett, ClusBeamSearch search) throws ClusException, IOException {
        super(schema, sett);
        this.m_Search = search;
    }

    @Override
    public void initializeHeuristic() {
        this.m_Search.initializeHeuristic();
    }

    @Override
    public boolean isModelWriter() {
        return true;
    }

    @Override
    public void writeModel(ClusModelCollectionIO strm) throws IOException, ClusException {
        this.m_Search.writeModel(strm);
    }

    @Override
    public ClusModel induceSingleUnpruned(ClusRun cr) throws Exception {
        ClusNode root = this.m_Search.beamSearch(cr);
        root.updateTree();
        return root;
    }

    @Override
    public void induceAll(ClusRun cr) throws Exception {
        this.m_Search.beamSearch(cr);
        ClusModelInfo def_model = cr.addModelInfo(0);
        def_model.setModel(ClusDecisionTree.induceDefault(cr));
        def_model.setName("Default");
        ArrayList lst = this.m_Search.getBeam().toArray();
        this.updateAllPredictions(lst);
        if (this.getSettings().getBeamSearch().getBeamTreeMaxSize() <= -1) {
            this.postPruneBeamModels(cr, lst);
        }
        if (this.getSettings().getBeamSearch().getBeamSortOnTrainParameter()) {
            this.sortModels(cr, lst);
        }
        ClusBeamSimilarityOutput bsimout = new ClusBeamSimilarityOutput(this.getSettings());
        bsimout.appendToFile(lst, cr);
        boolean toForest = cr.getStatManager().getSettings().getBeamSearch().isBeamToForest();
        ClusForest bForest = new ClusForest(this.getStatManager(), null);
        for (int i = 0; i < lst.size(); ++i) {
            ClusBeamModel mdl = (ClusBeamModel)lst.get(lst.size() - i - 1);
            ClusModelInfo model_info = cr.addModelInfo(i + 1);
            ClusNode tree = (ClusNode)mdl.getModel();
            model_info.setModel(tree);
            model_info.setName("Beam " + (i + 1));
            model_info.clearAll();
            if (!toForest) continue;
            bForest.addModelToForest(tree);
        }
        if (toForest) {
            ClusModelInfo forest_info = cr.addModelInfo(lst.size() + 1);
            forest_info.setModel(bForest);
            forest_info.setName("BeamToForest");
        }
    }

    public void postPruneBeamModels(ClusRun cr, ArrayList arr) throws ClusException, InterruptedException {
        this.updateAllPredictions(arr);
        for (int i = 0; i < arr.size(); ++i) {
            PruneTree pruner = this.getStatManager().getTreePruner(null);
            pruner.setTrainingData((RowData)cr.getTrainingSet());
            ClusNode tree = (ClusNode)((ClusBeamModel)arr.get(i)).getModel();
            pruner.prune(tree);
        }
    }

    public void updateAllPredictions(ArrayList<ClusBeamModel> arr) throws ClusException {
        for (int i = 0; i < arr.size(); ++i) {
            ClusNode tree = (ClusNode)arr.get(i).getModel();
            tree.updateTree();
        }
    }

    public void sortModels(ClusRun cr, ArrayList arr) throws ClusException, IOException, InterruptedException {
        int size = arr.size();
        ClusBeamModel[] models = new ClusBeamModel[size];
        double[] err = new double[size];
        double[] heur = new double[size];
        for (int i = 0; i < size; ++i) {
            models[i] = (ClusBeamModel)arr.get(i);
            err[i] = Clus.calcModelError(cr.getStatManager(), (RowData)cr.getTrainingSet(), models[i].getModel());
            heur[i] = models[i].getValue();
        }
        for (int j = 0; j < size - 1; ++j) {
            for (int k = j + 1; k < size; ++k) {
                double tmp;
                if (err[j] > err[k]) {
                    ClusBeamModel cbm = models[j];
                    models[j] = models[k];
                    models[k] = cbm;
                    tmp = err[j];
                    err[j] = err[k];
                    err[k] = tmp;
                    tmp = heur[j];
                    heur[j] = heur[k];
                    heur[k] = tmp;
                    continue;
                }
                if (err[j] != err[k] || !(heur[j] < heur[k])) continue;
                ClusBeamModel cbm = models[j];
                models[j] = models[k];
                models[k] = cbm;
                tmp = err[j];
                err[j] = err[k];
                err[k] = tmp;
                tmp = heur[j];
                heur[j] = heur[k];
                heur[k] = tmp;
            }
        }
        arr.clear();
        for (int m = 0; m < size; ++m) {
            arr.add(models[m]);
        }
    }
}

