/*
 * Decompiled with CFR 0.152.
 */
package com.intel.analytics.bigdl.dllib.models.rnn;

import com.intel.analytics.bigdl.dllib.feature.dataset.AbstractDataSet;
import com.intel.analytics.bigdl.dllib.feature.dataset.DataSet$;
import com.intel.analytics.bigdl.dllib.feature.dataset.FixedLength;
import com.intel.analytics.bigdl.dllib.feature.dataset.MiniBatch;
import com.intel.analytics.bigdl.dllib.feature.dataset.PaddingParam;
import com.intel.analytics.bigdl.dllib.feature.dataset.Sample;
import com.intel.analytics.bigdl.dllib.feature.dataset.SampleToMiniBatch$;
import com.intel.analytics.bigdl.dllib.feature.dataset.text.Dictionary;
import com.intel.analytics.bigdl.dllib.feature.dataset.text.Dictionary$;
import com.intel.analytics.bigdl.dllib.feature.dataset.text.LabeledSentence;
import com.intel.analytics.bigdl.dllib.feature.dataset.text.LabeledSentenceToSample$;
import com.intel.analytics.bigdl.dllib.feature.dataset.text.TextToLabeledSentence$;
import com.intel.analytics.bigdl.dllib.feature.dataset.text.utils.SentenceToken$;
import com.intel.analytics.bigdl.dllib.models.rnn.SequencePreprocess$;
import com.intel.analytics.bigdl.dllib.models.rnn.SimpleRNN$;
import com.intel.analytics.bigdl.dllib.models.rnn.Utils;
import com.intel.analytics.bigdl.dllib.models.rnn.Utils$;
import com.intel.analytics.bigdl.dllib.models.rnn.Utils$TrainParams$;
import com.intel.analytics.bigdl.dllib.nn.CrossEntropyCriterion$;
import com.intel.analytics.bigdl.dllib.nn.Module$;
import com.intel.analytics.bigdl.dllib.nn.TimeDistributedCriterion$;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.optim.Loss$mcF$sp;
import com.intel.analytics.bigdl.dllib.optim.OptimMethod;
import com.intel.analytics.bigdl.dllib.optim.OptimMethod$;
import com.intel.analytics.bigdl.dllib.optim.Optimizer;
import com.intel.analytics.bigdl.dllib.optim.Optimizer$;
import com.intel.analytics.bigdl.dllib.optim.SGD;
import com.intel.analytics.bigdl.dllib.optim.SGD$;
import com.intel.analytics.bigdl.dllib.optim.SGD$mcF$sp;
import com.intel.analytics.bigdl.dllib.optim.Trigger;
import com.intel.analytics.bigdl.dllib.optim.Trigger$;
import com.intel.analytics.bigdl.dllib.optim.ValidationMethod;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.Tensor$;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath$TensorNumeric$NumericFloat$;
import com.intel.analytics.bigdl.dllib.utils.Engine$;
import com.intel.analytics.bigdl.dllib.utils.OptimizerV1$;
import com.intel.analytics.bigdl.dllib.utils.OptimizerV2$;
import com.intel.analytics.bigdl.dllib.utils.T$;
import com.intel.analytics.bigdl.package$;
import org.apache.logging.log4j.Level;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.core.config.Configurator;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.rdd.RDD;
import scala.Function1;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.Serializable;
import scala.Some;
import scala.StringContext;
import scala.collection.Seq;
import scala.collection.mutable.StringBuilder;
import scala.math.Ordering;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;

public final class Train$ {
    public static final Train$ MODULE$;
    private final Logger logger;

    static {
        new Train$();
    }

    public Logger logger() {
        return this.logger;
    }

    public void main(String[] args) {
        Utils$.MODULE$.trainParser().parse((Seq<String>)Predef$.MODULE$.wrapRefArray((Object[])args), new Utils.TrainParams(Utils$TrainParams$.MODULE$.$lessinit$greater$default$1(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$2(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$3(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$4(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$5(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$6(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$7(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$8(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$9(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$10(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$11(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$12(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$13(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$14(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$15(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$16(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$17(), Utils$TrainParams$.MODULE$.$lessinit$greater$default$18())).map((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final void apply(Utils.TrainParams param) {
                OptimMethod<Object> optimMethod;
                AbstractModule<Activity, Activity, Object> model;
                AbstractModule<Activity, Activity, Object> abstractModule;
                SparkConf conf = Engine$.MODULE$.createSparkConf(Engine$.MODULE$.createSparkConf$default$1()).setAppName("Train rnn on text").set("spark.task.maxFailures", "1");
                SparkContext sc = new SparkContext(conf);
                Engine$.MODULE$.init();
                RDD<String[]> tokens = SequencePreprocess$.MODULE$.apply(new StringBuilder().append((Object)param.dataFolder()).append((Object)"/train.txt").toString(), sc, param.sentFile(), param.tokenFile());
                Dictionary dictionary = Dictionary$.MODULE$.apply(tokens, param.vocabSize());
                dictionary.save(param.saveFolder());
                int maxTrainLength = BoxesRunTime.unboxToInt((Object)tokens.map((Function1)new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final int apply(String[] x) {
                        return x.length;
                    }
                }, ClassTag$.MODULE$.Int()).max((Ordering)Ordering.Int$.MODULE$));
                RDD<String[]> valtokens = SequencePreprocess$.MODULE$.apply(new StringBuilder().append((Object)param.dataFolder()).append((Object)"/val.txt").toString(), sc, param.sentFile(), param.tokenFile());
                int maxValLength = BoxesRunTime.unboxToInt((Object)valtokens.map((Function1)new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final int apply(String[] x) {
                        return x.length;
                    }
                }, ClassTag$.MODULE$.Int()).max((Ordering)Ordering.Int$.MODULE$));
                Train$.MODULE$.logger().info(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"maxTrain length = ", ", maxVal = ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)maxTrainLength), BoxesRunTime.boxToInteger((int)maxValLength)})));
                int totalVocabLength = dictionary.getVocabSize() + 1;
                int startIdx = dictionary.getIndex(SentenceToken$.MODULE$.start());
                int endIdx = dictionary.getIndex(SentenceToken$.MODULE$.end());
                Tensor<Object> padFeature = Tensor$.MODULE$.apply(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$).resize(totalVocabLength);
                padFeature.setValue(endIdx + 1, BoxesRunTime.boxToFloat((float)1.0f));
                Tensor<Object> padLabel = Tensor$.MODULE$.apply(T$.MODULE$.apply(BoxesRunTime.boxToFloat((float)((float)startIdx + 1.0f)), (Seq<Object>)Predef$.MODULE$.genericWrapArray((Object)new Object[0])), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
                PaddingParam<T> featurePadding = new PaddingParam<T>((Option<Tensor<T>[]>)new Some((Object)new Tensor[]{padFeature}), new FixedLength(new int[]{maxTrainLength}), ClassTag$.MODULE$.Float());
                PaddingParam<T> labelPadding = new PaddingParam<T>((Option<Tensor<T>[]>)new Some((Object)new Tensor[]{padLabel}), new FixedLength(new int[]{maxTrainLength}), ClassTag$.MODULE$.Float());
                AbstractDataSet<C, ?> trainSet = DataSet$.MODULE$.rdd(tokens, DataSet$.MODULE$.rdd$default$2(), DataSet$.MODULE$.rdd$default$3(), DataSet$.MODULE$.rdd$default$4(), ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(String.class))).transform(TextToLabeledSentence$.MODULE$.apply(dictionary, ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$), ClassTag$.MODULE$.apply(LabeledSentence.class)).transform(LabeledSentenceToSample$.MODULE$.apply(totalVocabLength, LabeledSentenceToSample$.MODULE$.apply$default$2(), LabeledSentenceToSample$.MODULE$.apply$default$3(), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$), ClassTag$.MODULE$.apply(Sample.class)).transform(SampleToMiniBatch$.MODULE$.apply(param.batchSize(), new Some(featurePadding), new Some(labelPadding), SampleToMiniBatch$.MODULE$.apply$default$4(), SampleToMiniBatch$.MODULE$.apply$default$5(), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$), ClassTag$.MODULE$.apply(MiniBatch.class));
                AbstractDataSet<C, ?> validationSet = DataSet$.MODULE$.rdd(valtokens, DataSet$.MODULE$.rdd$default$2(), DataSet$.MODULE$.rdd$default$3(), DataSet$.MODULE$.rdd$default$4(), ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(String.class))).transform(TextToLabeledSentence$.MODULE$.apply(dictionary, ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$), ClassTag$.MODULE$.apply(LabeledSentence.class)).transform(LabeledSentenceToSample$.MODULE$.apply(totalVocabLength, LabeledSentenceToSample$.MODULE$.apply$default$2(), LabeledSentenceToSample$.MODULE$.apply$default$3(), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$), ClassTag$.MODULE$.apply(Sample.class)).transform(SampleToMiniBatch$.MODULE$.apply(param.batchSize(), new Some(featurePadding), new Some(labelPadding), SampleToMiniBatch$.MODULE$.apply$default$4(), SampleToMiniBatch$.MODULE$.apply$default$5(), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$), ClassTag$.MODULE$.apply(MiniBatch.class));
                if (param.modelSnapshot().isDefined()) {
                    abstractModule = Module$.MODULE$.loadModule((String)param.modelSnapshot().get(), Module$.MODULE$.loadModule$default$2(), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
                } else {
                    AbstractModule<Activity, Activity, Object> curModel = SimpleRNN$.MODULE$.apply(totalVocabLength, param.hiddenSize(), totalVocabLength);
                    curModel.reset();
                    abstractModule = model = curModel;
                }
                if (param.optimizerVersion().isDefined()) {
                    String string = ((String)param.optimizerVersion().get()).toLowerCase();
                    if ("optimizerv1".equals(string)) {
                        Engine$.MODULE$.setOptimizerVersion(OptimizerV1$.MODULE$);
                        BoxedUnit boxedUnit = BoxedUnit.UNIT;
                    } else if ("optimizerv2".equals(string)) {
                        Engine$.MODULE$.setOptimizerVersion(OptimizerV2$.MODULE$);
                        BoxedUnit boxedUnit = BoxedUnit.UNIT;
                    } else {
                        throw new MatchError((Object)string);
                    }
                }
                if (param.stateSnapshot().isDefined()) {
                    optimMethod = OptimMethod$.MODULE$.load((String)param.stateSnapshot().get(), ClassTag$.MODULE$.Float());
                } else {
                    double d = param.learningRate();
                    double d2 = param.weightDecay();
                    double d3 = param.momentum();
                    double d4 = param.dampening();
                    boolean bl = SGD$.MODULE$.$lessinit$greater$default$6();
                    SGD.LearningRateSchedule learningRateSchedule = SGD$.MODULE$.$lessinit$greater$default$7();
                    SGD$.MODULE$.$lessinit$greater$default$8();
                    SGD$.MODULE$.$lessinit$greater$default$9();
                    optimMethod = new SGD$mcF$sp(d, 0.0, d2, d3, d4, bl, learningRateSchedule, null, null, (ClassTag<Object>)ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
                }
                OptimMethod<Object> optimMethod2 = optimMethod;
                CrossEntropyCriterion$.MODULE$.apply$default$1();
                Optimizer<Object, C> optimizer = Optimizer$.MODULE$.apply(model, trainSet, package$.MODULE$.convCriterion(TimeDistributedCriterion$.MODULE$.apply$mFc$sp(CrossEntropyCriterion$.MODULE$.apply$mFc$sp(null, CrossEntropyCriterion$.MODULE$.apply$default$2(), (ClassTag<Object>)ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$), true, TimeDistributedCriterion$.MODULE$.apply$default$3(), (ClassTag<Object>)ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$)), ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
                Object object = param.checkpoint().isDefined() ? optimizer.setCheckpoint((String)param.checkpoint().get(), Trigger$.MODULE$.everyEpoch()) : BoxedUnit.UNIT;
                Object object2 = param.overWriteCheckpoint() ? optimizer.overWriteCheckpoint() : BoxedUnit.UNIT;
                Trigger trigger = Trigger$.MODULE$.everyEpoch();
                ValidationMethod[] validationMethodArray = new ValidationMethod[1];
                CrossEntropyCriterion$.MODULE$.apply$default$1();
                validationMethodArray[0] = new Loss$mcF$sp(package$.MODULE$.convCriterion(TimeDistributedCriterion$.MODULE$.apply$mFc$sp(CrossEntropyCriterion$.MODULE$.apply$mFc$sp(null, CrossEntropyCriterion$.MODULE$.apply$default$2(), (ClassTag<Object>)ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$), true, TimeDistributedCriterion$.MODULE$.apply$default$3(), (ClassTag<Object>)ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$)), (ClassTag<Object>)ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
                optimizer.setValidation(trigger, validationSet, (ValidationMethod[])((Object[])validationMethodArray)).setOptimMethod(optimMethod2).setEndWhen(Trigger$.MODULE$.maxEpoch(param.nEpochs())).setCheckpoint((String)param.checkpoint().get(), Trigger$.MODULE$.everyEpoch()).optimize();
                sc.stop();
            }
        });
    }

    private Train$() {
        MODULE$ = this;
        Configurator.setLevel("org", Level.ERROR);
        Configurator.setLevel("akka", Level.ERROR);
        Configurator.setLevel("breeze", Level.ERROR);
        this.logger = LogManager.getLogger(this.getClass());
    }
}

