/*
 * Decompiled with CFR 0.152.
 */
package hex.tree;

import ai.h2o.algos.tree.INode;
import hex.Model;
import hex.genmodel.algos.tree.ContributionComposer;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.genmodel.algos.tree.TreeSHAP;
import hex.genmodel.algos.tree.TreeSHAPEnsemble;
import hex.genmodel.algos.tree.TreeSHAPPredictor;
import hex.tree.SharedTreeModel;
import java.util.ArrayList;
import java.util.Arrays;
import water.Job;
import water.JobUpdatePostMap;
import water.Key;
import water.MRTask;
import water.MemoryManager;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.util.ArrayUtils;

public abstract class SharedTreeModelWithContributions<M extends SharedTreeModel<M, P, O>, P extends SharedTreeModel.SharedTreeParameters, O extends SharedTreeModel.SharedTreeOutput>
extends SharedTreeModel<M, P, O>
implements Model.Contributions {
    public SharedTreeModelWithContributions(Key<M> selfKey, P parms, O output) {
        super(selfKey, parms, output);
    }

    @Override
    public Frame scoreContributions(Frame frame, Key<Frame> destination_key) {
        return this.scoreContributions(frame, destination_key, null);
    }

    private Frame removeSpecialColumns(Frame frame) {
        Frame adaptFrm = new Frame(frame);
        this.adaptTestForTrain(adaptFrm, true, false);
        adaptFrm.remove(((SharedTreeModel.SharedTreeParameters)this._parms)._response_column);
        adaptFrm.remove(((SharedTreeModel.SharedTreeParameters)this._parms)._fold_column);
        adaptFrm.remove(((SharedTreeModel.SharedTreeParameters)this._parms)._weights_column);
        adaptFrm.remove(((SharedTreeModel.SharedTreeParameters)this._parms)._offset_column);
        return adaptFrm;
    }

    @Override
    public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Frame> j2) {
        if (((SharedTreeModel.SharedTreeOutput)this._output).nclasses() > 2) {
            throw new UnsupportedOperationException("Calculating contributions is currently not supported for multinomial models.");
        }
        Frame adaptFrm = this.removeSpecialColumns(frame);
        String[] outputNames = ArrayUtils.append(adaptFrm.names(), "BiasTerm");
        return ((ScoreContributionsTask)this.getScoreContributionsTask(this).withPostMapAction(JobUpdatePostMap.forJob(j2)).doAll(outputNames.length, (byte)3, adaptFrm)).outputFrame(destination_key, outputNames, null);
    }

    @Override
    public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Frame> j2, Model.Contributions.ContributionsOptions options) {
        if (((SharedTreeModel.SharedTreeOutput)this._output).nclasses() > 2) {
            throw new UnsupportedOperationException("Calculating contributions is currently not supported for multinomial models.");
        }
        if (options._outputFormat == Model.Contributions.ContributionsOutputFormat.Compact) {
            throw new UnsupportedOperationException("Only output_format \"Original\" is supported for this model.");
        }
        if (!options.isSortingRequired()) {
            return this.scoreContributions(frame, destination_key, j2);
        }
        Frame adaptFrm = this.removeSpecialColumns(frame);
        String[] contribNames = ArrayUtils.append(adaptFrm.names(), "BiasTerm");
        ContributionComposer contributionComposer = new ContributionComposer();
        int topNAdjusted = contributionComposer.checkAndAdjustInput(options._topN, adaptFrm.names().length);
        int bottomNAdjusted = contributionComposer.checkAndAdjustInput(options._bottomN, adaptFrm.names().length);
        int outputSize = Math.min((topNAdjusted + bottomNAdjusted) * 2, adaptFrm.names().length * 2);
        String[] names = new String[outputSize + 1];
        byte[] types = new byte[outputSize + 1];
        String[][] domains = new String[outputSize + 1][contribNames.length];
        this.composeScoreContributionTaskMetadata(names, types, domains, adaptFrm.names(), options);
        return ((ScoreContributionsTask)this.getScoreContributionsSoringTask(this, options).withPostMapAction(JobUpdatePostMap.forJob(j2)).doAll(types, adaptFrm)).outputFrame(destination_key, names, domains);
    }

    protected abstract ScoreContributionsTask getScoreContributionsTask(SharedTreeModel var1);

    protected abstract ScoreContributionsTask getScoreContributionsSoringTask(SharedTreeModel var1, Model.Contributions.ContributionsOptions var2);

    public class ScoreContributionsSortingTask
    extends ScoreContributionsTask {
        private final int _topN;
        private final int _bottomN;
        private final boolean _compareAbs;

        public ScoreContributionsSortingTask(SharedTreeModel model, Model.Contributions.ContributionsOptions options) {
            super(model);
            this._topN = options._topN;
            this._bottomN = options._bottomN;
            this._compareAbs = options._compareAbs;
        }

        protected void fillInput(Chunk[] chks, int row, double[] input, float[] contribs, int[] contribNameIds) {
            super.fillInput(chks, row, input, contribs);
            for (int i2 = 0; i2 < contribNameIds.length; ++i2) {
                contribNameIds[i2] = i2;
            }
        }

        @Override
        public void map(Chunk[] chks, NewChunk[] nc) {
            double[] input = MemoryManager.malloc8d(chks.length);
            float[] contribs = MemoryManager.malloc4f(chks.length + 1);
            int[] contribNameIds = MemoryManager.malloc4(chks.length + 1);
            TreeSHAPPredictor.Workspace workspace = this._treeSHAP.makeWorkspace();
            for (int row = 0; row < chks[0]._len; ++row) {
                this.fillInput(chks, row, input, contribs, contribNameIds);
                this._treeSHAP.calculateContributions(input, contribs, 0, -1, workspace);
                this.doModelSpecificComputation(contribs);
                ContributionComposer contributionComposer = new ContributionComposer();
                int[] contribNameIdsSorted = contributionComposer.composeContributions(contribNameIds, contribs, this._topN, this._bottomN, this._compareAbs);
                this.addContribToNewChunk(contribs, contribNameIdsSorted, nc);
            }
        }

        protected void addContribToNewChunk(float[] contribs, int[] contribNameIdsSorted, NewChunk[] nc) {
            int i2 = 0;
            int inputPointer = 0;
            while (i2 < nc.length - 1) {
                nc[i2].addNum(contribNameIdsSorted[inputPointer]);
                nc[i2 + 1].addNum(contribs[contribNameIdsSorted[inputPointer]]);
                i2 += 2;
                ++inputPointer;
            }
            nc[nc.length - 1].addNum(contribs[contribs.length - 1]);
        }
    }

    public class ScoreContributionsTask
    extends MRTask<ScoreContributionsTask> {
        protected final Key<SharedTreeModel> _modelKey;
        protected transient SharedTreeModel _model;
        protected transient SharedTreeModel.SharedTreeOutput _output;
        protected transient TreeSHAPPredictor<double[]> _treeSHAP;

        public ScoreContributionsTask(SharedTreeModel model) {
            this._modelKey = model._key;
        }

        @Override
        protected void setupLocal() {
            this._model = this._modelKey.get();
            assert (this._model != null);
            this._output = (SharedTreeModel.SharedTreeOutput)this._model._output;
            assert (this._output != null);
            ArrayList treeSHAPs = new ArrayList(this._output._ntrees);
            for (int treeIdx = 0; treeIdx < this._output._ntrees; ++treeIdx) {
                for (int treeClass = 0; treeClass < this._output._treeKeys[treeIdx].length; ++treeClass) {
                    if (this._output._treeKeys[treeIdx][treeClass] == null) continue;
                    SharedTreeSubgraph tree = this._model.getSharedTreeSubgraph(treeIdx, treeClass);
                    INode[] nodes = tree.getNodes();
                    treeSHAPs.add(new TreeSHAP(nodes));
                }
            }
            assert (treeSHAPs.size() == this._output._ntrees);
            this._treeSHAP = new TreeSHAPEnsemble<double[]>(treeSHAPs, (float)this._output._init_f);
        }

        protected void fillInput(Chunk[] chks, int row, double[] input, float[] contribs) {
            for (int i2 = 0; i2 < chks.length; ++i2) {
                input[i2] = chks[i2].atd(row);
            }
            Arrays.fill(contribs, 0.0f);
        }

        @Override
        public void map(Chunk[] chks, NewChunk[] nc) {
            assert (chks.length == nc.length - 1);
            double[] input = MemoryManager.malloc8d(chks.length);
            float[] contribs = MemoryManager.malloc4f(nc.length);
            TreeSHAPPredictor.Workspace workspace = this._treeSHAP.makeWorkspace();
            for (int row = 0; row < chks[0]._len; ++row) {
                this.fillInput(chks, row, input, contribs);
                this._treeSHAP.calculateContributions(input, contribs, 0, -1, workspace);
                this.doModelSpecificComputation(contribs);
                this.addContribToNewChunk(contribs, nc);
            }
        }

        protected void doModelSpecificComputation(float[] contribs) {
        }

        protected void addContribToNewChunk(float[] contribs, NewChunk[] nc) {
            for (int i2 = 0; i2 < nc.length; ++i2) {
                nc[i2].addNum(contribs[i2]);
            }
        }
    }
}

