import time
from starlette.requests import Request
from starlette.middleware.base import BaseHTTPMiddleware

from mineru_flow.internal.common.logging import get_logger

trace_logger = get_logger("middleware.tracing")


class TracingMiddleWare(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        start_time = time.perf_counter()
        method = request.method
        path = request.url.path
        client_ip = request.client.host if request.client else None
        request_id = request.headers.get("x-request-id")
        trace_logger.info(
            "请求开始",
            method=method,
            path=path,
            client_ip=client_ip,
            request_id=request_id,
        )

        response = await call_next(request)

        # uncomments for logging response body
        # response_body = [chunk async for chunk in response.body_iterator]
        # response.body_iterator = iterate_in_threadpool(iter(response_body))
        # trace_logger.debug(
        #     "响应体内容",
        #     method=method,
        #     path=path,
        #     body=(b"".join(response_body)).decode()
        # )

        duration_ms = (time.perf_counter() - start_time) * 1000
        trace_logger.info(
            "请求结束",
            method=method,
            path=path,
            status=response.status_code,
            duration_ms=round(duration_ms, 2),
            client_ip=client_ip,
            request_id=request_id,
        )
        return response
