diff --git a/src/tether/runtime/server.py b/src/tether/runtime/server.py index 8318b70..499de77 100644 --- a/src/tether/runtime/server.py +++ b/src/tether/runtime/server.py @@ -19,10 +19,12 @@ from __future__ import annotations import base64 +import contextvars import io import json import logging import time +import uuid from pathlib import Path from typing import Any @@ -65,6 +67,7 @@ def track_in_flight(*args, **kwargs): logger = logging.getLogger(__name__) _tracer = get_tracer(__name__) +_request_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("request_id", default="") try: from tether import __version__ as _TETHER_VERSION @@ -2180,6 +2183,17 @@ async def _heartbeat_loop(): lifespan=lifespan, ) + @app.middleware("http") + async def _request_id_middleware(request, call_next): + req_id = str(uuid.uuid4()) + token = _request_id_var.set(req_id) + try: + response = await call_next(request) + finally: + _request_id_var.reset(token) + response.headers["X-Reflex-Request-ID"] = req_id + return response + # Bearer auth dependency (Phase 1 auth-bearer feature). # If api_key is set at app-creation time, every protected route requires # the caller to pass `Authorization: Bearer ` (preferred) OR the @@ -2274,6 +2288,9 @@ async def act(request: PredictRequest, _auth: None = Depends(_require_api_key)): # Non-standard attrs under gen_ai.action.* — proposed for upstream # OTel GenAI working group contribution (Phase 2 per spec). span.set_attribute("gen_ai.action.embodiment", _emb_label) + _req_id = _request_id_var.get() + if _req_id: + span.set_attribute("tether.request_id", _req_id) # chunk_size + denoise_steps are set AFTER predict returns (we don't # know them until the result is in hand). See ~line 1590 below. span.set_attribute( diff --git a/tests/test_request_id_header.py b/tests/test_request_id_header.py new file mode 100644 index 0000000..9df2ad2 --- /dev/null +++ b/tests/test_request_id_header.py @@ -0,0 +1,86 @@ +"""Tests for the X-Reflex-Request-ID response header middleware. + +Verifies that every HTTP response carries a unique UUID4 in the +X-Reflex-Request-ID header, and that the value is available to the +OTel span via the _request_id_var context variable. + +Uses FastAPI's TestClient — no model loading, no ONNX runtime. +""" +from __future__ import annotations + +import contextvars +import uuid + +import pytest +from fastapi import FastAPI +from fastapi.responses import JSONResponse +from fastapi.testclient import TestClient + + +def _build_app() -> tuple[FastAPI, contextvars.ContextVar[str]]: + """Minimal FastAPI app with only the request-ID middleware wired in.""" + request_id_var: contextvars.ContextVar[str] = contextvars.ContextVar( + "request_id", default="" + ) + + app = FastAPI() + + @app.middleware("http") + async def _request_id_middleware(request, call_next): + req_id = str(uuid.uuid4()) + token = request_id_var.set(req_id) + try: + response = await call_next(request) + finally: + request_id_var.reset(token) + response.headers["X-Reflex-Request-ID"] = req_id + return response + + @app.get("/health") + async def health(): + return JSONResponse({"status": "ok"}) + + @app.post("/act") + async def act(): + return JSONResponse({"actions": []}) + + return app, request_id_var + + +@pytest.fixture +def client(): + app, _ = _build_app() + return TestClient(app) + + +class TestRequestIDHeader: + def test_health_response_has_header(self, client): + assert "X-Reflex-Request-ID" in client.get("/health").headers + + def test_act_response_has_header(self, client): + assert "X-Reflex-Request-ID" in client.post("/act").headers + + def test_header_value_is_valid_uuid(self, client): + value = client.get("/health").headers["X-Reflex-Request-ID"] + assert str(uuid.UUID(value)) == value # raises ValueError if malformed + + def test_each_request_gets_a_unique_id(self, client): + ids = {client.get("/health").headers["X-Reflex-Request-ID"] for _ in range(5)} + assert len(ids) == 5 + + def test_request_id_is_accessible_inside_route(self): + """Context var holds the same ID that ends up in the response header.""" + app, request_id_var = _build_app() + + captured: dict[str, str] = {} + + @app.post("/act_span") + async def act_span(): + captured["id"] = request_id_var.get() + return JSONResponse({"actions": []}) + + c = TestClient(app) + response = c.post("/act_span") + + assert captured["id"] != "" + assert response.headers["X-Reflex-Request-ID"] == captured["id"]