/*
 * Decompiled with CFR 0.152.
 */
package com.intel.analytics.bigdl.dllib.optim.parameters;

import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.optim.DistriOptimizer;
import com.intel.analytics.bigdl.dllib.optim.Metrics;
import com.intel.analytics.bigdl.dllib.optim.parameters.AllReduceParameter;
import com.intel.analytics.bigdl.dllib.optim.parameters.ParameterProcessor;
import com.intel.analytics.bigdl.dllib.optim.parameters.ParameterProcessor$class;
import com.intel.analytics.bigdl.dllib.optim.parameters.Util$;
import com.intel.analytics.bigdl.dllib.tensor.ConvertableFrom$ConvertableFromDouble$;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.Table;
import org.apache.spark.rdd.RDD;
import scala.Function1;
import scala.Function2;
import scala.Serializable;
import scala.collection.Iterator;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

@ScalaSignature(bytes="\u0006\u0001\u0005mb!B\u0001\u0003\u0001!\u0001\"a\u0006'3\u001d>\u0014Xn\u00117jaBLgn\u001a)s_\u000e,7o]8s\u0015\t\u0019A!\u0001\u0006qCJ\fW.\u001a;feNT!!\u0002\u0004\u0002\u000b=\u0004H/[7\u000b\u0005\u001dA\u0011!\u00023mY&\u0014'BA\u0005\u000b\u0003\u0015\u0011\u0017n\u001a3m\u0015\tYA\"A\u0005b]\u0006d\u0017\u0010^5dg*\u0011QBD\u0001\u0006S:$X\r\u001c\u0006\u0002\u001f\u0005\u00191m\\7\u0014\u0007\u0001\tr\u0003\u0005\u0002\u0013+5\t1CC\u0001\u0015\u0003\u0015\u00198-\u00197b\u0013\t12C\u0001\u0004B]f\u0014VM\u001a\t\u00031ei\u0011AA\u0005\u00035\t\u0011!\u0003U1sC6,G/\u001a:Qe>\u001cWm]:pe\"AA\u0004\u0001B\u0001B\u0003%a$A\bme9{'/\u001c+ie\u0016\u001c\bn\u001c7e\u0007\u0001\u0001\"AE\u0010\n\u0005\u0001\u001a\"A\u0002#pk\ndW\rC\u0003#\u0001\u0011\u00051%\u0001\u0004=S:LGO\u0010\u000b\u0003I\u0015\u0002\"\u0001\u0007\u0001\t\u000bq\t\u0003\u0019\u0001\u0010\t\u000b\u001d\u0002A\u0011\t\u0015\u0002#\r|G\u000e\\3di\u001ecwNY1m\t\u0006$\u0018-\u0006\u0002*\u0015R)!f\u00156oiR\u00111F\f\t\u0003%1J!!L\n\u0003\tUs\u0017\u000e\u001e\u0005\u0006_\u0019\u0002\u001d\u0001M\u0001\u0003KZ\u00042!M#I\u001d\t\u0011$I\u0004\u00024\u0001:\u0011Ag\u0010\b\u0003kyr!AN\u001f\u000f\u0005]bdB\u0001\u001d<\u001b\u0005I$B\u0001\u001e\u001e\u0003\u0019a$o\\8u}%\tq\"\u0003\u0002\u000e\u001d%\u00111\u0002D\u0005\u0003\u0013)I!a\u0002\u0005\n\u0005\u00053\u0011A\u0002;f]N|'/\u0003\u0002D\t\u0006\tB+\u001a8t_JtU/\\3sS\u000el\u0015\r\u001e5\u000b\u0005\u00053\u0011B\u0001$H\u00055!VM\\:pe:+X.\u001a:jG*\u00111\t\u0012\t\u0003\u0013*c\u0001\u0001B\u0003LM\t\u0007AJA\u0001U#\ti\u0005\u000b\u0005\u0002\u0013\u001d&\u0011qj\u0005\u0002\b\u001d>$\b.\u001b8h!\t\u0011\u0012+\u0003\u0002S'\t\u0019\u0011I\\=\t\u000bQ3\u0003\u0019A+\u0002\r5|G-\u001a7t!\r1v,Y\u0007\u0002/*\u0011\u0001,W\u0001\u0004e\u0012$'B\u0001.\\\u0003\u0015\u0019\b/\u0019:l\u0015\taV,\u0001\u0004ba\u0006\u001c\u0007.\u001a\u0006\u0002=\u0006\u0019qN]4\n\u0005\u0001<&a\u0001*E\tB\u0019!m\u001a%\u000f\u0005\r,gBA\u001ae\u0013\t)a!\u0003\u0002g\t\u0005yA)[:ue&|\u0005\u000f^5nSj,'/\u0003\u0002iS\n)1)Y2iK*\u0011a\r\u0002\u0005\u0006\u0007\u0019\u0002\ra\u001b\t\u000411D\u0015BA7\u0003\u0005I\tE\u000e\u001c*fIV\u001cW\rU1sC6,G/\u001a:\t\u000b=4\u0003\u0019\u00019\u0002\u000f5,GO]5dgB\u0011\u0011O]\u0007\u0002\t%\u00111\u000f\u0002\u0002\b\u001b\u0016$(/[2t\u0011\u0015)h\u00051\u0001w\u0003\u0015\u0019H/\u0019;f!\t9(0D\u0001y\u0015\tIh!A\u0003vi&d7/\u0003\u0002|q\n)A+\u00192mK\")Q\u0010\u0001C!}\u0006\t\u0002O]8dKN\u001c\b+\u0019:b[\u0016$XM]:\u0016\u0007}\fI\u0001\u0006\u0005\u0002\u0002\u0005-\u0011qBA\u000b)\rY\u00131\u0001\u0005\u0007_q\u0004\u001d!!\u0002\u0011\tE*\u0015q\u0001\t\u0004\u0013\u0006%A!B&}\u0005\u0004a\u0005BB\u0002}\u0001\u0004\ti\u0001\u0005\u0003\u0019Y\u0006\u001d\u0001bBA\ty\u0002\u0007\u00111C\u0001\u000b[>$W\r\\\"bG\",\u0007\u0003\u00022h\u0003\u000fAQ!\u001e?A\u0002YDa! \u0001\u0005B\u0005eQ\u0003BA\u000e\u0003K!b!!\b\u0002(\u0005eBcA\u0016\u0002 !9q&a\u0006A\u0004\u0005\u0005\u0002\u0003B\u0019F\u0003G\u00012!SA\u0013\t\u0019Y\u0015q\u0003b\u0001\u0019\"A\u0011\u0011FA\f\u0001\u0004\tY#A\u0003n_\u0012,G\u000e\u0005\u0004\u0002.\u0005M\u00121\u0005\b\u0004i\u0005=\u0012bAA\u0019\u0011\u00059\u0001/Y2lC\u001e,\u0017\u0002BA\u001b\u0003o\u0011a!T8ek2,'bAA\u0019\u0011!1Q/a\u0006A\u0002Y\u0004")
public class L2NormClippingProcessor
implements ParameterProcessor {
    private final double l2NormThreshold;

    @Override
    public <T> void collectGlobalData(RDD<DistriOptimizer.Cache<T>> models, AllReduceParameter<T> parameters2, Metrics metrics, Table state, TensorNumericMath.TensorNumeric<T> ev) {
        int numFinishedModel = BoxesRunTime.unboxToInt((Object)state.get("numFinishedModel").get());
        int parallelism = BoxesRunTime.unboxToInt((Object)state.get("parallelism").get());
        boolean isGradientUpdated = BoxesRunTime.unboxToBoolean((Object)state.get("isGradientUpdated").get());
        double sumSquare = BoxesRunTime.unboxToDouble((Object)models.mapPartitions((Function1)new Serializable(this, parameters2, metrics, ev, numFinishedModel, parallelism, isGradientUpdated){
            public static final long serialVersionUID = 0L;
            private final AllReduceParameter parameters$1;
            private final Metrics metrics$1;
            private final TensorNumericMath.TensorNumeric ev$1;
            private final int numFinishedModel$1;
            private final int parallelism$1;
            private final boolean isGradientUpdated$1;

            public final Iterator<Object> apply(Iterator<DistriOptimizer.Cache<T>> modelIter) {
                Object object;
                if (this.isGradientUpdated$1) {
                    object = BoxedUnit.UNIT;
                } else {
                    long getG = System.nanoTime();
                    this.parameters$1.aggregateGradientPartition(this.numFinishedModel$1);
                    object = this.metrics$1.add("aggregrateGradientParition average executor", System.nanoTime() - getG);
                }
                double sum2 = Util$.MODULE$.getSumsquareInParallel(this.parameters$1.gradientPartition(), this.parallelism$1, this.ev$1);
                return scala.package$.MODULE$.Iterator().single((Object)BoxesRunTime.boxToDouble((double)sum2));
            }
            {
                this.parameters$1 = parameters$1;
                this.metrics$1 = metrics$1;
                this.ev$1 = ev$1;
                this.numFinishedModel$1 = numFinishedModel$1;
                this.parallelism$1 = parallelism$1;
                this.isGradientUpdated$1 = isGradientUpdated$1;
            }
        }, models.mapPartitions$default$2(), ClassTag$.MODULE$.Double()).reduce((Function2)new Serializable(this){
            public static final long serialVersionUID = 0L;

            public final double apply(double x$1, double x$2) {
                return this.apply$mcDDD$sp(x$1, x$2);
            }

            public double apply$mcDDD$sp(double x$1, double x$2) {
                return x$1 + x$2;
            }
        }));
        state.update("isGradientUpdated", BoxesRunTime.boxToBoolean((boolean)true));
        state.update("l2Norm", BoxesRunTime.boxToDouble((double)package$.MODULE$.sqrt(sumSquare)));
    }

    @Override
    public <T> void processParameters(AllReduceParameter<T> parameters2, DistriOptimizer.Cache<T> modelCache, Table state, TensorNumericMath.TensorNumeric<T> ev) {
        double l2Norm2 = BoxesRunTime.unboxToDouble((Object)state.get("l2Norm").get());
        if (l2Norm2 > this.l2NormThreshold) {
            T scale = ev.fromType(BoxesRunTime.boxToDouble((double)(l2Norm2 / this.l2NormThreshold)), ConvertableFrom$ConvertableFromDouble$.MODULE$);
            parameters2.gradientPartition().div(scale);
        }
    }

    @Override
    public <T> void processParameters(AbstractModule<Activity, Activity, T> model, Table state, TensorNumericMath.TensorNumeric<T> ev) {
        int parallelism = BoxesRunTime.unboxToInt((Object)state.get("parallelism").get());
        Tensor gradients = (Tensor)model.getParameters()._2();
        double l2Norm2 = package$.MODULE$.sqrt(Util$.MODULE$.getSumsquareInParallel(gradients, parallelism, ev));
        if (l2Norm2 > this.l2NormThreshold) {
            T scale = ev.fromType(BoxesRunTime.boxToDouble((double)(l2Norm2 / this.l2NormThreshold)), ConvertableFrom$ConvertableFromDouble$.MODULE$);
            gradients.div(scale);
        }
    }

    public L2NormClippingProcessor(double l2NormThreshold) {
        this.l2NormThreshold = l2NormThreshold;
        ParameterProcessor$class.$init$(this);
    }
}

