import json
import threading
import datetime
import uuid
import inspect
import enum
from textwrap import dedent
from concurrent.futures import Future
from hashlib import md5
from typing import Any, Dict, List, Set, Optional, Callable, Type
from typing_extensions import Annotated, Self

from pydantic import UUID4, BaseModel, Field, PrivateAttr, field_validator, model_validator, InstanceOf, field_validator
from pydantic_core import PydanticCustomError

import versionhq as vhq
from versionhq.task.log_handler import TaskOutputStorageHandler
from versionhq.task.evaluate import Evaluation, EvaluationItem
from versionhq.tool.model import Tool, ToolSet
from versionhq._utils import process_config, Logger


class TaskExecutionType(enum.Enum):
    """
    Enumeration to store task execution types of independent tasks without dependencies.
    """
    SYNC = 1
    ASYNC = 2


class ResponseField(BaseModel):
    """
    A class to store a response format that will generate a JSON schema.
    One layer of nested child is acceptable.
    """

    title: str = Field(default=None, description="title of the field")
    data_type: Type = Field(default=None)
    items: Optional[Type] = Field(default=None, description="store data type of the array items")
    properties: Optional[List[BaseModel]] = Field(default=None, description="store dict items in ResponseField format")
    required: bool = Field(default=True)
    nullable: bool = Field(default=False)
    config: Optional[Dict[str, Any]] = Field(default_factory=dict, description="additional rules")


    @model_validator(mode="after")
    def validate_instance(self) -> Self:
        """
        Validate the model instance based on the given `data_type`. (An array must have `items`, dict must have properties.)
        """

        if self.data_type is list and self.items is None:
            self.items = str

        if self.data_type is dict or (self.data_type is list and self.items is dict):
            if self.properties is None:
                raise PydanticCustomError("missing_properties", "The dict type has to set the properties.", {})

            else:
                for item in self.properties:
                    if not isinstance(item, ResponseField):
                        raise PydanticCustomError("invalid_properties", "Properties field must input in ResponseField format.", {})

        return self


    def _format_props(self) -> Dict[str, Any]:
        """
        Structure valid properties from the ResponseField object. 1 layer of nested child is accepted.
        """
        from versionhq.llm.llm_vars import SchemaType

        schema_type = SchemaType(type=self.data_type).convert()
        props: Dict[str, Any] = {}

        if self.data_type is list:
            if self.items is dict:
                nested_p, nested_r = dict(), list()

                if self.properties:
                    for item in self.properties:
                        nested_p.update(**item._format_props())
                        nested_r.append(item.title)

                props = {
                    "type": schema_type,
                    "items": {
                        "type": SchemaType(type=self.items).convert(),
                        "properties": nested_p,
                        "required": nested_r,
                        "additionalProperties": False
                    }
                }

            elif self.items is list:
                props = {
                    "type": schema_type,
                    "items": { "type": SchemaType(type=self.items).convert(), "items": { "type": SchemaType(type=str).convert() }},
                }

            else:
                props = {
                    "type": schema_type,
                    "items": { "type": SchemaType(type=self.items).convert() },
                }


        elif self.data_type is dict:
            p, r = dict(), list()

            if self.properties:
                for item in self.properties:
                    p.update(**item._format_props())
                    r.append(item.title)

            props = {
                "type": schema_type,
                "properties": p,
                "required": r,
                "additionalProperties": False
            }

        else:
            props = {
                "type": schema_type,
                "nullable": self.nullable,
            }

        return { self.title: { **props, **self.config }} if self.config else { self.title: props }


    def _convert(self, value: Any) -> Any:
        """
        Convert the given value to the ideal data type.
        """
        try:
            if self.type is Any:
                pass
            elif self.type is int:
                return int(value)
            elif self.type is float:
                return float(value)
            elif self.type is list or self.type is dict:
                return json.loads(eval(str(value)))
            elif self.type is str:
                return str(value)
            else:
                return value
        except:
            return value


    def _annotate(self, value: Any) -> Annotated:
        """
        Address Pydantic's `create_model`
        """
        return Annotated[self.type, value] if isinstance(value, self.type) else Annotated[str, str(value)]


    def create_pydantic_model(self, result: Dict, base_model: InstanceOf[BaseModel] | Any) -> Any:
        """
        Create a Pydantic model from the given result.
        """
        for k, v in result.items():
            if k is not self.title:
                pass
            elif type(v) is not self.type:
                v = self._convert(v)
                setattr(base_model, k, v)
            else:
                setattr(base_model, k, v)
        return base_model


class TaskOutput(BaseModel):
    """
    A class to store the final output of the given task in raw (string), json_dict, and pydantic class formats.
    """

    task_id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True, description="store Task ID")
    raw: str = Field(default="", description="Raw output of the task")
    json_dict: Dict[str, Any] = Field(default=None, description="`raw` converted to dictionary")
    pydantic: Optional[Any] = Field(default=None)
    tool_output: Optional[Any] = Field(default=None, description="store tool result when the task takes tool output as its final output")
    callback_output: Optional[Any] = Field(default=None, description="store task or agent callback outcome")
    evaluation: Optional[InstanceOf[Evaluation]] = Field(default=None, description="store overall evaluation of the task output. passed to ltm")


    def to_context_prompt(self) -> str:
        """
        Returns response in string as a prompt context.
        """
        context = ""
        try:
            context = json.dumps(self.json_dict)
        except:
            try:
                if self.pydantic:
                    context = self.pydantic.model_dump()
            except:
                context = self.raw
        return context


    def evaluate(self, task) -> Evaluation:
        """
        Evaluate the output based on the criteria, score each from 0 to 1 scale, and raise suggestions for future improvement.
        """
        from versionhq.task.TEMPLATES.Description import EVALUATE

        self.evaluation = Evaluation() if not self.evaluation else self.evaluation

        # self.evaluation.latency = latency if latency is not None else task.latency
        # self.evaluation.tokens = tokens if tokens is not None else task.tokens

        eval_criteria = task.eval_criteria if task.eval_criteria else  ["Overall competitiveness", ]

        for item in eval_criteria:
            task_eval = Task(
                description=EVALUATE.format(task_description=task.description, task_output=self.raw, eval_criteria=str(item)),
                pydantic_output=EvaluationItem
            )
            res = task_eval.execute(agent=self.evaluation.eval_by)

            if res.pydantic:
                item = EvaluationItem(score=res.pydantic.score, suggestion=res.pydantic.suggestion, criteria=res.pydantic.criteria)
                self.evaluation.items.append(item)

            else:
                try:
                    item = EvaluationItem(
                        score=float(res.json_dict["score"]), suggestion=res.json_dict["suggestion"], criteria=res.json_dict["criteria"]
                    )
                    self.evaluation.items.append(item)
                except Exception as e:
                    Logger(verbose=True).log(level="error", message=f"Failed to convert the evaluation items: {str(e)}", color="red")
                    pass

        return self.evaluation


    @property
    def aggregate_score(self) -> float | int:
        if self.evaluation is None:
            return 0
        else:
            self.evaluation.aggregate_score


    @property
    def json_string(self) -> Optional[str]:
        return json.dumps(self.json_dict)


    def __str__(self) -> str:
        return str(self.pydantic) if self.pydantic else str(self.json_dict) if self.json_dict else self.raw


class Task(BaseModel):
    """
    A class that stores independent task information and handles task executions.
    """

    __hash__ = object.__hash__
    _original_description: str = PrivateAttr(default=None)
    _task_output_handler = TaskOutputStorageHandler()
    config: Optional[Dict[str, Any]] = Field(default=None, description="values to set on Task class")

    id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True, description="unique identifier for the object, not set by user")
    name: Optional[str] = Field(default=None)
    description: str = Field(description="Description of the actual task")

    # response format
    pydantic_output: Optional[Type[BaseModel]] = Field(default=None, description="store Pydantic class as structured response format")
    response_fields: Optional[List[ResponseField]] = Field(default_factory=list, description="store list of ResponseField as structured response format")

    # tool usage
    tools: Optional[List[ToolSet | Tool | Any]] = Field(default_factory=list, description="tools that the agent can use aside from their tools")
    can_use_agent_tools: bool = Field(default=True, description="whether the agent can use their own tools when executing the task")
    tool_res_as_final: bool = Field(default=False, description="when set True, tools res will be stored in the `TaskOutput`")

    # executing
    execution_type: TaskExecutionType = Field(default=TaskExecutionType.SYNC)
    allow_delegation: bool = Field(default=False, description="ask other agents for help and run the task instead")
    callback: Optional[Callable] = Field(default=None, description="callback to be executed after the task is completed.")
    callback_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict, description="kwargs for the callback when the callback is callable")

    # evaluation
    should_evaluate: bool = Field(default=False, description="True to run the evaluation flow")
    eval_criteria: Optional[List[str]] = Field(default_factory=list, description="criteria to evaluate the outcome. i.e., fit to the brand tone")

    # recording !# REFINEME - eval_callbacks
    processed_agents: Set[str] = Field(default_factory=set, description="store roles of the agents that executed the task")
    tool_errors: int = 0
    delegations: int = 0
    latency: int | float = 0 # job latency in sec
    tokens: int = 0 # tokens consumed
    output: Optional[TaskOutput] = Field(default=None, description="store the final task output in TaskOutput class")


    @model_validator(mode="before")
    @classmethod
    def process_config(cls, values: Dict[str, Any]) -> None:
        return process_config(values_to_update=values, model_class=cls)


    @field_validator("id", mode="before")
    @classmethod
    def _deny_user_set_id(cls, v: Optional[UUID4]) -> None:
        if v:
            raise PydanticCustomError("may_not_set_field", "This field is not to be set by the user.", {})


    @model_validator(mode="after")
    def validate_required_fields(self) -> Self:
        required_fields = ["description",]
        for field in required_fields:
            if getattr(self, field) is None:
                raise ValueError( f"{field} must be provided either directly or through config")
        return self


    @model_validator(mode="after")
    def set_up_tools(self) -> Self:
        if not self.tools:
            pass
        else:
            tool_list = []
            for item in self.tools:
                if isinstance(item, Tool) or isinstance(item, ToolSet):
                    tool_list.append(item)
                elif (isinstance(item, dict) and "function" not in item) or isinstance(item, str):
                    pass
                else:
                    tool_list.append(item) # address custom tool
            self.tools = tool_list
        return self


    def _draft_output_prompt(self, model_provider: str) -> str:
        """
        Draft prompts on the output format by converting `
        """

        output_prompt = ""

        if self.pydantic_output:
            output_prompt = f"""
Your response MUST STRICTLY follow the given repsonse format:
JSON schema: {str(self.pydantic_output)}
"""

        elif self.response_fields:
            output_prompt, output_formats_to_follow = "", dict()
            response_format = str(self._structure_response_format(model_provider=model_provider))
            for item in self.response_fields:
                if item:
                    output_formats_to_follow[item.title] = f"<Return your answer in {item.data_type.__name__}>"

            output_prompt = f"""
Your response MUST be a valid JSON string that strictly follows the response format. Use double quotes for all keys and string values. Do not use single quotes, trailing commas, or any other non-standard JSON syntax.
Response format: {response_format}
Ref. Output image: {output_formats_to_follow}
"""
        else:
            output_prompt = "You MUST Return your response as a valid JSON serializable string, enclosed in double quotes. Do not use single quotes, trailing commas, or other non-standard JSON syntax."

        return dedent(output_prompt)


    def _draft_context_prompt(self, context: Any) -> str:
        """
        Create a context prompt from the given context in any format: a task object, task output object, list, dict.
        """

        context_to_add = None
        if not context:
            # Logger().log(level="error", color="red", message="Missing a context to add to the prompt. We'll return ''.")
            return context_to_add

        match context:
            case str():
                context_to_add = context

            case Task():
                if not context.output:
                    res = context.execute()
                    context_to_add = res.to_context_prompt()

                else:
                    context_to_add = context.output.raw

            case TaskOutput():
                context_to_add = context.to_context_prompt()


            case dict():
                context_to_add = str(context)

            case list():
                res = ", ".join([self._draft_context_prompt(context=item) for item in context])
                context_to_add = res

            case _:
                pass

        return dedent(context_to_add)


    def _prompt(self, model_provider: str = None, context: Optional[Any] = None) -> str:
        """
        Format the task prompt and cascade it to the agent.
        """
        output_prompt = self._draft_output_prompt(model_provider=model_provider)
        task_slices = [self.description, output_prompt, ]

        if context:
            context_prompt = self._draft_context_prompt(context=context)
            task_slices.insert(len(task_slices), f"Consider the following context when responding: {context_prompt}")

        return "\n".join(task_slices)


    def _structure_response_format(self, data_type: str = "object", model_provider: str = "gemini") -> Dict[str, Any] | None:
        """
        Structure a response format either from`response_fields` or `pydantic_output`.
        1 nested item is accepted.
        """

        from versionhq.task.structured_response import StructuredOutput

        response_format: Dict[str, Any] = None

        if model_provider == "openrouter":
            return response_format

        else:
            if self.response_fields:
                properties, required_fields = {}, []
                for i, item in enumerate(self.response_fields):
                    if item:
                        properties.update(item._format_props())
                        required_fields.append(item.title)

                response_schema = {
                    "type": "object",
                    "properties": properties,
                    "required": required_fields,
                    "additionalProperties": False,
                }
                response_format = {
                    "type": "json_schema",
                    "json_schema": { "name": "outcome", "schema": response_schema }
                }


            elif self.pydantic_output:
                response_format = StructuredOutput(response_format=self.pydantic_output, provider=model_provider)._format()

            return response_format


    def _create_json_output(self, raw: str) -> Dict[str, Any]:
        """
        Create json (dict) output from the raw output and `response_fields` information.
        """

        if raw is None or raw == "":
            Logger().log(level="warning", message="The model returned an empty response. Returning an empty dict.", color="yellow")
            output = { "output": "" }
            return output

        try:
            r = str(raw).replace("true", "True").replace("false", "False")
            j = json.dumps(eval(r))
            output = json.loads(j)
            if isinstance(output, dict):
                return output

            else:
                r = str(raw).strip().replace("{'", '{"').replace("{ '", '{"').replace("': '", '": "').replace("'}", '"}').replace("' }", '"}').replace("', '", '", "').replace("['", '["').replace("[ '", '[ "').replace("']", '"]').replace("' ]", '" ]').replace("{\n'", '{"').replace("{\'", '{"').replace("true", "True").replace("false", "False").replace('\"', "'")
                j = json.dumps(eval(r))
                output = json.loads(j)

                if isinstance(output, dict):
                    return output

                else:
                    import ast
                    output = ast.literal_eval(r)
                    return output if isinstance(output, dict) else { "output": str(r) }

        except:
            output = { "output": str(raw) }
            return output


    def _create_pydantic_output(self, raw: str = None, json_dict: Dict[str, Any] = None) -> InstanceOf[BaseModel]:
        """
        Create pydantic output from raw or json_dict output.
        """

        output_pydantic = self.pydantic_output

        try:
            json_dict = json_dict if json_dict else self._create_json_output(raw=raw)

            for k, v in json_dict.items():
                setattr(output_pydantic, k, v)
        except:
            pass

        return output_pydantic


    def interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
        """
        Interpolate inputs into the task description.
        """
        self._original_description = self.description

        if inputs:
            self.description = self._original_description.format(**inputs)


    def _create_short_and_long_term_memories(self, agent: Any, task_output: TaskOutput) -> None:
        """
        After the task execution, create and save short-term/long-term memories in the storage.
        """
        from versionhq.agent.model import Agent
        from versionhq.memory.model import ShortTermMemory, MemoryMetadata, LongTermMemory

        agent = agent if isinstance(agent, Agent) else Agent(role=str(agent), goal=str(agent), use_memory=True)

        if agent.use_memory == False:
            return None

        try:
            evaluation = task_output.evaluation if task_output.evaluation else None
            memory_metadata = evaluation._create_memory_metadata() if evaluation else MemoryMetadata()

            agent.short_term_memory = agent.short_term_memory if agent.short_term_memory else ShortTermMemory(agent=agent, embedder_config=agent.embedder_config)
            agent.short_term_memory.save(
                task_description=str(self.description),
                task_output=str(task_output.raw),
                agent=str(agent.role),
                metadata=memory_metadata
            )

            agent.long_term_memory = agent.long_term_memory if agent.long_term_memory else LongTermMemory()
            agent.long_term_memory.save(
                task_description=str(self.description),
                task_output=str(task_output.raw),
                agent=str(agent.role),
                metadata=memory_metadata
                )

        except AttributeError as e:
            Logger().log(level="error", message=f"Missing attributes for long term memory: {str(e)}", color="red")
            pass

        except Exception as e:
            Logger().log(level="error", message=f"Failed to add to the memory: {str(e)}", color="red")
            pass


    def _build_agent_from_task(self, task_description: str = None) -> InstanceOf["vhq.Agent"]:
        task_description = task_description if task_description else self.description
        if not task_description:
            Logger().log(level="error", message="Task is missing the description.", color="red")
            pass

        agent = vhq.Agent(goal=task_description, role=task_description, maxit=1) #! REFINEME
        return agent


    # task execution
    def execute(
            self, type: TaskExecutionType = None, agent: Optional["vhq.Agent"] = None, context: Optional[Any] = None
            ) -> TaskOutput | Future[TaskOutput]:
        """
        A main method to handle task execution. Build an agent when the agent is not given.
        """
        type = type if type else  self.execution_type if self.execution_type else TaskExecutionType.SYNC

        if not agent:
            agent = self._build_agent_from_task(task_description=self.description)

        match type:
            case TaskExecutionType.SYNC:
                return self._execute_sync(agent=agent, context=context)

            case TaskExecutionType.ASYNC:
                return self._execute_async(agent=agent, context=context)


    def _execute_sync(self, agent, context: Optional[Any] = None) -> TaskOutput:
        """Executes the task synchronously."""
        return self._execute_core(agent, context)


    def _execute_async(self, agent, context: Optional[Any] = None) -> Future[TaskOutput]:
        """Executes the task asynchronously."""
        future: Future[TaskOutput] = Future()
        threading.Thread(daemon=True, target=self._execute_task_async, args=(agent, context, future)).start()
        return future


    def _execute_task_async(self, agent, context: Optional[str], future: Future[TaskOutput]) -> None:
        """
        Executes the task asynchronously with context handling.
        """
        result = self._execute_core(agent, context)
        future.set_result(result)


    def _execute_core(self, agent, context: Optional[Any]) -> TaskOutput:
        """
        A core method for task execution.
        Handles 1. agent delegation, 2. tools, 3. context to add to the prompt, and 4. callbacks.
        """

        from versionhq.agent.model import Agent
        from versionhq.agent_network.model import AgentNetwork

        task_output: InstanceOf[TaskOutput] = None
        raw_output: str = None
        tool_output: str | list = None
        task_tools: List[List[InstanceOf[Tool]| InstanceOf[ToolSet] | Type[Tool]]] = []
        started_at, ended_at = datetime.datetime.now(), datetime.datetime.now()

        if self.tools:
            for item in self.tools:
                if isinstance(item, ToolSet) or isinstance(item, Tool) or type(item) == Tool:
                    task_tools.append(item)

        if self.allow_delegation == True:
            agent_to_delegate = None

            if hasattr(agent, "network") and isinstance(agent.network, AgentNetwork):
                if agent.network.managers:
                    idling_manager_agents = [manager.agent for manager in agent.network.managers if manager.is_idling]
                    agent_to_delegate = idling_manager_agents[0] if idling_manager_agents else agent.network.managers[0]
                else:
                    peers = [member.agent for member in agent.network.members if member.is_manager == False and member.agent.id is not agent.id]
                    if len(peers) > 0:
                        agent_to_delegate = peers[0]
            else:
                agent_to_delegate = Agent(role="vhq-Delegated-Agent", goal=agent.goal, llm=agent.llm)

            agent = agent_to_delegate
            self.delegations += 1


        if self.tool_res_as_final == True:
            started_at = datetime.datetime.now()
            tool_output = agent.execute_task(task=self, context=context, task_tools=task_tools)
            ended_at = datetime.datetime.now()
            task_output = TaskOutput(task_id=self.id, tool_output=tool_output, raw=str(tool_output) if tool_output else "")

        else:
            started_at = datetime.datetime.now()
            raw_output = agent.execute_task(task=self, context=context, task_tools=task_tools)
            ended_at = datetime.datetime.now()

            json_dict_output = self._create_json_output(raw=raw_output)
            if "outcome" in json_dict_output:
                json_dict_output = self._create_json_output(raw=str(json_dict_output["outcome"]))

            pydantic_output = self._create_pydantic_output(raw=raw_output, json_dict=json_dict_output) if self.pydantic_output else None

            task_output = TaskOutput(
                task_id=self.id,
                raw=raw_output if raw_output is not None else "",
                pydantic=pydantic_output,
                json_dict=json_dict_output
            )

        self.latency = (ended_at - started_at).total_seconds()
        task_output.evaluation = Evaluation(latency=self.latency, tokens=self.tokens)
        self.output = task_output
        self.processed_agents.add(agent.role)

        if self.should_evaluate and raw_output: # eval only when raw output exsits
            task_output.evaluate(task=self)

        self._create_short_and_long_term_memories(agent=agent, task_output=task_output)

        if self.callback and isinstance(self.callback, Callable):
            kwargs = { **self.callback_kwargs, **task_output.json_dict }
            sig = inspect.signature(self.callback)
            valid_keys = [param.name for param in sig.parameters.values() if param.kind == param.POSITIONAL_OR_KEYWORD]
            valid_kwargs = { k: kwargs[k] if  k in kwargs else None for k in valid_keys }
            callback_res = self.callback(**valid_kwargs)
            task_output.callback_output = callback_res

        # if self.output_file: ## disabled for now
        #     content = (
        #         json_output
        #         if json_output
        #         else pydantic_output.model_dump_json() if pydantic_output else result
        #     )
        #     self._save_file(content)
        return task_output


    def _store_execution_log(self, task_index: int, was_replayed: bool = False, inputs: Optional[Dict[str, Any]] = {}) -> None:
        """
        Store the task execution log.
        """

        self._task_output_handler.update(task=self, task_index=task_index, was_replayed=was_replayed, inputs=inputs)


    @property
    def key(self) -> str:
        output_format = "json" if self.response_fields else "pydantic" if self.pydantic_output is not None else "raw"
        source = [self.description, output_format]
        return md5("|".join(source).encode(), usedforsecurity=False).hexdigest()


    @property
    def summary(self) -> str:
        return f"""
Task ID: {str(self.id)}
"Description": {self.description}
"Tools": {", ".join([tool.name for tool in self.tools])}
        """
