# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import warnings
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union

import requests

from ..common import streaming_response_iterator

if TYPE_CHECKING:
    from ...types import (
        ChatCompletion,
        ChatCompletionChunk,
        ChatCompletionMessage,
        ChatglmCppGenerateConfig,
        Completion,
        CompletionChunk,
        Embedding,
        ImageList,
        LlamaCppGenerateConfig,
        PytorchGenerateConfig,
    )


def _get_error_string(response: requests.Response) -> str:
    try:
        if response.content:
            return response.json()["detail"]
    except Exception:
        pass
    try:
        response.raise_for_status()
    except requests.HTTPError as e:
        return str(e)
    return "Unknown error"


class RESTfulModelHandle:
    """
    A sync model interface (for RESTful client) which provides type hints that makes it much easier to use xinference
    programmatically.
    """

    def __init__(self, model_uid: str, base_url: str, auth_headers: Dict):
        self._model_uid = model_uid
        self._base_url = base_url
        self.auth_headers = auth_headers


class RESTfulEmbeddingModelHandle(RESTfulModelHandle):
    def create_embedding(self, input: Union[str, List[str]]) -> "Embedding":
        """
        Create an Embedding from user input via RESTful APIs.

        Parameters
        ----------
        input: Union[str, List[str]]
            Input text to embed, encoded as a string or array of tokens.
            To embed multiple inputs in a single request, pass an array of strings or array of token arrays.

        Returns
        -------
        Embedding
           The resulted Embedding vector that can be easily consumed by machine learning models and algorithms.

        Raises
        ------
        RuntimeError
            Report the failure of embeddings and provide the error message.

        """
        url = f"{self._base_url}/v1/embeddings"
        request_body = {"model": self._model_uid, "input": input}
        response = requests.post(url, json=request_body, headers=self.auth_headers)
        if response.status_code != 200:
            raise RuntimeError(
                f"Failed to create the embeddings, detail: {_get_error_string(response)}"
            )

        response_data = response.json()
        return response_data


class RESTfulRerankModelHandle(RESTfulModelHandle):
    def rerank(
        self,
        documents: List[str],
        query: str,
        top_n: Optional[int] = None,
        max_chunks_per_doc: Optional[int] = None,
        return_documents: Optional[bool] = None,
    ):
        """
        Returns an ordered list of documents ordered by their relevance to the provided query.

        Parameters
        ----------
        query: str
            The search query
        documents: List[str]
            The documents to rerank
        top_n: int
            The number of results to return, defaults to returning all results
        max_chunks_per_doc: int
            The maximum number of chunks derived from a document
        return_documents: bool
            if return documents
        Returns
        -------
        Scores
           The scores of documents ordered by their relevance to the provided query

        Raises
        ------
        RuntimeError
            Report the failure of rerank and provide the error message.
        """
        url = f"{self._base_url}/v1/rerank"
        request_body = {
            "model": self._model_uid,
            "documents": documents,
            "query": query,
            "top_n": top_n,
            "max_chunks_per_doc": max_chunks_per_doc,
            "return_documents": return_documents,
        }
        response = requests.post(url, json=request_body, headers=self.auth_headers)
        if response.status_code != 200:
            raise RuntimeError(
                f"Failed to rerank documents, detail: {response.json()['detail']}"
            )
        response_data = response.json()
        for r in response_data["results"]:
            r["document"] = documents[r["index"]]
        return response_data


class RESTfulImageModelHandle(RESTfulModelHandle):
    def text_to_image(
        self,
        prompt: str,
        n: int = 1,
        size: str = "1024*1024",
        response_format: str = "url",
        **kwargs,
    ) -> "ImageList":
        """
        Creates an image by the input text.

        Parameters
        ----------
        prompt: `str` or `List[str]`
            The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
        n: `int`, defaults to 1
            The number of images to generate per prompt. Must be between 1 and 10.
        size: `str`, defaults to `1024*1024`
            The width*height in pixels of the generated image. Must be one of 256x256, 512x512, or 1024x1024.
        response_format: `str`, defaults to `url`
            The format in which the generated images are returned. Must be one of url or b64_json.
        Returns
        -------
        ImageList
            A list of image objects.
        """
        url = f"{self._base_url}/v1/images/generations"
        request_body = {
            "model": self._model_uid,
            "prompt": prompt,
            "n": n,
            "size": size,
            "response_format": response_format,
            "kwargs": json.dumps(kwargs),
        }
        response = requests.post(url, json=request_body, headers=self.auth_headers)
        if response.status_code != 200:
            raise RuntimeError(
                f"Failed to create the images, detail: {_get_error_string(response)}"
            )

        response_data = response.json()
        return response_data

    def image_to_image(
        self,
        image: Union[str, bytes],
        prompt: str,
        negative_prompt: str,
        n: int = 1,
        size: str = "1024*1024",
        response_format: str = "url",
        **kwargs,
    ) -> "ImageList":
        """
        Creates an image by the input text.

        Parameters
        ----------
        image: `Union[str, bytes]`
            The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
            specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
            accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
            and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
            `init`, images must be passed as a list such that each element of the list can be correctly batched for
            input to a single ControlNet.
        prompt: `str` or `List[str]`
            The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
        negative_prompt (`str` or `List[str]`, *optional*):
            The prompt or prompts not to guide the image generation. If not defined, one has to pass
            `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
            less than `1`).
        n: `int`, defaults to 1
            The number of images to generate per prompt. Must be between 1 and 10.
        size: `str`, defaults to `1024*1024`
            The width*height in pixels of the generated image. Must be one of 256x256, 512x512, or 1024x1024.
        response_format: `str`, defaults to `url`
            The format in which the generated images are returned. Must be one of url or b64_json.
        Returns
        -------
        ImageList
            A list of image objects.
            :param prompt:
            :param image:
        """
        url = f"{self._base_url}/v1/images/variations"
        params = {
            "model": self._model_uid,
            "prompt": prompt,
            "negative_prompt": negative_prompt,
            "n": n,
            "size": size,
            "response_format": response_format,
            "kwargs": json.dumps(kwargs),
        }
        files: List[Any] = []
        for key, value in params.items():
            files.append((key, (None, value)))
        files.append(("image", ("image", image, "application/octet-stream")))
        response = requests.post(url, files=files, headers=self.auth_headers)
        if response.status_code != 200:
            raise RuntimeError(
                f"Failed to variants the images, detail: {_get_error_string(response)}"
            )

        response_data = response.json()
        return response_data


class RESTfulGenerateModelHandle(RESTfulEmbeddingModelHandle):
    def generate(
        self,
        prompt: str,
        generate_config: Optional[
            Union["LlamaCppGenerateConfig", "PytorchGenerateConfig"]
        ] = None,
    ) -> Union["Completion", Iterator["CompletionChunk"]]:
        """
        Creates a completion for the provided prompt and parameters via RESTful APIs.

        Parameters
        ----------
        prompt: str
            The user's message or user's input.
        generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig"]]
            Additional configuration for the chat generation.
            "LlamaCppGenerateConfig" -> Configuration for ggml model
            "PytorchGenerateConfig" -> Configuration for pytorch model

        Returns
        -------
        Union["Completion", Iterator["CompletionChunk"]]
            Stream is a parameter in generate_config.
            When stream is set to True, the function will return Iterator["CompletionChunk"].
            When stream is set to False, the function will return "Completion".

        Raises
        ------
        RuntimeError
            Fail to generate the completion from the server. Detailed information provided in error message.

        """

        url = f"{self._base_url}/v1/completions"

        request_body: Dict[str, Any] = {"model": self._model_uid, "prompt": prompt}
        if generate_config is not None:
            for key, value in generate_config.items():
                request_body[key] = value

        stream = bool(generate_config and generate_config.get("stream"))

        response = requests.post(
            url, json=request_body, stream=stream, headers=self.auth_headers
        )
        if response.status_code != 200:
            raise RuntimeError(
                f"Failed to generate completion, detail: {_get_error_string(response)}"
            )

        if stream:
            return streaming_response_iterator(response.iter_lines())

        response_data = response.json()
        return response_data


class RESTfulChatModelHandle(RESTfulGenerateModelHandle):
    def chat(
        self,
        prompt: str,
        system_prompt: Optional[str] = None,
        chat_history: Optional[List["ChatCompletionMessage"]] = None,
        tools: Optional[List[Dict]] = None,
        generate_config: Optional[
            Union["LlamaCppGenerateConfig", "PytorchGenerateConfig"]
        ] = None,
    ) -> Union["ChatCompletion", Iterator["ChatCompletionChunk"]]:
        """
        Given a list of messages comprising a conversation, the model will return a response via RESTful APIs.

        Parameters
        ----------
        prompt: str
            The user's input.
        system_prompt: Optional[str]
            The system context provide to Model prior to any chats.
        chat_history: Optional[List["ChatCompletionMessage"]]
            A list of messages comprising the conversation so far.
        tools: Optional[List[Dict]]
            A tool list.
        generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig"]]
            Additional configuration for the chat generation.
            "LlamaCppGenerateConfig" -> configuration for ggml model
            "PytorchGenerateConfig" -> configuration for pytorch model

        Returns
        -------
        Union["ChatCompletion", Iterator["ChatCompletionChunk"]]
            Stream is a parameter in generate_config.
            When stream is set to True, the function will return Iterator["ChatCompletionChunk"].
            When stream is set to False, the function will return "ChatCompletion".

        Raises
        ------
        RuntimeError
            Report the failure to generate the chat from the server. Detailed information provided in error message.

        """

        url = f"{self._base_url}/v1/chat/completions"

        if chat_history is None:
            chat_history = []

        if chat_history and chat_history[0]["role"] == "system":
            if system_prompt is not None:
                chat_history[0]["content"] = system_prompt

        else:
            if system_prompt is not None:
                chat_history.insert(0, {"role": "system", "content": system_prompt})

        chat_history.append({"role": "user", "content": prompt})

        request_body: Dict[str, Any] = {
            "model": self._model_uid,
            "messages": chat_history,
        }
        if tools is not None:
            request_body["tools"] = tools
        if generate_config is not None:
            for key, value in generate_config.items():
                request_body[key] = value

        stream = bool(generate_config and generate_config.get("stream"))
        response = requests.post(
            url, json=request_body, stream=stream, headers=self.auth_headers
        )

        if response.status_code != 200:
            raise RuntimeError(
                f"Failed to generate chat completion, detail: {_get_error_string(response)}"
            )

        if stream:
            return streaming_response_iterator(response.iter_lines())

        response_data = response.json()
        return response_data


class RESTfulMultimodalModelHandle(RESTfulModelHandle):
    def chat(
        self,
        prompt: Any,
        system_prompt: Optional[str] = None,
        chat_history: Optional[List["ChatCompletionMessage"]] = None,
        tools: Optional[List[Dict]] = None,
        generate_config: Optional[
            Union["LlamaCppGenerateConfig", "PytorchGenerateConfig"]
        ] = None,
    ) -> Union["ChatCompletion", Iterator["ChatCompletionChunk"]]:
        """
        Given a list of messages comprising a conversation, the model will return a response via RESTful APIs.

        Parameters
        ----------
        prompt: str
            The user's input.
        system_prompt: Optional[str]
            The system context provide to Model prior to any chats.
        chat_history: Optional[List["ChatCompletionMessage"]]
            A list of messages comprising the conversation so far.
        tools: Optional[List[Dict]]
            A tool list.
        generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig"]]
            Additional configuration for the chat generation.
            "LlamaCppGenerateConfig" -> configuration for ggml model
            "PytorchGenerateConfig" -> configuration for pytorch model

        Returns
        -------
        Union["ChatCompletion", Iterator["ChatCompletionChunk"]]
            Stream is a parameter in generate_config.
            When stream is set to True, the function will return Iterator["ChatCompletionChunk"].
            When stream is set to False, the function will return "ChatCompletion".

        Raises
        ------
        RuntimeError
            Report the failure to generate the chat from the server. Detailed information provided in error message.

        """

        url = f"{self._base_url}/v1/chat/completions"

        if chat_history is None:
            chat_history = []

        if chat_history and chat_history[0]["role"] == "system":
            if system_prompt is not None:
                chat_history[0]["content"] = system_prompt

        else:
            if system_prompt is not None:
                chat_history.insert(0, {"role": "system", "content": system_prompt})

        chat_history.append({"role": "user", "content": prompt})

        request_body: Dict[str, Any] = {
            "model": self._model_uid,
            "messages": chat_history,
        }
        if tools is not None:
            raise RuntimeError("Multimodal does not support function call.")

        if generate_config is not None:
            for key, value in generate_config.items():
                request_body[key] = value

        stream = bool(generate_config and generate_config.get("stream"))
        response = requests.post(
            url, json=request_body, stream=stream, headers=self.auth_headers
        )

        if response.status_code != 200:
            raise RuntimeError(
                f"Failed to generate chat completion, detail: {_get_error_string(response)}"
            )

        if stream:
            return streaming_response_iterator(response.iter_lines())

        response_data = response.json()
        return response_data


class RESTfulChatglmCppChatModelHandle(RESTfulEmbeddingModelHandle):
    def chat(
        self,
        prompt: str,
        chat_history: Optional[List["ChatCompletionMessage"]] = None,
        tools: Optional[List[Dict]] = None,
        generate_config: Optional["ChatglmCppGenerateConfig"] = None,
    ) -> Union["ChatCompletion", Iterator["ChatCompletionChunk"]]:
        """
        Given a list of messages comprising a conversation, the ChatGLM model will return a response via RESTful APIs.

        Parameters
        ----------
        prompt: str
            The user's input.
        chat_history: Optional[List["ChatCompletionMessage"]]
            A list of messages comprising the conversation so far.
        tools: Optional[List[Dict]]
            A tool list.
        generate_config: Optional["ChatglmCppGenerateConfig"]
            Additional configuration for ChatGLM chat generation.

        Returns
        -------
        Union["ChatCompletion", Iterator["ChatCompletionChunk"]]
            Stream is a parameter in generate_config.
            When stream is set to True, the function will return Iterator["ChatCompletionChunk"].
            When stream is set to False, the function will return "ChatCompletion".

        Raises
        ------
        RuntimeError
            Report the failure to generate the chat from the server. Detailed information provided in error message.

        """

        url = f"{self._base_url}/v1/chat/completions"

        if chat_history is None:
            chat_history = []

        chat_history.append({"role": "user", "content": prompt})

        request_body: Dict[str, Any] = {
            "model": self._model_uid,
            "messages": chat_history,
        }
        if tools is not None:
            request_body["tools"] = tools
        if generate_config is not None:
            for key, value in generate_config.items():
                request_body[key] = value

        stream = bool(generate_config and generate_config.get("stream"))
        response = requests.post(
            url, json=request_body, stream=stream, headers=self.auth_headers
        )

        if response.status_code != 200:
            raise RuntimeError(
                f"Failed to generate chat completion, detail: {_get_error_string(response)}"
            )

        if stream:
            return streaming_response_iterator(response.iter_lines())

        response_data = response.json()
        return response_data


class RESTfulChatglmCppGenerateModelHandle(RESTfulChatglmCppChatModelHandle):
    def generate(
        self,
        prompt: str,
        generate_config: Optional["ChatglmCppGenerateConfig"] = None,
    ) -> Union["Completion", Iterator["CompletionChunk"]]:
        """
        Given a prompt, the ChatGLM model will generate a response via RESTful APIs.

        Parameters
        ----------
        prompt: str
            The user's input.
        generate_config: Optional["ChatglmCppGenerateConfig"]
            Additional configuration for ChatGLM chat generation.

        Returns
        -------
        Union["Completion", Iterator["CompletionChunk"]]
            Stream is a parameter in generate_config.
            When stream is set to True, the function will return Iterator["CompletionChunk"].
            When stream is set to False, the function will return "Completion".

        Raises
        ------
        RuntimeError
            Report the failure to generate the content from the server. Detailed information provided in error message.

        """

        url = f"{self._base_url}/v1/completions"

        request_body: Dict[str, Any] = {"model": self._model_uid, "prompt": prompt}
        if generate_config is not None:
            for key, value in generate_config.items():
                request_body[key] = value

        stream = bool(generate_config and generate_config.get("stream"))

        response = requests.post(
            url, json=request_body, stream=stream, headers=self.auth_headers
        )
        if response.status_code != 200:
            raise RuntimeError(
                f"Failed to generate completion, detail: {response.json()['detail']}"
            )

        if stream:
            return streaming_response_iterator(response.iter_lines())

        response_data = response.json()
        return response_data


class Client:
    def __init__(self, base_url):
        self.base_url = base_url
        self._headers = {}
        self._cluster_authed = False
        self._check_cluster_authenticated()

    def _set_token(self, token: Optional[str]):
        if not self._cluster_authed or token is None:
            return
        self._headers["Authorization"] = f"Bearer {token}"

    def _get_token(self) -> Optional[str]:
        return (
            str(self._headers["Authorization"]).replace("Bearer ", "")
            if "Authorization" in self._headers
            else None
        )

    def _check_cluster_authenticated(self):
        url = f"{self.base_url}/v1/cluster/auth"
        response = requests.get(url)
        if response.status_code != 200:
            raise RuntimeError(
                f"Failed to get cluster information, detail: {response.json()['detail']}"
            )
        response_data = response.json()
        self._cluster_authed = bool(response_data["auth"])

    def login(self, username: str, password: str):
        if not self._cluster_authed:
            return
        url = f"{self.base_url}/token"

        payload = {"username": username, "password": password}

        response = requests.post(url, json=payload)
        if response.status_code != 200:
            raise RuntimeError(f"Failed to login, detail: {response.json()['detail']}")

        response_data = response.json()
        # Only bearer token for now
        access_token = response_data["access_token"]
        self._headers["Authorization"] = f"Bearer {access_token}"

    def list_models(self) -> Dict[str, Dict[str, Any]]:
        """
        Retrieve the model specifications from the Server.

        Returns
        -------
        Dict[str, Dict[str, Any]]
            The collection of model specifications with their names on the server.

        """

        url = f"{self.base_url}/v1/models"

        response = requests.get(url, headers=self._headers)
        if response.status_code != 200:
            raise RuntimeError(
                f"Failed to list model, detail: {_get_error_string(response)}"
            )

        response_data = response.json()
        return response_data

    def launch_speculative_llm(
        self,
        model_name: str,
        model_size_in_billions: Optional[int],
        quantization: Optional[str],
        draft_model_name: str,
        draft_model_size_in_billions: Optional[int],
        draft_quantization: Optional[str],
        n_gpu: Optional[Union[int, str]] = "auto",
    ):
        """
        Launch the LLM along with a draft model based on the parameters on the server via RESTful APIs. This is an
        experimental feature and the API may change in the future.

        Returns
        -------
        str
            The unique model_uid for the launched model.

        """
        warnings.warn(
            "`launch_speculative_llm` is an experimental feature and the API may change in the future."
        )

        payload = {
            "model_uid": None,
            "model_name": model_name,
            "model_size_in_billions": model_size_in_billions,
            "quantization": quantization,
            "draft_model_name": draft_model_name,
            "draft_model_size_in_billions": draft_model_size_in_billions,
            "draft_quantization": draft_quantization,
            "n_gpu": n_gpu,
        }

        url = f"{self.base_url}/experimental/speculative_llms"
        response = requests.post(url, json=payload, headers=self._headers)
        if response.status_code != 200:
            raise RuntimeError(
                f"Failed to launch model, detail: {_get_error_string(response)}"
            )

        response_data = response.json()
        model_uid = response_data["model_uid"]
        return model_uid

    def launch_model(
        self,
        model_name: str,
        model_type: str = "LLM",
        model_uid: Optional[str] = None,
        model_size_in_billions: Optional[Union[int, str]] = None,
        model_format: Optional[str] = None,
        quantization: Optional[str] = None,
        replica: int = 1,
        n_gpu: Optional[Union[int, str]] = "auto",
        request_limits: Optional[int] = None,
        **kwargs,
    ) -> str:
        """
        Launch the model based on the parameters on the server via RESTful APIs.

        Parameters
        ----------
        model_name: str
            The name of model.
        model_type: str
            type of model.
        model_uid: str
            UID of model, auto generate a UUID if is None.
        model_size_in_billions: Optional[int]
            The size (in billions) of the model.
        model_format: Optional[str]
            The format of the model.
        quantization: Optional[str]
            The quantization of model.
        replica: Optional[int]
            The replica of model, default is 1.
        n_gpu: Optional[Union[int, str]],
            The number of GPUs used by the model, default is "auto".
            ``n_gpu=None`` means cpu only, ``n_gpu=auto`` lets the system automatically determine the best number of GPUs to use.
        request_limits: Optional[int]
            The number of request limits for this model， default is None.
            ``request_limits=None`` means no limits for this model.
        **kwargs:
            Any other parameters been specified.

        Returns
        -------
        str
            The unique model_uid for the launched model.

        """

        url = f"{self.base_url}/v1/models"

        payload = {
            "model_uid": model_uid,
            "model_name": model_name,
            "model_type": model_type,
            "model_size_in_billions": model_size_in_billions,
            "model_format": model_format,
            "quantization": quantization,
            "replica": replica,
            "n_gpu": n_gpu,
            "request_limits": request_limits,
        }

        for key, value in kwargs.items():
            payload[str(key)] = value

        response = requests.post(url, json=payload, headers=self._headers)
        if response.status_code != 200:
            raise RuntimeError(
                f"Failed to launch model, detail: {_get_error_string(response)}"
            )

        response_data = response.json()
        return response_data["model_uid"]

    def terminate_model(self, model_uid: str):
        """
        Terminate the specific model running on the server.

        Parameters
        ----------
        model_uid: str
            The unique id that identify the model we want.

        Raises
        ------
        RuntimeError
            Report failure to get the wanted model with given model_uid. Provide details of failure through error message.

        """

        url = f"{self.base_url}/v1/models/{model_uid}"

        response = requests.delete(url, headers=self._headers)
        if response.status_code != 200:
            raise RuntimeError(
                f"Failed to terminate model, detail: {_get_error_string(response)}"
            )

    def _get_supervisor_internal_address(self):
        url = f"{self.base_url}/v1/address"
        response = requests.get(url, headers=self._headers)
        if response.status_code != 200:
            raise RuntimeError(f"Failed to get supervisor internal address")
        response_data = response.json()
        return response_data

    def get_model(self, model_uid: str) -> RESTfulModelHandle:
        """
        Launch the model based on the parameters on the server via RESTful APIs.

        Parameters
        ----------
        model_uid: str
            The unique id that identify the model.

        Returns
        -------
        ModelHandle
            The corresponding Model Handler based on the Model specified in the uid:
              - :obj:`xinference.client.handlers.ChatglmCppChatModelHandle` -> provide handle to ChatGLM Model
              - :obj:`xinference.client.handlers.GenerateModelHandle` -> provide handle to basic generate Model. e.g. Baichuan.
              - :obj:`xinference.client.handlers.ChatModelHandle` -> provide handle to chat Model. e.g. Baichuan-chat.


        Raises
        ------
        RuntimeError
            Report failure to get the wanted model with given model_uid. Provide details of failure through error message.

        """

        url = f"{self.base_url}/v1/models/{model_uid}"
        response = requests.get(url, headers=self._headers)
        if response.status_code != 200:
            raise RuntimeError(
                f"Failed to get the model description, detail: {_get_error_string(response)}"
            )
        desc = response.json()

        if desc["model_type"] == "LLM":
            if desc["model_format"] == "ggmlv3" and "chatglm" in desc["model_name"]:
                return RESTfulChatglmCppGenerateModelHandle(
                    model_uid, self.base_url, auth_headers=self._headers
                )
            elif "chat" in desc["model_ability"]:
                return RESTfulChatModelHandle(
                    model_uid, self.base_url, auth_headers=self._headers
                )
            elif "generate" in desc["model_ability"]:
                return RESTfulGenerateModelHandle(
                    model_uid, self.base_url, auth_headers=self._headers
                )
            else:
                raise ValueError(f"Unrecognized model ability: {desc['model_ability']}")
        elif desc["model_type"] == "embedding":
            return RESTfulEmbeddingModelHandle(
                model_uid, self.base_url, auth_headers=self._headers
            )
        elif desc["model_type"] == "image":
            return RESTfulImageModelHandle(
                model_uid, self.base_url, auth_headers=self._headers
            )
        elif desc["model_type"] == "rerank":
            return RESTfulRerankModelHandle(
                model_uid, self.base_url, auth_headers=self._headers
            )
        elif desc["model_type"] == "multimodal":
            return RESTfulMultimodalModelHandle(
                model_uid, self.base_url, auth_headers=self._headers
            )
        else:
            raise ValueError(f"Unknown model type:{desc['model_type']}")

    def describe_model(self, model_uid: str):
        """
        Get model information via RESTful APIs.

        Parameters
        ----------
        model_uid: str
            The unique id that identify the model.

        Returns
        -------
        dict
            A dictionary containing the following keys:

            - "model_type": str
               the type of the model determined by its function, e.g. "LLM" (Large Language Model)
            - "model_name": str
               the name of the specific LLM model family
            - "model_lang": List[str]
               the languages supported by the LLM model
            - "model_ability": List[str]
               the ability or capabilities of the LLM model
            - "model_description": str
               a detailed description of the LLM model
            - "model_format": str
               the format specification of the LLM model
            - "model_size_in_billions": int
               the size of the LLM model in billions
            - "quantization": str
               the quantization applied to the model
            - "revision": str
               the revision number of the LLM model specification
            - "context_length": int
               the maximum text length the LLM model can accommodate (include all input & output)

        Raises
        ------
        RuntimeError
            Report failure to get the wanted model with given model_uid. Provide details of failure through error message.

        """

        url = f"{self.base_url}/v1/models/{model_uid}"
        response = requests.get(url, headers=self._headers)
        if response.status_code != 200:
            raise RuntimeError(
                f"Failed to get the model description, detail: {_get_error_string(response)}"
            )
        return response.json()

    def register_model(self, model_type: str, model: str, persist: bool):
        """
        Register a custom model.

        Parameters
        ----------
        model_type: str
            The type of model.
        model: str
            The model definition. (refer to: https://inference.readthedocs.io/en/latest/models/custom.html)
        persist: bool


        Raises
        ------
        RuntimeError
            Report failure to register the custom model. Provide details of failure through error message.
        """
        url = f"{self.base_url}/v1/model_registrations/{model_type}"
        request_body = {"model": model, "persist": persist}
        response = requests.post(url, json=request_body, headers=self._headers)
        if response.status_code != 200:
            raise RuntimeError(
                f"Failed to register model, detail: {_get_error_string(response)}"
            )

        response_data = response.json()
        return response_data

    def unregister_model(self, model_type: str, model_name: str):
        """
        Unregister a custom model.

        Parameters
        ----------
        model_type: str
            The type of model.
        model_name: str
            The name of the model

        Raises
        ------
        RuntimeError
            Report failure to unregister the custom model. Provide details of failure through error message.
        """
        url = f"{self.base_url}/v1/model_registrations/{model_type}/{model_name}"
        response = requests.delete(url, headers=self._headers)
        if response.status_code != 200:
            raise RuntimeError(
                f"Failed to register model, detail: {_get_error_string(response)}"
            )

        response_data = response.json()
        return response_data

    def list_model_registrations(self, model_type: str) -> List[Dict[str, Any]]:
        """
        List models registered on the server.

        Parameters
        ----------
        model_type: str
            The type of the model.

        Returns
        -------
        List[Dict[str, Any]]
            The collection of registered models on the server.

        Raises
        ------
        RuntimeError
            Report failure to list model registration. Provide details of failure through error message.

        """
        url = f"{self.base_url}/v1/model_registrations/{model_type}"
        response = requests.get(url, headers=self._headers)
        if response.status_code != 200:
            raise RuntimeError(
                f"Failed to list model registration, detail: {_get_error_string(response)}"
            )

        response_data = response.json()
        return response_data

    def get_model_registration(
        self, model_type: str, model_name: str
    ) -> Dict[str, Any]:
        """
        Get the model with the model type and model name registered on the server.

        Parameters
        ----------
        model_type: str
            The type of the model.

        model_name: str
            The name of the model.
        Returns
        -------
        List[Dict[str, Any]]
            The collection of registered models on the server.
        """
        url = f"{self.base_url}/v1/model_registrations/{model_type}/{model_name}"
        response = requests.get(url, headers=self._headers)
        if response.status_code != 200:
            raise RuntimeError(
                f"Failed to list model registration, detail: {_get_error_string(response)}"
            )

        response_data = response.json()
        return response_data
