"""Utils for using the OpenAI API."""
import inspect
import time
from functools import wraps
from typing import TYPE_CHECKING

import openai
from loguru import logger

from .chat_configs import OpenAiApiCallOptions
from .tokens import get_n_tokens_from_msgs

if TYPE_CHECKING:
    from .chat import Chat


class CannotConnectToApiError(Exception):
    """Error raised when the package cannot connect to the OpenAI API."""


def retry_api_call(max_n_attempts=5, auth_error_msg="Problems connecting to OpenAI API."):
    """Retry connecting to the API up to a maximum number of times."""

    def on_error(error, n_attempts):
        if n_attempts < max_n_attempts:
            logger.warning(
                "{}. Making new attempt ({}/{})...", error, n_attempts + 1, max_n_attempts
            )
            time.sleep(1)
        else:
            raise CannotConnectToApiError(auth_error_msg) from error

    def retry_api_call_decorator(function):
        """Wrap `function` and log beginning, exit and elapsed time."""

        @wraps(function)
        def wrapper_f(*args, **kwargs):
            n_attempts = 0
            while True:
                n_attempts += 1
                try:
                    return function(*args, **kwargs)
                except openai.APITimeoutError as error:
                    on_error(error=error, n_attempts=n_attempts)
                except (openai.APIError, openai.OpenAIError) as error:
                    raise CannotConnectToApiError(auth_error_msg) from error

        @wraps(function)
        def wrapper_generator_f(*args, **kwargs):
            n_attempts = 0
            success = False
            while not success:
                n_attempts += 1
                try:
                    yield from function(*args, **kwargs)
                except openai.APITimeoutError as error:
                    on_error(error=error, n_attempts=n_attempts)
                except (openai.APIError, openai.OpenAIError) as error:
                    raise CannotConnectToApiError(auth_error_msg) from error
                else:
                    success = True

        return wrapper_generator_f if inspect.isgeneratorfunction(function) else wrapper_f

    return retry_api_call_decorator


def make_api_chat_completion_call(conversation: list, chat_obj: "Chat"):
    """Stream a chat completion from OpenAI API given a conversation and a chat object.

    Args:
        conversation (list): A list of messages passed as input for the completion.
        chat_obj (Chat): Chat object containing the configurations for the chat.

    Yields:
        str: Chunks of text generated by the API in response to the conversation.
    """
    api_call_args = {}
    for field in OpenAiApiCallOptions.model_fields:
        if getattr(chat_obj, field) is not None:
            api_call_args[field] = getattr(chat_obj, field)

    @retry_api_call(auth_error_msg=chat_obj.api_connection_error_msg)
    def stream_reply(conversation, **api_call_args):
        # Update the chat's token usage database with tokens used in chat input
        # Do this here because every attempt consumes tokens, even if it fails
        n_tokens = get_n_tokens_from_msgs(messages=conversation, model=chat_obj.model)
        for db in [chat_obj.general_token_usage_db, chat_obj.token_usage_db]:
            db.insert_data(model=chat_obj.model, n_input_tokens=n_tokens)

        full_reply_content = ""
        for completion_chunk in openai.chat.completions.create(
            messages=conversation, stream=True, **api_call_args
        ):
            reply_chunk = getattr(completion_chunk.choices[0].delta, "content", "")
            if reply_chunk is None:
                break
            full_reply_content += reply_chunk
            yield reply_chunk

        # Update the chat's token usage database with tokens used in chat output
        reply_as_msg = {"role": "assistant", "content": full_reply_content}
        n_tokens = get_n_tokens_from_msgs(messages=[reply_as_msg], model=chat_obj.model)
        for db in [chat_obj.general_token_usage_db, chat_obj.token_usage_db]:
            db.insert_data(model=chat_obj.model, n_output_tokens=n_tokens)

    yield from stream_reply(conversation, **api_call_args)
