/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.mojos.runtime;

import ai.h2o.mojos.runtime.OriginalMatrix;
import ai.h2o.mojos.runtime.PipelineWiring;
import ai.h2o.mojos.runtime.ShapBuffers;
import ai.h2o.mojos.runtime.a.a;
import ai.h2o.mojos.runtime.a.b;
import ai.h2o.mojos.runtime.api.MojoColumnMeta;
import ai.h2o.mojos.runtime.frame.MojoColumn;
import ai.h2o.mojos.runtime.frame.MojoFrame;
import ai.h2o.mojos.runtime.frame.MojoFrameMeta;
import ai.h2o.mojos.runtime.transforms.C;
import ai.h2o.mojos.runtime.transforms.L;
import ai.h2o.mojos.runtime.transforms.MojoTransform;
import ai.h2o.mojos.runtime.transforms.MojoTransformExecPipeBuilder;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ShapBlender {
    private static final Logger log = LoggerFactory.getLogger(ShapBlender.class);
    private final MojoTransformExecPipeBuilder root;
    private final List<MojoColumnMeta> globalColumns;
    private final boolean shapOriginal;
    private final Map<L, List<Integer>> pcIndicesByTransform = new LinkedHashMap<L, List<Integer>>();
    private final Map<String, Integer> shapColumnsByName = new LinkedHashMap<String, Integer>();
    private Map<Integer, a> scalerByOutputColumn;

    public ShapBlender(List<MojoColumnMeta> globalColumns, MojoTransformExecPipeBuilder root, boolean shapOriginal) {
        this.root = root;
        this.globalColumns = globalColumns;
        this.shapOriginal = shapOriginal;
    }

    public Set<Integer> prepareShapColumns(PipelineWiring wiring) {
        Object object;
        LinkedHashSet<Integer> linkedHashSet = new LinkedHashSet<Integer>();
        int n2 = 0;
        for (MojoTransform object2 : wiring.shapTransforms) {
            Set<String> set = wiring.getGroupInputColumns(object2.getTransformationGroup(), object2.iindices);
            int n3 = this.root.pipelineMeta.probabilityComplementDetected ? 1 : (this.root.pipelineMeta.outputClassLabels == null ? this.root.oindices.length : this.root.pipelineMeta.outputClassLabels.size());
            ShapBlender shapBlender = this;
            object = shapBlender.buildShapColumns(shapBlender.shapColumnsByName, set, n3);
            if (object2 instanceof C) {
                this.pcIndicesByTransform.put((L)((Object)object2), object.subList(n2 % object.size(), n2 % object.size() + 1));
                ++n2;
            } else {
                this.pcIndicesByTransform.put((L)((Object)object2), (List<Integer>)object);
            }
            linkedHashSet.addAll((Collection<Integer>)object);
        }
        int n4 = this.root.pipelineMeta.outputClassLabels == null ? 1 : this.root.pipelineMeta.outputClassLabels.size();
        this.scalerByOutputColumn = b.a(wiring, this.root.oindices, n4);
        if (this.scalerByOutputColumn.size() == 0) {
            this.scalerByOutputColumn = new LinkedHashMap<Integer, a>();
            ai.h2o.mojos.runtime.c.b b2 = new ai.h2o.mojos.runtime.c.b();
            for (L l2 : this.pcIndicesByTransform.keySet()) {
                object = (MojoTransform)((Object)l2);
                int[] nArray = ((MojoTransform)object).oindices;
                int n5 = ((MojoTransform)object).oindices.length;
                for (int i2 = 0; i2 < n5; ++i2) {
                    int n6 = nArray[i2];
                    this.scalerByOutputColumn.put(n6, b2);
                }
            }
        }
        if (this.shapOriginal) {
            return this.switchToShapOriginalColumns();
        }
        return linkedHashSet;
    }

    private Set<Integer> switchToShapOriginalColumns() {
        LinkedHashSet<String> linkedHashSet = new LinkedHashSet<String>();
        int[] nArray = this.root.iindices;
        int n2 = this.root.iindices.length;
        for (int i2 = 0; i2 < n2; ++i2) {
            int n3 = nArray[i2];
            linkedHashSet.add(this.globalColumns.get(n3).getColumnName());
        }
        int n4 = this.root.pipelineMeta.probabilityComplementDetected ? 1 : this.root.oindices.length;
        ShapBlender shapBlender = this;
        List<Integer> list = shapBlender.buildShapColumns(shapBlender.shapColumnsByName, linkedHashSet, n4);
        log.trace("Original SHAP column indices are: {}", (Object)list);
        return new LinkedHashSet<Integer>(list);
    }

    private List<Integer> buildShapColumns(Map<String, Integer> shapColumnsByName, Set<String> inputColNames, int ocnt) {
        ArrayList<Integer> arrayList = new ArrayList<Integer>(ocnt * (inputColNames.size() + 1));
        for (int i2 = 0; i2 < ocnt; ++i2) {
            String string = ocnt > 1 ? "." + this.root.pipelineMeta.outputClassLabels.get(i2) : "";
            for (String string2 : inputColNames) {
                arrayList.add(this.shapColumn(shapColumnsByName, "contrib_" + string2 + string));
            }
            arrayList.add(this.shapColumn(shapColumnsByName, "contrib_bias" + string));
        }
        return arrayList;
    }

    private int shapColumn(Map<String, Integer> shapColumnsByName, String name) {
        Integer n2 = shapColumnsByName.get(name);
        if (n2 == null) {
            n2 = this.globalColumns.size();
            MojoColumnMeta mojoColumnMeta = MojoColumnMeta.create(name, MojoColumn.Type.Float64);
            shapColumnsByName.put(name, n2);
            this.globalColumns.add(mojoColumnMeta);
        }
        return n2;
    }

    private a getScaler(int oindex) {
        a a2 = this.scalerByOutputColumn.get(oindex);
        if (a2 == null) {
            MojoColumnMeta mojoColumnMeta = this.globalColumns.get(oindex);
            throw new IllegalStateException(String.format("Error in blender - no scaler found for column %d('%s')", oindex, mojoColumnMeta.getColumnName()));
        }
        return a2;
    }

    private List<Integer> getShapColumnIndices(L transform2) {
        List<Integer> list = this.pcIndicesByTransform.get(transform2);
        if (list == null) {
            throw new IllegalStateException(String.format("Shap indices not available for transform: %s", transform2));
        }
        log.trace("pcIndices ~ {}", (Object)MojoFrameMeta.debugIndicesToNames(this.globalColumns, list));
        return list;
    }

    public void computeShap(MojoFrame globalFrame, MojoTransform transform2) {
        L l2 = (L)((Object)transform2);
        OriginalMatrix originalMatrix = l2.a();
        if (this.shapOriginal && originalMatrix == null) {
            throw new UnsupportedOperationException("Missing original matrix - cannot compute original SHAP for " + transform2);
        }
        List<Integer> list = this.getShapColumnIndices(l2);
        double[][] dArrayArray = new double[list.size()][];
        for (int i2 = 0; i2 < dArrayArray.length; ++i2) {
            dArrayArray[i2] = (double[])globalFrame.getColumnData(list.get(i2));
        }
        ShapBuffers shapBuffers = new ShapBuffers(transform2);
        int n2 = globalFrame.getNrows();
        for (int i3 = 0; i3 < n2; ++i3) {
            double[] dArray = shapBuffers.prepareShapInputs(globalFrame, i3);
            double[][] dArray2 = shapBuffers.prepareShapOutputs();
            l2.a(dArray, dArray2);
            int n3 = 0;
            for (int i4 = 0; i4 < dArray2.length; ++i4) {
                int n4 = transform2.oindices[i4];
                a a2 = this.getScaler(n4);
                int n5 = dArray2[i4].length;
                for (int i5 = 0; i5 < n5; ++i5) {
                    double d2 = dArray2[i4][i5];
                    String string = globalFrame.getColumnName(list.get(n3));
                    if (Double.isNaN(d2)) {
                        throw new IllegalStateException(String.format("Row %d: %s(%s) did not compute shapOutput[%d][%d] : `%s`", i3, transform2.getId(), transform2.getClass().getName(), i4, i5, string));
                    }
                    boolean bl = i5 == n5 - 1;
                    double d3 = a2.a(d2, globalFrame, i3, bl);
                    double[] dArray3 = dArrayArray[n3];
                    int n6 = i3;
                    dArray3[n6] = dArray3[n6] + d3;
                    if (this.shapOriginal) {
                        originalMatrix.incrementOrigShap(globalFrame, i3, string, d3);
                    }
                    ++n3;
                }
            }
        }
    }
}

