import json
import logging
from typing import (
    Optional,
    Type,
)

from langchain_core.callbacks import CallbackManagerForToolRun
from pydantic import BaseModel, Field

from ttyg.utils import timeit
from .base import BaseGraphDBTool


class SparqlQueryTool(BaseGraphDBTool):
    """
    Tool, which executes SPARQL queries generated by the agent.
    """

    class SearchInput(BaseModel):
        query: str = Field(
            description="A valid SPARQL SELECT, CONSTRUCT, DESCRIBE or ASK query without prefixes"
        )

    name: str = "sparql_query"
    description: str = "Query GraphDB by SPARQL SELECT, CONSTRUCT, DESCRIBE or ASK query and return result."
    args_schema: Type[BaseModel] = SearchInput

    @timeit
    def _run(
            self,
            query: str,
            run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> str:
        logging.debug(f"Executing generated SPARQL query {query}")
        query_results = self.graph.eval_sparql_query(query)
        return json.dumps(query_results, indent=2)
