/*
 * Decompiled with CFR 0.152.
 */
package com.intel.analytics.bigdl.dllib.example.localEstimator;

import com.intel.analytics.bigdl.dllib.estimator.LocalEstimator;
import com.intel.analytics.bigdl.dllib.example.localEstimator.Cifar10DataLoader$;
import com.intel.analytics.bigdl.dllib.example.localEstimator.ImageProcessing$;
import com.intel.analytics.bigdl.dllib.example.localEstimator.ResnetLocalEstimatorParams;
import com.intel.analytics.bigdl.dllib.example.localEstimator.ResnetLocalEstimatorParams$;
import com.intel.analytics.bigdl.dllib.feature.dataset.image.LabeledBGRImage;
import com.intel.analytics.bigdl.dllib.models.resnet.ResNet$;
import com.intel.analytics.bigdl.dllib.models.resnet.ResNet$ShortcutType$A$;
import com.intel.analytics.bigdl.dllib.nn.CrossEntropyCriterion;
import com.intel.analytics.bigdl.dllib.nn.CrossEntropyCriterion$;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.optim.Adam$;
import com.intel.analytics.bigdl.dllib.optim.Adam$mcF$sp;
import com.intel.analytics.bigdl.dllib.optim.Loss$;
import com.intel.analytics.bigdl.dllib.optim.Loss$mcF$sp;
import com.intel.analytics.bigdl.dllib.optim.Top1Accuracy;
import com.intel.analytics.bigdl.dllib.optim.Top5Accuracy;
import com.intel.analytics.bigdl.dllib.optim.ValidationMethod;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath$TensorNumeric$NumericFloat$;
import com.intel.analytics.bigdl.dllib.utils.T$;
import com.intel.analytics.bigdl.package$;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Function1;
import scala.Predef;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scopt.OptionParser;
import scopt.Read$;

public final class ResnetLocalEstimator$ {
    public static final ResnetLocalEstimator$ MODULE$;
    private final Logger logger;

    static {
        new ResnetLocalEstimator$();
    }

    public Logger logger() {
        return this.logger;
    }

    public void main(String[] args) {
        OptionParser<ResnetLocalEstimatorParams> parser = new OptionParser<ResnetLocalEstimatorParams>(){
            {
                this.opt('d', "imageDirPath", Read$.MODULE$.stringRead()).required().text("The directory of mnist dataset").action(new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final ResnetLocalEstimatorParams apply(String x, ResnetLocalEstimatorParams c) {
                        return c.copy(x, c.copy$default$2(), c.copy$default$3(), c.copy$default$4());
                    }
                });
                this.opt('b', "batchSize", Read$.MODULE$.intRead()).required().text("The number of batchSize").action(new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final ResnetLocalEstimatorParams apply(int x, ResnetLocalEstimatorParams c) {
                        int x$1 = x;
                        String x$2 = c.copy$default$1();
                        int x$3 = c.copy$default$3();
                        int x$4 = c.copy$default$4();
                        return c.copy(x$2, x$1, x$3, x$4);
                    }
                });
                this.opt('e', "epoch", Read$.MODULE$.intRead()).required().text("The number of epoch").action(new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final ResnetLocalEstimatorParams apply(int x, ResnetLocalEstimatorParams c) {
                        int x$5 = x;
                        String x$6 = c.copy$default$1();
                        int x$7 = c.copy$default$2();
                        int x$8 = c.copy$default$4();
                        return c.copy(x$6, x$7, x$5, x$8);
                    }
                });
                this.opt('t', "threadNum", Read$.MODULE$.intRead()).required().text("The number of threadNum").action(new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final ResnetLocalEstimatorParams apply(int x, ResnetLocalEstimatorParams c) {
                        int x$9 = x;
                        String x$10 = c.copy$default$1();
                        int x$11 = c.copy$default$2();
                        int x$12 = c.copy$default$3();
                        return c.copy(x$10, x$11, x$12, x$9);
                    }
                });
            }
        };
        parser.parse((Seq<String>)Predef$.MODULE$.wrapRefArray((Object[])args), new ResnetLocalEstimatorParams(ResnetLocalEstimatorParams$.MODULE$.apply$default$1(), ResnetLocalEstimatorParams$.MODULE$.apply$default$2(), ResnetLocalEstimatorParams$.MODULE$.apply$default$3(), ResnetLocalEstimatorParams$.MODULE$.apply$default$4())).map((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final void apply(ResnetLocalEstimatorParams params) {
                ResnetLocalEstimator$.MODULE$.logger().info(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"params parsed as ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{params})));
                String imageDirPath = params.imageDirPath();
                int batchSize = params.batchSize();
                int epoch = params.epoch();
                int threadNum = params.threadNum();
                AbstractModule<Activity, Activity, Object> model = ResNet$.MODULE$.graph(10, T$.MODULE$.apply((Tuple2<Object, Object>)Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"shortcutType"), (Object)ResNet$ShortcutType$A$.MODULE$), (Seq<Tuple2<Object, Object>>)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"depth"), (Object)BoxesRunTime.boxToInteger((int)50)), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"optnet"), (Object)BoxesRunTime.boxToBoolean((boolean)false))})));
                CrossEntropyCriterion$.MODULE$.$lessinit$greater$default$1();
                CrossEntropyCriterion<Object> criterion = new CrossEntropyCriterion<Object>(null, CrossEntropyCriterion$.MODULE$.$lessinit$greater$default$2(), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
                Adam$mcF$sp adam = new Adam$mcF$sp(Adam$.MODULE$.$lessinit$greater$default$1(), Adam$.MODULE$.$lessinit$greater$default$2(), Adam$.MODULE$.$lessinit$greater$default$3(), Adam$.MODULE$.$lessinit$greater$default$4(), Adam$.MODULE$.$lessinit$greater$default$5(), (ClassTag<Object>)ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
                ValidationMethod[] validationMethodArray = new ValidationMethod[3];
                validationMethodArray[0] = new Top1Accuracy<Object>(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
                validationMethodArray[1] = new Top5Accuracy<Object>(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
                Loss$.MODULE$.$lessinit$greater$default$1();
                validationMethodArray[2] = new Loss$mcF$sp(null, (ClassTag<Object>)ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
                ValidationMethod[] validations = (ValidationMethod[])((Object[])validationMethodArray);
                LocalEstimator localEstimator = new LocalEstimator(model, package$.MODULE$.convCriterion(criterion), adam, validations, threadNum);
                ResnetLocalEstimator$.MODULE$.logger().info(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"LocalEstimator loaded as ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{localEstimator})));
                LabeledBGRImage[] trainData = Cifar10DataLoader$.MODULE$.loadTrainData(imageDirPath);
                LabeledBGRImage[] testData = Cifar10DataLoader$.MODULE$.loadTestData(imageDirPath);
                localEstimator.fit(trainData, testData, ImageProcessing$.MODULE$.labeledBGRImageToMiniBatchTransformer(), batchSize, epoch, ClassTag$.MODULE$.apply(LabeledBGRImage.class));
            }
        });
    }

    private ResnetLocalEstimator$() {
        MODULE$ = this;
        this.logger = LoggerFactory.getLogger(this.getClass());
    }
}

