# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/04_data.transforms.ipynb.

# %% ../../nbs/04_data.transforms.ipynb 3
from __future__ import annotations
from fastai.vision.all import *
from fastai.data.all import *
from fastcore.basics import patch
from pathlib import Path
import pandas as pd
from scipy.signal import savgol_filter
from tqdm import tqdm

# %% auto 0
__all__ = ['Spectra', 'SpectraTfm', 'ToAbsorbanceTfm', 'SpectraBlock', 'RandWAvgTfm', 'AvgTfm', 'SNVTfm', 'NormalizeTfm',
           'TrimTfm', 'DerivTfm', 'Analytes', 'AnalytesTfm', 'AnalytesBlock', 'LogTfm']

# %% ../../nbs/04_data.transforms.ipynb 7
class Spectra(Tensor):
    "A 'showable' spectra class subclassing torch.Tensor"
    domain = domain_name = light = None
        
    def show(self, ctx=None, wn=np.arange(4000, 600, -2), figsize=(8,2), **kwargs):
        spectra = self
        if ctx is None: _,ctx = plt.subplots(figsize=figsize)
        
        if Spectra.domain_name == 'wavenumber':
            ctx.set_xlim(np.max(Spectra.domain), np.min(Spectra.domain))

        ctx.set(xlabel=Spectra.domain_name.capitalize() + ' →', 
                ylabel=Spectra.light.capitalize())
        
        ctx.set_xlabel(Spectra.domain_name.capitalize() + '→', loc='right')
        ctx.set_ylabel(Spectra.light.capitalize() + ' →', loc='top')
        
        ctx.set_axisbelow(True)
        for spectrum in spectra:
            _ = ctx.plot(Spectra.domain, spectrum, c='steelblue', lw=1)
        ctx.grid(True, which='both')
        return ctx
    
    @classmethod
    def reset(cls):
        cls.domain = cls.domain_name = cls.light = None

# %% ../../nbs/04_data.transforms.ipynb 8
class SpectraTfm(Transform):
    "Transform folder path into a torch array of dimension: (n_replicates, n_wavenumbers)"        
    def __init__(self):
        Spectra.reset()
    
    def encodes(self, 
                o:L # list of spectrum replicates
               ):
        n, m  = len(pd.read_csv(o[0])), len(o)
        x = np.empty((m,n))
        for i, fname in enumerate(o):
            df = pd.read_csv(fname)
            if Spectra.light is None:
                Spectra.domain_name, Spectra.light = df.columns
                Spectra.domain = df[Spectra.domain_name].values
            x[i,:] = df[Spectra.light].values
        return Spectra(torch.Tensor(x)) 

# %% ../../nbs/04_data.transforms.ipynb 10
class ToAbsorbanceTfm(Transform):
    "Transform spectra replicates taking their average"
    def encodes(self, x:Spectra):
        return torch.log10(x)

# %% ../../nbs/04_data.transforms.ipynb 11
def SpectraBlock():
    return TransformBlock(type_tfms=SpectraTfm())

# %% ../../nbs/04_data.transforms.ipynb 12
class RandWAvgTfm(Transform):
    "Transform spectra replicates taking their random weighted averages for data augmentation"
    def encodes(self, x:Spectra):
        n = len(x)
        return torch.matmul(self._weights(n), x)  
    def _weights(self, n):
        weights = torch.rand(n)
        return (weights/weights.sum()).unsqueeze(dim=0)

# %% ../../nbs/04_data.transforms.ipynb 13
class AvgTfm(Transform):
    "Transform spectra replicates taking their average"
    def encodes(self, x:Spectra):
        return torch.mean(x, dim=0, keepdim=True)

# %% ../../nbs/04_data.transforms.ipynb 14
class SNVTfm(Transform):
    "Standard Normal Variate Transform of input spectrum"
    def encodes(self, x:Spectra):
        mean, std = torch.mean(x), torch.std(x)
        return (x - mean)/std

# %% ../../nbs/04_data.transforms.ipynb 15
class NormalizeTfm(Transform):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std
    def encodes(self, x:Spectra):
        return (x - self.mean)/self.std

# %% ../../nbs/04_data.transforms.ipynb 16
class TrimTfm(Transform):
    def encodes(self, 
                x:Spectra
               ):
        val = x[:,1675]
        x[:,1675:] = val
        return x

# %% ../../nbs/04_data.transforms.ipynb 17
class DerivTfm(Transform):
    def __init__(self, window_length=11, polyorder=1, deriv=1):
        self.window_length = window_length
        self.polyorder = polyorder
        self.deriv = deriv
        
    def encodes(self, 
                x:Spectra
               ):
        x = savgol_filter(x.detach().numpy(), 
                          self.window_length, self.polyorder, self.deriv)
        return Spectra(torch.Tensor(x))

# %% ../../nbs/04_data.transforms.ipynb 20
class Analytes(Tensor, ShowTitle):
    def __init__(self, ys):
        self.ys = ys
        # Ideally, w'd like to show (idx: 1234 | 725: 0.123 | 433: 0.7) decoded

# %% ../../nbs/04_data.transforms.ipynb 21
class AnalytesTfm(Transform):
    "Transform a folder path into a tensor of soil analyte(s) measurement"
    def __init__(self, 
                 analytes:list|None=None):
        self.analytes = analytes
        
    def encodes(self,
                o:Path # Path to directory containing both spectra and analyte(s) measurement
               ):
        df = pd.read_csv(o)
        if self.analytes:
            df = df[df.analyte.isin(self.analytes)]
        return Analytes(df['value'].values)

# %% ../../nbs/04_data.transforms.ipynb 23
def AnalytesBlock(analytes):
    return TransformBlock(type_tfms=AnalytesTfm(analytes))

# %% ../../nbs/04_data.transforms.ipynb 24
class LogTfm(Transform):
    def encodes(self, ys:Analytes):
        return torch.log10(ys)
