/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.algo.tdidt.tune;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Random;
import si.ijs.kt.clus.algo.ClusInductionAlgorithmType;
import si.ijs.kt.clus.algo.tdidt.ClusDecisionTree;
import si.ijs.kt.clus.algo.tdidt.ClusNode;
import si.ijs.kt.clus.data.ClusData;
import si.ijs.kt.clus.data.ClusSchema;
import si.ijs.kt.clus.data.attweights.ClusAttributeWeights;
import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.data.rows.MemoryTupleIterator;
import si.ijs.kt.clus.data.rows.RowData;
import si.ijs.kt.clus.error.common.ClusError;
import si.ijs.kt.clus.error.common.ClusErrorList;
import si.ijs.kt.clus.ext.hierarchical.HierRemoveInsigClasses;
import si.ijs.kt.clus.main.ClusRun;
import si.ijs.kt.clus.main.ClusStatManager;
import si.ijs.kt.clus.main.ClusSummary;
import si.ijs.kt.clus.main.settings.Settings;
import si.ijs.kt.clus.model.ClusModel;
import si.ijs.kt.clus.model.ClusModelInfo;
import si.ijs.kt.clus.pruning.PruneTree;
import si.ijs.kt.clus.pruning.SizeConstraintPruning;
import si.ijs.kt.clus.pruning.TreeErrorComputer;
import si.ijs.kt.clus.selection.XValDataSelection;
import si.ijs.kt.clus.selection.XValMainSelection;
import si.ijs.kt.clus.selection.XValRandomSelection;
import si.ijs.kt.clus.selection.XValSelection;
import si.ijs.kt.clus.statistic.ClusStatistic;
import si.ijs.kt.clus.util.ClusLogger;
import si.ijs.kt.clus.util.exception.ClusException;
import si.ijs.kt.clus.util.jeans.io.MyFile;
import si.ijs.kt.clus.util.jeans.math.SingleStatList;

public class CDTuneSizeConstrPruning
extends ClusDecisionTree {
    protected ClusInductionAlgorithmType m_Class;
    protected ClusSchema m_Schema;
    protected ClusStatistic m_TotalStat;
    protected boolean m_HasMissing;
    protected int m_NbExamples;
    protected int m_OrigSize;
    protected double m_RelErrAcc = 0.01;
    protected ArrayList<SingleStatList> m_Graph;
    protected int m_Optimal;
    protected int m_MaxSize;
    protected ClusAttributeWeights m_TargetWeights;
    protected boolean m_Relative;
    protected double m_RelativeScale;

    public CDTuneSizeConstrPruning(ClusInductionAlgorithmType clss) {
        super(clss.getClus());
        this.m_Class = clss;
    }

    @Override
    public void printInfo() {
        ClusLogger.info("TDIDT (Tuning Size Constraint)");
        ClusLogger.info("Heuristic: " + this.getStatManager().getHeuristicName());
    }

    private final void showFold(int i) {
        if (i != 0) {
            System.out.print(" ");
        }
        System.out.print(String.valueOf(i + 1));
        System.out.flush();
    }

    public void setRelativeMeasure(boolean enable, double value) {
        this.m_Relative = enable;
        this.m_RelativeScale = value;
    }

    public void computeTestStatistics(ClusRun[] runs, int model, ClusError error) throws IOException, ClusException {
        TreeErrorComputer comp = new TreeErrorComputer();
        for (int i = 0; i < runs.length; ++i) {
            ClusNode tree = (ClusNode)runs[i].getModelInfo(model).getModel();
            TreeErrorComputer.initializeTestErrors(tree, error);
            MemoryTupleIterator test = (MemoryTupleIterator)runs[i].getTestIter();
            test.init();
            DataTuple tuple = test.readTuple();
            while (tuple != null) {
                tree.applyModelProcessor(tuple, comp);
                tuple = test.readTuple();
            }
        }
    }

    public void computeErrorStandard(ClusNode tree, int model, ClusRun run) throws ClusException, IOException {
        ClusModelInfo mi = run.getModelInfo(model);
        ClusError err = mi.getTestError().getFirstError();
        MemoryTupleIterator test = (MemoryTupleIterator)run.getTestIter();
        test.init();
        DataTuple tuple = test.readTuple();
        while (tuple != null) {
            ClusStatistic pred = tree.predictWeighted(tuple);
            err.addExample(tuple, pred);
            tuple = test.readTuple();
        }
    }

    public SingleStatList computeTreeError(ClusRun[] runs, SizeConstraintPruning[] pruners, int model, ClusSummary summ, int size) throws ClusException, IOException {
        ClusModelInfo summ_mi = summ.getModelInfo(model);
        ClusError summ_err = summ_mi.getTestError().getFirstError();
        summ_err.reset();
        SingleStatList res = new SingleStatList(runs.length);
        for (int i = 0; i < runs.length; ++i) {
            ClusModelInfo mi = runs[i].getModelInfo(model);
            ClusNode tree = (ClusNode)mi.getModel();
            if (size == 1) {
                tree = tree.cloneNodeWithVisitor();
            } else {
                int modelsize = tree.getModelSize();
                if (size < modelsize) {
                    tree = tree.cloneTreeWithVisitors();
                    pruners[i].pruneExecute(tree, size);
                }
            }
            if (this.getStatManager().getTargetMode() == ClusStatManager.Mode.HIERARCHICAL) {
                PruneTree pruner = new PruneTree();
                boolean bonf = this.getSettings().getHMLC().isUseBonferroni();
                HierRemoveInsigClasses hierpruner = new HierRemoveInsigClasses(runs[i].getPruneSet(), pruner, bonf, this.getStatManager().getHier());
                hierpruner.setSignificance(this.getSettings().getHMLC().getHierPruneInSig());
                hierpruner.prune(tree);
            }
            ClusError err = mi.getTestError().getFirstError();
            err.reset();
            if (this.m_HasMissing) {
                this.computeErrorStandard(tree, model, runs[i]);
            } else {
                TreeErrorComputer.computeErrorSimple(tree, err);
            }
            summ_err.add(err);
            MemoryTupleIterator test = (MemoryTupleIterator)runs[i].getTestIter();
            mi.getTestError().setNbExamples(test.getNbExamples());
            if (this.m_Relative) {
                res.addFloat(err.getModelError() / this.m_RelativeScale);
                continue;
            }
            res.addFloat(err.getModelError());
        }
        summ_mi.getTestError().setNbExamples(this.m_NbExamples);
        if (this.m_Relative) {
            res.setY(summ_err.getModelError() / this.m_RelativeScale);
        } else {
            res.setY(summ_err.getModelError());
        }
        return res;
    }

    public SingleStatList addPoint(ArrayList<SingleStatList> points, int size, ClusRun[] runs, SizeConstraintPruning[] pruners, int model, ClusSummary summ) throws ClusException, IOException {
        int pos;
        for (pos = 0; pos < points.size() && points.get(pos).getX() < (double)size; ++pos) {
        }
        if (pos < points.size() && points.get(pos).getX() == (double)size) {
            return null;
        }
        SingleStatList point = this.computeTreeError(runs, pruners, model, summ, size);
        point.setX(size);
        points.add(pos, point);
        return point;
    }

    public double getRange(ArrayList<SingleStatList> graph) {
        double min = Double.POSITIVE_INFINITY;
        double max = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < graph.size(); ++i) {
            SingleStatList elem = graph.get(i);
            if (elem.getY() < min) {
                min = elem.getY();
            }
            if (!(elem.getY() > max)) continue;
            max = elem.getY();
        }
        return Math.abs(max - min);
    }

    public void refineGraph(ArrayList<SingleStatList> graph, ClusRun[] runs, SizeConstraintPruning[] pruners, int model, ClusSummary summ) throws ClusException, IOException {
        int prevsize = -1;
        while (true) {
            boolean not_found = true;
            for (int i = 0; i < graph.size() - 2 && not_found; ++i) {
                int s2;
                int s1;
                int nmean;
                int smean;
                SingleStatList e1 = graph.get(i);
                SingleStatList e2 = graph.get(i + 1);
                if (!(Math.abs(e1.getY() - e2.getY()) > this.m_RelErrAcc) || (smean = 2 * (nmean = (((s1 = (int)e1.getX()) + (s2 = (int)e2.getX())) / 2 - 1) / 2) + 1) == s1 || smean == s2 || this.m_OrigSize != -1 && smean >= this.m_OrigSize) continue;
                this.addPoint(graph, smean, runs, pruners, model, summ);
                System.out.print("#");
                System.out.flush();
                not_found = false;
            }
            if (graph.size() == prevsize) {
                return;
            }
            prevsize = graph.size();
        }
    }

    public int findOptimalSize(ArrayList<SingleStatList> graph, boolean shouldBeLow) {
        double best_value = shouldBeLow ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
        int best_index = -1;
        for (int i = 0; i < graph.size(); ++i) {
            SingleStatList elem = graph.get(i);
            if (shouldBeLow) {
                if (!(elem.getY() < best_value)) continue;
                best_value = elem.getY();
                best_index = i;
                continue;
            }
            if (!(elem.getY() > best_value)) continue;
            best_value = elem.getY();
            best_index = i;
        }
        if (best_index == -1) {
            return 1;
        }
        SingleStatList best_elem = graph.get(best_index);
        System.out.print("[" + best_elem.getX() + "," + best_elem.getY() + "]");
        SingleStatList result = best_elem;
        for (int pos = best_index - 1; pos >= 0; --pos) {
            SingleStatList prev_elem = graph.get(pos);
            if (!(prev_elem.getX() >= 3.0) || !(Math.abs(prev_elem.getY() - best_elem.getY()) < this.m_RelErrAcc)) continue;
            result = prev_elem;
            System.out.print(" < " + prev_elem.getX());
        }
        return (int)result.getX();
    }

    public final XValMainSelection getXValSelection(Settings sett, int nbrows) throws IOException, ClusException {
        String value = sett.getModel().getTuneFolds();
        if (value.length() > 0 && Character.isDigit(value.charAt(0))) {
            try {
                int nbfolds = Integer.parseInt(value);
                Random random = new Random(0L);
                return new XValRandomSelection(nbrows, nbfolds, random);
            }
            catch (NumberFormatException e) {
                throw new ClusException("Illegal number of folds: " + value);
            }
        }
        return XValDataSelection.readFoldsFile(value, nbrows);
    }

    public void findBestSize(ClusData trset) throws Exception {
        SingleStatList new_pt;
        int size;
        ClusModel tree;
        int prevVerb = this.getSettings().getGeneral().enableVerbose(0);
        ClusStatManager mgr = this.getStatManager();
        ClusSummary summ = new ClusSummary();
        ClusErrorList errorpar = mgr.createDefaultError();
        errorpar.setWeights(this.m_TargetWeights);
        summ.setTestError(errorpar);
        int model = 1;
        XValMainSelection sel = this.getXValSelection(this.getSettings(), trset.getNbRows());
        int nbfolds = sel.getNbFolds();
        ClusRun[] runs = new ClusRun[nbfolds];
        for (int i = 0; i < nbfolds; ++i) {
            this.showFold(i);
            XValSelection msel = new XValSelection(sel, i);
            ClusRun cr = this.m_Clus.partitionDataBasic(trset, msel, summ, i + 1);
            tree = this.m_Class.induceSingleUnpruned(cr);
            cr.getModelInfo(model).setModel(tree);
            runs[i] = cr;
        }
        int maxsize = 0;
        SizeConstraintPruning[] pruners = new SizeConstraintPruning[nbfolds];
        for (int i = 0; i < nbfolds; ++i) {
            tree = (ClusNode)runs[i].getModelInfo(model).getModel();
            int size2 = tree.getModelSize();
            if (this.m_OrigSize != -1 && size2 > this.m_OrigSize) {
                size2 = this.m_OrigSize;
            }
            if (size2 > maxsize) {
                maxsize = size2;
            }
            SizeConstraintPruning pruner = new SizeConstraintPruning(size2, mgr.getClusteringWeights());
            pruner.pruneInitialize((ClusNode)tree, size2);
            pruners[i] = pruner;
        }
        if (maxsize == 1) {
            ClusLogger.info("Optimal size (maxsize = 1) = 1");
            this.m_Class.getSettings().getConstraints().setSizeConstraintPruning(1);
            return;
        }
        ClusError error = summ.getModelInfo(model).getTestError().getFirstError();
        if (!this.m_HasMissing) {
            this.computeTestStatistics(runs, model, error);
        }
        ArrayList<SingleStatList> graph = new ArrayList<SingleStatList>();
        this.setRelativeMeasure(false, 0.0);
        SingleStatList point = this.computeTreeError(runs, pruners, model, summ, 1);
        this.setRelativeMeasure(true, point.getY());
        System.out.print(" ");
        System.out.print("<" + point.getY() + ">");
        this.addPoint(graph, 1, runs, pruners, model, summ);
        this.addPoint(graph, maxsize, runs, pruners, model, summ);
        int n = 1;
        boolean shouldBeLow = error.shouldBeLow();
        while (!((size = (int)(Math.pow(2.0, n) + 1.0)) > maxsize || this.m_OrigSize != -1 && size > this.m_OrigSize || (new_pt = this.addPoint(graph, size, runs, pruners, model, summ)) == null || (shouldBeLow ? graph.size() > 5 && new_pt.getY() > 1.1 : graph.size() > 5 && new_pt.getY() < 0.9))) {
            System.out.print("*");
            System.out.flush();
            ++n;
        }
        this.refineGraph(graph, runs, pruners, model, summ);
        int optimalSize = this.findOptimalSize(graph, shouldBeLow);
        ClusLogger.info(" Best = " + optimalSize);
        this.setFinalResult(graph, optimalSize, maxsize);
        this.getSettings().getConstraints().setSizeConstraintPruning(optimalSize);
        this.getSettings().getGeneral().enableVerbose(prevVerb);
    }

    @Override
    public void saveInformation(String fname) {
        ClusLogger.info("Saving: " + fname + ".dat");
        MyFile file = new MyFile(fname + ".dat");
        file.log("" + this.m_Optimal + "\t" + this.m_MaxSize);
        for (int i = 0; i < this.m_Graph.size(); ++i) {
            SingleStatList elem = this.m_Graph.get(i);
            file.log("" + elem.getX() + "\t" + elem.getY());
        }
        file.close();
    }

    public void setFinalResult(ArrayList<SingleStatList> graph, int optimal, int maxsize) {
        this.m_Graph = graph;
        this.m_Optimal = optimal;
        this.m_MaxSize = maxsize;
    }

    @Override
    public ClusModel induceSingle(ClusRun cr) {
        ClusLogger.info(">>> Error: induceSingle/1 not implemented");
        return null;
    }

    public ClusStatistic createTotalStat(RowData data) throws ClusException {
        ClusStatistic stat = this.m_Class.getStatManager().createClusteringStat();
        data.calcTotalStatBitVector(stat);
        return stat;
    }

    @Override
    public void induceAll(ClusRun cr) throws Exception {
        try {
            long start_time = System.currentTimeMillis();
            this.m_OrigSize = this.getSettings().getConstraints().getSizeConstraintPruning(0);
            if (this.getSettings().getConstraints().getSizeConstraintPruningNumber() > 1) {
                throw new ClusException("Only one value is allowed for MaxSize if -tunesize is given");
            }
            RowData train = (RowData)cr.getTrainingSet();
            this.m_Schema = train.getSchema();
            this.m_HasMissing = this.m_Schema.hasMissing();
            this.m_TotalStat = this.createTotalStat(train);
            this.m_NbExamples = train.getNbRows();
            ClusLogger.info("Has missing values: " + this.m_HasMissing);
            this.m_TargetWeights = this.m_Class.getStatManager().getClusteringWeights();
            this.findBestSize(train);
            ClusLogger.info();
            this.m_Class.induceAll(cr);
            this.getSettings().getConstraints().setSizeConstraintPruning(this.m_OrigSize);
            long time = System.currentTimeMillis() - start_time;
            if (this.getSettings().getGeneral().getVerbose() > 0) {
                ClusLogger.info("Time: " + (double)time / 1000.0 + " sec");
            }
            cr.setInductionTime(time);
        }
        catch (ClusException e) {
            System.err.println("Error: " + e);
        }
        catch (IOException e) {
            System.err.println("IO Error: " + e);
        }
    }
}

