from __future__ import annotations import logging import time import uuid from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import Response logger = logging.getLogger("app.http") class RequestContextMiddleware(BaseHTTPMiddleware): """注入 request_id,记录访问日志与耗时。""" async def dispatch(self, request: Request, call_next) -> Response: rid = request.headers.get("X-Request-ID") or str(uuid.uuid4()) request.state.request_id = rid path = request.url.path if path.startswith("/static"): response = await call_next(request) response.headers["X-Request-ID"] = rid return response client = request.client.host if request.client else "-" if path.startswith("/api"): logger.info( "http_in method=%s path=%s rid=%s client=%s", request.method, path, rid, client, ) started = time.perf_counter() try: response = await call_next(request) except Exception: duration_ms = (time.perf_counter() - started) * 1000 logger.exception( "http_error method=%s path=%s duration_ms=%.1f rid=%s", request.method, path, duration_ms, rid, ) raise duration_ms = (time.perf_counter() - started) * 1000 response.headers["X-Request-ID"] = rid logger.info( "http_out method=%s path=%s status=%s duration_ms=%.1f rid=%s", request.method, path, response.status_code, duration_ms, rid, ) return response