# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/096_Meta.ipynb.

# %% auto 0
__all__ = ['TorF', 'T', 'patch', 'combine_params', 'delegates', 'use_parameters_of', 'filter_using_signature', 'export',
           'classcontextmanager']

# %% ../../nbs/096_Meta.ipynb 1
import builtins
import copy as cp
import functools
import inspect
import sys
import types
from functools import partial, wraps
from types import *
from typing import *

import docstring_parser

# %% ../../nbs/096_Meta.ipynb 4
def test_eq(a: Any, b: Any) -> None:
    "`test` that `a==b`"
    if a != b:
        raise ValueError(f"{a} != {b}")

# %% ../../nbs/096_Meta.ipynb 6
F = TypeVar("F", bound=Callable[..., Any])


def copy_func(f: Union[F, FunctionType]) -> Union[F, FunctionType]:
    "Copy a non-builtin function (NB `copy.copy` does not work for this)"
    if not isinstance(f, FunctionType):
        return cp.copy(f)
    fn = FunctionType(
        f.__code__, f.__globals__, f.__name__, f.__defaults__, f.__closure__
    )
    fn.__kwdefaults__ = f.__kwdefaults__
    fn.__dict__.update(f.__dict__)
    fn.__annotations__.update(f.__annotations__)
    fn.__qualname__ = f.__qualname__
    fn.__doc__ = f.__doc__
    return fn

# %% ../../nbs/096_Meta.ipynb 11
def patch_to(
    cls: Union[Type, Iterable[Type]], as_prop: bool = False, cls_method: bool = False
) -> Callable[[F], F]:
    "Decorator: add `f` to `cls`"
    if not isinstance(cls, (tuple, list)):
        cls = (cls,)  # type: ignore

    def _inner(f: F) -> F:
        for c_ in cls:
            nf = copy_func(f)
            nm = f.__name__
            # `functools.update_wrapper` when passing patched function to `Pipeline`, so we do it manually
            for o in functools.WRAPPER_ASSIGNMENTS:
                setattr(nf, o, getattr(f, o))
            nf.__qualname__ = f"{c_.__name__}.{nm}"
            if cls_method:
                setattr(c_, nm, MethodType(nf, c_))
            else:
                setattr(c_, nm, property(nf) if as_prop else nf)
        # Avoid clobbering existing functions
        # nosemgrep
        existing_func = globals().get(nm, builtins.__dict__.get(nm, None))
        return existing_func  # type: ignore

    return _inner

# %% ../../nbs/096_Meta.ipynb 22
def eval_type(
    t: Sequence, glb: Optional[Dict[str, Any]], loc: Optional[Mapping[str, object]]
) -> Any:
    "`eval` a type or collection of types, if needed, for annotations in py3.10+"
    if isinstance(t, str):
        if "|" in t:
            return Union[eval_type(tuple(t.split("|")), glb, loc)]
        # nosemgrep
        return eval(t, glb, loc)  # nosec B307:blacklist
    if isinstance(t, (tuple, list)):
        return type(t)([eval_type(c, glb, loc) for c in t])
    return t


def union2tuple(t) -> Tuple[Any, ...]:  # type: ignore
    if getattr(t, "__origin__", None) is Union:
        return t.__args__  # type: ignore

    if sys.version_info >= (3, 10):
        if isinstance(t, UnionType):
            return t.__args__

    return t  # type: ignore


def get_annotations_ex(
    obj: Union[FunctionType, Type, F],
    *,
    globals: Optional[Dict[str, Any]] = None,
    locals: Optional[Dict[str, Any]] = None,
) -> Tuple[Dict[str, Any], Union[Any, Dict[str, Any], None], Dict[str, Any]]:
    "Backport of py3.10 `get_annotations` that returns globals/locals"
    if isinstance(obj, type):
        obj_dict = getattr(obj, "__dict__", None)
        if obj_dict and hasattr(obj_dict, "get"):
            ann = obj_dict.get("__annotations__", None)
            if isinstance(ann, types.GetSetDescriptorType):
                ann = None
        else:
            ann = None

        obj_globals = None
        module_name = getattr(obj, "__module__", None)
        if module_name:
            module = sys.modules.get(module_name, None)
            if module:
                obj_globals = getattr(module, "__dict__", None)
        obj_locals = dict(vars(obj))
        unwrap = obj
    elif isinstance(obj, types.ModuleType):
        ann = getattr(obj, "__annotations__", None)
        obj_globals = getattr(obj, "__dict__")
        obj_locals, unwrap = None, None
    elif callable(obj):
        ann = getattr(obj, "__annotations__", None)
        obj_globals = getattr(obj, "__globals__", None)
        obj_locals, unwrap = None, obj  # type: ignore
    else:
        raise TypeError(f"{obj!r} is not a module, class, or callable.")

    if ann is None:
        ann = {}
    if not isinstance(ann, dict):
        raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None")
    if not ann:
        ann = {}

    if unwrap is not None:
        while True:
            if hasattr(unwrap, "__wrapped__"):
                unwrap = unwrap.__wrapped__
                continue
            if isinstance(unwrap, functools.partial):
                unwrap = unwrap.func  # type: ignore
                continue
            break
        if hasattr(unwrap, "__globals__"):
            obj_globals = unwrap.__globals__

    if globals is None:
        globals = obj_globals
    if locals is None:
        locals = obj_locals

    return dict(ann), globals, locals  # type: ignore

# %% ../../nbs/096_Meta.ipynb 23
def patch(  # type: ignore
    f: Optional[F] = None, *, as_prop: bool = False, cls_method: bool = False
):
    "Decorator: add `f` to the first parameter's class (based on f's type annotations)"
    if f is None:
        return partial(patch, as_prop=as_prop, cls_method=cls_method)
    ann, glb, loc = get_annotations_ex(f)
    cls = union2tuple(
        eval_type(ann.pop("cls") if cls_method else next(iter(ann.values())), glb, loc)
    )
    return patch_to(cls, as_prop=as_prop, cls_method=cls_method)(f)

# %% ../../nbs/096_Meta.ipynb 35
def _delegates_without_docs(
    to: Optional[F] = None,  # Delegatee
    keep: bool = False,  # Keep `kwargs` in decorated function?
    but: Optional[List[str]] = None,  # Exclude these parameters from signature
) -> Callable[[F], F]:
    "Decorator: replace `**kwargs` in signature with params from `to`"
    if but is None:
        but = []

    def _f(f: F) -> F:
        if to is None:
            to_f, from_f = f.__base__.__init__, f.__init__  # type: ignore
        else:
            to_f, from_f = to.__init__ if isinstance(to, type) else to, f  # type: ignore
        from_f = getattr(from_f, "__func__", from_f)
        to_f = getattr(to_f, "__func__", to_f)
        if hasattr(from_f, "__delwrap__"):
            return f
        sig = inspect.signature(from_f)
        sigd = dict(sig.parameters)
        if "kwargs" in sigd:
            k = sigd.pop("kwargs")
        else:
            k = None
        s2 = {
            k: v.replace(kind=inspect.Parameter.KEYWORD_ONLY)
            for k, v in inspect.signature(to_f).parameters.items()
            if v.default != inspect.Parameter.empty and k not in sigd and k not in but  # type: ignore
        }
        anno = {
            k: v
            for k, v in getattr(to_f, "__annotations__", {}).items()
            if k not in sigd and k not in but  # type: ignore
        }
        sigd.update(s2)
        if keep and k is not None:
            sigd["kwargs"] = k
        else:
            from_f.__delwrap__ = to_f
        from_f.__signature__ = sig.replace(parameters=list(sigd.values()))
        if hasattr(from_f, "__annotations__"):
            from_f.__annotations__.update(anno)
        return f

    return _f

# %% ../../nbs/096_Meta.ipynb 45
def _format_args(xs: List[docstring_parser.DocstringParam]) -> str:
    return "\nArgs:\n - " + "\n - ".join(
        [f"{x.arg_name} ({x.type_name}): {x.description}" for x in xs]
    )


def combine_params(f: F, o: Union[Type, Callable[..., Any]]) -> F:
    """Combines docstring arguments of a function and another object or function

    Args:
        f: destination functions where combined arguments will end up
        o: source function from which arguments are taken from

    Returns:
        Function f with augumented docstring including arguments from both functions/objects
    """
    src_params = docstring_parser.parse_from_object(o).params
    #     logger.info(f"combine_params(): source:{_format_args(src_params)}")
    docs = docstring_parser.parse_from_object(f)
    #     logger.info(f"combine_params(): destination:{_format_args(docs.params)}")
    dst_params_names = [p.arg_name for p in docs.params]

    combined_params = docs.params + [
        x for x in src_params if not x.arg_name in dst_params_names
    ]
    #     logger.info(f"combine_params(): combined:{_format_args(combined_params)}")

    docs.meta = [
        x for x in docs.meta if not isinstance(x, docstring_parser.DocstringParam)
    ] + combined_params  # type: ignore

    f.__doc__ = docstring_parser.compose(
        docs, style=docstring_parser.DocstringStyle.GOOGLE
    )
    return f

# %% ../../nbs/096_Meta.ipynb 47
def delegates(
    o: Union[Type, Callable[..., Any]],
    keep: bool = False,
    but: Optional[List[str]] = None,
) -> Callable[[F], F]:
    """Delegates keyword agruments from o to the function the decorator is applied to

    Args:
        o: object (class or function) with default kwargs
        keep: Keep `kwargs` in decorated function?
        but: argument names not to include
    """

    def _inner(f: F, keep: bool = keep, but: Optional[List[str]] = but) -> F:
        def _combine_params(o: Union[Type, Callable[..., Any]]) -> Callable[[F], F]:
            def __combine_params(f: F, o: Union[Type, Callable[..., Any]] = o) -> F:
                return combine_params(f=f, o=o)

            return __combine_params

        @_combine_params(o)  # type: ignore
        @_delegates_without_docs(o, keep=keep, but=but)  # type: ignore
        @wraps(f)
        def _f(*args: Any, **kwargs: Any) -> Any:
            return f(*args, **kwargs)

        return _f

    return _inner

# %% ../../nbs/096_Meta.ipynb 64
def use_parameters_of(
    o: Union[Type, Callable[..., Any]], **kwargs: Dict[str, Any]
) -> Dict[str, Any]:
    """Restrict parameters passwed as keyword arguments to parameters from the signature of ``o``

    Args:
        o: object or callable which signature is used for restricting keyword arguments
        kwargs: keyword arguments

    Returns:
        restricted keyword arguments

    """
    allowed_keys = set(inspect.signature(o).parameters.keys())
    return {k: v for k, v in kwargs.items() if k in allowed_keys}

# %% ../../nbs/096_Meta.ipynb 66
def filter_using_signature(f: Callable, **kwargs: Dict[str, Any]) -> Dict[str, Any]:
    """todo: write docs"""
    param_names = list(inspect.signature(f).parameters.keys())
    return {k: v for k, v in kwargs.items() if k in param_names}

# %% ../../nbs/096_Meta.ipynb 68
TorF = TypeVar("TorF", Type, Callable[..., Any])


def export(module_name: str) -> Callable[[TorF], TorF]:
    def _inner(o: TorF, module_name: str = module_name) -> TorF:
        o.__module__ = module_name
        return o

    return _inner

# %% ../../nbs/096_Meta.ipynb 71
T = TypeVar("T")


def classcontextmanager(name: str = "lifecycle") -> Callable[[Type[T]], Type[T]]:
    def _classcontextmanager(cls: Type[T], name: str = name) -> Type[T]:
        if not hasattr(cls, name):
            raise ValueError

        @patch
        def __enter__(self: cls) -> Any:  # type: ignore
            if not hasattr(self, "_lifecycle_ctx"):
                self._lifecycle_ctx = []  # type: ignore

            self._lifecycle_ctx.append(getattr(self, name)())  # type: ignore
            return self._lifecycle_ctx[-1].__enter__()  # type: ignore

        @patch
        def __exit__(self: cls, *args: Any) -> None:  # type: ignore
            self._lifecycle_ctx.pop(-1).__exit__(*args)  # type: ignore

        return cls

    return _classcontextmanager

# %% ../../nbs/096_Meta.ipynb 74
def _get_default_kwargs_from_sig(f: F, **kwargs: Any) -> Dict[str, Any]:
    """
    Get default values for function **kwargs

    Args:
        f: Function to extract default values from

    Returns:
        Dict of default values of function f **kwargs
    """
    defaults = {
        k: v.default
        for k, v in inspect.signature(f).parameters.items()
        if v.default != inspect._empty
    }
    defaults.update(kwargs)
    return defaults
