import asyncio
import logging
import sys
from typing import Any, AsyncIterator, List, NamedTuple, Optional, Union

import httpx
import httpx_sse
import orjson
from httpx._types import QueryParamTypes

from langgraph_sdk.schema import (
    Assistant,
    Config,
    GraphSchema,
    Metadata,
    Run,
    RunEvent,
    StreamMode,
    Thread,
)

logger = logging.getLogger(__name__)


def get_client(*, url: str = "http://localhost:8123") -> "LangServeClient":
    client = httpx.AsyncClient(
        base_url=url,
        transport=httpx.AsyncHTTPTransport(retries=5),
        timeout=httpx.Timeout(connect=5, read=60, write=60, pool=5),
    )
    return LangServeClient(client)


class StreamPart(NamedTuple):
    event: str
    data: dict


class LangServeClient:
    def __init__(self, client: httpx.AsyncClient) -> None:
        self.http = HttpClient(client)
        self.assistants = AssistantsClient(self.http)
        self.threads = ThreadsClient(self.http)
        self.runs = RunsClient(self.http)


class HttpClient:
    def __init__(self, client: httpx.AsyncClient) -> None:
        self.client = client

    async def get(self, path: str, *, params: QueryParamTypes = None) -> dict:
        """Make a GET request."""
        r = await self.client.get(path, params=params)
        try:
            r.raise_for_status()
        except httpx.HTTPStatusError as e:
            body = (await r.aread()).decode()
            if sys.version_info >= (3, 11):
                e.add_note(body)
            logger.error(f"Error from langgraph-api: {body}", exc_info=e)
            raise e
        return await decode_json(r)

    async def post(self, path: str, *, json: dict) -> dict:
        """Make a POST request."""
        headers, content = await encode_json(json)
        r = await self.client.post(path, headers=headers, content=content)
        try:
            r.raise_for_status()
        except httpx.HTTPStatusError as e:
            body = (await r.aread()).decode()
            if sys.version_info >= (3, 11):
                e.add_note(body)
            logger.error(f"Error from langgraph-api: {body}", exc_info=e)
            raise e
        return await decode_json(r)

    async def put(self, path: str, *, json: dict) -> dict:
        """Make a PUT request."""
        headers, content = await encode_json(json)
        r = await self.client.put(path, headers=headers, content=content)
        try:
            r.raise_for_status()
        except httpx.HTTPStatusError as e:
            body = (await r.aread()).decode()
            if sys.version_info >= (3, 11):
                e.add_note(body)
            logger.error(f"Error from langgraph-api: {body}", exc_info=e)
            raise e
        return await decode_json(r)

    async def delete(self, path: str) -> None:
        """Make a DELETE request."""
        r = await self.client.delete(path)
        try:
            r.raise_for_status()
        except httpx.HTTPStatusError as e:
            body = (await r.aread()).decode()
            if sys.version_info >= (3, 11):
                e.add_note(body)
            logger.error(f"Error from langgraph-api: {body}", exc_info=e)
            raise e

    async def stream(
        self, path: str, method: str, *, json: dict = None
    ) -> AsyncIterator[StreamPart]:
        """Stream the results of a request using SSE."""
        headers, content = await encode_json(json)
        async with httpx_sse.aconnect_sse(
            self.client, method, path, headers=headers, content=content
        ) as sse:
            try:
                sse.response.raise_for_status()
            except httpx.HTTPStatusError as e:
                body = (await sse.response.aread()).decode()
                if sys.version_info >= (3, 11):
                    e.add_note(body)
                logger.error(f"Error from langgraph-api: {body}", exc_info=e)
                raise e
            async for event in sse.aiter_sse():
                yield StreamPart(
                    event.event, orjson.loads(event.data) if event.data else None
                )


def _orjson_default(obj: Any) -> Any:
    if hasattr(obj, "model_dump") and callable(obj.model_dump):
        return obj.model_dump()
    elif hasattr(obj, "dict") and callable(obj.dict):
        return obj.dict()
    else:
        raise TypeError(f"Object of type {type(obj)} is not JSON serializable")


async def encode_json(json: Any) -> tuple[dict[str, str], bytes]:
    body = await asyncio.get_running_loop().run_in_executor(
        None,
        orjson.dumps,
        json,
        _orjson_default,
        orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_NON_STR_KEYS,
    )
    content_length = str(len(body))
    content_type = "application/json"
    headers = {"Content-Length": content_length, "Content-Type": content_type}
    return headers, body


async def decode_json(r: httpx.Response) -> dict:
    return await asyncio.get_running_loop().run_in_executor(
        None, orjson.loads, await r.aread()
    )


class AssistantsClient:
    def __init__(self, http: HttpClient) -> None:
        self.http = http

    async def get(self, assistant_id: str) -> Assistant:
        """Get an assistant by ID."""
        return await self.http.get(f"/assistants/{assistant_id}")

    async def get_graph(self, assistant_id: str) -> dict[str, list[dict[str, Any]]]:
        """Get the graph of an assistant by ID."""
        return await self.http.get(f"/assistants/{assistant_id}/graph")

    async def get_schemas(self, assistant_id: str) -> GraphSchema:
        """Get the schemas of an assistant by ID."""
        return await self.http.get(f"/assistants/{assistant_id}/schemas")

    async def create(
        self,
        graph_id: Optional[str],
        config: Optional[Config] = None,
        *,
        metadata: Metadata = None,
    ) -> Assistant:
        """Create a new assistant."""
        return await self.http.post(
            "/assistants",
            json={"metadata": metadata, "graph_id": graph_id, "config": config or {}},
        )

    async def upsert(
        self,
        assistant_id: str,
        graph_id: str,
        config: Optional[Config] = None,
        *,
        metadata: Metadata = None,
    ) -> Assistant:
        """Create or update an assistant."""
        return await self.http.put(
            f"/assistants/{assistant_id}",
            json={"metadata": metadata, "graph_id": graph_id, "config": config or {}},
        )

    async def search(
        self, *, metadata: Metadata = None, limit: int = 10, offset: int = 0
    ) -> list[Assistant]:
        """Search for assistants."""
        return await self.http.post(
            "/assistants/search",
            json={"metadata": metadata, "limit": limit, "offset": offset},
        )


class ThreadsClient:
    def __init__(self, http: HttpClient) -> None:
        self.http = http

    async def get(self, thread_id: str) -> Thread:
        """Get a thread by ID."""
        return await self.http.get(f"/threads/{thread_id}")

    async def create(self, *, metadata: Metadata = None) -> Thread:
        """Create a new thread."""
        return await self.http.post("/threads", json={"metadata": metadata})

    async def upsert(self, thread_id: str, *, metadata: Metadata) -> Thread:
        """Create or update a thread."""
        return await self.http.put(f"/threads/{thread_id}", json={"metadata": metadata})

    async def delete(self, thread_id: str) -> None:
        """Delete a thread."""
        await self.http.delete(f"/threads/{thread_id}")

    async def search(
        self, *, metadata: Metadata = None, limit: int = 10, offset: int = 0
    ) -> list[Thread]:
        """Search for threads."""
        return await self.http.post(
            "/threads/search",
            json={"metadata": metadata, "limit": limit, "offset": offset},
        )

    async def get_state(self, thread_id: str) -> dict:
        """Get the state of a thread."""
        return await self.http.get(f"/threads/{thread_id}/state")

    async def update_state(
        self,
        thread_id: Union[str, Config],
        values: dict,
        *,
        as_node: Optional[str] = None,
    ) -> None:
        """Update the state of a thread."""
        if isinstance(thread_id, dict):
            config = thread_id
            thread_id_: str = thread_id["configurable"]["thread_id"]
        else:
            thread_id_ = thread_id
            config = None
        return await self.http.post(
            f"/threads/{thread_id_}/state",
            json={"values": values, "config": config, "as_node": as_node},
        )

    async def get_history(
        self, thread_id: str, limit: int = 10, before: Optional[Config] = None
    ) -> list[dict]:
        """Get the history of a thread."""
        return await self.http.get(
            f"/threads/{thread_id}/history", params={"limit": limit, "before": before}
        )


class RunsClient:
    def __init__(self, http: HttpClient) -> None:
        self.http = http

    def stream(
        self,
        thread_id: str,
        assistant_id: str,
        *,
        input: Optional[dict] = None,
        stream_mode: StreamMode = "values",
        metadata: Optional[dict] = None,
        config: Optional[Config] = None,
        interrupt_before: Optional[list[str]] = None,
        interrupt_after: Optional[list[str]] = None,
    ) -> AsyncIterator[StreamPart]:
        """Create a run and stream the results."""
        return self.http.stream(
            f"/threads/{thread_id}/runs/stream",
            "POST",
            json={
                "input": input,
                "config": config,
                "metadata": metadata,
                "stream_mode": stream_mode,
                "assistant_id": assistant_id,
                "interrupt_before": interrupt_before,
                "interrupt_after": interrupt_after,
            },
        )

    async def create(
        self,
        thread_id: str,
        assistant_id: str,
        *,
        input: Optional[dict] = None,
        stream_mode: StreamMode = "values",
        metadata: Optional[dict] = None,
        config: Optional[Config] = None,
        interrupt_before: Optional[list[str]] = None,
        interrupt_after: Optional[list[str]] = None,
        webhook: Optional[str] = None,
    ) -> Run:
        """Create a background run."""
        return await self.http.post(
            f"/threads/{thread_id}/runs",
            json={
                "input": input,
                "config": config,
                "metadata": metadata,
                "stream_mode": stream_mode,
                "assistant_id": assistant_id,
                "interrupt_before": interrupt_before,
                "interrupt_after": interrupt_after,
                "webhook": webhook,
            },
        )

    async def list(
        self, thread_id: str, *, limit: int = 10, offset: int = 0
    ) -> List[Run]:
        """List runs."""
        return await self.http.get(f"/threads/{thread_id}/runs")

    async def get(self, thread_id: str, run_id: str) -> Run:
        """Get a run."""
        return await self.http.get(f"/threads/{thread_id}/runs/{run_id}")

    async def list_events(
        self, thread_id: str, run_id: str, *, limit: int = 10, offset: int = 0
    ) -> List[RunEvent]:
        """List run events."""
        return await self.http.get(f"/threads/{thread_id}/runs/{run_id}/events")
