/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.pruning;

import si.ijs.kt.clus.algo.tdidt.ClusNode;
import si.ijs.kt.clus.data.attweights.ClusAttributeWeights;
import si.ijs.kt.clus.data.rows.RowData;
import si.ijs.kt.clus.pruning.PruneTree;
import si.ijs.kt.clus.statistic.RegressionStat;

public class M5PrunerMulti
extends PruneTree {
    double m_F = 1.0E-5;
    double m_PruningMult = 2.0;
    double[] m_GlobalRMSE;
    ClusAttributeWeights m_TargetWeights;
    RowData m_TrainingData;

    public M5PrunerMulti(ClusAttributeWeights prod, double mult) {
        this.m_TargetWeights = prod;
        this.m_PruningMult = mult;
    }

    @Override
    public void prune(ClusNode node) {
        RegressionStat stat = (RegressionStat)node.getClusteringStat();
        this.m_GlobalRMSE = stat.getRootScaledVariances(this.m_TargetWeights);
        this.pruneRecursive(node);
    }

    @Override
    public int getNbResults() {
        return 1;
    }

    private double pruningFactor(double num_instances, int num_params) {
        if (num_instances <= (double)num_params) {
            return 10.0;
        }
        return (num_instances + this.m_PruningMult * (double)num_params) / (num_instances - (double)num_params);
    }

    public static double estimateRootScaledVariance(ClusNode tree, int attr, ClusAttributeWeights scale) {
        double totweight = tree.getClusteringStat().getTotalWeight();
        return Math.sqrt(M5PrunerMulti.estimateScaledVariance(tree, attr, scale) / totweight);
    }

    public static double estimateScaledVariance(ClusNode tree, int attr, ClusAttributeWeights scale) {
        if (tree.atBottomLevel()) {
            RegressionStat stat = (RegressionStat)tree.getClusteringStat();
            return stat.getScaledSS(attr, scale);
        }
        double result = 0.0;
        for (int i = 0; i < tree.getNbChildren(); ++i) {
            ClusNode child = (ClusNode)tree.getChild(i);
            result += M5PrunerMulti.estimateScaledVariance(child, attr, scale);
        }
        return result;
    }

    public boolean allAccurate(RegressionStat stat) {
        for (int i = 0; i < stat.getNbAttributes(); ++i) {
            double E_leaf = stat.getRootScaledVariance(i, this.m_TargetWeights) * this.pruningFactor(stat.getTotalWeight(), 1);
            if (!(E_leaf >= this.m_GlobalRMSE[i] * this.m_F)) continue;
            return false;
        }
        return true;
    }

    public boolean allBetterThanTree(ClusNode node, RegressionStat stat, int modelsize) {
        for (int i = 0; i < stat.getNbAttributes(); ++i) {
            double E_tree;
            double E_leaf = stat.getRootScaledVariance(i, this.m_TargetWeights) * this.pruningFactor(stat.getTotalWeight(), 1);
            if (!(E_leaf > (E_tree = M5PrunerMulti.estimateRootScaledVariance(node, i, this.m_TargetWeights) * this.pruningFactor(stat.getTotalWeight(), modelsize)))) continue;
            return false;
        }
        return true;
    }

    public void pruneRecursive(ClusNode node) {
        int modelsize;
        if (node.atBottomLevel()) {
            return;
        }
        for (int i = 0; i < node.getNbChildren(); ++i) {
            ClusNode child = (ClusNode)node.getChild(i);
            this.pruneRecursive(child);
        }
        RegressionStat leaf_stat = (RegressionStat)node.getClusteringStat();
        if (this.allAccurate(leaf_stat)) {
            node.makeLeaf();
        }
        if (this.allBetterThanTree(node, leaf_stat, modelsize = node.getNbNodes())) {
            node.makeLeaf();
        }
    }

    @Override
    public void setTrainingData(RowData data) {
        this.m_TrainingData = data;
    }
}

