/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.ext.beamsearch;

import si.ijs.kt.clus.algo.tdidt.ClusNode;
import si.ijs.kt.clus.ext.beamsearch.ClusBeamHeuristic;
import si.ijs.kt.clus.main.settings.Settings;
import si.ijs.kt.clus.main.settings.section.SettingsBeamSearch;
import si.ijs.kt.clus.main.settings.section.SettingsTree;
import si.ijs.kt.clus.statistic.ClusStatistic;
import si.ijs.kt.clus.util.ClusLogger;

public class ClusBeamHeuristicMEstimate
extends ClusBeamHeuristic {
    protected double m_Prior;
    protected double m_MValue;

    public ClusBeamHeuristicMEstimate(ClusStatistic stat, double mvalue, Settings sett) {
        super(stat, sett);
        this.m_MValue = mvalue;
    }

    @Override
    public double calcHeuristic(ClusStatistic c_tstat, ClusStatistic c_pstat, ClusStatistic missing) {
        double n_tot = c_tstat.getTotalWeight();
        double n_pos = c_pstat.getTotalWeight();
        double n_neg = n_tot - n_pos;
        if (n_pos < SettingsTree.MINIMAL_WEIGHT || n_neg < SettingsTree.MINIMAL_WEIGHT) {
            return Double.NEGATIVE_INFINITY;
        }
        if (missing.getTotalWeight() <= 1.0E-9) {
            double pos_error = c_pstat.getError();
            double neg_error = c_tstat.getErrorDiff(c_pstat);
            return this.m_TreeOffset - (pos_error + neg_error) / this.m_NbTrain - 2.0 * SettingsBeamSearch.SIZE_PENALTY;
        }
        double pos_freq = n_pos / n_tot;
        this.m_Pos.copy(c_pstat);
        this.m_Neg.copy(c_tstat);
        this.m_Neg.subtractFromThis(c_pstat);
        this.m_Pos.addScaled(pos_freq, missing);
        this.m_Neg.addScaled(1.0 - pos_freq, missing);
        double pos_error = this.m_Pos.getError();
        double neg_error = this.m_Neg.getError();
        return this.m_TreeOffset - (pos_error + neg_error) / this.m_NbTrain - 2.0 * SettingsBeamSearch.SIZE_PENALTY;
    }

    @Override
    public double estimateBeamMeasure(ClusNode tree) {
        if (tree.atBottomLevel()) {
            ClusStatistic total = tree.getClusteringStat();
            return -total.getError() / this.m_NbTrain - SettingsBeamSearch.SIZE_PENALTY;
        }
        double result = 0.0;
        for (int i = 0; i < tree.getNbChildren(); ++i) {
            ClusNode child = (ClusNode)tree.getChild(i);
            result += this.estimateBeamMeasure(child);
        }
        return result - SettingsBeamSearch.SIZE_PENALTY;
    }

    @Override
    public double computeLeafAdd(ClusNode leaf) {
        return -leaf.getClusteringStat().getError() / this.m_NbTrain;
    }

    @Override
    public void setRootStatistic(ClusStatistic stat) {
        this.m_Prior = (stat.getTotalWeight() - stat.getError()) / stat.getTotalWeight();
        ClusLogger.info("Setting prior: " + this.m_Prior);
    }

    @Override
    public String getName() {
        return "Beam Heuristic (MEstimate = " + this.m_MValue + ")" + this.getAttrHeuristicString();
    }
}

