/*
 * Decompiled with CFR 0.152.
 */
package com.intel.analytics.bigdl.dllib.keras.models;

import com.intel.analytics.bigdl.dllib.feature.AbstractFeatureSet;
import com.intel.analytics.bigdl.dllib.feature.dataset.MiniBatch;
import com.intel.analytics.bigdl.dllib.keras.models.InternalDistriOptimizer$$anonfun$18$;
import com.intel.analytics.bigdl.dllib.keras.models.InternalOptimizerUtil$;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.optim.DistriOptimizer;
import com.intel.analytics.bigdl.dllib.optim.DistriOptimizer$;
import com.intel.analytics.bigdl.dllib.optim.ValidationMethod;
import com.intel.analytics.bigdl.dllib.optim.ValidationResult;
import com.intel.analytics.bigdl.dllib.optim.parameters.AllReduceParameter;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.Tensor$;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.Log4Error$;
import com.intel.analytics.bigdl.dllib.visualization.Summary;
import com.intel.analytics.bigdl.dllib.visualization.ValidationSummary;
import java.io.File;
import java.io.FilenameFilter;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.spark.TaskContext$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.rdd.ZippedPartitionsWithLocalityRDD$;
import scala.Array$;
import scala.Function1;
import scala.Function2;
import scala.MatchError;
import scala.Option;
import scala.Predef;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.GenIterable;
import scala.collection.GenTraversableOnce;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.generic.CanBuildFrom;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.immutable.Map;
import scala.collection.immutable.MapLike;
import scala.collection.parallel.ParIterableLike;
import scala.collection.parallel.mutable.ParArray$;
import scala.collection.parallel.package$;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.LongRef;
import scala.runtime.ObjectRef;
import scala.runtime.RichInt$;
import scala.runtime.ScalaRunTime$;

public final class InternalDistriOptimizer$ {
    public static final InternalDistriOptimizer$ MODULE$;
    private final Logger logger;

    static {
        new InternalDistriOptimizer$();
    }

    public Logger logger() {
        return this.logger;
    }

    public <T> Map<ValidationMethod<T>, ValidationResult> validate(AbstractFeatureSet<MiniBatch<T>, ?> validationFeatureSet, ValidationMethod<T>[] validationMethods, RDD<DistriOptimizer.CacheV1<T>> models, int step, Option<ValidationSummary> validationSummary) {
        ValidationMethod<T>[] vMethods = validationMethods;
        RDD validateRDD = (RDD)validationFeatureSet.toDistributed().data(false);
        Tuple2[] results = (Tuple2[])Predef$.MODULE$.refArrayOps((Object[])ZippedPartitionsWithLocalityRDD$.MODULE$.apply(models, validateRDD, ZippedPartitionsWithLocalityRDD$.MODULE$.apply$default$3(), new Serializable(){
            public static final long serialVersionUID = 0L;

            public final Iterator<ValidationResult[]> apply(Iterator<DistriOptimizer.CacheV1<T>> modelIter, Iterator<MiniBatch<T>> dataIter) {
                DistriOptimizer.CacheV1 cached = (DistriOptimizer.CacheV1)modelIter.next();
                AbstractModule[] workingModels = cached.localModels();
                Option<ValidationMethod<T>[]>[] localVMethods = cached.localMethods();
                Predef$.MODULE$.refArrayOps((Object[])workingModels).foreach((Function1)new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final AbstractModule<Activity, Activity, T> apply(AbstractModule<Activity, Activity, T> x$12) {
                        return x$12.evaluate();
                    }
                });
                int _subModelNumber = workingModels.length;
                return dataIter.map((Function1)new Serializable(this, workingModels, localVMethods, _subModelNumber){
                    public static final long serialVersionUID = 0L;
                    public final AbstractModule[] workingModels$1;
                    public final Option[] localVMethods$1;
                    private final int _subModelNumber$2;

                    public final ValidationResult[] apply(MiniBatch<T> batch) {
                        int stackSize = batch.size() / this._subModelNumber$2;
                        int extraSize = batch.size() % this._subModelNumber$2;
                        int parallelism = stackSize == 0 ? extraSize : this._subModelNumber$2;
                        return (ValidationResult[])((ParIterableLike)package$.MODULE$.CollectionsHaveToParArray((Object)RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), parallelism), (Function1)Predef$.MODULE$.$conforms()).toParArray().map((Function1)new Serializable(this, stackSize, extraSize, batch){
                            public static final long serialVersionUID = 0L;
                            private final /* synthetic */ anonfun$18$$anonfun$apply$7 $outer;
                            private final int stackSize$1;
                            private final int extraSize$1;
                            private final MiniBatch batch$1;

                            public final ValidationResult[] apply(int b) {
                                int offset = b * this.stackSize$1 + scala.math.package$.MODULE$.min(b, this.extraSize$1) + 1;
                                int length = this.stackSize$1 + (b < this.extraSize$1 ? 1 : 0);
                                MiniBatch<T> miniBatch = this.batch$1.slice(offset, length);
                                Activity input = miniBatch.getInput();
                                Activity target = miniBatch.getTarget();
                                B output = this.$outer.workingModels$1[b].forward(input);
                                ValidationMethod[] validatMethods = (ValidationMethod[])this.$outer.localVMethods$1[b].get();
                                return (ValidationResult[])Predef$.MODULE$.refArrayOps((Object[])validatMethods).map((Function1)new Serializable(this, target, (Activity)output){
                                    public static final long serialVersionUID = 0L;
                                    private final Activity target$1;
                                    private final Activity output$1;

                                    public final ValidationResult apply(ValidationMethod<T> validation) {
                                        return validation.apply(this.output$1, this.target$1);
                                    }
                                    {
                                        this.target$1 = target$1;
                                        this.output$1 = output$1;
                                    }
                                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ValidationResult.class)));
                            }
                            {
                                if ($outer == null) {
                                    throw null;
                                }
                                this.$outer = $outer;
                                this.stackSize$1 = stackSize$1;
                                this.extraSize$1 = extraSize$1;
                                this.batch$1 = batch$1;
                            }
                        }, (CanBuildFrom)ParArray$.MODULE$.canBuildFrom())).reduce((Function2)new Serializable(this){
                            public static final long serialVersionUID = 0L;

                            public final ValidationResult[] apply(ValidationResult[] left, ValidationResult[] right) {
                                return (ValidationResult[])Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.refArrayOps((Object[])left).zip((GenIterable)Predef$.MODULE$.wrapRefArray((Object[])right), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).map((Function1)new Serializable(this){
                                    public static final long serialVersionUID = 0L;

                                    public final ValidationResult apply(Tuple2<ValidationResult, ValidationResult> x0$4) {
                                        Tuple2<ValidationResult, ValidationResult> tuple2 = x0$4;
                                        if (tuple2 != null) {
                                            ValidationResult l = (ValidationResult)tuple2._1();
                                            ValidationResult r = (ValidationResult)tuple2._2();
                                            ValidationResult validationResult = l.$plus(r);
                                            return validationResult;
                                        }
                                        throw new MatchError(tuple2);
                                    }
                                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ValidationResult.class)));
                            }
                        });
                    }
                    {
                        this.workingModels$1 = workingModels$1;
                        this.localVMethods$1 = localVMethods$1;
                        this._subModelNumber$2 = _subModelNumber$2;
                    }
                });
            }
        }, ClassTag$.MODULE$.apply(DistriOptimizer.CacheV1.class), ClassTag$.MODULE$.apply(MiniBatch.class), ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(ValidationResult.class))).reduce((Function2)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final ValidationResult[] apply(ValidationResult[] left, ValidationResult[] right) {
                return (ValidationResult[])Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.refArrayOps((Object[])left).zip((GenIterable)Predef$.MODULE$.wrapRefArray((Object[])right), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).map((Function1)new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final ValidationResult apply(Tuple2<ValidationResult, ValidationResult> x0$5) {
                        Tuple2<ValidationResult, ValidationResult> tuple2 = x0$5;
                        if (tuple2 != null) {
                            ValidationResult l = (ValidationResult)tuple2._1();
                            ValidationResult r = (ValidationResult)tuple2._2();
                            ValidationResult validationResult = l.$plus(r);
                            return validationResult;
                        }
                        throw new MatchError(tuple2);
                    }
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ValidationResult.class)));
            }
        })).zip((GenIterable)Predef$.MODULE$.wrapRefArray((Object[])vMethods), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)));
        Predef$.MODULE$.refArrayOps((Object[])results).foreach((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final void apply(Tuple2<ValidationResult, ValidationMethod<T>> r) {
                DistriOptimizer$.MODULE$.logger().info(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"", " is ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{r._2(), r._1()})));
            }
        });
        if (validationSummary.isDefined() && step > 0) {
            Predef$.MODULE$.refArrayOps((Object[])results).foreach((Function1)new Serializable(step, validationSummary){
                public static final long serialVersionUID = 0L;
                private final int step$1;
                private final Option validationSummary$1;

                public final ValidationSummary apply(Tuple2<ValidationResult, ValidationMethod<T>> r) {
                    Tuple2<Object, Object> result2 = ((ValidationResult)r._1()).result();
                    return (ValidationSummary)((Summary)this.validationSummary$1.get()).addScalar(((ValidationMethod)r._2()).toString(), BoxesRunTime.unboxToFloat((Object)result2._1()), this.step$1 - 1);
                }
                {
                    this.step$1 = step$1;
                    this.validationSummary$1 = validationSummary$1;
                }
            });
        }
        return Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.refArrayOps((Object[])results).map((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final Tuple2<ValidationMethod<T>, ValidationResult> apply(Tuple2<ValidationResult, ValidationMethod<T>> a) {
                return new Tuple2(a._2(), a._1());
            }
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).toMap(Predef$.MODULE$.$conforms());
    }

    public String getLatestFile(String path, String fileName) {
        File fl = new File(path);
        File[] files = fl.listFiles(new FilenameFilter(fileName){
            private final String fileName$1;

            public boolean accept(File dir, String name) {
                return name.startsWith(this.fileName$1);
            }
            {
                this.fileName$1 = fileName$1;
            }
        });
        LongRef lastMod = LongRef.create((long)Long.MIN_VALUE);
        ObjectRef choice = ObjectRef.create(null);
        Predef$.MODULE$.refArrayOps((Object[])files).map((Function1)new Serializable(lastMod, choice){
            public static final long serialVersionUID = 0L;
            private final LongRef lastMod$1;
            private final ObjectRef choice$1;

            public final void apply(File file) {
                if (file.lastModified() > this.lastMod$1.elem) {
                    this.choice$1.elem = file.getPath();
                    this.lastMod$1.elem = file.lastModified();
                }
            }
            {
                this.lastMod$1 = lastMod$1;
                this.choice$1 = choice$1;
            }
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Unit()));
        return (String)choice.elem;
    }

    public <T> void unpersistCachedModel(RDD<DistriOptimizer.CacheV1<T>> models, ClassTag<T> evidence$15) {
        models.mapPartitions((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final Iterator<DistriOptimizer.CacheV1<T>> apply(Iterator<DistriOptimizer.CacheV1<T>> iter) {
                iter.foreach((Function1)new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final void apply(DistriOptimizer.CacheV1<T> arrayModels) {
                        Predef$.MODULE$.refArrayOps((Object[])arrayModels.localModels()).foreach((Function1)new Serializable(this){
                            public static final long serialVersionUID = 0L;

                            public final void apply(AbstractModule<Activity, Activity, T> x$13) {
                                x$13.release();
                            }
                        });
                    }
                });
                return iter;
            }
        }, models.mapPartitions$default$2(), ClassTag$.MODULE$.apply(DistriOptimizer.CacheV1.class)).count();
        models.unpersist(models.unpersist$default$1());
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public <T> AbstractModule<Activity, Activity, T> getModel(RDD<DistriOptimizer.CacheV1<T>> models, AllReduceParameter<T> parameters2, AbstractModule<Activity, Activity, T> trainingModel, ClassTag<T> evidence$16, TensorNumericMath.TensorNumeric<T> ev) {
        Object object;
        if (trainingModel.isTensorFlow()) {
            Tuple2 tuple2;
            Tuple2 tuple22;
            int partitionNum = models.partitions().length;
            models.mapPartitions((Function1)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final Iterator<Object> apply(Iterator<DistriOptimizer.CacheV1<T>> iter) {
                    ((AbstractModule)Predef$.MODULE$.refArrayOps((Object[])((DistriOptimizer.CacheV1)iter.next()).localModels()).head()).beforeGetModel();
                    return scala.package$.MODULE$.Iterator().single((Object)BoxesRunTime.boxToInteger((int)1));
                }
            }, models.mapPartitions$default$2(), ClassTag$.MODULE$.Int()).reduce((Function2)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final int apply(int x$14, int x$15) {
                    return this.apply$mcIII$sp(x$14, x$15);
                }

                public int apply$mcIII$sp(int x$14, int x$15) {
                    return x$14 + x$15;
                }
            });
            int extraParamLength = BoxesRunTime.unboxToInt((Object)models.map((Function1)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final int apply(DistriOptimizer.CacheV1<T> x$16) {
                    return ((AbstractModule)Predef$.MODULE$.refArrayOps((Object[])x$16.localModels()).head()).getExtraParameter().length;
                }
            }, ClassTag$.MODULE$.Int()).first());
            Tensor[] extraState = new Tensor[extraParamLength];
            RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), extraParamLength).foreach$mVc$sp((Function1)new Serializable(models, extraState){
                public static final long serialVersionUID = 0L;
                private final RDD models$1;
                private final Tensor[] extraState$1;

                public final void apply(int i) {
                    this.apply$mcVI$sp(i);
                }

                public void apply$mcVI$sp(int i) {
                    this.extraState$1[i] = (Tensor)this.models$1.map((Function1)new Serializable(this, i){
                        public static final long serialVersionUID = 0L;
                        private final int i$1;

                        public final Tensor<T> apply(DistriOptimizer.CacheV1<T> x$17) {
                            return ((AbstractModule)Predef$.MODULE$.refArrayOps((Object[])x$17.localModels()).head()).getExtraParameter()[this.i$1];
                        }
                        {
                            this.i$1 = i$1;
                        }
                    }, ClassTag$.MODULE$.apply(Tensor.class)).first();
                }
                {
                    this.models$1 = models$1;
                    this.extraState$1 = extraState$1;
                }
            });
            trainingModel.setExtraParameter(extraState);
            Tuple2<Tensor<T>[], Tensor<T>[]> parameterArray = trainingModel.parameters();
            RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), ((Tensor[])parameterArray._2()).length).foreach((Function1)new Serializable(parameterArray){
                public static final long serialVersionUID = 0L;
                private final Tuple2 parameterArray$1;

                public final Tensor<T> apply(int i) {
                    return ((Tensor[])this.parameterArray$1._2())[i].resizeAs(((Tensor[])this.parameterArray$1._1())[i]);
                }
                {
                    this.parameterArray$1 = parameterArray$1;
                }
            });
            Tuple2<Tensor<T>, Tensor<T>> tuple23 = InternalOptimizerUtil$.MODULE$.getParametersFromModel(trainingModel, evidence$16);
            if (tuple23 == null) throw new MatchError(tuple23);
            Tensor parameter = (Tensor)tuple23._1();
            Tensor gradientParameter = (Tensor)tuple23._2();
            Tuple2 tuple24 = tuple22 = new Tuple2((Object)parameter, (Object)gradientParameter);
            Tensor parameter2 = (Tensor)tuple24._1();
            Tensor gradientParameter2 = (Tensor)tuple24._2();
            Tuple2 tuple25 = (Tuple2)models.mapPartitions((Function1)new Serializable(parameters2, evidence$16, ev){
                public static final long serialVersionUID = 0L;
                private final AllReduceParameter parameters$1;
                private final ClassTag evidence$16$1;
                private final TensorNumericMath.TensorNumeric ev$1;

                public final Iterator<Tuple2<Map<Object, Tensor<T>>, Map<Object, Tensor<T>>>> apply(Iterator<DistriOptimizer.CacheV1<T>> iter) {
                    DistriOptimizer.CacheV1 cached = (DistriOptimizer.CacheV1)iter.next();
                    int curPartitionId = TaskContext$.MODULE$.getPartitionId();
                    Tuple2<Object, Object> tuple2 = InternalOptimizerUtil$.MODULE$.getLocalPartitionRangeFromParameters(this.parameters$1, this.evidence$16$1);
                    if (tuple2 != null) {
                        Tuple2.mcII.sp sp2;
                        int offset = tuple2._1$mcI$sp();
                        int size = tuple2._2$mcI$sp();
                        Tuple2.mcII.sp sp3 = sp2 = new Tuple2.mcII.sp(offset, size);
                        int offset2 = sp3._1$mcI$sp();
                        int size2 = sp3._2$mcI$sp();
                        Tensor<T> weightTensor = Tensor$.MODULE$.apply(size2, this.evidence$16$1, this.ev$1);
                        weightTensor.copy(((Tensor)Predef$.MODULE$.refArrayOps((Object[])cached.modelWeights()).head()).narrow(1, offset2, size2));
                        return scala.package$.MODULE$.Iterator().single((Object)new Tuple2((Object)Predef$.MODULE$.Map().apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)BoxesRunTime.boxToInteger((int)curPartitionId)), weightTensor)})), (Object)Predef$.MODULE$.Map().apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)BoxesRunTime.boxToInteger((int)curPartitionId)), this.parameters$1.gradientPartition())}))));
                    }
                    throw new MatchError(tuple2);
                }
                {
                    this.parameters$1 = parameters$1;
                    this.evidence$16$1 = evidence$16$1;
                    this.ev$1 = ev$1;
                }
            }, models.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class)).reduce((Function2)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final Tuple2<Map<Object, Tensor<T>>, Map<Object, Tensor<T>>> apply(Tuple2<Map<Object, Tensor<T>>, Map<Object, Tensor<T>>> a, Tuple2<Map<Object, Tensor<T>>, Map<Object, Tensor<T>>> b) {
                    return new Tuple2((Object)((MapLike)a._1()).$plus$plus((GenTraversableOnce)b._1()), (Object)((MapLike)a._2()).$plus$plus((GenTraversableOnce)b._2()));
                }
            });
            if (tuple25 == null) throw new MatchError((Object)tuple25);
            Map weights = (Map)tuple25._1();
            Map gradients = (Map)tuple25._2();
            Tuple2 tuple26 = tuple2 = new Tuple2((Object)weights, (Object)gradients);
            Map weights2 = (Map)tuple26._1();
            Map gradients2 = (Map)tuple26._2();
            int taskSize = parameters2.size() / partitionNum;
            Log4Error$.MODULE$.invalidOperationError(taskSize != 0, "parameter length should not less than partition number", Log4Error$.MODULE$.invalidOperationError$default$3(), Log4Error$.MODULE$.invalidOperationError$default$4());
            int extraSize = parameters2.size() % partitionNum;
            object = RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), partitionNum).map((Function1)new Serializable(parameters2, parameter2, gradientParameter2, weights2, gradients2, taskSize, extraSize){
                public static final long serialVersionUID = 0L;
                private final AllReduceParameter parameters$1;
                private final Tensor parameter$1;
                private final Tensor gradientParameter$1;
                private final Map weights$1;
                private final Map gradients$1;
                private final int taskSize$1;
                private final int extraSize$2;

                public final Tensor<T> apply(int pid) {
                    int start2 = this.parameters$1.paramOffset() + pid * this.taskSize$1 + scala.math.package$.MODULE$.min(pid, this.extraSize$2);
                    int length = this.taskSize$1 + (pid < this.extraSize$2 ? 1 : 0);
                    this.parameter$1.narrow(1, start2, length).copy((Tensor)this.weights$1.apply((Object)BoxesRunTime.boxToInteger((int)pid)));
                    return this.gradientParameter$1.narrow(1, start2, length).copy((Tensor)this.gradients$1.apply((Object)BoxesRunTime.boxToInteger((int)pid)));
                }
                {
                    this.parameters$1 = parameters$1;
                    this.parameter$1 = parameter$1;
                    this.gradientParameter$1 = gradientParameter$1;
                    this.weights$1 = weights$1;
                    this.gradients$1 = gradients$1;
                    this.taskSize$1 = taskSize$1;
                    this.extraSize$2 = extraSize$2;
                }
            }, IndexedSeq$.MODULE$.canBuildFrom());
            return trainingModel;
        } else {
            InternalOptimizerUtil$.MODULE$.getModel((Seq<Object>)Predef$.MODULE$.wrapRefArray(new Object[]{models, parameters2, trainingModel}), evidence$16, ev);
            object = BoxedUnit.UNIT;
        }
        return trainingModel;
    }

    private InternalDistriOptimizer$() {
        MODULE$ = this;
        this.logger = LogManager.getLogger(this.getClass());
    }
}

