/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.ext.beamsearch;

import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import si.ijs.kt.clus.Clus;
import si.ijs.kt.clus.algo.ClusInductionAlgorithm;
import si.ijs.kt.clus.algo.ClusInductionAlgorithmType;
import si.ijs.kt.clus.algo.split.CurrentBestTestAndHeuristic;
import si.ijs.kt.clus.algo.split.FindBestTest;
import si.ijs.kt.clus.algo.tdidt.ClusNode;
import si.ijs.kt.clus.algo.tdidt.ConstraintDFInduce;
import si.ijs.kt.clus.algo.tdidt.processor.BasicExampleCollector;
import si.ijs.kt.clus.data.ClusSchema;
import si.ijs.kt.clus.data.rows.DataTuple;
import si.ijs.kt.clus.data.rows.RowData;
import si.ijs.kt.clus.data.type.ClusAttrType;
import si.ijs.kt.clus.data.type.primitive.NominalAttrType;
import si.ijs.kt.clus.data.type.primitive.NumericAttrType;
import si.ijs.kt.clus.ext.beamsearch.ClusBeam;
import si.ijs.kt.clus.ext.beamsearch.ClusBeamHeuristic;
import si.ijs.kt.clus.ext.beamsearch.ClusBeamInduce;
import si.ijs.kt.clus.ext.beamsearch.ClusBeamModel;
import si.ijs.kt.clus.ext.beamsearch.ClusBeamModelDistance;
import si.ijs.kt.clus.ext.beamsearch.ClusBeamSyntacticConstraint;
import si.ijs.kt.clus.ext.beamsearch.ClusBeamTreeElem;
import si.ijs.kt.clus.ext.constraint.ClusConstraintFile;
import si.ijs.kt.clus.heuristic.ClusHeuristic;
import si.ijs.kt.clus.main.ClusRun;
import si.ijs.kt.clus.main.ClusStatManager;
import si.ijs.kt.clus.main.settings.Settings;
import si.ijs.kt.clus.main.settings.section.SettingsBeamSearch;
import si.ijs.kt.clus.main.settings.section.SettingsTree;
import si.ijs.kt.clus.model.ClusModel;
import si.ijs.kt.clus.model.ClusModelInfo;
import si.ijs.kt.clus.model.io.ClusModelCollectionIO;
import si.ijs.kt.clus.model.test.NodeTest;
import si.ijs.kt.clus.statistic.ClusStatistic;
import si.ijs.kt.clus.util.ClusLogger;
import si.ijs.kt.clus.util.exception.ClusException;
import si.ijs.kt.clus.util.jeans.io.MyFile;
import si.ijs.kt.clus.util.jeans.math.SingleStat;
import si.ijs.kt.clus.util.jeans.util.MyArray;
import si.ijs.kt.clus.util.jeans.util.StringUtils;
import si.ijs.kt.clus.util.jeans.util.cmdline.CMDLineArgs;

public class ClusBeamSearch
extends ClusInductionAlgorithmType {
    public static final int HEURISTIC_ERROR = 0;
    public static final int HEURISTIC_SS = 1;
    protected BasicExampleCollector m_Coll = new BasicExampleCollector();
    protected ConstraintDFInduce m_Induce;
    protected ClusBeamInduce m_BeamInduce;
    protected boolean m_BeamChanged;
    protected int m_CurrentModel;
    protected int m_MaxTreeSize;
    protected double m_TotalWeight;
    protected ArrayList m_BeamStats;
    protected ClusBeam m_Beam;
    protected boolean m_BeamPostPruning;
    protected ClusBeamHeuristic m_Heuristic;
    protected ClusHeuristic m_AttrHeuristic;
    protected boolean m_Verbose;
    protected ClusBeamModelDistance m_BeamModelDistance;
    protected ClusBeamSyntacticConstraint m_BeamSyntConstr;

    public ClusBeamSearch(Clus clus) throws ClusException, IOException {
        super(clus);
    }

    public void reset() {
        this.m_Beam = null;
        this.m_BeamChanged = false;
        this.m_CurrentModel = -1;
        this.m_TotalWeight = 0.0;
        this.m_BeamStats = new ArrayList();
    }

    @Override
    public ClusInductionAlgorithm createInduce(ClusSchema schema, Settings sett, CMDLineArgs cargs) throws ClusException, IOException {
        schema.addIndices(0);
        this.m_BeamInduce = new ClusBeamInduce(schema, sett, this);
        this.m_BeamInduce.getStatManager().setBeamSearch(true);
        return this.m_BeamInduce;
    }

    public void initializeHeuristic() {
        ClusStatManager smanager = this.m_BeamInduce.getStatManager();
        Settings sett = smanager.getSettings();
        this.m_MaxTreeSize = sett.getBeamSearch().getBeamTreeMaxSize();
        ClusLogger.info("BeamSearch : the maximal size of the trees is " + this.m_MaxTreeSize);
        this.m_BeamPostPruning = sett.getBeamSearch().isBeamPostPrune();
        this.m_Heuristic = (ClusBeamHeuristic)smanager.getHeuristic();
        SettingsTree.Heuristic attr_heur = sett.getBeamSearch().getBeamAttrHeuristic();
        if (!attr_heur.equals((Object)SettingsTree.Heuristic.Default)) {
            this.m_AttrHeuristic = smanager.createHeuristic(attr_heur);
            this.m_Heuristic.setAttrHeuristic(this.m_AttrHeuristic);
        }
    }

    public final boolean isBeamPostPrune() {
        return this.m_BeamPostPruning;
    }

    public double computeLeafAdd(ClusNode leaf) {
        return this.m_Heuristic.computeLeafAdd(leaf);
    }

    public double estimateBeamMeasure(ClusNode tree) {
        return this.m_Heuristic.estimateBeamMeasure(tree);
    }

    public void initSelector(CurrentBestTestAndHeuristic sel) {
        if (this.hasAttrHeuristic()) {
            sel.setHeuristic(this.m_AttrHeuristic);
        }
    }

    public final boolean hasAttrHeuristic() {
        return this.m_AttrHeuristic != null;
    }

    public ClusBeam initializeBeam(ClusRun run) throws Exception {
        ClusStatManager smanager = this.m_BeamInduce.getStatManager();
        Settings sett = smanager.getSettings();
        ClusBeam beam = new ClusBeam(sett.getBeamSearch().getBeamWidth(), sett.getBeamSearch().getBeamRemoveEqualHeur());
        RowData train = (RowData)run.getTrainingSet();
        train.addIndices();
        ClusStatistic stat = this.m_Induce.createTotalClusteringStat(train);
        stat.calcMean();
        this.m_Induce.initSelectorAndSplit(stat);
        this.initSelector(this.m_Induce.getBestTest());
        ClusLogger.info("Root statistic: " + stat);
        ClusNode root = null;
        String constr_file = sett.getConstraints().getConstraintFile();
        if (StringUtils.unCaseCompare(constr_file, "None")) {
            root = new ClusNode();
            root.setClusteringStat(stat);
        } else {
            ClusConstraintFile file = ClusConstraintFile.getInstance();
            root = file.getClone(constr_file);
            root.setClusteringStat(stat);
            this.m_Induce.fillInStatsAndTests(root, train);
        }
        root.initTargetStat(this.getStatManager(), train);
        root.getTargetStat().calcMean();
        root.getClusteringStat().setBeam(beam);
        root.getTargetStat().setBeam(beam);
        double weight = root.getClusteringStat().getTotalWeight();
        this.setTotalWeight(weight);
        double value = this.estimateBeamMeasure(root);
        beam.addModel(new ClusBeamModel(value, root));
        this.m_BeamModelDistance = new ClusBeamModelDistance(run, beam);
        if (SettingsBeamSearch.BEAM_SYNT_DIST_CONSTR) {
            this.m_BeamSyntConstr = new ClusBeamSyntacticConstraint(run);
        }
        return beam;
    }

    public void refineGivenLeaf(ClusNode leaf, ClusBeamModel root, ClusBeam beam, ClusAttrType[] attrs) throws Exception {
        MyArray arr = (MyArray)leaf.getVisitor();
        RowData data = new RowData(arr.getObjects(), arr.size());
        if (this.m_Induce.initSelectorAndStopCrit(leaf, data)) {
            return;
        }
        CurrentBestTestAndHeuristic sel = this.m_Induce.getBestTest();
        FindBestTest find = this.m_Induce.getFindBestTest();
        double base_value = root.getValue();
        double leaf_add = this.m_Heuristic.computeLeafAdd(leaf);
        this.m_Heuristic.setTreeOffset(base_value - leaf_add);
        for (int i = 0; i < attrs.length; ++i) {
            sel.resetBestTest();
            double beam_min_value = beam.getMinValue();
            sel.setBestHeur(beam_min_value);
            ClusAttrType at = attrs[i];
            if (at instanceof NominalAttrType) {
                find.findNominal((NominalAttrType)at, data, null);
            } else {
                find.findNumeric((NumericAttrType)at, data, null);
            }
            if (!sel.hasBestTest()) continue;
            ClusNode ref_leaf = (ClusNode)leaf.cloneNode();
            ref_leaf.testToNode(sel);
            if (this.getSettings().getGeneral().getVerbose() > 0) {
                ClusLogger.info("Test: " + ref_leaf.getTestString() + " -> " + sel.m_BestHeur + " (" + ref_leaf.getTest().getPosFreq() + ")");
            }
            ClusStatManager mgr = this.m_Induce.getStatManager();
            int arity = ref_leaf.updateArity();
            NodeTest test = ref_leaf.getTest();
            for (int j = 0; j < arity; ++j) {
                ClusNode child = new ClusNode();
                ref_leaf.setChild(child, j);
                RowData subset = data.applyWeighted(test, j);
                child.initClusteringStat(mgr, subset);
                child.initTargetStat(mgr, subset);
                child.getTargetStat().calcMean();
            }
            ClusNode root_model = (ClusNode)root.getModel();
            ClusNode ref_tree = (ClusNode)root_model.cloneTree(leaf, ref_leaf);
            double new_heur = this.sanityCheck(sel.m_BestHeur, ref_tree);
            ClusBeamModel new_model = new ClusBeamModel(new_heur, ref_tree);
            new_model.setParentModelIndex(this.getCurrentModel());
            if (SettingsBeamSearch.BEAM_SIMILARITY != 0.0 && !SettingsBeamSearch.BEAM_SYNT_DIST_CONSTR) {
                new_model.setModelPredictions(this.m_BeamModelDistance.getPredictions(new_model.getModel()));
                if (beam.modelAlreadyIn(new_model)) continue;
                this.m_BeamModelDistance.addDistToCandOpt(beam, new_model);
                if (beam.removeMinUpdatedOpt(new_model, this.m_BeamModelDistance) != 1) continue;
                this.setBeamChanged(true);
                continue;
            }
            if (SettingsBeamSearch.BEAM_SYNT_DIST_CONSTR) {
                ClusLogger.info("OLD HEUR = " + new_heur);
                new_model.setModelPredictions(this.m_BeamModelDistance.getPredictions(new_model.getModel()));
                ClusLogger.info("UPDT HEUR = " + (new_heur -= SettingsBeamSearch.BEAM_SIMILARITY * this.m_BeamModelDistance.getDistToConstraint(new_model, this.m_BeamSyntConstr)));
            }
            if (!(new_heur > beam_min_value)) continue;
            beam.addModel(new_model);
            this.setBeamChanged(true);
        }
    }

    public void refineEachLeaf(ClusNode tree, ClusBeamModel root, ClusBeam beam, ClusAttrType[] attrs) throws Exception {
        int nb_c = tree.getNbChildren();
        if (nb_c == 0) {
            this.refineGivenLeaf(tree, root, beam, attrs);
        } else {
            for (int i = 0; i < nb_c; ++i) {
                ClusNode child = (ClusNode)tree.getChild(i);
                this.refineEachLeaf(child, root, beam, attrs);
            }
        }
    }

    public void refineModel(ClusBeamModel model, ClusBeam beam, ClusRun run) throws Exception {
        int size;
        ClusNode tree = (ClusNode)model.getModel();
        if (this.m_MaxTreeSize >= 0 && (size = tree.getNbNodes()) + 2 > this.m_MaxTreeSize) {
            return;
        }
        RowData train = (RowData)run.getTrainingSet();
        this.m_Coll.initialize(tree, null);
        int nb_rows = train.getNbRows();
        for (int i = 0; i < nb_rows; ++i) {
            DataTuple tuple = train.getTuple(i);
            tree.applyModelProcessor(tuple, this.m_Coll);
        }
        ClusAttrType[] attrs = train.getSchema().getDescriptiveAttributes();
        this.refineEachLeaf(tree, model, beam, attrs);
        tree.clearVisitors();
    }

    public void refineBeam(ClusBeam beam, ClusRun run) throws Exception {
        this.setBeamChanged(false);
        ArrayList models = beam.toArray();
        for (int i = 0; i < models.size(); ++i) {
            this.setCurrentModel(i);
            ClusBeamModel model = (ClusBeamModel)models.get(i);
            if (!model.isRefined() && !model.isFinished()) {
                if (this.m_Verbose) {
                    System.out.print("[*]");
                }
                this.refineModel(model, beam, run);
                model.setRefined(true);
                model.setParentModelIndex(-1);
            }
            if (!this.m_Verbose) continue;
            if (model.isRefined()) {
                System.out.print("[R]");
            }
            if (!model.isFinished()) continue;
            System.out.print("[F]");
        }
    }

    @Override
    public Settings getSettings() {
        return this.m_Clus.getSettings();
    }

    public void estimateBeamStats(ClusBeam beam) {
        SingleStat stat_heuristic = new SingleStat();
        SingleStat stat_size = new SingleStat();
        SingleStat stat_same_heur = new SingleStat();
        ArrayList lst = beam.toArray();
        HashSet<NodeTest> tops = new HashSet<NodeTest>();
        for (int i = 0; i < lst.size(); ++i) {
            ClusBeamModel model = (ClusBeamModel)lst.get(i);
            stat_heuristic.addFloat(model.getValue());
            stat_size.addFloat(model.getModel().getModelSize());
            NodeTest top = ((ClusNode)model.getModel()).getTest();
            if (top == null || tops.contains(top)) continue;
            tops.add(top);
        }
        Iterator iter = beam.getIterator();
        while (iter.hasNext()) {
            ClusBeamTreeElem elem = (ClusBeamTreeElem)iter.next();
            stat_same_heur.addFloat(elem.getCount());
        }
        ArrayList<Object> stat = new ArrayList<Object>();
        stat.add(stat_heuristic);
        stat.add(stat_same_heur);
        stat.add(stat_size);
        stat.add(new Integer(tops.size()));
        this.m_BeamStats.add(stat);
    }

    public String getLevelStat(int i) {
        ArrayList stat = (ArrayList)this.m_BeamStats.get(i);
        StringBuffer buf = new StringBuffer();
        buf.append("Level: " + i);
        for (int j = 0; j < stat.size(); ++j) {
            Object elem = stat.get(j);
            buf.append(", ");
            if (elem instanceof SingleStat) {
                SingleStat st = (SingleStat)elem;
                buf.append(st.getMean() + "," + st.getRange());
                continue;
            }
            buf.append(elem.toString());
        }
        return buf.toString();
    }

    public void printBeamStats(int level) {
        ClusLogger.info(this.getLevelStat(level));
    }

    public void saveBeamStats() {
        MyFile stats = new MyFile(this.getSettings().getGeneric().getAppName() + ".bmstats");
        for (int i = 0; i < this.m_BeamStats.size(); ++i) {
            stats.log(this.getLevelStat(i));
        }
        stats.close();
    }

    public void writeModel(ClusModelCollectionIO strm) throws IOException, ClusException {
        this.saveBeamStats();
        ArrayList beam = this.getBeam().toArray();
        for (int i = 0; i < beam.size(); ++i) {
            ClusBeamModel m = (ClusBeamModel)beam.get(i);
            ClusNode node = (ClusNode)m.getModel();
            node.updateTree();
            node.clearVisitors();
        }
        int pos = 1;
        for (int i = beam.size() - 1; i >= 0; --i) {
            ClusBeamModel m = (ClusBeamModel)beam.get(i);
            ClusModelInfo info = new ClusModelInfo("B" + pos + ": " + m.getValue());
            info.setScore(m.getValue());
            info.setModel(m.getModel());
            strm.addModel(info);
            ++pos;
        }
    }

    public void setVerbose(boolean verb) {
        this.m_Verbose = verb;
    }

    public ClusNode beamSearch(ClusRun run) throws Exception {
        this.reset();
        ClusLogger.info("Starting beam search");
        this.m_Induce = new ConstraintDFInduce(this.m_BeamInduce);
        ClusBeam beam = this.initializeBeam(run);
        int i = 0;
        while (true) {
            ClusLogger.info("Step: " + i);
            this.refineBeam(beam, run);
            if (!this.isBeamChanged()) break;
            this.estimateBeamStats(beam);
            ++i;
        }
        this.setBeam(beam);
        double best = beam.getBestModel().getValue();
        double worst = beam.getWorstModel().getValue();
        ClusLogger.info("Worst = " + worst + " Best = " + best);
        this.printBeamStats(i - 1);
        ClusNode result = (ClusNode)beam.getBestAndSmallestModel().getModel();
        return result;
    }

    public void setBeam(ClusBeam beam) {
        this.m_Beam = beam;
    }

    public ClusBeam getBeam() {
        return this.m_Beam;
    }

    public boolean isBeamChanged() {
        return this.m_BeamChanged;
    }

    public void setBeamChanged(boolean change) {
        this.m_BeamChanged = change;
    }

    public int getCurrentModel() {
        return this.m_CurrentModel;
    }

    public void setCurrentModel(int model) {
        this.m_CurrentModel = model;
    }

    public void setTotalWeight(double weight) {
        this.m_TotalWeight = weight;
    }

    public double sanityCheck(double value, ClusNode tree) throws ClusException {
        double expected = this.estimateBeamMeasure(tree);
        if (Math.abs(value - expected) > 1.0E-6) {
            ClusLogger.info("Bug in heurisitc: " + value + " <> " + expected);
            PrintWriter wrt = new PrintWriter(System.out);
            tree.printModel(wrt);
            wrt.close();
            System.out.flush();
            throw new ClusException("Bug in heuristic: " + value + " <> " + expected);
        }
        return expected;
    }

    public void tryLogBeam(MyFile log, ClusBeam beam, String txt) {
        if (log.isEnabled()) {
            log.log(txt);
            log.log("*********************************************");
            beam.print(log.getWriter(), this.m_Clus.getSettings().getBeamSearch().getBeamBestN());
            log.log();
        }
    }

    @Override
    public void pruneAll(ClusRun cr) throws ClusException, IOException {
    }

    @Override
    public ClusModel pruneSingle(ClusModel model, ClusRun cr) throws ClusException, IOException {
        return model;
    }

    @Override
    public void postProcess(ClusRun cr) throws ClusException, IOException {
    }
}

