# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/04_util.ipynb.

# %% auto 0
__all__ = ['get_option', 'init_enum', 'getname', 'private', 'mangled', 'unmangle', 'bound_arguments_format', 'decorated_format',
           'decorated_format_from_func', 'nones', 'nextnone', 'tuple_arg1st', 'tuple_extend', 'join', 'passthrough',
           'get_mangled_keywords', 'overload_keyword', 'resolve_args', 'resolve_kwargs', 'hasvarg', 'hasvkws',
           'hasvpok', 'parameter_as_sort_tuple', 'sort_parameters', 'parameter_defaults', 'unbound_args',
           'unbound_kwargs', 'bind_unbound_args', 'bind_unbound_kwargs', 'handle_bound_methods',
           'handle_bound_variadic']

# %% ../nbs/04_util.ipynb 6
from inspect import (ismethod, Signature, BoundArguments, Parameter, _ParameterKind as ParamKind, _empty as Empty)
from functools import wraps, partial
from itertools import chain, zip_longest
from enum import EnumMeta

# %% ../nbs/04_util.ipynb 8
from typing import (Any, Callable, Iterable, Generator, overload, get_args)

# %% ../nbs/04_util.ipynb 10
#| export


# %% ../nbs/04_util.ipynb 12
#| export


# %% ../nbs/04_util.ipynb 14
from .type import (P, FoundArguments, VArgs, BoundReturn, DecoReturn)
from .cons import (_ARGS, _KWDS)
from .grds import (isnone, notnone, istuple, isdict, isvarg, isvkws, isvpos, isvpok, notempty, isoptional, issig)
from .enum import Direction, BoundFormat, DecoratedFormat, ResolutionOrder

# %% ../nbs/04_util.ipynb 17
def get_option(x, i: int = 0):
    "Check if `x` is an optional type"
    if not isoptional(x): return None
    args = get_args(x)
    if len(args) == 0: return None
    return args[i if i < len(args) else -1] 

# %% ../nbs/04_util.ipynb 19
def init_enum(ecls: EnumMeta) -> EnumMeta | None:
    try: return ecls()
    except: ...
    try: return ecls._default_() if callable(ecls._default_) else ecls._default_
    except: ...
    try: return ecls._missing_(None)
    except: ...
    try: list(ecls.__members__.values())[0]
    except: ...
    try: return ecls(0)
    except: ...
    return None

# %% ../nbs/04_util.ipynb 21
def getname(obj) -> str:
    return getattr(obj, '__name__', getattr(obj, '__qualname__', str(obj)))

def private(attr: str) -> str:
    '''A private attribute i.e. `_{attr}`'''
    return f'_{attr.lstrip("_")}'

def mangled(cls: object, attr: str) -> str:
    '''A mangled attribute i.e. `_{cls.__name__}__{attr}`'''
    return f'{private(getname(cls))}__{attr.lstrip("__")}'

def unmangle(cls: object, attr: str) -> str:
    '''Unmangle an attribute name i.e.  `_{cls.__name__}__{attr}` --> `__{attr}`'''
    return attr.split(private(getname(cls)))[-1]

# %% ../nbs/04_util.ipynb 23
@wraps(BoundFormat.format, assigned=('__module__', '__doc__', '__annotations__'))
def bound_arguments_format(
    bound: BoundArguments, *args: P.args, 
    __format: BoundFormat = BoundFormat.argskwargs, 
    __locals_if_empty: bool = False, 
    **kwargs: P.kwargs
) -> BoundReturn:
    '''Formats bound arguments according to the specified format.

    Parameters
    ----------
    bound : BoundArguments
        The bound arguments to format.
    *args : 
        Additional arguments for formatting.
    __format : BoundFormat, optional
        The formatting style to use.
    __locals_if_empty : bool, optional
        Whether to use locals if empty.
    **kwargs : 
        Additional keyword arguments for formatting.

    Returns
    -------
    BoundReturn
        The formatted bound arguments.

    See Also
    --------
    BoundFormat.format : The format method being wrapped.
    '''
    __format = kwargs.get('__format', __format)
    __missing = kwargs.get('__locals_if_empty', __locals_if_empty)
    # if (bound.args == () and bound.kwargs == {}): __format = BoundFormat.original
    return BoundFormat.format(bound, *args, __format=__format, __locals_if_empty=__missing, **kwargs)

# %% ../nbs/04_util.ipynb 24
@wraps(DecoratedFormat.format, assigned=('__module__', '__doc__', '__annotations__'))
def decorated_format(
    dec: Callable, fn: Callable | None = None, /, *args: P.args, 
    __format: DecoratedFormat = DecoratedFormat.decorated, **kwargs: P.kwargs
) -> DecoReturn:
    '''Formats a decorator and function according to the specified format.

    Parameters
    ----------
    dec : Callable
        The decorator to format.
    fn : Callable, optional
        The function the decorator is applied to.
    *args : 
        Additional arguments for formatting.
    __format : DecoratedFormat, optional
        The formatting style to use.
    **kwargs : 
        Additional keyword arguments for formatting.

    Returns
    -------
    DecoReturn
        The formatted decorator and function.

    See Also
    --------
    DecoratedFormat.format : The format method being wrapped.
    '''
    __format = kwargs.get('__format', __format)
    return DecoratedFormat.format(dec, fn, *args, __format=__format, **kwargs)

# %% ../nbs/04_util.ipynb 25
@wraps(DecoratedFormat.infer, assigned=('__module__', '__doc__', '__annotations__'))
def decorated_format_from_func(fn: Callable | None = None, *args, **kwargs: P.kwargs) -> DecoratedFormat:
    __format = kwargs.get('__format', __format)
    return DecoratedFormat.infer(fn)

# %% ../nbs/04_util.ipynb 27
def nones(n: int) -> tuple[None, ...]: 
    '''Return a tuple of `n` `None` values'''
    return tuple((None, ) * n)

def nextnone(item: Iterable) -> Generator: 
    '''Return the next non-None item in the iterable. If no such item is found, return None.'''
    return next((x for x in item if notnone(x)), None)

def tuple_arg1st(*tups: tuple) -> tuple:
    '''Return a tuple of the first non-None item at a given idex across each tuple.'''
    return tuple(nextnone(item) for item in zip_longest(*tups))

def tuple_extend(*tups: tuple) -> tuple:
    '''Return a tuple of all items in each tuple.'''
    return tuple(chain(*tups))

def join(*tups: tuple, __overwrite: bool = True) -> tuple:
    '''Join tuples in one of two ways. Either merge the first non-None element at each 
    index across the tuples, or merge all elements in each tuple.'''
    return tuple_arg1st(*tups) if __overwrite else tuple_extend(*tups)

# %% ../nbs/04_util.ipynb 28
@overload
def passthrough(a, *args, **kwargs): ...
@overload
def passthrough(*args, **kwargs): ...
def passthrough(*args, **kwargs):
    '''Returns the first argument a if it exists else None'''
    return args[0] if args and len(args) > 0 else None

# %% ../nbs/04_util.ipynb 30
def get_mangled_keywords(**kwargs) -> tuple[dict, dict]:
    opts, keys = dict(), list(kwargs.keys())
    for k in keys:
        if k.startswith('__'): 
            opts[k] = kwargs.pop(k)
    return opts, kwargs

# %% ../nbs/04_util.ipynb 32
def getorpop(keywords: dict, __pop: bool = False, **kwargs):
    return getattr(keywords, 'pop' if __pop else 'get', keywords.get)

def getpopkw(keywords: dict, __pop: bool = False, **kwargs):
    method = getorpop(keywords, __pop=__pop, **kwargs)
    def getter(kw: str, default = None): return method(kw, default)
    return getter

# %% ../nbs/04_util.ipynb 33
def overload_keyword(
    keyword: str, 
    options: tuple[str, ...] = (), 
    default = None, 
    __newkw: str | None= None,
    __pop_keywords: bool = True
):
    '''A decorator to overload a keyword argument for a function.

    Parameters
    ----------
    keyword : str
        The primary keyword to check in the function's keyword arguments.
    options : Tuple[str, ...], optional
        Other keyword arguments to check in order of preference.
    default : Any, optional
        The default value to use if none of the keywords are present.

    Returns
    -------
    Callable
        A decorator that wraps the function and overloads the specified keyword argument.
    '''
    def decorator(fn: Callable):
        @wraps(fn)
        def wrapper(*args, **kwargs):
            result, keywords, results = default, kwargs, []
            getter = getpopkw(keywords, __pop=__pop_keywords)
            for opt in (keyword, *options):
                if opt not in keywords: continue
                result = getter(opt, result)
                results.append(result)
                if result: break
                
            result = next((r for r in results if r), default)
            if __pop_keywords:
                for opt in set((keyword, *options)): 
                    kwargs.pop(opt, None)
            kwargs[(__newkw or keyword)] = result
            
            return fn(*args, **kwargs)
        return wrapper
    return decorator

# %% ../nbs/04_util.ipynb 36
def resolve_args(
    oarg: VArgs, iarg: VArgs | None = (), earg: VArgs | None = (),
    order: ResolutionOrder = ResolutionOrder.default, __overwrite: bool = True,
) -> VArgs:
    '''
    Resolve and combine positional arguments from multiple sources according to a specified order.

    Parameters
    ----------
    oarg : VArgs
        Outer scope arguments.
    iarg : VArgs | None, optional
        Inner arguments, by default ().
    earg : VArgs | None, optional
        Extra arguments, by default ().
    order : ResolutionOrder, optional
        The order in which to combine the arguments, by default ResolutionOrder.default.
    __overwrite : bool, optional
        Whether to overwrite existing arguments with those from later sources, by default True.

    Returns
    -------
    tuple
        A tuple containing the combined arguments.
    '''
    oarg, iarg, earg = map(lambda x: x or (), (oarg, iarg, earg))
    comb = partial(join, __overwrite=__overwrite)
    match order:
        case ResolutionOrder.OuterInnerExtra: return comb(oarg, iarg, earg)
        case ResolutionOrder.OuterExtraInner: return comb(oarg, earg, iarg)
        case ResolutionOrder.InnerOuterExtra: return comb(iarg, oarg, earg)
        case ResolutionOrder.InnerExtraOuter: return comb(iarg, earg, oarg)
        case ResolutionOrder.ExtraOuterInner: return comb(earg, oarg, iarg)
        case ResolutionOrder.ExtraInnerOuter: return comb(earg, iarg, oarg)
        case _: return comb(oarg, iarg, earg)

# %% ../nbs/04_util.ipynb 37
def resolve_kwargs(
    okws: dict, ikws: dict | None = dict(), ekws: dict | None = dict(),
    order: ResolutionOrder = ResolutionOrder.default, 
    direction: Direction = Direction.rl
) -> dict:
    '''Resolve and combine keyword arguments from multiple sources according to specified order and direction.

    Parameters
    ----------
    okws : dict
        Outer scope keyword arguments.
    ikws : dict | None, optional
        Inner keyword arguments, by default an empty dict.
    ekws : dict | None, optional
        Extra keyword arguments, by default an empty dict.
    order : ResolutionOrder, optional
        The order in which to combine the keyword arguments, by default ResolutionOrder.default.
    direction : Direction, optional
        The direction in which to apply the keyword arguments, by default Direction.rl.

    Returns
    -------
    dict
        The combined keyword arguments.
    '''
    
    okws, ikws, ekws = map(lambda x: x or {}, (okws, ikws, ekws))
    match direction:
        case Direction.rl:
            match order:
                case ResolutionOrder.OuterInnerExtra: return {**ekws, **ikws, **okws}
                case ResolutionOrder.OuterExtraInner: return {**ikws, **ekws, **okws}
                case ResolutionOrder.InnerOuterExtra: return {**ekws, **okws, **ikws}
                case ResolutionOrder.InnerExtraOuter: return {**okws, **ekws, **ikws}
                case ResolutionOrder.ExtraOuterInner: return {**ikws, **okws, **ekws}
                case ResolutionOrder.ExtraInnerOuter: return {**okws, **ikws, **ekws}
                case _: return {**ikws, **ekws, **okws}
        case Direction.lr:
            match order:            
                case ResolutionOrder.OuterInnerExtra: return {**okws, **ikws, **ekws}
                case ResolutionOrder.OuterExtraInner: return {**okws, **ekws, **ikws}
                case ResolutionOrder.InnerOuterExtra: return {**ikws, **okws, **ekws}
                case ResolutionOrder.InnerExtraOuter: return {**ikws, **ekws, **okws}
                case ResolutionOrder.ExtraOuterInner: return {**ekws, **okws, **ikws}
                case ResolutionOrder.ExtraInnerOuter: return {**ekws, **ikws, **okws}
                case _: return {**okws, **ikws, **ekws}
        case _: 
            return {**ikws, **ekws, **okws}

# %% ../nbs/04_util.ipynb 39
def hasvarg(sig: Signature) -> bool:
    '''Check if the signature has a VAR_POSITIONAL parameter.'''
    return any((isvarg(p) for p in sig.parameters.values()))

def hasvkws(sig: Signature) -> bool:
    '''Check if the signature has a VAR_KEYWORD parameter.'''
    return any((isvkws(p) for p in sig.parameters.values()))

def hasvpok(sig: Signature) -> bool:
    '''Check if the signature has a VAR_POSITIONAL parameter.'''
    return any((isvpok(p) for p in sig.parameters.values()))

def hasvarg(spec: Signature | dict[str, Parameter]) -> bool:
    '''Check if the signature has a VAR_POSITIONAL parameter.'''
    if issig(spec): return any((isvarg(p) for p in spec.parameters.values()))
    if isdict(spec): return any((isvarg(p) for p in spec.values()))
    return any((isvarg(p) for p in spec))

def hasvkws(spec: Signature | dict[str, Parameter]) -> bool:
    '''Check if the signature has a VAR_KEYWORD parameter.'''
    if issig(spec): return any((isvkws(p) for p in spec.parameters.values()))
    if isdict(spec): return any((isvkws(p) for p in spec.values()))
    return any((isvkws(p) for p in spec))

def hasvpok(spec: Signature | dict[str, Parameter]) -> bool:
    '''Check if the signature has a VAR_POSITIONAL parameter.'''
    if issig(spec): return any((isvpok(p) for p in spec.parameters.values()))
    if isdict(spec): return any((isvpok(p) for p in spec.values()))
    return any((isvpok(p) for p in spec))

# %% ../nbs/04_util.ipynb 41
def parameter_as_sort_tuple(p, order: list | None = None) -> tuple[int, int, bool, bool]:
    kinds = list(ParamKind)
    idx = order.index(p.name) if (order and p.name in order) else -1
    return (kinds.index(p.kind), idx, (p.default != Empty), (p.default == None))
    
def sort_parameters(
    __parameters: dict[str, Parameter] | list[Parameter],
    __idealorder: list | None = None
) -> list[Parameter]:
    sort = partial(parameter_as_sort_tuple, order=__idealorder)
    prms = list(__parameters.values()) if isdict(__parameters) else __parameters
    return sorted(prms, key = sort)

# %% ../nbs/04_util.ipynb 43
def parameter_defaults(sig: Signature) -> dict:
    '''Return a dictionary of parameter defaults if they exist.'''
    return dict(((k, v.default) for k, v in sig.parameters.items() if notempty(v)))

# %% ../nbs/04_util.ipynb 45
def unbound_args(
    bound: BoundArguments, 
    extra: tuple | None = None,
    var_only: bool = False
) -> tuple[FoundArguments, VArgs]:
    '''Extracts unbound arguments and vargs from a given 
    BoundArguments instance and extra arguments.

    Parameters
    ----------
    bound : BoundArguments
        The BoundArguments instance to extract from.
        
    extra : tuple, optional
        Extra arguments that were not bound during the initial binding.

    Returns
    -------
    tuple[FoundArguments, VArgs]
        A tuple where the first element is a list of dictionaries with keys 'key', 'idx', and 'val' representing
        found arguments and their positions, and the second element is a tuple of remaining varargs.
    '''
    if extra is None: 
        return (), ()
    
    if var_only and not hasvarg(bound.signature):
        return (), ()
    
    if not hasvpok(bound.signature):
        return (), extra
    
    used, rest = list(), list(extra).copy()
    for arg, p in bound.signature.parameters.items():
        if not (isvpos(p) or isvpok(p)): continue
        if len(rest) == 0: break # if there are no more arguments, done
        val, idx = rest.pop(0), len(used) # get the next value and its index
        used.append(dict(key=arg, idx=idx, val=val))
    return used, rest

def unbound_kwargs(bound: BoundArguments) -> dict:
    '''Extracts unbound keyword arguments from a given BoundArguments instance.

    Parameters
    ----------
    bound : BoundArguments
        The BoundArguments instance to extract from.

    Returns
    -------
    dict
        A dictionary of unbound keyword arguments.
    '''
    if not hasvkws(bound.signature): return {}
    kws = {}
    for k, v in bound.arguments.items():
        if k not in bound.signature.parameters:
            kws[k] = v
    return kws

# %% ../nbs/04_util.ipynb 47
def bind_unbound_args(
    bound: BoundArguments, 
    found: FoundArguments | None = None, 
    extra: VArgs | None = None
) -> BoundArguments:
    '''Binds unbound arguments and vargs to a BoundArguments instance.

    Parameters
    ----------
    bound : BoundArguments
        The BoundArguments instance to bind to.
    found : FoundArguments, optional
        A list of dictionaries representing found arguments to bind.
    extra : VArgs, optional
        Extra varargs to bind.

    Returns
    -------
    BoundArguments
        The updated BoundArguments instance with unbound arguments and varargs bound.
    '''
    used, rest = found or (), extra or ()
    if rest is not None and rest != ():
        rest = tuple(p['val'] for p in used) + tuple(rest)
        used, rest = unbound_args(bound, rest)

    for p in used: bound.arguments.update({p['key']: p['val']})
    if not hasvarg(bound.signature): return bound
    args = bound.arguments.get('args', ())
    bound.arguments.update(dict(args=args + tuple(rest)))
    return bound
    
def bind_unbound_kwargs(bound: BoundArguments, unbound: dict | None) -> BoundArguments:
    '''Binds unbound keyword arguments to a BoundArguments instance.

    Parameters
    ----------
    bound : BoundArguments
        The BoundArguments instance to bind to.
    unbound : dict, optional
        A dictionary of unbound keyword arguments to bind.

    Returns
    -------
    BoundArguments
        The updated BoundArguments instance with unbound keyword arguments bound.
    '''
    unbound = unbound or unbound_kwargs(bound)
    if not hasvkws(bound.signature): 
        return bound
    kwargs = bound.arguments.get('kwargs', {})
    kwargs.update(unbound)
    bound.arguments.update(dict(kwargs=kwargs))
    for k in unbound: bound.arguments.pop(k, None)
    return bound

# %% ../nbs/04_util.ipynb 48
def handle_bound_methods(
    bound: BoundArguments, 
    __function: Callable | None = None, 
    __drop_cls: bool = True, 
    __drop_self: bool = True
) -> BoundArguments:
    '''Handles bound method parameters (`self` or `cls`) in BoundArguments.

    Parameters
    ----------
    bound : BoundArguments
        The BoundArguments instance to handle.
    __function : Callable, optional
        The function to check if it's a method or not.
    __drop_cls : bool, optional
        Whether to drop `cls` parameter if present.
    __drop_self : bool, optional
        Whether to drop `self` parameter if present.

    Returns
    -------
    BoundArguments
        The updated BoundArguments instance with `self` or `cls` handled.
    '''
    pars = list(bound.signature.parameters.keys())
    p1st = nextnone(pars) # Get the first key in the signature
    # remove the `cls` and `self` paramters if the function is a method
    
    args = list(bound.arguments.keys())
    a1st = nextnone(args)
    
    for (guard, name, drop) in ((ismethod, 'cls', __drop_cls), (callable, 'self', __drop_self)):
        passed = guard(__function) or isnone(__function) # whether the function is a method or not
        arg1st = a1st == name or p1st == name
        if passed and arg1st and drop:
            bound.arguments.pop(name, None)
    return bound

def handle_bound_variadic(
    bound: BoundArguments, 
    fallback_args: tuple | None = (), 
    fallback_kwds: dict | None  = {}
) -> BoundArguments:
    '''Handles restoration of original variadic arguments and keyword arguments in BoundArguments.

    Parameters
    ----------
    bound : BoundArguments
        The BoundArguments instance to handle.
    fallback_args : tuple, optional
        The original variadic arguments to restore if needed.
    fallback_kwds : dict, optional
        The original keyword arguments to restore if needed.

    Returns
    -------
    BoundArguments
        The updated BoundArguments instance with original variadic arguments and keyword arguments restored.
    '''
    # get the current bound arguments and keyword arguments
    carg, ckws = bound.arguments.get('args', None), bound.arguments.get('kwargs', None)
    # if the original variadic arguments are not a tuple, restore the original
    args, kwargs = fallback_args or (), fallback_kwds or {}
    if not istuple(carg): bound.arguments.update(dict(args=args))
    # if the original keyword arguments are not a dictionary, restore the original
    if not isdict(ckws): bound.arguments.update(dict(kwargs=kwargs))
    return bound
