# AUTOGENERATED! DO NOT EDIT! File to edit: ../01_local_interpret.ipynb.

# %% auto 0
__all__ = ['get_generic_series', 'plot_series', 'plot_frame', 'gif_series', 'eval_generic_series', 'plot_eval_series',
           'rotationTransform', 'get_rotation_series', 'eval_rotation_series', 'cropTransform', 'get_crop_series',
           'eval_crop_series', 'brightnessTransform', 'get_brightness_series', 'eval_bright_series',
           'contrastTransform', 'get_contrast_series', 'eval_contrast_series', 'zoomTransform', 'get_zoom_series',
           'eval_zoom_series', 'dihedralTransform', 'get_dihedral_series', 'eval_dihedral_series', 'resizeTransform',
           'get_resize_series', 'eval_resize_series', 'get_confusion', 'plot_confusion', 'plot_confusion_series']

# %% ../01_local_interpret.ipynb 4
from PIL import Image, ImageEnhance, ImageOps
from functools import partial
import itertools
import pandas as pd
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import numpy as np
import altair as alt
from io import BytesIO as Buffer

# %% ../01_local_interpret.ipynb 7
def dice_by_component(predictedMask, trueMask, component = 1):
    """
    Calculates the dice score (the overlap between the mask predicted 
    by the model and the true mask that the user supplies, while 0 equals 
    no overlap and 1 equals 100% overlap)
    user supplies the predictedMask
    and trueMask through the function eval_generic_series, component's
    standard value is set to 1, but gets overwritten by 
    eval_generic_series)
    """
    dice = 1.0
    pred = np.array(predictedMask) == component
    msk = np.array(trueMask) == component
    intersect = pred&msk
    total = np.sum(pred) + np.sum(msk)
    if total > 0:
        dice = 2 * np.sum(intersect).astype(float) / total
    return dice

# %% ../01_local_interpret.ipynb 8
def recall_by_component(predictedMask, trueMask, component = 1):
    """
    Calculates the recall score (percentage of trueMask that is included 
    in the overlap of trueMask and predictedMask)
    """
    recall = 1.0
    pred = np.array(predictedMask) == component
    msk = np.array(trueMask) == component
    intersect = pred&msk
    total = np.sum(pred) + np.sum(msk) 
    if total > 0:
        recall = np.sum(intersect).astype(float) / np.sum(msk)
    return recall

# %% ../01_local_interpret.ipynb 9
def precision_by_component(predictedMask, trueMask, component = 1):
    """
    Calculates the precision score (percentage of predictedMask that is included 
    in the overlap of trueMask and predictedMask)
    """
    precision = 1.0
    pred = np.array(predictedMask) == component
    msk = np.array(trueMask) == component
    intersect = pred&msk
    total = np.sum(pred) + np.sum(msk) 
    if total > 0:
        precision = 0 
        if np.sum(pred) > 0:
            precision = np.sum(intersect).astype(float) / np.sum(pred)
    return precision

# %% ../01_local_interpret.ipynb 16
from matplotlib import cm
from matplotlib.colors import ListedColormap

default_cmap = cm.viridis(np.arange(cm.viridis.N))
default_cmap[0][-1] = 0
default_cmap = ListedColormap(default_cmap)

default_cmap_true_mask = cm.plasma_r(np.arange(cm.plasma.N))
default_cmap_true_mask[0][-1] = 0
default_cmap_true_mask = ListedColormap(default_cmap_true_mask)

# %% ../01_local_interpret.ipynb 18
def get_generic_series(image,
        model,
        transform,
        truth=None,
        tfm_y=False,
        start=0,
        end=180,
        step=30,
        log_steps=False,
    ):
    """
    Generic function for transforming images.
    Input: image (PIP image, usually your sample image opened by img()),
    model (the function for your model that manages the prediction for your mask),
    transform (your specific transformation function), 
    truth = None (replaces with a true mask if available),
    tfm_y = False (set to True if your true mask has to be transformed as well to fit the
    transformed sample image e.g in case of a rotation of the sample image),
    start, end, step as values the transform function,
    log_steps = False (if enabled logarithmic steps as parameters 
    for the transform function are possible)
    Output: a list containing lists of [param, img, pred, trueMask] 
    after img and optionally trueMask have been transformed
    and pred has been determined by using a modell on the transformed img for each different param
    """
    series = []
    trueMask = None
    steps = np.arange(start,end,step)
    if log_steps:
        steps = np.exp2(np.linspace(np.log2(start),np.log2(end),round((np.log2(end)-np.log2(start))/np.log2(step)+1)))
    for param in tqdm(steps, leave=False):
        img = image
        img = transform(img, param)
        if hasattr(model,"prepareSize"):
            img = model.prepareSize(img)
        #pred = model.predict(img)[0]
        pred = model.predict(img)
        series.append([param,img,pred])
        if truth:
            trueMask = truth
            if tfm_y:
                trueMask = transform(trueMask, param)
            if hasattr(model,"prepareSize"):
                trueMask = model.prepareSize(trueMask)
        series[-1].append(trueMask)
    return series

# %% ../01_local_interpret.ipynb 19
def plot_series(
        series,
        nrow=1,
        figsize=(16,6),
        param_name='param',
        overlay_truth=False,
        vmax=None,
        vmin=0,
        cmap=default_cmap,
        cmap_true_mask = default_cmap_true_mask,
        **kwargs
    ):
    """
    plots the transformed images with the prediction and optionally the true mask overlayed
    intput:
    series = a list containing lists of [param, img, pred, trueMask] from the function get_generic_series
    nrow = number of rows drawn with the transformed images
    figsize = (16,6)
    param_name='param',
    overlay_truth = False (if True displays the true mask 
    over the sample along with the prediction)
    vmax = None (controls how many colors the prediction is going to have, can be set manually by the user, 
    otherwise is deterimed by the max amount of colors in the prediction
    cmap= default_cmap (sets the default color map)
    output: a plot generated by mathplotlib
    """
    fig, axs = plt.subplots(nrow,int(np.ceil(len(series)/nrow)),figsize=figsize,**kwargs)
    #fig.subplots_adjust(left=0.1, bottom=0.1, right=0.9, top=0.9)
    if vmax is None:
        vmax = max([x[2].getextrema()[1] for x in series])
        #vmax = x[1]
        if series[0][3]:
            vmax_truth = max([x[3].getextrema()[1] for x in series])
            vmax = max(vmax_truth, vmax)
    #if vmin is None:
    #    vmin = min([x[2].getextrema()[0] for x in series])
    #    if series[0][3]:
    #        vmin_truth = min([x[3].getextrema()[1] for x in series])
    #        vmin = min(vmin_truth, vmin)
    for element, ax in zip(series, axs.flatten()):
        param,img,pred,truth = element
        ax.imshow(np.array(img))
        ax.imshow(np.array(pred), vmax=vmax,cmap=cmap, vmin=vmin, alpha=.5, interpolation="nearest")
        ax.set_title(f"{param_name}:{np.around(param,decimals=2)}")
        ax.axes.xaxis.set_visible(False)
        ax.axes.yaxis.set_visible(False)
        
        if overlay_truth and truth:
            ax.imshow(np.array(truth), alpha = 0.2, cmap = cmap_true_mask, interpolation="nearest")

# %% ../01_local_interpret.ipynb 20
def plot_frame(param, img, pred, param_name="param",vmax=None, vmin=0, cmap=default_cmap,**kwargs):
    """
    plots the transformed images and prediction overlayed for the gif_series function
    """
    _,ax = plt.subplots(**kwargs)
    ax.imshow(img)
    ax.imshow(np.array(pred), vmax=vmax,cmap=cmap, vmin=vmin, alpha=.5, interpolation="nearest")
    ax.set_title(f"{param_name}:{np.around(param,decimals=2)}")
    ax.axes.xaxis.set_visible(False)
    ax.axes.yaxis.set_visible(False)
    buffer = Buffer() #thanks to https://github.com/maxhumber/gif/
    plt.savefig(buffer, format="png")
    plt.close()
    buffer.seek(0)
    img = Image.open(buffer)
    return img
    


# %% ../01_local_interpret.ipynb 21
def gif_series(series, fname, duration=150, param_name="param", vmax=None, vmin=0, cmap=default_cmap):
    """
    creates a gif from the output of plot_frame
    """
    if vmax is None:
        vmax = max([x[2].getextrema()[1] for x in series])
    #if vmin is None:
    #    vmin = min([x[2].getextrema()[0] for x in series])
    frames = [plot_frame(*x[:3], param_name=param_name, vmax=vmax, cmap=cmap, vmin=vmin) for x in series]
    #gif.save(frames, fname, duration=duration)
    frames[0].save(fname,
               save_all = True, append_images = frames[1:],
               optimize = False, duration = duration, disposal=2, loop=0)

# %% ../01_local_interpret.ipynb 22
def eval_generic_series(
        image,
        mask,
        model,
        transform_function,
        start=0,
        end=360,
        step=5,
        param_name="param",
        mask_transform_function=None,
        components=['bg', 'c1','c2'],
        eval_function=dice_by_component,
        mask_prepareSize=True
    ):
    """
    Perform the transformation on the sample, creates a prediction and 
    then uses the prediction and true mask to run an evaluation function 
    to measure the overlap between predicted mask and true mask
    """
    results = list()
    for param in tqdm(np.arange(start, end, step), leave=False):
        img = image
        img = transform_function(img, param)
        trueMask = mask
        if mask_transform_function:
            trueMask = mask_transform_function(trueMask, param)
        if hasattr(model,"prepareSize"):
            img = model.prepareSize(img)
            if mask_prepareSize:
                trueMask = model.prepareSize(trueMask)
        #prediction = model.predict(img)[0]
        prediction = model.predict(img)
        # prediction._px = prediction._px.float()
        result = [param]
        for i in range(len(components)):
            result.append(eval_function(prediction, trueMask, component = i))
        results.append(result)

    results = pd.DataFrame(results,columns = [param_name, *components])
    return results

# %% ../01_local_interpret.ipynb 23
def plot_eval_series(results, chart_type="line", value_vars=None, value_name='Dice Score'):
    """
    Plots the resuls of the eval_generic_function
    """
    id_var = results.columns[0]
    if not value_vars:
        value_vars = results.columns[2:]
    plot = alt.Chart(results.melt(id_vars=[id_var],value_vars=value_vars, value_name=value_name))
    if chart_type=="line":
        plot = plot.mark_line()
    elif chart_type=="point":
        plot = plot.mark_point(size=80)
    else:
        raise ValueError(f'chart_type must be one of "line" or "point"')
    plot = plot.encode(
      x=id_var,
      y=value_name,
      color=alt.Color("variable"),#,legend=None),
      tooltip=value_name
    ).properties(width=700,height=300).interactive()
    return plot#.configure_axis(title=None,labels=False,ticks=False)

# %% ../01_local_interpret.ipynb 25
def rotationTransform(image, deg):
    """
    rotates an image by x degrees (deg)
    """
    return image.rotate(int(deg))
    

def get_rotation_series(image, model, start=0, end=361, step=60, **kwargs):
    """ 
    runs the get_generic_series with rotationTransform as transform 
    """
    return get_generic_series(image,model,rotationTransform, start=start, end=end, step=step, tfm_y = True, **kwargs)

# %% ../01_local_interpret.ipynb 28
def eval_rotation_series(image, mask, model, step=5, start=0, end=360,  param_name="deg", **kwargs):
    return eval_generic_series(
        image,
        mask,
        model,
        rotationTransform,
        start=start,
        end=end,
        step=step,
        mask_transform_function=rotationTransform,
        param_name=param_name,
        **kwargs
    )

# %% ../01_local_interpret.ipynb 34
def cropTransform(image, pxls, finalSize = None):
    image = ImageOps.fit(image, (finalSize))
    image = ImageOps.crop(image, (pxls))
    image = ImageOps.crop(image, (-pxls))
    return image



def get_crop_series(image, model, start=0, end=256, step=10, finalSize = None, **kwargs):
    if finalSize == None:
        finalSize = image.size
    if type(finalSize) == int:
        finalSize = (finalSize, finalSize)
    if end >= min(finalSize)//2:
        end = min(finalSize)//2

    return get_generic_series(image,model,partial(cropTransform,finalSize=finalSize), start=start, end=end, step=step, tfm_y = True, **kwargs)

# %% ../01_local_interpret.ipynb 37
def eval_crop_series(image, mask, model, step=10, start=0, end=256, finalSize=None, param_name="pixels", **kwargs):
    finalmaskSize = finalSize
    if finalSize == None:
        finalSize = image.size
        finalmaskSize = mask.size
    if type(finalSize) == int:
        finalSize = (finalSize, finalSize)
        finalmaskSize = mask.size
    if end >= min(finalSize)//2:
        end = min(finalSize)//2
    return eval_generic_series(
        image,
        mask,
        model,
        partial(cropTransform,finalSize=finalSize),
        start=start,
        end=end,
        step=step,
        mask_transform_function=partial(cropTransform,finalSize=finalmaskSize),
        param_name=param_name,
        **kwargs
    )

# %% ../01_local_interpret.ipynb 43
def brightnessTransform(image, light):
    enhancer = ImageEnhance.Brightness(image)
    image = enhancer.enhance (light)
    return image    

def get_brightness_series(image, model, start=0.25, end=8, step=np.sqrt(2), log_steps = True, **kwargs):
    return get_generic_series(image,model,brightnessTransform, start=start, end=end, step=step, log_steps=log_steps, **kwargs)

# %% ../01_local_interpret.ipynb 46
def eval_bright_series(image, mask, model, start=0.05, end=.95, step=0.05, param_name="brightness", **kwargs):
    return eval_generic_series(
        image,
        mask,
        model,
        brightnessTransform,
        start=start,
        end=end,
        step=step,
        param_name=param_name,
        **kwargs
    )

# %% ../01_local_interpret.ipynb 51
def contrastTransform(image, scale):
    enhancer = ImageEnhance.Contrast(image)
    image = enhancer.enhance (scale)
    return image
def get_contrast_series(image, model, start=0.25, end=8, step=np.sqrt(2),log_steps = True, **kwargs):
    return get_generic_series(image,model,contrastTransform, start=start, end=end, step=step,log_steps = log_steps, **kwargs)

# %% ../01_local_interpret.ipynb 54
def eval_contrast_series(image, mask, model, start=0.25, end=8, step=np.sqrt(2), param_name="contrast", **kwargs):
    return eval_generic_series(
        image,
        mask,
        model,
        contrastTransform,
        start=start,
        end=end,
        step=step,
        param_name=param_name,
        **kwargs
    )

# %% ../01_local_interpret.ipynb 59
def zoomTransform(image, zoom, finalSize= None):
    max_zoom = ((image.size[0]//2-1),(image.size[1]//2)-1)
    zoom_factor = []
    zoom_factor.append (round(max_zoom [0] * zoom)) #min_zoom = (0,0)
    zoom_factor.append (round(max_zoom [1] * zoom))
    zoom_factor_tuple=tuple(zoom_factor)
    image = ImageOps.crop (image, zoom_factor_tuple)
    image = ImageOps.pad (image ,(finalSize))
    return image

def get_zoom_series(image, model, start=0, end=1, step=.1, finalSize= None, **kwargs):
    if finalSize == None:
        finalSize = image.size
    if type(finalSize) == int:
        finalSize = (finalSize, finalSize)
    if end > 1:
        end = 1
    return get_generic_series(image,model,partial(zoomTransform,finalSize=finalSize), start=start, end=end, step=step, tfm_y = True, **kwargs)

# %% ../01_local_interpret.ipynb 62
def eval_zoom_series(image, mask, model, step=0.1, start=0, end=1, finalSize=None, param_name="scale", **kwargs):
    finalmaskSize = finalSize
    if finalSize == None:
        finalSize = image.size
        finalmaskSize = mask.size
    if type(finalSize) == int:
        finalSize = (finalSize, finalSize)
        finalmaskSize = mask.size
    if end > 1:
        end = 1
    return eval_generic_series(
        image,
        mask,
        model,
        partial(zoomTransform,finalSize=finalSize),
        start=start,
        end=end,
        step=step,
        mask_transform_function=partial(zoomTransform,finalSize=finalmaskSize),
        param_name=param_name,
        **kwargs
    )

# %% ../01_local_interpret.ipynb 67
def dihedralTransform(image, sym_im):
    rot = [0, 90, 180, 270]
    flip = [False, True]
    dihedral = list (itertools.product (rot, flip))
    image = image.rotate (dihedral[sym_im][0])
    if dihedral [sym_im][1] == True:
        #image = ImageOps.flip(image)
        image = ImageOps.mirror(image)
    return image


def get_dihedral_series(image, model, start=0, end=8, step=1, **kwargs):
    return get_generic_series(image,model,dihedralTransform, start=start, end=end, step=step, tfm_y = True, **kwargs)

# %% ../01_local_interpret.ipynb 70
def eval_dihedral_series(image, mask, model, start=0, end=8, step=1, param_name="k", **kwargs):
    return eval_generic_series(
        image,
        mask,
        model,
        dihedralTransform,
        start=start,
        end=end,
        step=step,
        param_name=param_name,
        mask_transform_function=dihedralTransform,
        **kwargs
    )

# %% ../01_local_interpret.ipynb 75
def resizeTransform(image, size):
    #size_original = image.size
    image=ImageOps.contain (image, (size,size))
    #image = ImageOps.fit(image, size_original)
    return image


# %% ../01_local_interpret.ipynb 76
def get_resize_series(image, model, start=10, end=200, step=30, **kwargs):
    return get_generic_series(image,model,resizeTransform, start=start, end=end, step=step, tfm_y = False, **kwargs)

# %% ../01_local_interpret.ipynb 79
def eval_resize_series(image, mask, model, start=22, end=3000, step=100, param_name="px", **kwargs):
    return eval_generic_series(
        image,
        mask,
        model,
        resizeTransform,
        start=start,
        end=end,
        step=step,
        param_name=param_name,
        mask_transform_function=resizeTransform,
        **kwargs
    )

# %% ../01_local_interpret.ipynb 98
def get_confusion(prediction, truth, max_class=None):
    if not max_class:
        max_class = max(np.array(prediction).max(), np.array(truth).max())
    # https://stackoverflow.com/a/50023660
    cm = np.zeros((max_class+1, max_class+1), dtype=int)
    np.add.at(cm, (np.array(prediction), np.array(truth)), 1)
    return cm

# %% ../01_local_interpret.ipynb 103
def plot_confusion(confusion_matrix, norm_axis=0, components=None, ax=None, ax_label=True, cmap="Blues"):
    cm = confusion_matrix / confusion_matrix.sum(axis=norm_axis, keepdims=True)
    if not components:
        components = ["c" + str(i) for i in range(cm.shape[0])]
    if not ax:
        _, ax = plt.subplots()
    ax.imshow(cm, cmap=cmap)
    
    # We want to show all ticks...
    ax.set_xticks(np.arange(len(components)))
    ax.set_yticks(np.arange(len(components)))
    
    # ... and label them with the respective list entries
    ax.set_xticklabels(components)
    ax.set_yticklabels(components)
    
    # label axes
    if ax_label:
        ax.set_xlabel("truth")
        ax.set_ylabel("prediction")
    
    # label cells
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            text = ax.text(j, i, round(cm[i, j],2),
                           ha="center", va="center")
    
    return ax

# %% ../01_local_interpret.ipynb 106
def plot_confusion_series(
        series,
        nrow=1,
        figsize=(16,6),
        param_name='param',
        cmap="Blues",
        components=None,
        norm_axis=0,
        **kwargs
    ):
    fig, axs = plt.subplots(nrow,int(np.ceil(len(series)/nrow)),figsize=figsize,**kwargs)
    for element, ax in zip(series, axs.flatten()):
        param,img,pred,truth = element
        cm = get_confusion(pred,truth)
        plot_confusion(cm, ax=ax, components=components, ax_label=False, norm_axis=norm_axis, cmap=cmap)
