/*
 * Decompiled with CFR 0.152.
 */
package com.intel.analytics.bigdl.dllib.nn.mkldnn;

import com.intel.analytics.bigdl.dllib.nn.Container;
import com.intel.analytics.bigdl.dllib.nn.CrossEntropyCriterion;
import com.intel.analytics.bigdl.dllib.nn.CrossEntropyCriterion$;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.DnnGraph;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.HeapData;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.HeapData$;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.MemoryData;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.MklDnnContainer;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.Perf$$anonfun$main$1$;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.Phase$InferencePhase$;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.Phase$TrainingPhase$;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.ResNet$;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.ResNet$DatasetType$ImageNet$;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.ResNet50PerfParams;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.ResNet50PerfParams$;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.Sequential;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.models.Vgg_16$;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.Tensor$;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath$TensorNumeric$NumericFloat$;
import com.intel.analytics.bigdl.dllib.utils.Engine$;
import com.intel.analytics.bigdl.dllib.utils.Log4Error$;
import com.intel.analytics.bigdl.dllib.utils.RandomGenerator$;
import com.intel.analytics.bigdl.dllib.utils.T$;
import com.intel.analytics.bigdl.package$;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.Predef;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.StringOps;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scopt.OptionParser;
import scopt.Read$;

public final class Perf$ {
    public static final Perf$ MODULE$;
    private final Logger logger;
    private final OptionParser<ResNet50PerfParams> parser;

    static {
        new Perf$();
    }

    public Logger logger() {
        return this.logger;
    }

    public OptionParser<ResNet50PerfParams> parser() {
        return this.parser;
    }

    public void main(String[] argv) {
        System.setProperty("bigdl.mkldnn.fusion.convbn", "true");
        System.setProperty("bigdl.mkldnn.fusion.bnrelu", "true");
        System.setProperty("bigdl.mkldnn.fusion.convrelu", "true");
        System.setProperty("bigdl.mkldnn.fusion.convsum", "true");
        System.setProperty("bigdl.localMode", "true");
        System.setProperty("bigdl.engineType", "mkldnn");
        Engine$.MODULE$.init();
        this.parser().parse((Seq<String>)Predef$.MODULE$.wrapRefArray((Object[])argv), new ResNet50PerfParams(ResNet50PerfParams$.MODULE$.$lessinit$greater$default$1(), ResNet50PerfParams$.MODULE$.$lessinit$greater$default$2(), ResNet50PerfParams$.MODULE$.$lessinit$greater$default$3(), ResNet50PerfParams$.MODULE$.$lessinit$greater$default$4())).foreach((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final void apply(ResNet50PerfParams params) {
                Container container;
                int batchSize = params.batchSize();
                boolean training2 = params.training();
                int iterations = params.iteration();
                int classNum = 1000;
                int inputFormat = 7;
                int[] inputShape = new int[]{batchSize, 3, 224, 224};
                Tensor<Object> input = Tensor$.MODULE$.apply$mFc$sp(inputShape, (ClassTag<Object>)ClassTag$.MODULE$.Float(), (TensorNumericMath.TensorNumeric<Object>)TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$).rand();
                Tensor<Object> label = Tensor$.MODULE$.apply$mFc$sp(batchSize, (ClassTag<Object>)ClassTag$.MODULE$.Float(), (TensorNumericMath.TensorNumeric<Object>)TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$).apply1((Function1<Object, Object>)new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final float apply(float x$1) {
                        return this.apply$mcFF$sp(x$1);
                    }

                    public float apply$mcFF$sp(float x$1) {
                        return (float)Math.ceil(RandomGenerator$.MODULE$.RNG().uniform(0.0, 1.0) * (double)1000);
                    }
                });
                String string = params.model();
                if ("vgg16".equals(string)) {
                    container = Vgg_16$.MODULE$.apply(batchSize, classNum, true);
                } else if ("resnet50".equals(string)) {
                    container = ResNet$.MODULE$.apply(batchSize, classNum, T$.MODULE$.apply((Tuple2<Object, Object>)Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"depth"), (Object)BoxesRunTime.boxToInteger((int)50)), (Seq<Tuple2<Object, Object>>)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"dataSet"), (Object)ResNet$DatasetType$ImageNet$.MODULE$)})));
                } else if ("vgg16_graph".equals(string)) {
                    container = Vgg_16$.MODULE$.graph(batchSize, classNum, true);
                } else if ("resnet50_graph".equals(string)) {
                    container = ResNet$.MODULE$.graph(batchSize, classNum, T$.MODULE$.apply((Tuple2<Object, Object>)Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"depth"), (Object)BoxesRunTime.boxToInteger((int)50)), (Seq<Tuple2<Object, Object>>)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"dataSet"), (Object)ResNet$DatasetType$ImageNet$.MODULE$)})));
                } else {
                    Log4Error$.MODULE$.invalidInputError(false, new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Unkown model ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{params.model()})), "only support vgg16, resnet50, vgg16_graph, resnet50_graph");
                    container = null;
                }
                Sequential model = container;
                CrossEntropyCriterion$.MODULE$.apply$default$1();
                CrossEntropyCriterion<Object> criterion = CrossEntropyCriterion$.MODULE$.apply$mFc$sp(null, CrossEntropyCriterion$.MODULE$.apply$default$2(), (ClassTag<Object>)ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
                Engine$.MODULE$.dnnComputing().invokeAndWait2(Predef$.MODULE$.wrapRefArray((Object[])Predef$.MODULE$.intArrayOps(new int[]{1}).map((Function1)new Serializable(this, training2, inputFormat, inputShape, model){
                    public static final long serialVersionUID = 0L;
                    public final boolean training$1;
                    public final int inputFormat$1;
                    public final int[] inputShape$1;
                    public final Container model$1;

                    public final Function0<BoxedUnit> apply(int x$2) {
                        return new Serializable(this){
                            public static final long serialVersionUID = 0L;
                            private final /* synthetic */ anonfun$main$1$$anonfun$apply$2 $outer;

                            public final void apply() {
                                this.apply$mcV$sp();
                            }

                            public void apply$mcV$sp() {
                                if (this.$outer.training$1) {
                                    this.$outer.model$1.training();
                                    if (this.$outer.model$1 instanceof MklDnnContainer) {
                                        ((MklDnnContainer)((Object)this.$outer.model$1)).compile(Phase$TrainingPhase$.MODULE$, (MemoryData[])((Object[])new MemoryData[]{new HeapData(this.$outer.inputShape$1, this.$outer.inputFormat$1, HeapData$.MODULE$.apply$default$3())}));
                                    } else if (this.$outer.model$1 instanceof DnnGraph) {
                                        ((DnnGraph)this.$outer.model$1).compile(Phase$TrainingPhase$.MODULE$);
                                    }
                                } else {
                                    this.$outer.model$1.evaluate();
                                    if (this.$outer.model$1 instanceof MklDnnContainer) {
                                        ((MklDnnContainer)((Object)this.$outer.model$1)).compile(Phase$InferencePhase$.MODULE$, (MemoryData[])((Object[])new MemoryData[]{new HeapData(this.$outer.inputShape$1, this.$outer.inputFormat$1, HeapData$.MODULE$.apply$default$3())}));
                                    } else if (this.$outer.model$1 instanceof DnnGraph) {
                                        ((DnnGraph)this.$outer.model$1).compile(Phase$InferencePhase$.MODULE$);
                                    }
                                }
                            }
                            {
                                if ($outer == null) {
                                    throw null;
                                }
                                this.$outer = $outer;
                            }
                        };
                    }
                    {
                        this.training$1 = training$1;
                        this.inputFormat$1 = inputFormat$1;
                        this.inputShape$1 = inputShape$1;
                        this.model$1 = model$1;
                    }
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Function0.class)))), Engine$.MODULE$.dnnComputing().invokeAndWait2$default$2(), Engine$.MODULE$.dnnComputing().invokeAndWait2$default$3());
                for (int iteration2 = 0; iteration2 < iterations; ++iteration2) {
                    long start2 = System.nanoTime();
                    Engine$.MODULE$.dnnComputing().invokeAndWait2(Predef$.MODULE$.wrapRefArray((Object[])Predef$.MODULE$.intArrayOps(new int[]{1}).map((Function1)new Serializable(this, training2, input, label, model, criterion){
                        public static final long serialVersionUID = 0L;
                        public final boolean training$1;
                        public final Tensor input$1;
                        public final Tensor label$1;
                        public final Container model$1;
                        public final CrossEntropyCriterion criterion$1;

                        public final Function0<Object> apply(int x$3) {
                            return new Serializable(this){
                                public static final long serialVersionUID = 0L;
                                private final /* synthetic */ anonfun$main$1$$anonfun$apply$3 $outer;

                                public final Object apply() {
                                    Object object;
                                    B output = this.$outer.model$1.forward(this.$outer.input$1);
                                    if (this.$outer.training$1) {
                                        float _loss = BoxesRunTime.unboxToFloat(package$.MODULE$.convCriterion(this.$outer.criterion$1).forward((Activity)output, this.$outer.label$1));
                                        Tensor<Object> errors = package$.MODULE$.convCriterion(this.$outer.criterion$1).backward((Activity)output, this.$outer.label$1).toTensor(TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
                                        object = this.$outer.model$1.backward(this.$outer.input$1, errors);
                                    } else {
                                        object = BoxedUnit.UNIT;
                                    }
                                    return object;
                                }
                                {
                                    if ($outer == null) {
                                        throw null;
                                    }
                                    this.$outer = $outer;
                                }
                            };
                        }
                        {
                            this.training$1 = training$1;
                            this.input$1 = input$1;
                            this.label$1 = label$1;
                            this.model$1 = model$1;
                            this.criterion$1 = criterion$1;
                        }
                    }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Function0.class)))), Engine$.MODULE$.dnnComputing().invokeAndWait2$default$2(), Engine$.MODULE$.dnnComputing().invokeAndWait2$default$3());
                    long takes = System.nanoTime() - start2;
                    String throughput = new StringOps(Predef$.MODULE$.augmentString("%.2f")).format((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToDouble((double)((double)batchSize / ((double)takes / 1.0E9)))}));
                    Perf$.MODULE$.logger().info(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Iteration ", ", takes ", " s, throughput is ", " imgs/sec"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)iteration2), BoxesRunTime.boxToLong((long)takes), throughput})));
                }
            }
        });
    }

    private Perf$() {
        MODULE$ = this;
        this.logger = LogManager.getLogger(this.getClass());
        this.parser = new OptionParser<ResNet50PerfParams>(){
            {
                this.opt('m', "model", Read$.MODULE$.stringRead()).text("model you want, vgg16 | resnet50 | vgg16_graph | resnet50_graph").action(new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final ResNet50PerfParams apply(String v, ResNet50PerfParams p) {
                        String x$12 = v;
                        int x$13 = p.copy$default$1();
                        int x$14 = p.copy$default$2();
                        boolean x$15 = p.copy$default$3();
                        return p.copy(x$13, x$14, x$15, x$12);
                    }
                });
                this.opt('b', "batchSize", Read$.MODULE$.intRead()).text("Batch size of input data").action(new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final ResNet50PerfParams apply(int v, ResNet50PerfParams p) {
                        return p.copy(v, p.copy$default$2(), p.copy$default$3(), p.copy$default$4());
                    }
                });
                this.opt('i', "iteration", Read$.MODULE$.intRead()).text("Iteration of perf test. The result will be average of each iteration time cost").action(new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final ResNet50PerfParams apply(int v, ResNet50PerfParams p) {
                        int x$16 = v;
                        int x$17 = p.copy$default$1();
                        boolean x$18 = p.copy$default$3();
                        String x$19 = p.copy$default$4();
                        return p.copy(x$17, x$16, x$18, x$19);
                    }
                });
                this.opt('t', "training", Read$.MODULE$.booleanRead()).text(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Perf test training or testing"})).s((Seq)Nil$.MODULE$)).action(new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final ResNet50PerfParams apply(boolean v, ResNet50PerfParams p) {
                        boolean x$20 = v;
                        int x$21 = p.copy$default$1();
                        int x$22 = p.copy$default$2();
                        String x$23 = p.copy$default$4();
                        return p.copy(x$21, x$22, x$20, x$23);
                    }
                });
            }
        };
    }
}

