/*
 * Decompiled with CFR 0.152.
 */
package org.tugraz.sysds.runtime.matrix.data;

import jcuda.Pointer;
import jcuda.jcudnn.JCudnn;
import jcuda.jcudnn.cudnnConvolutionDescriptor;
import jcuda.jcudnn.cudnnFilterDescriptor;
import jcuda.jcudnn.cudnnHandle;
import jcuda.jcudnn.cudnnTensorDescriptor;
import org.tugraz.sysds.api.DMLScript;
import org.tugraz.sysds.runtime.DMLRuntimeException;
import org.tugraz.sysds.runtime.instructions.gpu.context.GPUContext;
import org.tugraz.sysds.runtime.matrix.data.LibMatrixCUDA;
import org.tugraz.sysds.runtime.matrix.data.LibMatrixCuDNN;

public class LibMatrixCuDNNConvolutionAlgorithm
implements AutoCloseable {
    static long MAX_WORKSPACE_LIMIT_BYTES = 1000000000L;
    public int algo = -1;
    public Pointer workSpace = new Pointer();
    public long sizeInBytes = 0L;
    cudnnTensorDescriptor nchwTensorDesc = null;
    cudnnTensorDescriptor nkpqTensorDesc = null;
    cudnnFilterDescriptor filterDesc = null;
    cudnnConvolutionDescriptor convDesc = null;
    GPUContext gCtx = null;
    String instName = null;

    private LibMatrixCuDNNConvolutionAlgorithm(GPUContext gCtx, String instName, int N, int C, int H, int W, int K, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q) {
        int[] padding = new int[]{pad_h, pad_w};
        int[] strides = new int[]{stride_h, stride_w};
        this.convDesc = LibMatrixCuDNNConvolutionAlgorithm.allocateConvolutionDescriptor(padding, strides);
        this.gCtx = gCtx;
        this.instName = instName;
        this.nchwTensorDesc = LibMatrixCuDNNConvolutionAlgorithm.allocateTensorDescriptor(N, C, H, W);
        this.nkpqTensorDesc = LibMatrixCuDNNConvolutionAlgorithm.allocateTensorDescriptor(N, K, P, Q);
        this.filterDesc = LibMatrixCuDNNConvolutionAlgorithm.allocateFilterDescriptor(K, C, R, S);
    }

    @Override
    public void close() {
        if (this.nchwTensorDesc != null) {
            JCudnn.cudnnDestroyTensorDescriptor((cudnnTensorDescriptor)this.nchwTensorDesc);
        }
        if (this.nkpqTensorDesc != null) {
            JCudnn.cudnnDestroyTensorDescriptor((cudnnTensorDescriptor)this.nkpqTensorDesc);
        }
        if (this.filterDesc != null) {
            JCudnn.cudnnDestroyFilterDescriptor((cudnnFilterDescriptor)this.filterDesc);
        }
        if (this.convDesc != null) {
            JCudnn.cudnnDestroyConvolutionDescriptor((cudnnConvolutionDescriptor)this.convDesc);
        }
        if (this.sizeInBytes != 0L) {
            try {
                this.gCtx.cudaFreeHelper(this.instName, this.workSpace, DMLScript.EAGER_CUDA_FREE);
            }
            catch (DMLRuntimeException e) {
                throw new RuntimeException(e);
            }
        }
    }

    public static LibMatrixCuDNNConvolutionAlgorithm cudnnGetConvolutionForwardAlgorithm(GPUContext gCtx, String instName, int N, int C, int H, int W, int K, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q, long workspaceLimit) {
        LibMatrixCuDNNConvolutionAlgorithm ret = new LibMatrixCuDNNConvolutionAlgorithm(gCtx, instName, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
        int[] algos = new int[]{-1};
        long[] sizeInBytesArray = new long[]{Math.min(workspaceLimit, MAX_WORKSPACE_LIMIT_BYTES)};
        JCudnn.cudnnGetConvolutionForwardAlgorithm((cudnnHandle)LibMatrixCuDNN.getCudnnHandle(gCtx), (cudnnTensorDescriptor)ret.nchwTensorDesc, (cudnnFilterDescriptor)ret.filterDesc, (cudnnConvolutionDescriptor)ret.convDesc, (cudnnTensorDescriptor)ret.nkpqTensorDesc, (int)2, (long)sizeInBytesArray[0], (int[])algos);
        JCudnn.cudnnGetConvolutionForwardWorkspaceSize((cudnnHandle)LibMatrixCuDNN.getCudnnHandle(gCtx), (cudnnTensorDescriptor)ret.nchwTensorDesc, (cudnnFilterDescriptor)ret.filterDesc, (cudnnConvolutionDescriptor)ret.convDesc, (cudnnTensorDescriptor)ret.nkpqTensorDesc, (int)algos[0], (long[])sizeInBytesArray);
        if (sizeInBytesArray[0] != 0L) {
            ret.workSpace = gCtx.allocate(instName, sizeInBytesArray[0]);
        }
        ret.sizeInBytes = sizeInBytesArray[0];
        ret.algo = algos[0];
        return ret;
    }

    public static LibMatrixCuDNNConvolutionAlgorithm cudnnGetConvolutionBackwardFilterAlgorithm(GPUContext gCtx, String instName, int N, int C, int H, int W, int K, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q, long workspaceLimit) {
        LibMatrixCuDNNConvolutionAlgorithm ret = new LibMatrixCuDNNConvolutionAlgorithm(gCtx, instName, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
        int[] algos = new int[]{-1};
        long[] sizeInBytesArray = new long[]{Math.min(workspaceLimit, MAX_WORKSPACE_LIMIT_BYTES)};
        JCudnn.cudnnGetConvolutionBackwardFilterAlgorithm((cudnnHandle)LibMatrixCuDNN.getCudnnHandle(gCtx), (cudnnTensorDescriptor)ret.nchwTensorDesc, (cudnnTensorDescriptor)ret.nkpqTensorDesc, (cudnnConvolutionDescriptor)ret.convDesc, (cudnnFilterDescriptor)ret.filterDesc, (int)2, (long)sizeInBytesArray[0], (int[])algos);
        JCudnn.cudnnGetConvolutionBackwardFilterWorkspaceSize((cudnnHandle)LibMatrixCuDNN.getCudnnHandle(gCtx), (cudnnTensorDescriptor)ret.nchwTensorDesc, (cudnnTensorDescriptor)ret.nkpqTensorDesc, (cudnnConvolutionDescriptor)ret.convDesc, (cudnnFilterDescriptor)ret.filterDesc, (int)algos[0], (long[])sizeInBytesArray);
        if (sizeInBytesArray[0] != 0L) {
            ret.workSpace = gCtx.allocate(instName, sizeInBytesArray[0]);
        }
        ret.sizeInBytes = sizeInBytesArray[0];
        ret.algo = algos[0];
        return ret;
    }

    public static LibMatrixCuDNNConvolutionAlgorithm cudnnGetConvolutionBackwardDataAlgorithm(GPUContext gCtx, String instName, int N, int C, int H, int W, int K, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q, long workspaceLimit) {
        LibMatrixCuDNNConvolutionAlgorithm ret = new LibMatrixCuDNNConvolutionAlgorithm(gCtx, instName, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
        if (H == R || W == S) {
            ret.algo = 0;
        } else {
            int[] algos = new int[]{-1};
            long[] sizeInBytesArray = new long[]{Math.min(workspaceLimit, MAX_WORKSPACE_LIMIT_BYTES)};
            JCudnn.cudnnGetConvolutionBackwardDataAlgorithm((cudnnHandle)LibMatrixCuDNN.getCudnnHandle(gCtx), (cudnnFilterDescriptor)ret.filterDesc, (cudnnTensorDescriptor)ret.nkpqTensorDesc, (cudnnConvolutionDescriptor)ret.convDesc, (cudnnTensorDescriptor)ret.nchwTensorDesc, (int)2, (long)sizeInBytesArray[0], (int[])algos);
            JCudnn.cudnnGetConvolutionBackwardDataWorkspaceSize((cudnnHandle)LibMatrixCuDNN.getCudnnHandle(gCtx), (cudnnFilterDescriptor)ret.filterDesc, (cudnnTensorDescriptor)ret.nkpqTensorDesc, (cudnnConvolutionDescriptor)ret.convDesc, (cudnnTensorDescriptor)ret.nchwTensorDesc, (int)algos[0], (long[])sizeInBytesArray);
            if (sizeInBytesArray[0] != 0L) {
                ret.workSpace = gCtx.allocate(instName, sizeInBytesArray[0]);
            }
            ret.sizeInBytes = sizeInBytesArray[0];
            ret.algo = algos[0];
        }
        return ret;
    }

    private static cudnnTensorDescriptor allocateTensorDescriptor(int N, int C, int H, int W) {
        cudnnTensorDescriptor tensorDescriptor = new cudnnTensorDescriptor();
        JCudnn.cudnnCreateTensorDescriptor((cudnnTensorDescriptor)tensorDescriptor);
        JCudnn.cudnnSetTensor4dDescriptor((cudnnTensorDescriptor)tensorDescriptor, (int)0, (int)LibMatrixCUDA.CUDNN_DATA_TYPE, (int)N, (int)C, (int)H, (int)W);
        return tensorDescriptor;
    }

    private static cudnnFilterDescriptor allocateFilterDescriptor(int K, int C, int R, int S) {
        cudnnFilterDescriptor filterDesc = new cudnnFilterDescriptor();
        JCudnn.cudnnCreateFilterDescriptor((cudnnFilterDescriptor)filterDesc);
        JCudnn.cudnnSetFilter4dDescriptor((cudnnFilterDescriptor)filterDesc, (int)LibMatrixCUDA.CUDNN_DATA_TYPE, (int)0, (int)K, (int)C, (int)R, (int)S);
        return filterDesc;
    }

    private static cudnnConvolutionDescriptor allocateConvolutionDescriptor(int[] padding, int[] strides) {
        cudnnConvolutionDescriptor convDesc = new cudnnConvolutionDescriptor();
        JCudnn.cudnnCreateConvolutionDescriptor((cudnnConvolutionDescriptor)convDesc);
        JCudnn.cudnnSetConvolution2dDescriptor((cudnnConvolutionDescriptor)convDesc, (int)padding[0], (int)padding[1], (int)strides[0], (int)strides[1], (int)1, (int)1, (int)1, (int)LibMatrixCUDA.CUDNN_DATA_TYPE);
        return convDesc;
    }
}

