/*
 * Decompiled with CFR 0.152.
 */
package com.intel.analytics.bigdl.dllib.optim;

import com.intel.analytics.bigdl.dllib.feature.dataset.AbstractDataSet;
import com.intel.analytics.bigdl.dllib.feature.dataset.DistributedDataSet;
import com.intel.analytics.bigdl.dllib.feature.dataset.MiniBatch;
import com.intel.analytics.bigdl.dllib.models.utils.ModelBroadcast;
import com.intel.analytics.bigdl.dllib.models.utils.ModelBroadcast$;
import com.intel.analytics.bigdl.dllib.nn.Container;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractCriterion;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.MklDnnContainer;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.Phase$TrainingPhase$;
import com.intel.analytics.bigdl.dllib.optim.AbstractOptimizer;
import com.intel.analytics.bigdl.dllib.optim.DistriOptimizer;
import com.intel.analytics.bigdl.dllib.optim.Metrics;
import com.intel.analytics.bigdl.dllib.optim.OptimMethod;
import com.intel.analytics.bigdl.dllib.optim.Optimizer$;
import com.intel.analytics.bigdl.dllib.optim.ParallelOptimizer$;
import com.intel.analytics.bigdl.dllib.optim.ParallelOptimizer$$anonfun$optimize$5$;
import com.intel.analytics.bigdl.dllib.optim.Trigger;
import com.intel.analytics.bigdl.dllib.optim.ValidationMethod;
import com.intel.analytics.bigdl.dllib.optim.parameters.AllReduceParameter;
import com.intel.analytics.bigdl.dllib.tensor.ConvertableTo$ConvertableToDouble$;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.Tensor$;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.BlockManagerParameterSynchronizer;
import com.intel.analytics.bigdl.dllib.utils.DistriParameterSynchronizer;
import com.intel.analytics.bigdl.dllib.utils.Engine$;
import com.intel.analytics.bigdl.dllib.utils.EngineType;
import com.intel.analytics.bigdl.dllib.utils.Log4Error$;
import com.intel.analytics.bigdl.dllib.utils.MklBlas$;
import com.intel.analytics.bigdl.dllib.utils.MklDnn$;
import com.intel.analytics.bigdl.dllib.utils.T$;
import com.intel.analytics.bigdl.dllib.utils.Table;
import com.intel.analytics.bigdl.dllib.utils.ThreadPool;
import com.intel.analytics.bigdl.dllib.utils.Util$;
import com.intel.analytics.bigdl.dllib.visualization.TrainSummary;
import com.intel.analytics.bigdl.dllib.visualization.ValidationSummary;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.spark.SparkContext;
import org.apache.spark.TaskContext$;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.rdd.RDD;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.Function2;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef;
import scala.Predef$;
import scala.Serializable;
import scala.Some;
import scala.StringContext;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;
import scala.Tuple6;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.immutable.Iterable$;
import scala.collection.immutable.Map$;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.ArrayBuffer$;
import scala.collection.mutable.Buffer;
import scala.collection.mutable.Buffer$;
import scala.collection.mutable.HashMap;
import scala.collection.mutable.Map;
import scala.collection.mutable.StringBuilder;
import scala.collection.mutable.WrappedArray;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.reflect.package$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.IntRef;
import scala.runtime.LongRef;
import scala.runtime.Nothing$;
import scala.runtime.ObjectRef;
import scala.runtime.RichInt$;
import scala.runtime.ScalaRunTime$;

public final class ParallelOptimizer$
extends AbstractOptimizer {
    public static final ParallelOptimizer$ MODULE$;
    private final Logger logger;

    static {
        new ParallelOptimizer$();
    }

    public Logger logger() {
        return this.logger;
    }

    public <T> void optimize(AbstractModule<Activity, Activity, T> trainingModel, DistributedDataSet<MiniBatch<T>> dataset, int coresPerNode, Table state, Trigger endWhen, Metrics metrics, RDD<DistriOptimizer.Cache<T>> models, scala.collection.immutable.Map<String, OptimMethod<T>> optimMethods, Option<Trigger> validationTrigger, Option<AbstractDataSet<MiniBatch<T>, ?>> validationDataSet, Option<ValidationMethod<T>[]> validationMethods, Option<Trigger> cacheTrigger, Option<String> cachePath, Option<TrainSummary> trainSummary, Option<ValidationSummary> validationSummary, boolean isOverWrite, ClassTag<T> evidence$1, TensorNumericMath.TensorNumeric<T> ev) {
        EngineType engineType;
        block14: {
            Tuple3 tuple3;
            int n;
            long lastEpochTime;
            LongRef wallClockTime;
            int partitionNum;
            SparkContext sc;
            block13: {
                block12: {
                    sc = dataset.originRDD().sparkContext();
                    partitionNum = dataset.originRDD().partitions().length;
                    wallClockTime = LongRef.create((long)0L);
                    lastEpochTime = 0L;
                    optimMethods.values().foreach((Function1)new Serializable(){
                        public static final long serialVersionUID = 0L;

                        public final Object apply(OptimMethod<T> optimMethod) {
                            Object object = optimMethod.state().contains("epoch") ? BoxedUnit.UNIT : optimMethod.state().update("epoch", BoxesRunTime.boxToInteger((int)1));
                            Object object2 = optimMethod.state().contains("neval") ? BoxedUnit.UNIT : optimMethod.state().update("neval", BoxesRunTime.boxToInteger((int)1));
                            Object object3 = optimMethod.state().contains("Loss") ? BoxedUnit.UNIT : optimMethod.state().update("Loss", BoxesRunTime.boxToFloat((float)Float.POSITIVE_INFINITY));
                            Object object4 = optimMethod.state().contains("score") ? BoxedUnit.UNIT : optimMethod.state().update("score", BoxesRunTime.boxToFloat((float)0.0f));
                            return optimMethod.state().contains("recordsProcessedThisEpoch") ? BoxedUnit.UNIT : optimMethod.state().update("recordsProcessedThisEpoch", BoxesRunTime.boxToInteger((int)0));
                        }
                    });
                    engineType = Engine$.MODULE$.getEngineType();
                    if (!MklBlas$.MODULE$.equals(engineType)) break block12;
                    n = coresPerNode;
                    break block13;
                }
                if (!MklDnn$.MODULE$.equals(engineType)) break block14;
                n = 1;
            }
            int _subModelNumber = n;
            Log4Error$.MODULE$.invalidOperationError(_subModelNumber == 1, "currently only single model supported especially for mkldnn", Log4Error$.MODULE$.invalidOperationError$default$3(), Log4Error$.MODULE$.invalidOperationError$default$4());
            Table driverState = T$.MODULE$.apply((Tuple2<Object, Object>)Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"epoch"), ((OptimMethod)optimMethods.values().head()).state().apply("epoch")), (Seq<Tuple2<Object, Object>>)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"neval"), ((OptimMethod)optimMethods.values().head()).state().apply("neval")), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"Loss"), ((OptimMethod)optimMethods.values().head()).state().apply("Loss")), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"score"), ((OptimMethod)optimMethods.values().head()).state().apply("score")), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"parallelism"), (Object)BoxesRunTime.boxToInteger((int)_subModelNumber))}));
            this.logger().info("Count dataset");
            long countBefore = System.nanoTime();
            int numSamples = BoxesRunTime.unboxToInt((Object)((RDD)dataset.data(false)).map((Function1)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final int apply(MiniBatch<T> x$1) {
                    return x$1.size();
                }
            }, ClassTag$.MODULE$.Int()).reduce((Function2)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final int apply(int x$2, int x$3) {
                    return this.apply$mcIII$sp(x$2, x$3);
                }

                public int apply$mcIII$sp(int x$2, int x$3) {
                    return x$2 + x$3;
                }
            }));
            long countAfter = System.nanoTime();
            this.logger().info(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Count dataset complete. Time elapsed: ", "s"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToDouble((double)((double)(countAfter - countBefore) / 1.0E9))})));
            if ((long)numSamples != dataset.size()) {
                this.logger().warn("If the dataset is built directly from RDD[Minibatch], the data in each minibatch is fixed, and a single minibatch is randomly selected in each partition. If the dataset is transformed from RDD[Sample], each minibatch will be constructed on the fly from random samples, which is better for convergence.");
            }
            this.logger().info(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"config ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{state})));
            IntRef recordsProcessedThisEpoch = IntRef.create((int)BoxesRunTime.unboxToInt(((OptimMethod)optimMethods.values().head()).state().apply("recordsProcessedThisEpoch")));
            if (recordsProcessedThisEpoch.elem == 0) {
                long shuffleBefore = System.nanoTime();
                this.logger().info("Shuffle data");
                dataset.shuffle();
                long shuffleEnd = System.nanoTime();
                this.logger().info(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Shuffle data complete. Takes ", "s"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToDouble((double)((double)(shuffleEnd - shuffleBefore) / 1.0E9))})));
            }
            ArrayBuffer tasks = new ArrayBuffer();
            LongRef threshold = LongRef.create((long)Long.MAX_VALUE);
            LongRef timeout = LongRef.create((long)Long.MAX_VALUE);
            IntRef iteration2 = IntRef.create((int)0);
            double dropPercentage = BoxesRunTime.unboxToDouble((Object)state.get("dropPercentage").get());
            int warmupIterationNum = BoxesRunTime.unboxToInt((Object)state.get("warmupIterationNum").get());
            int computeThresholdbatchSize = BoxesRunTime.unboxToInt((Object)state.get("computeThresholdbatchSize").get());
            double maxDropPercentage = BoxesRunTime.unboxToDouble((Object)state.get("maxDropPercentage").get());
            int iterationPerTime = new StringOps(Predef$.MODULE$.augmentString(System.getProperty("bigdl.parallelOptimizer.iterationPerTime", "1"))).toInt();
            int driverSubModelNum = partitionNum * _subModelNumber * iterationPerTime;
            int dropModelNumBatch = 0;
            ObjectRef lossArray = ObjectRef.create((Object)new double[_subModelNumber]);
            long epochStart = System.nanoTime();
            RDD dataRDD = (RDD)dataset.data(true);
            while (true) {
                Tuple3 tuple32;
                if (endWhen.apply(driverState)) {
                    return;
                }
                DoubleRef lossSum = DoubleRef.create((double)0.0);
                IntRef recordsNum = IntRef.create((int)0);
                metrics.set("computing time for each node", (ArrayBuffer<Object>)((ArrayBuffer)ArrayBuffer$.MODULE$.apply((Seq)Nil$.MODULE$)), sc);
                metrics.set("computing time average", 0.0, sc, partitionNum);
                Metrics driverMetrics = metrics;
                long start2 = System.nanoTime();
                tuple3 = (Tuple3)dataRDD.zipPartitions(models, true, (Function2)new Serializable(ev, wallClockTime, _subModelNumber, threshold, timeout, iteration2, dropPercentage, warmupIterationNum, computeThresholdbatchSize, iterationPerTime, lossArray, lossSum, recordsNum, driverMetrics, start2){
                    public static final long serialVersionUID = 0L;
                    public final TensorNumericMath.TensorNumeric ev$1;
                    private final LongRef wallClockTime$1;
                    private final int _subModelNumber$1;
                    private final LongRef threshold$1;
                    private final LongRef timeout$1;
                    private final IntRef iteration$1;
                    private final double dropPercentage$1;
                    private final int warmupIterationNum$1;
                    private final int computeThresholdbatchSize$1;
                    private final int iterationPerTime$1;
                    public final ObjectRef lossArray$1;
                    private final DoubleRef lossSum$1;
                    private final IntRef recordsNum$1;
                    private final Metrics driverMetrics$1;
                    private final long start$1;

                    public final Iterator<Tuple3<Object, Object, Object>> apply(Iterator<MiniBatch<T>> data2, Iterator<DistriOptimizer.Cache<T>> modelIter) {
                        int finishedThreadSize = 0;
                        DistriOptimizer.Cache cached = (DistriOptimizer.Cache)modelIter.next();
                        ObjectRef miniBatch = ObjectRef.create(null);
                        for (int count2 = 0; count2 < this.iterationPerTime$1; ++count2) {
                            long syWStart = System.nanoTime();
                            miniBatch.elem = (MiniBatch)data2.next();
                            long time = System.nanoTime();
                            if (this.dropPercentage$1 > 0.0 && this.iteration$1.elem > this.warmupIterationNum$1 + this.computeThresholdbatchSize$1 - 1) {
                                this.timeout$1.elem = this.threshold$1.elem;
                            }
                            int pre = this.iteration$1.elem % this.computeThresholdbatchSize$1 * this._subModelNumber$1;
                            ThreadPool qual$1 = Engine$.MODULE$.default();
                            Seq x$18 = (Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Function0[]{new Serializable(this, cached, miniBatch, pre){
                                public static final long serialVersionUID = 0L;
                                private final /* synthetic */ anonfun.4 $outer;
                                private final DistriOptimizer.Cache cached$1;
                                private final ObjectRef miniBatch$1;
                                private final int pre$1;

                                public final int apply() {
                                    return this.apply$mcI$sp();
                                }

                                public int apply$mcI$sp() {
                                    long trainStart = System.nanoTime();
                                    AbstractModule<Activity, Activity, T> localModel = this.cached$1.localModels()[0];
                                    localModel.training();
                                    AbstractCriterion<Activity, Activity, T> localCriterion = this.cached$1.localCriterions()[0];
                                    Activity input = ((MiniBatch)this.miniBatch$1.elem).getInput();
                                    Activity target = ((MiniBatch)this.miniBatch$1.elem).getTarget();
                                    Activity output = localModel.forward(input);
                                    ((double[])this.$outer.lossArray$1.elem)[0] = BoxesRunTime.unboxToDouble((Object)this.$outer.ev$1.toType(localCriterion.forward(output, target), ConvertableTo$ConvertableToDouble$.MODULE$));
                                    Activity errors = localCriterion.backward(output, target);
                                    localModel.backward(input, errors);
                                    this.cached$1.moduleTimeList()[0 + this.pre$1] = System.nanoTime() - trainStart;
                                    return 0;
                                }
                                {
                                    if ($outer == null) {
                                        throw null;
                                    }
                                    this.$outer = $outer;
                                    this.cached$1 = cached$1;
                                    this.miniBatch$1 = miniBatch$1;
                                    this.pre$1 = pre$1;
                                }
                            }}));
                            long x$19 = this.timeout$1.elem;
                            TimeUnit x$20 = qual$1.invokeAndWait2$default$3();
                            Buffer<Future<T>> trainingThreads = qual$1.invokeAndWait2(x$18, x$19, x$20);
                            long computingTime = System.nanoTime() - time;
                            this.driverMetrics$1.add("computing time average", computingTime);
                            this.driverMetrics$1.add("computing time for each node", computingTime);
                            Buffer finishedThreads = (Buffer)((TraversableLike)trainingThreads.filter((Function1)new Serializable(this){
                                public static final long serialVersionUID = 0L;

                                public final boolean apply(Future<Object> x$4) {
                                    return !x$4.isCancelled();
                                }
                            })).map((Function1)new Serializable(this){
                                public static final long serialVersionUID = 0L;

                                public final int apply(Future<Object> x$5) {
                                    return BoxesRunTime.unboxToInt((Object)x$5.get());
                                }
                            }, Buffer$.MODULE$.canBuildFrom());
                            int currFinishedSize = finishedThreads.size();
                            finishedThreadSize += currFinishedSize;
                            this.recordsNum$1.elem += currFinishedSize * ((MiniBatch)miniBatch.elem).size();
                            for (int i = 0; i < currFinishedSize; ++i) {
                                this.lossSum$1.elem += ((double[])this.lossArray$1.elem)[BoxesRunTime.unboxToInt((Object)finishedThreads.apply(i))];
                            }
                        }
                        long end = System.nanoTime();
                        this.wallClockTime$1.elem += end - this.start$1;
                        return scala.package$.MODULE$.Iterator().single((Object)new Tuple3((Object)BoxesRunTime.boxToInteger((int)finishedThreadSize), (Object)BoxesRunTime.boxToDouble((double)this.lossSum$1.elem), (Object)BoxesRunTime.boxToInteger((int)this.recordsNum$1.elem)));
                    }
                    {
                        this.ev$1 = ev$1;
                        this.wallClockTime$1 = wallClockTime$1;
                        this._subModelNumber$1 = _subModelNumber$1;
                        this.threshold$1 = threshold$1;
                        this.timeout$1 = timeout$1;
                        this.iteration$1 = iteration$1;
                        this.dropPercentage$1 = dropPercentage$1;
                        this.warmupIterationNum$1 = warmupIterationNum$1;
                        this.computeThresholdbatchSize$1 = computeThresholdbatchSize$1;
                        this.iterationPerTime$1 = iterationPerTime$1;
                        this.lossArray$1 = lossArray$1;
                        this.lossSum$1 = lossSum$1;
                        this.recordsNum$1 = recordsNum$1;
                        this.driverMetrics$1 = driverMetrics$1;
                        this.start$1 = start$1;
                    }
                }, ClassTag$.MODULE$.apply(DistriOptimizer.Cache.class), ClassTag$.MODULE$.apply(Tuple3.class)).reduce((Function2)new Serializable(){
                    public static final long serialVersionUID = 0L;

                    public final Tuple3<Object, Object, Object> apply(Tuple3<Object, Object, Object> a, Tuple3<Object, Object, Object> b) {
                        return new Tuple3((Object)BoxesRunTime.boxToInteger((int)(BoxesRunTime.unboxToInt((Object)a._1()) + BoxesRunTime.unboxToInt((Object)b._1()))), (Object)BoxesRunTime.boxToDouble((double)(BoxesRunTime.unboxToDouble((Object)a._2()) + BoxesRunTime.unboxToDouble((Object)b._2()))), (Object)BoxesRunTime.boxToInteger((int)(BoxesRunTime.unboxToInt((Object)a._3()) + BoxesRunTime.unboxToInt((Object)b._3()))));
                    }
                });
                if (tuple3 == null) break;
                int numFinishedModelUpdates = BoxesRunTime.unboxToInt((Object)tuple3._1());
                double localLossSum = BoxesRunTime.unboxToDouble((Object)tuple3._2());
                int localRecordsNum = BoxesRunTime.unboxToInt((Object)tuple3._3());
                Tuple3 tuple33 = tuple32 = new Tuple3((Object)BoxesRunTime.boxToInteger((int)numFinishedModelUpdates), (Object)BoxesRunTime.boxToDouble((double)localLossSum), (Object)BoxesRunTime.boxToInteger((int)localRecordsNum));
                int numFinishedModelUpdates2 = BoxesRunTime.unboxToInt((Object)tuple33._1());
                double localLossSum2 = BoxesRunTime.unboxToDouble((Object)tuple33._2());
                int localRecordsNum2 = BoxesRunTime.unboxToInt((Object)tuple33._3());
                dropModelNumBatch += driverSubModelNum - numFinishedModelUpdates2;
                if (dropPercentage == 0.0 || (double)numFinishedModelUpdates2 >= (double)driverSubModelNum * (1.0 - maxDropPercentage)) {
                    Object object;
                    driverState.update("numFinishedModel", BoxesRunTime.boxToInteger((int)numFinishedModelUpdates2));
                    recordsProcessedThisEpoch.elem += localRecordsNum2;
                    long end = System.nanoTime();
                    wallClockTime.elem += end - start2;
                    driverState.update("Loss", BoxesRunTime.boxToDouble((double)(localLossSum2 / (double)numFinishedModelUpdates2)));
                    optimMethods.foreach((Function1)new Serializable(){
                        public static final long serialVersionUID = 0L;

                        public final void apply(Tuple2<String, OptimMethod<T>> v) {
                            ((OptimMethod)v._2()).updateHyperParameter();
                        }
                    });
                    driverState.update(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"LearningRate"})).s((Seq)Nil$.MODULE$), BoxesRunTime.boxToFloat((float)((float)((OptimMethod)((Tuple2)optimMethods.head())._2()).getLearningRate())));
                    driverState.update("Throughput", BoxesRunTime.boxToFloat((float)((float)localRecordsNum2 / (float)((double)(end - start2) / 1.0E9))));
                    String _header = Optimizer$.MODULE$.header(BoxesRunTime.unboxToInt(driverState.apply("epoch")), recordsProcessedThisEpoch.elem, numSamples, BoxesRunTime.unboxToInt(driverState.apply("neval")), wallClockTime.elem);
                    this.logger().info(new StringBuilder().append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"", " Trained ", " records in ", " "})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{_header, BoxesRunTime.boxToInteger((int)localRecordsNum2), BoxesRunTime.boxToDouble((double)((double)(end - start2) / 1.0E9))}))).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"seconds. Throughput is ", " records/second. Loss is ", "."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{driverState.apply("Throughput"), driverState.apply("Loss")}))).toString());
                    this.logger().debug(new StringBuilder().append((Object)"\n").append((Object)metrics.summary(metrics.summary$default$1(), metrics.summary$default$2())).toString());
                    this.logger().debug(new StringBuilder().append((Object)"Dropped modules: ").append((Object)BoxesRunTime.boxToInteger((int)(driverSubModelNum - numFinishedModelUpdates2))).toString());
                    lossArray.elem = new double[_subModelNumber];
                    ++iteration2.elem;
                    if (dropPercentage > 0.0 && iteration2.elem > warmupIterationNum && iteration2.elem % computeThresholdbatchSize == 0) {
                        long[] moduleTimeList = (long[])models.mapPartitions((Function1)new Serializable(){
                            public static final long serialVersionUID = 0L;

                            public final Iterator<Object> apply(Iterator<DistriOptimizer.Cache<T>> iter) {
                                return Predef$.MODULE$.longArrayOps(((DistriOptimizer.Cache)iter.next()).moduleTimeList()).iterator();
                            }
                        }, models.mapPartitions$default$2(), ClassTag$.MODULE$.Long()).collect();
                        int k = (int)(dropPercentage * (double)computeThresholdbatchSize * (double)driverSubModelNum);
                        threshold.elem = k > dropModelNumBatch ? Util$.MODULE$.kthLargest(moduleTimeList, 0, moduleTimeList.length - 1, k - dropModelNumBatch) : (long)((double)threshold.elem * 1.01);
                        this.logger().info(new StringBuilder().append((Object)"threshold: ").append((Object)BoxesRunTime.boxToLong((long)threshold.elem)).toString());
                        models.mapPartitions((Function1)new Serializable(){
                            public static final long serialVersionUID = 0L;

                            public final Iterator<Nothing$> apply(Iterator<DistriOptimizer.Cache<T>> iter) {
                                long[] timeList = ((DistriOptimizer.Cache)iter.next()).moduleTimeList();
                                for (int i = 0; i < timeList.length; ++i) {
                                    timeList[i] = 0L;
                                }
                                return scala.package$.MODULE$.Iterator().empty();
                            }
                        }, models.mapPartitions$default$2(), evidence$1).count();
                        dropModelNumBatch = 0;
                    }
                    driverState.update("neval", BoxesRunTime.boxToInteger((int)(BoxesRunTime.unboxToInt(driverState.apply("neval")) + iterationPerTime)));
                    if (recordsProcessedThisEpoch.elem >= numSamples) {
                        long epochEnd = System.nanoTime();
                        lastEpochTime = wallClockTime.elem = lastEpochTime + epochEnd - epochStart;
                        epochStart = System.nanoTime();
                        this.logger().info(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"", " Epoch finished. Wall clock time is ", " ms"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{_header, BoxesRunTime.boxToDouble((double)((double)wallClockTime.elem / 1000000.0))})));
                        driverState.update("epoch", BoxesRunTime.boxToInteger((int)(BoxesRunTime.unboxToInt(driverState.apply("epoch")) + 1)));
                        dataset.shuffle();
                        dataRDD = (RDD)dataset.data(true);
                        recordsProcessedThisEpoch.elem = 0;
                    }
                    optimMethods.map((Function1)new Serializable(validationMethods, driverState, recordsProcessedThisEpoch){
                        public static final long serialVersionUID = 0L;
                        private final Option validationMethods$1;
                        private final Table driverState$1;
                        private final IntRef recordsProcessedThisEpoch$1;

                        public final Object apply(Tuple2<String, OptimMethod<T>> x0$1) {
                            Tuple2<String, OptimMethod<T>> tuple2 = x0$1;
                            if (tuple2 != null) {
                                OptimMethod optimMethod = (OptimMethod)tuple2._2();
                                optimMethod.state().update("recordsProcessedThisEpoch", BoxesRunTime.boxToInteger((int)this.recordsProcessedThisEpoch$1.elem));
                                optimMethod.state().update("epoch", this.driverState$1.apply("epoch"));
                                optimMethod.state().update("neval", this.driverState$1.apply("neval"));
                                optimMethod.state().update("Loss", this.driverState$1.apply("Loss"));
                                BoxedUnit boxedUnit = this.validationMethods$1.isDefined() ? optimMethod.state().update("score", this.driverState$1.apply("score")) : BoxedUnit.UNIT;
                                return boxedUnit;
                            }
                            throw new MatchError(tuple2);
                        }
                        {
                            this.validationMethods$1 = validationMethods$1;
                            this.driverState$1 = driverState$1;
                            this.recordsProcessedThisEpoch$1 = recordsProcessedThisEpoch$1;
                        }
                    }, Iterable$.MODULE$.canBuildFrom());
                    if (endWhen.apply(driverState)) {
                        this.logger().info(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"training finished, updating all layers parameters"})).s((Seq)Nil$.MODULE$));
                        object = models.mapPartitions((Function1)new Serializable(evidence$1){
                            public static final long serialVersionUID = 0L;
                            public final ClassTag evidence$1$1;

                            public final Iterator<Nothing$> apply(Iterator<DistriOptimizer.Cache<T>> modelIter) {
                                AbstractModule<Activity, Activity, T>[] localModels = ((DistriOptimizer.Cache)modelIter.next()).localModels();
                                Function0[] updateTaskes = (Function0[])Predef$.MODULE$.refArrayOps((Object[])localModels).map((Function1)new Serializable(this){
                                    public static final long serialVersionUID = 0L;
                                    private final /* synthetic */ anonfun.optimize.5 $outer;

                                    public final Function0<BoxedUnit> apply(AbstractModule<Activity, Activity, T> localModel) {
                                        return new Serializable(this, localModel){
                                            public static final long serialVersionUID = 0L;
                                            private final /* synthetic */ anonfun$optimize$5$$anonfun$9 $outer;
                                            private final AbstractModule localModel$1;

                                            public final void apply() {
                                                this.apply$mcV$sp();
                                            }

                                            public void apply$mcV$sp() {
                                                ParallelOptimizer$.MODULE$.com$intel$analytics$bigdl$dllib$optim$ParallelOptimizer$$updateLayerParameters(this.localModel$1, this.$outer.com$intel$analytics$bigdl$dllib$optim$ParallelOptimizer$$anonfun$$anonfun$$$outer().evidence$1$1);
                                            }
                                            {
                                                if ($outer == null) {
                                                    throw null;
                                                }
                                                this.$outer = $outer;
                                                this.localModel$1 = localModel$1;
                                            }
                                        };
                                    }

                                    public /* synthetic */ anonfun.optimize.5 com$intel$analytics$bigdl$dllib$optim$ParallelOptimizer$$anonfun$$anonfun$$$outer() {
                                        return this.$outer;
                                    }
                                    {
                                        if ($outer == null) {
                                            throw null;
                                        }
                                        this.$outer = $outer;
                                    }
                                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Function0.class)));
                                ThreadPool qual$2 = Engine$.MODULE$.default();
                                WrappedArray x$21 = Predef$.MODULE$.wrapRefArray((Object[])updateTaskes);
                                long x$22 = qual$2.invokeAndWait2$default$2();
                                TimeUnit x$23 = qual$2.invokeAndWait2$default$3();
                                qual$2.invokeAndWait2(x$21, x$22, x$23);
                                return scala.package$.MODULE$.Iterator().empty();
                            }
                            {
                                this.evidence$1$1 = evidence$1$1;
                            }
                        }, models.mapPartitions$default$2(), evidence$1).collect();
                    } else {
                        object = BoxedUnit.UNIT;
                    }
                    this.validate$default$9();
                    this.validate(validationTrigger, validationDataSet, validationMethods, coresPerNode, models, driverState, validationSummary, _header, null);
                    trainSummary.foreach((Function1)new Serializable(trainingModel, models, evidence$1, ev, driverState){
                        public static final long serialVersionUID = 0L;
                        private final AbstractModule trainingModel$1;
                        private final RDD models$1;
                        private final ClassTag evidence$1$1;
                        private final TensorNumericMath.TensorNumeric ev$1;
                        private final Table driverState$1;

                        public final void apply(TrainSummary summary2) {
                            ParallelOptimizer$.MODULE$.saveSummary(summary2, this.models$1, this.driverState$1, null, this.trainingModel$1, this.evidence$1$1, this.ev$1);
                        }
                        {
                            this.trainingModel$1 = trainingModel$1;
                            this.models$1 = models$1;
                            this.evidence$1$1 = evidence$1$1;
                            this.ev$1 = ev$1;
                            this.driverState$1 = driverState$1;
                        }
                    });
                    this.checkpoint(cacheTrigger, cachePath, isOverWrite, wallClockTime.elem, models, driverState, null, optimMethods, trainingModel, evidence$1, ev);
                    continue;
                }
                this.logger().info(new StringBuilder().append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Warning! Not enough training samples were successfully processed in this "})).s((Seq)Nil$.MODULE$)).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"iteration due to some slow tasks. The gradients computed in this iteration will be "})).s((Seq)Nil$.MODULE$)).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"discarded. Only ", "/", " threads successfully "})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)numFinishedModelUpdates2), BoxesRunTime.boxToInteger((int)driverSubModelNum)}))).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"completed training."})).s((Seq)Nil$.MODULE$)).toString());
            }
            throw new MatchError((Object)tuple3);
        }
        throw new MatchError((Object)engineType);
    }

    public <T> void com$intel$analytics$bigdl$dllib$optim$ParallelOptimizer$$updateLayerParameters(AbstractModule<Activity, Activity, T> module, ClassTag<T> evidence$2) {
        module.updateParameter();
        if (module instanceof Container) {
            ((Container)module).modules().foreach((Function1)new Serializable(evidence$2){
                public static final long serialVersionUID = 0L;
                private final ClassTag evidence$2$1;

                public final void apply(AbstractModule<Activity, Activity, T> sub2) {
                    ParallelOptimizer$.MODULE$.com$intel$analytics$bigdl$dllib$optim$ParallelOptimizer$$updateLayerParameters(sub2, this.evidence$2$1);
                }
                {
                    this.evidence$2$1 = evidence$2$1;
                }
            });
        }
    }

    public <T> RDD<DistriOptimizer.CacheV1<T>> com$intel$analytics$bigdl$dllib$optim$ParallelOptimizer$$initThreadModels(AbstractModule<Activity, Activity, T> model, DistributedDataSet<MiniBatch<T>> dataset, AbstractCriterion<Activity, Activity, T> criterion, Table state, int nodeNumber, int coresPerNode, boolean checkSingleton, Option<ValidationMethod<T>[]> validationMethods, scala.collection.immutable.Map<String, OptimMethod<T>> optimMethod, Map<String, Object> priorities, ClassTag<T> evidence$3, TensorNumericMath.TensorNumeric<T> ev) {
        EngineType engineType;
        block4: {
            int n;
            ModelBroadcast<T> modelBroadcast;
            Broadcast broadcast;
            block3: {
                block2: {
                    SparkContext sc = dataset.originRDD().sparkContext();
                    broadcast = sc.broadcast((Object)new Tuple4(criterion, (Object)state, validationMethods, optimMethod), ClassTag$.MODULE$.apply(Tuple4.class));
                    modelBroadcast = ModelBroadcast$.MODULE$.apply(evidence$3, ev).broadcast(sc, model);
                    model.getParameters();
                    engineType = Engine$.MODULE$.getEngineType();
                    if (!MklBlas$.MODULE$.equals(engineType)) break block2;
                    n = coresPerNode;
                    break block3;
                }
                if (!MklDnn$.MODULE$.equals(engineType)) break block4;
                n = 1;
            }
            int _subModelNumber = n;
            Log4Error$.MODULE$.invalidOperationError(dataset.originRDD().partitions().length == nodeNumber, new StringBuilder().append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Passed in rdd partition number ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)dataset.originRDD().partitions().length)}))).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{" is not equal to configured node number ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)nodeNumber)}))).toString(), Log4Error$.MODULE$.invalidOperationError$default$3(), Log4Error$.MODULE$.invalidOperationError$default$4());
            int computeThresholdbatchSize = BoxesRunTime.unboxToInt((Object)state.get("computeThresholdbatchSize").get());
            int nExecutor = Engine$.MODULE$.nodeNumber();
            int parameterBlocks = new StringOps(Predef$.MODULE$.augmentString(System.getProperty("bigdl.parallelOptimizer.parameterBlocks", "10"))).toInt();
            RDD<?> qual$3 = dataset.originRDD();
            Serializable x$24 = new Serializable(coresPerNode, checkSingleton, evidence$3, ev, broadcast, modelBroadcast, _subModelNumber, computeThresholdbatchSize, nExecutor, parameterBlocks){
                public static final long serialVersionUID = 0L;
                private final int coresPerNode$1;
                private final boolean checkSingleton$1;
                public final ClassTag evidence$3$1;
                public final TensorNumericMath.TensorNumeric ev$2;
                private final Broadcast broadcast$1;
                public final ModelBroadcast modelBroadcast$1;
                private final int _subModelNumber$2;
                private final int computeThresholdbatchSize$2;
                private final int nExecutor$1;
                public final int parameterBlocks$1;

                public final Iterator<DistriOptimizer.CacheV1<T>> apply(Iterator<Object> x$7) {
                    int partitionId = TaskContext$.MODULE$.getPartitionId();
                    Tuple4 tuple4 = (Tuple4)this.broadcast$1.value();
                    if (tuple4 != null) {
                        Tuple4 tuple42;
                        AbstractCriterion broadcastCriterion = (AbstractCriterion)tuple4._1();
                        Table broadcastState = (Table)tuple4._2();
                        Option broadcastMethod = (Option)tuple4._3();
                        scala.collection.immutable.Map broadcastOptim = (scala.collection.immutable.Map)tuple4._4();
                        Tuple4 tuple43 = tuple42 = new Tuple4((Object)broadcastCriterion, (Object)broadcastState, (Object)broadcastMethod, (Object)broadcastOptim);
                        AbstractCriterion broadcastCriterion2 = (AbstractCriterion)tuple43._1();
                        Table broadcastState2 = (Table)tuple43._2();
                        Option broadcastMethod2 = (Option)tuple43._3();
                        scala.collection.immutable.Map broadcastOptim2 = (scala.collection.immutable.Map)tuple43._4();
                        if (!Engine$.MODULE$.checkSingleton()) {
                            if (this.checkSingleton$1) {
                                Log4Error$.MODULE$.invalidOperationError(Engine$.MODULE$.checkSingleton(), "Partitions of the training data are not evenly distributed across the executors in the Spark cluster; are there sufficient trainingdata to be distributed? Set property \"bigdl.check.singleton\" to false to skip this check", Log4Error$.MODULE$.invalidOperationError$default$3(), Log4Error$.MODULE$.invalidOperationError$default$4());
                            } else {
                                ParallelOptimizer$.MODULE$.logger().warn("Partitions of the training data are not evenlydistributed across the executors in the Spark cluster; are there sufficient trainingdata to be distributed?");
                            }
                        }
                        Engine$.MODULE$.setNodeAndCore(this.nExecutor$1, this.coresPerNode$1);
                        BlockManagerParameterSynchronizer<T> synchronizer = new BlockManagerParameterSynchronizer<T>(partitionId, this.nExecutor$1, this.evidence$3$1, this.ev$2);
                        Tuple6[] cached = (Tuple6[])((TraversableOnce)RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), this._subModelNumber$2).map((Function1)new Serializable(this, partitionId, broadcastCriterion2, broadcastState2, broadcastMethod2, synchronizer){
                            public static final long serialVersionUID = 0L;
                            private final /* synthetic */ anonfun.10 $outer;
                            private final int partitionId$1;
                            private final AbstractCriterion broadcastCriterion$1;
                            private final Table broadcastState$1;
                            private final Option broadcastMethod$1;
                            private final BlockManagerParameterSynchronizer synchronizer$1;

                            public final Tuple6<AbstractModule<Activity, Activity, T>, Tensor<T>, Tensor<T>, AbstractCriterion<Activity, Activity, T>, Table, Option<ValidationMethod<T>[]>> apply(int x$9) {
                                AbstractModule<Activity, Activity, T> localModel = this.$outer.modelBroadcast$1.value(true, false);
                                AbstractModule<Activity, Activity, T> abstractModule = localModel;
                                if (abstractModule instanceof MklDnnContainer) {
                                    MklDnnContainer mklDnnContainer = (MklDnnContainer)((Object)abstractModule);
                                    mklDnnContainer.compile(Phase$TrainingPhase$.MODULE$);
                                    BoxedUnit boxedUnit = BoxedUnit.UNIT;
                                } else {
                                    BoxedUnit boxedUnit = BoxedUnit.UNIT;
                                }
                                ParallelOptimizer$.MODULE$.com$intel$analytics$bigdl$dllib$optim$ParallelOptimizer$$setModelId(localModel, this.partitionId$1, this.$outer.evidence$3$1);
                                ParallelOptimizer$.MODULE$.com$intel$analytics$bigdl$dllib$optim$ParallelOptimizer$$setDistriPartitionsynchronizer(localModel, this.synchronizer$1, (Map<Object, Object>)new HashMap(), this.$outer.parameterBlocks$1, this.$outer.evidence$3$1);
                                AbstractCriterion<A, B, T> localCriterion = this.broadcastCriterion$1.cloneCriterion();
                                Table localState = this.broadcastState$1.clone();
                                None$ localMethod = this.broadcastMethod$1.isDefined() ? new Some(Predef$.MODULE$.refArrayOps((Object[])this.broadcastMethod$1.get()).map((Function1)new Serializable(this){
                                    public static final long serialVersionUID = 0L;

                                    public final ValidationMethod<T> apply(ValidationMethod<T> x$10) {
                                        return x$10.clone();
                                    }
                                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ValidationMethod.class)))) : None$.MODULE$;
                                return new Tuple6(localModel, Tensor$.MODULE$.apply(0, this.$outer.evidence$3$1, this.$outer.ev$2), Tensor$.MODULE$.apply(0, this.$outer.evidence$3$1, this.$outer.ev$2), localCriterion, (Object)localState, (Object)localMethod);
                            }
                            {
                                if ($outer == null) {
                                    throw null;
                                }
                                this.$outer = $outer;
                                this.partitionId$1 = partitionId$1;
                                this.broadcastCriterion$1 = broadcastCriterion$1;
                                this.broadcastState$1 = broadcastState$1;
                                this.broadcastMethod$1 = broadcastMethod$1;
                                this.synchronizer$1 = synchronizer$1;
                            }
                        }, IndexedSeq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(Tuple6.class));
                        ParallelOptimizer$.MODULE$.logger().info(new StringBuilder().append((Object)"model thread pool size is ").append((Object)BoxesRunTime.boxToInteger((int)Engine$.MODULE$.model().getPoolSize())).toString());
                        Tensor weights = (Tensor)((Tuple6)Predef$.MODULE$.refArrayOps((Object[])cached).head())._2();
                        return scala.package$.MODULE$.Iterator().single(new DistriOptimizer.CacheV1<T>((AbstractModule[])Predef$.MODULE$.refArrayOps((Object[])cached).map((Function1)new Serializable(this){
                            public static final long serialVersionUID = 0L;

                            public final AbstractModule<Activity, Activity, T> apply(Tuple6<AbstractModule<Activity, Activity, T>, Tensor<T>, Tensor<T>, AbstractCriterion<Activity, Activity, T>, Table, Option<ValidationMethod<T>[]>> x$11) {
                                return (AbstractModule)x$11._1();
                            }
                        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(AbstractModule.class))), (Tensor[])Predef$.MODULE$.refArrayOps((Object[])cached).map((Function1)new Serializable(this){
                            public static final long serialVersionUID = 0L;

                            public final Tensor<T> apply(Tuple6<AbstractModule<Activity, Activity, T>, Tensor<T>, Tensor<T>, AbstractCriterion<Activity, Activity, T>, Table, Option<ValidationMethod<T>[]>> x$12) {
                                return (Tensor)x$12._2();
                            }
                        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tensor.class))), (Tensor[])Predef$.MODULE$.refArrayOps((Object[])cached).map((Function1)new Serializable(this){
                            public static final long serialVersionUID = 0L;

                            public final Tensor<T> apply(Tuple6<AbstractModule<Activity, Activity, T>, Tensor<T>, Tensor<T>, AbstractCriterion<Activity, Activity, T>, Table, Option<ValidationMethod<T>[]>> x$13) {
                                return (Tensor)x$13._3();
                            }
                        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tensor.class))), (AbstractCriterion[])Predef$.MODULE$.refArrayOps((Object[])cached).map((Function1)new Serializable(this){
                            public static final long serialVersionUID = 0L;

                            public final AbstractCriterion<Activity, Activity, T> apply(Tuple6<AbstractModule<Activity, Activity, T>, Tensor<T>, Tensor<T>, AbstractCriterion<Activity, Activity, T>, Table, Option<ValidationMethod<T>[]>> x$14) {
                                return (AbstractCriterion)x$14._4();
                            }
                        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(AbstractCriterion.class))), (Table[])Predef$.MODULE$.refArrayOps((Object[])cached).map((Function1)new Serializable(this){
                            public static final long serialVersionUID = 0L;

                            public final Table apply(Tuple6<AbstractModule<Activity, Activity, T>, Tensor<T>, Tensor<T>, AbstractCriterion<Activity, Activity, T>, Table, Option<ValidationMethod<T>[]>> x$15) {
                                return (Table)x$15._5();
                            }
                        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Table.class))), new long[this._subModelNumber$2 * this.computeThresholdbatchSize$2], (Option[])Predef$.MODULE$.refArrayOps((Object[])cached).map((Function1)new Serializable(this){
                            public static final long serialVersionUID = 0L;

                            public final Option<ValidationMethod<T>[]> apply(Tuple6<AbstractModule<Activity, Activity, T>, Tensor<T>, Tensor<T>, AbstractCriterion<Activity, Activity, T>, Table, Option<ValidationMethod<T>[]>> x$16) {
                                return (Option)x$16._6();
                            }
                        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Option.class))), (scala.collection.immutable.Map)broadcastOptim2.map((Function1)new Serializable(this){
                            public static final long serialVersionUID = 0L;

                            public final Tuple2<String, OptimMethod<T>> apply(Tuple2<String, OptimMethod<T>> v) {
                                return new Tuple2(v._1(), ((OptimMethod)v._2()).clone());
                            }
                        }, Map$.MODULE$.canBuildFrom()), synchronizer));
                    }
                    throw new MatchError((Object)tuple4);
                }
                {
                    this.coresPerNode$1 = coresPerNode$1;
                    this.checkSingleton$1 = checkSingleton$1;
                    this.evidence$3$1 = evidence$3$1;
                    this.ev$2 = ev$2;
                    this.broadcast$1 = broadcast$1;
                    this.modelBroadcast$1 = modelBroadcast$1;
                    this._subModelNumber$2 = _subModelNumber$2;
                    this.computeThresholdbatchSize$2 = computeThresholdbatchSize$2;
                    this.nExecutor$1 = nExecutor$1;
                    this.parameterBlocks$1 = parameterBlocks$1;
                }
            };
            boolean x$25 = qual$3.mapPartitions$default$2();
            RDD models = qual$3.mapPartitions((Function1)x$24, x$25, ClassTag$.MODULE$.apply(DistriOptimizer.CacheV1.class)).persist();
            models.setName("Thread Model RDD");
            this.logger().info("Cache thread models...");
            models.count();
            this.logger().info("Cache thread models... done");
            return models;
        }
        throw new MatchError((Object)engineType);
    }

    /*
     * WARNING - void declaration
     */
    public <T> ArrayBuffer<AbstractModule<Activity, Activity, T>> com$intel$analytics$bigdl$dllib$optim$ParallelOptimizer$$getExecutionOrder(AbstractModule<Activity, Activity, T> module, ClassTag<T> evidence$4) {
        void var3_3;
        Object object;
        ArrayBuffer res = new ArrayBuffer();
        if (module instanceof Container) {
            ArrayBuffer subModules = ((Container)module).modules();
            subModules.foreach((Function1)new Serializable(evidence$4, res){
                public static final long serialVersionUID = 0L;
                private final ClassTag evidence$4$1;
                private final ArrayBuffer res$1;

                public final ArrayBuffer<AbstractModule<Activity, Activity, T>> apply(AbstractModule<Activity, Activity, T> sub2) {
                    return this.res$1.$plus$plus$eq(ParallelOptimizer$.MODULE$.com$intel$analytics$bigdl$dllib$optim$ParallelOptimizer$$getExecutionOrder(sub2, this.evidence$4$1));
                }
                {
                    this.evidence$4$1 = evidence$4$1;
                    this.res$1 = res$1;
                }
            });
            object = BoxedUnit.UNIT;
        } else {
            object = module.parameters() == null ? BoxedUnit.UNIT : res.$plus$eq(module);
        }
        return var3_3;
    }

    public <T> void com$intel$analytics$bigdl$dllib$optim$ParallelOptimizer$$setDistriPartitionsynchronizer(AbstractModule<Activity, Activity, T> model, DistriParameterSynchronizer<T> parameterSynchronizer, Map<Object, Object> barrierLayers, int slices, ClassTag<T> evidence$5) {
        Tensor globalWeights = (Tensor)model.getParameters()._1();
        Tensor globalGrads = (Tensor)model.getParameters()._2();
        int totalSize = globalGrads.nElement();
        ArrayBuffer<AbstractModule<Activity, Activity, T>> executorOrders = this.com$intel$analytics$bigdl$dllib$optim$ParallelOptimizer$$getExecutionOrder(model, evidence$5);
        int size = totalSize / slices - 1;
        int extraSize = totalSize - size * (slices - 1);
        int lastOffSet = totalSize;
        for (int i = executorOrders.length() - 1; i >= 0; --i) {
            AbstractModule currModule = (AbstractModule)executorOrders.apply(i);
            if (currModule.parameters() == null) continue;
            Tensor grads = (Tensor)currModule.getParameters()._1();
            int offSet = grads.storageOffset() - 1;
            int index = offSet == 0 ? 0 : (offSet - 1) / size + 1;
            int currParSize = lastOffSet - offSet;
            if (index >= slices || barrierLayers.contains((Object)BoxesRunTime.boxToInteger((int)index))) continue;
            barrierLayers.put((Object)BoxesRunTime.boxToInteger((int)index), (Object)BoxesRunTime.boxToInteger((int)offSet));
            Tensor weightsPar = globalWeights.narrow(1, offSet + 1, currParSize);
            Tensor gradsPar = globalGrads.narrow(1, offSet + 1, currParSize);
            parameterSynchronizer.init(currModule.getName(), currParSize, executorOrders.length() - i, weightsPar, gradsPar);
            currModule.setParameterSynchronizer(parameterSynchronizer);
            lastOffSet = offSet;
        }
    }

    public <T> void com$intel$analytics$bigdl$dllib$optim$ParallelOptimizer$$setModelId(AbstractModule<Activity, Activity, T> model, int partitionId, ClassTag<T> evidence$6) {
        model.setId(partitionId);
        if (model instanceof Container) {
            ((Container)model).modules().foreach((Function1)new Serializable(partitionId, evidence$6){
                public static final long serialVersionUID = 0L;
                private final int partitionId$2;
                private final ClassTag evidence$6$1;

                public final void apply(AbstractModule<Activity, Activity, T> sub2) {
                    ParallelOptimizer$.MODULE$.com$intel$analytics$bigdl$dllib$optim$ParallelOptimizer$$setModelId(sub2, this.partitionId$2, this.evidence$6$1);
                }
                {
                    this.partitionId$2 = partitionId$2;
                    this.evidence$6$1 = evidence$6$1;
                }
            });
        }
    }

    @Override
    public <T> AbstractModule<Activity, Activity, T> getModel(RDD<DistriOptimizer.Cache<T>> models, AllReduceParameter<T> parameters2, AbstractModule<Activity, Activity, T> trainingModel, ClassTag<T> evidence$7, TensorNumericMath.TensorNumeric<T> ev) {
        int partitionNum = models.partitions().length;
        Tensor[] extraState = (Tensor[])models.map((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final Tensor<T>[] apply(DistriOptimizer.Cache<T> x$17) {
                return ((AbstractModule)Predef$.MODULE$.refArrayOps((Object[])x$17.localModels()).head()).getExtraParameter();
            }
        }, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Tensor.class))).first();
        trainingModel.setExtraParameter(extraState);
        Tuple2<Tensor<T>[], Tensor<T>[]> parameterArray = trainingModel.parameters();
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), ((Tensor[])parameterArray._2()).length).foreach((Function1)new Serializable(parameterArray){
            public static final long serialVersionUID = 0L;
            private final Tuple2 parameterArray$1;

            public final Tensor<T> apply(int i) {
                return ((Tensor[])this.parameterArray$1._2())[i].resizeAs(((Tensor[])this.parameterArray$1._1())[i]);
            }
            {
                this.parameterArray$1 = parameterArray$1;
            }
        });
        Tensor parameter = (Tensor)trainingModel.getParameters()._1();
        ClassTag _classTag = package$.MODULE$.classTag(evidence$7);
        int size = ScalaRunTime$.MODULE$.array_length(parameter.storage().array());
        int taskSize = size / partitionNum;
        int extraSize = size % partitionNum;
        scala.collection.immutable.Map weights = (scala.collection.immutable.Map)models.mapPartitions((Function1)new Serializable(evidence$7, ev, taskSize, extraSize){
            public static final long serialVersionUID = 0L;
            private final ClassTag evidence$7$1;
            private final TensorNumericMath.TensorNumeric ev$3;
            private final int taskSize$1;
            private final int extraSize$1;

            public final Iterator<scala.collection.immutable.Map<Object, Tensor<T>>> apply(Iterator<DistriOptimizer.Cache<T>> iter) {
                DistriOptimizer.Cache localCache = (DistriOptimizer.Cache)iter.next();
                AbstractModule<Activity, Activity, T>[] localModels = localCache.localModels();
                Tensor localWeights = (Tensor)((AbstractModule)Predef$.MODULE$.refArrayOps((Object[])localModels).head()).getParameters()._1();
                BlockManagerParameterSynchronizer synchronizer = (BlockManagerParameterSynchronizer)localCache.parameterSynchronizer();
                int partitionId = synchronizer.partitionID();
                int start2 = partitionId * this.taskSize$1 + scala.math.package$.MODULE$.min(partitionId, this.extraSize$1);
                int length = this.taskSize$1 + (partitionId < this.extraSize$1 ? 1 : 0);
                Tensor<T> partitionWeight = Tensor$.MODULE$.apply(length, this.evidence$7$1, this.ev$3);
                partitionWeight.copy(localWeights.narrow(1, start2 + 1, length));
                return scala.package$.MODULE$.Iterator().single((Object)Predef$.MODULE$.Map().apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)BoxesRunTime.boxToInteger((int)partitionId)), partitionWeight)})));
            }
            {
                this.evidence$7$1 = evidence$7$1;
                this.ev$3 = ev$3;
                this.taskSize$1 = taskSize$1;
                this.extraSize$1 = extraSize$1;
            }
        }, models.mapPartitions$default$2(), ClassTag$.MODULE$.apply(scala.collection.immutable.Map.class)).reduce((Function2)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final scala.collection.immutable.Map<Object, Tensor<T>> apply(scala.collection.immutable.Map<Object, Tensor<T>> a, scala.collection.immutable.Map<Object, Tensor<T>> b) {
                return a.$plus$plus(b);
            }
        });
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), partitionNum).map((Function1)new Serializable(parameter, taskSize, extraSize, weights){
            public static final long serialVersionUID = 0L;
            private final Tensor parameter$1;
            private final int taskSize$1;
            private final int extraSize$1;
            private final scala.collection.immutable.Map weights$1;

            public final Tensor<T> apply(int pid) {
                int start2 = this.parameter$1.storageOffset() + pid * this.taskSize$1 + scala.math.package$.MODULE$.min(pid, this.extraSize$1);
                int length = this.taskSize$1 + (pid < this.extraSize$1 ? 1 : 0);
                return this.parameter$1.narrow(1, start2, length).copy((Tensor)this.weights$1.apply((Object)BoxesRunTime.boxToInteger((int)pid)));
            }
            {
                this.parameter$1 = parameter$1;
                this.taskSize$1 = taskSize$1;
                this.extraSize$1 = extraSize$1;
                this.weights$1 = weights$1;
            }
        }, IndexedSeq$.MODULE$.canBuildFrom());
        return trainingModel;
    }

    private ParallelOptimizer$() {
        MODULE$ = this;
        this.logger = LogManager.getLogger(this.getClass());
    }
}

