import inspect
import types
import warnings
from dataclasses import dataclass, is_dataclass
from pathlib import Path
from typing import Any, Callable, Iterable, Mapping

from muutils.json_serialize.array import ArrayMode, serialize_array
from muutils.json_serialize.util import (
    ErrorMode,
    Hashableitem,
    JSONitem,
    MonoTuple,
    SerializationException,
    _recursive_hashify,
    isinstance_namedtuple,
    safe_getsource,
    string_as_lines,
    try_catch,
)

# pylint: disable=protected-access

SERIALIZER_SPECIAL_KEYS: MonoTuple[str] = (
    "__name__",
    "__doc__",
    "__module__",
    "__class__",
    "__dict__",
    "__annotations__",
)

SERIALIZER_SPECIAL_FUNCS: dict[str, Callable] = {
    "str": str,
    "dir": dir,
    "type": try_catch(lambda x: str(type(x).__name__)),
    "repr": try_catch(lambda x: repr(x)),
    "code": try_catch(lambda x: inspect.getsource(x)),
    "sourcefile": try_catch(lambda x: inspect.getsourcefile(x)),
}

SERIALIZE_DIRECT_AS_STR: set[str] = {
    "<class 'torch.device'>",
    "<class 'torch.dtype'>",
}

ObjectPath = MonoTuple[str | int]


@dataclass
class SerializerHandler:
    """a handler for a specific type of object

    # Parameters:
        - `check : Callable[[JsonSerializer, Any], bool]` takes a JsonSerializer and an object, returns whether to use this handler
        - `serialize : Callable[[JsonSerializer, Any, ObjectPath], JSONitem]` takes a JsonSerializer, an object, and the current path, returns the serialized object
        - `desc : str` description of the handler (optional)
    """

    # (self_config, object) -> whether to use this handler
    check: Callable[["JsonSerializer", Any, ObjectPath], bool]
    # (self_config, object, path) -> serialized object
    serialize_func: Callable[["JsonSerializer", Any, ObjectPath], JSONitem]
    # unique identifier for the handler
    uid: str
    # optional description of how this serializer works
    desc: str = "(no description)"

    def serialize(self) -> dict:
        """serialize the handler info"""
        return {
            # get the code and doc of the check function
            "check": {
                "code": safe_getsource(self.check),
                "doc": string_as_lines(self.check.__doc__),
            },
            # get the code and doc of the load function
            "serialize_func": {
                "code": safe_getsource(self.serialize_func),
                "doc": string_as_lines(self.serialize_func.__doc__),
            },
            # get the uid, source_pckg, priority, and desc
            "uid": str(self.uid),
            "source_pckg": getattr(self.serialize_func, "source_pckg", None),
            "__module__": getattr(self.serialize_func, "__module__", None),
            "desc": str(self.desc),
        }


BASE_HANDLERS: MonoTuple[SerializerHandler] = (
    SerializerHandler(
        check=lambda self, obj, path: isinstance(
            obj, (bool, int, float, str, types.NoneType)
        ),
        serialize_func=lambda self, obj, path: obj,
        uid="base types",
    ),
    SerializerHandler(
        check=lambda self, obj, path: isinstance(obj, Mapping),
        serialize_func=lambda self, obj, path: {
            str(k): self.json_serialize(v, tuple(path) + (k,)) for k, v in obj.items()
        },
        uid="dictionaries",
    ),
    SerializerHandler(
        check=lambda self, obj, path: isinstance(obj, (list, tuple)),
        serialize_func=lambda self, obj, path: [
            self.json_serialize(x, tuple(path) + (i,)) for i, x in enumerate(obj)
        ],
        uid="(list, tuple) -> list",
    ),
)


def _serialize_override_serialize_func(
    self: "JsonSerializer", obj: Any, path: ObjectPath
) -> JSONitem:
    obj_cls: type = type(obj)
    # if hasattr(obj_cls, "_register_self") and callable(obj_cls._register_self):
    #     obj_cls._register_self()

    # get the serialized object
    return obj.serialize()


DEFAULT_HANDLERS: MonoTuple[SerializerHandler] = tuple(BASE_HANDLERS) + (
    SerializerHandler(
        # TODO: allow for custom serialization handler name
        check=lambda self, obj, path: hasattr(obj, "serialize")
        and callable(obj.serialize),
        serialize_func=_serialize_override_serialize_func,
        uid=".serialize override",
    ),
    SerializerHandler(
        check=lambda self, obj, path: isinstance_namedtuple(obj),
        serialize_func=lambda self, obj, path: self.json_serialize(dict(obj._asdict())),
        uid="namedtuple -> dict",
    ),
    SerializerHandler(
        check=lambda self, obj, path: is_dataclass(obj),
        serialize_func=lambda self, obj, path: {
            k: self.json_serialize(getattr(obj, k), tuple(path) + (k,))
            for k in obj.__dataclass_fields__
        },
        uid="dataclass -> dict",
    ),
    SerializerHandler(
        check=lambda self, obj, path: isinstance(obj, Path),
        serialize_func=lambda self, obj, path: obj.as_posix(),
        uid="path -> str",
    ),
    SerializerHandler(
        check=lambda self, obj, path: str(type(obj)) in SERIALIZE_DIRECT_AS_STR,
        serialize_func=lambda self, obj, path: str(obj),
        uid="obj -> str(obj)",
    ),
    SerializerHandler(
        check=lambda self, obj, path: str(type(obj)) == "<class 'numpy.ndarray'>",
        serialize_func=lambda self, obj, path: serialize_array(self, obj, path=path),
        uid="numpy.ndarray",
    ),
    SerializerHandler(
        check=lambda self, obj, path: str(type(obj)) == "<class 'torch.Tensor'>",
        serialize_func=lambda self, obj, path: serialize_array(
            self, obj.detach().cpu(), path=path
        ),
        uid="torch.Tensor",
    ),
    SerializerHandler(
        check=lambda self, obj, path: str(type(obj))
        == "<class 'pandas.core.frame.DataFrame'>",
        serialize_func=lambda self, obj, path: obj.to_dict(orient="records"),
        uid="pandas.DataFrame",
    ),
    SerializerHandler(
        check=lambda self, obj, path: isinstance(obj, (set, list, tuple))
        or isinstance(obj, Iterable),
        serialize_func=lambda self, obj, path: [
            self.json_serialize(x, tuple(path) + (i,)) for i, x in enumerate(obj)
        ],
        uid="(set, list, tuple, Iterable) -> list",
    ),
    SerializerHandler(
        check=lambda self, obj, path: True,
        serialize_func=lambda self, obj, path: {
            **{k: str(getattr(obj, k, None)) for k in SERIALIZER_SPECIAL_KEYS},
            **{k: f(obj) for k, f in SERIALIZER_SPECIAL_FUNCS.items()},
        },
        uid="fallback",
    ),
)


class JsonSerializer:
    """Json serialization class (holds configs)"""

    def __init__(
        self,
        *args,
        array_mode: ArrayMode = "array_list_meta",
        error_mode: ErrorMode = "except",
        handlers_pre: MonoTuple[SerializerHandler] = tuple(),
        handlers_default: MonoTuple[SerializerHandler] = DEFAULT_HANDLERS,
    ):
        if len(args) > 0:
            raise ValueError(
                f"JsonSerializer takes no positional arguments!\n{args = }"
            )

        self.array_mode: ArrayMode = array_mode
        self.error_mode: ErrorMode = error_mode
        # join up the handlers
        self.handlers: MonoTuple[SerializerHandler] = tuple(handlers_pre) + tuple(
            handlers_default
        )

    def json_serialize(
        self,
        obj: Any,
        path: ObjectPath = tuple(),
    ) -> JSONitem:
        try:
            for handler in self.handlers:
                if handler.check(self, obj, path):
                    return handler.serialize_func(self, obj, path)

            raise ValueError(f"no handler found for object with {type(obj) = }")

        except Exception as e:
            if self.error_mode == "except":
                obj_str: str = repr(obj)
                if len(obj_str) > 1000:
                    obj_str = obj_str[:1000] + "..."
                raise SerializationException(
                    f"error serializing at {path = } with last handler: '{handler.uid}'\nfrom: {e}\nobj: {obj_str}"
                ) from e
            elif self.error_mode == "warn":
                warnings.warn(
                    f"error serializing at {path = }, will return as string\n{obj = }\nexception = {e}"
                )

            return repr(obj)

    def hashify(
        self,
        obj: Any,
        path: ObjectPath = tuple(),
        force: bool = True,
    ) -> Hashableitem:
        """try to turn any object into something hashable"""
        data = self.json_serialize(obj, path=path)

        # recursive hashify, turning dicts and lists into tuples
        return _recursive_hashify(data, force=force)


def json_serialize(obj: Any, path: ObjectPath = tuple()) -> JSONitem:
    """serialize object to json-serializable object with default config"""
    return JsonSerializer().json_serialize(obj, path=path)
