/*
 * Decompiled with CFR 0.152.
 */
package com.intel.analytics.bigdl.dllib.nn.mkldnn;

import com.intel.analytics.bigdl.dllib.nn.MklInt8Convertible;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.HeapData;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.HeapData$;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.MemoryData;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.MemoryOwner;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.MklDnnModule;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.MklDnnRuntime;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.Phase$InferencePhase$;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.ReorderMemory;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.ReorderMemory$;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.SpatialConvolution;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath$TensorNumeric$NumericFloat$;
import com.intel.analytics.bigdl.dllib.utils.Log4Error$;
import com.intel.analytics.bigdl.dllib.utils.T$;
import com.intel.analytics.bigdl.dllib.utils.Table;
import scala.Array$;
import scala.Function1;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.GenIterable;
import scala.collection.Seq;
import scala.collection.immutable.Nil$;
import scala.math.Ordering;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;

public final class Utils$ {
    public static final Utils$ MODULE$;

    static {
        new Utils$();
    }

    public void copyMaskAndScales(MemoryData from2, MemoryData to) {
        if (from2 != null && to != null && Predef$.MODULE$.floatArrayOps(to.scales()).isEmpty()) {
            to.setScales((float[])from2.scales().clone());
            to.setMask(from2.mask());
        }
    }

    public void copyMaskAndScales(MemoryData[] from2, MemoryData[] to) {
        boolean needCopy;
        if (from2 == null || to == null) {
            return;
        }
        boolean valid = from2.length == 1 || to.length == 1 || from2.length == to.length;
        boolean bl = needCopy = from2 != to && Predef$.MODULE$.refArrayOps((Object[])from2).forall((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final boolean apply(MemoryData x$1) {
                return Predef$.MODULE$.floatArrayOps(x$1.scales()).nonEmpty();
            }
        }) && Predef$.MODULE$.refArrayOps((Object[])to).forall((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final boolean apply(MemoryData x$2) {
                return Predef$.MODULE$.floatArrayOps(x$2.scales()).isEmpty();
            }
        });
        if (valid && needCopy) {
            if (from2.length == to.length) {
                Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.refArrayOps((Object[])to).zip((GenIterable)Predef$.MODULE$.wrapRefArray((Object[])from2), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).foreach((Function1)new Serializable(){
                    public static final long serialVersionUID = 0L;

                    public final void apply(Tuple2<MemoryData, MemoryData> x) {
                        if (Predef$.MODULE$.floatArrayOps(((MemoryData)x._1()).scales()).isEmpty()) {
                            ((MemoryData)x._1()).setScales(((MemoryData)x._2()).scales());
                            ((MemoryData)x._1()).setMask(((MemoryData)x._2()).mask());
                        }
                    }
                });
            } else if (to.length == 1) {
                ((MemoryData)Predef$.MODULE$.refArrayOps((Object[])to).head()).setScales((float[])Predef$.MODULE$.refArrayOps(Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.refArrayOps((Object[])from2).map((Function1)new Serializable(){
                    public static final long serialVersionUID = 0L;

                    public final float[] apply(MemoryData x$3) {
                        return x$3.scales();
                    }
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE))))).transpose((Function1)Predef$.MODULE$.$conforms())).map((Function1)new Serializable(){
                    public static final long serialVersionUID = 0L;

                    public final float apply(float[] x$4) {
                        return BoxesRunTime.unboxToFloat((Object)Predef$.MODULE$.floatArrayOps(x$4).max((Ordering)Ordering.Float$.MODULE$));
                    }
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Float())));
                Log4Error$.MODULE$.invalidInputError(((int[])Predef$.MODULE$.intArrayOps((int[])Predef$.MODULE$.refArrayOps((Object[])from2).map((Function1)new Serializable(){
                    public static final long serialVersionUID = 0L;

                    public final int apply(MemoryData x$5) {
                        return x$5.mask();
                    }
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int()))).distinct()).length == 1, new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"only support the same mask"})).s((Seq)Nil$.MODULE$), Log4Error$.MODULE$.invalidInputError$default$3());
                ((MemoryData)Predef$.MODULE$.refArrayOps((Object[])to).head()).setMask(BoxesRunTime.unboxToInt((Object)Predef$.MODULE$.intArrayOps((int[])Predef$.MODULE$.intArrayOps((int[])Predef$.MODULE$.refArrayOps((Object[])from2).map((Function1)new Serializable(){
                    public static final long serialVersionUID = 0L;

                    public final int apply(MemoryData x$6) {
                        return x$6.mask();
                    }
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int()))).distinct()).head()));
            } else if (to.length > 1) {
                Predef$.MODULE$.refArrayOps((Object[])to).foreach((Function1)new Serializable(from2){
                    public static final long serialVersionUID = 0L;
                    private final MemoryData[] from$1;

                    public final void apply(MemoryData x$7) {
                        x$7.setScales(((MemoryData)Predef$.MODULE$.refArrayOps((Object[])this.from$1).head()).scales());
                    }
                    {
                        this.from$1 = from$1;
                    }
                });
                Predef$.MODULE$.refArrayOps((Object[])to).foreach((Function1)new Serializable(from2){
                    public static final long serialVersionUID = 0L;
                    private final MemoryData[] from$1;

                    public final void apply(MemoryData x$8) {
                        x$8.setMask(((MemoryData)Predef$.MODULE$.refArrayOps((Object[])this.from$1).head()).mask());
                    }
                    {
                        this.from$1 = from$1;
                    }
                });
            }
        }
    }

    public int getDefaultFormat(MemoryData memoryData, boolean isInOrOut) {
        int n;
        int n2 = memoryData.shape().length;
        switch (n2) {
            default: {
                Log4Error$.MODULE$.invalidOperationError(false, new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"unexpected shape ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{memoryData.shape()})), "Linear only supports 2-D or 4-D", Log4Error$.MODULE$.invalidOperationError$default$4());
                n = 0;
                break;
            }
            case 4: {
                if (isInOrOut) {
                    n = 7;
                    break;
                }
                n = 16;
                break;
            }
            case 2: {
                n = isInOrOut ? 4 : 12;
            }
        }
        return n;
    }

    public boolean getDefaultFormat$default$2() {
        return true;
    }

    private Tensor<Object> denseTensor(MemoryData format2, Tensor<Object> tensor, boolean isInOrOut, MklDnnRuntime runtime) {
        HeapData x$9 = new HeapData(format2.shape(), this.getDefaultFormat(format2, isInOrOut), HeapData$.MODULE$.apply$default$3());
        MemoryData x$10 = ReorderMemory$.MODULE$.apply$default$2();
        MemoryOwner x$11 = ReorderMemory$.MODULE$.apply$default$3(x$9, x$10);
        ReorderMemory reorder = ReorderMemory$.MODULE$.apply(x$9, x$10, x$11);
        reorder.setRuntime(runtime);
        reorder.initFwdPrimitives((MemoryData[])((Object[])new MemoryData[]{format2}), Phase$InferencePhase$.MODULE$);
        return reorder.forward(tensor).toTensor(TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
    }

    private boolean denseTensor$default$3() {
        return true;
    }

    private Activity denseActivity(MemoryData[] formats, Activity activity, boolean isInOrOut, MklDnnRuntime runtime) {
        Activity activity2;
        if (formats.length > 1) {
            Log4Error$.MODULE$.invalidInputError(formats.length == activity.toTable().length(), new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"formats should be the same as activity"})).s((Seq)Nil$.MODULE$), Log4Error$.MODULE$.invalidInputError$default$3());
            Table table = T$.MODULE$.apply();
            for (int i = 1; i <= formats.length; ++i) {
                MemoryData format2 = formats[i - 1];
                Tensor tensor = (Tensor)activity.toTable().get(BoxesRunTime.boxToInteger((int)i)).get();
                table.update(BoxesRunTime.boxToInteger((int)i), this.denseTensor(format2, tensor, isInOrOut, runtime));
            }
            activity2 = table;
        } else {
            activity2 = this.denseTensor(formats[0], activity.toTensor(TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$), isInOrOut, runtime);
        }
        Tensor<Object> ret = activity2;
        return ret;
    }

    private boolean denseActivity$default$3() {
        return true;
    }

    public Activity getDenseIn(MklInt8Convertible module, Activity input) {
        Activity activity;
        if (module instanceof MklDnnModule) {
            MklDnnModule mklDnnLayer = (MklDnnModule)((Object)module);
            activity = this.denseActivity(mklDnnLayer.inputFormats(), input, true, mklDnnLayer.getRuntime());
        } else {
            activity = input;
        }
        return activity;
    }

    public Activity getDenseOut(MklInt8Convertible module, Activity output) {
        Activity activity;
        if (module instanceof MklDnnModule) {
            MklDnnModule mklDnnLayer = (MklDnnModule)((Object)module);
            activity = this.denseActivity(mklDnnLayer.outputFormats(), output, true, mklDnnLayer.getRuntime());
        } else {
            activity = output;
        }
        return activity;
    }

    private void setConvNegativeInput(MklInt8Convertible module, Activity input) {
        if (module instanceof SpatialConvolution) {
            SpatialConvolution conv = (SpatialConvolution)module;
            Activity denseIn = this.getDenseIn(module, input);
            float min2 = BoxesRunTime.unboxToFloat(denseIn.toTensor(TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$).min());
            if (min2 >= 0.0f) {
                conv.negativeInput_$eq(false);
            }
        }
    }

    public void calcScales(AbstractModule<?, ?, ?> module, Activity input) {
        AbstractModule<?, ?, ?> abstractModule = module;
        if (abstractModule instanceof MklInt8Convertible) {
            AbstractModule<?, ?, ?> abstractModule2 = abstractModule;
            ((MklInt8Convertible)((Object)abstractModule2)).calcScales(input);
            this.setConvNegativeInput((MklInt8Convertible)((Object)abstractModule2), input);
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
    }

    public Activity getOutput(AbstractModule<?, ?, ?> module, Activity input) {
        AbstractModule<?, ?, ?> abstractModule = module;
        Object obj = abstractModule instanceof MklDnnModule ? module.output() : module.output();
        return obj;
    }

    private Utils$() {
        MODULE$ = this;
    }
}

