/*
 * Decompiled with CFR 0.152.
 */
package hex.tree;

import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import water.DKV;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.rapids.Rapids;
import water.rapids.Val;
import water.util.ArrayUtils;
import water.util.VecUtils;

public class FriedmanPopescusH {
    public static double h(Frame frame, String[] vars, double learnRate, SharedTreeSubgraph[][] sharedTreeSubgraphs) {
        int i2;
        Frame filteredFrame = FriedmanPopescusH.filterFrame(frame, vars);
        int[] modelIds = FriedmanPopescusH.getModelIds(frame.names(), vars);
        HashMap<String, Frame> fValues = new HashMap<String, Frame>();
        int numCols = filteredFrame.numCols();
        int[] colIds = new int[numCols];
        for (i2 = 0; i2 < numCols; ++i2) {
            colIds[i2] = i2;
        }
        for (i2 = numCols; i2 > 0; --i2) {
            List<int[]> currCombinations = FriedmanPopescusH.combinations(colIds, i2);
            for (int j2 = 0; j2 < currCombinations.size(); ++j2) {
                int[] currCombination = currCombinations.get(j2);
                String[] cols = FriedmanPopescusH.getCurrCombinationCols(currCombination, vars);
                int[] currModelIds = FriedmanPopescusH.getCurrentCombinationModelIds(currCombination, modelIds);
                fValues.put(Arrays.toString(currCombination), FriedmanPopescusH.computeFValues(currModelIds, filteredFrame, cols, learnRate, sharedTreeSubgraphs));
            }
        }
        return FriedmanPopescusH.computeHValue(fValues, filteredFrame, colIds);
    }

    static int[] getCurrentCombinationModelIds(int[] currCombination, int[] modelIds) {
        int[] currCombinationCols = new int[currCombination.length];
        for (int i2 = 0; i2 < currCombination.length; ++i2) {
            currCombinationCols[i2] = modelIds[currCombination[i2]];
        }
        return currCombinationCols;
    }

    static double computeHValue(Map<String, Frame> fValues, Frame filteredFrame, int[] inds) {
        if (filteredFrame._key == null) {
            filteredFrame._key = Key.make();
        }
        Frame uniqueWithCounts = FriedmanPopescusH.uniqueRowsWithCounts(filteredFrame);
        long uniqHeight = uniqueWithCounts.numRows();
        Vec numerEls = Vec.makeZero(uniqHeight);
        Vec denomEls = Vec.makeZero(uniqHeight);
        for (long i2 = 0L; i2 < uniqHeight; ++i2) {
            int sign = 1;
            for (int n2 = inds.length; n2 > 0; --n2) {
                List<int[]> currCombinations = FriedmanPopescusH.combinations(inds, n2);
                for (int j2 = 0; j2 < currCombinations.size(); ++j2) {
                    double fValue = FriedmanPopescusH.findFValue(i2, (int[])currCombinations.toArray()[j2], fValues.get(Arrays.toString((int[])currCombinations.toArray()[j2])), filteredFrame);
                    numerEls.set(i2, numerEls.at(i2) + (double)((float)sign * (float)fValue));
                }
                sign *= -1;
            }
            denomEls.set(i2, (float)fValues.get(Arrays.toString(inds)).vec(0).at(i2));
        }
        double numer = ((Transform)new Transform((int)2).doAll((Vec[])new Vec[]{numerEls, uniqueWithCounts.vec((String)"nrow")})).result;
        double denom = ((Transform)new Transform((int)2).doAll((Vec[])new Vec[]{denomEls, uniqueWithCounts.vec((String)"nrow")})).result;
        return numer < denom ? Math.sqrt(numer / denom) : Double.NaN;
    }

    static double[] getValueToFindFValueFor(int[] currCombination, Frame filteredFrame, long i2) {
        int combinationLength = currCombination.length;
        double[] value = new double[combinationLength];
        for (int j2 = 0; j2 < combinationLength; ++j2) {
            value[j2] = filteredFrame.vec(currCombination[j2]).at(i2);
        }
        return value;
    }

    static double findFValue(long i2, int[] currCombination, Frame currFValues, Frame filteredFrame) {
        String[] currNames;
        double[] valueToFindFValueFor = FriedmanPopescusH.getValueToFindFValueFor(currCombination, filteredFrame, i2);
        FindFValue findFValueTask = new FindFValue(valueToFindFValueFor, currNames = FriedmanPopescusH.getCurrCombinationNames(currCombination, filteredFrame.names()), currFValues._names, 1.0E-5);
        Frame result = ((FindFValue)findFValueTask.doAll((byte)3, currFValues)).outputFrame();
        if (result.numRows() == 0L) {
            throw new RuntimeException("FValue was not found!" + Arrays.toString(currCombination) + "value: " + Arrays.toString(valueToFindFValueFor));
        }
        return result.vec(0).at(0L);
    }

    static String[] getCurrCombinationNames(int[] currCombination, String[] names) {
        String[] currNames = new String[currCombination.length];
        for (int j2 = 0; j2 < currCombination.length; ++j2) {
            currNames[j2] = names[currCombination[j2]];
        }
        return currNames;
    }

    static String[] getCurrCombinationCols(int[] currCombination, String[] vars) {
        String[] currCombinationCols = new String[currCombination.length];
        for (int i2 = 0; i2 < currCombination.length; ++i2) {
            currCombinationCols[i2] = vars[currCombination[i2]];
        }
        return currCombinationCols;
    }

    static int findFirstNumericalColumn(Frame frame) {
        for (int i2 = 0; i2 < frame.names().length; ++i2) {
            if (!frame.vec(i2).isNumeric()) continue;
            return i2;
        }
        return -1;
    }

    static Frame uniqueRowsWithCounts(Frame frame) {
        int i2;
        DKV.put(frame);
        StringBuilder sb = new StringBuilder("(GB ");
        String[] cols = frame.names();
        sb.append(frame._key.toString());
        sb.append(" [");
        for (i2 = 0; i2 < cols.length; ++i2) {
            if (i2 != 0) {
                sb.append(",");
            }
            sb.append(i2);
        }
        sb.append("] ");
        i2 = FriedmanPopescusH.findFirstNumericalColumn(frame);
        if (i2 == -1) {
            frame.add("nrow", Vec.makeOne(frame.numRows()));
            return frame;
        }
        sb.append(" nrow ").append(i2).append(" \"all\")");
        Val val = Rapids.exec(sb.toString());
        DKV.remove(frame._key);
        return val.getFrame();
    }

    static Frame computeFValues(int[] modelIds, Frame filteredFrame, String[] cols, double learnRate, SharedTreeSubgraph[][] sharedTreeSubgraphs) {
        filteredFrame = FriedmanPopescusH.filterFrame(filteredFrame, cols);
        filteredFrame = new Frame(Key.make(), filteredFrame.names(), filteredFrame.vecs());
        Frame uniqueWithCounts = FriedmanPopescusH.uniqueRowsWithCounts(filteredFrame);
        Frame uncenteredFvalues = new Frame(FriedmanPopescusH.partialDependence(modelIds, uniqueWithCounts, learnRate, sharedTreeSubgraphs).vec(0));
        VecUtils.DotProduct multiply = (VecUtils.DotProduct)new VecUtils.DotProduct().doAll(uniqueWithCounts.vec("nrow"), uncenteredFvalues.vec(0));
        double meanUncenteredFValue = multiply.result / (double)filteredFrame.numRows();
        int i2 = 0;
        while ((long)i2 < uncenteredFvalues.numRows()) {
            uncenteredFvalues.vec(0).set((long)i2, uncenteredFvalues.vec(0).at(i2) - meanUncenteredFValue);
            ++i2;
        }
        return uncenteredFvalues.add(uniqueWithCounts);
    }

    static Frame partialDependence(int[] modelIds, Frame uniqueWithCounts, double learnRate, SharedTreeSubgraph[][] sharedTreeSubgraphs) {
        Frame result = new Frame(new Vec[0]);
        int nclasses = sharedTreeSubgraphs[0].length;
        int ntrees = sharedTreeSubgraphs.length;
        for (int treeClass = 0; treeClass < nclasses; ++treeClass) {
            Vec pdp = Vec.makeZero(uniqueWithCounts.numRows());
            for (int i2 = 0; i2 < ntrees; ++i2) {
                SharedTreeSubgraph sharedTreeSubgraph = sharedTreeSubgraphs[i2][treeClass];
                Vec currTreePdp = FriedmanPopescusH.partialDependenceTree(sharedTreeSubgraph, modelIds, learnRate, uniqueWithCounts);
                for (long j2 = 0L; j2 < uniqueWithCounts.numRows(); ++j2) {
                    pdp.set(j2, pdp.at(j2) + currTreePdp.at(j2));
                }
            }
            result.add("pdp_C" + treeClass, pdp);
        }
        return result;
    }

    public static double[] add(double[] first, double[] second) {
        int length = Math.min(first.length, second.length);
        double[] result = new double[length];
        for (int i2 = 0; i2 < length; ++i2) {
            result[i2] = first[i2] + second[i2];
        }
        return result;
    }

    static Frame filterFrame(Frame frame, String[] cols) {
        Frame frame1 = new Frame(new Vec[0]);
        frame1.add(cols, frame.vecs(cols));
        return frame1;
    }

    static int[] getModelIds(String[] frameNames, String[] vars) {
        int[] modelIds = new int[vars.length];
        Arrays.fill(modelIds, -1);
        for (int i2 = 0; i2 < vars.length; ++i2) {
            for (int j2 = 0; j2 < frameNames.length; ++j2) {
                if (!vars[i2].equals(frameNames[j2])) continue;
                modelIds[i2] = j2;
            }
            if (modelIds[i2] != -1) continue;
            throw new RuntimeException("Column " + vars[i2] + " is not present in the input frame!");
        }
        return modelIds;
    }

    static List<int[]> combinations(int[] vals, int combinationSize) {
        ArrayList<int[]> overallResult = new ArrayList<int[]>();
        FriedmanPopescusH.combinations(vals, combinationSize, 0, new int[combinationSize], overallResult);
        return overallResult;
    }

    private static void combinations(int[] arr, int len, int startPosition, int[] result, List<int[]> overallResult) {
        if (len == 0) {
            overallResult.add((int[])result.clone());
            return;
        }
        for (int i2 = startPosition; i2 <= arr.length - len; ++i2) {
            result[result.length - len] = arr[i2];
            FriedmanPopescusH.combinations(arr, len - 1, i2 + 1, result, overallResult);
        }
    }

    static Vec partialDependenceTree(SharedTreeSubgraph tree, int[] targetFeature, double learnRate, Frame grid) {
        Vec outVec = Vec.makeZero(grid.numRows());
        SharedTreeNode[] nodeStackAr = new SharedTreeNode[tree.nodesArray.size() * 2];
        Object[] weightStackAr = new Double[tree.nodesArray.size() * 2];
        Arrays.fill(weightStackAr, (Object)1.0);
        for (long i2 = 0L; i2 < grid.numRows(); ++i2) {
            int stackSize = 1;
            nodeStackAr[0] = tree.rootNode;
            weightStackAr[0] = 1.0;
            double totalWeight = 0.0;
            while (stackSize > 0) {
                SharedTreeNode currNode;
                if ((currNode = nodeStackAr[--stackSize]).isLeaf()) {
                    outVec.set(i2, outVec.at(i2) + (Double)weightStackAr[stackSize] * (double)currNode.getPredValue() * learnRate);
                    totalWeight += ((Double)weightStackAr[stackSize]).doubleValue();
                    continue;
                }
                int featureId = ArrayUtils.find(targetFeature, currNode.getColId());
                if (featureId >= 0) {
                    nodeStackAr[stackSize] = grid.vec(featureId).at(i2) <= (double)currNode.getSplitValue() ? currNode.getLeftChild() : currNode.getRightChild();
                    ++stackSize;
                    continue;
                }
                double currWeight = (Double)weightStackAr[stackSize];
                nodeStackAr[stackSize] = currNode.getLeftChild();
                double left_sample_frac = currNode.getLeftChild().getWeight() / currNode.getWeight();
                weightStackAr[stackSize] = currWeight * left_sample_frac;
                nodeStackAr[++stackSize] = currNode.getRightChild();
                weightStackAr[stackSize] = currWeight * (1.0 - left_sample_frac);
                ++stackSize;
            }
            if (0.999 < totalWeight && totalWeight < 1.001) continue;
            throw new RuntimeException("Total weight should be 1.0 but was " + totalWeight);
        }
        return outVec;
    }

    static class FindFValue
    extends MRTask<FindFValue> {
        double[] valueToFindFValueFor;
        String[] currNames;
        String[] currFValuesNames;
        double eps;

        FindFValue(double[] valueToFindFValueFor, String[] currNames, String[] currFValuesNames, double eps) {
            this.valueToFindFValueFor = valueToFindFValueFor;
            this.currNames = currNames;
            this.currFValuesNames = currFValuesNames;
            this.eps = eps;
        }

        @Override
        public void map(Chunk[] cs, NewChunk[] nc) {
            int count2 = 0;
            for (int iRow = 0; iRow < cs[0].len(); ++iRow) {
                for (int k2 = 0; k2 < this.valueToFindFValueFor.length; ++k2) {
                    int id = ArrayUtils.find(this.currFValuesNames, this.currNames[k2]);
                    if (!(Math.abs(this.valueToFindFValueFor[k2] - cs[id].atd(iRow)) < this.eps)) continue;
                    ++count2;
                }
                if (count2 == this.valueToFindFValueFor.length) {
                    nc[0].addNum(cs[0].atd(iRow));
                    continue;
                }
                count2 = 0;
            }
        }
    }

    private static class Transform
    extends MRTask<Transform> {
        double result;
        int power;

        Transform(int power) {
            this.power = power;
        }

        @Override
        public void map(Chunk[] bvs) {
            this.result = 0.0;
            int len = bvs[0]._len;
            for (int i2 = 0; i2 < len; ++i2) {
                this.result += Math.pow(bvs[0].atd(i2), 2.0) * bvs[1].atd(i2);
            }
        }

        @Override
        public void reduce(Transform mrt) {
            this.result += mrt.result;
        }
    }
}

