/*
 * Decompiled with CFR 0.152.
 */
package com.intel.analytics.bigdl.dllib.nn;

import com.intel.analytics.bigdl.dllib.nn.CAddTable;
import com.intel.analytics.bigdl.dllib.nn.ConcatTable;
import com.intel.analytics.bigdl.dllib.nn.Container;
import com.intel.analytics.bigdl.dllib.nn.DynamicContainer;
import com.intel.analytics.bigdl.dllib.nn.Graph;
import com.intel.analytics.bigdl.dllib.nn.MklInt8Convertible;
import com.intel.analytics.bigdl.dllib.nn.ReLU;
import com.intel.analytics.bigdl.dllib.nn.SpatialBatchNormalization;
import com.intel.analytics.bigdl.dllib.nn.Utils$;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.Linear;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.Sequential;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.SpatialConvolution;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath$TensorNumeric$NumericFloat$;
import com.intel.analytics.bigdl.dllib.utils.Log4Error$;
import com.intel.analytics.bigdl.dllib.utils.Node;
import com.intel.analytics.bigdl.dllib.utils.Table;
import scala.Function1;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.ArrayBuffer$;
import scala.collection.mutable.StringBuilder;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;

public abstract class MklInt8Convertible$class {
    public static void calcScales(MklInt8Convertible $this, Activity inputActvt) {
        if (inputActvt != null) {
            AbstractModule module = (AbstractModule)((Object)$this);
            Object outputActvt = module.output();
            AbstractModule abstractModule = module;
            if (abstractModule instanceof Graph) {
                MklInt8Convertible$class.calcGraphScales($this, inputActvt, outputActvt);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else if (abstractModule instanceof com.intel.analytics.bigdl.dllib.nn.Linear) {
                com.intel.analytics.bigdl.dllib.nn.Linear linear = (com.intel.analytics.bigdl.dllib.nn.Linear)abstractModule;
                MklInt8Convertible$class.calcModuleScales($this, inputActvt, outputActvt, MklInt8Convertible$class.getWeight($this, linear));
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else if (abstractModule instanceof com.intel.analytics.bigdl.dllib.nn.SpatialConvolution) {
                com.intel.analytics.bigdl.dllib.nn.SpatialConvolution spatialConvolution = (com.intel.analytics.bigdl.dllib.nn.SpatialConvolution)abstractModule;
                MklInt8Convertible$class.calcModuleScales($this, inputActvt, outputActvt, MklInt8Convertible$class.getWeight($this, spatialConvolution));
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else if (abstractModule instanceof ReLU) {
                MklInt8Convertible$class.calcModuleScales($this, inputActvt, outputActvt);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else if (abstractModule instanceof CAddTable) {
                MklInt8Convertible$class.calcModuleScales($this, inputActvt, outputActvt);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else if (abstractModule instanceof SpatialBatchNormalization) {
                MklInt8Convertible$class.calcModuleScales($this, inputActvt, outputActvt);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else if (abstractModule instanceof com.intel.analytics.bigdl.dllib.nn.Sequential) {
                MklInt8Convertible$class.calcSequentialScales($this, inputActvt, outputActvt);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else if (abstractModule instanceof ConcatTable) {
                MklInt8Convertible$class.calcConcatTableScales($this, inputActvt, outputActvt);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else if (abstractModule instanceof Linear) {
                Linear linear = (Linear)abstractModule;
                MklInt8Convertible$class.calcModuleScales($this, inputActvt, outputActvt, MklInt8Convertible$class.getWeight($this, linear));
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else if (abstractModule instanceof SpatialConvolution) {
                SpatialConvolution spatialConvolution = (SpatialConvolution)abstractModule;
                MklInt8Convertible$class.calcModuleScales($this, inputActvt, outputActvt, MklInt8Convertible$class.getWeight($this, spatialConvolution));
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else if (abstractModule instanceof Sequential) {
                MklInt8Convertible$class.calcSequentialScales($this, inputActvt, outputActvt);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else if (abstractModule instanceof com.intel.analytics.bigdl.dllib.nn.mkldnn.ConcatTable) {
                MklInt8Convertible$class.calcConcatTableScales($this, inputActvt, outputActvt);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else if (abstractModule instanceof com.intel.analytics.bigdl.dllib.nn.mkldnn.ReLU) {
                MklInt8Convertible$class.calcModuleScales($this, inputActvt, outputActvt);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else if (abstractModule instanceof com.intel.analytics.bigdl.dllib.nn.mkldnn.SpatialBatchNormalization) {
                MklInt8Convertible$class.calcModuleScales($this, inputActvt, outputActvt);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else if (abstractModule instanceof com.intel.analytics.bigdl.dllib.nn.mkldnn.CAddTable) {
                MklInt8Convertible$class.calcModuleScales($this, inputActvt, outputActvt);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                Log4Error$.MODULE$.invalidOperationError(false, new StringBuilder().append((Object)"Int8 conversion is not supported for module: ").append((Object)module.getName()).toString(), Log4Error$.MODULE$.invalidOperationError$default$3(), Log4Error$.MODULE$.invalidOperationError$default$4());
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            }
        }
    }

    public static void flushWeightScales(MklInt8Convertible $this, Tensor weight) {
        $this.weightScalesBuffer().clear();
        MklInt8Convertible$class.com$intel$analytics$bigdl$dllib$nn$MklInt8Convertible$$appendWeightScales($this, MklInt8Convertible$class.com$intel$analytics$bigdl$dllib$nn$MklInt8Convertible$$calcTensorScale($this, weight, $this.weightDimMask()));
    }

    private static void calcModuleScales(MklInt8Convertible $this, Activity inputActvt, Activity outputActvt) {
        if (inputActvt != null) {
            Activity denseIn = com.intel.analytics.bigdl.dllib.nn.mkldnn.Utils$.MODULE$.getDenseIn($this, inputActvt);
            Predef$.MODULE$.refArrayOps((Object[])MklInt8Convertible$class.calcActivityScales($this, denseIn, $this.inputDimMask())).foreach((Function1)new Serializable($this){
                public static final long serialVersionUID = 0L;
                private final /* synthetic */ MklInt8Convertible $outer;

                public final void apply(float[] scale) {
                    MklInt8Convertible$class.com$intel$analytics$bigdl$dllib$nn$MklInt8Convertible$$appendInputScales(this.$outer, scale);
                }
                {
                    if ($outer == null) {
                        throw null;
                    }
                    this.$outer = $outer;
                }
            });
        }
        if (outputActvt != null) {
            Activity denseOut = com.intel.analytics.bigdl.dllib.nn.mkldnn.Utils$.MODULE$.getDenseOut($this, outputActvt);
            Predef$.MODULE$.refArrayOps((Object[])MklInt8Convertible$class.calcActivityScales($this, denseOut, $this.outputDimMask())).foreach((Function1)new Serializable($this){
                public static final long serialVersionUID = 0L;
                private final /* synthetic */ MklInt8Convertible $outer;

                public final void apply(float[] scale) {
                    MklInt8Convertible$class.com$intel$analytics$bigdl$dllib$nn$MklInt8Convertible$$appendOutputScales(this.$outer, scale);
                }
                {
                    if ($outer == null) {
                        throw null;
                    }
                    this.$outer = $outer;
                }
            });
        }
    }

    private static void calcModuleScales(MklInt8Convertible $this, Activity inActivity, Activity outActivity, Tensor weightTensor) {
        MklInt8Convertible$class.calcModuleScales($this, inActivity, outActivity);
        MklInt8Convertible$class.com$intel$analytics$bigdl$dllib$nn$MklInt8Convertible$$appendWeightScales($this, MklInt8Convertible$class.com$intel$analytics$bigdl$dllib$nn$MklInt8Convertible$$calcTensorScale($this, weightTensor, $this.weightDimMask()));
    }

    private static float[][] calcActivityScales(MklInt8Convertible $this, Activity activity, int mask) {
        float[][] fArray;
        Activity activity2 = activity;
        if (activity2 instanceof Tensor) {
            fArray = (float[][])((Object[])new float[][]{MklInt8Convertible$class.com$intel$analytics$bigdl$dllib$nn$MklInt8Convertible$$calcTensorScale($this, activity.toTensor(TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$), mask)});
        } else if (activity2 instanceof Table) {
            fArray = (float[][])activity.toTable().map(new Serializable($this, mask){
                public static final long serialVersionUID = 0L;
                private final /* synthetic */ MklInt8Convertible $outer;
                private final int mask$1;

                public final float[] apply(Tuple2<Object, Object> elem) {
                    Object index = elem._1();
                    Tensor tensor = (Tensor)elem._2();
                    return MklInt8Convertible$class.com$intel$analytics$bigdl$dllib$nn$MklInt8Convertible$$calcTensorScale(this.$outer, tensor, this.mask$1);
                }
                {
                    if ($outer == null) {
                        throw null;
                    }
                    this.$outer = $outer;
                    this.mask$1 = mask$1;
                }
            }).toArray(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE)));
        } else {
            Log4Error$.MODULE$.unKnowExceptionError(false, new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Invalid activity ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{activity})), "only support Tensor and Table", Log4Error$.MODULE$.unKnowExceptionError$default$4());
            fArray = null;
        }
        return fArray;
    }

    public static float[] com$intel$analytics$bigdl$dllib$nn$MklInt8Convertible$$calcTensorScale(MklInt8Convertible $this, Tensor tensor, int mask) {
        float[] fArray;
        if (mask == 0) {
            float[] fArray2 = new float[1];
            fArray = fArray2;
            fArray2[0] = BoxesRunTime.unboxToFloat(tensor.clone().abs().max());
        } else {
            fArray = package$.MODULE$.pow(2.0, (double)tensor.dim()) - 1.0 == (double)mask ? (float[])tensor.clone().abs().storage().toArray(ClassTag$.MODULE$.Float()) : Utils$.MODULE$.calcScales(tensor, mask);
        }
        return fArray;
    }

    private static void calcSequentialScales(MklInt8Convertible $this, Activity inputActvt, Activity outputActvt) {
        Log4Error$.MODULE$.invalidOperationError($this instanceof com.intel.analytics.bigdl.dllib.nn.Sequential || $this instanceof Sequential, new StringBuilder().append((Object)$this.getClass().getName()).append((Object)" is not an instance of Sequential.").toString(), Log4Error$.MODULE$.invalidOperationError$default$3(), Log4Error$.MODULE$.invalidOperationError$default$4());
        DynamicContainer module = (DynamicContainer)((Object)$this);
        Activity prevOutputActivity = inputActvt;
        MklInt8Convertible$class.calcModuleScales($this, inputActvt, outputActvt);
        for (AbstractModule currModule : module.modules()) {
            if (currModule instanceof MklInt8Convertible) {
                MklInt8Convertible cvtbModule = (MklInt8Convertible)((Object)currModule);
                cvtbModule.calcScales(prevOutputActivity);
            }
            prevOutputActivity = currModule.output();
        }
    }

    private static void calcConcatTableScales(MklInt8Convertible $this, Activity inputActvt, Activity outputActvt) {
        Log4Error$.MODULE$.invalidOperationError($this instanceof ConcatTable || $this instanceof com.intel.analytics.bigdl.dllib.nn.mkldnn.ConcatTable, new StringBuilder().append((Object)$this.getClass().getName()).append((Object)" is not an instance of ConcatTable.").toString(), Log4Error$.MODULE$.invalidOperationError$default$3(), Log4Error$.MODULE$.invalidOperationError$default$4());
        DynamicContainer module = (DynamicContainer)((Object)$this);
        MklInt8Convertible$class.calcModuleScales($this, inputActvt, outputActvt);
        for (AbstractModule currModule : module.modules()) {
            if (!(currModule instanceof MklInt8Convertible)) continue;
            MklInt8Convertible cvtbModule = (MklInt8Convertible)((Object)currModule);
            cvtbModule.calcScales(inputActvt);
        }
    }

    private static void calcGraphScales(MklInt8Convertible $this, Activity inputActvt, Activity outputActvt) {
        Log4Error$.MODULE$.invalidOperationError($this instanceof Graph, new StringBuilder().append((Object)$this.getClass().getName()).append((Object)" is not an instance of Graph[Float]").toString(), Log4Error$.MODULE$.invalidOperationError$default$3(), Log4Error$.MODULE$.invalidOperationError$default$4());
        MklInt8Convertible$class.calcModuleScales($this, inputActvt, outputActvt);
        Graph module = (Graph)$this;
        Node<AbstractModule<Activity, Activity, T>>[] outputNodes = module.getForwardExecutions();
        for (int i = 0; i < outputNodes.length; ++i) {
            Node currNode = outputNodes[i];
            Activity currInputActvt = module.findInput(currNode, inputActvt);
            if (!(currNode.element() instanceof MklInt8Convertible)) continue;
            ((MklInt8Convertible)((Object)currNode.element())).calcScales(currInputActvt);
        }
    }

    private static Tensor getWeight(MklInt8Convertible $this, AbstractModule module) {
        Tensor tensor;
        if (module == null) {
            tensor = null;
        } else {
            Tensor weight = ((Tensor[])module.parameters()._1())[0];
            tensor = module instanceof com.intel.analytics.bigdl.dllib.nn.SpatialConvolution && weight.size(1) == 1 ? weight.select(1, 1) : weight;
        }
        return tensor;
    }

    public static int getInputDimMask(MklInt8Convertible $this) {
        return $this.inputDimMask();
    }

    public static void setInputDimMask(MklInt8Convertible $this, int mask, boolean overrideSubmodules) {
        $this.inputDimMask_$eq(mask);
        if ($this instanceof Container && overrideSubmodules) {
            Container container = (Container)((Object)$this);
            ArrayBuffer modules = container.modules();
            modules.foreach((Function1)new Serializable($this, mask, overrideSubmodules){
                public static final long serialVersionUID = 0L;
                private final int mask$2;
                private final boolean overrideSubmodules$1;

                public final void apply(AbstractModule<Activity, Activity, Object> module) {
                    if (module instanceof MklInt8Convertible) {
                        ((MklInt8Convertible)((Object)module)).setInputDimMask(this.mask$2, this.overrideSubmodules$1);
                    }
                }
                {
                    this.mask$2 = mask$2;
                    this.overrideSubmodules$1 = overrideSubmodules$1;
                }
            });
        }
    }

    public static boolean setInputDimMask$default$2(MklInt8Convertible $this) {
        return false;
    }

    public static int getOutputDimMask(MklInt8Convertible $this) {
        return $this.outputDimMask();
    }

    public static void setOutputDimMask(MklInt8Convertible $this, int mask, boolean overrideSubmodules) {
        $this.outputDimMask_$eq(mask);
        if ($this instanceof Container && overrideSubmodules) {
            Container container = (Container)((Object)$this);
            ArrayBuffer modules = container.modules();
            modules.foreach((Function1)new Serializable($this, mask, overrideSubmodules){
                public static final long serialVersionUID = 0L;
                private final int mask$3;
                private final boolean overrideSubmodules$2;

                public final void apply(AbstractModule<Activity, Activity, Object> module) {
                    if (module instanceof MklInt8Convertible) {
                        ((MklInt8Convertible)((Object)module)).setOutputDimMask(this.mask$3, this.overrideSubmodules$2);
                    }
                }
                {
                    this.mask$3 = mask$3;
                    this.overrideSubmodules$2 = overrideSubmodules$2;
                }
            });
        }
    }

    public static boolean setOutputDimMask$default$2(MklInt8Convertible $this) {
        return false;
    }

    public static int getWeightDimMask(MklInt8Convertible $this) {
        return $this.weightDimMask();
    }

    public static void setWeightDimMask(MklInt8Convertible $this, int mask, boolean overrideSubmodules) {
        $this.weightDimMask_$eq(mask);
        if ($this instanceof Container && overrideSubmodules) {
            Container container = (Container)((Object)$this);
            ArrayBuffer modules = container.modules();
            modules.foreach((Function1)new Serializable($this, mask, overrideSubmodules){
                public static final long serialVersionUID = 0L;
                private final int mask$4;
                private final boolean overrideSubmodules$3;

                public final void apply(AbstractModule<Activity, Activity, Object> module) {
                    if (module instanceof MklInt8Convertible) {
                        ((MklInt8Convertible)((Object)module)).setWeightDimMask(this.mask$4, this.overrideSubmodules$3);
                    }
                }
                {
                    this.mask$4 = mask$4;
                    this.overrideSubmodules$3 = overrideSubmodules$3;
                }
            });
        }
    }

    public static boolean setWeightDimMask$default$2(MklInt8Convertible $this) {
        return false;
    }

    public static float[][] getInputScales(MklInt8Convertible $this) {
        return (float[][])$this.inputScalesBuffer().toArray(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE)));
    }

    public static void setInputScales(MklInt8Convertible $this, float[][] inScales) {
        $this.inputScalesBuffer().clear();
        Predef$.MODULE$.refArrayOps((Object[])inScales).foreach((Function1)new Serializable($this){
            public static final long serialVersionUID = 0L;
            private final /* synthetic */ MklInt8Convertible $outer;

            public final void apply(float[] scale) {
                MklInt8Convertible$class.com$intel$analytics$bigdl$dllib$nn$MklInt8Convertible$$appendInputScales(this.$outer, scale);
            }
            {
                if ($outer == null) {
                    throw null;
                }
                this.$outer = $outer;
            }
        });
    }

    public static float[][] getOutputScales(MklInt8Convertible $this) {
        return (float[][])$this.outputScalesBuffer().toArray(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE)));
    }

    public static void setOutputScales(MklInt8Convertible $this, float[][] outScales) {
        $this.outputScalesBuffer().clear();
        Predef$.MODULE$.refArrayOps((Object[])outScales).foreach((Function1)new Serializable($this){
            public static final long serialVersionUID = 0L;
            private final /* synthetic */ MklInt8Convertible $outer;

            public final void apply(float[] scale) {
                MklInt8Convertible$class.com$intel$analytics$bigdl$dllib$nn$MklInt8Convertible$$appendOutputScales(this.$outer, scale);
            }
            {
                if ($outer == null) {
                    throw null;
                }
                this.$outer = $outer;
            }
        });
    }

    public static float[][] getWeightScales(MklInt8Convertible $this) {
        return (float[][])$this.weightScalesBuffer().toArray(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE)));
    }

    public static void setWeightScales(MklInt8Convertible $this, float[][] weightScales) {
        $this.weightScalesBuffer().clear();
        Predef$.MODULE$.refArrayOps((Object[])weightScales).foreach((Function1)new Serializable($this){
            public static final long serialVersionUID = 0L;
            private final /* synthetic */ MklInt8Convertible $outer;

            public final void apply(float[] scale) {
                MklInt8Convertible$class.com$intel$analytics$bigdl$dllib$nn$MklInt8Convertible$$appendWeightScales(this.$outer, scale);
            }
            {
                if ($outer == null) {
                    throw null;
                }
                this.$outer = $outer;
            }
        });
    }

    public static void com$intel$analytics$bigdl$dllib$nn$MklInt8Convertible$$appendInputScales(MklInt8Convertible $this, float[] scale) {
        $this.inputScalesBuffer().append((Seq)Predef$.MODULE$.wrapRefArray((Object[])new float[][]{scale}));
    }

    public static void com$intel$analytics$bigdl$dllib$nn$MklInt8Convertible$$appendOutputScales(MklInt8Convertible $this, float[] scale) {
        $this.outputScalesBuffer().append((Seq)Predef$.MODULE$.wrapRefArray((Object[])new float[][]{scale}));
    }

    public static void com$intel$analytics$bigdl$dllib$nn$MklInt8Convertible$$appendWeightScales(MklInt8Convertible $this, float[] scale) {
        $this.weightScalesBuffer().append((Seq)Predef$.MODULE$.wrapRefArray((Object[])new float[][]{scale}));
    }

    private static void updateInputScales(MklInt8Convertible $this, float[] scale, int index) {
        MklInt8Convertible$class.updateScalesHelper($this, $this.inputScalesBuffer(), scale, index);
    }

    private static void updateOutputScales(MklInt8Convertible $this, float[] scale, int index) {
        MklInt8Convertible$class.updateScalesHelper($this, $this.outputScalesBuffer(), scale, index);
    }

    private static void updateWeightScales(MklInt8Convertible $this, float[] scale, int index) {
        MklInt8Convertible$class.updateScalesHelper($this, $this.weightScalesBuffer(), scale, index);
    }

    private static void updateScalesHelper(MklInt8Convertible $this, ArrayBuffer scales, float[] scale, int index) {
        if (scales.length() - 1 < index) {
            scales.append((Seq)Predef$.MODULE$.wrapRefArray((Object[])new float[][]{scale}));
        }
        Predef$.MODULE$.floatArrayOps((float[])scales.apply(index)).indices().foreach$mVc$sp((Function1)new Serializable($this, scales, scale, index){
            public static final long serialVersionUID = 0L;
            private final ArrayBuffer scales$1;
            private final float[] scale$1;
            private final int index$1;

            public final void apply(int i) {
                this.apply$mcVI$sp(i);
            }

            public void apply$mcVI$sp(int i) {
                if (this.scale$1[i] > ((float[])this.scales$1.apply(this.index$1))[i]) {
                    ((float[])this.scales$1.apply((int)this.index$1))[i] = this.scale$1[i];
                }
            }
            {
                this.scales$1 = scales$1;
                this.scale$1 = scale$1;
                this.index$1 = index$1;
            }
        });
    }

    public static void $init$(MklInt8Convertible $this) {
        $this.inputDimMask_$eq(0);
        $this.outputDimMask_$eq(0);
        $this.weightDimMask_$eq(0);
        $this.inputScalesBuffer_$eq((ArrayBuffer<float[]>)((ArrayBuffer)ArrayBuffer$.MODULE$.empty()));
        $this.outputScalesBuffer_$eq((ArrayBuffer<float[]>)((ArrayBuffer)ArrayBuffer$.MODULE$.empty()));
        $this.weightScalesBuffer_$eq((ArrayBuffer<float[]>)((ArrayBuffer)ArrayBuffer$.MODULE$.empty()));
    }
}

