import requests
import sys
from datetime import datetime
from enum import StrEnum
from logging import Logger
from pypomes_core import TZ_LOCAL, exc_format
from pypomes_crypto import crypto_jwk_convert
from threading import Lock
from typing import Any, Final


class IamServer(StrEnum):
    """
    Supported IAM servers.
    """
    IAM_JUSRBR = "iam-jusbr",
    IAM_KEYCLOAK = "iam-keycloak"


# the logger for IAM service operations
# (used exclusively at the HTTP endpoint - all other functions receive the lgger as parameter)
__IAM_LOGGER: Logger | None = None

# registry structure:
# { <IamServer>:
#    {
#       "client-id": <str>,
#       "client-secret": <str>,
#       "client-timeout": <int>,
#       "recipient-attr": <str>,
#       "public_key": <str>,
#       "pk-lifetime": <int>,
#       "pk-expiration": <int>,
#       "base-url": <str>,
#       "cache": <FIFOCache>,
#       "redirect-uri": <str>  <-- transient
#    },
#    ...
# }
# data in "cache":
# {
#    "users": {
#       "<user-id>": {
#          "access-token": <str>
#          "refresh-token": <str>
#          "access-expiration": <timestamp>,
#          "refresh-expiration": <timestamp>,
#          "login-expiration": <timestamp>,    <-- transient
#          "login-id": <str>,                  <-- transient
#       }
#    },
#   ...
# }
_IAM_SERVERS: Final[dict[IamServer, dict[str, Any]]] = {}

# the lock protecting the data in '_IAM_SERVER'
# (because it is 'Final' and set at declaration time, it can be accessed through simple imports)
_iam_lock: Final[Lock] = Lock()


def _get_logger() -> Logger | None:
    """
    Retrieve the registered logger for *IAM* operations.

    This function is invoked exclusively from the HTTP endpoints.
    All other functions receive the logger as parameter.

    :return: the registered logger for *IAM* operations.
    """
    return __IAM_LOGGER


def _register_logger(logger: Logger) -> None:
    """
    Register the logger for *IAM* operations

    :param logger: the logger to be rergistered
    """
    global __IAM_LOGGER
    __IAM_LOGGER = logger


def _get_public_key(iam_server: IamServer,
                    errors: list[str] | None,
                    logger: Logger | None) -> str:
    """
    Obtain the public key used by *iam_server* to sign the authentication tokens.

    The public key is saved in *iam_server*'s registry.

    :param iam_server: the reference registered *IAM* server
    :param errors: incidental error messages
    :param logger: optional logger
    :return: the public key in *PEM* format, or *None* if the server is unknown
    """
    # initialize the return variable
    result: str | None = None

    registry: dict[str, Any] = _get_iam_registry(iam_server=iam_server,
                                                 errors=errors,
                                                 logger=logger)
    if registry:
        now: int = int(datetime.now(tz=TZ_LOCAL).timestamp())
        if now > registry["pk-expiration"]:
            # obtain a new public key
            url: str = f"{registry["base-url"]}/protocol/openid-connect/certs"
            if logger:
                logger.debug(msg=f"GET '{url}'")
            try:
                response: requests.Response = requests.get(url=url)
                if response.status_code == 200:
                    # request succeeded
                    if logger:
                        logger.debug(msg=f"GET success, status {response.status_code}")
                    reply: dict[str, Any] = response.json()
                    result = crypto_jwk_convert(jwk=reply["keys"][0],
                                                fmt="PEM")
                    registry["public-key"] = result
                    lifetime: int = registry["pk-lifetime"] or 0
                    registry["pk-expiration"] = now + lifetime
                elif logger:
                    msg: str = f"GET failure, status {response.status_code}, reason '{response.reason}'"
                    if hasattr(response, "content") and response.content:
                        msg += f", content '{response.content}'"
                    logger.error(msg=msg)
                    if isinstance(errors, list):
                        errors.append(msg)
            except Exception as e:
                # the operation raised an exception
                msg = exc_format(exc=e,
                                 exc_info=sys.exc_info())
                if logger:
                    logger.error(msg=msg)
                if isinstance(errors, list):
                    errors.append(msg)
        else:
            result = registry["public-key"]

    return result


def _get_login_timeout(iam_server: IamServer,
                       errors: list[str] | None,
                       logger: Logger) -> int | None:
    """
    Retrieve the timeout currently applicable for the login operation.

    :param iam_server: the reference registered *IAM* server
    :param errors: incidental error messages
    :param logger: optional logger
    :return: the current login timeout, or *None* if the server is unknown or none has been set.
    """
    # initialize the return variable
    result: int | None = None

    registry: dict[str, Any] = _get_iam_registry(iam_server=iam_server,
                                                 errors=errors,
                                                 logger=logger)
    if registry:
        timeout: int = registry.get("client-timeout")
        if isinstance(timeout, int) and timeout > 0:
            result = timeout

    return result


def _get_user_data(iam_server: IamServer,
                   user_id: str,
                   errors: list[str] | None,
                   logger: Logger | None) -> dict[str, Any] | None:
    """
    Retrieve the data for *user_id* from *iam_server*'s registry.

    If an entry is not found for *user_id* in the registry, it is created.
    It will remain there until the user is logged out.

    :param iam_server: the reference registered *IAM* server
    :param errors: incidental error messages
    :param logger: optional logger
    :return: the data for *user_id* in *iam_server*'s registry, or *None* if the server is unknown
    """
    # initialize the return variable
    result: dict[str, Any] | None = None

    users: dict[str, dict[str, Any]] = _get_iam_users(iam_server=iam_server,
                                                      errors=errors,
                                                      logger=logger)
    if users:
        result = users.get(user_id)
        if not result:
            result = {
                "access-token": None,
                "refresh-token": None,
                "access-expiration": int(datetime.now(tz=TZ_LOCAL).timestamp()),
                "refresh-expiration": sys.maxsize
            }
            users[user_id] = result
            if logger:
                logger.debug(msg=f"Entry for '{user_id}' added to {iam_server}'s registry")
        elif logger:
            logger.debug(msg=f"Entry for '{user_id}' obtained from {iam_server}'s registry")

    return result


def _get_iam_server(endpoint: str,
                    errors: list[str] | None,
                    logger: Logger | None) -> IamServer | None:
    """
    Retrieve the registered *IAM* server associated with the service's invocation *endpoint*.

    :param endpoint: the service's invocation endpoint
    :param errors: incidental error messages
    :param logger: optional logger
    :return: the corresponding *IAM* server, or *None* if one could not be obtained
    """
    # declare the return variable
    result: IamServer | None

    if endpoint.startswith("jusbr"):
        result = IamServer.IAM_JUSRBR
    elif endpoint.startswith("keycloak"):
        result = IamServer.IAM_KEYCLOAK
    else:
        result = None
        msg: str = f"Unknown endpoind {endpoint}"
        if logger:
            logger.error(msg=msg)
        if isinstance(errors, list):
            errors.append(msg)

    return result


def _get_iam_registry(iam_server: IamServer,
                      errors: list[str] | None,
                      logger: Logger | None) -> dict[str, Any]:
    """
    Retrieve the registry associated with *iam_server*.

    :param iam_server: the reference registered *IAM* server
    :param errors: incidental error messages
    :param logger: optional logger
    :return: the registry associated with *iam_server*, or *None* if the server is unknown
    """
    # declare the return variable
    result: dict[str, Any] | None

    match iam_server:
        case IamServer.IAM_JUSRBR:
            result = _IAM_SERVERS[IamServer.IAM_JUSRBR]
        case IamServer.IAM_KEYCLOAK:
            result = _IAM_SERVERS[IamServer.IAM_KEYCLOAK]
        case _:
            result = None
            msg = f"Unknown IAM server '{iam_server}'"
            if logger:
                logger.error(msg=msg)
            if isinstance(errors, list):
                errors.append(msg)

    return result


def _get_iam_users(iam_server: IamServer,
                   errors: list[str] | None,
                   logger: Logger | None) -> dict[str, dict[str, Any]]:
    """
    Retrieve the cache storage in *iam_server*'s registry.

    :param iam_server: the reference registered *IAM* server
    :param errors: incidental error messages
    :param logger: optional logger
    :return: the cache storage in *iam_server*'s registry, or *None* if the server is unknown
    """
    registry: dict[str, Any] = _get_iam_registry(iam_server=iam_server,
                                                 errors=errors,
                                                 logger=logger)
    return registry["cache"]["users"] if registry else None
