# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/00_core.ipynb (unless otherwise specified).

__all__ = ['map_dict_ex', 'D', 'tensor2shape', 'tensor2mu', 'TensorBatch', 'obj2tensor', 'BD']

# Cell
# Python native modules
import os,warnings
# Third party libs
from fastcore.all import *
from fastai.torch_core import *
from fastai.basics import *
import pandas as pd
import torch
import numpy as np
# Local modules

# Cell
def map_dict_ex(d,f,*args,gen=False,wise=None,**kwargs):
    "Like `map`, but for dicts and uses `bind`, and supports `str` and indexing"
    g = (bind(f,*args,**kwargs) if callable(f)
         else f.format if isinstance(f,str)
         else f.__getitem__)

    if wise is None:  return map(g,d.items())
    return ((k,g(v)) if wise=='value' else (g(k),v) for k,v in d.items())

# Cell
_error_msg='Found idxs: %s have values more than %s e.g.: %s'

class D(dict):
    "Improved version of `dict` with array handling abilities"
    def __init__(self,*args,mapping=False,**kwargs):
        self.mapping=mapping
        super().__init__(*args,**kwargs)

    def eq_k(self,o:'D',with_diff=False):
        eq=set(o.keys())==set(self.keys())
        if with_diff: return eq,set(o.keys()).symmetric_difference(set(self.keys()))
        return eq
    def _new(self,*args,**kwargs): return type(self)(*args,**kwargs)

    def map(self,f,*args,gen=False,**kwargs):
        return (self._new,noop)[gen](map_dict_ex(self,f,*args,**kwargs),mapping=True)
    def mapk(self,f,*args,gen=False,wise='key',**kwargs):
        return self.map(f,*args,gen=gen,wise=wise,**kwargs)
    def mapv(self,f,*args,gen=False,wise='value',**kwargs):
        return self.map(f,*args,gen=gen,wise=wise,**kwargs)

# Cell
def tensor2shape(k,t:'TensorBatch',relative_shape=False):
    "Converts a tensor into a dict of shapes, or a 1d numpy array"
    return {
        k:t.cpu().numpy().reshape(-1,) if len(t.shape)==2 and t.shape[1]==1 else
        [str((1,*t.shape[1:]) if relative_shape else t.shape)]*t.shape[0]
    }

# Cell
def tensor2mu(k,t:Tensor): return {f'{k}_mu':t.reshape(t.shape[0],-1).double().mean(axis=1)}
tensor2mu.__docs__="Returns a dict with key `k`_mu with the mean of `t` batchwise "

# Cell
class TensorBatch(TensorBase):
    "A tensor assumes a batch dimension"
    def __new__(cls, x, bs=1,**kwargs):
        if isinstance(x,cls): bs=x.bs()
        res=super(TensorBatch,cls).__new__(cls,x,**kwargs)
        assert res.shape[0]==bs,f'Tensor has shape {res.shape} while bs is {bs}'
        return res

    def bs(self): return self.shape[0]
    def get(self,*args):
        "Get a possible subset of a tensor while maintaining a batch dim."
        res=self[args]
        if len(self.shape)>len(res.shape): res=res.unsqueeze(0)
        return res

    @classmethod
    def vstack(cls,*args):
        return cls(torch.vstack(*args),bs=L(*args).map(cls).map(Self.bs()).sum())

def obj2tensor(o):
    return (o if isinstance(o,TensorBatch) else
            TensorBatch(o) if isinstance(o,(L,list,np.ndarray,Tensor,TensorBatch)) else
            TensorBatch([o]))

def _get_bs(o): return o.bs if isinstance(o,TensorBatch) else TensorBatch(o).bs

# export
class BD(D):
    def __init__(self,*args,**kwargs):
        super().__init__(*args,**kwargs)
        self.bs=list(self.values())[0].bs

    def __radd__(self,o): return self if isinstance(o,int) else self.__add__(o)
    def __add__(self,o):
        return BD({k:TensorBatch.vstack((self[k],o[k])) for k in self})

    def __getitem__(self,o):
        if is_listy(o) or isinstance(o,(TensorBatch,int,Tensor)):
            return type(self)({k:self[k].get(o) for k in self})
        return super().__getitem__(o)

    @classmethod
    def merge(cls,*ds,**kwargs): return cls(merge(*ds),**kwargs)
    @delegates(pd.DataFrame)
    def pandas(self,mu=False,relative_shape=False,**kwargs):
        "Turns a `BD` into a pandas Dataframe optionally showing `mu` of values."
        return pd.DataFrame(merge(
            *tuple(tensor2shape(k,v,relative_shape) for k,v in self.items()),
            *(tuple(tensor2mu(k,v) for k,v in self.items()) if mu else ())
        ),**kwargs)