/*
 * Decompiled with CFR 0.152.
 */
package com.intel.analytics.bigdl.dllib.nn;

import com.intel.analytics.bigdl.dllib.nn.abstractnn.TensorModule;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.Log4Error$;
import com.intel.analytics.bigdl.dllib.utils.Shape;
import com.intel.analytics.bigdl.dllib.utils.Shape$;
import scala.Predef$;
import scala.StringContext;
import scala.collection.Seq;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;

@ScalaSignature(bytes="\u0006\u0001\u0005\rd\u0001B\u0001\u0003\u0001=\u0011A\"\u00169TC6\u0004H.\u001b8hg\u0011S!a\u0001\u0003\u0002\u00059t'BA\u0003\u0007\u0003\u0015!G\u000e\\5c\u0015\t9\u0001\"A\u0003cS\u001e$GN\u0003\u0002\n\u0015\u0005I\u0011M\\1msRL7m\u001d\u0006\u0003\u00171\tQ!\u001b8uK2T\u0011!D\u0001\u0004G>l7\u0001A\u000b\u0003!e\u0019\"\u0001A\t\u0011\u0007I)r#D\u0001\u0014\u0015\t!\"!\u0001\u0006bEN$(/Y2u]:L!AF\n\u0003\u0019Q+gn]8s\u001b>$W\u000f\\3\u0011\u0005aIB\u0002\u0001\u0003\u00065\u0001\u0011\ra\u0007\u0002\u0002)F\u0011AD\t\t\u0003;\u0001j\u0011A\b\u0006\u0002?\u0005)1oY1mC&\u0011\u0011E\b\u0002\b\u001d>$\b.\u001b8h!\ti2%\u0003\u0002%=\t\u0019\u0011I\\=\t\u0011\u0019\u0002!Q1A\u0005\u0002\u001d\nAa]5{KV\t\u0001\u0006E\u0002\u001eS-J!A\u000b\u0010\u0003\u000b\u0005\u0013(/Y=\u0011\u0005ua\u0013BA\u0017\u001f\u0005\rIe\u000e\u001e\u0005\t_\u0001\u0011\t\u0011)A\u0005Q\u0005)1/\u001b>fA!A\u0011\u0007\u0001B\u0002B\u0003-!'\u0001\u0006fm&$WM\\2fIE\u00022a\r\u001c\u0018\u001b\u0005!$BA\u001b\u001f\u0003\u001d\u0011XM\u001a7fGRL!a\u000e\u001b\u0003\u0011\rc\u0017m]:UC\u001eD\u0001\"\u000f\u0001\u0003\u0002\u0003\u0006YAO\u0001\u0003KZ\u00042aO(\u0018\u001d\taDJ\u0004\u0002>\u0015:\u0011a(\u0013\b\u0003\u007f!s!\u0001Q$\u000f\u0005\u00053eB\u0001\"F\u001b\u0005\u0019%B\u0001#\u000f\u0003\u0019a$o\\8u}%\tQ\"\u0003\u0002\f\u0019%\u0011\u0011BC\u0005\u0003\u000f!I!!\u0002\u0004\n\u0005-#\u0011A\u0002;f]N|'/\u0003\u0002N\u001d\u0006\tB+\u001a8t_JtU/\\3sS\u000el\u0015\r\u001e5\u000b\u0005-#\u0011B\u0001)R\u00055!VM\\:pe:+X.\u001a:jG*\u0011QJ\u0014\u0005\u0006'\u0002!\t\u0001V\u0001\u0007y%t\u0017\u000e\u001e \u0015\u0005USFc\u0001,Y3B\u0019q\u000bA\f\u000e\u0003\tAQ!\r*A\u0004IBQ!\u000f*A\u0004iBQA\n*A\u0002!BQ\u0001\u0018\u0001\u0005Bu\u000b!cY8naV$XmT;uaV$8\u000b[1qKR\u0011a\f\u001a\t\u0003?\nl\u0011\u0001\u0019\u0006\u0003C\u0012\tQ!\u001e;jYNL!a\u00191\u0003\u000bMC\u0017\r]3\t\u000b\u0015\\\u0006\u0019\u00010\u0002\u0015%t\u0007/\u001e;TQ\u0006\u0004X\rC\u0003h\u0001\u0011\u0005\u0003.\u0001\u0007va\u0012\fG/Z(viB,H\u000f\u0006\u0002j[B\u0019!n[\f\u000e\u00039K!\u0001\u001c(\u0003\rQ+gn]8s\u0011\u0015qg\r1\u0001j\u0003\u0015Ig\u000e];u\u0011\u0015\u0001\b\u0001\"\u0011r\u0003=)\b\u000fZ1uK\u001e\u0013\u0018\rZ%oaV$HcA5sg\")an\u001ca\u0001S\")Ao\u001ca\u0001S\u0006QqM]1e\u001fV$\b/\u001e;)\t\u00011\u0018P\u001f\t\u0003;]L!\u0001\u001f\u0010\u0003!M+'/[1m-\u0016\u00148/[8o+&#\u0015!\u0002<bYV,g\u0004\u0003\u0019\r\u001d\"Ov\u0014\b\u0007\b\u000bq\u0014\u0001\u0012A?\u0002\u0019U\u00038+Y7qY&twm\r#\u0011\u0005]sh!B\u0001\u0003\u0011\u0003y8#\u0002@\u0002\u0002\u0005\u001d\u0001cA\u000f\u0002\u0004%\u0019\u0011Q\u0001\u0010\u0003\r\u0005s\u0017PU3g!\ri\u0012\u0011B\u0005\u0004\u0003\u0017q\"\u0001D*fe&\fG.\u001b>bE2,\u0007BB*\u007f\t\u0003\ty\u0001F\u0001~\u0011\u001d\t\u0019B C\u0001\u0003+\tQ!\u00199qYf,B!a\u0006\u0002 Q!\u0011\u0011DA&)\u0019\tY\"!\u0011\u0002HA!q\u000bAA\u000f!\rA\u0012q\u0004\u0003\u000b5\u0005E\u0001\u0015!A\u0001\u0006\u0004Y\u0002\u0006CA\u0010\u0003G\tI#a\u000e\u0011\u0007u\t)#C\u0002\u0002(y\u00111b\u001d9fG&\fG.\u001b>fIFJ1%a\u000b\u0002.\u0005E\u0012q\u0006\b\u0004;\u00055\u0012bAA\u0018=\u0005)a\t\\8biF2A%a\r\u00026}q1AQA\u001b\u0013\u0005y\u0012'C\u0012\u0002:\u0005m\u0012qHA\u001f\u001d\ri\u00121H\u0005\u0004\u0003{q\u0012A\u0002#pk\ndW-\r\u0004%\u0003g\t)d\b\u0005\u000b\u0003\u0007\n\t\"!AA\u0004\u0005\u0015\u0013AC3wS\u0012,gnY3%eA!1GNA\u000f\u0011\u001dI\u0014\u0011\u0003a\u0002\u0003\u0013\u0002BaO(\u0002\u001e!1a%!\u0005A\u0002!B\u0011\"a\u0014\u007f\u0003\u0003%I!!\u0015\u0002\u0017I,\u0017\r\u001a*fg>dg/\u001a\u000b\u0003\u0003'\u0002B!!\u0016\u0002`5\u0011\u0011q\u000b\u0006\u0005\u00033\nY&\u0001\u0003mC:<'BAA/\u0003\u0011Q\u0017M^1\n\t\u0005\u0005\u0014q\u000b\u0002\u0007\u001f\nTWm\u0019;")
public class UpSampling3D<T>
extends TensorModule<T> {
    public static final long serialVersionUID = 3462228835945094156L;
    private final int[] size;
    private final TensorNumericMath.TensorNumeric<T> ev;

    public int[] size() {
        return this.size;
    }

    @Override
    public Shape computeOutputShape(Shape inputShape) {
        int[] input = (int[])inputShape.toSingle().toArray(ClassTag$.MODULE$.Int());
        Log4Error$.MODULE$.invalidInputError(input.length == 5, new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"UpSampling3D requires 5D input, but got input dim ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)input.length)})), Log4Error$.MODULE$.invalidInputError$default$3());
        return Shape$.MODULE$.apply((Seq<Object>)Predef$.MODULE$.wrapIntArray(new int[]{input[0], input[1], input[2] * this.size()[0], input[3] * this.size()[1], input[4] * this.size()[2]}));
    }

    @Override
    public Tensor<T> updateOutput(Tensor<T> input) {
        Log4Error$.MODULE$.invalidInputError(input.dim() == 5, "only supports 5d tensors", Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(input.isContiguous(), "input need to be contiguous", Log4Error$.MODULE$.invalidInputError$default$3());
        int inputDepth = input.size(3);
        int inputHeight = input.size(4);
        int inputWidth = input.size(5);
        int dT = this.size()[0];
        int dH = this.size()[1];
        int dW = this.size()[2];
        int outputDepth = inputDepth * dT;
        int outputHeight = inputHeight * dH;
        int outputWidth = inputWidth * dW;
        ((Tensor)this.output()).resize(input.size(1), input.size(2), outputDepth, outputHeight, outputWidth);
        int idim = input.dim();
        int xDim = idim - 1;
        int yDim = idim - 2;
        int zDim = idim - 3;
        int osz0 = ((Tensor)this.output()).size(1);
        int osz1 = ((Tensor)this.output()).size(2);
        int osz2 = ((Tensor)this.output()).size(3);
        int osz3 = ((Tensor)this.output()).size(4);
        int osz4 = ((Tensor)this.output()).size(5);
        int[] is = input.stride();
        int[] os = ((Tensor)this.output()).stride();
        Object pin = input.storage().array();
        int inOffset = input.storageOffset() - 1;
        Object pout = ((Tensor)this.output()).storage().array();
        int outOffset = ((Tensor)this.output()).storageOffset() - 1;
        int i0 = 0;
        int i1 = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        int isrc = 0;
        int idst = 0;
        int[] iout = new int[5];
        int[] iin = new int[5];
        for (i0 = 0; i0 < osz0; ++i0) {
            iout[0] = i0;
            iin[0] = i0;
            for (i1 = 0; i1 < osz1; ++i1) {
                iout[1] = i1;
                iin[1] = i1;
                for (i2 = 0; i2 < osz2; ++i2) {
                    iout[2] = i2;
                    iin[2] = i2;
                    for (i3 = 0; i3 < osz3; ++i3) {
                        iout[3] = i3;
                        iin[3] = i3;
                        for (i4 = 0; i4 < osz4; ++i4) {
                            iout[4] = i4;
                            iin[4] = i4;
                            iin[xDim] = iout[xDim] / dW;
                            iin[yDim] = iout[yDim] / dH;
                            iin[zDim] = iout[zDim] / dT;
                            idst = i0 * os[0] + i1 * os[1] + i2 * os[2] + i3 * os[3];
                            isrc = iin[0] * is[0] + iin[1] * is[1] + iin[2] * is[2] + iin[3] * is[3];
                            if (idim > 4) {
                                idst += i4 * os[4];
                                isrc += iin[4] * is[4];
                            }
                            ScalaRunTime$.MODULE$.array_update(pout, outOffset + idst, ScalaRunTime$.MODULE$.array_apply(pin, inOffset + isrc));
                        }
                    }
                }
            }
        }
        return (Tensor)this.output();
    }

    @Override
    public Tensor<T> updateGradInput(Tensor<T> input, Tensor<T> gradOutput) {
        ((Tensor)this.gradInput()).resizeAs(input).zero();
        int dT = this.size()[0];
        int dH = this.size()[1];
        int dW = this.size()[2];
        int idim = ((Tensor)this.gradInput()).dim();
        int xDim = idim - 1;
        int yDim = idim - 2;
        int zDim = idim - 3;
        int isz0 = ((Tensor)this.gradInput()).size(1);
        int isz1 = ((Tensor)this.gradInput()).size(2);
        int isz2 = ((Tensor)this.gradInput()).size(3);
        int isz3 = ((Tensor)this.gradInput()).size(4);
        int isz4 = ((Tensor)this.gradInput()).size(5);
        int[] is = ((Tensor)this.gradInput()).stride();
        int[] os = gradOutput.stride();
        Object pin = ((Tensor)this.gradInput()).storage().array();
        Object pout = gradOutput.storage().array();
        int inOffset = ((Tensor)this.gradInput()).storageOffset() - 1;
        int outOffset = gradOutput.storageOffset() - 1;
        int i0 = 0;
        int i1 = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        int isrc = 0;
        int idst = 0;
        int x = 0;
        int y = 0;
        int z = 0;
        int[] iin = new int[5];
        int[] iout = new int[5];
        for (i0 = 0; i0 < isz0; ++i0) {
            iout[0] = i0;
            iin[0] = i0;
            for (i1 = 0; i1 < isz1; ++i1) {
                iout[1] = i1;
                iin[1] = i1;
                for (i2 = 0; i2 < isz2; ++i2) {
                    iout[2] = i2;
                    iin[2] = i2;
                    for (i3 = 0; i3 < isz3; ++i3) {
                        iout[3] = i3;
                        iin[3] = i3;
                        for (i4 = 0; i4 < isz4; ++i4) {
                            iout[4] = i4;
                            iin[4] = i4;
                            idst = i0 * is[0] + i1 * is[1] + i2 * is[2] + i3 * is[3];
                            if (idim > 4) {
                                idst += i4 * is[4];
                            }
                            for (z = 0; z < dT; ++z) {
                                for (y = 0; y < dH; ++y) {
                                    for (x = 0; x < dW; ++x) {
                                        iout[xDim] = dW * iin[xDim] + x;
                                        iout[yDim] = dH * iin[yDim] + y;
                                        iout[zDim] = dT * iin[zDim] + z;
                                        isrc = iout[0] * os[0] + iout[1] * os[1] + iout[2] * os[2] + iout[3] * os[3];
                                        if (idim > 4) {
                                            isrc += iout[4] * os[4];
                                        }
                                        ScalaRunTime$.MODULE$.array_update(pin, inOffset + idst, this.ev.plus(ScalaRunTime$.MODULE$.array_apply(pin, inOffset + idst), ScalaRunTime$.MODULE$.array_apply(pout, outOffset + isrc)));
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
        return (Tensor)this.gradInput();
    }

    public UpSampling3D(int[] size, ClassTag<T> evidence$1, TensorNumericMath.TensorNumeric<T> ev) {
        this.size = size;
        this.ev = ev;
        super(evidence$1, ev);
        Log4Error$.MODULE$.invalidInputError(size != null && size.length == 3, "the size should be 3 dims", Log4Error$.MODULE$.invalidInputError$default$3());
    }
}

