/*
 * Decompiled with CFR 0.152.
 */
package com.intel.analytics.bigdl.dllib.optim;

import com.intel.analytics.bigdl.dllib.feature.dataset.AbstractDataSet;
import com.intel.analytics.bigdl.dllib.feature.dataset.DistributedDataSet;
import com.intel.analytics.bigdl.dllib.feature.dataset.MiniBatch;
import com.intel.analytics.bigdl.dllib.models.utils.ModelBroadcast;
import com.intel.analytics.bigdl.dllib.models.utils.ModelBroadcast$;
import com.intel.analytics.bigdl.dllib.nn.Container;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractCriterion;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.DnnGraph;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.MklDnnContainer;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.Phase$TrainingPhase$;
import com.intel.analytics.bigdl.dllib.optim.AbstractOptimizer;
import com.intel.analytics.bigdl.dllib.optim.DistriOptimizer;
import com.intel.analytics.bigdl.dllib.optim.DistriOptimizer$;
import com.intel.analytics.bigdl.dllib.optim.DistriOptimizer$$anonfun$11$$anonfun$12$;
import com.intel.analytics.bigdl.dllib.optim.DistriOptimizer$$anonfun$4$;
import com.intel.analytics.bigdl.dllib.optim.DistriOptimizer$$anonfun$4$$anonfun$5$;
import com.intel.analytics.bigdl.dllib.optim.DistriOptimizer$$anonfun$4$$anonfun$5$$anonfun$apply$2$;
import com.intel.analytics.bigdl.dllib.optim.DistriOptimizer$$anonfun$optimize$4$;
import com.intel.analytics.bigdl.dllib.optim.DistriOptimizer$CacheV1$;
import com.intel.analytics.bigdl.dllib.optim.Metrics;
import com.intel.analytics.bigdl.dllib.optim.OptimMethod;
import com.intel.analytics.bigdl.dllib.optim.Optimizer$;
import com.intel.analytics.bigdl.dllib.optim.Trigger;
import com.intel.analytics.bigdl.dllib.optim.ValidationMethod;
import com.intel.analytics.bigdl.dllib.optim.parameters.AllReduceParameter;
import com.intel.analytics.bigdl.dllib.optim.parameters.FutureResult;
import com.intel.analytics.bigdl.dllib.optim.parameters.ParameterProcessor;
import com.intel.analytics.bigdl.dllib.tensor.ConvertableFrom$ConvertableFromDouble$;
import com.intel.analytics.bigdl.dllib.tensor.ConvertableTo$ConvertableToDouble$;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.Engine$;
import com.intel.analytics.bigdl.dllib.utils.EngineType;
import com.intel.analytics.bigdl.dllib.utils.Log4Error$;
import com.intel.analytics.bigdl.dllib.utils.MklBlas$;
import com.intel.analytics.bigdl.dllib.utils.MklDnn$;
import com.intel.analytics.bigdl.dllib.utils.T$;
import com.intel.analytics.bigdl.dllib.utils.Table;
import com.intel.analytics.bigdl.dllib.utils.ThreadPool;
import com.intel.analytics.bigdl.dllib.utils.Util$;
import com.intel.analytics.bigdl.dllib.utils.intermediate.ConversionUtils$;
import com.intel.analytics.bigdl.dllib.utils.intermediate.IRGraph;
import com.intel.analytics.bigdl.dllib.visualization.TrainSummary;
import com.intel.analytics.bigdl.dllib.visualization.ValidationSummary;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.spark.SparkContext;
import org.apache.spark.TaskContext$;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.rdd.RDD;
import org.apache.spark.util.DoubleAccumulator;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.Function2;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef;
import scala.Predef$;
import scala.Serializable;
import scala.Some;
import scala.StringContext;
import scala.Tuple2;
import scala.Tuple4;
import scala.Tuple6;
import scala.collection.GenTraversableOnce;
import scala.collection.Iterator;
import scala.collection.Iterator$;
import scala.collection.Seq;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.IndexedSeq;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.immutable.Iterable$;
import scala.collection.immutable.Map;
import scala.collection.immutable.Map$;
import scala.collection.immutable.MapLike;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.ArrayBuffer$;
import scala.collection.mutable.Buffer;
import scala.collection.mutable.Buffer$;
import scala.collection.mutable.StringBuilder;
import scala.concurrent.duration.Duration;
import scala.math.package$;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.LongRef;
import scala.runtime.Nothing$;
import scala.runtime.ObjectRef;
import scala.runtime.RichInt$;

public final class DistriOptimizer$
extends AbstractOptimizer {
    public static final DistriOptimizer$ MODULE$;
    private final Logger logger;

    static {
        new DistriOptimizer$();
    }

    public Logger logger() {
        return this.logger;
    }

    public <T> void optimize(AbstractModule<Activity, Activity, T> trainingModel, DistributedDataSet<MiniBatch<T>> dataset, int coresPerNode, Table state, Trigger endWhen, Metrics metrics, RDD<DistriOptimizer.CacheV1<T>> models, Map<String, OptimMethod<T>> optimMethods, AllReduceParameter<T> parameters2, Map<String, Tuple2<Object, Object>> parameterSplits, Option<Trigger> validationTrigger, Option<AbstractDataSet<MiniBatch<T>, ?>> validationDataSet, Option<ValidationMethod<T>[]> validationMethods, Option<Trigger> cacheTrigger, Option<String> cachePath, Option<TrainSummary> trainSummary, Option<ValidationSummary> validationSummary, boolean isOverWrite, ParameterProcessor[] parameterProcessers, ClassTag<T> evidence$1, TensorNumericMath.TensorNumeric<T> ev) {
        EngineType engineType;
        block12: {
            int n;
            long lastEpochTime;
            long wallClockTime;
            int partitionNum;
            SparkContext sc;
            block11: {
                block10: {
                    sc = dataset.originRDD().sparkContext();
                    partitionNum = dataset.originRDD().partitions().length;
                    wallClockTime = 0L;
                    lastEpochTime = 0L;
                    optimMethods.values().foreach((Function1)new Serializable(){
                        public static final long serialVersionUID = 0L;

                        public final Object apply(OptimMethod<T> optimMethod) {
                            Object object = optimMethod.state().contains("epoch") ? BoxedUnit.UNIT : optimMethod.state().update("epoch", BoxesRunTime.boxToInteger((int)1));
                            Object object2 = optimMethod.state().contains("neval") ? BoxedUnit.UNIT : optimMethod.state().update("neval", BoxesRunTime.boxToInteger((int)1));
                            Object object3 = optimMethod.state().contains("Loss") ? BoxedUnit.UNIT : optimMethod.state().update("Loss", BoxesRunTime.boxToFloat((float)Float.POSITIVE_INFINITY));
                            Object object4 = optimMethod.state().contains("score") ? BoxedUnit.UNIT : optimMethod.state().update("score", BoxesRunTime.boxToFloat((float)0.0f));
                            return optimMethod.state().contains("recordsProcessedThisEpoch") ? BoxedUnit.UNIT : optimMethod.state().update("recordsProcessedThisEpoch", BoxesRunTime.boxToInteger((int)0));
                        }
                    });
                    engineType = Engine$.MODULE$.getEngineType();
                    if (!MklBlas$.MODULE$.equals(engineType)) break block10;
                    n = coresPerNode;
                    break block11;
                }
                if (!MklDnn$.MODULE$.equals(engineType)) break block12;
                n = 1;
            }
            int _subModelNumber = n;
            Table driverState = T$.MODULE$.apply((Tuple2<Object, Object>)Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"epoch"), ((OptimMethod)optimMethods.values().head()).state().apply("epoch")), (Seq<Tuple2<Object, Object>>)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"neval"), ((OptimMethod)optimMethods.values().head()).state().apply("neval")), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"Loss"), ((OptimMethod)optimMethods.values().head()).state().apply("Loss")), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"score"), ((OptimMethod)optimMethods.values().head()).state().apply("score")), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"parallelism"), (Object)BoxesRunTime.boxToInteger((int)_subModelNumber))}));
            this.logger().info("Count dataset");
            long countBefore = System.nanoTime();
            int numSamples = BoxesRunTime.unboxToInt((Object)((RDD)dataset.data(false)).map((Function1)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final int apply(MiniBatch<T> x$1) {
                    return x$1.size();
                }
            }, ClassTag$.MODULE$.Int()).reduce((Function2)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final int apply(int x$2, int x$3) {
                    return this.apply$mcIII$sp(x$2, x$3);
                }

                public int apply$mcIII$sp(int x$2, int x$3) {
                    return x$2 + x$3;
                }
            }));
            long countAfter = System.nanoTime();
            this.logger().info(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Count dataset complete. Time elapsed: ", "s"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToDouble((double)((double)(countAfter - countBefore) / 1.0E9))})));
            if ((long)numSamples != dataset.size()) {
                this.logger().warn("If the dataset is built directly from RDD[Minibatch], the data in each minibatch is fixed, and a single minibatch is randomly selected in each partition. If the dataset is transformed from RDD[Sample], each minibatch will be constructed on the fly from random samples, which is better for convergence.");
            }
            this.logger().info(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"config ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{state})));
            IntRef recordsProcessedThisEpoch = IntRef.create((int)BoxesRunTime.unboxToInt(((OptimMethod)optimMethods.values().head()).state().apply("recordsProcessedThisEpoch")));
            if (recordsProcessedThisEpoch.elem == 0) {
                long shuffleBefore = System.nanoTime();
                this.logger().info("Shuffle data");
                dataset.shuffle();
                long shuffleEnd = System.nanoTime();
                this.logger().info(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Shuffle data complete. Takes ", "s"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToDouble((double)((double)(shuffleEnd - shuffleBefore) / 1.0E9))})));
            }
            ObjectRef tasks = ObjectRef.create((Object)new ArrayBuffer());
            LongRef threshold = LongRef.create((long)Long.MAX_VALUE);
            LongRef timeout = LongRef.create((long)Long.MAX_VALUE);
            IntRef iteration2 = IntRef.create((int)0);
            double dropPercentage = BoxesRunTime.unboxToDouble((Object)state.get("dropPercentage").get());
            int warmupIterationNum = BoxesRunTime.unboxToInt((Object)state.get("warmupIterationNum").get());
            int computeThresholdbatchSize = BoxesRunTime.unboxToInt((Object)state.get("computeThresholdbatchSize").get());
            double maxDropPercentage = BoxesRunTime.unboxToDouble((Object)state.get("maxDropPercentage").get());
            int driverSubModelNum = partitionNum * _subModelNumber;
            int dropModelNumBatch = 0;
            ObjectRef lossArray = ObjectRef.create((Object)new double[_subModelNumber]);
            long epochStart = System.nanoTime();
            RDD dataRDD = (RDD)dataset.data(true);
            while (true) {
                if (endWhen.apply(driverState)) {
                    return;
                }
                DoubleAccumulator lossSum = sc.doubleAccumulator("loss sum");
                DoubleAccumulator recordsNum = sc.doubleAccumulator("record number");
                metrics.set("computing time for each node", (ArrayBuffer<Object>)((ArrayBuffer)ArrayBuffer$.MODULE$.apply((Seq)Nil$.MODULE$)), sc);
                metrics.set("get weights for each node", (ArrayBuffer<Object>)((ArrayBuffer)ArrayBuffer$.MODULE$.apply((Seq)Nil$.MODULE$)), sc);
                metrics.set("computing time average", 0.0, sc, partitionNum);
                metrics.set("aggregate gradient time", 0.0, sc, partitionNum);
                metrics.set("get weights average", 0.0, sc, partitionNum);
                metrics.set("put gradient", 0.0, sc, Engine$.MODULE$.nodeNumber());
                metrics.set("aggregrateGradientParition average executor", 0.0, sc, Engine$.MODULE$.nodeNumber());
                metrics.set("compute weight average", 0.0, sc, Engine$.MODULE$.nodeNumber());
                metrics.set("send weights average", 0.0, sc, Engine$.MODULE$.nodeNumber());
                Metrics driverMetrics = metrics;
                long start2 = System.nanoTime();
                int numFinishedModelUpdates = BoxesRunTime.unboxToInt((Object)dataRDD.zipPartitions(models, true, (Function2)new Serializable(parameters2, ev, _subModelNumber, tasks, threshold, timeout, iteration2, dropPercentage, warmupIterationNum, computeThresholdbatchSize, lossArray, lossSum, recordsNum, driverMetrics){
                    public static final long serialVersionUID = 0L;
                    private final AllReduceParameter parameters$1;
                    public final TensorNumericMath.TensorNumeric ev$2;
                    public final int _subModelNumber$1;
                    private final ObjectRef tasks$1;
                    private final LongRef threshold$1;
                    private final LongRef timeout$1;
                    private final IntRef iteration$1;
                    private final double dropPercentage$1;
                    private final int warmupIterationNum$1;
                    private final int computeThresholdbatchSize$1;
                    public final ObjectRef lossArray$1;
                    private final DoubleAccumulator lossSum$1;
                    private final DoubleAccumulator recordsNum$1;
                    private final Metrics driverMetrics$1;

                    public final Iterator<Object> apply(Iterator<MiniBatch<T>> data2, Iterator<DistriOptimizer.CacheV1<T>> modelIter) {
                        Metrics metrics;
                        DistriOptimizer.CacheV1 cached = (DistriOptimizer.CacheV1)modelIter.next();
                        long syWStart = System.nanoTime();
                        FutureResult<Object> weightsResults = this.parameters$1.getWeights(((Tensor)Predef$.MODULE$.refArrayOps((Object[])cached.modelWeights()).head()).narrow(1, this.parameters$1.paramOffset(), this.parameters$1.size()));
                        MiniBatch[] miniBatchBuffer = new MiniBatch[this._subModelNumber$1];
                        MiniBatch batch = (MiniBatch)data2.next();
                        int stackSize = batch.size() / this._subModelNumber$1;
                        ((ArrayBuffer)this.tasks$1.elem).$plus$eq(Engine$.MODULE$.default().invoke(new Serializable(this, miniBatchBuffer, batch, stackSize){
                            public static final long serialVersionUID = 0L;
                            private final /* synthetic */ anonfun.4 $outer;
                            private final MiniBatch[] miniBatchBuffer$1;
                            private final MiniBatch batch$1;
                            private final int stackSize$1;

                            public final void apply() {
                                this.apply$mcV$sp();
                            }

                            public void apply$mcV$sp() {
                                Log4Error$.MODULE$.invalidOperationError(this.batch$1.size() >= this.$outer._subModelNumber$1 && this.batch$1.size() % this.$outer._subModelNumber$1 == 0, new StringBuilder().append((Object)"total batch size: ").append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"", " should be divided by total core number: ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)this.batch$1.size()), BoxesRunTime.boxToInteger((int)this.$outer._subModelNumber$1)}))).toString(), Log4Error$.MODULE$.invalidOperationError$default$3(), Log4Error$.MODULE$.invalidOperationError$default$4());
                                if (this.batch$1.size() < this.$outer._subModelNumber$1 * 2) {
                                    DistriOptimizer$.MODULE$.logger().warn(new StringBuilder().append((Object)"Warning: for better training speed, total batch size is recommended to be at least two times of core number").append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"", ", please tune your batch size accordingly"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)this.$outer._subModelNumber$1)}))).toString());
                                }
                                for (int b = 0; b < this.$outer._subModelNumber$1; ++b) {
                                    this.miniBatchBuffer$1[b] = this.batch$1.slice(b * this.stackSize$1 + 1, this.stackSize$1);
                                }
                            }
                            {
                                if ($outer == null) {
                                    throw null;
                                }
                                this.$outer = $outer;
                                this.miniBatchBuffer$1 = miniBatchBuffer$1;
                                this.batch$1 = batch$1;
                                this.stackSize$1 = stackSize$1;
                            }
                        }));
                        ThreadPool qual$1 = Engine$.MODULE$.default();
                        ArrayBuffer x$31 = (ArrayBuffer)this.tasks$1.elem;
                        Duration x$32 = qual$1.sync$default$2();
                        qual$1.sync((Seq<scala.concurrent.Future<?>>)x$31, x$32);
                        weightsResults.waitResult();
                        long weightSyncTime = System.nanoTime() - syWStart;
                        this.driverMetrics$1.add("get weights average", weightSyncTime);
                        this.driverMetrics$1.add("get weights for each node", weightSyncTime);
                        ((ArrayBuffer)this.tasks$1.elem).clear();
                        long time = System.nanoTime();
                        if (this.dropPercentage$1 > 0.0 && this.iteration$1.elem > this.warmupIterationNum$1 + this.computeThresholdbatchSize$1 - 1) {
                            this.timeout$1.elem = this.threshold$1.elem - weightSyncTime;
                        }
                        int pre = this.iteration$1.elem % this.computeThresholdbatchSize$1 * this._subModelNumber$1;
                        ThreadPool qual$2 = Engine$.MODULE$.default();
                        IndexedSeq x$33 = (IndexedSeq)RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), this._subModelNumber$1).map((Function1)new Serializable(this, cached, miniBatchBuffer, weightSyncTime, pre){
                            public static final long serialVersionUID = 0L;
                            private final /* synthetic */ anonfun.4 $outer;
                            public final DistriOptimizer.CacheV1 cached$1;
                            public final MiniBatch[] miniBatchBuffer$1;
                            public final long weightSyncTime$1;
                            public final int pre$1;

                            public final Function0<Object> apply(int i) {
                                return new Serializable(this, i){
                                    public static final long serialVersionUID = 0L;
                                    private final /* synthetic */ anonfun$4$$anonfun$5 $outer;
                                    public final int i$1;

                                    public final int apply() {
                                        return this.apply$mcI$sp();
                                    }

                                    public int apply$mcI$sp() {
                                        Object object;
                                        long trainStart = System.nanoTime();
                                        AbstractModule<Activity, Activity, T> localModel = this.$outer.cached$1.localModels()[this.i$1];
                                        localModel.training();
                                        AbstractCriterion<Activity, Activity, T> localCriterion = this.$outer.cached$1.localCriterions()[this.i$1];
                                        Activity input = this.$outer.miniBatchBuffer$1[this.i$1].getInput();
                                        Activity target = this.$outer.miniBatchBuffer$1[this.i$1].getTarget();
                                        EngineType engineType = Engine$.MODULE$.getEngineType();
                                        MklBlas$ mklBlas$ = MklBlas$.MODULE$;
                                        if (!(engineType != null ? !engineType.equals(mklBlas$) : mklBlas$ != null)) {
                                            Activity output = localModel.forward(input);
                                            ((double[])this.$outer.com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$anonfun$$anonfun$$$outer().lossArray$1.elem)[this.i$1] = BoxesRunTime.unboxToDouble((Object)this.$outer.com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$anonfun$$anonfun$$$outer().ev$2.toType(localCriterion.forward(output, target), ConvertableTo$ConvertableToDouble$.MODULE$));
                                            Activity errors = localCriterion.backward(output, target);
                                            object = localModel.backward(input, errors);
                                        } else if (localModel instanceof IRGraph) {
                                            Activity output = localModel.forward(input);
                                            Engine$.MODULE$.dnnComputing().invokeAndWait2(Predef$.MODULE$.wrapRefArray((Object[])Predef$.MODULE$.intArrayOps(new int[]{0}).map((Function1)new Serializable(this, localCriterion, target, output){
                                                public static final long serialVersionUID = 0L;
                                                private final /* synthetic */ anonfun$4$$anonfun$5$$anonfun$apply$2 $outer;
                                                public final AbstractCriterion localCriterion$1;
                                                public final Activity target$1;
                                                public final Activity output$1;

                                                public final Function0<Activity> apply(int x$4) {
                                                    return new Serializable(this){
                                                        public static final long serialVersionUID = 0L;
                                                        private final /* synthetic */ anonfun$4$$anonfun$5$$anonfun$apply$2$$anonfun$apply$mcI$sp$1 $outer;

                                                        public final Activity apply() {
                                                            ((double[])this.$outer.com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$anonfun$$anonfun$$anonfun$$anonfun$$$outer().com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$anonfun$$anonfun$$anonfun$$$outer().com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$anonfun$$anonfun$$$outer().lossArray$1.elem)[this.$outer.com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$anonfun$$anonfun$$anonfun$$anonfun$$$outer().i$1] = BoxesRunTime.unboxToDouble((Object)this.$outer.com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$anonfun$$anonfun$$anonfun$$anonfun$$$outer().com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$anonfun$$anonfun$$anonfun$$$outer().com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$anonfun$$anonfun$$$outer().ev$2.toType(this.$outer.localCriterion$1.forward(this.$outer.output$1, this.$outer.target$1), ConvertableTo$ConvertableToDouble$.MODULE$));
                                                            return this.$outer.localCriterion$1.backward(this.$outer.output$1, this.$outer.target$1);
                                                        }
                                                        {
                                                            if ($outer == null) {
                                                                throw null;
                                                            }
                                                            this.$outer = $outer;
                                                        }
                                                    };
                                                }

                                                public /* synthetic */ anonfun$4$$anonfun$5$$anonfun$apply$2 com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$anonfun$$anonfun$$anonfun$$anonfun$$$outer() {
                                                    return this.$outer;
                                                }
                                                {
                                                    if ($outer == null) {
                                                        throw null;
                                                    }
                                                    this.$outer = $outer;
                                                    this.localCriterion$1 = localCriterion$1;
                                                    this.target$1 = target$1;
                                                    this.output$1 = output$1;
                                                }
                                            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Function0.class)))), Engine$.MODULE$.dnnComputing().invokeAndWait2$default$2(), Engine$.MODULE$.dnnComputing().invokeAndWait2$default$3());
                                            object = localModel.backward(input, localCriterion.gradInput());
                                        } else {
                                            object = Engine$.MODULE$.dnnComputing().invokeAndWait2(Predef$.MODULE$.wrapRefArray((Object[])Predef$.MODULE$.intArrayOps(new int[]{0}).map((Function1)new Serializable(this, localModel, localCriterion, input, target){
                                                public static final long serialVersionUID = 0L;
                                                private final /* synthetic */ anonfun$4$$anonfun$5$$anonfun$apply$2 $outer;
                                                public final AbstractModule localModel$1;
                                                public final AbstractCriterion localCriterion$1;
                                                public final Activity input$1;
                                                public final Activity target$1;

                                                public final Function0<Activity> apply(int x$5) {
                                                    return new Serializable(this){
                                                        public static final long serialVersionUID = 0L;
                                                        private final /* synthetic */ anonfun$4$$anonfun$5$$anonfun$apply$2$$anonfun$apply$mcI$sp$2 $outer;

                                                        public final Activity apply() {
                                                            B output = this.$outer.localModel$1.forward(this.$outer.input$1);
                                                            ((double[])this.$outer.com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$anonfun$$anonfun$$anonfun$$anonfun$$$outer().com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$anonfun$$anonfun$$anonfun$$$outer().com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$anonfun$$anonfun$$$outer().lossArray$1.elem)[this.$outer.com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$anonfun$$anonfun$$anonfun$$anonfun$$$outer().i$1] = BoxesRunTime.unboxToDouble((Object)this.$outer.com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$anonfun$$anonfun$$anonfun$$anonfun$$$outer().com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$anonfun$$anonfun$$anonfun$$$outer().com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$anonfun$$anonfun$$$outer().ev$2.toType(this.$outer.localCriterion$1.forward(output, this.$outer.target$1), ConvertableTo$ConvertableToDouble$.MODULE$));
                                                            B errors = this.$outer.localCriterion$1.backward(output, this.$outer.target$1);
                                                            return this.$outer.localModel$1.backward(this.$outer.input$1, errors);
                                                        }
                                                        {
                                                            if ($outer == null) {
                                                                throw null;
                                                            }
                                                            this.$outer = $outer;
                                                        }
                                                    };
                                                }

                                                public /* synthetic */ anonfun$4$$anonfun$5$$anonfun$apply$2 com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$anonfun$$anonfun$$anonfun$$anonfun$$$outer() {
                                                    return this.$outer;
                                                }
                                                {
                                                    if ($outer == null) {
                                                        throw null;
                                                    }
                                                    this.$outer = $outer;
                                                    this.localModel$1 = localModel$1;
                                                    this.localCriterion$1 = localCriterion$1;
                                                    this.input$1 = input$1;
                                                    this.target$1 = target$1;
                                                }
                                            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Function0.class)))), Engine$.MODULE$.dnnComputing().invokeAndWait2$default$2(), Engine$.MODULE$.dnnComputing().invokeAndWait2$default$3());
                                        }
                                        this.$outer.cached$1.moduleTimeList()[this.i$1 + this.$outer.pre$1] = System.nanoTime() - trainStart + this.$outer.weightSyncTime$1;
                                        return this.i$1;
                                    }

                                    public /* synthetic */ anonfun$4$$anonfun$5 com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$anonfun$$anonfun$$anonfun$$$outer() {
                                        return this.$outer;
                                    }
                                    {
                                        if ($outer == null) {
                                            throw null;
                                        }
                                        this.$outer = $outer;
                                        this.i$1 = i$1;
                                    }
                                };
                            }

                            public /* synthetic */ anonfun.4 com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$anonfun$$anonfun$$$outer() {
                                return this.$outer;
                            }
                            {
                                if ($outer == null) {
                                    throw null;
                                }
                                this.$outer = $outer;
                                this.cached$1 = cached$1;
                                this.miniBatchBuffer$1 = miniBatchBuffer$1;
                                this.weightSyncTime$1 = weightSyncTime$1;
                                this.pre$1 = pre$1;
                            }
                        }, IndexedSeq$.MODULE$.canBuildFrom());
                        long x$34 = this.timeout$1.elem;
                        TimeUnit x$35 = qual$2.invokeAndWait2$default$3();
                        Buffer<Future<T>> trainingThreads = qual$2.invokeAndWait2(x$33, x$34, x$35);
                        long computingTime = System.nanoTime() - time;
                        this.driverMetrics$1.add("computing time average", computingTime);
                        this.driverMetrics$1.add("computing time for each node", computingTime);
                        Buffer finishedThreads = (Buffer)((TraversableLike)trainingThreads.filter((Function1)new Serializable(this){
                            public static final long serialVersionUID = 0L;

                            public final boolean apply(Future<Object> x$6) {
                                return !x$6.isCancelled();
                            }
                        })).map((Function1)new Serializable(this){
                            public static final long serialVersionUID = 0L;

                            public final int apply(Future<Object> x$7) {
                                return BoxesRunTime.unboxToInt((Object)x$7.get());
                            }
                        }, Buffer$.MODULE$.canBuildFrom());
                        this.recordsNum$1.add((double)(finishedThreads.size() * stackSize));
                        for (int i = 0; i < finishedThreads.size(); ++i) {
                            this.lossSum$1.add(((double[])this.lossArray$1.elem)[BoxesRunTime.unboxToInt((Object)finishedThreads.apply(i))]);
                        }
                        if (finishedThreads.nonEmpty()) {
                            Object object;
                            int parallelNum;
                            Buffer finishedGradients = (Buffer)finishedThreads.map((Function1)new Serializable(this, cached){
                                public static final long serialVersionUID = 0L;
                                private final DistriOptimizer.CacheV1 cached$1;

                                public final Tensor<T> apply(int x$8) {
                                    return this.cached$1.modelGradients()[x$8];
                                }
                                {
                                    this.cached$1 = cached$1;
                                }
                            }, Buffer$.MODULE$.canBuildFrom());
                            time = System.nanoTime();
                            int pOffset = this.parameters$1.paramOffset();
                            int pLength = this.parameters$1.size();
                            int taskSize = pLength / this._subModelNumber$1;
                            int extraTask = pLength % this._subModelNumber$1;
                            int n = parallelNum = taskSize == 0 ? extraTask : this._subModelNumber$1;
                            if (parallelNum != 1) {
                                ThreadPool qual$3 = Engine$.MODULE$.default();
                                IndexedSeq x$36 = (IndexedSeq)RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), parallelNum).map((Function1)new Serializable(this, finishedGradients, pOffset, taskSize, extraTask){
                                    public static final long serialVersionUID = 0L;
                                    public final Buffer finishedGradients$1;
                                    public final int pOffset$1;
                                    public final int taskSize$1;
                                    public final int extraTask$1;

                                    public final Function0<BoxedUnit> apply(int tid) {
                                        return new Serializable(this, tid){
                                            public static final long serialVersionUID = 0L;
                                            private final /* synthetic */ anonfun$4$$anonfun$9 $outer;
                                            private final int tid$1;

                                            public final void apply() {
                                                this.apply$mcV$sp();
                                            }

                                            public void apply$mcV$sp() {
                                                int offset = this.$outer.pOffset$1 + this.tid$1 * this.$outer.taskSize$1 + package$.MODULE$.min(this.tid$1, this.$outer.extraTask$1);
                                                int length = this.$outer.taskSize$1 + (this.tid$1 < this.$outer.extraTask$1 ? 1 : 0);
                                                for (int i = 1; i < this.$outer.finishedGradients$1.length(); ++i) {
                                                    ((Tensor)this.$outer.finishedGradients$1.apply(0)).narrow(1, offset, length).add(((Tensor)this.$outer.finishedGradients$1.apply(i)).narrow(1, offset, length));
                                                }
                                            }
                                            {
                                                if ($outer == null) {
                                                    throw null;
                                                }
                                                this.$outer = $outer;
                                                this.tid$1 = tid$1;
                                            }
                                        };
                                    }
                                    {
                                        this.finishedGradients$1 = finishedGradients$1;
                                        this.pOffset$1 = pOffset$1;
                                        this.taskSize$1 = taskSize$1;
                                        this.extraTask$1 = extraTask$1;
                                    }
                                }, IndexedSeq$.MODULE$.canBuildFrom());
                                Duration x$37 = qual$3.invokeAndWait$default$2();
                                qual$3.invokeAndWait(x$36, x$37);
                                object = this.driverMetrics$1.add("aggregate gradient time", System.nanoTime() - time);
                            } else {
                                object = BoxedUnit.UNIT;
                            }
                            long putG = System.nanoTime();
                            this.parameters$1.putGradients(((Tensor)finishedGradients.apply(0)).narrow(1, pOffset, pLength));
                            metrics = this.driverMetrics$1.add("put gradient", System.nanoTime() - putG);
                        } else {
                            long putG = System.nanoTime();
                            cached.modelGradients()[0].zero();
                            this.parameters$1.putGradients(cached.modelGradients()[0].narrow(1, this.parameters$1.paramOffset(), this.parameters$1.size()));
                            metrics = this.driverMetrics$1.add("put gradient", System.nanoTime() - putG);
                        }
                        ((ArrayBuffer)this.tasks$1.elem).$plus$plus$eq(Engine$.MODULE$.default().invoke((Seq)RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), this._subModelNumber$1).map((Function1)new Serializable(this, cached){
                            public static final long serialVersionUID = 0L;
                            public final DistriOptimizer.CacheV1 cached$1;

                            public final Function0<BoxedUnit> apply(int i) {
                                return new Serializable(this, i){
                                    public static final long serialVersionUID = 0L;
                                    private final /* synthetic */ anonfun$4$$anonfun$apply$8 $outer;
                                    private final int i$2;

                                    public final void apply() {
                                        this.apply$mcV$sp();
                                    }

                                    public void apply$mcV$sp() {
                                        this.$outer.cached$1.localModels()[this.i$2].training();
                                        this.$outer.cached$1.localModels()[this.i$2].zeroGradParameters();
                                    }
                                    {
                                        if ($outer == null) {
                                            throw null;
                                        }
                                        this.$outer = $outer;
                                        this.i$2 = i$2;
                                    }
                                };
                            }
                            {
                                this.cached$1 = cached$1;
                            }
                        }, IndexedSeq$.MODULE$.canBuildFrom())));
                        return scala.package$.MODULE$.Iterator().single((Object)BoxesRunTime.boxToInteger((int)finishedThreads.size()));
                    }
                    {
                        this.parameters$1 = parameters$1;
                        this.ev$2 = ev$2;
                        this._subModelNumber$1 = _subModelNumber$1;
                        this.tasks$1 = tasks$1;
                        this.threshold$1 = threshold$1;
                        this.timeout$1 = timeout$1;
                        this.iteration$1 = iteration$1;
                        this.dropPercentage$1 = dropPercentage$1;
                        this.warmupIterationNum$1 = warmupIterationNum$1;
                        this.computeThresholdbatchSize$1 = computeThresholdbatchSize$1;
                        this.lossArray$1 = lossArray$1;
                        this.lossSum$1 = lossSum$1;
                        this.recordsNum$1 = recordsNum$1;
                        this.driverMetrics$1 = driverMetrics$1;
                    }
                }, ClassTag$.MODULE$.apply(DistriOptimizer.CacheV1.class), ClassTag$.MODULE$.Int()).reduce((Function2)new Serializable(){
                    public static final long serialVersionUID = 0L;

                    public final int apply(int x$9, int x$10) {
                        return this.apply$mcIII$sp(x$9, x$10);
                    }

                    public int apply$mcIII$sp(int x$9, int x$10) {
                        return x$9 + x$10;
                    }
                }));
                dropModelNumBatch += driverSubModelNum - numFinishedModelUpdates;
                if (dropPercentage == 0.0 || (double)numFinishedModelUpdates >= (double)driverSubModelNum * (1.0 - maxDropPercentage)) {
                    double value2 = Predef$.MODULE$.Double2double(lossSum.value()) / (double)numFinishedModelUpdates;
                    driverState.update("numFinishedModel", BoxesRunTime.boxToInteger((int)numFinishedModelUpdates));
                    driverState.update("isGradientUpdated", BoxesRunTime.boxToBoolean((boolean)false));
                    Predef$.MODULE$.refArrayOps((Object[])parameterProcessers).foreach((Function1)new Serializable(metrics, models, parameters2, ev, driverState){
                        public static final long serialVersionUID = 0L;
                        private final Metrics metrics$1;
                        private final RDD models$1;
                        private final AllReduceParameter parameters$1;
                        private final TensorNumericMath.TensorNumeric ev$2;
                        private final Table driverState$1;

                        public final void apply(ParameterProcessor x$11) {
                            x$11.collectGlobalData(this.models$1, this.parameters$1, this.metrics$1, this.driverState$1, this.ev$2);
                        }
                        {
                            this.metrics$1 = metrics$1;
                            this.models$1 = models$1;
                            this.parameters$1 = parameters$1;
                            this.ev$2 = ev$2;
                            this.driverState$1 = driverState$1;
                        }
                    });
                    boolean isGradientUpdated = BoxesRunTime.unboxToBoolean(driverState.apply("isGradientUpdated"));
                    Broadcast stateBroadcast = sc.broadcast((Object)driverState, ClassTag$.MODULE$.apply(Table.class));
                    models.mapPartitions((Function1)new Serializable(parameters2, parameterSplits, validationMethods, parameterProcessers, ev, driverState, driverMetrics, numFinishedModelUpdates, value2, isGradientUpdated){
                        public static final long serialVersionUID = 0L;
                        public final AllReduceParameter parameters$1;
                        public final Map parameterSplits$1;
                        public final Option validationMethods$1;
                        private final ParameterProcessor[] parameterProcessers$1;
                        public final TensorNumericMath.TensorNumeric ev$2;
                        public final Table driverState$1;
                        private final Metrics driverMetrics$1;
                        private final int numFinishedModelUpdates$1;
                        public final double value$1;
                        private final boolean isGradientUpdated$1;

                        public final Iterator<Nothing$> apply(Iterator<DistriOptimizer.CacheV1<T>> modelIter) {
                            Tuple2<Object, Object> tuple2 = this.parameters$1.localPartitionRange();
                            if (tuple2 != null) {
                                Object object;
                                Tuple2.mcII.sp sp2;
                                int paramLocalStart = tuple2._1$mcI$sp();
                                int paramLocalLen = tuple2._2$mcI$sp();
                                Tuple2.mcII.sp sp3 = sp2 = new Tuple2.mcII.sp(paramLocalStart, paramLocalLen);
                                int paramLocalStart2 = sp3._1$mcI$sp();
                                int paramLocalLen2 = sp3._2$mcI$sp();
                                DistriOptimizer.CacheV1 modelCache = (DistriOptimizer.CacheV1)modelIter.next();
                                if (this.isGradientUpdated$1) {
                                    object = BoxedUnit.UNIT;
                                } else {
                                    long getG = System.nanoTime();
                                    this.parameters$1.aggregateGradientPartition(this.numFinishedModelUpdates$1);
                                    object = this.driverMetrics$1.add("aggregrateGradientParition average executor", System.nanoTime() - getG);
                                }
                                Predef$.MODULE$.refArrayOps((Object[])this.parameterProcessers$1).foreach((Function1)new Serializable(this, modelCache){
                                    public static final long serialVersionUID = 0L;
                                    private final /* synthetic */ anonfun.optimize.4 $outer;
                                    private final DistriOptimizer.CacheV1 modelCache$1;

                                    public final void apply(ParameterProcessor x$13) {
                                        x$13.processParameters(this.$outer.parameters$1, this.modelCache$1, this.$outer.driverState$1, this.$outer.ev$2);
                                    }
                                    {
                                        if ($outer == null) {
                                            throw null;
                                        }
                                        this.$outer = $outer;
                                        this.modelCache$1 = modelCache$1;
                                    }
                                });
                                modelCache.optimMethods().foreach((Function1)new Serializable(this, paramLocalStart2, paramLocalLen2){
                                    public static final long serialVersionUID = 0L;
                                    private final /* synthetic */ anonfun.optimize.4 $outer;
                                    public final int paramLocalStart$1;
                                    private final int paramLocalLen$1;

                                    public final Object apply(Tuple2<String, OptimMethod<T>> x0$1) {
                                        Tuple2<String, OptimMethod<T>> tuple2 = x0$1;
                                        if (tuple2 != null) {
                                            String name = (String)tuple2._1();
                                            OptimMethod optimMethod = (OptimMethod)tuple2._2();
                                            optimMethod.state().update("epoch", this.$outer.driverState$1.apply("epoch"));
                                            optimMethod.state().update("neval", this.$outer.driverState$1.apply("neval"));
                                            optimMethod.state().update("Loss", this.$outer.driverState$1.apply("Loss"));
                                            Object object = this.$outer.validationMethods$1.isDefined() ? optimMethod.state().update("score", this.$outer.driverState$1.apply("score")) : BoxedUnit.UNIT;
                                            Tuple2 p = (Tuple2)this.$outer.parameterSplits$1.apply((Object)name);
                                            int startIdx = Math.max(this.paramLocalStart$1, p._1$mcI$sp());
                                            int endIdx = Math.min(this.paramLocalStart$1 + this.paramLocalLen$1, p._1$mcI$sp() + p._2$mcI$sp());
                                            Tuple2<Tensor<T>, Object> tuple22 = endIdx > startIdx ? optimMethod.optimize(new Serializable(this, startIdx, endIdx){
                                                public static final long serialVersionUID = 0L;
                                                private final /* synthetic */ anonfun$optimize$4$$anonfun$apply$10 $outer;
                                                private final int startIdx$1;
                                                private final int endIdx$1;

                                                public final Tuple2<T, Tensor<T>> apply(Tensor<T> x$14) {
                                                    return new Tuple2(this.$outer.com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$anonfun$$anonfun$$$outer().ev$2.fromType(BoxesRunTime.boxToDouble((double)this.$outer.com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$anonfun$$anonfun$$$outer().value$1), ConvertableFrom$ConvertableFromDouble$.MODULE$), this.$outer.com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$anonfun$$anonfun$$$outer().parameters$1.gradientPartition().narrow(1, this.startIdx$1 - this.$outer.paramLocalStart$1 + 1, this.endIdx$1 - this.startIdx$1));
                                                }
                                                {
                                                    if ($outer == null) {
                                                        throw null;
                                                    }
                                                    this.$outer = $outer;
                                                    this.startIdx$1 = startIdx$1;
                                                    this.endIdx$1 = endIdx$1;
                                                }
                                            }, this.$outer.parameters$1.weightPartition().narrow(1, startIdx - this.paramLocalStart$1 + 1, endIdx - startIdx)) : BoxedUnit.UNIT;
                                            return tuple22;
                                        }
                                        throw new MatchError(tuple2);
                                    }

                                    public /* synthetic */ anonfun.optimize.4 com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$anonfun$$anonfun$$$outer() {
                                        return this.$outer;
                                    }
                                    {
                                        if ($outer == null) {
                                            throw null;
                                        }
                                        this.$outer = $outer;
                                        this.paramLocalStart$1 = paramLocalStart$1;
                                        this.paramLocalLen$1 = paramLocalLen$1;
                                    }
                                });
                                long time = System.nanoTime();
                                this.driverMetrics$1.add("compute weight average", System.nanoTime() - time);
                                this.parameters$1.sendWeightPartition();
                                time = System.nanoTime();
                                this.driverMetrics$1.add("send weights average", System.nanoTime() - time);
                                return scala.package$.MODULE$.Iterator().empty();
                            }
                            throw new MatchError(tuple2);
                        }
                        {
                            this.parameters$1 = parameters$1;
                            this.parameterSplits$1 = parameterSplits$1;
                            this.validationMethods$1 = validationMethods$1;
                            this.parameterProcessers$1 = parameterProcessers$1;
                            this.ev$2 = ev$2;
                            this.driverState$1 = driverState$1;
                            this.driverMetrics$1 = driverMetrics$1;
                            this.numFinishedModelUpdates$1 = numFinishedModelUpdates$1;
                            this.value$1 = value$1;
                            this.isGradientUpdated$1 = isGradientUpdated$1;
                        }
                    }, models.mapPartitions$default$2(), evidence$1).count();
                    stateBroadcast.destroy();
                    recordsProcessedThisEpoch.elem += (int)Predef$.MODULE$.Double2double(recordsNum.value());
                    long end = System.nanoTime();
                    wallClockTime += end - start2;
                    driverState.update("isGradientUpdated", BoxesRunTime.boxToBoolean((boolean)true));
                    driverState.update("Loss", BoxesRunTime.boxToFloat((float)((float)Predef$.MODULE$.Double2double(lossSum.value()) / (float)numFinishedModelUpdates)));
                    optimMethods.foreach((Function1)new Serializable(){
                        public static final long serialVersionUID = 0L;

                        public final void apply(Tuple2<String, OptimMethod<T>> v) {
                            ((OptimMethod)v._2()).updateHyperParameter();
                        }
                    });
                    driverState.update(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"LearningRate"})).s((Seq)Nil$.MODULE$), BoxesRunTime.boxToFloat((float)((float)((OptimMethod)((Tuple2)optimMethods.head())._2()).getLearningRate())));
                    driverState.update("Throughput", BoxesRunTime.boxToFloat((float)((float)Predef$.MODULE$.Double2double(recordsNum.value()) / (float)((double)(end - start2) / 1.0E9))));
                    String _header = Optimizer$.MODULE$.header(BoxesRunTime.unboxToInt(driverState.apply("epoch")), recordsProcessedThisEpoch.elem, numSamples, BoxesRunTime.unboxToInt(driverState.apply("neval")), wallClockTime);
                    this.logger().info(new StringBuilder().append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"", " Trained ", " records in ", " "})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{_header, recordsNum.value(), BoxesRunTime.boxToDouble((double)((double)(end - start2) / 1.0E9))}))).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"seconds. Throughput is ", " records/second. Loss is ", ". ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{driverState.apply("Throughput"), driverState.apply("Loss"), Optimizer$.MODULE$.getHyperParameterLog(optimMethods)}))).toString());
                    this.logger().debug(new StringBuilder().append((Object)"\n").append((Object)metrics.summary(metrics.summary$default$1(), metrics.summary$default$2())).toString());
                    this.logger().debug(new StringBuilder().append((Object)"Dropped modules: ").append((Object)BoxesRunTime.boxToInteger((int)(driverSubModelNum - numFinishedModelUpdates))).toString());
                    lossArray.elem = new double[_subModelNumber];
                    ++iteration2.elem;
                    if (dropPercentage > 0.0 && iteration2.elem > warmupIterationNum && iteration2.elem % computeThresholdbatchSize == 0) {
                        long[] moduleTimeList = (long[])models.mapPartitions((Function1)new Serializable(){
                            public static final long serialVersionUID = 0L;

                            public final Iterator<Object> apply(Iterator<DistriOptimizer.CacheV1<T>> iter) {
                                return Predef$.MODULE$.longArrayOps(((DistriOptimizer.CacheV1)iter.next()).moduleTimeList()).iterator();
                            }
                        }, models.mapPartitions$default$2(), ClassTag$.MODULE$.Long()).collect();
                        int k = (int)(dropPercentage * (double)computeThresholdbatchSize * (double)driverSubModelNum);
                        threshold.elem = k > dropModelNumBatch ? Util$.MODULE$.kthLargest(moduleTimeList, 0, moduleTimeList.length - 1, k - dropModelNumBatch) : (long)((double)threshold.elem * 1.01);
                        this.logger().info(new StringBuilder().append((Object)"threshold: ").append((Object)BoxesRunTime.boxToLong((long)threshold.elem)).toString());
                        models.mapPartitions((Function1)new Serializable(){
                            public static final long serialVersionUID = 0L;

                            public final Iterator<Nothing$> apply(Iterator<DistriOptimizer.CacheV1<T>> iter) {
                                long[] timeList = ((DistriOptimizer.CacheV1)iter.next()).moduleTimeList();
                                for (int i = 0; i < timeList.length; ++i) {
                                    timeList[i] = 0L;
                                }
                                return scala.package$.MODULE$.Iterator().empty();
                            }
                        }, models.mapPartitions$default$2(), evidence$1).count();
                        dropModelNumBatch = 0;
                    }
                    driverState.update("neval", BoxesRunTime.boxToInteger((int)(BoxesRunTime.unboxToInt(driverState.apply("neval")) + 1)));
                    if (recordsProcessedThisEpoch.elem >= numSamples) {
                        long epochEnd = System.nanoTime();
                        lastEpochTime = wallClockTime = lastEpochTime + epochEnd - epochStart;
                        epochStart = System.nanoTime();
                        this.logger().info(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"", " Epoch finished. Wall clock time is ", " ms"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{_header, BoxesRunTime.boxToDouble((double)((double)wallClockTime / 1000000.0))})));
                        driverState.update("epoch", BoxesRunTime.boxToInteger((int)(BoxesRunTime.unboxToInt(driverState.apply("epoch")) + 1)));
                        dataset.shuffle();
                        dataRDD = (RDD)dataset.data(true);
                        recordsProcessedThisEpoch.elem = 0;
                    }
                    optimMethods.map((Function1)new Serializable(validationMethods, driverState, recordsProcessedThisEpoch){
                        public static final long serialVersionUID = 0L;
                        private final Option validationMethods$1;
                        private final Table driverState$1;
                        private final IntRef recordsProcessedThisEpoch$1;

                        public final Object apply(Tuple2<String, OptimMethod<T>> x0$2) {
                            Tuple2<String, OptimMethod<T>> tuple2 = x0$2;
                            if (tuple2 != null) {
                                OptimMethod optimMethod = (OptimMethod)tuple2._2();
                                optimMethod.state().update("recordsProcessedThisEpoch", BoxesRunTime.boxToInteger((int)this.recordsProcessedThisEpoch$1.elem));
                                optimMethod.state().update("epoch", this.driverState$1.apply("epoch"));
                                optimMethod.state().update("neval", this.driverState$1.apply("neval"));
                                optimMethod.state().update("Loss", this.driverState$1.apply("Loss"));
                                BoxedUnit boxedUnit = this.validationMethods$1.isDefined() ? optimMethod.state().update("score", this.driverState$1.apply("score")) : BoxedUnit.UNIT;
                                return boxedUnit;
                            }
                            throw new MatchError(tuple2);
                        }
                        {
                            this.validationMethods$1 = validationMethods$1;
                            this.driverState$1 = driverState$1;
                            this.recordsProcessedThisEpoch$1 = recordsProcessedThisEpoch$1;
                        }
                    }, Iterable$.MODULE$.canBuildFrom());
                    this.validate(validationTrigger, validationDataSet, validationMethods, coresPerNode, models, driverState, validationSummary, _header, parameters2);
                    trainSummary.foreach((Function1)new Serializable(trainingModel, models, parameters2, evidence$1, ev, driverState){
                        public static final long serialVersionUID = 0L;
                        private final AbstractModule trainingModel$1;
                        private final RDD models$1;
                        private final AllReduceParameter parameters$1;
                        private final ClassTag evidence$1$1;
                        private final TensorNumericMath.TensorNumeric ev$2;
                        private final Table driverState$1;

                        public final void apply(TrainSummary summary2) {
                            DistriOptimizer$.MODULE$.saveSummary(summary2, this.models$1, this.driverState$1, this.parameters$1, this.trainingModel$1, this.evidence$1$1, this.ev$2);
                        }
                        {
                            this.trainingModel$1 = trainingModel$1;
                            this.models$1 = models$1;
                            this.parameters$1 = parameters$1;
                            this.evidence$1$1 = evidence$1$1;
                            this.ev$2 = ev$2;
                            this.driverState$1 = driverState$1;
                        }
                    });
                    this.checkpoint(cacheTrigger, cachePath, isOverWrite, wallClockTime, models, driverState, parameters2, optimMethods, trainingModel, evidence$1, ev);
                    continue;
                }
                this.logger().info(new StringBuilder().append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Warning! Not enough training samples were successfully processed in this "})).s((Seq)Nil$.MODULE$)).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"iteration due to some slow tasks. The gradients computed in this iteration will be "})).s((Seq)Nil$.MODULE$)).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"discarded. Only ", "/", " threads successfully "})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)numFinishedModelUpdates), BoxesRunTime.boxToInteger((int)driverSubModelNum)}))).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"completed training."})).s((Seq)Nil$.MODULE$)).toString());
            }
        }
        throw new MatchError((Object)engineType);
    }

    public <T> Tuple2<RDD<DistriOptimizer.CacheV1<T>>, ModelBroadcast<T>> initThreadModels(AbstractModule<Activity, Activity, T> model, DistributedDataSet<MiniBatch<T>> dataset, AbstractCriterion<Activity, Activity, T> criterion, Table state, int nodeNumber, int coresPerNode, boolean checkSingleton, AllReduceParameter<T> allReduceParameter, Map<String, Tuple2<Object, Object>> parameterSplits, Option<ValidationMethod<T>[]> validationMethods, Map<String, OptimMethod<T>> optimMethod, ArrayBuffer<ParameterProcessor> parameterProcessors, ClassTag<T> evidence$2, TensorNumericMath.TensorNumeric<T> ev) {
        int n;
        SparkContext sc = dataset.originRDD().sparkContext();
        Broadcast broadcast = sc.broadcast((Object)new Tuple4(criterion, (Object)state, validationMethods, optimMethod), ClassTag$.MODULE$.apply(Tuple4.class));
        AbstractModule<Activity, Activity, T> convertedModel = ConversionUtils$.MODULE$.convert(model, evidence$2);
        convertedModel.getParameters();
        ModelBroadcast<T> modelBroadcast = ModelBroadcast$.MODULE$.apply(evidence$2, ev).broadcast(sc, convertedModel);
        EngineType engineType = Engine$.MODULE$.getEngineType();
        if (MklBlas$.MODULE$.equals(engineType)) {
            n = coresPerNode;
        } else if (MklDnn$.MODULE$.equals(engineType)) {
            n = 1;
        } else {
            Log4Error$.MODULE$.invalidInputError(false, new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"unexpected engine type ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{Engine$.MODULE$.getEngineType()})), "only support MklBlas and MklDnn");
            n = 0;
        }
        int _subModelNumber = n;
        Log4Error$.MODULE$.invalidOperationError(dataset.originRDD().partitions().length == nodeNumber, new StringBuilder().append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Passed in rdd partition number ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)dataset.originRDD().partitions().length)}))).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{" is not equal to configured node number ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)nodeNumber)}))).toString(), Log4Error$.MODULE$.invalidOperationError$default$3(), Log4Error$.MODULE$.invalidOperationError$default$4());
        int computeThresholdbatchSize = BoxesRunTime.unboxToInt((Object)state.get("computeThresholdbatchSize").get());
        int nExecutor = Engine$.MODULE$.nodeNumber();
        int executorCores = Engine$.MODULE$.coreNumber();
        RDD<?> qual$4 = dataset.originRDD();
        Serializable x$38 = new Serializable(checkSingleton, allReduceParameter, evidence$2, ev, broadcast, modelBroadcast, _subModelNumber, computeThresholdbatchSize, nExecutor, executorCores){
            public static final long serialVersionUID = 0L;
            private final boolean checkSingleton$1;
            private final AllReduceParameter allReduceParameter$1;
            public final ClassTag evidence$2$1;
            private final TensorNumericMath.TensorNumeric ev$1;
            private final Broadcast broadcast$1;
            public final ModelBroadcast modelBroadcast$1;
            public final int _subModelNumber$2;
            private final int computeThresholdbatchSize$2;
            private final int nExecutor$1;
            private final int executorCores$1;

            public final Iterator<DistriOptimizer.CacheV1<T>> apply(Iterator<Object> x$15) {
                int partitionId = TaskContext$.MODULE$.getPartitionId();
                Tuple4 tuple4 = (Tuple4)this.broadcast$1.value();
                if (tuple4 != null) {
                    Tuple4 tuple42;
                    AbstractCriterion broadcastCriterion = (AbstractCriterion)tuple4._1();
                    Table broadcastState = (Table)tuple4._2();
                    Option broadcastMethod = (Option)tuple4._3();
                    Map broadcastOptim = (Map)tuple4._4();
                    Tuple4 tuple43 = tuple42 = new Tuple4((Object)broadcastCriterion, (Object)broadcastState, (Object)broadcastMethod, (Object)broadcastOptim);
                    AbstractCriterion broadcastCriterion2 = (AbstractCriterion)tuple43._1();
                    Table broadcastState2 = (Table)tuple43._2();
                    Option broadcastMethod2 = (Option)tuple43._3();
                    Map broadcastOptim2 = (Map)tuple43._4();
                    if (!Engine$.MODULE$.checkSingleton()) {
                        if (this.checkSingleton$1) {
                            Log4Error$.MODULE$.invalidOperationError(Engine$.MODULE$.checkSingleton(), "Partitions of the training data are not evenlydistributed across the executors in the Spark cluster; are there sufficient trainingdata to be distributed? Set property \"bigdl.check.singleton\" to false to skip this check", Log4Error$.MODULE$.invalidOperationError$default$3(), Log4Error$.MODULE$.invalidOperationError$default$4());
                        } else {
                            DistriOptimizer$.MODULE$.logger().warn("Partitions of the training data are not evenlydistributed across the executors in the Spark cluster; are there sufficient trainingdata to be distributed?");
                        }
                    }
                    Engine$.MODULE$.setNodeAndCore(this.nExecutor$1, this.executorCores$1);
                    Tuple6[] cached = (Tuple6[])((TraversableOnce)RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), this._subModelNumber$2).map((Function1)new Serializable(this, partitionId, broadcastCriterion2, broadcastState2, broadcastMethod2){
                        public static final long serialVersionUID = 0L;
                        private final /* synthetic */ anonfun.11 $outer;
                        private final int partitionId$1;
                        private final AbstractCriterion broadcastCriterion$1;
                        private final Table broadcastState$1;
                        private final Option broadcastMethod$1;

                        public final Tuple6<AbstractModule<Activity, Activity, T>, Tensor<T>, Tensor<T>, AbstractCriterion<Activity, Activity, T>, Table, Option<ValidationMethod<T>[]>> apply(int x$17) {
                            AbstractModule<Activity, Activity, T> localModel = this.$outer.modelBroadcast$1.value(true, this.$outer.modelBroadcast$1.value$default$2());
                            EngineType engineType = Engine$.MODULE$.getEngineType();
                            MklDnn$ mklDnn$ = MklDnn$.MODULE$;
                            Buffer<Future<T>> buffer = !(engineType != null ? !engineType.equals(mklDnn$) : mklDnn$ != null) && !(localModel instanceof IRGraph) ? Engine$.MODULE$.dnnComputing().invokeAndWait2((Seq)RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), this.$outer._subModelNumber$2).map((Function1)new Serializable(this, localModel){
                                public static final long serialVersionUID = 0L;
                                public final AbstractModule localModel$2;

                                public final Function0<BoxedUnit> apply(int i) {
                                    return new Serializable(this){
                                        public static final long serialVersionUID = 0L;
                                        private final /* synthetic */ anonfun$11$$anonfun$12$$anonfun$apply$12 $outer;

                                        public final void apply() {
                                            this.apply$mcV$sp();
                                        }

                                        public void apply$mcV$sp() {
                                            AbstractModule abstractModule = this.$outer.localModel$2;
                                            if (abstractModule instanceof MklDnnContainer) {
                                                MklDnnContainer mklDnnContainer = (MklDnnContainer)((Object)abstractModule);
                                                mklDnnContainer.compile(Phase$TrainingPhase$.MODULE$);
                                                BoxedUnit boxedUnit = BoxedUnit.UNIT;
                                            } else if (abstractModule instanceof DnnGraph) {
                                                DnnGraph dnnGraph = (DnnGraph)abstractModule;
                                                dnnGraph.compile(Phase$TrainingPhase$.MODULE$);
                                                BoxedUnit boxedUnit = BoxedUnit.UNIT;
                                            } else {
                                                BoxedUnit boxedUnit = BoxedUnit.UNIT;
                                            }
                                        }
                                        {
                                            if ($outer == null) {
                                                throw null;
                                            }
                                            this.$outer = $outer;
                                        }
                                    };
                                }
                                {
                                    this.localModel$2 = localModel$2;
                                }
                            }, IndexedSeq$.MODULE$.canBuildFrom()), Engine$.MODULE$.dnnComputing().invokeAndWait2$default$2(), Engine$.MODULE$.dnnComputing().invokeAndWait2$default$3()) : BoxedUnit.UNIT;
                            DistriOptimizer$.MODULE$.com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$setModelId(localModel, this.partitionId$1, this.$outer.evidence$2$1);
                            AbstractCriterion<A, B, T> localCriterion = this.broadcastCriterion$1.cloneCriterion();
                            Table localState = this.broadcastState$1.clone();
                            None$ localMethod = this.broadcastMethod$1.isDefined() ? new Some(Predef$.MODULE$.refArrayOps((Object[])this.broadcastMethod$1.get()).map((Function1)new Serializable(this){
                                public static final long serialVersionUID = 0L;

                                public final ValidationMethod<T> apply(ValidationMethod<T> x$18) {
                                    return x$18.clone();
                                }
                            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ValidationMethod.class)))) : None$.MODULE$;
                            Tuple2<Tensor<T>, Tensor<T>> tuple2 = localModel.getParameters();
                            if (tuple2 != null) {
                                Tuple2 tuple22;
                                Tensor weights = (Tensor)tuple2._1();
                                Tensor grads = (Tensor)tuple2._2();
                                Tuple2 tuple23 = tuple22 = new Tuple2((Object)weights, (Object)grads);
                                Tensor weights2 = (Tensor)tuple23._1();
                                Tensor grads2 = (Tensor)tuple23._2();
                                return new Tuple6(localModel, (Object)weights2, (Object)grads2, localCriterion, (Object)localState, (Object)localMethod);
                            }
                            throw new MatchError(tuple2);
                        }
                        {
                            if ($outer == null) {
                                throw null;
                            }
                            this.$outer = $outer;
                            this.partitionId$1 = partitionId$1;
                            this.broadcastCriterion$1 = broadcastCriterion$1;
                            this.broadcastState$1 = broadcastState$1;
                            this.broadcastMethod$1 = broadcastMethod$1;
                        }
                    }, IndexedSeq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(Tuple6.class));
                    DistriOptimizer$.MODULE$.logger().info(new StringBuilder().append((Object)"model thread pool size is ").append((Object)BoxesRunTime.boxToInteger((int)Engine$.MODULE$.model().getPoolSize())).toString());
                    Tensor weights = (Tensor)((Tuple6)Predef$.MODULE$.refArrayOps((Object[])cached).head())._2();
                    this.allReduceParameter$1.init(weights.narrow(1, this.allReduceParameter$1.paramOffset(), this.allReduceParameter$1.size()), this.ev$1);
                    Iterator$ iterator$ = scala.package$.MODULE$.Iterator();
                    AbstractModule[] abstractModuleArray = (AbstractModule[])Predef$.MODULE$.refArrayOps((Object[])cached).map((Function1)new Serializable(this){
                        public static final long serialVersionUID = 0L;

                        public final AbstractModule<Activity, Activity, T> apply(Tuple6<AbstractModule<Activity, Activity, T>, Tensor<T>, Tensor<T>, AbstractCriterion<Activity, Activity, T>, Table, Option<ValidationMethod<T>[]>> x$20) {
                            return (AbstractModule)x$20._1();
                        }
                    }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(AbstractModule.class)));
                    Tensor[] tensorArray = (Tensor[])Predef$.MODULE$.refArrayOps((Object[])cached).map((Function1)new Serializable(this){
                        public static final long serialVersionUID = 0L;

                        public final Tensor<T> apply(Tuple6<AbstractModule<Activity, Activity, T>, Tensor<T>, Tensor<T>, AbstractCriterion<Activity, Activity, T>, Table, Option<ValidationMethod<T>[]>> x$21) {
                            return (Tensor)x$21._2();
                        }
                    }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tensor.class)));
                    Tensor[] tensorArray2 = (Tensor[])Predef$.MODULE$.refArrayOps((Object[])cached).map((Function1)new Serializable(this){
                        public static final long serialVersionUID = 0L;

                        public final Tensor<T> apply(Tuple6<AbstractModule<Activity, Activity, T>, Tensor<T>, Tensor<T>, AbstractCriterion<Activity, Activity, T>, Table, Option<ValidationMethod<T>[]>> x$22) {
                            return (Tensor)x$22._3();
                        }
                    }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tensor.class)));
                    AbstractCriterion[] abstractCriterionArray = (AbstractCriterion[])Predef$.MODULE$.refArrayOps((Object[])cached).map((Function1)new Serializable(this){
                        public static final long serialVersionUID = 0L;

                        public final AbstractCriterion<Activity, Activity, T> apply(Tuple6<AbstractModule<Activity, Activity, T>, Tensor<T>, Tensor<T>, AbstractCriterion<Activity, Activity, T>, Table, Option<ValidationMethod<T>[]>> x$23) {
                            return (AbstractCriterion)x$23._4();
                        }
                    }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(AbstractCriterion.class)));
                    Table[] tableArray = (Table[])Predef$.MODULE$.refArrayOps((Object[])cached).map((Function1)new Serializable(this){
                        public static final long serialVersionUID = 0L;

                        public final Table apply(Tuple6<AbstractModule<Activity, Activity, T>, Tensor<T>, Tensor<T>, AbstractCriterion<Activity, Activity, T>, Table, Option<ValidationMethod<T>[]>> x$24) {
                            return (Table)x$24._5();
                        }
                    }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Table.class)));
                    long[] lArray = new long[this._subModelNumber$2 * this.computeThresholdbatchSize$2];
                    Option[] optionArray = (Option[])Predef$.MODULE$.refArrayOps((Object[])cached).map((Function1)new Serializable(this){
                        public static final long serialVersionUID = 0L;

                        public final Option<ValidationMethod<T>[]> apply(Tuple6<AbstractModule<Activity, Activity, T>, Tensor<T>, Tensor<T>, AbstractCriterion<Activity, Activity, T>, Table, Option<ValidationMethod<T>[]>> x$25) {
                            return (Option)x$25._6();
                        }
                    }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Option.class)));
                    Map map = (Map)broadcastOptim2.map((Function1)new Serializable(this){
                        public static final long serialVersionUID = 0L;

                        public final Tuple2<String, OptimMethod<T>> apply(Tuple2<String, OptimMethod<T>> v) {
                            return new Tuple2(v._1(), ((OptimMethod)v._2()).clone());
                        }
                    }, Map$.MODULE$.canBuildFrom());
                    DistriOptimizer$CacheV1$.MODULE$.apply$default$9();
                    return iterator$.single(new DistriOptimizer.CacheV1<T>(abstractModuleArray, tensorArray, tensorArray2, abstractCriterionArray, tableArray, lArray, optionArray, map, null));
                }
                throw new MatchError((Object)tuple4);
            }
            {
                this.checkSingleton$1 = checkSingleton$1;
                this.allReduceParameter$1 = allReduceParameter$1;
                this.evidence$2$1 = evidence$2$1;
                this.ev$1 = ev$1;
                this.broadcast$1 = broadcast$1;
                this.modelBroadcast$1 = modelBroadcast$1;
                this._subModelNumber$2 = _subModelNumber$2;
                this.computeThresholdbatchSize$2 = computeThresholdbatchSize$2;
                this.nExecutor$1 = nExecutor$1;
                this.executorCores$1 = executorCores$1;
            }
        };
        boolean x$39 = qual$4.mapPartitions$default$2();
        RDD models = qual$4.mapPartitions((Function1)x$38, x$39, ClassTag$.MODULE$.apply(DistriOptimizer.CacheV1.class)).persist();
        models.setName("Thread Model RDD");
        this.logger().info("Cache thread models...");
        models.count();
        this.logger().info("Cache thread models... done");
        return new Tuple2((Object)models, modelBroadcast);
    }

    public <T> void com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$setModelId(AbstractModule<Activity, Activity, T> model, int partitionId, ClassTag<T> evidence$3) {
        model.setId(partitionId);
        if (model instanceof Container) {
            ((Container)model).modules().foreach((Function1)new Serializable(partitionId, evidence$3){
                public static final long serialVersionUID = 0L;
                private final int partitionId$2;
                private final ClassTag evidence$3$1;

                public final void apply(AbstractModule<Activity, Activity, T> sub2) {
                    DistriOptimizer$.MODULE$.com$intel$analytics$bigdl$dllib$optim$DistriOptimizer$$setModelId(sub2, this.partitionId$2, this.evidence$3$1);
                }
                {
                    this.partitionId$2 = partitionId$2;
                    this.evidence$3$1 = evidence$3$1;
                }
            });
        }
    }

    @Override
    public <T> AbstractModule<Activity, Activity, T> getModel(RDD<DistriOptimizer.Cache<T>> models, AllReduceParameter<T> parameters2, AbstractModule<Activity, Activity, T> trainingModel, ClassTag<T> evidence$4, TensorNumericMath.TensorNumeric<T> ev) {
        int partitionNum = models.partitions().length;
        models.mapPartitions((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final Iterator<Object> apply(Iterator<DistriOptimizer.Cache<T>> iter) {
                ((AbstractModule)Predef$.MODULE$.refArrayOps((Object[])((DistriOptimizer.Cache)iter.next()).localModels()).head()).beforeGetModel();
                return scala.package$.MODULE$.Iterator().single((Object)BoxesRunTime.boxToInteger((int)1));
            }
        }, models.mapPartitions$default$2(), ClassTag$.MODULE$.Int()).reduce((Function2)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final int apply(int x$26, int x$27) {
                return this.apply$mcIII$sp(x$26, x$27);
            }

            public int apply$mcIII$sp(int x$26, int x$27) {
                return x$26 + x$27;
            }
        });
        Util$.MODULE$.setExtraParametersFromModelRDD(models, trainingModel, 500000000, evidence$4, ev);
        Tuple2<Tensor<T>[], Tensor<T>[]> parameterArray = trainingModel.parameters();
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), ((Tensor[])parameterArray._2()).length).foreach((Function1)new Serializable(parameterArray){
            public static final long serialVersionUID = 0L;
            private final Tuple2 parameterArray$1;

            public final Tensor<T> apply(int i) {
                return ((Tensor[])this.parameterArray$1._2())[i].resizeAs(((Tensor[])this.parameterArray$1._1())[i]);
            }
            {
                this.parameterArray$1 = parameterArray$1;
            }
        });
        Tuple2<Tensor<T>, Tensor<T>> tuple2 = trainingModel.getParameters();
        if (tuple2 != null) {
            Tuple2 tuple22;
            Tensor parameter = (Tensor)tuple2._1();
            Tensor gradientParameter = (Tensor)tuple2._2();
            Tuple2 tuple23 = tuple22 = new Tuple2((Object)parameter, (Object)gradientParameter);
            Tensor parameter2 = (Tensor)tuple23._1();
            Tensor gradientParameter2 = (Tensor)tuple23._2();
            Tuple2 tuple24 = (Tuple2)models.mapPartitions((Function1)new Serializable(parameters2){
                public static final long serialVersionUID = 0L;
                private final AllReduceParameter parameters$2;

                public final Iterator<Tuple2<Map<Object, Tensor<T>>, Map<Object, Tensor<T>>>> apply(Iterator<DistriOptimizer.Cache<T>> iter) {
                    DistriOptimizer.Cache cached = (DistriOptimizer.Cache)iter.next();
                    int curPartitionId = TaskContext$.MODULE$.getPartitionId();
                    return scala.package$.MODULE$.Iterator().single((Object)new Tuple2((Object)Predef$.MODULE$.Map().apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)BoxesRunTime.boxToInteger((int)curPartitionId)), this.parameters$2.weightPartition())})), (Object)Predef$.MODULE$.Map().apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)BoxesRunTime.boxToInteger((int)curPartitionId)), this.parameters$2.gradientPartition())}))));
                }
                {
                    this.parameters$2 = parameters$2;
                }
            }, models.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class)).reduce((Function2)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final Tuple2<Map<Object, Tensor<T>>, Map<Object, Tensor<T>>> apply(Tuple2<Map<Object, Tensor<T>>, Map<Object, Tensor<T>>> a, Tuple2<Map<Object, Tensor<T>>, Map<Object, Tensor<T>>> b) {
                    return new Tuple2((Object)((MapLike)a._1()).$plus$plus((GenTraversableOnce)b._1()), (Object)((MapLike)a._2()).$plus$plus((GenTraversableOnce)b._2()));
                }
            });
            if (tuple24 != null) {
                Tuple2 tuple25;
                Map weights = (Map)tuple24._1();
                Map gradients = (Map)tuple24._2();
                Tuple2 tuple26 = tuple25 = new Tuple2((Object)weights, (Object)gradients);
                Map weights2 = (Map)tuple26._1();
                Map gradients2 = (Map)tuple26._2();
                int taskSize = parameters2.size() / partitionNum;
                Log4Error$.MODULE$.invalidOperationError(taskSize != 0, "parameter length should not less than partition number", Log4Error$.MODULE$.invalidOperationError$default$3(), Log4Error$.MODULE$.invalidOperationError$default$4());
                int extraSize = parameters2.size() % partitionNum;
                RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), partitionNum).map((Function1)new Serializable(parameters2, parameter2, gradientParameter2, weights2, gradients2, taskSize, extraSize){
                    public static final long serialVersionUID = 0L;
                    private final AllReduceParameter parameters$2;
                    private final Tensor parameter$1;
                    private final Tensor gradientParameter$1;
                    private final Map weights$1;
                    private final Map gradients$1;
                    private final int taskSize$2;
                    private final int extraSize$1;

                    public final Tensor<T> apply(int pid) {
                        int start2 = this.parameters$2.paramOffset() + pid * this.taskSize$2 + package$.MODULE$.min(pid, this.extraSize$1);
                        int length = this.taskSize$2 + (pid < this.extraSize$1 ? 1 : 0);
                        this.parameter$1.narrow(1, start2, length).copy((Tensor)this.weights$1.apply((Object)BoxesRunTime.boxToInteger((int)pid)));
                        return this.gradientParameter$1.narrow(1, start2, length).copy((Tensor)this.gradients$1.apply((Object)BoxesRunTime.boxToInteger((int)pid)));
                    }
                    {
                        this.parameters$2 = parameters$2;
                        this.parameter$1 = parameter$1;
                        this.gradientParameter$1 = gradientParameter$1;
                        this.weights$1 = weights$1;
                        this.gradients$1 = gradients$1;
                        this.taskSize$2 = taskSize$2;
                        this.extraSize$1 = extraSize$1;
                    }
                }, IndexedSeq$.MODULE$.canBuildFrom());
                return trainingModel;
            }
            throw new MatchError((Object)tuple24);
        }
        throw new MatchError(tuple2);
    }

    private DistriOptimizer$() {
        MODULE$ = this;
        this.logger = LogManager.getLogger(this.getClass());
    }
}

