/*
 * Decompiled with CFR 0.152.
 */
package com.intel.analytics.bigdl.dllib.utils.tf;

import com.intel.analytics.bigdl.dllib.nn.SpatialBatchNormalization;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.Tensor$;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath$TensorNumeric$NumericInt$;
import com.intel.analytics.bigdl.dllib.utils.Log4Error$;
import com.intel.analytics.bigdl.dllib.utils.tf.BigDLToTensorflow;
import com.intel.analytics.bigdl.dllib.utils.tf.Tensorflow$;
import java.nio.ByteOrder;
import org.tensorflow.framework.NodeDef;
import scala.Function1;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;

public final class BatchNorm2DToTF$
implements BigDLToTensorflow {
    public static final BatchNorm2DToTF$ MODULE$;

    static {
        new BatchNorm2DToTF$();
    }

    @Override
    public Seq<NodeDef> toTFDef(AbstractModule<?, ?, ?> module, Seq<NodeDef> inputs2, ByteOrder byteOrder) {
        Seq seq;
        Log4Error$.MODULE$.invalidInputError(inputs2.length() == 1, "BatchNorm only accept one input", Log4Error$.MODULE$.invalidInputError$default$3());
        SpatialBatchNormalization layer2 = (SpatialBatchNormalization)module;
        Log4Error$.MODULE$.invalidInputError(!layer2.isTraining(), "Only support evaluate mode batch norm", Log4Error$.MODULE$.invalidInputError$default$3());
        Tensor<Object> size = Tensor$.MODULE$.apply(layer2.nDim(), ClassTag$.MODULE$.Int(), TensorNumericMath$TensorNumeric$NumericInt$.MODULE$);
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), layer2.nDim()).foreach((Function1)new Serializable(size){
            public static final long serialVersionUID = 0L;
            private final Tensor size$1;

            public final Tensor<Object> apply(int i) {
                return this.size$1.setValue(i + 1, BoxesRunTime.boxToInteger((int)1));
            }
            {
                this.size$1 = size$1;
            }
        });
        size.update(2, (Object)BoxesRunTime.boxToInteger((int)layer2.runningVar().size(1)));
        if (layer2.weight() == null) {
            NodeDef shapeVar = Tensorflow$.MODULE$.const(size, new StringBuilder().append((Object)layer2.getName()).append((Object)"/reshape_1/shape").toString(), byteOrder);
            NodeDef shapeMean = Tensorflow$.MODULE$.const(size, new StringBuilder().append((Object)layer2.getName()).append((Object)"/reshape_2/shape").toString(), byteOrder);
            NodeDef varNode = Tensorflow$.MODULE$.const(layer2.runningVar(), new StringBuilder().append((Object)layer2.getName()).append((Object)"/var").toString(), byteOrder);
            NodeDef mean2 = Tensorflow$.MODULE$.const(layer2.runningMean(), new StringBuilder().append((Object)layer2.getName()).append((Object)"/mean").toString(), byteOrder);
            NodeDef reshapeVar = Tensorflow$.MODULE$.reshape(varNode, shapeVar, new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"", "/reshape_1"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{layer2.getName()})));
            NodeDef reshapeMean = Tensorflow$.MODULE$.reshape(mean2, shapeMean, new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"", "/reshape_2"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{layer2.getName()})));
            NodeDef sqrtVar = Tensorflow$.MODULE$.rsqrt(reshapeVar, new StringBuilder().append((Object)layer2.getName()).append((Object)"/sqrtvar").toString());
            NodeDef mul1 = Tensorflow$.MODULE$.multiply((NodeDef)inputs2.apply(0), sqrtVar, new StringBuilder().append((Object)layer2.getName()).append((Object)"/mul1").toString());
            NodeDef mul2 = Tensorflow$.MODULE$.multiply(reshapeMean, sqrtVar, new StringBuilder().append((Object)layer2.getName()).append((Object)"/mul2").toString());
            NodeDef output = Tensorflow$.MODULE$.subtract(mul1, mul2, new StringBuilder().append((Object)layer2.getName()).append((Object)"/output").toString());
            seq = (Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new NodeDef[]{output, mul2, mul1, reshapeMean, shapeMean, mean2, sqrtVar, reshapeVar, shapeVar, varNode}));
        } else {
            NodeDef shapeVar = Tensorflow$.MODULE$.const(size, new StringBuilder().append((Object)layer2.getName()).append((Object)"/reshape_1/shape").toString(), byteOrder);
            NodeDef shapeMean = Tensorflow$.MODULE$.const(size, new StringBuilder().append((Object)layer2.getName()).append((Object)"/reshape_2/shape").toString(), byteOrder);
            NodeDef shapeScale = Tensorflow$.MODULE$.const(size, new StringBuilder().append((Object)layer2.getName()).append((Object)"/reshape_3/shape").toString(), byteOrder);
            NodeDef shapeOffset = Tensorflow$.MODULE$.const(size, new StringBuilder().append((Object)layer2.getName()).append((Object)"/reshape_4/shape").toString(), byteOrder);
            NodeDef varNode = Tensorflow$.MODULE$.const(layer2.runningVar(), new StringBuilder().append((Object)layer2.getName()).append((Object)"/var").toString(), byteOrder);
            NodeDef mean3 = Tensorflow$.MODULE$.const(layer2.runningMean(), new StringBuilder().append((Object)layer2.getName()).append((Object)"/mean").toString(), byteOrder);
            NodeDef scale = Tensorflow$.MODULE$.const(layer2.weight(), new StringBuilder().append((Object)layer2.getName()).append((Object)"/scale").toString(), byteOrder);
            NodeDef offset = Tensorflow$.MODULE$.const(layer2.bias(), new StringBuilder().append((Object)layer2.getName()).append((Object)"/offset").toString(), byteOrder);
            NodeDef reshapeVar = Tensorflow$.MODULE$.reshape(varNode, shapeVar, new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"", "/reshape_1"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{layer2.getName()})));
            NodeDef reshapeMean = Tensorflow$.MODULE$.reshape(mean3, shapeMean, new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"", "/reshape_2"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{layer2.getName()})));
            NodeDef reshapeScale = Tensorflow$.MODULE$.reshape(scale, shapeScale, new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"", "/reshape_3"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{layer2.getName()})));
            NodeDef reshapeOffset = Tensorflow$.MODULE$.reshape(offset, shapeOffset, new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"", "/reshape_4"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{layer2.getName()})));
            NodeDef sqrtVar = Tensorflow$.MODULE$.rsqrt(reshapeVar, new StringBuilder().append((Object)layer2.getName()).append((Object)"/sqrtvar").toString());
            NodeDef mul0 = Tensorflow$.MODULE$.multiply(reshapeScale, sqrtVar, new StringBuilder().append((Object)layer2.getName()).append((Object)"/mul0").toString());
            NodeDef mul1 = Tensorflow$.MODULE$.multiply((NodeDef)inputs2.apply(0), mul0, new StringBuilder().append((Object)layer2.getName()).append((Object)"/mul1").toString());
            NodeDef mul2 = Tensorflow$.MODULE$.multiply(reshapeMean, mul0, new StringBuilder().append((Object)layer2.getName()).append((Object)"/mul2").toString());
            NodeDef sub2 = Tensorflow$.MODULE$.subtract(reshapeOffset, mul2, new StringBuilder().append((Object)layer2.getName()).append((Object)"/sub").toString());
            NodeDef output = Tensorflow$.MODULE$.add(mul1, sub2, new StringBuilder().append((Object)layer2.getName()).append((Object)"/output").toString());
            seq = (Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new NodeDef[]{output, sub2, mul2, mul1, mul0, reshapeOffset, reshapeMean, reshapeScale, shapeOffset, shapeMean, shapeScale, offset, scale, mean3, sqrtVar, reshapeVar, shapeVar, varNode}));
        }
        return seq;
    }

    private BatchNorm2DToTF$() {
        MODULE$ = this;
    }
}

