/*
 * Decompiled with CFR 0.152.
 */
package org.tugraz.sysds.runtime.matrix.data;

import jcuda.Pointer;
import jcuda.jcudnn.JCudnn;
import jcuda.jcudnn.cudnnDropoutDescriptor;
import jcuda.jcudnn.cudnnFilterDescriptor;
import jcuda.jcudnn.cudnnHandle;
import jcuda.jcudnn.cudnnRNNDescriptor;
import jcuda.jcudnn.cudnnTensorDescriptor;
import org.tugraz.sysds.api.DMLScript;
import org.tugraz.sysds.runtime.DMLRuntimeException;
import org.tugraz.sysds.runtime.controlprogram.context.ExecutionContext;
import org.tugraz.sysds.runtime.instructions.gpu.context.GPUContext;
import org.tugraz.sysds.runtime.matrix.data.LibMatrixCUDA;

public class LibMatrixCuDNNRnnAlgorithm
implements AutoCloseable {
    GPUContext gCtx;
    String instName;
    cudnnDropoutDescriptor dropoutDesc;
    cudnnRNNDescriptor rnnDesc;
    cudnnTensorDescriptor[] xDesc;
    cudnnTensorDescriptor[] dxDesc;
    cudnnTensorDescriptor[] yDesc;
    cudnnTensorDescriptor[] dyDesc;
    cudnnTensorDescriptor hxDesc;
    cudnnTensorDescriptor cxDesc;
    cudnnTensorDescriptor hyDesc;
    cudnnTensorDescriptor cyDesc;
    cudnnTensorDescriptor dhxDesc;
    cudnnTensorDescriptor dcxDesc;
    cudnnTensorDescriptor dhyDesc;
    cudnnTensorDescriptor dcyDesc;
    cudnnFilterDescriptor wDesc;
    cudnnFilterDescriptor dwDesc;
    long sizeInBytes;
    Pointer workSpace;
    long reserveSpaceSizeInBytes;
    Pointer reserveSpace;
    long dropOutSizeInBytes;
    Pointer dropOutStateSpace;

    public LibMatrixCuDNNRnnAlgorithm(ExecutionContext ec, GPUContext gCtx, String instName, String rnnMode, int N, int T, int M, int D, boolean isTraining, Pointer w) throws DMLRuntimeException {
        this.gCtx = gCtx;
        this.instName = instName;
        this.xDesc = new cudnnTensorDescriptor[T];
        this.dxDesc = new cudnnTensorDescriptor[T];
        this.yDesc = new cudnnTensorDescriptor[T];
        this.dyDesc = new cudnnTensorDescriptor[T];
        for (int t = 0; t < T; ++t) {
            this.xDesc[t] = LibMatrixCuDNNRnnAlgorithm.allocateTensorDescriptorWithStride(N, D, 1);
            this.dxDesc[t] = LibMatrixCuDNNRnnAlgorithm.allocateTensorDescriptorWithStride(N, D, 1);
            this.yDesc[t] = LibMatrixCuDNNRnnAlgorithm.allocateTensorDescriptorWithStride(N, M, 1);
            this.dyDesc[t] = LibMatrixCuDNNRnnAlgorithm.allocateTensorDescriptorWithStride(N, M, 1);
        }
        this.hxDesc = LibMatrixCuDNNRnnAlgorithm.allocateTensorDescriptorWithStride(1, N, M);
        this.dhxDesc = LibMatrixCuDNNRnnAlgorithm.allocateTensorDescriptorWithStride(1, N, M);
        this.cxDesc = LibMatrixCuDNNRnnAlgorithm.allocateTensorDescriptorWithStride(1, N, M);
        this.dcxDesc = LibMatrixCuDNNRnnAlgorithm.allocateTensorDescriptorWithStride(1, N, M);
        this.hyDesc = LibMatrixCuDNNRnnAlgorithm.allocateTensorDescriptorWithStride(1, N, M);
        this.dhyDesc = LibMatrixCuDNNRnnAlgorithm.allocateTensorDescriptorWithStride(1, N, M);
        this.cyDesc = LibMatrixCuDNNRnnAlgorithm.allocateTensorDescriptorWithStride(1, N, M);
        this.dcyDesc = LibMatrixCuDNNRnnAlgorithm.allocateTensorDescriptorWithStride(1, N, M);
        this.dropoutDesc = new cudnnDropoutDescriptor();
        JCudnn.cudnnCreateDropoutDescriptor((cudnnDropoutDescriptor)this.dropoutDesc);
        long[] _dropOutSizeInBytes = new long[]{-1L};
        JCudnn.cudnnDropoutGetStatesSize((cudnnHandle)gCtx.getCudnnHandle(), (long[])_dropOutSizeInBytes);
        this.dropOutSizeInBytes = _dropOutSizeInBytes[0];
        this.dropOutStateSpace = new Pointer();
        if (this.dropOutSizeInBytes != 0L) {
            this.dropOutStateSpace = gCtx.allocate(instName, this.dropOutSizeInBytes);
        }
        JCudnn.cudnnSetDropoutDescriptor((cudnnDropoutDescriptor)this.dropoutDesc, (cudnnHandle)gCtx.getCudnnHandle(), (float)0.0f, (Pointer)this.dropOutStateSpace, (long)this.dropOutSizeInBytes, (long)12345L);
        this.rnnDesc = new cudnnRNNDescriptor();
        JCudnn.cudnnCreateRNNDescriptor((cudnnRNNDescriptor)this.rnnDesc);
        JCudnn.cudnnSetRNNDescriptor_v6((cudnnHandle)gCtx.getCudnnHandle(), (cudnnRNNDescriptor)this.rnnDesc, (int)M, (int)1, (cudnnDropoutDescriptor)this.dropoutDesc, (int)0, (int)0, (int)LibMatrixCuDNNRnnAlgorithm.getCuDNNRnnMode(rnnMode), (int)0, (int)LibMatrixCUDA.CUDNN_DATA_TYPE);
        int expectedNumWeights = this.getExpectedNumWeights();
        if (rnnMode.equalsIgnoreCase("lstm") && (D + M + 2) * 4 * M != expectedNumWeights) {
            throw new DMLRuntimeException("Incorrect number of RNN parameters " + (D + M + 2) * 4 * M + " != " + expectedNumWeights + ", where numFeatures=" + D + ", hiddenSize=" + M);
        }
        this.wDesc = LibMatrixCuDNNRnnAlgorithm.allocateFilterDescriptor(expectedNumWeights);
        this.dwDesc = LibMatrixCuDNNRnnAlgorithm.allocateFilterDescriptor(expectedNumWeights);
        this.workSpace = new Pointer();
        this.reserveSpace = new Pointer();
        this.sizeInBytes = this.getWorkspaceSize(T);
        if (this.sizeInBytes != 0L) {
            this.workSpace = gCtx.allocate(instName, this.sizeInBytes);
        }
        this.reserveSpaceSizeInBytes = 0L;
        if (isTraining) {
            this.reserveSpaceSizeInBytes = this.getReservespaceSize(T);
            if (this.reserveSpaceSizeInBytes != 0L) {
                this.reserveSpace = gCtx.allocate(instName, this.reserveSpaceSizeInBytes);
            }
        }
    }

    private static int getNumLinearLayers(String rnnMode) throws DMLRuntimeException {
        int ret = 0;
        if (rnnMode.equalsIgnoreCase("rnn_relu") || rnnMode.equalsIgnoreCase("rnn_tanh")) {
            ret = 2;
        } else if (rnnMode.equalsIgnoreCase("lstm")) {
            ret = 8;
        } else if (rnnMode.equalsIgnoreCase("gru")) {
            ret = 6;
        } else {
            throw new DMLRuntimeException("Unsupported rnn mode:" + rnnMode);
        }
        return ret;
    }

    private long getWorkspaceSize(int seqLength) {
        long[] sizeInBytesArray = new long[1];
        JCudnn.cudnnGetRNNWorkspaceSize((cudnnHandle)this.gCtx.getCudnnHandle(), (cudnnRNNDescriptor)this.rnnDesc, (int)seqLength, (cudnnTensorDescriptor[])this.xDesc, (long[])sizeInBytesArray);
        return sizeInBytesArray[0];
    }

    private long getReservespaceSize(int seqLength) {
        long[] sizeInBytesArray = new long[1];
        JCudnn.cudnnGetRNNTrainingReserveSize((cudnnHandle)this.gCtx.getCudnnHandle(), (cudnnRNNDescriptor)this.rnnDesc, (int)seqLength, (cudnnTensorDescriptor[])this.xDesc, (long[])sizeInBytesArray);
        return sizeInBytesArray[0];
    }

    private static int getCuDNNRnnMode(String rnnMode) throws DMLRuntimeException {
        int rnnModeVal = -1;
        if (rnnMode.equalsIgnoreCase("rnn_relu")) {
            rnnModeVal = 0;
        } else if (rnnMode.equalsIgnoreCase("rnn_tanh")) {
            rnnModeVal = 1;
        } else if (rnnMode.equalsIgnoreCase("lstm")) {
            rnnModeVal = 2;
        } else if (rnnMode.equalsIgnoreCase("gru")) {
            rnnModeVal = 3;
        } else {
            throw new DMLRuntimeException("Unsupported rnn mode:" + rnnMode);
        }
        return rnnModeVal;
    }

    private int getExpectedNumWeights() throws DMLRuntimeException {
        long[] weightSizeInBytesArray = new long[]{-1L};
        JCudnn.cudnnGetRNNParamsSize((cudnnHandle)this.gCtx.getCudnnHandle(), (cudnnRNNDescriptor)this.rnnDesc, (cudnnTensorDescriptor)this.xDesc[0], (long[])weightSizeInBytesArray, (int)LibMatrixCUDA.CUDNN_DATA_TYPE);
        return LibMatrixCUDA.toInt(weightSizeInBytesArray[0] / (long)LibMatrixCUDA.sizeOfDataType);
    }

    private static cudnnFilterDescriptor allocateFilterDescriptor(int numWeights) {
        cudnnFilterDescriptor filterDesc = new cudnnFilterDescriptor();
        JCudnn.cudnnCreateFilterDescriptor((cudnnFilterDescriptor)filterDesc);
        JCudnn.cudnnSetFilterNdDescriptor((cudnnFilterDescriptor)filterDesc, (int)LibMatrixCUDA.CUDNN_DATA_TYPE, (int)0, (int)3, (int[])new int[]{numWeights, 1, 1});
        return filterDesc;
    }

    private static cudnnTensorDescriptor allocateTensorDescriptorWithStride(int firstDim, int secondDim, int thirdDim) throws DMLRuntimeException {
        cudnnTensorDescriptor tensorDescriptor = new cudnnTensorDescriptor();
        JCudnn.cudnnCreateTensorDescriptor((cudnnTensorDescriptor)tensorDescriptor);
        int[] dimA = new int[]{firstDim, secondDim, thirdDim};
        int[] strideA = new int[]{dimA[2] * dimA[1], dimA[2], 1};
        JCudnn.cudnnSetTensorNdDescriptor((cudnnTensorDescriptor)tensorDescriptor, (int)LibMatrixCUDA.CUDNN_DATA_TYPE, (int)3, (int[])dimA, (int[])strideA);
        return tensorDescriptor;
    }

    @Override
    public void close() {
        if (this.dropoutDesc != null) {
            JCudnn.cudnnDestroyDropoutDescriptor((cudnnDropoutDescriptor)this.dropoutDesc);
        }
        this.dropoutDesc = null;
        if (this.rnnDesc != null) {
            JCudnn.cudnnDestroyRNNDescriptor((cudnnRNNDescriptor)this.rnnDesc);
        }
        this.rnnDesc = null;
        if (this.hxDesc != null) {
            JCudnn.cudnnDestroyTensorDescriptor((cudnnTensorDescriptor)this.hxDesc);
        }
        this.hxDesc = null;
        if (this.dhxDesc != null) {
            JCudnn.cudnnDestroyTensorDescriptor((cudnnTensorDescriptor)this.dhxDesc);
        }
        this.dhxDesc = null;
        if (this.hyDesc != null) {
            JCudnn.cudnnDestroyTensorDescriptor((cudnnTensorDescriptor)this.hyDesc);
        }
        this.hyDesc = null;
        if (this.dhyDesc != null) {
            JCudnn.cudnnDestroyTensorDescriptor((cudnnTensorDescriptor)this.dhyDesc);
        }
        this.dhyDesc = null;
        if (this.cxDesc != null) {
            JCudnn.cudnnDestroyTensorDescriptor((cudnnTensorDescriptor)this.cxDesc);
        }
        this.cxDesc = null;
        if (this.dcxDesc != null) {
            JCudnn.cudnnDestroyTensorDescriptor((cudnnTensorDescriptor)this.dcxDesc);
        }
        this.dcxDesc = null;
        if (this.cyDesc != null) {
            JCudnn.cudnnDestroyTensorDescriptor((cudnnTensorDescriptor)this.cyDesc);
        }
        this.cyDesc = null;
        if (this.dcyDesc != null) {
            JCudnn.cudnnDestroyTensorDescriptor((cudnnTensorDescriptor)this.dcyDesc);
        }
        this.dcyDesc = null;
        if (this.wDesc != null) {
            JCudnn.cudnnDestroyFilterDescriptor((cudnnFilterDescriptor)this.wDesc);
        }
        this.wDesc = null;
        if (this.dwDesc != null) {
            JCudnn.cudnnDestroyFilterDescriptor((cudnnFilterDescriptor)this.dwDesc);
        }
        this.dwDesc = null;
        if (this.xDesc != null) {
            for (cudnnTensorDescriptor dsc : this.xDesc) {
                JCudnn.cudnnDestroyTensorDescriptor((cudnnTensorDescriptor)dsc);
            }
            this.xDesc = null;
        }
        if (this.dxDesc != null) {
            for (cudnnTensorDescriptor dsc : this.dxDesc) {
                JCudnn.cudnnDestroyTensorDescriptor((cudnnTensorDescriptor)dsc);
            }
            this.dxDesc = null;
        }
        if (this.yDesc != null) {
            for (cudnnTensorDescriptor dsc : this.yDesc) {
                JCudnn.cudnnDestroyTensorDescriptor((cudnnTensorDescriptor)dsc);
            }
            this.yDesc = null;
        }
        if (this.dyDesc != null) {
            for (cudnnTensorDescriptor dsc : this.dyDesc) {
                JCudnn.cudnnDestroyTensorDescriptor((cudnnTensorDescriptor)dsc);
            }
            this.dyDesc = null;
        }
        if (this.sizeInBytes != 0L) {
            try {
                this.gCtx.cudaFreeHelper(this.instName, this.workSpace, DMLScript.EAGER_CUDA_FREE);
            }
            catch (DMLRuntimeException e) {
                throw new RuntimeException(e);
            }
        }
        this.workSpace = null;
        if (this.reserveSpaceSizeInBytes != 0L) {
            try {
                this.gCtx.cudaFreeHelper(this.instName, this.reserveSpace, DMLScript.EAGER_CUDA_FREE);
            }
            catch (DMLRuntimeException e) {
                throw new RuntimeException(e);
            }
        }
        this.reserveSpace = null;
        if (this.dropOutSizeInBytes != 0L) {
            try {
                this.gCtx.cudaFreeHelper(this.instName, this.dropOutStateSpace, DMLScript.EAGER_CUDA_FREE);
            }
            catch (DMLRuntimeException e) {
                throw new RuntimeException(e);
            }
        }
    }
}

