/*
 * Decompiled with CFR 0.152.
 */
package com.intel.analytics.bigdl.dllib.nn;

import com.intel.analytics.bigdl.dllib.nn.SpatialBatchNormalization;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.DataFormat;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.DataFormat$NCHW$;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.Log4Error$;
import com.intel.analytics.bigdl.dllib.utils.ParameterSynchronizer$;
import java.util.Map;
import scala.Function1;
import scala.Serializable;
import scala.collection.IterableLike;
import scala.collection.JavaConverters$;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.Null$;

public final class SpatialBatchNormalization$
implements Serializable {
    public static final SpatialBatchNormalization$ MODULE$;

    static {
        new SpatialBatchNormalization$();
    }

    public <T> SpatialBatchNormalization<T> apply(int nOutput, double eps, double momentum, boolean affine, Tensor<T> initWeight, Tensor<T> initBias, Tensor<T> initGradWeight, Tensor<T> initGradBias, DataFormat dataFormat, ClassTag<T> evidence$2, TensorNumericMath.TensorNumeric<T> ev) {
        return new SpatialBatchNormalization<T>(nOutput, eps, momentum, affine, initWeight, initBias, initGradWeight, initGradBias, dataFormat, evidence$2, ev);
    }

    public <T> double apply$default$2() {
        return 1.0E-5;
    }

    public <T> double apply$default$3() {
        return 0.1;
    }

    public <T> boolean apply$default$4() {
        return true;
    }

    public <T> Null$ apply$default$5() {
        return null;
    }

    public <T> Null$ apply$default$6() {
        return null;
    }

    public <T> Null$ apply$default$7() {
        return null;
    }

    public <T> Null$ apply$default$8() {
        return null;
    }

    public <T> DataFormat apply$default$9() {
        return DataFormat$NCHW$.MODULE$;
    }

    public void updateOutputNHWCInferFloat(Tensor<Object> input, Tensor<Object> output, Tensor<Object> mean2, Tensor<Object> variance, Tensor<Object> scale, Tensor<Object> offset, float eps) {
        Log4Error$.MODULE$.invalidInputError(input.isContiguous(), "BatchNorm NHWC require a contiguous input", Log4Error$.MODULE$.invalidInputError$default$3());
        float[] inputData = (float[])input.storage().array();
        int inputOffset = input.storageOffset() - 1;
        float[] outputData = (float[])output.storage().array();
        int outputOffset = output.storageOffset() - 1;
        int nChannels = input.size(4);
        int n = input.nElement();
        float[] meanData = (float[])mean2.storage().array();
        int meanOffset = mean2.storageOffset() - 1;
        float[] varData = (float[])variance.storage().array();
        int varOffset = variance.storageOffset() - 1;
        if (scale == null) {
            for (int i = 0; i < n; i += nChannels) {
                for (int c = 0; c < nChannels; ++c) {
                    float invStd = 1.0f / (float)Math.sqrt(varData[varOffset + c] + eps);
                    outputData[i + outputOffset + c] = (inputData[i + inputOffset + c] - meanData[c + meanOffset]) * invStd;
                }
            }
        } else {
            float[] scaleData = (float[])scale.storage().array();
            int scaleOffset = scale.storageOffset() - 1;
            float[] offsetData = (float[])offset.storage().array();
            int offsetOffset = offset.storageOffset() - 1;
            boolean isIdenticalScale = false;
            float identicalScale = 0.0f;
            if (scaleData.length == 1) {
                isIdenticalScale = true;
                identicalScale = scaleData[0];
            }
            if (isIdenticalScale) {
                for (int i = 0; i < n; i += nChannels) {
                    for (int c = 0; c < nChannels; ++c) {
                        float invStd = 1.0f / (float)Math.sqrt(varData[varOffset + c] + eps);
                        outputData[i + outputOffset + c] = (inputData[i + inputOffset + c] - meanData[c + meanOffset]) * invStd * identicalScale + offsetData[offsetOffset + c];
                    }
                }
            } else {
                for (int i = 0; i < n; i += nChannels) {
                    for (int c = 0; c < nChannels; ++c) {
                        float invStd = 1.0f / (float)Math.sqrt(varData[varOffset + c] + eps);
                        outputData[i + outputOffset + c] = (inputData[i + inputOffset + c] - meanData[c + meanOffset]) * invStd * scaleData[scaleOffset + c] + offsetData[offsetOffset + c];
                    }
                }
            }
        }
    }

    public void updateOutputNHWCInferDouble(Tensor<Object> input, Tensor<Object> output, Tensor<Object> mean2, Tensor<Object> variance, Tensor<Object> scale, Tensor<Object> offset, double eps) {
        Log4Error$.MODULE$.invalidInputError(input.isContiguous(), "BatchNorm NHWC require a contiguous input", Log4Error$.MODULE$.invalidInputError$default$3());
        double[] inputData = (double[])input.storage().array();
        int inputOffset = input.storageOffset() - 1;
        double[] outputData = (double[])output.storage().array();
        int outputOffset = output.storageOffset() - 1;
        int nChannels = input.size(4);
        int n = input.nElement();
        double[] meanData = (double[])mean2.storage().array();
        int meanOffset = mean2.storageOffset() - 1;
        double[] varData = (double[])variance.storage().array();
        int varOffset = variance.storageOffset() - 1;
        if (scale == null) {
            for (int i = 0; i < n; i += nChannels) {
                for (int c = 0; c < nChannels; ++c) {
                    double invStd = 1.0 / Math.sqrt(varData[varOffset + c] + eps);
                    outputData[i + outputOffset + c] = (inputData[i + inputOffset + c] - meanData[meanOffset + c]) * invStd;
                }
            }
        } else {
            double[] scaleData = (double[])scale.storage().array();
            int scaleOffset = scale.storageOffset() - 1;
            double[] offsetData = (double[])offset.storage().array();
            int offsetOffset = offset.storageOffset() - 1;
            for (int i = 0; i < n; i += nChannels) {
                for (int c = 0; c < nChannels; ++c) {
                    double invStd = 1.0 / Math.sqrt(varData[varOffset + c] + eps);
                    outputData[i + outputOffset + c] = (inputData[i + inputOffset + c] - meanData[meanOffset + c]) * invStd * scaleData[scaleOffset + c] + offsetData[offsetOffset + c];
                }
            }
        }
    }

    public void updateOutputNHWCTrainFloat(Tensor<Object> input, Tensor<Object> output, Tensor<Object> saveMean, Tensor<Object> saveStd, Tensor<Object> runningMean, Tensor<Object> runningVar, Tensor<Object> scale, Tensor<Object> offset, float eps, float momentum, Tensor<Object> batchVar, Tensor<Object> saveVar) {
        int c;
        int i;
        Object object;
        Log4Error$.MODULE$.invalidInputError(input.isContiguous(), "BatchNorm NHWC require a contiguous input", Log4Error$.MODULE$.invalidInputError$default$3());
        float[] inputData = (float[])input.storage().array();
        int inputOffset = input.storageOffset() - 1;
        float[] outputData = (float[])output.storage().array();
        int outputOffset = output.storageOffset() - 1;
        int nChannels = input.size(4);
        if (saveMean.size(1) != nChannels) {
            saveMean.resize(nChannels);
            saveStd.resize(nChannels);
            runningMean.resize(nChannels);
            object = runningVar.resize(nChannels);
        } else {
            object = BoxedUnit.UNIT;
        }
        float[] meanData = (float[])saveMean.storage().array();
        int meanOffset = saveMean.storageOffset() - 1;
        int n = input.nElement();
        int frameSize = n / nChannels;
        for (i = 0; i < n; i += nChannels) {
            for (int c2 = 0; c2 < nChannels; ++c2) {
                int n2 = meanOffset + c2;
                meanData[n2] = meanData[n2] + inputData[inputOffset + i + c2];
            }
        }
        float[] runningMeanData = (float[])runningMean.storage().array();
        int runningMeanDataOffset = runningMean.storageOffset() - 1;
        for (c = 0; c < nChannels; ++c) {
            int n3 = meanOffset + c;
            meanData[n3] = meanData[n3] / (float)frameSize;
            runningMeanData[runningMeanDataOffset + c] = meanData[meanOffset + c] * momentum + (1.0f - momentum) * runningMeanData[c + runningMeanDataOffset];
        }
        float[] stdData = (float[])saveStd.storage().array();
        int stdOffset = saveStd.storageOffset() - 1;
        for (i = 0; i < n; i += nChannels) {
            for (int c3 = 0; c3 < nChannels; ++c3) {
                float diff = inputData[inputOffset + i + c3] - meanData[meanOffset + c3];
                int n4 = stdOffset + c3;
                stdData[n4] = stdData[n4] + diff * diff;
            }
        }
        float[] runningVarData = (float[])runningVar.storage().array();
        int runningVarOffset = runningVar.storageOffset() - 1;
        for (c = 0; c < nChannels; ++c) {
            Object object2;
            if (stdData[c + stdOffset] == 0.0f && eps == 0.0f) {
                stdData[c + stdOffset] = 0.0f;
                Object object3 = saveVar == null ? BoxedUnit.UNIT : saveVar.setValue(c + 1, BoxesRunTime.boxToFloat((float)0.0f));
                if (batchVar == null) {
                    object2 = BoxedUnit.UNIT;
                    continue;
                }
                object2 = batchVar.setValue(c + 1, BoxesRunTime.boxToFloat((float)0.0f));
                continue;
            }
            float s2 = stdData[c + stdOffset];
            float unbiasedVar = s2 / (float)(frameSize - 1);
            Object object4 = saveVar == null ? BoxedUnit.UNIT : saveVar.setValue(c + 1, BoxesRunTime.boxToFloat((float)(s2 / (float)frameSize)));
            Object object5 = batchVar == null ? BoxedUnit.UNIT : batchVar.setValue(c + 1, BoxesRunTime.boxToFloat((float)unbiasedVar));
            stdData[c + stdOffset] = 1.0f / (float)Math.sqrt(s2 / (float)frameSize + eps);
            runningVarData[c + runningVarOffset] = momentum * unbiasedVar + (1.0f - momentum) * runningVarData[c + runningVarOffset];
            object2 = BoxedUnit.UNIT;
        }
        if (scale == null) {
            for (i = 0; i < n; i += nChannels) {
                for (int c4 = 0; c4 < nChannels; ++c4) {
                    outputData[i + outputOffset + c4] = (inputData[i + inputOffset + c4] - meanData[meanOffset + c4]) * stdData[c4 + stdOffset];
                }
            }
        } else {
            float[] scaleData = (float[])scale.storage().array();
            int scaleOffset = scale.storageOffset() - 1;
            float[] offsetData = (float[])offset.storage().array();
            int offsetOffset = offset.storageOffset() - 1;
            for (i = 0; i < n; i += nChannels) {
                for (int c5 = 0; c5 < nChannels; ++c5) {
                    outputData[i + outputOffset + c5] = (inputData[i + inputOffset + c5] - meanData[meanOffset + c5]) * stdData[c5 + stdOffset] * scaleData[scaleOffset + c5] + offsetData[offsetOffset + c5];
                }
            }
        }
    }

    public Tensor<Object> updateOutputNHWCTrainFloat$default$11() {
        return null;
    }

    public Tensor<Object> updateOutputNHWCTrainFloat$default$12() {
        return null;
    }

    public void updateOutputNHWCTrainDouble(Tensor<Object> input, Tensor<Object> output, Tensor<Object> saveMean, Tensor<Object> saveStd, Tensor<Object> runningMean, Tensor<Object> runningVar, Tensor<Object> scale, Tensor<Object> offset, double eps, double momentum, Tensor<Object> batchVar, Tensor<Object> saveVar) {
        int c;
        int i;
        Object object;
        Log4Error$.MODULE$.invalidInputError(input.isContiguous(), "BatchNorm NHWC require a contiguous input", Log4Error$.MODULE$.invalidInputError$default$3());
        double[] inputData = (double[])input.storage().array();
        int inputOffset = input.storageOffset() - 1;
        double[] outputData = (double[])output.storage().array();
        int outputOffset = output.storageOffset() - 1;
        int nChannels = input.size(4);
        if (saveMean.size(1) != nChannels) {
            saveMean.resize(nChannels);
            saveStd.resize(nChannels);
            runningMean.resize(nChannels);
            object = runningVar.resize(nChannels);
        } else {
            object = BoxedUnit.UNIT;
        }
        double[] meanData = (double[])saveMean.storage().array();
        int meanOffset = saveMean.storageOffset() - 1;
        int n = input.nElement();
        int frameSize = n / nChannels;
        for (i = 0; i < n; i += nChannels) {
            for (int c2 = 0; c2 < nChannels; ++c2) {
                int n2 = c2 + meanOffset;
                meanData[n2] = meanData[n2] + inputData[inputOffset + i + c2];
            }
        }
        double[] runningMeanData = (double[])runningMean.storage().array();
        int runningMeanOffset = runningMean.storageOffset() - 1;
        for (c = 0; c < nChannels; ++c) {
            int n3 = c + meanOffset;
            meanData[n3] = meanData[n3] / (double)frameSize;
            runningMeanData[c + runningMeanOffset] = meanData[c + meanOffset] * momentum + (1.0 - momentum) * runningMeanData[c + runningMeanOffset];
        }
        double[] stdData = (double[])saveStd.storage().array();
        int stdOffset = saveStd.storageOffset() - 1;
        for (i = 0; i < n; i += nChannels) {
            for (int c3 = 0; c3 < nChannels; ++c3) {
                double diff = inputData[inputOffset + i + c3] - meanData[c3 + meanOffset];
                int n4 = c3 + stdOffset;
                stdData[n4] = stdData[n4] + diff * diff;
            }
        }
        double[] runningVarData = (double[])runningVar.storage().array();
        int runningVarOffset = runningVar.storageOffset() - 1;
        for (c = 0; c < nChannels; ++c) {
            Object object2;
            if (stdData[c + stdOffset] == 0.0 && eps == 0.0) {
                stdData[c + stdOffset] = 0.0;
                Object object3 = saveVar == null ? BoxedUnit.UNIT : saveVar.setValue(c + 1, BoxesRunTime.boxToDouble((double)0.0));
                if (batchVar == null) {
                    object2 = BoxedUnit.UNIT;
                    continue;
                }
                object2 = batchVar.setValue(c + 1, BoxesRunTime.boxToDouble((double)0.0));
                continue;
            }
            double s2 = stdData[c + stdOffset];
            double unbiasedVar = s2 / (double)(frameSize - 1);
            Object object4 = saveVar == null ? BoxedUnit.UNIT : saveVar.setValue(c + 1, BoxesRunTime.boxToDouble((double)(s2 / (double)frameSize)));
            Object object5 = batchVar == null ? BoxedUnit.UNIT : batchVar.setValue(c + 1, BoxesRunTime.boxToDouble((double)unbiasedVar));
            stdData[c + stdOffset] = 1.0f / (float)Math.sqrt(s2 / (double)frameSize + eps);
            runningVarData[c + runningVarOffset] = momentum * unbiasedVar + (1.0 - momentum) * runningVarData[c + runningVarOffset];
            object2 = BoxedUnit.UNIT;
        }
        if (scale == null) {
            for (i = 0; i < n; i += nChannels) {
                for (int c4 = 0; c4 < nChannels; ++c4) {
                    outputData[i + outputOffset + c4] = (inputData[i + inputOffset + c4] - meanData[c4 + meanOffset]) * stdData[c4 + stdOffset];
                }
            }
        } else {
            double[] scaleData = (double[])scale.storage().array();
            int scaleOffset = scale.storageOffset() - 1;
            double[] offsetData = (double[])offset.storage().array();
            int offsetOffset = offset.storageOffset() - 1;
            for (i = 0; i < n; i += nChannels) {
                for (int c5 = 0; c5 < nChannels; ++c5) {
                    outputData[i + outputOffset + c5] = (inputData[i + inputOffset + c5] - meanData[c5 + meanOffset]) * stdData[c5 + stdOffset] * scaleData[c5 + scaleOffset] + offsetData[c5 + offsetOffset];
                }
            }
        }
    }

    public Tensor<Object> updateOutputNHWCTrainDouble$default$11() {
        return null;
    }

    public Tensor<Object> updateOutputNHWCTrainDouble$default$12() {
        return null;
    }

    public void updateOutputNCHWInferFloat(Tensor<Object> input, Tensor<Object> output, Tensor<Object> mean2, Tensor<Object> variance, Tensor<Object> scale, Tensor<Object> offset, float eps) {
        Log4Error$.MODULE$.invalidInputError(input.isContiguous(), "BatchNorm NCHW require a contiguous input", Log4Error$.MODULE$.invalidInputError$default$3());
        float[] inputData = (float[])input.storage().array();
        int inputOffset = input.storageOffset() - 1;
        float[] outputData = (float[])output.storage().array();
        int outputOffset = output.storageOffset() - 1;
        float[] meanData = (float[])mean2.storage().array();
        int meanOffset = mean2.storageOffset() - 1;
        float[] varData = (float[])variance.storage().array();
        int varOffset = variance.storageOffset() - 1;
        int nChannels = input.size(2);
        int nBatch = input.size(1);
        int nFrame = input.size(3) * input.size(4);
        if (scale == null) {
            int i = 0;
            for (int b = 0; b < nBatch; ++b) {
                for (int c = 0; c < nChannels; ++c) {
                    int k = 0;
                    while (k < nFrame) {
                        float invStd = 1.0f / (float)Math.sqrt(varData[varOffset + c] + eps);
                        outputData[i + outputOffset] = (inputData[i + inputOffset] - meanData[c + meanOffset]) * invStd;
                        ++k;
                        ++i;
                    }
                }
            }
        } else {
            float[] scaleData = (float[])scale.storage().array();
            int scaleOffset = scale.storageOffset() - 1;
            float[] offsetData = (float[])offset.storage().array();
            int offsetOffset = offset.storageOffset() - 1;
            int i = 0;
            for (int b = 0; b < nBatch; ++b) {
                for (int c = 0; c < nChannels; ++c) {
                    int k = 0;
                    while (k < nFrame) {
                        float invStd = 1.0f / (float)Math.sqrt(varData[varOffset + c] + eps);
                        outputData[i + outputOffset] = (inputData[i + inputOffset] - meanData[c + meanOffset]) * invStd * scaleData[c + scaleOffset] + offsetData[c + offsetOffset];
                        ++k;
                        ++i;
                    }
                }
            }
        }
    }

    public void updateOutputNCHWInferDouble(Tensor<Object> input, Tensor<Object> output, Tensor<Object> mean2, Tensor<Object> variance, Tensor<Object> scale, Tensor<Object> offset, double eps) {
        Log4Error$.MODULE$.invalidInputError(input.isContiguous(), "BatchNorm NCHW require a contiguous input", Log4Error$.MODULE$.invalidInputError$default$3());
        double[] inputData = (double[])input.storage().array();
        int inputOffset = input.storageOffset() - 1;
        double[] outputData = (double[])output.storage().array();
        int outputOffset = output.storageOffset() - 1;
        double[] meanData = (double[])mean2.storage().array();
        int meanOffset = mean2.storageOffset() - 1;
        double[] varData = (double[])variance.storage().array();
        int varOffset = variance.storageOffset() - 1;
        int nChannels = input.size(2);
        int nBatch = input.size(1);
        int nFrame = input.size(3) * input.size(4);
        if (scale == null) {
            int i = 0;
            for (int b = 0; b < nBatch; ++b) {
                for (int c = 0; c < nChannels; ++c) {
                    int k = 0;
                    while (k < nFrame) {
                        double invStd = 1.0 / Math.sqrt(varData[varOffset + c] + eps);
                        outputData[i + outputOffset] = (inputData[i + inputOffset] - meanData[c + meanOffset]) * invStd;
                        ++k;
                        ++i;
                    }
                }
            }
        } else {
            double[] scaleData = (double[])scale.storage().array();
            int scaleOffset = scale.storageOffset() - 1;
            double[] offsetData = (double[])offset.storage().array();
            int offsetOffset = offset.storageOffset() - 1;
            int i = 0;
            for (int b = 0; b < nBatch; ++b) {
                for (int c = 0; c < nChannels; ++c) {
                    int k = 0;
                    while (k < nFrame) {
                        double invStd = 1.0 / Math.sqrt(varData[varOffset + c] + eps);
                        outputData[i + outputOffset] = (inputData[i + inputOffset] - meanData[c + meanOffset]) * invStd * scaleData[c + scaleOffset] + offsetData[c + offsetOffset];
                        ++k;
                        ++i;
                    }
                }
            }
        }
    }

    public void updateGradInputNHWCTrainFloat(Tensor<Object> input, Tensor<Object> gradOutput, Tensor<Object> gradInput, Tensor<Object> scale, Tensor<Object> saveMean, Tensor<Object> saveStd, Tensor<Object> gMean, Tensor<Object> gxMean) {
        Object object;
        Log4Error$.MODULE$.invalidInputError(input.nDimension() == 4, "BN require a 4D input", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(input.isContiguous(), "input is not contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(gradOutput.nDimension() == 4, "BN require a 4D gradient", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(gradOutput.isContiguous(), "gradient is not contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        int nChannel = gradOutput.size(4);
        Log4Error$.MODULE$.invalidInputError(saveMean.size(1) == nChannel, "saveMean length is not consistent with channel number", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(saveStd.size(1) == nChannel, "saveStd length is not consistent with channel number", Log4Error$.MODULE$.invalidInputError$default$3());
        gradInput.resizeAs(gradOutput);
        if (gMean.isEmpty()) {
            gMean.resize(nChannel);
            object = gxMean.resize(nChannel);
        } else {
            object = BoxedUnit.UNIT;
        }
        float[] inputData = (float[])input.storage().array();
        int inputOffset = input.storageOffset() - 1;
        float[] gradOutputData = (float[])gradOutput.storage().array();
        int gradOutputOffset = gradOutput.storageOffset() - 1;
        float[] gradInputData = (float[])gradInput.storage().array();
        int gradInputOffset = gradInput.storageOffset() - 1;
        float[] saveMeanData = (float[])saveMean.storage().array();
        int saveMeanOffset = saveMean.storageOffset() - 1;
        float[] saveStdData = (float[])saveStd.storage().array();
        int saveStdOffset = saveStd.storageOffset() - 1;
        float[] gMeanData = (float[])gMean.storage().array();
        float[] gxMeanData = (float[])gxMean.storage().array();
        int n = gradOutput.nElement();
        int i = 0;
        while (i < n) {
            int c = 0;
            while (c < nChannel) {
                int n2 = c;
                gMeanData[n2] = gMeanData[n2] + gradOutputData[i + gradOutputOffset];
                int n3 = c;
                gxMeanData[n3] = gxMeanData[n3] + gradOutputData[i + gradOutputOffset] * (inputData[i + inputOffset] - saveMeanData[c + saveMeanOffset]);
                ++c;
                ++i;
            }
        }
        int c = 0;
        int size = n / nChannel;
        while (c < nChannel) {
            int n4 = c;
            gMeanData[n4] = gMeanData[n4] / (float)size;
            int n5 = c++;
            gxMeanData[n5] = gxMeanData[n5] / (float)size;
        }
        if (scale == null) {
            i = 0;
            while (i < n) {
                int c2 = 0;
                while (c2 < nChannel) {
                    float invStd = saveStdData[saveStdOffset + c2];
                    gradInputData[gradInputOffset + i] = invStd * (gradOutputData[gradOutputOffset + i] - gMeanData[c2] - gxMeanData[c2] * invStd * invStd * (inputData[inputOffset + i] - saveMeanData[saveMeanOffset + c2]));
                    ++c2;
                    ++i;
                }
            }
        } else {
            Log4Error$.MODULE$.invalidInputError(scale.size(1) == nChannel, "scale length is not consistent with channel number", Log4Error$.MODULE$.invalidInputError$default$3());
            float[] scaleData = (float[])scale.storage().array();
            int scaleOffset = scale.storageOffset() - 1;
            i = 0;
            while (i < n) {
                int c3 = 0;
                while (c3 < nChannel) {
                    float invStd = saveStdData[saveStdOffset + c3];
                    gradInputData[gradInputOffset + i] = scaleData[scaleOffset + c3] * invStd * (gradOutputData[gradOutputOffset + i] - gMeanData[c3] - gxMeanData[c3] * invStd * invStd * (inputData[inputOffset + i] - saveMeanData[saveMeanOffset + c3]));
                    ++c3;
                    ++i;
                }
            }
        }
    }

    public void updateGradInputNHWCTrainDouble(Tensor<Object> input, Tensor<Object> gradOutput, Tensor<Object> gradInput, Tensor<Object> scale, Tensor<Object> saveMean, Tensor<Object> saveStd, Tensor<Object> gMean, Tensor<Object> gxMean) {
        Object object;
        Log4Error$.MODULE$.invalidInputError(input.nDimension() == 4, "BN require a 4D input", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(input.isContiguous(), "input is not contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(gradOutput.nDimension() == 4, "BN require a 4D gradient", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(gradOutput.isContiguous(), "gradient is not contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        int nChannel = gradOutput.size(4);
        Log4Error$.MODULE$.invalidInputError(saveMean.size(1) == nChannel, "saveMean length is not consistent with channel number", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(saveStd.size(1) == nChannel, "saveStd length is not consistent with channel number", Log4Error$.MODULE$.invalidInputError$default$3());
        gradInput.resizeAs(gradOutput);
        if (gMean.isEmpty()) {
            gMean.resize(nChannel);
            object = gxMean.resize(nChannel);
        } else {
            object = BoxedUnit.UNIT;
        }
        double[] inputData = (double[])input.storage().array();
        int inputOffset = input.storageOffset() - 1;
        double[] gradOutputData = (double[])gradOutput.storage().array();
        int gradOutputOffset = gradOutput.storageOffset() - 1;
        double[] gradInputData = (double[])gradInput.storage().array();
        int gradInputOffset = gradInput.storageOffset() - 1;
        double[] saveMeanData = (double[])saveMean.storage().array();
        int saveMeanOffset = saveMean.storageOffset() - 1;
        double[] saveStdData = (double[])saveStd.storage().array();
        int saveStdOffset = saveStd.storageOffset() - 1;
        double[] gMeanData = (double[])gMean.storage().array();
        double[] gxMeanData = (double[])gxMean.storage().array();
        int n = gradOutput.nElement();
        int i = 0;
        while (i < n) {
            int c = 0;
            while (c < nChannel) {
                int n2 = c;
                gMeanData[n2] = gMeanData[n2] + gradOutputData[i + gradOutputOffset];
                int n3 = c;
                gxMeanData[n3] = gxMeanData[n3] + gradOutputData[i + gradOutputOffset] * (inputData[i + inputOffset] - saveMeanData[c + saveMeanOffset]);
                ++c;
                ++i;
            }
        }
        int c = 0;
        int size = n / nChannel;
        while (c < nChannel) {
            int n4 = c;
            gMeanData[n4] = gMeanData[n4] / (double)size;
            int n5 = c++;
            gxMeanData[n5] = gxMeanData[n5] / (double)size;
        }
        if (scale == null) {
            i = 0;
            while (i < n) {
                int c2 = 0;
                while (c2 < nChannel) {
                    double invStd = saveStdData[saveStdOffset + c2];
                    gradInputData[gradInputOffset + i] = invStd * (gradOutputData[gradOutputOffset + i] - gMeanData[c2] - gxMeanData[c2] * invStd * invStd * (inputData[inputOffset + i] - saveMeanData[saveMeanOffset + c2]));
                    ++c2;
                    ++i;
                }
            }
        } else {
            Log4Error$.MODULE$.invalidInputError(scale.size(1) == nChannel, "scale length is not consistent with channel number", Log4Error$.MODULE$.invalidInputError$default$3());
            double[] scaleData = (double[])scale.storage().array();
            int scaleOffset = scale.storageOffset() - 1;
            i = 0;
            while (i < n) {
                int c3 = 0;
                while (c3 < nChannel) {
                    double invStd = saveStdData[saveStdOffset + c3];
                    gradInputData[gradInputOffset + i] = scaleData[scaleOffset + c3] * invStd * (gradOutputData[gradOutputOffset + i] - gMeanData[c3] - gxMeanData[c3] * invStd * invStd * (inputData[inputOffset + i] - saveMeanData[saveMeanOffset + c3]));
                    ++c3;
                    ++i;
                }
            }
        }
    }

    public void updateGradInputNHWCInferFloat(Tensor<Object> gradOutput, Tensor<Object> gradInput, Tensor<Object> scale, Tensor<Object> saveStd) {
        Log4Error$.MODULE$.invalidInputError(gradOutput.nDimension() == 4, "BN require a 4D gradient", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(gradOutput.isContiguous(), "gradient is not contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        int nChannel = gradOutput.size(4);
        Log4Error$.MODULE$.invalidInputError(saveStd.size(1) == nChannel, "saveStd length is not consistent with channel number", Log4Error$.MODULE$.invalidInputError$default$3());
        gradInput.resizeAs(gradOutput);
        float[] gradOutputData = (float[])gradOutput.storage().array();
        int gradOutputOffset = gradOutput.storageOffset() - 1;
        float[] gradInputData = (float[])gradInput.storage().array();
        int gradInputOffset = gradInput.storageOffset() - 1;
        float[] saveStdData = (float[])saveStd.storage().array();
        int saveStdOffset = saveStd.storageOffset() - 1;
        int n = gradOutput.nElement();
        if (scale == null) {
            int i = 0;
            while (i < n) {
                int c = 0;
                while (c < nChannel) {
                    float invStd = saveStdData[saveStdOffset + c];
                    gradInputData[gradInputOffset + i] = invStd * gradOutputData[gradOutputOffset + i];
                    ++c;
                    ++i;
                }
            }
        } else {
            Log4Error$.MODULE$.invalidInputError(scale.size(1) == nChannel, "scale length is not consistent with channel number", Log4Error$.MODULE$.invalidInputError$default$3());
            float[] scaleData = (float[])scale.storage().array();
            int scaleOffset = scale.storageOffset() - 1;
            int i = 0;
            while (i < n) {
                int c = 0;
                while (c < nChannel) {
                    float invStd = saveStdData[saveStdOffset + c];
                    gradInputData[gradInputOffset + i] = scaleData[scaleOffset + c] * invStd * gradOutputData[gradOutputOffset + i];
                    ++c;
                    ++i;
                }
            }
        }
    }

    public void updateGradInputNHWCInferDouble(Tensor<Object> gradOutput, Tensor<Object> gradInput, Tensor<Object> scale, Tensor<Object> saveStd) {
        Log4Error$.MODULE$.invalidInputError(gradOutput.nDimension() == 4, "BN require a 4D gradient", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(gradOutput.isContiguous(), "gradient is not contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        int nChannel = gradOutput.size(4);
        Log4Error$.MODULE$.invalidInputError(saveStd.size(1) == nChannel, "saveStd length is not consistent with channel number", Log4Error$.MODULE$.invalidInputError$default$3());
        gradInput.resizeAs(gradOutput);
        double[] gradOutputData = (double[])gradOutput.storage().array();
        int gradOutputOffset = gradOutput.storageOffset() - 1;
        double[] gradInputData = (double[])gradInput.storage().array();
        int gradInputOffset = gradInput.storageOffset() - 1;
        double[] saveStdData = (double[])saveStd.storage().array();
        int saveStdOffset = saveStd.storageOffset() - 1;
        int n = gradOutput.nElement();
        int i = 0;
        if (scale == null) {
            while (i < n) {
                int c = 0;
                while (c < nChannel) {
                    double invStd = saveStdData[saveStdOffset + c];
                    gradInputData[gradInputOffset + i] = invStd * gradOutputData[gradOutputOffset + i];
                    ++c;
                    ++i;
                }
            }
        } else {
            Log4Error$.MODULE$.invalidInputError(scale.size(1) == nChannel, "scale length is not consistent with channel number", Log4Error$.MODULE$.invalidInputError$default$3());
            double[] scaleData = (double[])scale.storage().array();
            int scaleOffset = scale.storageOffset() - 1;
            while (i < n) {
                int c = 0;
                while (c < nChannel) {
                    double invStd = saveStdData[saveStdOffset + c];
                    gradInputData[gradInputOffset + i] = scaleData[scaleOffset + c] * invStd * gradOutputData[gradOutputOffset + i];
                    ++c;
                    ++i;
                }
            }
        }
    }

    public void updateGradInputNCHWTrainFloat(Tensor<Object> input, Tensor<Object> gradOutput, Tensor<Object> gradInput, Tensor<Object> scale, Tensor<Object> saveMean, Tensor<Object> saveStd, Tensor<Object> gMean, Tensor<Object> gxMean, float[] globalGmean, float[] globalGxmean, String gMeanKey, String gxMeanKey, boolean needSync) {
        int b;
        Object object;
        Log4Error$.MODULE$.invalidInputError(input.nDimension() == 4, "BN require a 4D input", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(input.isContiguous(), "input is not contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(gradOutput.nDimension() == 4, "BN require a 4D gradient", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(gradOutput.isContiguous(), "gradient is not contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        int nChannel = gradOutput.size(2);
        Log4Error$.MODULE$.invalidInputError(saveMean.size(1) == nChannel, "saveMean length is not consistent with channel number", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(saveStd.size(1) == nChannel, "saveStd length is not consistent with channel number", Log4Error$.MODULE$.invalidInputError$default$3());
        gradInput.resizeAs(gradOutput);
        if (gMean.isEmpty()) {
            gMean.resize(nChannel);
            object = gxMean.resize(nChannel);
        } else {
            object = BoxedUnit.UNIT;
        }
        float[] inputData = (float[])input.storage().array();
        int inputOffset = input.storageOffset() - 1;
        float[] gradOutputData = (float[])gradOutput.storage().array();
        int gradOutputOffset = gradOutput.storageOffset() - 1;
        float[] gradInputData = (float[])gradInput.storage().array();
        int gradInputOffset = gradInput.storageOffset() - 1;
        float[] saveMeanData = (float[])saveMean.storage().array();
        int saveMeanOffset = saveMean.storageOffset() - 1;
        float[] saveStdData = (float[])saveStd.storage().array();
        int saveStdOffset = saveStd.storageOffset() - 1;
        float[] gMeanData = (float[])gMean.storage().array();
        float[] gxMeanData = (float[])gxMean.storage().array();
        int nBatch = gradOutput.size(1);
        int frameSize = gradOutput.size(3) * gradOutput.size(4);
        int n = gradOutput.nElement();
        int i = 0;
        for (b = 0; b < nBatch; ++b) {
            for (int c = 0; c < nChannel; ++c) {
                int k = 0;
                while (k < frameSize) {
                    int n2 = c;
                    gMeanData[n2] = gMeanData[n2] + gradOutputData[i + gradOutputOffset];
                    int n3 = c;
                    gxMeanData[n3] = gxMeanData[n3] + gradOutputData[i + gradOutputOffset] * (inputData[i + inputOffset] - saveMeanData[c + saveMeanOffset]);
                    ++k;
                    ++i;
                }
            }
        }
        int gmeanEventLen = 1;
        int gmxmeanEventLen = 1;
        if (needSync) {
            ParameterSynchronizer$.MODULE$.syncData(gMeanKey, gMean, ClassTag$.MODULE$.Float());
            Map gMeanEventData = ParameterSynchronizer$.MODULE$.collect(gMeanKey, ClassTag$.MODULE$.Float());
            IntRef c = IntRef.create((int)0);
            while (c.elem < nChannel) {
                globalGmean[c.elem] = 0.0f;
                ((IterableLike)JavaConverters$.MODULE$.asScalaSetConverter(gMeanEventData.keySet()).asScala()).foreach((Function1)new Serializable(globalGmean, gMeanEventData, c){
                    public static final long serialVersionUID = 0L;
                    private final float[] globalGmean$1;
                    private final Map gMeanEventData$1;
                    private final IntRef c$1;

                    public final void apply(String threadId) {
                        Tensor localGmean = (Tensor)this.gMeanEventData$1.get(threadId);
                        int localGmeanOffset = localGmean.storageOffset() - 1;
                        int n = this.c$1.elem;
                        this.globalGmean$1[n] = this.globalGmean$1[n] + ((float[])localGmean.storage().array())[this.c$1.elem + localGmeanOffset];
                    }
                    {
                        this.globalGmean$1 = globalGmean$1;
                        this.gMeanEventData$1 = gMeanEventData$1;
                        this.c$1 = c$1;
                    }
                });
                ++c.elem;
            }
            gmeanEventLen = gMeanEventData.size();
            ParameterSynchronizer$.MODULE$.reset(gMeanKey, ClassTag$.MODULE$.Float());
            ParameterSynchronizer$.MODULE$.syncData(gxMeanKey, gxMean, ClassTag$.MODULE$.Float());
            Map gxMeanEventData = ParameterSynchronizer$.MODULE$.collect(gxMeanKey, ClassTag$.MODULE$.Float());
            c.elem = 0;
            while (c.elem < nChannel) {
                globalGxmean[c.elem] = 0.0f;
                ((IterableLike)JavaConverters$.MODULE$.asScalaSetConverter(gxMeanEventData.keySet()).asScala()).foreach((Function1)new Serializable(globalGxmean, c, gxMeanEventData){
                    public static final long serialVersionUID = 0L;
                    private final float[] globalGxmean$1;
                    private final IntRef c$1;
                    private final Map gxMeanEventData$1;

                    public final void apply(String threadId) {
                        Tensor localGxmean = (Tensor)this.gxMeanEventData$1.get(threadId);
                        int localGxmeanOffset = localGxmean.storageOffset() - 1;
                        int n = this.c$1.elem;
                        this.globalGxmean$1[n] = this.globalGxmean$1[n] + ((float[])localGxmean.storage().array())[this.c$1.elem + localGxmeanOffset];
                    }
                    {
                        this.globalGxmean$1 = globalGxmean$1;
                        this.c$1 = c$1;
                        this.gxMeanEventData$1 = gxMeanEventData$1;
                    }
                });
                ++c.elem;
            }
            gmxmeanEventLen = gxMeanEventData.size();
            ParameterSynchronizer$.MODULE$.reset(gxMeanKey, ClassTag$.MODULE$.Float());
        }
        int size = n / nChannel;
        for (int c = 0; c < nChannel; ++c) {
            if (needSync) {
                gMeanData[c] = globalGmean[c] / (float)(size * gmeanEventLen);
                gxMeanData[c] = globalGxmean[c] / (float)(size * gmxmeanEventLen);
                continue;
            }
            gMeanData[c] = gMeanData[c] / (float)(size * gmeanEventLen);
            gxMeanData[c] = gxMeanData[c] / (float)(size * gmxmeanEventLen);
        }
        i = 0;
        if (scale == null) {
            for (b = 0; b < nBatch; ++b) {
                for (int c = 0; c < nChannel; ++c) {
                    int k = 0;
                    while (k < frameSize) {
                        float invStd = saveStdData[saveStdOffset + c];
                        gradInputData[gradInputOffset + i] = invStd * (gradOutputData[gradOutputOffset + i] - gMeanData[c] - gxMeanData[c] * invStd * invStd * (inputData[inputOffset + i] - saveMeanData[saveMeanOffset + c]));
                        ++k;
                        ++i;
                    }
                }
            }
        } else {
            Log4Error$.MODULE$.invalidInputError(scale.size(1) == nChannel, "scale length is not consistent with channel number", Log4Error$.MODULE$.invalidInputError$default$3());
            float[] scaleData = (float[])scale.storage().array();
            int scaleOffset = scale.storageOffset() - 1;
            while (b < nBatch) {
                for (int c = 0; c < nChannel; ++c) {
                    int k = 0;
                    while (k < frameSize) {
                        float invStd = saveStdData[saveStdOffset + c];
                        gradInputData[gradInputOffset + i] = scaleData[scaleOffset + c] * invStd * (gradOutputData[gradOutputOffset + i] - gMeanData[c] - gxMeanData[c] * invStd * invStd * (inputData[inputOffset + i] - saveMeanData[saveMeanOffset + c]));
                        ++k;
                        ++i;
                    }
                }
                ++b;
            }
        }
    }

    public String updateGradInputNCHWTrainFloat$default$11() {
        return null;
    }

    public String updateGradInputNCHWTrainFloat$default$12() {
        return null;
    }

    public boolean updateGradInputNCHWTrainFloat$default$13() {
        return false;
    }

    public void updateOutputNCHWTrainFloat(Tensor<Object> input, Tensor<Object> output, Tensor<Object> saveMean, Tensor<Object> saveStd, Tensor<Object> runningMean, Tensor<Object> runningVar, Tensor<Object> scale, Tensor<Object> offset, float eps, float momentum, Tensor<Object> batchVar, Tensor<Object> saveVar, boolean needFix, float[] globalMean, float[] globalStd, String meanKey, String stdKey, boolean needSync) {
        int b;
        Object object;
        Log4Error$.MODULE$.invalidInputError(input.isContiguous(), "BatchNorm NCHW require a contiguous input", Log4Error$.MODULE$.invalidInputError$default$3());
        float[] inputData = (float[])input.storage().array();
        int inputOffset = input.storageOffset() - 1;
        float[] outputData = (float[])output.storage().array();
        int outputOffset = output.storageOffset() - 1;
        int nChannels = input.size(2);
        int nBatch = input.size(1);
        int nFrame = input.size(3) * input.size(4);
        if (saveMean.size(1) != nChannels) {
            saveMean.resize(nChannels);
            saveStd.resize(nChannels);
            runningMean.resize(nChannels);
            object = runningVar.resize(nChannels);
        } else {
            object = BoxedUnit.UNIT;
        }
        float[] meanData = (float[])saveMean.storage().array();
        int meanOffset = saveMean.storageOffset() - 1;
        int i = 0;
        for (b = 0; b < nBatch; ++b) {
            for (int c = 0; c < nChannels; ++c) {
                int k = 0;
                float meanSum = 0.0f;
                while (k < nFrame) {
                    meanSum += inputData[i + inputOffset];
                    ++k;
                    ++i;
                }
                int n = c + meanOffset;
                meanData[n] = meanData[n] + meanSum;
            }
        }
        int meanLen = 1;
        if (needSync) {
            ParameterSynchronizer$.MODULE$.syncData(meanKey, saveMean, ClassTag$.MODULE$.Float());
            Map meanEventData = ParameterSynchronizer$.MODULE$.collect(meanKey, ClassTag$.MODULE$.Float());
            meanLen = meanEventData.size();
            IntRef c = IntRef.create((int)0);
            while (c.elem < nChannels) {
                globalMean[c.elem] = 0.0f;
                ((IterableLike)JavaConverters$.MODULE$.asScalaSetConverter(meanEventData.keySet()).asScala()).foreach((Function1)new Serializable(globalMean, meanEventData, c){
                    public static final long serialVersionUID = 0L;
                    private final float[] globalMean$1;
                    private final Map meanEventData$1;
                    private final IntRef c$3;

                    public final void apply(String threadId) {
                        Tensor localMean = (Tensor)this.meanEventData$1.get(threadId);
                        int localOffset = localMean.storageOffset() - 1;
                        int n = this.c$3.elem;
                        this.globalMean$1[n] = this.globalMean$1[n] + ((float[])localMean.storage().array())[this.c$3.elem + localOffset];
                    }
                    {
                        this.globalMean$1 = globalMean$1;
                        this.meanEventData$1 = meanEventData$1;
                        this.c$3 = c$3;
                    }
                });
                ++c.elem;
            }
            ParameterSynchronizer$.MODULE$.reset(meanKey, ClassTag$.MODULE$.Float());
            System.arraycopy(globalMean, 0, meanData, meanOffset, nChannels);
        }
        int n = input.nElement();
        int frameSize = n / nChannels;
        IntRef c = IntRef.create((int)0);
        float[] runningMeanData = (float[])runningMean.storage().array();
        int runningMeanOffset = runningMean.storageOffset() - 1;
        while (c.elem < nChannels) {
            int n2 = c.elem + meanOffset;
            meanData[n2] = meanData[n2] / (float)(frameSize * meanLen);
            runningMeanData[c.elem + runningMeanOffset] = meanData[c.elem + meanOffset] * momentum + (1.0f - momentum) * runningMeanData[c.elem + runningMeanOffset];
            ++c.elem;
        }
        float[] stdData = (float[])saveStd.storage().array();
        int stdOffset = saveStd.storageOffset() - 1;
        i = 0;
        for (b = 0; b < nBatch; ++b) {
            for (int c2 = 0; c2 < nChannels; ++c2) {
                int k = 0;
                float stdSum = 0.0f;
                while (k < nFrame) {
                    float diff = inputData[i + inputOffset] - meanData[c2 + meanOffset];
                    stdSum += diff * diff;
                    ++k;
                    ++i;
                }
                int n3 = c2 + stdOffset;
                stdData[n3] = stdData[n3] + stdSum;
            }
        }
        int stdLen = 1;
        if (needSync) {
            ParameterSynchronizer$.MODULE$.syncData(stdKey, saveStd, ClassTag$.MODULE$.Float());
            Map stdEventData = ParameterSynchronizer$.MODULE$.collect(stdKey, ClassTag$.MODULE$.Float());
            c.elem = 0;
            while (c.elem < nChannels) {
                globalStd[c.elem] = 0.0f;
                ((IterableLike)JavaConverters$.MODULE$.asScalaSetConverter(stdEventData.keySet()).asScala()).foreach((Function1)new Serializable(globalStd, c, stdEventData){
                    public static final long serialVersionUID = 0L;
                    private final float[] globalStd$1;
                    private final IntRef c$2;
                    private final Map stdEventData$1;

                    public final void apply(String threadId) {
                        Tensor localStd = (Tensor)this.stdEventData$1.get(threadId);
                        int localStdOffSet = localStd.storageOffset() - 1;
                        int n = this.c$2.elem;
                        this.globalStd$1[n] = this.globalStd$1[n] + ((float[])localStd.storage().array())[this.c$2.elem + localStdOffSet];
                    }
                    {
                        this.globalStd$1 = globalStd$1;
                        this.c$2 = c$2;
                        this.stdEventData$1 = stdEventData$1;
                    }
                });
                ++c.elem;
            }
            stdLen = stdEventData.size();
            ParameterSynchronizer$.MODULE$.reset(stdKey, ClassTag$.MODULE$.Float());
            System.arraycopy(globalStd, 0, stdData, stdOffset, nChannels);
        }
        c.elem = 0;
        float[] runningVarData = (float[])runningVar.storage().array();
        int runningVarOffset = runningVar.storageOffset() - 1;
        while (c.elem < nChannels) {
            Object object2;
            if (stdData[c.elem + stdOffset] == 0.0f && eps == 0.0f) {
                stdData[c.elem + stdOffset] = 0.0f;
                Object object3 = saveVar == null ? BoxedUnit.UNIT : saveVar.setValue(c.elem + 1, BoxesRunTime.boxToFloat((float)0.0f));
                object2 = batchVar == null ? BoxedUnit.UNIT : batchVar.setValue(c.elem + 1, BoxesRunTime.boxToFloat((float)0.0f));
            } else {
                float s2 = stdData[c.elem + stdOffset];
                float unbiasedVar = s2 / (float)(frameSize * stdLen - 1);
                Object object4 = saveVar == null ? BoxedUnit.UNIT : saveVar.setValue(c.elem + 1, BoxesRunTime.boxToFloat((float)(s2 / (float)(frameSize * stdLen))));
                Object object5 = batchVar == null ? BoxedUnit.UNIT : batchVar.setValue(c.elem + 1, BoxesRunTime.boxToFloat((float)unbiasedVar));
                stdData[c.elem + stdOffset] = 1.0f / (float)Math.sqrt(s2 / (float)(frameSize * stdLen) + eps);
                runningVarData[c.elem + runningVarOffset] = momentum * unbiasedVar + (1.0f - momentum) * runningVarData[c.elem + runningVarOffset];
                object2 = BoxedUnit.UNIT;
            }
            ++c.elem;
        }
        if (needFix) {
            c.elem = 0;
            while (c.elem < nChannels) {
                meanData[c.elem + meanOffset] = 0.0f;
                stdData[c.elem + stdOffset] = 1.0E-4f;
                ++c.elem;
            }
        }
        if (scale == null) {
            i = 0;
            for (b = 0; b < nBatch; ++b) {
                for (int c3 = 0; c3 < nChannels; ++c3) {
                    int k = 0;
                    while (k < nFrame) {
                        outputData[i + outputOffset] = (inputData[i + inputOffset] - meanData[c3 + meanOffset]) * stdData[c3 + stdOffset];
                        ++k;
                        ++i;
                    }
                }
            }
        } else {
            float[] scaleData = (float[])scale.storage().array();
            int scaleOffset = scale.storageOffset() - 1;
            float[] offsetData = (float[])offset.storage().array();
            int offsetOffset = offset.storageOffset() - 1;
            i = 0;
            for (b = 0; b < nBatch; ++b) {
                for (int c4 = 0; c4 < nChannels; ++c4) {
                    int k = 0;
                    while (k < nFrame) {
                        outputData[i + outputOffset] = (inputData[i + inputOffset] - meanData[c4 + meanOffset]) * stdData[c4 + stdOffset] * scaleData[c4 + scaleOffset] + offsetData[c4 + offsetOffset];
                        ++k;
                        ++i;
                    }
                }
            }
        }
    }

    public Tensor<Object> updateOutputNCHWTrainFloat$default$11() {
        return null;
    }

    public Tensor<Object> updateOutputNCHWTrainFloat$default$12() {
        return null;
    }

    public boolean updateOutputNCHWTrainFloat$default$13() {
        return false;
    }

    public float[] updateOutputNCHWTrainFloat$default$14() {
        return null;
    }

    public float[] updateOutputNCHWTrainFloat$default$15() {
        return null;
    }

    public String updateOutputNCHWTrainFloat$default$16() {
        return null;
    }

    public String updateOutputNCHWTrainFloat$default$17() {
        return null;
    }

    public boolean updateOutputNCHWTrainFloat$default$18() {
        return false;
    }

    public void updateOutputNCHWTrainDouble(Tensor<Object> input, Tensor<Object> output, Tensor<Object> saveMean, Tensor<Object> saveStd, Tensor<Object> runningMean, Tensor<Object> runningVar, Tensor<Object> scale, Tensor<Object> offset, double eps, double momentum, Tensor<Object> batchVar, Tensor<Object> saveVar, boolean needFix, double[] globalMean, double[] globalStd, String meanKey, String stdKey, boolean needSync) {
        int b;
        Object object;
        Log4Error$.MODULE$.invalidInputError(input.isContiguous(), "BatchNorm NCHW require a contiguous input", Log4Error$.MODULE$.invalidInputError$default$3());
        double[] inputData = (double[])input.storage().array();
        int inputOffset = input.storageOffset() - 1;
        double[] outputData = (double[])output.storage().array();
        int outputOffset = output.storageOffset() - 1;
        int nChannels = input.size(2);
        int nBatch = input.size(1);
        int nFrame = input.size(3) * input.size(4);
        if (saveMean.size(1) != nChannels) {
            saveMean.resize(nChannels);
            saveStd.resize(nChannels);
            runningMean.resize(nChannels);
            object = runningVar.resize(nChannels);
        } else {
            object = BoxedUnit.UNIT;
        }
        double[] meanData = (double[])saveMean.storage().array();
        int meanOffset = saveMean.storageOffset() - 1;
        int i = 0;
        for (b = 0; b < nBatch; ++b) {
            for (int c = 0; c < nChannels; ++c) {
                int k = 0;
                double meanSum = 0.0;
                while (k < nFrame) {
                    meanSum += inputData[i + inputOffset];
                    ++k;
                    ++i;
                }
                int n = c + meanOffset;
                meanData[n] = meanData[n] + meanSum;
            }
        }
        int meanLen = 1;
        if (needSync) {
            ParameterSynchronizer$.MODULE$.syncData(meanKey, saveMean, ClassTag$.MODULE$.Double());
            Map meanEventData = ParameterSynchronizer$.MODULE$.collect(meanKey, ClassTag$.MODULE$.Double());
            meanLen = meanEventData.size();
            IntRef c = IntRef.create((int)0);
            while (c.elem < nChannels) {
                globalMean[c.elem] = 0.0;
                ((IterableLike)JavaConverters$.MODULE$.asScalaSetConverter(meanEventData.keySet()).asScala()).foreach((Function1)new Serializable(globalMean, meanEventData, c){
                    public static final long serialVersionUID = 0L;
                    private final double[] globalMean$2;
                    private final Map meanEventData$2;
                    private final IntRef c$5;

                    public final void apply(String threadId) {
                        Tensor localMean = (Tensor)this.meanEventData$2.get(threadId);
                        int localOffset = localMean.storageOffset() - 1;
                        int n = this.c$5.elem;
                        this.globalMean$2[n] = this.globalMean$2[n] + ((double[])localMean.storage().array())[this.c$5.elem + localOffset];
                    }
                    {
                        this.globalMean$2 = globalMean$2;
                        this.meanEventData$2 = meanEventData$2;
                        this.c$5 = c$5;
                    }
                });
                ++c.elem;
            }
            ParameterSynchronizer$.MODULE$.reset(meanKey, ClassTag$.MODULE$.Double());
            System.arraycopy(globalMean, 0, meanData, meanOffset, nChannels);
        }
        int n = input.nElement();
        int frameSize = n / nChannels;
        IntRef c = IntRef.create((int)0);
        double[] runningMeanData = (double[])runningMean.storage().array();
        int runningMeanOffset = runningMean.storageOffset() - 1;
        while (c.elem < nChannels) {
            int n2 = c.elem + meanOffset;
            meanData[n2] = meanData[n2] / (double)(frameSize * meanLen);
            runningMeanData[c.elem + runningMeanOffset] = meanData[c.elem + meanOffset] * momentum + (1.0 - momentum) * runningMeanData[c.elem + runningMeanOffset];
            ++c.elem;
        }
        double[] stdData = (double[])saveStd.storage().array();
        int stdOffset = saveStd.storageOffset() - 1;
        i = 0;
        for (b = 0; b < nBatch; ++b) {
            for (int c2 = 0; c2 < nChannels; ++c2) {
                int k = 0;
                while (k < nFrame) {
                    double diff = inputData[i + inputOffset] - meanData[c2 + meanOffset];
                    int n3 = c2 + stdOffset;
                    stdData[n3] = stdData[n3] + diff * diff;
                    ++k;
                    ++i;
                }
            }
        }
        int stdLen = 1;
        if (needSync) {
            ParameterSynchronizer$.MODULE$.syncData(stdKey, saveStd, ClassTag$.MODULE$.Double());
            Map stdEventData = ParameterSynchronizer$.MODULE$.collect(stdKey, ClassTag$.MODULE$.Double());
            c.elem = 0;
            while (c.elem < nChannels) {
                globalStd[c.elem] = 0.0;
                ((IterableLike)JavaConverters$.MODULE$.asScalaSetConverter(stdEventData.keySet()).asScala()).foreach((Function1)new Serializable(globalStd, c, stdEventData){
                    public static final long serialVersionUID = 0L;
                    private final double[] globalStd$2;
                    private final IntRef c$4;
                    private final Map stdEventData$2;

                    public final void apply(String threadId) {
                        Tensor localStd = (Tensor)this.stdEventData$2.get(threadId);
                        int localStdOffSet = localStd.storageOffset() - 1;
                        int n = this.c$4.elem;
                        this.globalStd$2[n] = this.globalStd$2[n] + ((double[])localStd.storage().array())[this.c$4.elem + localStdOffSet];
                    }
                    {
                        this.globalStd$2 = globalStd$2;
                        this.c$4 = c$4;
                        this.stdEventData$2 = stdEventData$2;
                    }
                });
                ++c.elem;
            }
            stdLen = stdEventData.size();
            ParameterSynchronizer$.MODULE$.reset(stdKey, ClassTag$.MODULE$.Double());
            System.arraycopy(globalStd, 0, stdData, stdOffset, nChannels);
        }
        c.elem = 0;
        double[] runningVarData = (double[])runningVar.storage().array();
        int runningVarOffset = runningVar.storageOffset() - 1;
        while (c.elem < nChannels) {
            Object object2;
            if (stdData[c.elem + stdOffset] == 0.0 && eps == 0.0) {
                stdData[c.elem + stdOffset] = 0.0;
                Object object3 = saveVar == null ? BoxedUnit.UNIT : saveVar.setValue(c.elem + 1, BoxesRunTime.boxToDouble((double)0.0));
                object2 = batchVar == null ? BoxedUnit.UNIT : batchVar.setValue(c.elem + 1, BoxesRunTime.boxToDouble((double)0.0));
            } else {
                double s2 = stdData[c.elem + stdOffset];
                double unbiasedVar = s2 / (double)(frameSize * stdLen - 1);
                Object object4 = saveVar == null ? BoxedUnit.UNIT : saveVar.setValue(c.elem + 1, BoxesRunTime.boxToDouble((double)(s2 / (double)(frameSize * stdLen))));
                Object object5 = batchVar == null ? BoxedUnit.UNIT : batchVar.setValue(c.elem + 1, BoxesRunTime.boxToDouble((double)unbiasedVar));
                stdData[c.elem + stdOffset] = 1.0 / Math.sqrt(s2 / (double)(frameSize * stdLen) + eps);
                runningVarData[c.elem + stdOffset] = momentum * unbiasedVar + (1.0 - momentum) * runningVarData[c.elem + runningVarOffset];
                object2 = BoxedUnit.UNIT;
            }
            ++c.elem;
        }
        if (needFix) {
            c.elem = 0;
            while (c.elem < nChannels) {
                meanData[c.elem + meanOffset] = 0.0;
                stdData[c.elem + stdOffset] = 1.0E-4;
                ++c.elem;
            }
        }
        if (scale == null) {
            i = 0;
            for (b = 0; b < nBatch; ++b) {
                for (int c3 = 0; c3 < nChannels; ++c3) {
                    int k = 0;
                    while (k < nFrame) {
                        outputData[i + outputOffset] = (inputData[i + inputOffset] - meanData[c3 + meanOffset]) * stdData[c3 + stdOffset];
                        ++k;
                        ++i;
                    }
                }
            }
        } else {
            double[] scaleData = (double[])scale.storage().array();
            int scaleOffset = scale.storageOffset() - 1;
            double[] offsetData = (double[])offset.storage().array();
            int offsetOffset = offset.storageOffset() - 1;
            i = 0;
            for (b = 0; b < nBatch; ++b) {
                for (int c4 = 0; c4 < nChannels; ++c4) {
                    int k = 0;
                    while (k < nFrame) {
                        outputData[i + outputOffset] = (inputData[i + inputOffset] - meanData[c4 + meanOffset]) * stdData[c4 + stdOffset] * scaleData[c4 + scaleOffset] + offsetData[c4 + offsetOffset];
                        ++k;
                        ++i;
                    }
                }
            }
        }
    }

    public Tensor<Object> updateOutputNCHWTrainDouble$default$11() {
        return null;
    }

    public Tensor<Object> updateOutputNCHWTrainDouble$default$12() {
        return null;
    }

    public boolean updateOutputNCHWTrainDouble$default$13() {
        return false;
    }

    public double[] updateOutputNCHWTrainDouble$default$14() {
        return null;
    }

    public double[] updateOutputNCHWTrainDouble$default$15() {
        return null;
    }

    public String updateOutputNCHWTrainDouble$default$16() {
        return null;
    }

    public String updateOutputNCHWTrainDouble$default$17() {
        return null;
    }

    public boolean updateOutputNCHWTrainDouble$default$18() {
        return false;
    }

    public void updateGradInputNCHWTrainDouble(Tensor<Object> input, Tensor<Object> gradOutput, Tensor<Object> gradInput, Tensor<Object> scale, Tensor<Object> saveMean, Tensor<Object> saveStd, Tensor<Object> gMean, Tensor<Object> gxMean, double[] globalGmean, double[] globalGxmean, String gMeanKey, String gxMeanKey, boolean needSync) {
        int b;
        Object object;
        Log4Error$.MODULE$.invalidInputError(input.nDimension() == 4, "BN require a 4D input", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(input.isContiguous(), "input is not contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(gradOutput.nDimension() == 4, "BN require a 4D gradient", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(gradOutput.isContiguous(), "gradient is not contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        int nChannel = gradOutput.size(2);
        Log4Error$.MODULE$.invalidInputError(saveMean.size(1) == nChannel, "saveMean length is not consistent with channel number", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(saveStd.size(1) == nChannel, "saveStd length is not consistent with channel number", Log4Error$.MODULE$.invalidInputError$default$3());
        gradInput.resizeAs(gradOutput);
        if (gMean.isEmpty()) {
            gMean.resize(saveMean.size(1));
            object = gxMean.resize(saveMean.size(1));
        } else {
            object = BoxedUnit.UNIT;
        }
        double[] inputData = (double[])input.storage().array();
        int inputOffset = input.storageOffset() - 1;
        double[] gradOutputData = (double[])gradOutput.storage().array();
        int gradOutputOffset = gradOutput.storageOffset() - 1;
        double[] gradInputData = (double[])gradInput.storage().array();
        int gradInputOffset = gradInput.storageOffset() - 1;
        double[] saveMeanData = (double[])saveMean.storage().array();
        int saveMeanOffset = saveMean.storageOffset() - 1;
        double[] saveStdData = (double[])saveStd.storage().array();
        int saveStdOffset = saveStd.storageOffset() - 1;
        double[] gMeanData = (double[])gMean.storage().array();
        double[] gxMeanData = (double[])gxMean.storage().array();
        int nBatch = gradOutput.size(1);
        int frameSize = gradOutput.size(3) * gradOutput.size(4);
        int n = gradOutput.nElement();
        int i = 0;
        for (b = 0; b < nBatch; ++b) {
            for (int c = 0; c < nChannel; ++c) {
                int k = 0;
                while (k < frameSize) {
                    int n2 = c;
                    gMeanData[n2] = gMeanData[n2] + gradOutputData[i + gradOutputOffset];
                    int n3 = c;
                    gxMeanData[n3] = gxMeanData[n3] + gradOutputData[i + gradOutputOffset] * (inputData[i + inputOffset] - saveMeanData[c + saveMeanOffset]);
                    ++k;
                    ++i;
                }
            }
        }
        int gmeanEventLen = 1;
        int gmxmeanEventLen = 1;
        if (needSync) {
            ParameterSynchronizer$.MODULE$.syncData(gMeanKey, gMean, ClassTag$.MODULE$.Double());
            Map gMeanEventData = ParameterSynchronizer$.MODULE$.collect(gMeanKey, ClassTag$.MODULE$.Double());
            IntRef c = IntRef.create((int)0);
            while (c.elem < nChannel) {
                globalGmean[c.elem] = 0.0;
                ((IterableLike)JavaConverters$.MODULE$.asScalaSetConverter(gMeanEventData.keySet()).asScala()).foreach((Function1)new Serializable(globalGmean, gMeanEventData, c){
                    public static final long serialVersionUID = 0L;
                    private final double[] globalGmean$2;
                    private final Map gMeanEventData$2;
                    private final IntRef c$6;

                    public final void apply(String threadId) {
                        Tensor localGmean = (Tensor)this.gMeanEventData$2.get(threadId);
                        int localGmeanOffset = localGmean.storageOffset() - 1;
                        int n = this.c$6.elem;
                        this.globalGmean$2[n] = this.globalGmean$2[n] + ((double[])localGmean.storage().array())[this.c$6.elem + localGmeanOffset];
                    }
                    {
                        this.globalGmean$2 = globalGmean$2;
                        this.gMeanEventData$2 = gMeanEventData$2;
                        this.c$6 = c$6;
                    }
                });
                ++c.elem;
            }
            gmeanEventLen = gMeanEventData.size();
            ParameterSynchronizer$.MODULE$.reset(gMeanKey, ClassTag$.MODULE$.Double());
            ParameterSynchronizer$.MODULE$.syncData(gxMeanKey, gxMean, ClassTag$.MODULE$.Double());
            Map gxMeanEventData = ParameterSynchronizer$.MODULE$.collect(gxMeanKey, ClassTag$.MODULE$.Double());
            c.elem = 0;
            while (c.elem < nChannel) {
                globalGxmean[c.elem] = 0.0;
                ((IterableLike)JavaConverters$.MODULE$.asScalaSetConverter(gxMeanEventData.keySet()).asScala()).foreach((Function1)new Serializable(globalGxmean, c, gxMeanEventData){
                    public static final long serialVersionUID = 0L;
                    private final double[] globalGxmean$2;
                    private final IntRef c$6;
                    private final Map gxMeanEventData$2;

                    public final void apply(String threadId) {
                        Tensor localGxmean = (Tensor)this.gxMeanEventData$2.get(threadId);
                        int localGxmeanOffset = localGxmean.storageOffset() - 1;
                        int n = this.c$6.elem;
                        this.globalGxmean$2[n] = this.globalGxmean$2[n] + ((double[])localGxmean.storage().array())[this.c$6.elem + localGxmeanOffset];
                    }
                    {
                        this.globalGxmean$2 = globalGxmean$2;
                        this.c$6 = c$6;
                        this.gxMeanEventData$2 = gxMeanEventData$2;
                    }
                });
                ++c.elem;
            }
            gmxmeanEventLen = gxMeanEventData.size();
            ParameterSynchronizer$.MODULE$.reset(gxMeanKey, ClassTag$.MODULE$.Double());
        }
        int size = n / nChannel;
        for (int c = 0; c < nChannel; ++c) {
            if (needSync) {
                gMeanData[c] = globalGmean[c] / (double)(size * gmeanEventLen);
                double invStd = saveStdData[saveStdOffset + c];
                gxMeanData[c] = globalGxmean[c] * invStd * invStd / (double)(size * gmxmeanEventLen);
                continue;
            }
            gMeanData[c] = gMeanData[c] / (double)(size * gmeanEventLen);
            double invStd = saveStdData[saveStdOffset + c];
            gxMeanData[c] = gxMeanData[c] * invStd * invStd / (double)(size * gmxmeanEventLen);
        }
        i = 0;
        if (scale == null) {
            for (b = 0; b < nBatch; ++b) {
                for (int c = 0; c < nChannel; ++c) {
                    int k = 0;
                    while (k < frameSize) {
                        double invStd = saveStdData[saveStdOffset + c];
                        gradInputData[gradInputOffset + i] = (gradOutputData[gradOutputOffset + i] - gMeanData[c] - (inputData[inputOffset + i] - saveMeanData[saveMeanOffset + c]) * gxMeanData[c]) * invStd;
                        ++k;
                        ++i;
                    }
                }
            }
        } else {
            double[] scaleData = (double[])scale.storage().array();
            int scaleOffset = scale.storageOffset() - 1;
            Log4Error$.MODULE$.invalidInputError(scale.size(1) == nChannel, "scale length is not consistent with channel number", Log4Error$.MODULE$.invalidInputError$default$3());
            while (b < nBatch) {
                for (int c = 0; c < nChannel; ++c) {
                    int k = 0;
                    while (k < frameSize) {
                        double invStd = saveStdData[saveStdOffset + c];
                        gradInputData[gradInputOffset + i] = (gradOutputData[gradOutputOffset + i] - gMeanData[c] - (inputData[inputOffset + i] - saveMeanData[saveMeanOffset + c]) * gxMeanData[c]) * invStd * scaleData[scaleOffset + c];
                        ++k;
                        ++i;
                    }
                }
                ++b;
            }
        }
    }

    public String updateGradInputNCHWTrainDouble$default$11() {
        return null;
    }

    public String updateGradInputNCHWTrainDouble$default$12() {
        return null;
    }

    public boolean updateGradInputNCHWTrainDouble$default$13() {
        return false;
    }

    public void updateGradInputNCHWInferFloat(Tensor<Object> gradOutput, Tensor<Object> gradInput, Tensor<Object> scale, Tensor<Object> saveStd) {
        int b;
        Log4Error$.MODULE$.invalidInputError(gradOutput.nDimension() == 4, "BN require a 4D gradient", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(gradOutput.isContiguous(), "gradient is not contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        int nChannel = gradOutput.size(2);
        Log4Error$.MODULE$.invalidInputError(saveStd.size(1) == nChannel, "saveStd length is not consistent with channel number", Log4Error$.MODULE$.invalidInputError$default$3());
        gradInput.resizeAs(gradOutput);
        float[] gradOutputData = (float[])gradOutput.storage().array();
        int gradOutputOffset = gradOutput.storageOffset() - 1;
        float[] gradInputData = (float[])gradInput.storage().array();
        int gradInputOffset = gradInput.storageOffset() - 1;
        float[] saveStdData = (float[])saveStd.storage().array();
        int saveStdOffset = saveStd.storageOffset() - 1;
        int nBatch = gradOutput.size(1);
        int frameSize = gradOutput.size(3) * gradOutput.size(4);
        int i = 0;
        if (scale == null) {
            for (b = 0; b < nBatch; ++b) {
                for (int c = 0; c < nChannel; ++c) {
                    int k = 0;
                    while (k < frameSize) {
                        float invStd = saveStdData[saveStdOffset + c];
                        gradInputData[gradInputOffset + i] = invStd * gradOutputData[gradOutputOffset + i];
                        ++k;
                        ++i;
                    }
                }
            }
        } else {
            Log4Error$.MODULE$.invalidInputError(scale.size(1) == nChannel, "scale length is not consistent with channel number", Log4Error$.MODULE$.invalidInputError$default$3());
            float[] scaleData = (float[])scale.storage().array();
            int scaleOffset = scale.storageOffset() - 1;
            while (b < nBatch) {
                for (int c = 0; c < nChannel; ++c) {
                    int k = 0;
                    while (k < frameSize) {
                        float invStd = saveStdData[saveStdOffset + c];
                        gradInputData[gradInputOffset + i] = scaleData[scaleOffset + c] * invStd * gradOutputData[gradOutputOffset + i];
                        ++k;
                        ++i;
                    }
                }
                ++b;
            }
        }
    }

    public void updateGradInputNCHWInferDouble(Tensor<Object> gradOutput, Tensor<Object> gradInput, Tensor<Object> scale, Tensor<Object> saveStd) {
        int b;
        Log4Error$.MODULE$.invalidInputError(gradOutput.nDimension() == 4, "BN require a 4D gradient", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(gradOutput.isContiguous(), "gradient is not contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        int nChannel = gradOutput.size(2);
        Log4Error$.MODULE$.invalidInputError(saveStd.size(1) == nChannel, "saveStd length is not consistent with channel number", Log4Error$.MODULE$.invalidInputError$default$3());
        gradInput.resizeAs(gradOutput);
        double[] gradOutputData = (double[])gradOutput.storage().array();
        int gradOutputOffset = gradOutput.storageOffset() - 1;
        double[] gradInputData = (double[])gradInput.storage().array();
        int gradInputOffset = gradInput.storageOffset() - 1;
        double[] saveStdData = (double[])saveStd.storage().array();
        int saveStdOffset = saveStd.storageOffset() - 1;
        int nBatch = gradOutput.size(1);
        int frameSize = gradOutput.size(3) * gradOutput.size(4);
        int i = 0;
        if (scale == null) {
            for (b = 0; b < nBatch; ++b) {
                for (int c = 0; c < nChannel; ++c) {
                    int k = 0;
                    while (k < frameSize) {
                        double invStd = saveStdData[saveStdOffset + c];
                        gradInputData[gradInputOffset + i] = invStd * gradOutputData[gradOutputOffset + i];
                        ++k;
                        ++i;
                    }
                }
            }
        } else {
            Log4Error$.MODULE$.invalidInputError(scale.size(1) == nChannel, "scale length is not consistent with channel number", Log4Error$.MODULE$.invalidInputError$default$3());
            double[] scaleData = (double[])scale.storage().array();
            int scaleOffset = scale.storageOffset() - 1;
            while (b < nBatch) {
                for (int c = 0; c < nChannel; ++c) {
                    int k = 0;
                    while (k < frameSize) {
                        double invStd = saveStdData[saveStdOffset + c];
                        gradInputData[gradInputOffset + i] = scaleData[scaleOffset + c] * invStd * gradOutputData[gradOutputOffset + i];
                        ++k;
                        ++i;
                    }
                }
                ++b;
            }
        }
    }

    public void accGradientNHWCFloat(Tensor<Object> gradOutput, Tensor<Object> gradWeight, Tensor<Object> gradBias, Tensor<Object> input, Tensor<Object> saveMean, Tensor<Object> saveStd, float scaleW, float scaleB) {
        Log4Error$.MODULE$.invalidInputError(gradOutput.isContiguous(), "gradOutput must be contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(gradWeight.isContiguous(), "gradWeight must be contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(gradBias.isContiguous(), "gradBias must be contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(input.isContiguous(), "gradWeight must be contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(saveMean.nDimension() == 1, "saveMean must be 1D", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(saveStd.nDimension() == 1, "saveStd must be 1D", Log4Error$.MODULE$.invalidInputError$default$3());
        int nChannel = saveMean.size(1);
        float[] gradOutputData = (float[])gradOutput.storage().array();
        int gradOutputOffset = gradOutput.storageOffset() - 1;
        float[] gradWeightData = (float[])gradWeight.storage().array();
        int gradWeightOffset = gradWeight.storageOffset() - 1;
        float[] gradBiasData = (float[])gradBias.storage().array();
        int gradBiasOffset = gradBias.storageOffset() - 1;
        float[] inputData = (float[])input.storage().array();
        int inputOffset = input.storageOffset() - 1;
        float[] saveMeanData = (float[])saveMean.storage().array();
        int saveMeanOffset = saveMean.storageOffset() - 1;
        float[] saveStdData = (float[])saveStd.storage().array();
        int saveStdOffset = saveStd.storageOffset() - 1;
        int i = 0;
        int n = input.nElement();
        while (i < n) {
            for (int c = 0; c < nChannel; ++c) {
                float g2 = gradOutputData[gradOutputOffset + i];
                int n2 = c + gradWeightOffset;
                gradWeightData[n2] = gradWeightData[n2] + g2 * (inputData[inputOffset + i] - saveMeanData[saveMeanOffset + c]) * saveStdData[saveStdOffset + c] * scaleW;
                int n3 = c + gradBiasOffset;
                gradBiasData[n3] = gradBiasData[n3] + g2 * scaleB;
                ++i;
            }
        }
    }

    public void accGradientNHWCDouble(Tensor<Object> gradOutput, Tensor<Object> gradWeight, Tensor<Object> gradBias, Tensor<Object> input, Tensor<Object> saveMean, Tensor<Object> saveStd, double scaleW, double scaleB) {
        Log4Error$.MODULE$.invalidInputError(gradOutput.isContiguous(), "gradOutput must be contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(gradWeight.isContiguous(), "gradWeight must be contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(gradBias.isContiguous(), "gradBias must be contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(input.isContiguous(), "gradWeight must be contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(saveMean.nDimension() == 1, "saveMean must be 1D", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(saveStd.nDimension() == 1, "saveStd must be 1D", Log4Error$.MODULE$.invalidInputError$default$3());
        int nChannel = saveMean.size(1);
        double[] gradOutputData = (double[])gradOutput.storage().array();
        int gradOutputOffset = gradOutput.storageOffset() - 1;
        double[] gradWeightData = (double[])gradWeight.storage().array();
        int gradWeightOffset = gradWeight.storageOffset() - 1;
        double[] gradBiasData = (double[])gradBias.storage().array();
        int gradBiasOffset = gradBias.storageOffset() - 1;
        double[] inputData = (double[])input.storage().array();
        int inputOffset = input.storageOffset() - 1;
        double[] saveMeanData = (double[])saveMean.storage().array();
        int saveMeanOffset = saveMean.storageOffset() - 1;
        double[] saveStdData = (double[])saveStd.storage().array();
        int saveStdOffset = saveStd.storageOffset() - 1;
        int i = 0;
        int n = input.nElement();
        while (i < n) {
            for (int c = 0; c < nChannel; ++c) {
                double g2 = gradOutputData[gradOutputOffset + i];
                int n2 = c + gradWeightOffset;
                gradWeightData[n2] = gradWeightData[n2] + g2 * (inputData[inputOffset + i] - saveMeanData[saveMeanOffset + c]) * saveStdData[saveStdOffset + c] * scaleW;
                int n3 = c + gradBiasOffset;
                gradBiasData[n3] = gradBiasData[n3] + g2 * scaleB;
                ++i;
            }
        }
    }

    public void accGradientNCHWFloat(Tensor<Object> gradOutput, Tensor<Object> gradWeight, Tensor<Object> gradBias, Tensor<Object> input, Tensor<Object> saveMean, Tensor<Object> saveStd, float scaleW, float scaleB) {
        Log4Error$.MODULE$.invalidInputError(gradOutput.isContiguous(), "gradOutput must be contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(gradWeight.isContiguous(), "gradWeight must be contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(gradBias.isContiguous(), "gradBias must be contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(input.isContiguous(), "gradWeight must be contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(saveMean.nDimension() == 1, "saveMean must be 1D", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(saveStd.nDimension() == 1, "saveStd must be 1D", Log4Error$.MODULE$.invalidInputError$default$3());
        int nChannel = saveMean.size(1);
        float[] gradOutputData = (float[])gradOutput.storage().array();
        int gradOutputOffset = gradOutput.storageOffset() - 1;
        float[] gradWeightData = (float[])gradWeight.storage().array();
        int gradWeightOffset = gradWeight.storageOffset() - 1;
        float[] gradBiasData = (float[])gradBias.storage().array();
        int gradBiasOffset = gradBias.storageOffset() - 1;
        float[] inputData = (float[])input.storage().array();
        int inputOffset = input.storageOffset() - 1;
        float[] saveMeanData = (float[])saveMean.storage().array();
        int saveMeanOffset = saveMean.storageOffset() - 1;
        float[] saveStdData = (float[])saveStd.storage().array();
        int saveStdOffset = saveStd.storageOffset() - 1;
        int nBatch = gradOutput.size(1);
        int frameSize = gradOutput.size(3) * gradOutput.size(4);
        int i = 0;
        for (int b = 0; b < nBatch; ++b) {
            for (int c = 0; c < nChannel; ++c) {
                int k = 0;
                while (k < frameSize) {
                    float g2 = gradOutputData[gradOutputOffset + i];
                    int n = c + gradWeightOffset;
                    gradWeightData[n] = gradWeightData[n] + g2 * (inputData[inputOffset + i] - saveMeanData[saveMeanOffset + c]) * saveStdData[saveStdOffset + c] * scaleW;
                    int n2 = c + gradBiasOffset;
                    gradBiasData[n2] = gradBiasData[n2] + g2 * scaleB;
                    ++k;
                    ++i;
                }
            }
        }
    }

    public void accGradientNCHWDouble(Tensor<Object> gradOutput, Tensor<Object> gradWeight, Tensor<Object> gradBias, Tensor<Object> input, Tensor<Object> saveMean, Tensor<Object> saveStd, double scaleW, double scaleB) {
        Log4Error$.MODULE$.invalidInputError(gradOutput.isContiguous(), "gradOutput must be contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(gradWeight.isContiguous(), "gradWeight must be contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(gradBias.isContiguous(), "gradBias must be contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(input.isContiguous(), "gradWeight must be contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(saveMean.nDimension() == 1, "saveMean must be 1D", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(saveStd.nDimension() == 1, "saveStd must be 1D", Log4Error$.MODULE$.invalidInputError$default$3());
        int nChannel = saveMean.size(1);
        double[] gradOutputData = (double[])gradOutput.storage().array();
        int gradOutputOffset = gradOutput.storageOffset() - 1;
        double[] gradWeightData = (double[])gradWeight.storage().array();
        int gradWeightOffset = gradWeight.storageOffset() - 1;
        double[] gradBiasData = (double[])gradBias.storage().array();
        int gradBiasOffset = gradBias.storageOffset() - 1;
        double[] inputData = (double[])input.storage().array();
        int inputOffset = input.storageOffset() - 1;
        double[] saveMeanData = (double[])saveMean.storage().array();
        int saveMeanOffset = saveMean.storageOffset() - 1;
        double[] saveStdData = (double[])saveStd.storage().array();
        int saveStdOffset = saveStd.storageOffset() - 1;
        int nBatch = gradOutput.size(1);
        int frameSize = gradOutput.size(3) * gradOutput.size(4);
        int i = 0;
        for (int b = 0; b < nBatch; ++b) {
            for (int c = 0; c < nChannel; ++c) {
                int k = 0;
                while (k < frameSize) {
                    double g2 = gradOutputData[gradOutputOffset + i];
                    int n = c + gradWeightOffset;
                    gradWeightData[n] = gradWeightData[n] + scaleW * (inputData[inputOffset + i] - saveMeanData[saveMeanOffset + c]) * g2 * saveStdData[saveStdOffset + c];
                    int n2 = c + gradBiasOffset;
                    gradBiasData[n2] = gradBiasData[n2] + g2 * scaleB;
                    ++k;
                    ++i;
                }
            }
        }
    }

    public <T> double $lessinit$greater$default$2() {
        return 1.0E-5;
    }

    public <T> double $lessinit$greater$default$3() {
        return 0.1;
    }

    public <T> boolean $lessinit$greater$default$4() {
        return true;
    }

    public <T> Null$ $lessinit$greater$default$5() {
        return null;
    }

    public <T> Null$ $lessinit$greater$default$6() {
        return null;
    }

    public <T> Null$ $lessinit$greater$default$7() {
        return null;
    }

    public <T> Null$ $lessinit$greater$default$8() {
        return null;
    }

    public <T> DataFormat $lessinit$greater$default$9() {
        return DataFormat$NCHW$.MODULE$;
    }

    private Object readResolve() {
        return MODULE$;
    }

    public SpatialBatchNormalization<Object> apply$mDc$sp(int nOutput, double eps, double momentum, boolean affine, Tensor<Object> initWeight, Tensor<Object> initBias, Tensor<Object> initGradWeight, Tensor<Object> initGradBias, DataFormat dataFormat, ClassTag<Object> evidence$2, TensorNumericMath.TensorNumeric<Object> ev) {
        return new SpatialBatchNormalization<Object>(nOutput, eps, momentum, affine, initWeight, initBias, initGradWeight, initGradBias, dataFormat, evidence$2, ev);
    }

    public SpatialBatchNormalization<Object> apply$mFc$sp(int nOutput, double eps, double momentum, boolean affine, Tensor<Object> initWeight, Tensor<Object> initBias, Tensor<Object> initGradWeight, Tensor<Object> initGradBias, DataFormat dataFormat, ClassTag<Object> evidence$2, TensorNumericMath.TensorNumeric<Object> ev) {
        return new SpatialBatchNormalization<Object>(nOutput, eps, momentum, affine, initWeight, initBias, initGradWeight, initGradBias, dataFormat, evidence$2, ev);
    }

    private SpatialBatchNormalization$() {
        MODULE$ = this;
    }
}

