# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/05_augment_PIL-img_filters.ipynb (unless otherwise specified).

__all__ = ['is_3dlut_row', 'read_lut', 'ApplyPILFilter']

# Cell
try:
    from fastai.vision.all import *
except:
    from fastai2.vision.all import *
from PIL import ImageFilter
from typing import List, Tuple, Callable, Union, Optional, Any

# Cell
def is_3dlut_row(row:list) -> bool:
    'Check if one line in the file has exactly 3 values'
    row_values = []
    for val in row:
        try: row_values.append(float(val))
        except: continue
    if len(row_values) == 3: return True
    else: return False

def read_lut(path_lut:Union[str,Path], num_channels:int=3):
    'Read LUT from raw file. Assumes each line in a file is part of the lut table'
    with open(path_lut) as f: lut_raw = f.read().splitlines()

    size      = round(len(lut_raw) ** (1/3))
    row2val   = lambda row: tuple([float(val) for val in row])
    lut_table = [row2val(row.split(' ')) for row in lut_raw if is_3dlut_row(row.split(' '))]

    return ImageFilter.Color3DLUT(size, lut_table, num_channels)

# Cell
class ApplyPILFilter(RandTransform):
    "Apply a `PIL.ImageFilter` and return as a PILImage"
    order = 0 # Apply before `ToTensor`
    def __init__(self, filters, p=1.):
        super().__init__(p=p)
        self.filter = filters

    def select_filter(self, o):
        'If multiple `filters` are given, select and apply one'
        if isinstance(self.filter, (tuple,list,L)):
              rand_idx = np.random.randint(0, len(self.filter))
              return o.filter(self.filter[rand_idx])
        else: return o.filter(self.filter)

    #def _encodes(self, o:(PILImage,TensorImage,str,Path)): return TensorImage(self.select_filter(o)).permute(2,0,1)
    def _encodes(self, o): return PILImage(self.select_filter(o))

    def encodes(self, o:PILImage):               return self._encodes(o)
    def encodes(self, o:(TensorImage,str,Path)): return self._encodes(PILImage.create(o))