# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/06_utils.ipynb (unless otherwise specified).

__all__ = ['ensemble_results', 'plot_results', 'iou', 'mean_entropy', 'label_mask', 'get_candidates', 'iou_mapping',
           'staple', 'mvoting']

# Cell
import numpy as np
from scipy import ndimage
from scipy.spatial.distance import jaccard
from scipy.stats import entropy
from skimage.feature import peak_local_max
from skimage.segmentation import clear_border
from skimage.measure import label
from skimage.segmentation import relabel_sequential
from skimage.morphology import watershed
from scipy.optimize import linear_sum_assignment
import SimpleITK as sitk

# Cell
def ensemble_results(res_dict, file, idx=0):
    "Combines single model predictions."
    stack_tmp = [np.array(res_dict[(mod, f)][idx]) for mod, f in res_dict if f==file]
    stack_tmp = np.mean(stack_tmp, axis=0)
    return np.argmax(stack_tmp, axis=-1)

# Cell
def plot_results(*args, df):
    "Plot images, masks, predictions and uncertainties side-by-side."
    img, msk, pred, pred_std = args
    fig, axs = plt.subplots(nrows=1, ncols=4, figsize=(20, 20))
    axs[0].imshow(img)
    axs[0].set_axis_off()
    axs[0].set_title(f'File {df.file}')
    axs[1].imshow(msk)
    axs[1].set_axis_off()
    axs[1].set_title('Target')
    axs[2].imshow(pred)
    axs[2].set_axis_off()
    axs[2].set_title(f'Prediction \n IoU: {df.iou:.2f}')
    axs[3].imshow(pred_std)
    axs[3].set_axis_off()
    axs[3].set_title(f'Uncertainty \n Entropy: {df.entropy:.2f}')
    plt.show()

# Cell
def iou(a,b,threshold=0.5):
    '''Computes the Intersection-Over-Union metric.'''
    a = np.array(a) > threshold
    b = np.array(b) > threshold
    overlap = a*b # Logical AND
    union = a+b # Logical OR
    return np.count_nonzero(overlap)/np.count_nonzero(union)

# Cell
def mean_entropy(a):
    '''Computes the mean entropy.'''
    return entropy(a).mean()

# Cell
def label_mask(mask, threshold=0.5, min_pixel=15, do_watershed=True, exclude_border=False):
    '''Analyze regions and return labels'''
    if mask.ndim == 3:
        mask = np.squeeze(mask, axis=2)

    # apply threshold to mask
    # bw = closing(mask > threshold, square(2))
    bw = (mask > threshold).astype(int)

    # label image regions
    label_image = label(bw, connectivity=2) # Falk p.13, 8-“connectivity”.

    # Watershed: Separates objects in image by generate the markers
    # as local maxima of the distance to the background
    if do_watershed:
        distance = ndimage.distance_transform_edt(bw)
        # Minimum number of pixels separating peaks in a region of `2 * min_distance + 1`
        # (i.e. peaks are separated by at least `min_distance`)
        min_distance = int(np.ceil(np.sqrt(min_pixel / np.pi)))
        local_maxi = peak_local_max(distance, indices=False, exclude_border=False,
                                    min_distance=min_distance, labels=label_image)
        markers = label(local_maxi)
        label_image = watershed(-distance, markers, mask=bw)

    # remove artifacts connected to image border
    if exclude_border:
        label_image = clear_border(label_image)

    # remove areas < min pixel
    unique, counts = np.unique(label_image, return_counts=True)
    label_image[np.isin(label_image, unique[counts<min_pixel])] = 0

    # re-label image
    label_image, _ , _ = relabel_sequential(label_image, offset=1)

    return (label_image)

# Cell
def get_candidates(labels_a, labels_b):
    '''Get candiate masks for ROI-wise analysis'''

    label_stack = np.dstack((labels_a, labels_b))
    cadidates = np.unique(label_stack.reshape(-1, label_stack.shape[2]), axis=0)
    # Remove Zero Entries
    cadidates = cadidates[np.prod(cadidates, axis=1) > 0]
    return(cadidates)

# Cell
def iou_mapping(labels_a, labels_b, min_roi_size=30):
    '''Compare masks using ROI-wise analysis'''

    candidates = get_candidates(labels_a, labels_b)

    if candidates.size > 0:
        # create a similarity matrix
        dim_a = np.max(candidates[:,0])+1
        dim_b = np.max(candidates[:,1])+1
        similarity_matrix = np.zeros((dim_a, dim_b))

        for x,y in candidates:
            roi_a = (labels_a == x).astype(np.uint8).flatten()
            roi_b = (labels_b == y).astype(np.uint8).flatten()
            similarity_matrix[x,y] = 1-jaccard(roi_a, roi_b)

        row_ind, col_ind = linear_sum_assignment(-similarity_matrix)

        return(similarity_matrix[row_ind,col_ind],
               row_ind, col_ind,
               np.max(labels_a),
               np.max(labels_b)
               )
    else:
        return([],
               np.nan, np.nan,
               np.max(labels_a),
               np.max(labels_b)
               )

# Cell
def staple(segmentations, foregroundValue = 1, threshold = 0.5):
    'STAPLE: Simultaneous Truth and Performance Level Estimation with simple ITK'

    segmentations = [sitk.GetImageFromArray(x) for x in segmentations]
    STAPLE_probabilities = sitk.STAPLE(segmentations)
    STAPLE = STAPLE_probabilities > threshold
    return sitk.GetArrayViewFromImage(STAPLE)

# Cell
def mvoting(segmentations, labelForUndecidedPixels = 0):
    'Majority Voting from  simple ITK Label Voting'

    segmentations = [sitk.GetImageFromArray(x) for x in segmentations]
    mv_segmentation = sitk.LabelVoting(segmentations, labelForUndecidedPixels)
    return sitk.GetArrayViewFromImage(mv_segmentation)