/*
 * Decompiled with CFR 0.152.
 */
package ca.ubc.cs.beta.models.fastrf;

import ca.ubc.cs.beta.models.fastrf.Regtree;
import ca.ubc.cs.beta.models.fastrf.utils.Utils;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.Set;

public strictfp class RegtreeFwd {
    public static int[] fwd(Regtree tree, double[][] X) {
        int numdata = X.length;
        int numnodes = tree.node.length;
        if (tree.cut.length != numnodes) {
            throw new RuntimeException("cut must be Nx1 vector.");
        }
        if (tree.nodepred.length != numnodes) {
            throw new RuntimeException("nodepred must be Nx1 vector.");
        }
        if (tree.children.length != numnodes) {
            throw new RuntimeException("children must be Nx2 matrix.");
        }
        if (X[0].length != tree.npred) {
            throw new IllegalArgumentException("X should be square matrix and should be have " + tree.npred + " columns not " + X[0].length);
        }
        int[] result = new int[numdata];
        for (int i = 0; i < numdata; ++i) {
            int thisnode;
            block10: {
                thisnode = 0;
                while (true) {
                    int splitvar;
                    if ((splitvar = tree.var[thisnode]) == 0) break block10;
                    double cutoff = tree.cut[thisnode];
                    int left_kid = tree.children[thisnode][0];
                    int right_kid = tree.children[thisnode][1];
                    if (Double.isNaN(X[i][Math.abs(splitvar) - 1])) {
                        throw new RuntimeException("In fwd, trying to split on variable " + splitvar + " (1-based, negative means categorical), but data point number " + i + " is NaN for that.");
                    }
                    if (splitvar > 0) {
                        thisnode = X[i][splitvar - 1] <= cutoff ? left_kid : right_kid;
                        continue;
                    }
                    int x = (int)X[i][-splitvar - 1];
                    if (x <= 0) {
                        throw new RuntimeException("Input error in Regtree.fwd: categoricals have to be integers >= 1");
                    }
                    int split = tree.catsplit[(int)cutoff][x - 1];
                    if (split == 0) {
                        thisnode = left_kid;
                        continue;
                    }
                    if (split != 1) break;
                    thisnode = right_kid;
                }
                throw new RuntimeException("Missing value -- not allowed in this implementation.");
            }
            result[i] = thisnode;
        }
        return result;
    }

    public static Object[] marginalFwd(Regtree tree, double[][] Theta, double[][] X) {
        if (Theta == null || Theta.length == 0) {
            throw new RuntimeException("Theta must not be empty");
        }
        int thetarows = Theta.length;
        int thetacols = Theta[0].length;
        int numnodes = tree.node.length;
        if (numnodes == 0) {
            throw new RuntimeException("Tree must exist.");
        }
        if (tree.cut.length != numnodes) {
            throw new RuntimeException("cut must be Nx1 vector.");
        }
        if (tree.nodepred.length != numnodes) {
            throw new RuntimeException("nodepred must be Nx1 vector.");
        }
        if (tree.parent.length != numnodes) {
            throw new RuntimeException("parent must be Nx1 matrix.");
        }
        if (tree.children.length != numnodes) {
            throw new RuntimeException("children must be Nx2 matrix.");
        }
        if (!tree.preprocessed) {
            tree = RegtreeFwd.preprocess_inst_splits(tree, X);
        }
        double[] result = new double[thetarows];
        double[] vars = new double[thetarows];
        LinkedList<Integer> queue = new LinkedList<Integer>();
        for (int i = 0; i < thetarows; ++i) {
            vars[i] = 0.0;
            queue.add(0);
            block1: while (!queue.isEmpty()) {
                int thisnode = (Integer)queue.poll();
                while (true) {
                    int splitvar = tree.var[thisnode];
                    double cutoff = tree.cut[thisnode];
                    int left_kid = tree.children[thisnode][0];
                    int right_kid = tree.children[thisnode][1];
                    if (splitvar == 0) {
                        int n = i;
                        result[n] = result[n] + tree.weightedpred[thisnode];
                        continue block1;
                    }
                    if (Math.abs(splitvar) > thetacols) {
                        queue.add(right_kid);
                        thisnode = left_kid;
                        continue;
                    }
                    if (Double.isNaN(Theta[i][Math.abs(splitvar) - 1])) {
                        throw new RuntimeException("In marginalFwd, trying to split on variable " + splitvar + " (1-based, negative means categorical), but data point number " + i + " is NaN for that.");
                    }
                    if (splitvar > 0) {
                        thisnode = Theta[i][splitvar - 1] <= cutoff ? left_kid : right_kid;
                        continue;
                    }
                    int x = (int)Theta[i][-splitvar - 1];
                    if (x <= 0) {
                        throw new RuntimeException("Input error in Regtree.marginalFwd: categoricals have to be integers >= 1");
                    }
                    int split = tree.catsplit[(int)cutoff][x - 1];
                    if (split == 0) {
                        thisnode = left_kid;
                        continue;
                    }
                    if (split != 1) break;
                    thisnode = right_kid;
                }
                throw new RuntimeException("Missing value -- not allowed in this implementation.");
            }
        }
        Object[] retn = new Object[]{result, vars};
        return retn;
    }

    public static Set<Integer> marginalFwdNodes(Regtree tree, double[][] Theta, double[][] X) {
        if (Theta == null || Theta.length == 0) {
            throw new RuntimeException("Theta must not be empty");
        }
        int thetarows = Theta.length;
        int thetacols = Theta[0].length;
        int numnodes = tree.node.length;
        if (numnodes == 0) {
            throw new RuntimeException("Tree must exist.");
        }
        if (tree.cut.length != numnodes) {
            throw new RuntimeException("cut must be Nx1 vector.");
        }
        if (tree.nodepred.length != numnodes) {
            throw new RuntimeException("nodepred must be Nx1 vector.");
        }
        if (tree.parent.length != numnodes) {
            throw new RuntimeException("parent must be Nx1 matrix.");
        }
        if (tree.children.length != numnodes) {
            throw new RuntimeException("children must be Nx2 matrix.");
        }
        if (!tree.preprocessed) {
            tree = RegtreeFwd.preprocess_inst_splits(tree, X);
        }
        double[] result = new double[thetarows];
        double[] vars = new double[thetarows];
        LinkedList<Integer> queue = new LinkedList<Integer>();
        HashSet<Integer> locations = new HashSet<Integer>();
        for (int i = 0; i < thetarows; ++i) {
            vars[i] = 0.0;
            queue.add(0);
            block1: while (!queue.isEmpty()) {
                int thisnode = (Integer)queue.poll();
                while (true) {
                    int splitvar = tree.var[thisnode];
                    double cutoff = tree.cut[thisnode];
                    int left_kid = tree.children[thisnode][0];
                    int right_kid = tree.children[thisnode][1];
                    if (splitvar == 0) {
                        int n = i;
                        result[n] = result[n] + tree.weightedpred[thisnode];
                        locations.add(thisnode);
                        continue block1;
                    }
                    if (Math.abs(splitvar) > thetacols) {
                        queue.add(right_kid);
                        thisnode = left_kid;
                        continue;
                    }
                    if (Double.isNaN(Theta[i][Math.abs(splitvar) - 1])) {
                        throw new RuntimeException("In marginalFwd, trying to split on variable " + splitvar + " (1-based, negative means categorical), but data point number " + i + " is NaN for that.");
                    }
                    if (splitvar > 0) {
                        thisnode = Theta[i][splitvar - 1] <= cutoff ? left_kid : right_kid;
                        continue;
                    }
                    int x = (int)Theta[i][-splitvar - 1];
                    if (x <= 0) {
                        throw new RuntimeException("Input error in Regtree.marginalFwd: categoricals have to be integers >= 1");
                    }
                    int split = tree.catsplit[(int)cutoff][x - 1];
                    if (split == 0) {
                        thisnode = left_kid;
                        continue;
                    }
                    if (split != 1) break;
                    thisnode = right_kid;
                }
                throw new RuntimeException("Missing value -- not allowed in this implementation.");
            }
        }
        Object[] retn = new Object[]{result, vars};
        return locations;
    }

    public static Regtree preprocess_inst_splits(Regtree tree, double[][] X) {
        int i;
        tree = new Regtree(tree);
        int numnodes = tree.node.length;
        if (numnodes == 0) {
            throw new RuntimeException("Tree must exist.");
        }
        if (tree.cut.length != numnodes) {
            throw new RuntimeException("cut must be Nx1 vector.");
        }
        if (tree.nodepred.length != numnodes) {
            throw new RuntimeException("nodepred must be Nx1 vector.");
        }
        if (tree.parent.length != numnodes) {
            throw new RuntimeException("parent must be Nx1 matrix.");
        }
        if (tree.children.length != numnodes) {
            throw new RuntimeException("children must be Nx2 matrix.");
        }
        tree.weights = new double[numnodes];
        tree.weightedpred = new double[numnodes];
        tree.weightedvar = new double[numnodes];
        for (int i2 = 0; i2 < numnodes; ++i2) {
            tree.weights[i2] = 0.0;
            if (tree.var[i2] != 0) continue;
            tree.weightedpred[i2] = tree.nodepred[i2];
            tree.weightedvar[i2] = tree.nodevar[i2];
        }
        if (X == null) {
            if (tree.preprocessed) {
                return tree;
            }
            throw new IllegalStateException("No X Matrix passed, but tree is not preprocessed");
        }
        int numinsts = X.length;
        int thetacols = tree.npred - X[0].length;
        LinkedList<Integer> queue = new LinkedList<Integer>();
        for (i = 0; i < numinsts; ++i) {
            queue.add(0);
            block2: while (!queue.isEmpty()) {
                int thisnode = (Integer)queue.poll();
                while (true) {
                    int splitvar = tree.var[thisnode];
                    double cutoff = tree.cut[thisnode];
                    int left_kid = tree.children[thisnode][0];
                    int right_kid = tree.children[thisnode][1];
                    if (splitvar == 0) {
                        int n = thisnode;
                        tree.weights[n] = tree.weights[n] + 1.0;
                        continue block2;
                    }
                    if (Math.abs(splitvar) <= thetacols) {
                        queue.add(right_kid);
                        thisnode = left_kid;
                        continue;
                    }
                    if (Double.isNaN(X[i][Math.abs(splitvar) - 1 - thetacols])) {
                        throw new RuntimeException("In preprocess_inst_splits, trying to split on variable " + splitvar + " (1-based, negative means categorical), but data point number " + i + " is NaN for that.");
                    }
                    if (splitvar > 0) {
                        thisnode = X[i][splitvar - 1 - thetacols] <= cutoff ? left_kid : right_kid;
                        continue;
                    }
                    int x = (int)X[i][-splitvar - 1 - thetacols];
                    int split = tree.catsplit[(int)cutoff][x - 1];
                    if (split == 0) {
                        thisnode = left_kid;
                        continue;
                    }
                    if (split != 1) break;
                    thisnode = right_kid;
                }
                throw new RuntimeException("Missing value -- not allowed in this implementation.");
            }
        }
        for (i = 0; i < numnodes; ++i) {
            int n = i;
            tree.weights[n] = tree.weights[n] / (double)numinsts;
            int n2 = i;
            tree.weightedpred[n2] = tree.weightedpred[n2] * tree.weights[i];
            int n3 = i;
            tree.weightedvar[n3] = tree.weightedvar[n3] * (tree.weights[i] * tree.weights[i]);
        }
        RegtreeFwd.cut_instance_leaf_split_helper(tree, thetacols, 0);
        tree.preprocessed = true;
        return tree;
    }

    private static int cut_instance_leaf_split_helper(Regtree tree, int thetacols, int thisnode) {
        if (tree.var[thisnode] == 0) {
            return 1;
        }
        int left_kid = tree.children[thisnode][0];
        int right_kid = tree.children[thisnode][1];
        int ret = 0;
        if (RegtreeFwd.cut_instance_leaf_split_helper(tree, thetacols, left_kid) + RegtreeFwd.cut_instance_leaf_split_helper(tree, thetacols, right_kid) == 2 && Math.abs(tree.var[thisnode]) > thetacols) {
            RegtreeFwd.make_into_leaf(tree, thisnode);
            ret = 1;
        }
        tree.weights[thisnode] = tree.weights[left_kid] + tree.weights[right_kid];
        if (ret == 0 && tree.weights[thisnode] == 0.0) {
            RegtreeFwd.make_into_leaf(tree, thisnode);
            ret = 1;
        }
        return ret;
    }

    private static void make_into_leaf(Regtree tree, int thisnode) {
        int left_kid = tree.children[thisnode][0];
        int right_kid = tree.children[thisnode][1];
        tree.children[thisnode][0] = 0;
        tree.children[thisnode][1] = 0;
        tree.var[thisnode] = 0;
        tree.weightedpred[thisnode] = tree.weightedpred[left_kid] + tree.weightedpred[right_kid];
        tree.weightedvar[thisnode] = tree.weightedvar[left_kid] + tree.weightedvar[right_kid];
    }

    public static void preprocess_for_classification(Regtree tree) {
        if (!tree.resultsStoredInLeaves) {
            throw new RuntimeException("Classification can only be done if the tree was built with the resultsStoredInLeaves flag on.");
        }
        tree.bestClasses = new double[tree.numNodes][];
        for (int i = 0; i < tree.numNodes; ++i) {
            if (tree.var[i] != 0) continue;
            tree.bestClasses[i] = Utils.mode(tree.ysub[i]);
        }
        tree.preprocessed_for_classification = true;
    }
}

