from typing import Optional
from uuid import uuid4

from openai import OpenAI
from openai.resources.chat import Completions
from typing_extensions import override

from ...scribe import scribe
from ..logger import Generation, GenerationConfig, Logger, Trace, TraceConfig
from .utils import OpenAIUtils


class MaximOpenAIChatCompletions(Completions):
    def __init__(self, client: OpenAI, logger: Logger):
        super().__init__(client=client)
        self._logger = logger

    @override
    def create(self, *args, **kwargs):
        extra_headers = kwargs.get("extra_headers", None)
        trace_id = None
        generation_name = None
        if extra_headers is not None:
            trace_id = extra_headers.get("x-maxim-trace-id", None)
            generation_name = extra_headers.get("x-maxim-generation-name", None)
        is_local_trace = trace_id is None
        model = kwargs.get("model", None)
        final_trace_id = trace_id or str(uuid4())
        generation: Optional[Generation] = None
        trace: Optional[Trace] = None
        messages = kwargs.get("messages", None)
        try:
            trace = self._logger.trace(TraceConfig(id=final_trace_id))
            gen_config = GenerationConfig(
                id=str(uuid4()),
                model=model,
                provider="openai",
                name=generation_name,
                model_parameters=OpenAIUtils.get_model_params(**kwargs),
                messages=OpenAIUtils.parse_message_param(messages),
            )
            generation = trace.generation(gen_config)
        except Exception as e:
            scribe().warning(
                f"[MaximSDK][MaximOpenAIChatCompletions] Error in generating content: {str(e)}"
            )

        response = super().create(*args, **kwargs)

        try:
            if generation is not None:
                generation.result(OpenAIUtils.parse_completion(response))
            if is_local_trace and trace is not None:
                trace.set_output(response.choices[0].message.content or "")
                trace.end()
        except Exception as e:
            scribe().warning(
                f"[MaximSDK][MaximOpenAIChatCompletions] Error in logging generation: {str(e)}"
            )

        return response
