/*
 * Decompiled with CFR 0.152.
 */
package si.ijs.kt.clus.pruning;

import si.ijs.kt.clus.algo.tdidt.ClusNode;
import si.ijs.kt.clus.data.attweights.ClusAttributeWeights;
import si.ijs.kt.clus.data.rows.RowData;
import si.ijs.kt.clus.error.common.ClusError;
import si.ijs.kt.clus.error.common.ClusErrorList;
import si.ijs.kt.clus.model.test.NodeTest;
import si.ijs.kt.clus.pruning.PruneTree;
import si.ijs.kt.clus.pruning.SizeConstraintVisitor;
import si.ijs.kt.clus.pruning.TreeErrorComputer;
import si.ijs.kt.clus.util.ClusLogger;
import si.ijs.kt.clus.util.exception.ClusException;

public class SizeConstraintPruning
extends PruneTree {
    public RowData m_Data;
    public double[] m_MaxError;
    public ClusErrorList m_ErrorMeasure;
    public ClusErrorList m_AdditiveError;
    public int[] m_MaxSize;
    public ClusAttributeWeights m_TargetWeights;
    public int m_CrIndex;
    public int m_MaxIndex;

    public SizeConstraintPruning(int maxsize, ClusAttributeWeights prod) {
        this.m_MaxSize = new int[1];
        this.m_MaxSize[0] = maxsize;
        this.m_TargetWeights = prod;
    }

    public SizeConstraintPruning(int[] maxsize, ClusAttributeWeights prod) {
        this.m_MaxSize = maxsize;
        this.m_TargetWeights = prod;
    }

    public int getMaxSize() {
        return this.m_MaxSize[0];
    }

    @Override
    public void setTrainingData(RowData data) {
        this.m_Data = data;
    }

    public ClusAttributeWeights getTargetWeights() {
        return this.m_TargetWeights;
    }

    public void pruneInitialize(ClusNode node, int size) throws ClusException {
        SizeConstraintPruning.recursiveInitialize(node, size);
        if (this.isUsingAdditiveError()) {
            this.recursiveInitializeError(node, this.m_Data);
        } else {
            this.recursiveInitializeErrorFromStatistic(node);
        }
    }

    public void pruneExecute(ClusNode node, int size) {
        this.computeCosts(node, size);
        SizeConstraintPruning.pruneToSizeK(node, size);
    }

    @Override
    public void prune(ClusNode node) throws ClusException {
        this.prune(0, node);
    }

    @Override
    public int getNbResults() {
        return Math.max(1, this.m_MaxSize.length);
    }

    @Override
    public String getPrunedName(int i) {
        return "S(" + this.m_MaxSize[i] + ")";
    }

    @Override
    public void prune(int result, ClusNode node) throws ClusException {
        if (this.m_MaxError == null) {
            int size = this.m_MaxSize[result];
            int orig = node.getNbNodes();
            ClusLogger.info("Pruning to size (" + orig + "): " + size);
            this.pruneInitialize(node, size);
            this.pruneExecute(node, size);
        } else if (this.m_MaxSize.length == 0) {
            this.pruneMaxError(node, node.getNbNodes());
        } else {
            this.pruneMaxError(node, this.m_MaxSize[result]);
        }
        node.clearVisitors();
    }

    @Override
    public void sequenceInitialize(ClusNode node) {
        int max_size = node.getNbNodes();
        int abs_max = this.getMaxSize();
        if (abs_max != -1 && max_size > abs_max) {
            max_size = abs_max;
        }
        if (max_size % 2 == 0) {
            --max_size;
        }
        this.m_CrIndex = this.m_MaxIndex = max_size;
        SizeConstraintPruning.recursiveInitialize(node, max_size);
        this.setOriginalTree(node);
    }

    @Override
    public void sequenceReset() {
        this.m_CrIndex = this.m_MaxIndex;
    }

    @Override
    public ClusNode sequenceNext() throws ClusException {
        if (this.m_CrIndex > 0) {
            ClusNode cloned = this.getOriginalTree().cloneTreeWithVisitors();
            this.pruneExecute(cloned, this.m_CrIndex);
            this.m_CrIndex -= 2;
            return cloned;
        }
        return null;
    }

    @Override
    public void sequenceToElemK(ClusNode node, int k) {
        this.pruneExecute(node, this.m_MaxIndex - 2 * k);
    }

    public void pruneMaxError(ClusNode node, int maxsize) throws ClusException {
        this.pruneInitialize(node, maxsize);
        int constr_ok_size = maxsize;
        for (int crsize = 1; crsize <= maxsize; crsize += 2) {
            ClusNode copy = node.cloneTreeWithVisitors();
            this.pruneExecute(copy, crsize);
            ClusErrorList cr_err = this.m_ErrorMeasure.getErrorClone();
            ClusError err = cr_err.getFirstError();
            TreeErrorComputer.computeErrorStandard(copy, this.m_Data, err);
            cr_err.setNbExamples(this.m_Data.getNbRows());
            if (this.m_MaxError.length == 1) {
                double max_err = this.m_MaxError[0];
                if (!(err.getModelError() <= max_err)) continue;
                constr_ok_size = crsize;
                break;
            }
            boolean isOK = true;
            for (int i = 0; i < this.m_MaxError.length; ++i) {
                double err_i = this.m_MaxError[i];
                if (Double.isNaN(err_i) || !(err.getModelErrorComponent(i) > err_i)) continue;
                isOK = false;
            }
            if (!isOK) continue;
            constr_ok_size = crsize;
            break;
        }
        this.pruneExecute(node, constr_ok_size);
    }

    private static void recursiveInitialize(ClusNode node, int size) {
        SizeConstraintVisitor visitor = new SizeConstraintVisitor(size);
        node.setVisitor(visitor);
        for (int i = 0; i < node.getNbChildren(); ++i) {
            ClusNode child = (ClusNode)node.getChild(i);
            SizeConstraintPruning.recursiveInitialize(child, size);
        }
    }

    private void recursiveInitializeError(ClusNode node, RowData data) throws ClusException {
        SizeConstraintVisitor visitor = (SizeConstraintVisitor)node.getVisitor();
        ClusErrorList parent = this.getAdditiveError();
        ClusError err = parent.getFirstError();
        parent.reset();
        TreeErrorComputer.computeErrorNode(node, data, err);
        parent.setNbExamples(data.getNbRows());
        visitor.error = err.getModelErrorAdditive();
        NodeTest tst = node.getTest();
        for (int i = 0; i < node.getNbChildren(); ++i) {
            ClusNode child = (ClusNode)node.getChild(i);
            RowData subset = data.applyWeighted(tst, i);
            this.recursiveInitializeError(child, subset);
        }
    }

    private void recursiveInitializeErrorFromStatistic(ClusNode node) {
        SizeConstraintVisitor visitor = (SizeConstraintVisitor)node.getVisitor();
        visitor.error = node.getTargetStat().getError(this.m_TargetWeights);
        for (int i = 0; i < node.getNbChildren(); ++i) {
            ClusNode child = (ClusNode)node.getChild(i);
            this.recursiveInitializeErrorFromStatistic(child);
        }
    }

    public double computeNodeCost(ClusNode node) {
        SizeConstraintVisitor visitor = (SizeConstraintVisitor)node.getVisitor();
        return visitor.error;
    }

    public double computeCosts(ClusNode node, int l) {
        SizeConstraintVisitor visitor = (SizeConstraintVisitor)node.getVisitor();
        if (visitor.computed[l]) {
            return visitor.cost[l];
        }
        if (l < 3 || node.atBottomLevel()) {
            visitor.cost[l] = this.computeNodeCost(node);
        } else {
            visitor.cost[l] = this.computeNodeCost(node);
            ClusNode ch1 = (ClusNode)node.getChild(0);
            ClusNode ch2 = (ClusNode)node.getChild(1);
            for (int k1 = 1; k1 <= l - 2; ++k1) {
                int k2 = l - k1 - 1;
                double cost = this.computeCosts(ch1, k1) + this.computeCosts(ch2, k2);
                if (!(cost < visitor.cost[l])) continue;
                visitor.cost[l] = cost;
                visitor.left[l] = k1;
            }
        }
        visitor.computed[l] = true;
        return visitor.cost[l];
    }

    public static void pruneToSizeK(ClusNode node, int l) {
        if (node.atBottomLevel()) {
            return;
        }
        SizeConstraintVisitor visitor = (SizeConstraintVisitor)node.getVisitor();
        if (l < 3 || visitor.left[l] == 0) {
            node.makeLeaf();
        } else {
            int k1 = visitor.left[l];
            int k2 = l - k1 - 1;
            SizeConstraintPruning.pruneToSizeK((ClusNode)node.getChild(0), k1);
            SizeConstraintPruning.pruneToSizeK((ClusNode)node.getChild(1), k2);
        }
    }

    public void setMaxError(double[] max_err) {
        this.m_MaxError = max_err;
    }

    public void setErrorMeasure(ClusErrorList parent) {
        this.m_ErrorMeasure = parent;
    }

    public void setAdditiveError(ClusErrorList parent) {
        this.m_AdditiveError = parent;
    }

    public ClusErrorList getAdditiveError() {
        return this.m_AdditiveError;
    }

    public boolean isUsingAdditiveError() {
        return this.m_AdditiveError != null;
    }
}

