from functools import reduce
from ..interface import element_wise_shape,numpy
from ...yu import array_index_traversal, multi_range
from enum import Enum
class ConvolutionMode(Enum):
    valid = 0
    full = 1


convolution_map = {
    'valid': ConvolutionMode.valid,
    'full': ConvolutionMode.full,
}


def __get_convolution_mode_string(mode):
    if isinstance(mode, str):
        if mode in convolution_map:
            return mode
        else:
            raise ValueError('No such convolution mode: {}'.format(mode))
    elif isinstance(mode, ConvolutionMode):
        return mode.name
    else:
        raise ValueError('Invalid mode type: {}'.format(type(mode)))


def basic_convolution_shape(shape_data, shape_kernel, dimension: int, mode: str):
    if mode == 'valid':
        return tuple(shape_data[i] - shape_kernel[i] + 1 for i in range(-dimension, 0))
    elif mode == 'full':
        return tuple(shape_data[i] + shape_kernel[i] - 1 for i in range(-dimension, 0))
    else:
        raise ValueError('Invalid convolution mode: {}'.format(mode))

def __compute_valid_convolution_nd(data, kernel, dimension: int):
    convolution_shape = tuple(data.shape[i] - kernel.shape[i] + 1 for i in range(-1, -dimension - 1, -1))
    list_dimension = reduce(lambda a, b: a * b, convolution_shape)
    data_prefix = data.shape[:-dimension]
    kernel_flat = kernel.ravel()
    data_flat = numpy.zeros(data_prefix + (list_dimension, len(kernel_flat)))
    for i in range(list_dimension):
        tensor_slice_start = [0] * len(kernel.shape)
        tensor_slice = [slice(None)] * len(data.shape)
        tensor_slice_start[-1] = i
        for r in range(-1, -len(kernel.shape) - 1, -1):
            dimension_scale = data.shape[r] - kernel.shape[r] + 1
            if tensor_slice_start[r] >= dimension_scale:
                tensor_slice_start[r + 1] = tensor_slice_start[r] // dimension_scale
                tensor_slice_start[r] %= dimension_scale
            tensor_slice[r] = slice(tensor_slice_start[r], tensor_slice_start[r] + kernel.shape[r])
        sub_convolution_index = (slice(None),) * (len(data.shape) - dimension) + tuple([i, slice(None)])
        data_flat[tuple(sub_convolution_index)] = data[tuple(tensor_slice)].reshape(data_prefix + (reduce(lambda a, b: a * b, kernel.shape),))
    convolution_flat = numpy.matmul(data_flat, numpy.flip(kernel_flat, axis=0))
    convolution_nd = convolution_flat.reshape(data_prefix + convolution_shape)
    return convolution_nd

def __compute_convolution_nd(data, kernel, dimension: int, mode: str):
    mode_string = __get_convolution_mode_string(mode)
    for i in range(dimension):
        if data.shape[i] < kernel.shape[i]:
            raise ValueError('Data shape smaller than kernel shape: {} {}'.format(data.shape, kernel.shape))
    if mode_string == 'valid':
        return __compute_valid_convolution_nd(data, kernel, dimension)
    elif mode_string == 'full':
        data_prefix = data.shape[:-dimension]
        expand_data = numpy.zeros(data_prefix + tuple(data.shape[i] + (kernel.shape[i] - 1) * 2 for i in range(dimension)))
        padding = tuple(kernel.shape[i] - 1 for i in range(dimension))
        expand_data[tuple(slice(None) for _ in data_prefix) + tuple(slice(padding[i], padding[i] + data.shape[i]) for i in range(dimension))] = data
        return __compute_valid_convolution_nd(expand_data, kernel, dimension)
    else:
        raise ValueError('Never reached.')


def compute_convolution_nd(data, kernel, dimension: int, mode=ConvolutionMode.valid, element_wise: bool=False):
    mode_string = __get_convolution_mode_string(mode)
    result = []
    data_prefix_shape = data.shape[:-dimension]
    kernel_prefix_shape = kernel.shape[:-dimension]
    if element_wise:
        final_shape = element_wise_shape(data_prefix_shape, kernel_prefix_shape)[0]
        data = numpy.broadcast_to(data, final_shape + data.shape[-2:])
        kernel = numpy.broadcast_to(kernel, final_shape + kernel.shape[-2:])
        if final_shape:
            for index in array_index_traversal(final_shape):
                result.append(__compute_convolution_nd(data[index], kernel[index], dimension, mode_string))
            return numpy.array(result).reshape(final_shape + result[0].shape)
        else:
            return __compute_convolution_nd(data, kernel, dimension, mode_string)
    else:
        if kernel_prefix_shape:
            final_shape = data_prefix_shape + kernel_prefix_shape + basic_convolution_shape(data.shape[-dimension:], kernel.shape[-dimension:], dimension, mode_string)
            result = numpy.zeros(final_shape)
            for kernel_index in array_index_traversal(kernel_prefix_shape):
                sub_result_index = tuple(slice(None) for _ in data_prefix_shape) + kernel_index + tuple(slice(None) for _ in range(dimension))
                result[sub_result_index] = __compute_convolution_nd(data, kernel[kernel_index], dimension, mode_string)
            return result
        else:
            return __compute_convolution_nd(data, kernel, dimension, mode_string)

def __compute_max_pooling_nd(data, size, step, dimension: int, reference=None):
    for i in range(dimension):
        if data.shape[i] < size[i]:
            raise ValueError('Data shape smaller than size: {} {}'.format(data.shape, size))
    pooling_array = []
    pooling_grid = [range(0, data.shape[i] - size[i] + 1, step[i]) for i in range(dimension)]
    for index in multi_range(pooling_grid):
        sub_slice = [slice(index[i], index[i] + size[i]) for i in range(dimension)]
        if reference is None:
            pooling_array.append(numpy.max(data[tuple(sub_slice)]))
        else:
            max_index = numpy.argmax(reference[sub_slice])
            sub_data = data[sub_slice]
            pooling_array.append(sub_data[numpy.unravel_index(max_index, sub_data.shape)])
    
    return numpy.array(pooling_array).reshape([len(g) for g in pooling_grid])

def compute_max_pooling_nd(data, size, step, dimension: int, reference=None):
    result = []
    data_prefix_shape = data.shape[:-dimension]
    if data_prefix_shape:
        for key in array_index_traversal(data_prefix_shape):
            if reference is None:
                result.append(__compute_max_pooling_nd(data[key], size, step, dimension))
            else:
                result.append(__compute_max_pooling_nd(data[key], size, step, dimension, reference[key]))
        return numpy.array(result).reshape(data_prefix_shape + result[0].shape)
    else:
        return __compute_max_pooling_nd(data, size, step, dimension)


def __compute_max_unpooling_nd(data, pooling, size, step, dimension: int):
    for i in range(dimension):
        if data.shape[i] < size[i]:
            raise ValueError('Data shape smaller than size: {} {}'.format(data.shape, size))
    unpooling_array = numpy.zeros(data.shape)
    unpooling_grid = [range(0, data.shape[i] - size[i] + 1, step[i]) for i in range(dimension)]
    for n, index in enumerate(multi_range(unpooling_grid)):
        sub_slice = tuple([slice(index[i], index[i] + size[i]) for i in range(dimension)])
        max_index = numpy.argmax(data[sub_slice])
        sub_unpooling_array = unpooling_array[sub_slice]
        sub_unpooling_array[numpy.unravel_index(max_index, sub_unpooling_array.shape)] = pooling[numpy.unravel_index(n, pooling.shape)]
    return unpooling_array


def compute_max_unpooling_nd(data, pooling, size, step, dimension: int):
    result = []
    data_prefix_shape = data.shape[:-dimension]
    kernel_prefix_shape = pooling.shape[:-dimension]
    final_shape = element_wise_shape(data_prefix_shape, kernel_prefix_shape)[0]
    data = numpy.broadcast_to(data, final_shape + data.shape[-dimension:])
    pooling = numpy.broadcast_to(pooling, final_shape + pooling.shape[-dimension:])
    if final_shape:
        for key in array_index_traversal(final_shape):
            result.append(__compute_max_unpooling_nd(data[key], pooling[key], size, step, dimension))
        return numpy.array(result).reshape(final_shape + result[0].shape)
    else:
        return __compute_max_unpooling_nd(data, pooling, size, step, dimension)


def __compute_average_pooling_nd(data, size, step, dimension: int):
    for i in range(dimension):
        if data.shape[i] < size[i]:
            raise ValueError('Data shape smaller than size: {} {}'.format(data.shape, size))
    pooling_array = []
    pooling_grid = [range(0, data.shape[i] - size[i] + 1, step[i]) for i in range(dimension)]
    for index in multi_range(pooling_grid):
        pooling_array.append(numpy.mean(data[tuple([slice(index[i], index[i] + size[i]) for i in range(dimension)])]))
    return numpy.array(pooling_array).reshape([len(g) for g in pooling_grid])


def compute_average_pooling_nd(data, size, step, dimension: int):
    result = []
    data_prefix_shape = data.shape[:-dimension]
    if data_prefix_shape:
        for key in array_index_traversal(data_prefix_shape):
            result.append(__compute_average_pooling_nd(data[key], size, step, dimension))
        return numpy.array(result).reshape(data_prefix_shape + result[0].shape)
    else:
        return __compute_average_pooling_nd(data, size, step, dimension)


def __compute_average_unpooling_nd(pooling, size, step, dimension: int, unpooling_size=None):
    if unpooling_size is None:
        unpooling_array = numpy.zeros([size[i] + (pooling.shape[i] - 1) * step[i] for i in range(dimension)])
    else:
        unpooling_array = numpy.zeros(unpooling_size)
    unpooling_grid = [range(0, unpooling_array.shape[i] - size[i] + 1, step[i]) for i in range(dimension)]
    for n, index in enumerate(multi_range(unpooling_grid)):
        sub_slice = tuple(slice(index[i], index[i] + size[i]) for i in range(dimension))
        unpooling_array[sub_slice] += pooling[numpy.unravel_index(n, pooling.shape)]
    return unpooling_array


def compute_average_unpooling_nd(pooling, size, step, dimension: int, unpooling_size=None):
    result = []
    data_prefix_shape = pooling.shape[:-dimension]
    if data_prefix_shape:
        for key in array_index_traversal(data_prefix_shape):
            result.append(__compute_average_unpooling_nd(pooling[key], size, step, dimension, unpooling_size))
        return numpy.array(result).reshape(data_prefix_shape + result[0].shape)
    else:
        return __compute_average_unpooling_nd(pooling, size, step, dimension, unpooling_size)
