# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/03c_utils.tile2d.ipynb.

# %% auto 0
__all__ = ['tile2d', 'hypertile']

# %% ../../nbs/03c_utils.tile2d.ipynb 3
from math import floor, ceil, log2
import numpy as np
from .. import lovely
from .pad import pad_frame_gutters


# %% ../../nbs/03c_utils.tile2d.ipynb 8
def fit_columns(t: np.ndarray, # Tensor with images, shape=(n,h,w,c)
                view_width=966):
    """Find out how many colums and rows to use to display the images"""
    
    assert t.ndim == 4
    # Let's figure out how many images can we put in a row without the need for
    # re-scaling. Let's try to keep the number as power or 2 if we have to have
    # multiple rows.
    
    n_img = t.shape[0]
    width = t.shape[-2]
    
    n_cols = 2**floor(log2((view_width / width)))

    # At least 1 image per row, even if it does not fit the view without rescaling.
    n_cols = max(1, n_cols)

    # But if we actually don't have enough images to fill a single
    # power-of-two row, just display as many as we got.
    n_cols = min(n_img, n_cols)
    
    n_rows = ceil(n_img / n_cols) # Last row might have free space.
    
    # Avoid producing tilings that are very wide and short.
    while n_rows < n_cols/3:
        n_cols //=2
        n_rows = ceil(n_img / n_cols)

    return (n_rows, n_cols)

# %% ../../nbs/03c_utils.tile2d.ipynb 16
def tile2d(t: np.ndarray,      # Tensor containing images, shape=(n,h,w,c)
            view_width=966):   # Try to protuce an images at most this wide
    """
    Tile images in a grid.
    """
    assert t.ndim == 4
    assert t.shape[-1] in (3, 4) # Either RGB or RGBA.

    
    n_images = t.shape[0]
    n_channels = t.shape[-1]
    xy_shape = t.shape[1:3]

    n_rows, n_cols = fit_columns(t, view_width=view_width)

    # We need to form the images inro a rectangular area. For this, we might
    # need to add some dummy images to the last row, whoch might be not be full.
    n_extra_images = n_rows*n_cols - t.shape[0]
    if n_extra_images:
        extra_images = np.ones((n_extra_images, *t.shape[1:]))
        # extra_images = torch.ones((n_extra_images, *t.shape[1:]))
        t = np.concatenate([ t, extra_images ])
    
    # This is where the fun begins! Imagine 't' is tensor[20, 128, 128, 3].
    # and we want 5 rows, 4 columns each.
    
    t = t.reshape(n_rows, n_cols, *t.shape[-3:])
    # Now t is tensor[5, 4, 128, 128, 3]

    t = t.transpose(0, 2, 1, 3, 4)
    # now t is tensor[5, 128, 4, 128, 3]
    # If we just squick dimensions 0,1 and 2,3 togerther, we get the image we want.
    t = t.reshape(n_rows*xy_shape[0], n_cols*xy_shape[1], n_channels)
    
    # Now t is tensor[640, 512, 3], channel-last.
    
    return t

# %% ../../nbs/03c_utils.tile2d.ipynb 22
def hypertile(t: np.ndarray, #torch.Tensor, # input tensor, shape=([...], B, H, W, C)
            frame_px=1,        # Frame width for the innermost group
            gutter_px=3,       # Gutter width for the innermost group
            view_width=966):   # Try to protuce an images at most this wide

    "Recursively tile images on a 2d grid"

    assert t.ndim >= 4, f"Tiling requires at least 3 dimensions: H, W, C. got {t.shape}"
    level = t.ndim - 3

    if t.ndim > 4:
        tlist = list(t)
        sub_view_width = view_width - (gutter_px - frame_px) * level * 2
        tiles = [ hypertile(item,
                            frame_px=frame_px,
                            gutter_px=gutter_px,
                            view_width=sub_view_width)
                  for item in tlist ]
        t = np.stack(tiles)

    return tile2d(pad_frame_gutters(t,
                                     frame_px=frame_px*level,
                                     gutter_px=gutter_px*level),
                    view_width=view_width)
