/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.pruning;

import si.ijs.kt.clus.algo.tdidt.ClusNode;
import si.ijs.kt.clus.data.ClusSchema;
import si.ijs.kt.clus.data.attweights.ClusAttributeWeights;
import si.ijs.kt.clus.data.rows.RowData;
import si.ijs.kt.clus.data.type.ClusAttrType;
import si.ijs.kt.clus.data.type.primitive.NominalAttrType;
import si.ijs.kt.clus.data.type.primitive.NumericAttrType;
import si.ijs.kt.clus.error.ClusSumError;
import si.ijs.kt.clus.error.MSError;
import si.ijs.kt.clus.error.MSNominalError;
import si.ijs.kt.clus.error.MisclassificationError;
import si.ijs.kt.clus.error.common.ClusError;
import si.ijs.kt.clus.error.common.ClusErrorList;
import si.ijs.kt.clus.pruning.CartVisitor;
import si.ijs.kt.clus.pruning.PruneTree;
import si.ijs.kt.clus.pruning.TreeErrorComputer;
import si.ijs.kt.clus.util.exception.ClusException;

public class CartPruning
extends PruneTree {
    protected int[] m_MaxSize;
    protected ClusAttributeWeights m_Weights;
    protected double m_U1;
    protected double m_U2;
    protected boolean m_IsMSENominal;
    protected ClusError m_ErrorMeasure;

    public CartPruning(ClusAttributeWeights weights, boolean isMSENominal) {
        this.m_Weights = weights;
        this.m_IsMSENominal = isMSENominal;
    }

    public CartPruning(int[] maxsize, ClusAttributeWeights weights) {
        this.m_Weights = weights;
        this.m_MaxSize = maxsize;
    }

    @Override
    public int getNbResults() {
        return Math.max(1, this.m_MaxSize.length);
    }

    @Override
    public void prune(int result, ClusNode node) throws ClusException {
        int size = this.m_MaxSize[result];
        TreeErrorComputer.recursiveInitialize(node, new CartVisitor());
        this.internalInitialize(node);
        while (node.getNbNodes() > size) {
            this.internalSequenceNext(node);
        }
    }

    @Override
    public ClusErrorList createErrorMeasure(RowData data, ClusAttributeWeights weights) {
        ClusSchema schema = data.getSchema();
        ClusErrorList parent = new ClusErrorList();
        NumericAttrType[] num = schema.getNumericAttrUse(ClusAttrType.AttributeUseType.Clustering);
        NominalAttrType[] nom = schema.getNominalAttrUse(ClusAttrType.AttributeUseType.Clustering);
        if (nom.length != 0 && num.length != 0) {
            MSError numErr = new MSError(parent, num, weights);
            MSNominalError nomErr = new MSNominalError(parent, nom, weights);
            ClusSumError error = new ClusSumError(parent);
            error.addComponent(numErr);
            error.addComponent(nomErr);
            this.m_ErrorMeasure = error;
            parent.addError(this.m_ErrorMeasure);
        } else {
            if (nom.length != 0) {
                if (this.m_IsMSENominal) {
                    this.m_ErrorMeasure = new MSNominalError(parent, nom, weights);
                    parent.addError(this.m_ErrorMeasure);
                } else {
                    this.m_ErrorMeasure = new MisclassificationError(parent, nom);
                    parent.addError(this.m_ErrorMeasure);
                }
            }
            if (num.length != 0) {
                this.m_ErrorMeasure = new MSError(parent, num, weights);
                parent.addError(this.m_ErrorMeasure);
            }
        }
        parent.setWeights(weights);
        return parent;
    }

    @Override
    public void sequenceInitialize(ClusNode node) {
        TreeErrorComputer.recursiveInitialize(node, new CartVisitor());
        this.setOriginalTree(node);
    }

    @Override
    public void sequenceReset() {
        this.setCurrentTree(null);
    }

    @Override
    public ClusNode sequenceNext() throws ClusException {
        ClusNode result = this.getCurrentTree();
        if (result == null) {
            result = this.getOriginalTree().cloneTreeWithVisitors();
            this.internalInitialize(result);
        } else {
            if (result.atBottomLevel()) {
                return null;
            }
            this.internalSequenceNext(result);
        }
        this.setCurrentTree(result);
        return result;
    }

    @Override
    public void sequenceToElemK(ClusNode node, int k) {
        this.internalInitialize(node);
        for (int i = 0; i < k; ++i) {
            this.internalSequenceNext(node);
        }
    }

    public void initU(ClusNode node) {
        CartVisitor cart = (CartVisitor)node.getVisitor();
        this.m_U1 = 1.0 + cart.delta_u1;
        this.m_U2 = this.m_ErrorMeasure.computeLeafError(node.getClusteringStat()) + cart.delta_u2;
    }

    public static final double getLambda(ClusNode node) {
        CartVisitor cart = (CartVisitor)node.getVisitor();
        return cart.lambda;
    }

    public static final double getLambdaMin(ClusNode node) {
        CartVisitor cart = (CartVisitor)node.getVisitor();
        return cart.lambda_min;
    }

    public static final void updateLambdaMin(ClusNode node) {
        CartVisitor cart = (CartVisitor)node.getVisitor();
        cart.lambda_min = cart.lambda;
        for (int i = 0; i < node.getNbChildren(); ++i) {
            ClusNode ch = (ClusNode)node.getChild(i);
            cart.lambda_min = Math.min(cart.lambda_min, CartPruning.getLambdaMin(ch));
        }
    }

    public static final void updateLambda(ClusNode node) {
        CartVisitor cart = (CartVisitor)node.getVisitor();
        cart.lambda = -cart.delta_u2 / cart.delta_u1;
    }

    public static final void subtractDeltaU(ClusNode node, double d_u1, double d_u2) {
        CartVisitor cart = (CartVisitor)node.getVisitor();
        cart.delta_u1 -= d_u1;
        cart.delta_u2 -= d_u2;
    }

    public void internalSequenceNext(ClusNode node) {
        ClusNode cr_node_t = node;
        double lambda_min_t0 = CartPruning.getLambdaMin(node);
        while (CartPruning.getLambda(cr_node_t) > lambda_min_t0) {
            ClusNode ch1 = (ClusNode)cr_node_t.getChild(0);
            ClusNode ch2 = (ClusNode)cr_node_t.getChild(1);
            if (CartPruning.getLambdaMin(ch1) == lambda_min_t0) {
                cr_node_t = ch1;
                continue;
            }
            cr_node_t = ch2;
        }
        cr_node_t.makeLeaf();
        CartVisitor cart_t = (CartVisitor)cr_node_t.getVisitor();
        double delta_u1 = cart_t.delta_u1;
        double delta_u2 = cart_t.delta_u2;
        cart_t.lambda_min = Double.POSITIVE_INFINITY;
        while (!cr_node_t.atTopLevel()) {
            cr_node_t = (ClusNode)cr_node_t.getParent();
            CartPruning.subtractDeltaU(cr_node_t, delta_u1, delta_u2);
            CartPruning.updateLambda(cr_node_t);
            CartPruning.updateLambdaMin(cr_node_t);
        }
        this.m_U1 -= delta_u1;
        this.m_U2 -= delta_u2;
    }

    public void internalInitialize(ClusNode node) {
        this.internalRecursiveInitialize(node);
        this.initU(node);
    }

    public void internalRecursiveInitialize(ClusNode node) {
        int nb_c = node.getNbChildren();
        for (int i = 0; i < nb_c; ++i) {
            this.internalRecursiveInitialize((ClusNode)node.getChild(i));
        }
        CartVisitor cart = (CartVisitor)node.getVisitor();
        if (nb_c == 0) {
            cart.delta_u1 = 0.0;
            cart.delta_u2 = 0.0;
            cart.lambda_min = Double.POSITIVE_INFINITY;
        } else {
            cart.delta_u1 = node.computeNodesLeavesDepth()[1] - 1;
            double leaf_err = this.m_ErrorMeasure.computeLeafError(node.getClusteringStat());
            double tree_err = this.m_ErrorMeasure.computeTreeErrorClusteringAbsolute(node);
            cart.delta_u2 = tree_err - leaf_err;
            CartPruning.updateLambda(node);
            CartPruning.updateLambdaMin(node);
        }
    }
}

