/*
 * Decompiled with CFR 0.152.
 */
package com.intel.analytics.bigdl.dllib.models.resnet;

import com.intel.analytics.bigdl.dllib.feature.dataset.AbstractDataSet;
import com.intel.analytics.bigdl.dllib.feature.dataset.MiniBatch;
import com.intel.analytics.bigdl.dllib.models.resnet.ImageNetDataSet$;
import com.intel.analytics.bigdl.dllib.models.resnet.ResNet$;
import com.intel.analytics.bigdl.dllib.models.resnet.ResNet$DatasetType$ImageNet$;
import com.intel.analytics.bigdl.dllib.models.resnet.ResNet$ShortcutType$B$;
import com.intel.analytics.bigdl.dllib.models.resnet.Utils;
import com.intel.analytics.bigdl.dllib.models.resnet.Utils$;
import com.intel.analytics.bigdl.dllib.models.resnet.Utils$TrainParams$;
import com.intel.analytics.bigdl.dllib.nn.BatchNormalization;
import com.intel.analytics.bigdl.dllib.nn.Container;
import com.intel.analytics.bigdl.dllib.nn.CrossEntropyCriterion;
import com.intel.analytics.bigdl.dllib.nn.CrossEntropyCriterion$;
import com.intel.analytics.bigdl.dllib.nn.Module$;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.optim.OptimMethod$;
import com.intel.analytics.bigdl.dllib.optim.Optimizer;
import com.intel.analytics.bigdl.dllib.optim.Optimizer$;
import com.intel.analytics.bigdl.dllib.optim.SGD;
import com.intel.analytics.bigdl.dllib.optim.SGD$;
import com.intel.analytics.bigdl.dllib.optim.SGD$mcF$sp;
import com.intel.analytics.bigdl.dllib.optim.Top1Accuracy;
import com.intel.analytics.bigdl.dllib.optim.Top5Accuracy;
import com.intel.analytics.bigdl.dllib.optim.Trigger$;
import com.intel.analytics.bigdl.dllib.optim.ValidationMethod;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath$TensorNumeric$NumericFloat$;
import com.intel.analytics.bigdl.dllib.utils.Engine$;
import com.intel.analytics.bigdl.dllib.utils.EngineType;
import com.intel.analytics.bigdl.dllib.utils.LoggerFilter$;
import com.intel.analytics.bigdl.dllib.utils.MklBlas$;
import com.intel.analytics.bigdl.dllib.utils.OptimizerV1$;
import com.intel.analytics.bigdl.dllib.utils.OptimizerV2$;
import com.intel.analytics.bigdl.dllib.utils.T$;
import com.intel.analytics.bigdl.dllib.visualization.TrainSummary;
import com.intel.analytics.bigdl.dllib.visualization.TrainSummary$;
import com.intel.analytics.bigdl.dllib.visualization.ValidationSummary;
import com.intel.analytics.bigdl.dllib.visualization.ValidationSummary$;
import com.intel.analytics.bigdl.package$;
import org.apache.logging.log4j.Level;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.core.config.Configurator;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import scala.Function1;
import scala.MatchError;
import scala.Predef;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.Tuple4;
import scala.collection.Seq;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

public final class TrainImageNet$ {
    public static final TrainImageNet$ MODULE$;
    private final Logger logger;

    static {
        new TrainImageNet$();
    }

    public Logger logger() {
        return this.logger;
    }

    public double imageNetDecay(int epoch) {
        return epoch >= 80 ? 3.0 : (epoch >= 60 ? 2.0 : (epoch >= 30 ? 1.0 : 0.0));
    }

    public void main(String[] args) {
        Utils$.MODULE$.trainParser().parse((Seq<String>)Predef$.MODULE$.wrapRefArray((Object[])args), new Utils.TrainParams(Utils$TrainParams$.MODULE$.$lessinit$greater$default$1(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$2(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$3(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$4(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$5(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$6(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$7(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$8(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$9(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$10(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$11(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$12(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$13(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$14(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$15(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$16(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$17(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$18(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$19())).map((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final void apply(Utils.TrainParams param) {
                SparkConf conf = Engine$.MODULE$.createSparkConf(Engine$.MODULE$.createSparkConf$default$1()).setAppName("Train ResNet on ImageNet2012").set("spark.rpc.message.maxSize", "200");
                SparkContext sc = new SparkContext(conf);
                Engine$.MODULE$.init();
                int batchSize = param.batchSize();
                Tuple4 tuple4 = new Tuple4((Object)BoxesRunTime.boxToInteger((int)224), (Object)ResNet$DatasetType$ImageNet$.MODULE$, (Object)BoxesRunTime.boxToInteger((int)param.nepochs()), (Object)ImageNetDataSet$.MODULE$);
                if (tuple4 != null) {
                    SGD sGD;
                    AbstractModule<Activity, Activity, Object> abstractModule;
                    Tuple4 tuple42;
                    int imageSize = BoxesRunTime.unboxToInt((Object)tuple4._1());
                    ResNet$DatasetType$ImageNet$ dataSetType = (ResNet$DatasetType$ImageNet$)tuple4._2();
                    int maxEpoch = BoxesRunTime.unboxToInt((Object)tuple4._3());
                    ImageNetDataSet$ dataSet = (ImageNetDataSet$)tuple4._4();
                    Tuple4 tuple43 = tuple42 = new Tuple4((Object)BoxesRunTime.boxToInteger((int)imageSize), (Object)dataSetType, (Object)BoxesRunTime.boxToInteger((int)maxEpoch), (Object)dataSet);
                    int imageSize2 = BoxesRunTime.unboxToInt((Object)tuple43._1());
                    ResNet$DatasetType$ImageNet$ dataSetType2 = (ResNet$DatasetType$ImageNet$)tuple43._2();
                    int maxEpoch2 = BoxesRunTime.unboxToInt((Object)tuple43._3());
                    ImageNetDataSet$ dataSet2 = (ImageNetDataSet$)tuple43._4();
                    AbstractDataSet<MiniBatch<Object>, ?> trainDataSet = dataSet2.trainDataSet(new StringBuilder().append((Object)param.folder()).append((Object)"/train").toString(), sc, imageSize2, batchSize);
                    AbstractDataSet<MiniBatch<Object>, ?> validateSet = dataSet2.valDataSet(new StringBuilder().append((Object)param.folder()).append((Object)"/val").toString(), sc, imageSize2, batchSize);
                    ResNet$ShortcutType$B$ shortcut = ResNet$ShortcutType$B$.MODULE$;
                    if (param.modelSnapshot().isDefined()) {
                        abstractModule = Module$.MODULE$.loadModule((String)param.modelSnapshot().get(), Module$.MODULE$.loadModule$default$2(), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
                    } else {
                        AbstractModule<Activity, Activity, Object> curModel = ResNet$.MODULE$.apply(param.classes(), T$.MODULE$.apply((Tuple2<Object, Object>)Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"shortcutType"), (Object)shortcut), (Seq<Tuple2<Object, Object>>)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"depth"), (Object)BoxesRunTime.boxToInteger((int)param.depth())), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"optnet"), (Object)BoxesRunTime.boxToBoolean((boolean)param.optnet())), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"dataSet"), (Object)dataSetType2)})));
                        if (param.optnet()) {
                            ResNet$.MODULE$.shareGradInput(curModel);
                        }
                        ResNet$.MODULE$.modelInit(curModel);
                        EngineType engineType = Engine$.MODULE$.getEngineType();
                        MklBlas$ mklBlas$ = MklBlas$.MODULE$;
                        if (!(engineType != null ? !engineType.equals(mklBlas$) : mklBlas$ != null)) {
                            TrainImageNet$.MODULE$.com$intel$analytics$bigdl$dllib$models$resnet$TrainImageNet$$setParallism(curModel, Engine$.MODULE$.coreNumber());
                        }
                        abstractModule = curModel;
                    }
                    AbstractModule<Activity, Activity, Object> model = abstractModule;
                    Predef$.MODULE$.println(model);
                    if (param.optimizerVersion().isDefined()) {
                        String string = ((String)param.optimizerVersion().get()).toLowerCase();
                        if ("optimizerv1".equals(string)) {
                            Engine$.MODULE$.setOptimizerVersion(OptimizerV1$.MODULE$);
                            BoxedUnit boxedUnit = BoxedUnit.UNIT;
                        } else if ("optimizerv2".equals(string)) {
                            Engine$.MODULE$.setOptimizerVersion(OptimizerV2$.MODULE$);
                            BoxedUnit boxedUnit = BoxedUnit.UNIT;
                        } else {
                            throw new MatchError((Object)string);
                        }
                    }
                    if (param.stateSnapshot().isDefined()) {
                        SGD optim = (SGD)OptimMethod$.MODULE$.load((String)param.stateSnapshot().get(), ClassTag$.MODULE$.Float());
                        double baseLr = param.learningRate();
                        int iterationsPerEpoch = (int)scala.math.package$.MODULE$.ceil((double)(1281167 / param.batchSize()));
                        int warmUpIteration = iterationsPerEpoch * param.warmupEpoch();
                        double maxLr = param.maxLr();
                        double delta = (maxLr - baseLr) / (double)warmUpIteration;
                        optim.learningRateSchedule_$eq(new SGD.EpochDecayWithWarmUp(warmUpIteration, delta, (Function1<Object, Object>)new Serializable(this){
                            public static final long serialVersionUID = 0L;

                            public final double apply(int epoch) {
                                return this.apply$mcDI$sp(epoch);
                            }

                            public double apply$mcDI$sp(int epoch) {
                                return TrainImageNet$.MODULE$.imageNetDecay(epoch);
                            }
                        }));
                        sGD = optim;
                    } else {
                        double baseLr = param.learningRate();
                        int iterationsPerEpoch = (int)scala.math.package$.MODULE$.ceil((double)(1281167 / param.batchSize()));
                        int warmUpIteration = iterationsPerEpoch * param.warmupEpoch();
                        double maxLr = param.maxLr();
                        double delta = (maxLr - baseLr) / (double)warmUpIteration;
                        TrainImageNet$.MODULE$.logger().info(new StringBuilder().append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"warmUpIteraion: ", ", startLr: ", ", "})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)warmUpIteration), BoxesRunTime.boxToDouble((double)param.learningRate())}))).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"maxLr: ", ", "})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToDouble((double)maxLr)}))).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"delta: ", ", nesterov: ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToDouble((double)delta), BoxesRunTime.boxToBoolean((boolean)param.nesterov())}))).toString());
                        double d = param.learningRate();
                        double d2 = param.weightDecay();
                        double d3 = param.momentum();
                        double d4 = param.dampening();
                        boolean bl = param.nesterov();
                        SGD.EpochDecayWithWarmUp epochDecayWithWarmUp = new SGD.EpochDecayWithWarmUp(warmUpIteration, delta, (Function1<Object, Object>)new Serializable(this){
                            public static final long serialVersionUID = 0L;

                            public final double apply(int epoch) {
                                return this.apply$mcDI$sp(epoch);
                            }

                            public double apply$mcDI$sp(int epoch) {
                                return TrainImageNet$.MODULE$.imageNetDecay(epoch);
                            }
                        });
                        SGD$.MODULE$.$lessinit$greater$default$8();
                        SGD$.MODULE$.$lessinit$greater$default$9();
                        sGD = new SGD$mcF$sp(d, 0.0, d2, d3, d4, bl, (SGD.LearningRateSchedule)epochDecayWithWarmUp, null, null, (ClassTag<Object>)ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
                    }
                    SGD optimMethod = sGD;
                    CrossEntropyCriterion$.MODULE$.$lessinit$greater$default$1();
                    Optimizer<Object, MiniBatch<Object>> optimizer = Optimizer$.MODULE$.apply(model, trainDataSet, package$.MODULE$.convCriterion(new CrossEntropyCriterion<Object>(null, CrossEntropyCriterion$.MODULE$.$lessinit$greater$default$2(), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$)), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
                    Object object = param.checkpoint().isDefined() ? optimizer.setCheckpoint((String)param.checkpoint().get(), Trigger$.MODULE$.everyEpoch()) : BoxedUnit.UNIT;
                    String logdir = "resnet-imagenet";
                    String appName = new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{sc.applicationId()}));
                    TrainSummary trainSummary = TrainSummary$.MODULE$.apply(logdir, appName);
                    trainSummary.setSummaryTrigger("LearningRate", Trigger$.MODULE$.severalIteration(1));
                    trainSummary.setSummaryTrigger("Parameters", Trigger$.MODULE$.severalIteration(10));
                    ValidationSummary validationSummary = ValidationSummary$.MODULE$.apply(logdir, appName);
                    optimizer.setOptimMethod(optimMethod).setValidation(Trigger$.MODULE$.everyEpoch(), validateSet, (ValidationMethod[])((Object[])new ValidationMethod[]{new Top1Accuracy<Object>(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$), new Top5Accuracy<Object>(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$)})).setEndWhen(Trigger$.MODULE$.maxEpoch(maxEpoch2)).optimize();
                    sc.stop();
                    return;
                }
                throw new MatchError((Object)tuple4);
            }
        });
    }

    public void com$intel$analytics$bigdl$dllib$models$resnet$TrainImageNet$$setParallism(AbstractModule<?, ?, Object> model, int parallism) {
        if (model instanceof BatchNormalization) {
            ((BatchNormalization)model).setParallism(parallism);
        }
        if (model instanceof Container) {
            ((Container)model).modules().foreach((Function1)new Serializable(parallism){
                public static final long serialVersionUID = 0L;
                private final int parallism$1;

                public final void apply(AbstractModule<Activity, Activity, Object> sub2) {
                    TrainImageNet$.MODULE$.com$intel$analytics$bigdl$dllib$models$resnet$TrainImageNet$$setParallism(sub2, this.parallism$1);
                }
                {
                    this.parallism$1 = parallism$1;
                }
            });
        }
    }

    private TrainImageNet$() {
        MODULE$ = this;
        LoggerFilter$.MODULE$.redirectSparkInfoLogs(LoggerFilter$.MODULE$.redirectSparkInfoLogs$default$1());
        Configurator.setLevel("com.intel.analytics.bigdl.dllib.optim", Level.INFO);
        this.logger = LogManager.getLogger(this.getClass());
    }
}

