# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/04_util.ipynb.

# %% auto 0
__all__ = ['DataSplits']

# %% ../nbs/04_util.ipynb 6
from functools import wraps, partial

# %% ../nbs/04_util.ipynb 8
from types import FunctionType, ModuleType
from typing import Optional, Iterable

# %% ../nbs/04_util.ipynb 11
#| export

# %% ../nbs/04_util.ipynb 13
try: import pandas as pd
except ImportError: ...

# %% ../nbs/04_util.ipynb 15
try: import torch
except ImportError: ...

# %% ../nbs/04_util.ipynb 17
#| export


# %% ../nbs/04_util.ipynb 19
from chck import istuple, isdict, isnone
from pcts import (permute_indicies, normalize_percents, diff_percents, sum_percents)

# %% ../nbs/04_util.ipynb 21
from .cons import (FIT, VAL, TEST, PRED)

# %% ../nbs/04_util.ipynb 24
class DataSplits:
    _names = (FIT, VAL, TEST, PRED)
    
    def __init__(self, fit: float | None = 1., val: float | None = 0., test: float | None = 0., pred: float | None = 0.):
        pcts = normalize_percents(fit, val, test, pred)
        sums = sum_percents(*pcts, total=1)
        diff = diff_percents(*pcts, sums=sums, total=1)  
        self.fit, self.val, self.test, self.pred = pcts
        self.acc = sums
        self.slc = diff
        
    def astuple(self) -> tuple[float, float, float, float]: 
        return (self.fit, self.val, self.test, self.pred)
    
    def asdict(self) -> dict[str, float]: 
        return dict(zip(self._names, self.astuple()))
    
    @classmethod
    def fromdict(cls, dct: dict, **kwargs) -> 'DataSplits': 
        kws = dct.copy()
        kws.update(kwargs)
        for k, v in kws.items():
            if k not in cls._names: kws.pop(k)
        return cls(**kws)
    
    @classmethod
    def make(cls, ins: tuple[float, ...] | dict[str, float] | Optional['DataSplits'], **kwargs) -> Optional['DataSplits']: 
        if isnone(ins): return None
        if not isinstance(ins, (tuple, dict, DataSplits)): return None
        if istuple(ins): return cls(*ins)
        elif isdict(ins): return cls.fromdict(ins, **kwargs)
        elif isinstance(ins, cls): return ins
        else: return cls.fromdict(**kwargs)
        
    def __iter__(self) -> Iterable[tuple[float, float, float, float]]:
        return iter((self.fit, self.val, self.test, self.pred))
    
    def __len__(self) -> int: return 4
    
    def __repr__(self):
        return f'datasplits{self.fit, self.val, self.test, self.pred}'
    
    def indicies(self, total: int = 1):
        return permute_indicies(total, *self)

# %% ../nbs/04_util.ipynb 27
#| export
