from __future__ import annotations

from typing import Any, AsyncGenerator, Dict, Iterator, List

from seekrai.abstract import api_requestor
from seekrai.seekrflow_response import SeekrFlowResponse
from seekrai.types import (
    ChatCompletionChunk,
    ChatCompletionRequest,
    ChatCompletionResponse,
    SeekrFlowClient,
    SeekrFlowRequest,
)


class ChatCompletions:
    def __init__(self, client: SeekrFlowClient) -> None:
        self._client = client

    def create(
        self,
        *,
        messages: List[Dict[str, str]],
        model: str,
        max_tokens: int | None = 512,
        stop: List[str] | None = None,
        temperature: float = 0.7,
        top_p: float = 1,
        top_k: int = 5,
        repetition_penalty: float = 1,
        stream: bool = False,
        logprobs: int = 0,
        echo: bool = False,
        n: int = 1,
        safety_model: str | None = None,
        response_format: Dict[str, str | Dict[str, Any]] | None = None,
        tools: Dict[str, str | Dict[str, Any]] | None = None,
        tool_choice: str | Dict[str, str | Dict[str, str]] | None = None,
    ) -> ChatCompletionResponse | Iterator[ChatCompletionChunk]:
        """
        Method to generate completions based on a given prompt using a specified model.

        Args:
            messages (List[Dict[str, str]]): A list of messages in the format
                `[{"role": seekrai.types.chat_completions.MessageRole, "content": TEXT}, ...]`
            model (str): The name of the model to query.
            max_tokens (int, optional): The maximum number of tokens to generate.
                Defaults to 512.
            stop (List[str], optional): List of strings at which to stop generation.
                Defaults to None.
            temperature (float, optional): A decimal number that determines the degree of randomness in the response.
                Defaults to None.
            top_p (float, optional): The top_p (nucleus) parameter is used to dynamically adjust the number
                    of choices for each predicted token based on the cumulative probabilities.
                Defaults to None.
            top_k (int, optional): The top_k parameter is used to limit the number of choices for the
                    next predicted word or token.
                Defaults to None.
            repetition_penalty (float, optional): A number that controls the diversity of generated text
                    by reducing the likelihood of repeated sequences. Higher values decrease repetition.
                Defaults to None.
            stream (bool, optional): Flag indicating whether to stream the generated completions.
                Defaults to False.
            logprobs (int, optional): Number of top-k logprobs to return
                Defaults to None.
            echo (bool, optional): Echo prompt in output. Can be used with logprobs to return prompt logprobs.
                Defaults to None.
            n (int, optional): Number of completions to generate. Setting to None will return a single generation.
                Defaults to None.
            safety_model (str, optional): A moderation model to validate tokens. Choice between available moderation
                    models found [here](https://docs.seekrflow.ai/docs/inference-models#moderation-models).
                Defaults to None.
            response_format (Dict[str, Any], optional): An object specifying the format that the model must output.
                Defaults to None.
            tools (Dict[str, str | Dict[str, str | Dict[str, Any]]], optional): A list of tools the model may call.
                    Currently, only functions are supported as a tool.
                    Use this to provide a list of functions the model may generate JSON inputs for.
                Defaults to None
            tool_choice: Controls which (if any) function is called by the model. auto means the model can pick
                    between generating a message or calling a function. Specifying a particular function
                    via {"type": "function", "function": {"name": "my_function"}} forces the model to call that function.
                    Sets to `auto` if None.
                Defaults to None.

        Returns:
            ChatCompletionResponse | Iterator[ChatCompletionChunk]: Object containing the completions
            or an iterator over completion chunks.
        """

        requestor = api_requestor.APIRequestor(
            client=self._client,
        )

        parameter_payload = ChatCompletionRequest(
            model=model,
            messages=messages,
            top_p=top_p,
            top_k=top_k,
            temperature=temperature,
            max_tokens=max_tokens,
            stop=stop,
            repetition_penalty=repetition_penalty,
            stream=stream,
            logprobs=logprobs,
            echo=echo,
            n=n,
            safety_model=safety_model,
            response_format=response_format,
            tools=tools,
            tool_choice=tool_choice,
        ).model_dump()

        response, _, _ = requestor.request(
            options=SeekrFlowRequest(
                method="POST",
                url="inference/chat/completions",
                params=parameter_payload,
            ),
            stream=stream,
        )

        if stream:
            # must be an iterator
            assert not isinstance(response, SeekrFlowResponse)
            return (ChatCompletionChunk(**line.data) for line in response)
        assert isinstance(response, SeekrFlowResponse)
        return ChatCompletionResponse(**response.data)


class AsyncChatCompletions:
    def __init__(self, client: SeekrFlowClient) -> None:
        self._client = client

    async def create(
        self,
        *,
        messages: List[Dict[str, str]],
        model: str,
        max_tokens: int | None = 512,
        stop: List[str] | None = None,
        temperature: float = 0.7,
        top_p: float = 1,
        top_k: int = 5,
        repetition_penalty: float = 1,
        stream: bool = False,
        logprobs: int = 0,
        echo: bool = False,
        n: int = 1,
        safety_model: str | None = None,
        response_format: Dict[str, str | Dict[str, Any]] | None = None,
        # tools: Dict[str, str | Dict[str, Any]] | None = None,
        # tool_choice: str | Dict[str, str | Dict[str, str]] | None = None,
    ) -> AsyncGenerator[ChatCompletionChunk, None] | ChatCompletionResponse:
        """
        Async method to generate completions based on a given prompt using a specified model.

        Args:
            messages (List[Dict[str, str]]): A list of messages in the format
                `[{"role": seekrai.types.chat_completions.MessageRole, "content": TEXT}, ...]`
            model (str): The name of the model to query.
            max_tokens (int, optional): The maximum number of tokens to generate.
                Defaults to 512.
            stop (List[str], optional): List of strings at which to stop generation.
                Defaults to None.
            temperature (float, optional): A decimal number that determines the degree of randomness in the response.
                Defaults to None.
            top_p (float, optional): The top_p (nucleus) parameter is used to dynamically adjust the number
                    of choices for each predicted token based on the cumulative probabilities.
                Defaults to None.
            top_k (int, optional): The top_k parameter is used to limit the number of choices for the
                    next predicted word or token.
                Defaults to None.
            repetition_penalty (float, optional): A number that controls the diversity of generated text
                    by reducing the likelihood of repeated sequences. Higher values decrease repetition.
                Defaults to None.
            stream (bool, optional): Flag indicating whether to stream the generated completions.
                Defaults to False.
            logprobs (int, optional): Number of top-k logprobs to return
                Defaults to None.
            echo (bool, optional): Echo prompt in output. Can be used with logprobs to return prompt logprobs.
                Defaults to None.
            n (int, optional): Number of completions to generate. Setting to None will return a single generation.
                Defaults to None.
            safety_model (str, optional): A moderation model to validate tokens. Choice between available moderation
                    models found [here](https://docs.seekrflow.ai/docs/inference-models#moderation-models).
                Defaults to None.
            response_format (Dict[str, Any], optional): An object specifying the format that the model must output.
                Defaults to None.
            tools (Dict[str, str | Dict[str, str | Dict[str, Any]]], optional): A list of tools the model may call.
                    Currently, only functions are supported as a tool.
                    Use this to provide a list of functions the model may generate JSON inputs for.
                Defaults to None
            tool_choice: Controls which (if any) function is called by the model. auto means the model can pick
                    between generating a message or calling a function. Specifying a particular function
                    via {"type": "function", "function": {"name": "my_function"}} forces the model to call that function.
                    Sets to `auto` if None.
                Defaults to None.

        Returns:
            AsyncGenerator[ChatCompletionChunk, None] | ChatCompletionResponse: Object containing the completions
            or an iterator over completion chunks.
        """

        requestor = api_requestor.APIRequestor(
            client=self._client,
        )

        parameter_payload = ChatCompletionRequest(
            model=model,
            messages=messages,
            top_p=top_p,
            top_k=top_k,
            temperature=temperature,
            max_tokens=max_tokens,
            stop=stop,
            repetition_penalty=repetition_penalty,
            stream=stream,
            logprobs=logprobs,
            echo=echo,
            n=n,
            safety_model=safety_model,
            response_format=response_format,
            # tools=tools,
            # tool_choice=tool_choice,
        ).model_dump()

        response, _, _ = await requestor.arequest(
            options=SeekrFlowRequest(
                method="POST",
                url="inference/chat/completions",
                params=parameter_payload,
            ),
            stream=stream,
        )

        if stream:
            # must be an iterator
            assert not isinstance(response, SeekrFlowResponse)
            return (ChatCompletionChunk(**line.data) async for line in response)
        assert isinstance(response, SeekrFlowResponse)
        return ChatCompletionResponse(**response.data)
