# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_repr_rgb.ipynb.

# %% auto 0
__all__ = ['rgb']

# %% ../nbs/01_repr_rgb.ipynb 4
from PIL import Image
import torch

from .utils.pad import pad_frame_gutters
from .utils.tile2d import hypertile

# %% ../nbs/01_repr_rgb.ipynb 5
# This is here for the monkey-patched tensor use case.

# I want to be able to call both `tensor.rgb` and `tensor.rgb(stats)`. For the
# first case, the class defines `_repr_png_` to send the image to Jupyter. For
# the later case, it defines __call__, which accps the argument.

class RGBProxy():
    """Flexible `PIL.Image.Image` wrapper"""
    @torch.no_grad()
    def __init__(self, t:torch.Tensor):
        super().__init__()
        # assert t.dim() == 3, f"Expecting a 3-dim tensor, got {t.shape}={t.dim()}"
        self.t = t.detach().cpu()

    @torch.no_grad()
    def __call__(self, denorm=None, cl=False, 
                    gutter_px=3, frame_px=1, view_width=966):
        t = self.t

        # This object might linger in PyTorch history.
        # Del the tensor, since it won't be needed after this call.
        del self.t 

        # swap channels if it's not channe-last already
        if not cl:
            # Is there any easy way to .permute() without knowing the number of dims?
            t = torch.swapaxes(torch.swapaxes(t, -3, -1), -3, -2)
            
        n_ch = t.shape[-1]
        assert n_ch in (3, 4), f"Expecting 3 (RGB) or 4 (RGBA) channels, got {n_ch}" 
        if denorm:
            means = torch.tensor(denorm[0])
            stds = torch.tensor(denorm[1])
            t = t.mul(stds).add(means)

        if t.ndim > 3:
            t = hypertile(  t=t,
                            gutter_px=gutter_px,
                            frame_px=frame_px,
                            view_width=view_width)

        return Image.fromarray(t.mul(255).byte().numpy())
    
    @torch.no_grad()
    def _repr_png_(self):
        # Note: In order to prevernt IPYthon from hogging memory, we
        # delete the reference to the tensor after the first call to
        # `_repr_png_`. This is fine for Jupyter use.
        return self.__call__()._repr_png_()


# %% ../nbs/01_repr_rgb.ipynb 6
def rgb(t: torch.Tensor, # Tensor to display. [[...], C,H,W] or [[...], H,W,C]
            denorm=None, # Reverse per-channel normalizatoin
            cl=False,    # Channel-last
            gutter_px = 3,  # If more than one tensor -> tile with this gutter width
            frame_px=1,  # If more than one tensor -> tile with this frame width
            view_width=966): # targer width of the image
    return RGBProxy(t)( denorm=denorm, cl=cl,
                        gutter_px=gutter_px,
                        frame_px=frame_px,
                        view_width=view_width)
