from locale import normalize
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrow
import ipywidgets as ipw

from .show_plots import show_edge_scatter, colorize_raster
from ..splineutils import edge_colored_by_features, spline_curvature

def animate_edge_vect(data, param, res, fig, ax, k, curvature=False):
    """Animate an image of the contour as generated by show_plots.show_edge_scatter().

    Parameters
    ----------
    data : data object
    param : param object
    res : result object
    fig : matplotlib figure
    ax : matplotlib axis
    k : int
        frame
    curvature : bool, optional
        use curvature for coloring, by default False
    """
    ax.set_title(f'Frame {k-1} to {k}')
    image = data.load_frame_morpho(k)
    for a in ax.get_children():
        if isinstance(a, FancyArrow):
            a.remove()

    if curvature:
        f = spline_curvature(res.spline[k], np.linspace(0, 1, param.n_curve + 1))
    else:
        f = res.displacement[:, k]

    ax.get_images()[0].set_data(image)

    ax.lines.pop(0)
    ax.lines.pop(0)

    fig, ax = show_edge_scatter(
        param.n_curve,
        res.spline[k - 1],  # res.spline[k],
        res.spline[k],  # res.spline[k + 1],
        res.param0[k],
        res.param[k],
        f,
        fig_ax=(fig, ax),
    )

def interact_edge_vect(data, param, res, fig, ax, curvature=False):
    """Create slider for edege animation.

    Parameters
    ----------
    data : data object
    param : param object
    res : result object
    fig : matplotlib figure
    ax : matplotlib axis

    Returns
    -------
    int_box
        ipywidget slider
    """
    int_box = ipw.interactive(
        lambda k: animate_edge_vect(data, param, res, fig, ax, k, curvature=curvature),
        k=ipw.IntSlider(1, min=1, max=data.K-2))
    return int_box

def animate_edge_raster_coloured_by_feature(
    data, param, res, k, N, feature, fig, ax, width=1, min_val=None,
    max_val=None, cmap_name='seismic', alpha=0.5):
    """Animate image of rasterized contour.

    Parameters
    ----------
    data : data object
    param : param object
    res : result object
    k : int
        frame
    N : int
        number of points on contour
    feature : str
        feature for coloring 'displacement', 'displacement_cumul', 'curvature'
    fig : matplotlib figure
    ax : matplotlib axis
    width : int, optional
        width of contour for display, by default 1
    min_val : float, optional
        min value to display, by default min of image
    max_val : [type], optional
        max value to display, by default max of image
    cmap_name : str
        Matplotlib colormap
    alpha : float, optional
        transparency of image, by default 0.5
    """

    im_disp, mask = edge_colored_by_features(
        data, res, t=k, feature=feature, N=N, enlarge_width=width)
    im_disp_coloured = colorize_raster(
        im_disp, cmap_name=cmap_name, 
        min_val=min_val, max_val=max_val,
        mask=mask, alpha=alpha)
    ax.get_images()[1].set_data(im_disp_coloured)
    ax.get_images()[0].set_data(data.load_frame_morpho(k))


def interact_edge_raster_coloured_by_feature(
    data, param, res, N, feature, fig, ax, width=1,
    normalize=False, cmap_name='seismic', alpha=0.5):
    """Create interactive slider for colored rasterized contour image.

    Parameters
    ----------
    data : data object
    param : param object
    res : result object
    N : int
        number of points on contour
    feature : str
        feature for coloring 'displacement', 'displacement_cumul', 'curvature'
    fig : matplotlib figure
    ax : matplotlib axis
    width : int, optional
        width of contour for display, by default 1
    normalize : bool, optional
        normalized across time-lapse, by default False
    cmap_name : str
        Matplotlib colormap
    alpha : float, optional
        transparency of image, by default 0.5

    Returns
    -------
    int_box
        ipywidget time slider
    """
    
    min_val = None
    max_val = None
    if normalize:
        if feature == 'displacement':
            min_val = res.displacement.min()
            max_val = res.displacement.max()
        elif feature == 'displacement_cumul':
            min_val = np.cumsum(res.displacement, axis=1).min()
            max_val = np.cumsum(res.displacement, axis=1).max()
            
    int_box = ipw.interactive(lambda k: animate_edge_raster_coloured_by_feature(
        data, param, res, k, N, feature, fig, ax, width, min_val, max_val, cmap_name=cmap_name, alpha=alpha), k=ipw.IntSlider(1, min=1, max=data.K-2))
    return int_box


