# AUTOGENERATED! DO NOT EDIT! File to edit: ../04_mr_artifacts.ipynb.

# %% auto 0
__all__ = ['spikeTransform', 'get_spike_series', 'eval_spike_series', 'spikePosTransform', 'get_spike_pos_series',
           'eval_spike_pos_series', 'biasfieldTransform', 'get_biasfield_series', 'eval_biasfield_series']

# %% ../04_mr_artifacts.ipynb 5
from torchio.transforms import Spike
from misas.core import gif_series

# %% ../04_mr_artifacts.ipynb 7
from PIL import Image, ImageEnhance, ImageShow, ImageOps
import numpy as np
from misas.core import get_generic_series, plot_series, eval_generic_series
from misas.fastai_model import Fastai2_model
from functools import partial
import torch

# %% ../04_mr_artifacts.ipynb 9
def spikeTransform(image, intensityFactor, spikePosition=[.1,.1]):
    #data = image.data[0].unsqueeze(0)
    #data = np.array(image)[0].unsequeeze(0)
    image = np.array(image)[:,:,0]
    image = np.expand_dims(np.array(image), 0)
    image = image/255
    spikePosition = [[0.0] + spikePosition]
    spike = Spike(spikePosition, intensityFactor)
    #print(np.array(image).shape)
    image = torch.Tensor(image)
    image = spike.add_artifact(image, spikePosition, intensityFactor)[0]
    image = torch.stack((image,image,image))
    image = np.array((torch.clamp(image,0,1)))
    image = image*255
    image = image.astype(np.uint8)
    image = np.moveaxis(image, 0, 2)
    image = Image.fromarray(image)
    return image

def get_spike_series(image, model, start=0, end=2.5, step=.5, spikePosition=[.1,.1], **kwargs):
    return get_generic_series(image,model,partial(spikeTransform,spikePosition=spikePosition), start=start, end=end, step=step, **kwargs)

# %% ../04_mr_artifacts.ipynb 22
def eval_spike_series(image, mask, model, step=.1, start=0, end=2.5, spikePosition=[.1,.1], param_name="intensity", **kwargs):
    return eval_generic_series(
        image,
        mask,
        model,
        partial(spikeTransform,spikePosition=spikePosition),
        start=start,
        end=end,
        step=step,
        mask_transform_function=None,
        param_name=param_name,
        **kwargs
    )

# %% ../04_mr_artifacts.ipynb 25
def spikePosTransform(image, spikePositionX, spikePositionY=0.1, intensityFactor=0.5):
    image = np.array(image)[:,:,0]
    image = np.expand_dims(np.array(image), 0)
    image = image/255
    image = torch.Tensor(image)
    spikePosition = [[0.0] + [spikePositionX, spikePositionY]]
    spike = Spike(spikePosition, intensityFactor)
    image = spike.add_artifact(image, spikePosition, intensityFactor)[0]
    image = torch.stack((image,image,image))
    image = np.array((torch.clamp(image,0,1)))
    image = image*255
    image = image.astype(np.uint8)
    image = np.moveaxis(image, 0, 2)
    image = Image.fromarray(image)
    return image

def get_spike_pos_series(image, model, start=0.1, end=0.9, step=.1, intensityFactor=0.5, spikePositionY=0.1, **kwargs):
    return get_generic_series(image,
                              model,
                              #spikePosTransform,
                              partial(spikePosTransform, intensityFactor=intensityFactor, spikePositionY=spikePositionY),
                              start=start, 
                              end=end,
                              step=step,
                              **kwargs)

# %% ../04_mr_artifacts.ipynb 27
def eval_spike_pos_series(image, mask, model, step=0.1, start=0.1, end=0.9, intensityFactor=0.1, param_name="Spike X Position", **kwargs):
    return eval_generic_series(
        image,
        mask,
        model,
        partial(spikePosTransform, intensityFactor=intensityFactor),
        start=start,
        end=end,
        step=step,
        mask_transform_function=None,
        param_name=param_name,
        **kwargs
    )

# %% ../04_mr_artifacts.ipynb 33
from torchio import RandomBiasField, BiasField

# %% ../04_mr_artifacts.ipynb 40
def biasfieldTransform(image, coef, order=3):
    image = np.array(image)[:,:,0]
    image = np.expand_dims(np.array(image), 0)
    image = np.expand_dims(np.array(image), 0)
    image = image/255
    image = torch.Tensor(image)
    coefficients = RandomBiasField().get_params(3,[coef,coef])
    bf = BiasField.generate_bias_field(image, order=3, coefficients=coefficients)
    bf[0] = torch.clamp(torch.Tensor(bf[0]),0,1)
    image = image[0][0] * bf[0]
    image = torch.stack((image,image,image))
    image = np.array((torch.clamp(image,0,1)))
    image = image*255
    image = image.astype(np.uint8)
    image = np.moveaxis(image, 0, 2)
    image = Image.fromarray(image)
    return image 

def get_biasfield_series(image, model, start=0, end=-.6, step=-.2, order=3, **kwargs):
    return get_generic_series(image,model,partial(biasfieldTransform,order=order), start=start, end=end, step=step, **kwargs)

# %% ../04_mr_artifacts.ipynb 43
def eval_biasfield_series(image, mask, model, step=-.05, start=0, end=-.55, order=3, **kwargs):
    return eval_generic_series(
        image,
        mask,
        model,
        partial(biasfieldTransform,order=order),
        start=start,
        end=end,
        step=step,
        mask_transform_function=None,
        param_name="coefficient",
        **kwargs
    )
