# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_tensors.ipynb.

# %% auto 0
__all__ = ['PRINT_OPTS', 'tensor_str', 'lovely', 'rgb', 'monkey_patch']

# %% ../nbs/00_tensors.ipynb 2
from nbdev.showdoc import *
from typing import Optional

from PIL import Image
import torch

from fastcore.test import test_eq
from fastcore.foundation import patch_to

# import wandb

# %% ../nbs/00_tensors.ipynb 3
class __PrinterOptions(object):
    precision: int = 3
    threshold_max: int = 3 # .abs() larger than 1e3 -> Sci mode
    threshold_min: int = -4 # .abs() smaller that 1e-4 -> Sci mode
    sci_mode: Optional[bool] = None # None = auto. Otherwise, force sci mode.
    indent: int = 2 # Indent for .deeper()
    color: bool = False


# %% ../nbs/00_tensors.ipynb 4
PRINT_OPTS = __PrinterOptions()

# %% ../nbs/00_tensors.ipynb 5
# Do we want this float in decimal or scientific mode?
def sci_mode(f: float):
    return (abs(f) < 10**(PRINT_OPTS.threshold_min) or
            abs(f) > 10**PRINT_OPTS.threshold_max)

# %% ../nbs/00_tensors.ipynb 8
# Convert a tensor into a string.
# This only looks good for small tensors, which is how it's intended to be used.
def tensor_str(t: torch.Tensor):
    """A slightly better way to print `float` values"""
    if t.dim() == 0:
        v = t.item()
        if t.is_floating_point():
            if not t.is_nonzero():
                return "0."

            sci = (PRINT_OPTS.sci_mode or
                    (PRINT_OPTS.sci_mode is None and sci_mode(v)))

            # The f-string will generate something like "{.4f}", which is used
            # to format the value.
            return f"{{:.{PRINT_OPTS.precision}{'e' if sci else 'f'}}}".format(v)
        else:
            return '{}'.format(v) # Should we use sci mode for large ints too?
    else:
        slices = [tensor_str(t[i]) for i in range(0, t.size(0))]
        return '[' + ", ".join(slices) + ']'

# %% ../nbs/00_tensors.ipynb 13
def space_join(lst):
    # Join non-empty list elements into a space-sepaeated string
    return " ".join( [ l for l in lst if l] )

# %% ../nbs/00_tensors.ipynb 15
class LovelyProxy():
    def __init__(self, t: torch.Tensor, plain=False, verbose=False, depth=0, lvl=0):
        self.t = t
        self.plain = plain
        self.verbose = verbose
        self.depth=depth
        self.lvl=lvl

    @torch.no_grad()
    def to_str(self):
        t = self.t
        if self.plain:
            return torch._tensor_str._tensor_str(t, indent=0)

        
        grey_style = "\x1b[38;2;127;127;127m" if PRINT_OPTS.color else ""
        red_style = "\x1b[31m" if PRINT_OPTS.color else ""
        end_style = "\x1b[0m" if PRINT_OPTS.color else ""

        tname = "tensor" if type(t) in [torch.Tensor, torch.nn.Parameter] else type(t).__name__

        grad_fn = "grad_fn" if t.grad_fn else None
        # All tensors along the compute path actually have required_grad=True. Torch __repr__ just dones not show it.
        grad = "grad" if not t.grad_fn and t.requires_grad else None

        shape = str(list(t.shape))

        zeros = grey_style+"all_zeros"+end_style if not t.count_nonzero() else None
        pinf = red_style+"+inf!"+end_style if t.isposinf().any() else None
        ninf = red_style+"-inf!"+end_style if t.isneginf().any() else None
        nan = red_style+"nan!"+end_style if t.isnan().any() else None

        # zeros = "all_zeros" if not t.count_nonzero() else None
        # pinf = "+inf!" if t.isposinf().any() else None
        # ninf = "-inf!" if t.isneginf().any() else None
        # nan = "nan!" if t.isnan().any() else None

        attention = space_join([zeros,pinf,ninf,nan])

        x = ""
        summary = f"n={t.numel()}" if t.numel() > 5 else None
        if not zeros:
            if t.numel() <= 10: x = " x=" + tensor_str(t)

            # Make sure it's float32. Also, we calculate stats on good values only.
            ft = t.float()[  torch.isfinite(t) ]

            minmax = f"x∈[{tensor_str(ft.min())}, {tensor_str(ft.max())}]" if t.numel() > 2 and ft.numel() > 2 else None
            meanstd = f"μ={tensor_str(ft.mean())} σ={tensor_str(ft.std())}" if t.numel() >= 2 and ft.numel() >= 2 else None
            numel = f"n={t.numel()}" if t.numel() > 5 else None

            summary = space_join([numel, minmax, meanstd])

        dtnames = { torch.float32: "",
                    torch.float16: "f16",
                    torch.float64: "f64",
                    torch.uint8: "u8",
                    torch.int32: "i32",
                }

        dtype = dtnames[t.dtype] if t.dtype in dtnames else str(t.dtype)[6:]
        dev = str(t.device) if t.device.type != "cpu" else None

        res = tname + space_join([shape,summary,dtype,grad,grad_fn,dev,attention])

        res += ("\nx=" + torch._tensor_str._tensor_str(t, indent=PRINT_OPTS.indent) if self.verbose else x)

        if self.depth and t.dim() > 1:
            res += "\n"
            # for i in range(t.shape[0]):
                # str = 
            res += "\n".join([
                " "*PRINT_OPTS.indent*(self.lvl+1) +
                str(LovelyProxy(t[i,:], depth=self.depth-1, lvl=self.lvl+1))
                for i in range(t.shape[0])])

        return res
    
    def __repr__(self):
        return self.to_str()

    def __call__(self, depth=0):
        return LovelyProxy(self.t, depth=depth)


# %% ../nbs/00_tensors.ipynb 17
def lovely(t: torch.Tensor, verbose=False, plain=False, depth=0):
    return LovelyProxy(t, verbose=verbose, plain=plain, depth=depth)

# %% ../nbs/00_tensors.ipynb 32
# This is here for the monkey-patched tensor use case.

# I want to be able to call both `tensor.rgb` and `tensor.rgb(stats)`. For the
# first case, the class defines `_repr_png_` to send the image to Jupyter. For
# the later case, it defines __call__, which accps the argument.
class ProxyImage(Image.Image): 
    """Flexible `PIL.Image.Image` wrapper"""
    @torch.no_grad()
    def __init__(self, t:torch.Tensor):
        super().__init__()
        self.t = t.permute(1, 2, 0)

        # Mode and size - to be used by super().__repr__()
        self.mode = "RGB"
        self._size = tuple(t.shape[1:])

    @torch.no_grad()
    def __call__(self, denorm=None):            
        means = torch.tensor(denorm["mean"] if denorm else (0., 0., 0.,))
        stds = torch.tensor(denorm["std"] if denorm else (1., 1., 1.,))
        return Image.fromarray((self.t*stds+means).mul(255).byte().numpy())

    @torch.no_grad()
    def _repr_png_(self):
        "Jupyter PNG representation"
        return Image.fromarray(self.t.mul(255).byte().numpy())._repr_png_()


# %% ../nbs/00_tensors.ipynb 33
def rgb(t: torch.Tensor, denorm=None):
    return ProxyImage(t)(denorm)

# %% ../nbs/00_tensors.ipynb 36
def monkey_patch(cls=torch.Tensor):
    "Monkey-patch lovely features into `cls`" 

    @patch_to(cls)
    def __repr__(self: torch.Tensor, *, tensor_contents=None):        
        return str(LovelyProxy(self))

    # Keep an esiy way to get the standard behavior.
    @patch_to(cls, as_prop=True)
    def plain(self: torch.Tensor, *, tensor_contents=None):
        return LovelyProxy(self, plain=True)

    # And a verbose option for a good measure.
    @patch_to(cls, as_prop=True)
    def verbose(self: torch.Tensor, *, tensor_contents=None):
        return LovelyProxy(self, verbose=True)

    @patch_to(cls, as_prop=True)
    def deeper(self: torch.Tensor):
        return LovelyProxy(self, depth=1)

    @patch_to(cls, as_prop=True)
    def rgb(t: torch.Tensor):
        return ProxyImage(t)
