import cupy

from dexp.processing.backends._cupy.texture.texture import create_cuda_texture
from dexp.processing.backends.backend import Backend


def _warp_3d_cupy(image,
                  vector_field,
                  mode,
                  block_size: int = 8):
    """

    Parameters
    ----------
    image
    vector_field
    mode
    block_size

    Returns
    -------

    """
    xp = Backend.get_xp_module()

    source = r'''
                extern "C"{
                __global__ void warp_3d(float* warped_image,
                                        cudaTextureObject_t input_image,
                                        cudaTextureObject_t vector_field,
                                        int width, 
                                        int height,
                                        int depth)
                {
                    unsigned int x = blockIdx.x * blockDim.x + threadIdx.x;
                    unsigned int y = blockIdx.y * blockDim.y + threadIdx.y;
                    unsigned int z = blockIdx.z * blockDim.z + threadIdx.z;

                    if (x < width && y < height && z < depth)
                    {
                        // coordinates in coord-normalised vector_field texture:
                        float u = float(x)/width;
                        float v = float(y)/height;
                        float w = float(z)/depth;
                        //printf("(%f,%f,%f)\n", u, v, w);

                        // Obtain linearly interpolated vector at (u,v,w):
                        float4 vector = tex3D<float4>(vector_field, u, v, w);
                        
                        //printf("(%f,%f,%f,%f)\n", vector.x, vector.y, vector.z, vector.w);

                        // Obtain the shifted coordinates of the source voxel, 
                        // flip axis order to match numpy order:
                        float sx = 0.5f + float(x) - vector.z;
                        float sy = 0.5f + float(y) - vector.y;
                        float sz = 0.5f + float(z) - vector.x;

                        // Sample source image for voxel value:
                        float value = tex3D<float>(input_image, sx, sy, sz);

                        //printf("(%f, %f, %f)=%f\n", sx, sy, sz, value);

                        // Store interpolated value:
                        warped_image[z*width*height + y*width + x] = value;

                        //TODO: supersampling would help in regions for which warping misses voxels in the source image,
                        //better: adaptive supersampling would automatically use the vector field divergence to determine where
                        //to super sample and by how much.  
                    }
                }
                }
                '''

    if image.ndim != 3 or vector_field.ndim != 4:
        raise ValueError("image or vector field has wrong number of dimensions!")

    # set up textures:
    input_image_tex, input_image_cudarr = create_cuda_texture(image,
                                                              num_channels=1,
                                                              normalised_coords=False,
                                                              sampling_mode='linear',
                                                              address_mode=mode)

    vector_field = cupy.pad(vector_field, pad_width=((0, 0),) * 3 + ((0, 1),), mode='constant')

    vector_field_tex, vector_field_cudarr = create_cuda_texture(vector_field,
                                                                num_channels=4,
                                                                normalised_coords=True,
                                                                sampling_mode='linear',
                                                                address_mode='clamp')

    # Set up resulting image:
    warped_image = xp.empty(shape=image.shape, dtype=image.dtype)

    # get the kernel, which copies from texture memory
    warp_3d_kernel = cupy.RawKernel(source, 'warp_3d')

    # launch kernel
    depth, height, width = image.shape

    grid_x = (width + block_size - 1) // block_size
    grid_y = (height + block_size - 1) // block_size
    grid_z = (depth + block_size - 1) // block_size
    warp_3d_kernel((grid_x, grid_y, grid_z),
                   (block_size,) * 3,
                   (warped_image, input_image_tex, vector_field_tex, width, height, depth))

    del input_image_tex, input_image_cudarr, vector_field_tex, vector_field_cudarr

    return warped_image
