import functools
import logging
import os
from importlib.metadata import version
from typing import Any, Optional

import httpx
from rich.logging import RichHandler
from rich.pretty import pprint

__version__ = version("datalab-api")

__all__ = ("__version__", "BaseDatalabClient")


def pretty_displayer(method):
    @functools.wraps(method)
    def rich_wrapper(self, *args, **kwargs):
        display = kwargs.pop("display", False)
        result = method(self, *args, **kwargs)
        if display:
            if isinstance(result, dict) and "blocks_obj" in result:
                blocks: dict[str, dict] = result["blocks_obj"]
                for block in blocks.values():
                    if "bokeh_plot_data" in block:
                        bokeh_from_json(block)
            pprint(result, max_length=None, max_string=100, max_depth=3)

        return result

    return rich_wrapper


class AutoPrettyPrint(type):
    def __new__(cls, name, bases, dct):
        for attr, value in dct.items():
            if callable(value) and not attr.startswith("__"):
                dct[attr] = pretty_displayer(value)
        return super().__new__(cls, name, bases, dct)


def bokeh_from_json(block_data):
    from bokeh.io import curdoc
    from bokeh.plotting import show

    if "bokeh_plot_data" in block_data:
        bokeh_plot_data = block_data["bokeh_plot_data"]
    else:
        bokeh_plot_data = block_data
    curdoc().replace_with_json(bokeh_plot_data["doc"])
    show(curdoc().roots[0])


class BaseDatalabClient(metaclass=AutoPrettyPrint):
    """A base class that implements some of the shared/logistical functionality
    (hopefully) common to all Datalab clients.

    Mainly used to keep the namespace of the 'real' client classes clean and
    readable by users.
    """

    _api_key: Optional[str] = None
    _session: Optional[httpx.Client] = None
    _headers: dict[str, str] = {}

    bad_server_versions: Optional[tuple[tuple[int, int, int]]] = ((0, 2, 0),)
    """Any known server versions that are not supported by this client."""

    min_server_version: tuple[int, int, int] = (0, 1, 0)
    """The minimum supported server version that this client supports."""

    def __init__(self, datalab_api_url: str, log_level: str = "WARNING"):
        """Creates an authenticated client.

        An API key is required to authenticate requests. The client will attempt to load it from a
        series of environment variables, `DATALAB_API_KEY` and prefixed versions for the given
        requested instance (e.g., `PUBLIC_DATALAB_API_KEY` for the public deployment
        which has prefix `public`).

        Parameters:
            datalab_api_url: The URL of the Datalab API.
                TODO: If the URL of a datalab *UI* is provided, a request will be made to attempt
                to resolve the underlying API URL (e.g., `https://public.datalab.odbx.science`
                will 'redirect' to `https://public.api.odbx.science`).
            log_level: The logging level to use for the client. Defaults to "WARNING".


        """

        self.datalab_api_url = datalab_api_url
        if not self.datalab_api_url:
            raise ValueError("No Datalab API URL provided.")
        if not self.datalab_api_url.startswith("http"):
            self.datalab_api_url = f"https://{self.datalab_api_url}"
        logging.basicConfig(level=log_level, handlers=[RichHandler()])
        self.log = logging.getLogger(__name__)

        self._http_client = httpx.Client
        self._headers["User-Agent"] = f"Datalab Python API/{__version__}"

        info_json = self.get_info()

        self._datalab_api_versions: list[str] = info_json["data"]["attributes"][
            "available_api_versions"
        ]
        self._datalab_server_version: str = info_json["data"]["attributes"]["server_version"]
        self._datalab_instance_prefix: Optional[str] = info_json["data"]["attributes"].get(
            "identifier_prefix"
        )

        self._find_api_key()

    def get_info(self) -> dict[str, Any]:
        raise NotImplementedError

    @property
    def session(self) -> httpx.Client:
        if self._session is None:
            return self._http_client(headers=self.headers)
        return self._session

    @property
    def headers(self):
        return self._headers

    def _version_negotiation(self):
        """Check whether this client is expected to work with this instance.

        Raises:
            RuntimeError: If the server version is not supported or if no supported API versions are found.

        """

        for available_api_version in sorted(self._datalab_api_versions):
            major, minor, _ = (int(_) for _ in available_api_version.split("."))
            if major == self.min_api_version[0] and minor == self.min_api_version[1]:
                self._selected_api_version = available_api_version
                break
        else:
            raise RuntimeError(f"No supported API versions found in {self._datalab_api_versions=}")

        if self._datalab_server_version in self.bad_server_versions:
            raise RuntimeError(
                f"Server version {self._datalab_server_version} is not supported by this client."
            )

    @property
    def api_key(self) -> str:
        """The API key used to authenticate requests to the Datalab API, passed
        as the `DATALAB-API-KEY` HTTP header.

        This can be retrieved by an authenticated user with the `/get-api-key`
        endpoint of a Datalab API.

        """
        if self._api_key is not None:
            return self._api_key
        return self._find_api_key()

    def _find_api_key(self) -> str:
        """Checks various environment variables for an API key and sets the value in the
        session headers.
        """
        if self._api_key is None:
            key_env_var = "DATALAB_API_KEY"

            api_key: Optional[str] = None

            # probe the prefixed environment variable first
            if self._datalab_instance_prefix is not None:
                api_key = os.getenv(f"{self._datalab_instance_prefix.upper()}_{key_env_var}")

            if api_key is None:
                api_key = os.getenv("DATALAB_API_KEY")

            # Remove single and double quotes around API key if present
            if api_key is not None:
                api_key = api_key.strip("'").strip('"')

            if api_key is None:
                raise ValueError(
                    f"No API key found in environment variables {key_env_var}/<prefix>_{key_env_var}."
                )

            self._api_key = api_key

            # Reset session as we are now updating the headers
            if self._session is not None:
                try:
                    self._session.close()
                except Exception:
                    pass
                finally:
                    self._session = None
            self._headers["DATALAB-API-KEY"] = self.api_key

        return self.api_key

    def __enter__(self) -> "BaseDatalabClient":
        return self

    def __exit__(self, *_):
        if self._session is not None:
            self._session.close()
