# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/03_utils.utils.ipynb.

# %% auto 0
__all__ = ['PRINT_OPTS', 'sci_mode', 'pretty_str', 'sparse_join', 'ansi_color', 'np_to_str_common', 'history_warning']

# %% ../../nbs/03_utils.utils.ipynb 3
from collections import defaultdict
import warnings
import numpy as np
from typing import Optional, Union

# %% ../../nbs/03_utils.utils.ipynb 4
class __PrinterOptions(object):
    precision: int = 3
    threshold_max: int = 3 # .abs() larger than 1e3 -> Sci mode
    threshold_min: int = -4 # .abs() smaller that 1e-4 -> Sci mode
    sci_mode: Optional[bool] = None # None = auto. Otherwise, force sci mode.
    indent: int = 2 # Indent for .deeper()
    color: bool = True

PRINT_OPTS = __PrinterOptions()

# %% ../../nbs/03_utils.utils.ipynb 5
# Do we want this float in decimal or scientific mode?
def sci_mode(f: float):
    return (abs(f) < 10**(PRINT_OPTS.threshold_min) or
            abs(f) > 10**PRINT_OPTS.threshold_max)

# %% ../../nbs/03_utils.utils.ipynb 8
# Convert an ndarray or scalar into a string.
# This only looks good for small tensors, which is how it's intended to be used.
def pretty_str(x):
    """A slightly better way to print `float`-y values.
    Works for `np.ndarray`, `torch.Tensor`, `jax.DeviceArray`, and scalars."""

    if isinstance(x, int):
        return '{}'.format(x)
    elif isinstance(x, float):
        if x == 0.:
            return "0."

        sci = sci_mode(x) if PRINT_OPTS.sci_mode is None else PRINT_OPTS.sci_mode
        
        fmt = f"{{:.{PRINT_OPTS.precision}{'e' if sci else 'f'}}}"

        return fmt.format(x)
    elif x.ndim == 0:
            return pretty_str(x.item())
    else:
        slices = [pretty_str(x[i]) for i in range(0, x.shape[0])]
        return '[' + ", ".join(slices) + ']'

# %% ../../nbs/03_utils.utils.ipynb 13
def sparse_join(lst, sep=" "):
    # Join non-empty list elements into a space-sepaeated string
    return sep.join( [ l for l in lst if l] )

# %% ../../nbs/03_utils.utils.ipynb 15
def ansi_color(s: str, col: str, use_color=True):
        "Very minimal ANSI color support"
        style = defaultdict(str)
        style["grey"] = "\x1b[38;2;127;127;127m"
        style["red"] = "\x1b[31m"
        end_style = "\x1b[0m"
       
        return style[col]+s+end_style if use_color else s

# %% ../../nbs/03_utils.utils.ipynb 18
def np_to_str_common(x: Union[np.ndarray, np.generic],
                        color=None):
    
    color = PRINT_OPTS.color if color is None else color

    zeros = ansi_color("all_zeros", "grey", color) if np.equal(x, 0.).all() and x.size > 1 else None
    pinf = ansi_color("+Inf!", "red", color) if np.isposinf(x).any() else None
    ninf = ansi_color("-Inf!", "red", color) if np.isneginf(x).any() else None
    nan = ansi_color("NaN!", "red", color) if np.isnan(x).any() else None

    attention = sparse_join([zeros,pinf,ninf,nan])
    numel = f"n={x.size}" if x.size > 5 and max(x.shape) != x.size else None

    summary = None
    if not zeros and isinstance(x, np.ndarray):
        # Calculate stats on good values only.
        gx = x[ np.isfinite(x) ]

        minmax = f"x∈[{pretty_str(gx.min())}, {pretty_str(gx.max())}]" if gx.size > 2 else None
        meanstd = f"μ={pretty_str(gx.mean())} σ={pretty_str(gx.std())}" if gx.size >= 2 else None
        summary = sparse_join([numel, minmax, meanstd])


    return sparse_join([ summary, attention])

# %% ../../nbs/03_utils.utils.ipynb 21
def history_warning():
    "Issue a warning (once) ifw e are running in IPYthon with output cache enabled"
    if "get_ipython" in globals() and get_ipython().cache_size > 0:
        warnings.warn("IPYthon has its output cache enabled. See https://xl0.github.io/lovely-tensors/history.html")
