import json
import inspect

from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from typing import Callable
from pprint import pprint
from typing import Optional, List, Union
from pydantic import BaseModel, ValidationError
from dataclasses import dataclass, field

from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat import ChatCompletionMessage
from openai.types.chat.chat_completion_message_tool_call   import ChatCompletionMessageToolCall, Function

from llm_easy_tools.schema_generator import get_name, llm_function, parameters_basemodel_from_function

class NoMatchingTool(Exception):
    def __init__(self, message):
        self.message = message
        super().__init__(self.message)


@dataclass
class ToolResult:
    """
    Represents the result of a tool invocation within the ToolBox framework.

    Attributes:
        tool_call_id (str): A unique identifier for the tool call.
        name (str): The name of the tool that was called.
        output (Optional[Union[str, BaseModel]]): The output generated by the tool call, if any.
        error (Optional[Exception]): An error message if the tool call failed.
        soft_errors (List[Exception]): A list of non-critical error messages encountered during the tool call.
        prefix (Optional[BaseModel]): The Pydantic model instance used as a prefix in the tool call, if applicable.

    Methods:
        to_message(): Converts the ToolResult into a dictionary suitable for returning to a chat interface.
    """
    tool_call_id: str
    name: str
    output: Optional[Union[str, BaseModel]] = None
    error: Optional[Exception] = None
    soft_errors: List[Exception] = field(default_factory=list)
    prefix: Optional[BaseModel] = None
    tool: Optional[Callable|BaseModel] = None

    def to_message(self) -> dict[str, str]:
        if self.error is not None:
            content = f"{self.error}"
        elif self.output is None:
            content = ''
        elif isinstance(self.output, BaseModel):
            content = f"{self.name} created"
        else:
            content = str(self.output)
        return {
            "role": "tool",
            "tool_call_id": self.tool_call_id,
            "name": self.name,
            "content": content,
        }

def process_tool_call(tool_call, functions_or_models, prefix_class=None, fix_json_args=True, case_insensitive=False) -> ToolResult:
    function_call = tool_call.function
    tool_name = function_call.name
    args = function_call.arguments
    soft_errors = []
    error = None
    prefix = None
    output = None
    try:
        tool_args = json.loads(args)
    except json.decoder.JSONDecodeError as e:
        if fix_json_args:
            soft_errors.append(e)
            args = args.replace(', }', '}').replace(',}', '}')
            tool_args = json.loads(args)
        else:
            return ToolResult(tool_call_id=tool_call.id, name=tool_name, error=e)

    if prefix_class is not None:
        try:
            prefix = _extract_prefix_unpacked(tool_args, prefix_class)
        except ValidationError as e:
            soft_errors.append(e)
        prefix_name = prefix_class.__name__
        if case_insensitive:
            prefix_name = prefix_name.lower()
        if not tool_name.startswith(prefix_name):
            soft_errors.append(NoMatchingTool(f"Trying to decode function call with a name '{tool_name}' not matching prefix '{prefix_name}'"))
        else:
            tool_name = tool_name[len(prefix_name + '_and_'):]

    tool = None

    for f in functions_or_models:
        if get_name(f, case_insensitive=case_insensitive) == tool_name:
            tool = f
            try:
                output = _process_unpacked(f, tool_args)
            except Exception as e:
                error = e
            break
    else:
        error = NoMatchingTool(f"Function {tool_name} not found")
    result = ToolResult(
        tool_call_id=tool_call.id, 
        name=tool_name, 
        output=output, 
        error=error,
        soft_errors=soft_errors,
        prefix=prefix,
        tool=tool,
    )
    return result

def _process_unpacked(function, tool_args={}) -> Union[str, BaseModel]:
    model = parameters_basemodel_from_function(function)
    model_instance = model(**tool_args)
    args = {}
    for field, _ in model.model_fields.items():
        args[field] = getattr(model_instance, field)
    return function(**args)

def _extract_prefix_unpacked(tool_args, prefix_class):
    # modifies tool_args
    prefix_args = {}
    for key in list(tool_args.keys()):  # copy keys to list because we modify the dict while iterating over it
        if key in prefix_class.__annotations__:
            prefix_args[key] = tool_args.pop(key)
    prefix = prefix_class(**prefix_args)
    return(prefix)

def process_response(
        response: ChatCompletion,
        functions: List[Callable],
        choice_num=0,
        prefix_class=None,
        fix_json_args=True,
        case_insensitive=False,
        executor: ThreadPoolExecutor|ProcessPoolExecutor|None=None,
        ) -> list[ToolResult]:
    """
    Processes a ChatCompletion response, executing contained tool calls.
    For each tool call matches a function from the 'functions' list by name.
    The result of the tool call is returned as a ToolResult object.
    If the tool call raises an exception, that exception is saved in the 'error' field in the result.

    Args:
        response (ChatCompletion): The response object containing tool calls.
        functions (List[Callable]): A list of functions or pydantic models to call.
        choice_num (int, optional): The index of the choice to process from the response. Defaults to 0.

    Returns:
        list[ToolResult]: A list of ToolResult objects, each representing the outcome of a processed tool call.
    """
    results = []
    if hasattr(response.choices[choice_num].message, 'function_call') and (function_call:=response.choices[choice_num].message.function_call):
        # this is obsolete in openai - but maybe it is used by other llms?
        tool_calls = [ChatCompletionMessageToolCall(id='A', function=Function(name=function_call.name, arguments=function_call.arguments), type='function')]
    elif hasattr(response.choices[choice_num].message, 'tool_calls') and response.choices[choice_num].message.tool_calls:
        tool_calls = response.choices[choice_num].message.tool_calls
    else:
        tool_calls = []
        # Prepare the arguments for each tool call
    if not tool_calls:
        return []
    args_list = [(tool_call, functions, prefix_class, fix_json_args, case_insensitive) for tool_call in tool_calls]

    if executor:
        results = list(executor.map(lambda args: process_tool_call(*args), args_list))
    else:
        results = list(map(lambda args: process_tool_call(*args), args_list)) 
    return results

def get_toolset_tools(obj: object) -> list[Callable]:
    result = []
    methods = inspect.getmembers(obj, predicate=inspect.ismethod)
    for _, method in methods:
        if hasattr(method, 'LLMEasyTools_external_function'):
            result.append(method)
    for attr_name in dir(obj.__class__):
        attr_value = getattr(obj.__class__, attr_name)
        if isinstance(attr_value, type) and hasattr(attr_value, 'LLMEasyTools_external_function'):
            result.append(attr_value)
    return result



#######################################
#
# Examples


if __name__ == "__main__":

    @llm_function(schema_name="altered_name")
    def function_decorated():
        return 'Result of function_decorated'

    class ExampleClass:
        def simple_method(self, count: int, size: float):
            """simple method does something"""
            return 'Result of simple_method'

    example_object = ExampleClass()

    class User(BaseModel):
        name: str
        email: str

    def mk_chat_with_tool_call(name, args):
        message = ChatCompletionMessage(
            role="assistant",
            tool_calls=[
                {
                    "id": 'A',
                    "type": 'function',
                    "function": {
                        "arguments": json.dumps(args),
                        "name": name
                    }
                }
            ]
        )
        chat_completion = ChatCompletion(
            id='A',
            created=0,
            model='A',
            choices=[{'finish_reason': 'stop', 'index': 0, 'message': message}],
            object='chat.completion'
        )
        return chat_completion


    pprint(process_response(mk_chat_with_tool_call('altered_name', {}), [function_decorated]))
    call_to_altered_name = mk_chat_with_tool_call('altered_name', {}).choices[0].message.tool_calls[0]
    pprint(call_to_altered_name)
    pprint(process_tool_call(call_to_altered_name, [function_decorated]))

    call_to_simple_method = mk_chat_with_tool_call('simple_method', {"count": 1, "size": 2.2}).choices[0].message.tool_calls[0]
    pprint(process_tool_call(call_to_simple_method, [example_object.simple_method]))

    call_to_model = mk_chat_with_tool_call('User', {"name": 'John', "email": 'john@example.com'}).choices[0].message.tool_calls[0]
    pprint(process_tool_call(call_to_model, [User]))

