/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.tree;

import ai.h2o.algos.tree.INode;
import ai.h2o.algos.tree.INodeStat;
import hex.genmodel.algos.tree.TreeSHAPPredictor;
import java.io.Serializable;

public class TreeSHAP<R, N extends INode<R>, S extends INodeStat>
implements TreeSHAPPredictor<R> {
    private final int rootNodeId;
    private final N[] nodes;
    private final S[] stats;
    private final float expectedTreeValue;

    public TreeSHAP(N[] nodes, S[] stats, int rootNodeId) {
        this.rootNodeId = rootNodeId;
        this.nodes = nodes;
        this.stats = stats;
        this.expectedTreeValue = this.treeMeanValue();
    }

    private void extendPath(PathPointer unique_path, int unique_depth, float zero_fraction, float one_fraction, int feature_index) {
        unique_path.get((int)unique_depth).feature_index = feature_index;
        unique_path.get((int)unique_depth).zero_fraction = zero_fraction;
        unique_path.get((int)unique_depth).one_fraction = one_fraction;
        unique_path.get((int)unique_depth).pweight = unique_depth == 0 ? 1.0f : 0.0f;
        for (int i2 = unique_depth - 1; i2 >= 0; --i2) {
            unique_path.get((int)(i2 + 1)).pweight += one_fraction * unique_path.get((int)i2).pweight * (float)(i2 + 1) / (float)(unique_depth + 1);
            unique_path.get((int)i2).pweight = zero_fraction * unique_path.get((int)i2).pweight * (float)(unique_depth - i2) / (float)(unique_depth + 1);
        }
    }

    private void unwindPath(PathPointer unique_path, int unique_depth, int path_index) {
        int i2;
        float one_fraction = unique_path.get((int)path_index).one_fraction;
        float zero_fraction = unique_path.get((int)path_index).zero_fraction;
        float next_one_portion = unique_path.get((int)unique_depth).pweight;
        for (i2 = unique_depth - 1; i2 >= 0; --i2) {
            if (one_fraction != 0.0f) {
                float tmp = unique_path.get((int)i2).pweight;
                unique_path.get((int)i2).pweight = next_one_portion * (float)(unique_depth + 1) / ((float)(i2 + 1) * one_fraction);
                next_one_portion = tmp - unique_path.get((int)i2).pweight * zero_fraction * (float)(unique_depth - i2) / (float)(unique_depth + 1);
                continue;
            }
            unique_path.get((int)i2).pweight = unique_path.get((int)i2).pweight * (float)(unique_depth + 1) / (zero_fraction * (float)(unique_depth - i2));
        }
        for (i2 = path_index; i2 < unique_depth; ++i2) {
            unique_path.get((int)i2).feature_index = unique_path.get((int)(i2 + 1)).feature_index;
            unique_path.get((int)i2).zero_fraction = unique_path.get((int)(i2 + 1)).zero_fraction;
            unique_path.get((int)i2).one_fraction = unique_path.get((int)(i2 + 1)).one_fraction;
        }
    }

    private float unwoundPathSum(PathPointer unique_path, int unique_depth, int path_index) {
        float one_fraction = unique_path.get((int)path_index).one_fraction;
        float zero_fraction = unique_path.get((int)path_index).zero_fraction;
        float next_one_portion = unique_path.get((int)unique_depth).pweight;
        float total = 0.0f;
        for (int i2 = unique_depth - 1; i2 >= 0; --i2) {
            if (one_fraction != 0.0f) {
                float tmp = next_one_portion * (float)(unique_depth + 1) / ((float)(i2 + 1) * one_fraction);
                total += tmp;
                next_one_portion = unique_path.get((int)i2).pweight - tmp * zero_fraction * ((float)(unique_depth - i2) / (float)(unique_depth + 1));
                continue;
            }
            if (zero_fraction != 0.0f) {
                total += unique_path.get((int)i2).pweight / zero_fraction / ((float)(unique_depth - i2) / (float)(unique_depth + 1));
                continue;
            }
            if (unique_path.get((int)i2).pweight == 0.0f) continue;
            throw new IllegalStateException("Unique path " + i2 + " must have zero getWeight");
        }
        return total;
    }

    private void treeShap(R feat, float[] phi, N node, S nodeStat, int unique_depth, PathPointer parent_unique_path, float parent_zero_fraction, float parent_one_fraction, int parent_feature_index, int condition, int condition_feature, float condition_fraction) {
        if (condition_fraction == 0.0f) {
            return;
        }
        PathPointer unique_path = parent_unique_path.move(unique_depth);
        if (condition == 0 || condition_feature != parent_feature_index) {
            this.extendPath(unique_path, unique_depth, parent_zero_fraction, parent_one_fraction, parent_feature_index);
        }
        int split_index = node.getSplitIndex();
        if (node.isLeaf()) {
            for (int i2 = 1; i2 <= unique_depth; ++i2) {
                float w2 = this.unwoundPathSum(unique_path, unique_depth, i2);
                PathElement el = unique_path.get(i2);
                int n2 = el.feature_index;
                phi[n2] = phi[n2] + w2 * (el.one_fraction - el.zero_fraction) * node.getLeafValue() * condition_fraction;
            }
        } else {
            int path_index;
            int hot_index = node.next(feat);
            int cold_index = hot_index == node.getLeftChildIndex() ? node.getRightChildIndex() : node.getLeftChildIndex();
            float w3 = nodeStat.getWeight();
            float hot_zero_fraction = this.stats[hot_index].getWeight() / w3;
            float cold_zero_fraction = this.stats[cold_index].getWeight() / w3;
            float incoming_zero_fraction = 1.0f;
            float incoming_one_fraction = 1.0f;
            for (path_index = 0; path_index <= unique_depth && unique_path.get((int)path_index).feature_index != split_index; ++path_index) {
            }
            if (path_index != unique_depth + 1) {
                incoming_zero_fraction = unique_path.get((int)path_index).zero_fraction;
                incoming_one_fraction = unique_path.get((int)path_index).one_fraction;
                this.unwindPath(unique_path, unique_depth, path_index);
                --unique_depth;
            }
            float hot_condition_fraction = condition_fraction;
            float cold_condition_fraction = condition_fraction;
            if (condition > 0 && split_index == condition_feature) {
                cold_condition_fraction = 0.0f;
                --unique_depth;
            } else if (condition < 0 && split_index == condition_feature) {
                hot_condition_fraction *= hot_zero_fraction;
                cold_condition_fraction *= cold_zero_fraction;
                --unique_depth;
            }
            this.treeShap(feat, phi, this.nodes[hot_index], this.stats[hot_index], unique_depth + 1, unique_path, hot_zero_fraction * incoming_zero_fraction, incoming_one_fraction, split_index, condition, condition_feature, hot_condition_fraction);
            this.treeShap(feat, phi, this.nodes[cold_index], this.stats[cold_index], unique_depth + 1, unique_path, cold_zero_fraction * incoming_zero_fraction, 0.0f, split_index, condition, condition_feature, cold_condition_fraction);
        }
    }

    @Override
    public float[] calculateContributions(R feat, float[] out_contribs) {
        return this.calculateContributions(feat, out_contribs, 0, -1, this.makeWorkspace());
    }

    @Override
    public float[] calculateContributions(R feat, float[] out_contribs, int condition, int condition_feature, Object workspace) {
        if (condition == 0) {
            int n2 = out_contribs.length - 1;
            out_contribs[n2] = out_contribs[n2] + this.expectedTreeValue;
        }
        PathPointer uniquePathWorkspace = (PathPointer)workspace;
        uniquePathWorkspace.reset();
        this.treeShap(feat, out_contribs, this.nodes[this.rootNodeId], this.stats[this.rootNodeId], 0, uniquePathWorkspace, 1.0f, 1.0f, -1, condition, condition_feature, 1.0f);
        return out_contribs;
    }

    @Override
    public PathPointer makeWorkspace() {
        int wsSize = this.getWorkspaceSize();
        PathElement[] unique_path_data = new PathElement[wsSize];
        for (int i2 = 0; i2 < unique_path_data.length; ++i2) {
            unique_path_data[i2] = new PathElement();
        }
        return new PathPointer(unique_path_data);
    }

    @Override
    public int getWorkspaceSize() {
        int maxd = this.treeDepth() + 2;
        return maxd * (maxd + 1) / 2;
    }

    private int treeDepth() {
        return TreeSHAP.nodeDepth(this.nodes, (int)0);
    }

    private static <N extends INode> int nodeDepth(N[] nodes, int node) {
        N n2 = nodes[node];
        if (n2.isLeaf()) {
            return 1;
        }
        return 1 + Math.max(TreeSHAP.nodeDepth(nodes, (int)n2.getLeftChildIndex()), TreeSHAP.nodeDepth(nodes, (int)n2.getRightChildIndex()));
    }

    private float treeMeanValue() {
        return TreeSHAP.nodeMeanValue(this.nodes, this.stats, (int)0);
    }

    private static <N extends INode, S extends INodeStat> float nodeMeanValue(N[] nodes, S[] stats, int node) {
        N n2 = nodes[node];
        if (n2.isLeaf()) {
            return n2.getLeafValue();
        }
        return (stats[n2.getLeftChildIndex()].getWeight() * TreeSHAP.nodeMeanValue(nodes, stats, (int)n2.getLeftChildIndex()) + stats[n2.getRightChildIndex()].getWeight() * TreeSHAP.nodeMeanValue(nodes, stats, (int)n2.getRightChildIndex())) / stats[node].getWeight();
    }

    public static class PathPointer {
        PathElement[] path;
        int position;

        PathPointer(PathElement[] path) {
            this.path = path;
        }

        PathPointer(PathElement[] path, int position) {
            this.path = path;
            this.position = position;
        }

        PathElement get(int i2) {
            return this.path[this.position + i2];
        }

        PathPointer move(int len) {
            for (int i2 = 0; i2 < len; ++i2) {
                this.path[this.position + len + i2].feature_index = this.path[this.position + i2].feature_index;
                this.path[this.position + len + i2].zero_fraction = this.path[this.position + i2].zero_fraction;
                this.path[this.position + len + i2].one_fraction = this.path[this.position + i2].one_fraction;
                this.path[this.position + len + i2].pweight = this.path[this.position + i2].pweight;
            }
            return new PathPointer(this.path, this.position + len);
        }

        void reset() {
            this.path[0].reset();
        }
    }

    private static class PathElement
    implements Serializable {
        int feature_index;
        float zero_fraction;
        float one_fraction;
        float pweight;

        private PathElement() {
        }

        void reset() {
            this.feature_index = 0;
            this.zero_fraction = 0.0f;
            this.one_fraction = 0.0f;
            this.pweight = 0.0f;
        }
    }
}

