/*
 * Decompiled with CFR 0.152.
 */
package com.intel.analytics.bigdl.dllib.example.localEstimator;

import com.intel.analytics.bigdl.dllib.estimator.LocalEstimator;
import com.intel.analytics.bigdl.dllib.example.localEstimator.ImageProcessing$;
import com.intel.analytics.bigdl.dllib.example.localEstimator.LenetLocalEstimatorParams;
import com.intel.analytics.bigdl.dllib.example.localEstimator.LenetLocalEstimatorParams$;
import com.intel.analytics.bigdl.dllib.example.localEstimator.MnistDataLoader$;
import com.intel.analytics.bigdl.dllib.feature.dataset.image.LabeledGreyImage;
import com.intel.analytics.bigdl.dllib.keras.objectives.ZooClassNLLCriterion;
import com.intel.analytics.bigdl.dllib.keras.objectives.ZooClassNLLCriterion$;
import com.intel.analytics.bigdl.dllib.models.lenet.LeNet5$;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.optim.Adam$;
import com.intel.analytics.bigdl.dllib.optim.Adam$mcF$sp;
import com.intel.analytics.bigdl.dllib.optim.Loss$;
import com.intel.analytics.bigdl.dllib.optim.Loss$mcF$sp;
import com.intel.analytics.bigdl.dllib.optim.Top1Accuracy;
import com.intel.analytics.bigdl.dllib.optim.ValidationMethod;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath$TensorNumeric$NumericFloat$;
import com.intel.analytics.bigdl.package$;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Function1;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.collection.Seq;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scopt.OptionParser;
import scopt.Read$;

public final class LenetLocalEstimator$ {
    public static final LenetLocalEstimator$ MODULE$;
    private final Logger logger;

    static {
        new LenetLocalEstimator$();
    }

    public Logger logger() {
        return this.logger;
    }

    public void main(String[] args) {
        OptionParser<LenetLocalEstimatorParams> parser = new OptionParser<LenetLocalEstimatorParams>(){
            {
                this.opt('d', "imageDirPath", Read$.MODULE$.stringRead()).required().text("The directory of mnist dataset").action(new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final LenetLocalEstimatorParams apply(String x, LenetLocalEstimatorParams c) {
                        return c.copy(x, c.copy$default$2(), c.copy$default$3(), c.copy$default$4());
                    }
                });
                this.opt('b', "batchSize", Read$.MODULE$.intRead()).required().text("The number of batchSize").action(new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final LenetLocalEstimatorParams apply(int x, LenetLocalEstimatorParams c) {
                        int x$1 = x;
                        String x$2 = c.copy$default$1();
                        int x$3 = c.copy$default$3();
                        int x$4 = c.copy$default$4();
                        return c.copy(x$2, x$1, x$3, x$4);
                    }
                });
                this.opt('e', "epoch", Read$.MODULE$.intRead()).required().text("The number of epoch").action(new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final LenetLocalEstimatorParams apply(int x, LenetLocalEstimatorParams c) {
                        int x$5 = x;
                        String x$6 = c.copy$default$1();
                        int x$7 = c.copy$default$2();
                        int x$8 = c.copy$default$4();
                        return c.copy(x$6, x$7, x$5, x$8);
                    }
                });
                this.opt('t', "threadNum", Read$.MODULE$.intRead()).required().text("The number of threadNum").action(new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final LenetLocalEstimatorParams apply(int x, LenetLocalEstimatorParams c) {
                        int x$9 = x;
                        String x$10 = c.copy$default$1();
                        int x$11 = c.copy$default$2();
                        int x$12 = c.copy$default$3();
                        return c.copy(x$10, x$11, x$12, x$9);
                    }
                });
            }
        };
        parser.parse((Seq<String>)Predef$.MODULE$.wrapRefArray((Object[])args), new LenetLocalEstimatorParams(LenetLocalEstimatorParams$.MODULE$.apply$default$1(), LenetLocalEstimatorParams$.MODULE$.apply$default$2(), LenetLocalEstimatorParams$.MODULE$.apply$default$3(), LenetLocalEstimatorParams$.MODULE$.apply$default$4())).map((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final void apply(LenetLocalEstimatorParams params) {
                LenetLocalEstimator$.MODULE$.logger().info(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"params parsed as ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{params})));
                String imageDirPath = params.imageDirPath();
                int batchSize = params.batchSize();
                int epoch = params.epoch();
                int threadNum = params.threadNum();
                AbstractModule<Activity, Activity, Object> model = LeNet5$.MODULE$.apply(10);
                ZooClassNLLCriterion$.MODULE$.apply$default$1();
                ZooClassNLLCriterion<Object> criterion = ZooClassNLLCriterion$.MODULE$.apply$mFc$sp(null, ZooClassNLLCriterion$.MODULE$.apply$default$2(), ZooClassNLLCriterion$.MODULE$.apply$default$3(), ZooClassNLLCriterion$.MODULE$.apply$default$4(), (ClassTag<Object>)ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
                Adam$mcF$sp adam = new Adam$mcF$sp(Adam$.MODULE$.$lessinit$greater$default$1(), Adam$.MODULE$.$lessinit$greater$default$2(), Adam$.MODULE$.$lessinit$greater$default$3(), Adam$.MODULE$.$lessinit$greater$default$4(), Adam$.MODULE$.$lessinit$greater$default$5(), (ClassTag<Object>)ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
                ValidationMethod[] validationMethodArray = new ValidationMethod[2];
                validationMethodArray[0] = new Top1Accuracy<Object>(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
                Loss$.MODULE$.$lessinit$greater$default$1();
                validationMethodArray[1] = new Loss$mcF$sp(null, (ClassTag<Object>)ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
                ValidationMethod[] validations = (ValidationMethod[])((Object[])validationMethodArray);
                LocalEstimator localEstimator = new LocalEstimator(model, package$.MODULE$.convCriterion(criterion), adam, validations, threadNum);
                LenetLocalEstimator$.MODULE$.logger().info(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"LocalEstimator loaded as ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{localEstimator})));
                LabeledGreyImage[] trainData = MnistDataLoader$.MODULE$.loadTrainData(imageDirPath);
                LabeledGreyImage[] testData = MnistDataLoader$.MODULE$.loadTestData(imageDirPath);
                localEstimator.fit(trainData, testData, ImageProcessing$.MODULE$.labeledGreyImageToMiniBatchTransformer(), batchSize, epoch, ClassTag$.MODULE$.apply(LabeledGreyImage.class));
            }
        });
    }

    private LenetLocalEstimator$() {
        MODULE$ = this;
        this.logger = LoggerFactory.getLogger(this.getClass());
    }
}

