# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/03a_utils.colormap.ipynb.

# %% auto 0
__all__ = ['InfCmap']

# %% ../../nbs/03a_utils.colormap.ipynb 3
from typing import Optional as  O
import numpy as np
import matplotlib as mpl, matplotlib.cm as cm
from matplotlib.colors import Colormap, to_rgba

from ..repr_rgb import rgb

# %% ../../nbs/03a_utils.colormap.ipynb 5
def get_cmap(cmap: str) -> Colormap:
    # Matplotlib changed the colormap interface in version 3.6, and immediately
    # marked the old one as deprecated with a warning. I want to suppot
    # both for the time being, and avoid the warning for people using 3.6+.
    major, minor, *rest = mpl.__version__.split(".")
    assert int(major) == 3 # Drop this compat code when mpl is at 4.0

    if int(minor) <= 5:
        return cm.get_cmap(cmap)
    else:
        return mpl.colormaps[cmap]


# %% ../../nbs/03a_utils.colormap.ipynb 21
class InfCmap():
    """
    Matplotlib colormap extended to have colors for +/-inf

    Parameters extept `cmap` are matplotlib color strings.
    """
    def __init__(self,
                 cmap:  Colormap, # Base matplotlib colormap
                 below: O[str] =None, # Values below 0
                 above: O[str] =None, # Values above 1
                 nan:   O[str] =None, # NaNs
                 ninf:  O[str] =None, # -inf
                 pinf:  O[str] =None, # +inf
                ):
        _ = cmap(0) # one call to make sure the cmap is initialized
        lut = cmap._lut.copy()
        cmax = cmap.N-1
        assert len(lut) == cmap.N+3, "Unexpected colormap LUT size"
        
        
        if below: lut[cmax+1] = np.array(to_rgba(below))
        if above: lut[cmax+2] = np.array(to_rgba(above))
        if nan: lut[cmax+3] = np.array(to_rgba(nan))
        
        # For +/- inf, use above/below as defaults.
        tensor_cmap_ninf = np.array(to_rgba(ninf)) if ninf else lut[cmax+1]
        tensor_cmap_pinf = np.array(to_rgba(pinf)) if pinf else lut[cmax+2]

        # Remove the alpha channel, it causes problems in pad_frame_gutters().
        self.lut = np.concatenate([ lut, tensor_cmap_ninf[None], tensor_cmap_pinf[None] ])[:,:3]
        self.cmax = cmax

    def __call__(self, t: np.ndarray):
        vals = ((t + 1) / 2)
        cmax = self.cmax
        lut_idxs = (vals * cmax).astype(np.int64)
        
        lut_idxs[ vals < 0. ] = cmax+1
        lut_idxs[ vals > 1. ] = cmax+2
        lut_idxs[ np.isnan(t)] = cmax+3
        

        lut_idxs[ np.isneginf(t) ] = cmax+4
        lut_idxs[ np.isposinf(t) ] = cmax+5
        
        return self.lut.take(lut_idxs, axis=0, mode="clip") # RGB added as color-last.         
