/*
 * Decompiled with CFR 0.152.
 */
package com.intel.analytics.bigdl.dllib.optim;

import com.intel.analytics.bigdl.dllib.feature.dataset.MiniBatch;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractCriterion;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.optim.LossWithElapsedTime;
import com.intel.analytics.bigdl.dllib.optim.OptimMethod;
import com.intel.analytics.bigdl.dllib.optim.ParamSegments;
import com.intel.analytics.bigdl.dllib.optim.StateEntry$;
import com.intel.analytics.bigdl.dllib.optim.TrainingContext$;
import com.intel.analytics.bigdl.dllib.tensor.ConvertableFrom$ConvertableFromDouble$;
import com.intel.analytics.bigdl.dllib.tensor.ConvertableTo$ConvertableToDouble$;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.Engine$;
import com.intel.analytics.bigdl.dllib.utils.Log4Error$;
import com.intel.analytics.bigdl.dllib.utils.Table;
import com.intel.analytics.bigdl.dllib.utils.ThreadPool;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import org.apache.logging.log4j.LogManager;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.TraversableLike;
import scala.collection.immutable.IndexedSeq;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.immutable.Map;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.Buffer;
import scala.collection.mutable.Buffer$;
import scala.collection.mutable.StringBuilder;
import scala.concurrent.duration.Duration;
import scala.math.package$;
import scala.reflect.ClassTag;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;

@ScalaSignature(bytes="\u0006\u0001\u0005\u001dh\u0001B\u0001\u0003\u0001=\u0011q\u0002\u0016:bS:LgnZ\"p]R,\u0007\u0010\u001e\u0006\u0003\u0007\u0011\tQa\u001c9uS6T!!\u0002\u0004\u0002\u000b\u0011dG.\u001b2\u000b\u0005\u001dA\u0011!\u00022jO\u0012d'BA\u0005\u000b\u0003%\tg.\u00197zi&\u001c7O\u0003\u0002\f\u0019\u0005)\u0011N\u001c;fY*\tQ\"A\u0002d_6\u001c\u0001!\u0006\u0002\u0011wM\u0019\u0001!E\f\u0011\u0005I)R\"A\n\u000b\u0003Q\tQa]2bY\u0006L!AF\n\u0003\r\u0005s\u0017PU3g!\t\u0011\u0002$\u0003\u0002\u001a'\ta1+\u001a:jC2L'0\u00192mK\"A1\u0004\u0001BC\u0002\u0013\u0005A$\u0001\btk\nlu\u000eZ3m\u001dVl'-\u001a:\u0016\u0003u\u0001\"A\u0005\u0010\n\u0005}\u0019\"aA%oi\"A\u0011\u0005\u0001B\u0001B\u0003%Q$A\btk\nlu\u000eZ3m\u001dVl'-\u001a:!\u0011!\u0019\u0003A!b\u0001\n\u0003a\u0012A\u00038v[N\u000bW\u000e\u001d7fg\"AQ\u0005\u0001B\u0001B\u0003%Q$A\u0006ok6\u001c\u0016-\u001c9mKN\u0004\u0003\u0002C\u0014\u0001\u0005\u000b\u0007I\u0011\u0001\u0015\u0002\u000bM$\u0018\r^3\u0016\u0003%\u0002\"AK\u0017\u000e\u0003-R!\u0001\f\u0003\u0002\u000bU$\u0018\u000e\\:\n\u00059Z#!\u0002+bE2,\u0007\u0002\u0003\u0019\u0001\u0005\u0003\u0005\u000b\u0011B\u0015\u0002\rM$\u0018\r^3!\u0011!\u0011\u0004AaA!\u0002\u0017\u0019\u0014aC3wS\u0012,gnY3%cE\u00022\u0001N\u001c:\u001b\u0005)$B\u0001\u001c\u0014\u0003\u001d\u0011XM\u001a7fGRL!\u0001O\u001b\u0003\u0011\rc\u0017m]:UC\u001e\u0004\"AO\u001e\r\u0001\u0011)A\b\u0001b\u0001{\t\tA+\u0005\u0002?\u0003B\u0011!cP\u0005\u0003\u0001N\u0011qAT8uQ&tw\r\u0005\u0002\u0013\u0005&\u00111i\u0005\u0002\u0004\u0003:L\b\"B#\u0001\t\u00031\u0015A\u0002\u001fj]&$h\b\u0006\u0003H\u00172kEC\u0001%K!\rI\u0005!O\u0007\u0002\u0005!)!\u0007\u0012a\u0002g!)1\u0004\u0012a\u0001;!)1\u0005\u0012a\u0001;!)q\u0005\u0012a\u0001S!)q\n\u0001C\u0001!\u0006)\u0002.Y:D_6\u0004H.\u001a;f\u00032d7+Y7qY\u0016\u001cHcA)U-B\u0011!CU\u0005\u0003'N\u0011qAQ8pY\u0016\fg\u000eC\u0003V\u001d\u0002\u0007Q$\u0001\tsK\u000e|'\u000fZ:Qe>\u001cWm]:fI\")qK\u0014a\u00011\u0006)Qn\u001c3fYB\u0019\u0011lZ\u001d\u000f\u0005i+gBA.e\u001d\ta6M\u0004\u0002^E:\u0011a,Y\u0007\u0002?*\u0011\u0001MD\u0001\u0007yI|w\u000e\u001e \n\u00035I!a\u0003\u0007\n\u0005%Q\u0011BA\u0004\t\u0013\t1g!A\u0004qC\u000e\\\u0017mZ3\n\u0005!L'AB'pIVdWM\u0003\u0002g\r!)1\u000e\u0001C\u0001Y\u0006Qa-\u001a;dQ\n\u000bGo\u00195\u0016\u00055\\HC\u00018\u0000)\tyG\u0010E\u0002\u0013aJL!!]\n\u0003\u000b\u0005\u0013(/Y=\u0011\u0007MD(0D\u0001u\u0015\t)h/A\u0004eCR\f7/\u001a;\u000b\u0005]$\u0011a\u00024fCR,(/Z\u0005\u0003sR\u0014\u0011\"T5oS\n\u000bGo\u00195\u0011\u0005iZH!\u0002\u001fk\u0005\u0004i\u0004bB?k\u0003\u0003\u0005\u001dA`\u0001\fKZLG-\u001a8dK\u0012\n$\u0007E\u00025oiDq!!\u0001k\u0001\u0004\t\u0019!\u0001\u0003eCR\f\u0007#BA\u0003\u0003\u001b\u0011h\u0002BA\u0004\u0003\u0017q1AXA\u0005\u0013\u0005!\u0012B\u00014\u0014\u0013\u0011\ty!!\u0005\u0003\u0011%#XM]1u_JT!AZ\n\t\u000f\u0005U\u0001\u0001\"\u0001\u0002\u0018\u0005)AO]1j]V!\u0011\u0011DA\u0019)!\tY\"!\u0015\u0002X\u0005}CCBA\u000f\u0003S\t\u0019\u0004\u0005\u0004\u0002\u0006\u0005}\u00111E\u0005\u0005\u0003C\t\tBA\u0002TKF\u00042!SA\u0013\u0013\r\t9C\u0001\u0002\u0014\u0019>\u001c8oV5uQ\u0016c\u0017\r]:fIRKW.\u001a\u0005\u000b\u0003W\t\u0019\"!AA\u0004\u00055\u0012aC3wS\u0012,gnY3%cM\u0002B\u0001N\u001c\u00020A\u0019!(!\r\u0005\rq\n\u0019B1\u0001>\u0011!\t)$a\u0005A\u0004\u0005]\u0012AA3w!\u0019\tI$a\u0013\u000209!\u00111HA#\u001d\u0011\ti$!\u0011\u000f\u0007i\u000by$\u0003\u0002\u0006\r%\u0019\u00111\t\u0003\u0002\rQ,gn]8s\u0013\u0011\t9%!\u0013\u0002#Q+gn]8s\u001dVlWM]5d\u001b\u0006$\bNC\u0002\u0002D\u0011IA!!\u0014\u0002P\tiA+\u001a8t_JtU/\\3sS\u000eTA!a\u0012\u0002J!A\u0011\u0011AA\n\u0001\u0004\t\u0019\u0006\u0005\u0003\u0013a\u0006U\u0003\u0003B:y\u0003_A\u0001\"!\u0017\u0002\u0014\u0001\u0007\u00111L\u0001\u0007[>$W\r\\:\u0011\tI\u0001\u0018Q\f\t\u00053\u001e\fy\u0003\u0003\u0005\u0002b\u0005M\u0001\u0019AA2\u0003%\u0019'/\u001b;fe&|g\u000e\u0005\u0003\u0013a\u0006\u0015\u0004#B-\u0002h\u0005=\u0012bAA5S\nI1I]5uKJLwN\u001c\u0005\b\u0003[\u0002A\u0011AA8\u0003\u0019)\b\u000fZ1uKV!\u0011\u0011OAB))\t\u0019(!#\u0002(\u0006M\u0016q\u0017\u000b\u0007\u0003k\nY(!\"\u0011\u0007I\t9(C\u0002\u0002zM\u0011A!\u00168ji\"Q\u0011QPA6\u0003\u0003\u0005\u001d!a \u0002\u0017\u00154\u0018\u000eZ3oG\u0016$\u0013\u0007\u000e\t\u0005i]\n\t\tE\u0002;\u0003\u0007#a\u0001PA6\u0005\u0004i\u0004\u0002CA\u001b\u0003W\u0002\u001d!a\"\u0011\r\u0005e\u00121JAA\u0011!\tY)a\u001bA\u0002\u00055\u0015!D8qi&l7+Z4nK:$8\u000f\u0005\u0005\u0002\u0010\u0006U\u00151TAQ\u001d\r\u0011\u0012\u0011S\u0005\u0004\u0003'\u001b\u0012A\u0002)sK\u0012,g-\u0003\u0003\u0002\u0018\u0006e%aA'ba*\u0019\u00111S\n\u0011\t\u0005=\u0015QT\u0005\u0005\u0003?\u000bIJ\u0001\u0004TiJLgn\u001a\t\u0006\u0013\u0006\r\u0016\u0011Q\u0005\u0004\u0003K\u0013!!\u0004)be\u0006l7+Z4nK:$8\u000f\u0003\u0005\u0002*\u0006-\u0004\u0019AAV\u0003\u00199X-[4iiB1\u0011QVAX\u0003\u0003k!!!\u0013\n\t\u0005E\u0016\u0011\n\u0002\u0007)\u0016t7o\u001c:\t\u0011\u0005U\u00161\u000ea\u0001\u0003W\u000b\u0001b\u001a:bI&,g\u000e\u001e\u0005\t\u0003s\u000bY\u00071\u0001\u0002<\u0006Y\u0011M^3sC\u001e,Gj\\:t!\r\u0011\u0012QX\u0005\u0004\u0003\u007f\u001b\"A\u0002#pk\ndW\rC\u0004\u0002D\u0002!\t!!2\u0002\u0013\u0005<wM]3hCR,W\u0003BAd\u0003\u001f$B!!3\u0002XR!\u00111ZAi!\u0019\ti+a,\u0002NB\u0019!(a4\u0005\rq\n\tM1\u0001>\u0011)\t\u0019.!1\u0002\u0002\u0003\u000f\u0011Q[\u0001\fKZLG-\u001a8dK\u0012\nT\u0007\u0005\u00035o\u00055\u0007\u0002CAm\u0003\u0003\u0004\r!a7\u0002\u0013\u001d\u0014\u0018\rZ5f]R\u001c\b\u0003\u0002\nq\u0003\u0017Dq!a8\u0001\t\u000b\t\t/A\u0005m_\u0006$7\u000b^1uKR!\u00111]As\u001b\u0005\u0001\u0001BB\u0014\u0002^\u0002\u0007\u0011\u0006")
public class TrainingContext<T>
implements Serializable {
    private final int subModelNumber;
    private final int numSamples;
    private final Table state;

    public int subModelNumber() {
        return this.subModelNumber;
    }

    public int numSamples() {
        return this.numSamples;
    }

    public Table state() {
        return this.state;
    }

    public boolean hasCompleteAllSamples(int recordsProcessed, AbstractModule<Activity, Activity, T> model) {
        return recordsProcessed >= this.numSamples();
    }

    /*
     * WARNING - void declaration
     */
    public <T> MiniBatch<T>[] fetchBatch(Iterator<MiniBatch<T>> data2, ClassTag<T> evidence$12) {
        void var3_3;
        MiniBatch[] miniBatchBuffer = new MiniBatch[this.subModelNumber()];
        MiniBatch batch = (MiniBatch)data2.next();
        int stackSize = batch.size() / this.subModelNumber();
        Log4Error$.MODULE$.invalidOperationError(batch.size() >= this.subModelNumber() && batch.size() % this.subModelNumber() == 0, new StringBuilder().append((Object)"total batch size: ").append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"", " should be divided by total core number: ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)batch.size()), BoxesRunTime.boxToInteger((int)this.subModelNumber())}))).toString(), Log4Error$.MODULE$.invalidOperationError$default$3(), Log4Error$.MODULE$.invalidOperationError$default$4());
        if (batch.size() < this.subModelNumber() * 2) {
            LogManager.getLogger(this.getClass()).warn(new StringBuilder().append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Warning: for better training speed, total batch size is recommended to be "})).s((Seq)Nil$.MODULE$)).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"at least two times of core number ", ". "})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)this.subModelNumber())}))).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"please tune your batch size accordingly"})).s((Seq)Nil$.MODULE$)).toString());
        }
        for (int b = 0; b < this.subModelNumber(); ++b) {
            miniBatchBuffer[b] = batch.slice(b * stackSize + 1, stackSize);
        }
        return var3_3;
    }

    public <T> Seq<LossWithElapsedTime> train(MiniBatch<T>[] data2, AbstractModule<Activity, Activity, T>[] models, AbstractCriterion<Activity, Activity, T>[] criterion, ClassTag<T> evidence$13, TensorNumericMath.TensorNumeric<T> ev) {
        ThreadPool qual$5 = Engine$.MODULE$.default();
        IndexedSeq x$34 = (IndexedSeq)Predef$.MODULE$.refArrayOps((Object[])models).indices().map((Function1)new Serializable(this, (MiniBatch[])data2, (AbstractModule[])models, (AbstractCriterion[])criterion, ev){
            public static final long serialVersionUID = 0L;
            public final MiniBatch[] data$2;
            public final AbstractModule[] models$2;
            public final AbstractCriterion[] criterion$1;
            public final TensorNumericMath.TensorNumeric ev$6;

            public final Function0<LossWithElapsedTime> apply(int i) {
                return new Serializable(this, i){
                    public static final long serialVersionUID = 0L;
                    private final /* synthetic */ $anonfun$20 $outer;
                    private final int i$2;

                    public final LossWithElapsedTime apply() {
                        long start2 = System.nanoTime();
                        AbstractModule localModel = this.$outer.models$2[this.i$2];
                        AbstractCriterion localCriterion = this.$outer.criterion$1[this.i$2];
                        Activity input = this.$outer.data$2[this.i$2].getInput();
                        Activity target = this.$outer.data$2[this.i$2].getTarget();
                        double loss2 = 0.0;
                        localModel.training();
                        B output = localModel.forward(input);
                        loss2 = BoxesRunTime.unboxToDouble((Object)this.$outer.ev$6.toType(localCriterion.forward(output, target), ConvertableTo$ConvertableToDouble$.MODULE$));
                        B errors = localCriterion.backward(output, target);
                        localModel.backward(input, errors);
                        long end = System.nanoTime();
                        return new LossWithElapsedTime(this.i$2, loss2, end - start2);
                    }
                    {
                        if ($outer == null) {
                            throw null;
                        }
                        this.$outer = $outer;
                        this.i$2 = i$2;
                    }
                };
            }
            {
                void var5_5;
                void var4_4;
                void var3_3;
                this.data$2 = data$2;
                this.models$2 = var3_3;
                this.criterion$1 = var4_4;
                this.ev$6 = var5_5;
            }
        }, IndexedSeq$.MODULE$.canBuildFrom());
        long x$35 = Long.MAX_VALUE;
        TimeUnit x$36 = qual$5.invokeAndWait2$default$3();
        Buffer trainingThreads = qual$5.invokeAndWait2(x$34, x$35, x$36);
        return (Seq)((TraversableLike)trainingThreads.filter((Function1)new Serializable(this){
            public static final long serialVersionUID = 0L;

            public final boolean apply(Future<LossWithElapsedTime> x$23) {
                return !x$23.isCancelled();
            }
        })).map((Function1)new Serializable(this){
            public static final long serialVersionUID = 0L;

            public final LossWithElapsedTime apply(Future<LossWithElapsedTime> x$24) {
                return x$24.get();
            }
        }, Buffer$.MODULE$.canBuildFrom());
    }

    public <T> void update(Map<String, ParamSegments<T>> optimSegments, Tensor<T> weight, Tensor<T> gradient, double averageLoss, ClassTag<T> evidence$14, TensorNumericMath.TensorNumeric<T> ev) {
        optimSegments.foreach((Function1)new Serializable(this, weight, gradient, averageLoss, ev){
            public static final long serialVersionUID = 0L;
            private final Tensor weight$1;
            public final Tensor gradient$1;
            public final double averageLoss$1;
            public final TensorNumericMath.TensorNumeric ev$7;

            public final Object apply(Tuple2<String, ParamSegments<T>> x0$5) {
                ParamSegments paramSegments;
                Tuple2<String, ParamSegments<T>> tuple2 = x0$5;
                if (tuple2 != null && (paramSegments = (ParamSegments)tuple2._2()) != null) {
                    int start2 = paramSegments.start();
                    int length = paramSegments.length();
                    OptimMethod<T> method = paramSegments.method();
                    Tuple2<Tensor<T>, Object> tuple22 = length > 0 ? method.optimize(new Serializable(this){
                        public static final long serialVersionUID = 0L;
                        private final /* synthetic */ $anonfun$update$1 $outer;
                        private final int start$1;
                        private final int length$1;

                        public final Tuple2<T, Tensor<T>> apply(Tensor<T> x$25) {
                            return new Tuple2(this.$outer.ev$7.fromType(BoxesRunTime.boxToDouble((double)this.$outer.averageLoss$1), ConvertableFrom$ConvertableFromDouble$.MODULE$), this.$outer.gradient$1.narrow(1, this.start$1, this.length$1));
                        }
                        {
                            if ($outer == null) {
                                throw null;
                            }
                            this.$outer = $outer;
                            this.start$1 = start$1;
                            this.length$1 = length$1;
                        }
                    }, this.weight$1.narrow(1, start2, length)) : BoxedUnit.UNIT;
                    return tuple22;
                }
                throw new MatchError(tuple2);
            }
            {
                void var6_5;
                void var3_3;
                this.weight$1 = weight$1;
                this.gradient$1 = var3_3;
                this.averageLoss$1 = averageLoss$1;
                this.ev$7 = var6_5;
            }
        });
    }

    public <T> Tensor<T> aggregate(Tensor<T>[] gradients, ClassTag<T> evidence$15) {
        Seq seq;
        int parallelNum;
        int start2 = ((Tensor)Predef$.MODULE$.refArrayOps((Object[])gradients).head()).storageOffset();
        int length = ((Tensor)Predef$.MODULE$.refArrayOps((Object[])gradients).head()).nElement();
        int taskSize = length / this.subModelNumber();
        int extraTask = length % this.subModelNumber();
        int n = parallelNum = taskSize == 0 ? extraTask : this.subModelNumber();
        if (parallelNum != 1) {
            ThreadPool qual$6 = Engine$.MODULE$.default();
            IndexedSeq x$37 = (IndexedSeq)RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), parallelNum).map((Function1)new Serializable(this, (Tensor[])gradients, start2, taskSize, extraTask){
                public static final long serialVersionUID = 0L;
                public final Tensor[] gradients$3;
                public final int start$2;
                public final int taskSize$2;
                public final int extraTask$1;

                public final Function0<BoxedUnit> apply(int tid) {
                    return new Serializable(this, tid){
                        public static final long serialVersionUID = 0L;
                        private final /* synthetic */ $anonfun$21 $outer;
                        private final int tid$1;

                        public final void apply() {
                            this.apply$mcV$sp();
                        }

                        public void apply$mcV$sp() {
                            int offset = this.$outer.start$2 + this.tid$1 * this.$outer.taskSize$2 + package$.MODULE$.min(this.tid$1, this.$outer.extraTask$1);
                            int length = this.$outer.taskSize$2 + (this.tid$1 < this.$outer.extraTask$1 ? 1 : 0);
                            for (int i = 1; i < this.$outer.gradients$3.length; ++i) {
                                Tensor<Tensor<T>> target = this.$outer.gradients$3[0].narrow(1, offset, length);
                                Tensor<T> source = this.$outer.gradients$3[i].narrow(1, offset, length);
                                target.add(source);
                            }
                        }
                        {
                            if ($outer == null) {
                                throw null;
                            }
                            this.$outer = $outer;
                            this.tid$1 = tid$1;
                        }
                    };
                }
                {
                    this.gradients$3 = gradients$3;
                    this.start$2 = start$2;
                    this.taskSize$2 = taskSize$2;
                    this.extraTask$1 = extraTask$1;
                }
            }, IndexedSeq$.MODULE$.canBuildFrom());
            Duration x$38 = qual$6.invokeAndWait$default$2();
            seq = qual$6.invokeAndWait(x$37, x$38);
        } else {
            seq = BoxedUnit.UNIT;
        }
        return gradients[0];
    }

    public final TrainingContext<T> loadState(Table state) {
        this.state().update(StateEntry$.MODULE$.EPOCH(), state.apply(StateEntry$.MODULE$.EPOCH()));
        this.state().update(StateEntry$.MODULE$.NEVAL(), state.apply(StateEntry$.MODULE$.NEVAL()));
        this.state().update(StateEntry$.MODULE$.LOSS(), state.apply(StateEntry$.MODULE$.LOSS()));
        this.state().update(StateEntry$.MODULE$.SCORE(), state.apply(StateEntry$.MODULE$.SCORE()));
        this.state().update(StateEntry$.MODULE$.PARALLELISM(), BoxesRunTime.boxToInteger((int)this.subModelNumber()));
        this.state().update(StateEntry$.MODULE$.RECORDS_PROCESSED(), state.apply(StateEntry$.MODULE$.RECORDS_PROCESSED()));
        return this;
    }

    public TrainingContext(int subModelNumber, int numSamples, Table state, ClassTag<T> evidence$11) {
        this.subModelNumber = subModelNumber;
        this.numSamples = numSamples;
        this.state = state;
    }
}

