diff --git a/.gitignore b/.gitignore index e6d5efa..f4b6a03 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,9 @@ node_modules/ dist/ coverage/ *.log +.venv/ +__pycache__/ +*.py[cod] +*.egg-info/ +build/ +.pytest_cache/ diff --git a/README.md b/README.md index 7985392..0b4b022 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,87 @@ pi --model dynamo/ -p "Reply exactly ok." For local Dynamo, the API key is usually not checked. This package defaults to `dynamo-local` if `DYNAMO_API_KEY` is unset. + +## Generic Agent Proxy + +This repository also ships a standalone Python `dynamo-agent-proxy` CLI for +agents that can speak OpenAI or Anthropic APIs but cannot load a Pi plugin. The +proxy is not Pi-specific: it reads `DYNAMO_*` and `DYN_AGENT_*` environment +variables, injects `nvext.agent_context` / `nvext.agent_hints`, and forwards +requests to a Dynamo OpenAI-compatible endpoint. Existing unknown `nvext` fields +are preserved, while proxy-provided context and hint keys take precedence. + +Supported client surfaces: + +- `POST /v1/chat/completions` - pass-through to Dynamo with `nvext` injection. +- `POST /v1/responses` - pass-through to Dynamo with `nvext` injection. +- `POST /v1/messages` - Anthropic Messages API compatibility, translated to + Dynamo `/v1/chat/completions` and translated back to Anthropic responses. + +Install directly from GitHub with `uv`: + +```bash +uv pip install "dynamo-agent-proxy @ git+ssh://git@github.com:ai-dynamo/pi-dynamo-provider.git" +``` + +For local development from a checkout: + +```bash +uv pip install -e . +``` + +Run the proxy: + +```bash +export DYNAMO_BASE_URL=http://127.0.0.1:8000/v1 +export DYNAMO_API_KEY=dummy +export DYN_AGENT_SESSION_TYPE_ID=generic_agent +export DYN_AGENT_SESSION_ID=agent-demo-001 +export DYN_AGENT_TRAJECTORY_ID=agent-demo-001:main + +dynamo-agent-proxy --listen-port 18080 --priority 5 --osl 1024 +``` + +Point OpenAI-compatible clients at: + +```text +http://127.0.0.1:18080/v1 +``` + +Point Anthropic-compatible clients at: + +```text +http://127.0.0.1:18080 +``` + +For example, an Anthropic client will call `POST /v1/messages`; the proxy maps +that request onto Dynamo's OpenAI-compatible `/v1/chat/completions` endpoint. +Tool definitions and tool-choice hints are converted between Anthropic tool +schemas and OpenAI function tools. + +Proxy-specific options mirror the environment variables: + +```text +--listen-host HOST Bind host. Default: 127.0.0.1 +--listen-port PORT Bind port. Default: 18080 +--upstream URL Dynamo OpenAI-compatible base URL. Default: http://127.0.0.1:8000/v1 +--api-key KEY Bearer token sent to Dynamo. Default: dynamo-local +--model MODEL Fallback model for Anthropic requests without model +--session-type-id VALUE nvext.agent_context.session_type_id. Default: generic_agent +--session-id VALUE nvext.agent_context.session_id. Default: generated proxy id +--trajectory-id VALUE nvext.agent_context.trajectory_id. Default: :main +--parent-trajectory-id ID Optional parent trajectory id +--priority INT Convenience hint for nvext.agent_hints.priority +--osl INT Convenience hint for nvext.agent_hints.osl +--agent-hint KEY=VALUE Additional nvext.agent_hints entry; VALUE may be JSON +``` + +You can also set `DYN_AGENT_HINTS` to a JSON object, for example: + +```bash +export DYN_AGENT_HINTS='{"priority":5,"osl":1024}' +``` + ## Local Dynamo Launcher For local onboarding, this repo includes two small Dynamo helper scripts. @@ -440,6 +521,9 @@ npm install npm run check npm run test npm run build + +PYTHONPATH=python python3 -m unittest discover -s test/python +python3 -m pip wheel . --no-deps --no-build-isolation -w /tmp/dynamo-agent-proxy-wheel ``` Run from source without installing: @@ -483,6 +567,7 @@ Authentication fails: Included: - OpenAI-compatible chat-completions path. +- Generic proxy for OpenAI `/v1/chat/completions`, OpenAI `/v1/responses`, and Anthropic `/v1/messages`. - Model discovery from `/v1/models`. - Dynamo request metadata injection. - Pi session id as default `trajectory_id`. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..2461cdf --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,23 @@ +[build-system] +requires = ["setuptools>=69"] +build-backend = "setuptools.build_meta" + +[project] +name = "dynamo-agent-proxy" +version = "0.1.0" +description = "Generic OpenAI and Anthropic compatible proxy that injects Dynamo nvext agent metadata." +readme = "README.md" +requires-python = ">=3.11" +license = { file = "LICENSE" } +authors = [{ name = "NVIDIA" }] +keywords = ["dynamo", "openai", "anthropic", "proxy", "agent", "nvext"] +dependencies = [] + +[project.scripts] +dynamo-agent-proxy = "dynamo_agent_proxy.proxy:main" + +[project.urls] +Repository = "https://github.com/ai-dynamo/pi-dynamo-provider" + +[tool.setuptools.packages.find] +where = ["python"] diff --git a/python/dynamo_agent_proxy/__init__.py b/python/dynamo_agent_proxy/__init__.py new file mode 100644 index 0000000..b7ebd4c --- /dev/null +++ b/python/dynamo_agent_proxy/__init__.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Generic Dynamo agent proxy.""" + +from .proxy import ( + AgentAnnotation, + ProxyConfig, + annotate_json_request_body, + make_handler, + merge_dynamo_metadata, + read_proxy_config, + translate_anthropic_messages_request, + translate_openai_chat_response_to_anthropic, +) + +__all__ = [ + "AgentAnnotation", + "ProxyConfig", + "annotate_json_request_body", + "make_handler", + "merge_dynamo_metadata", + "read_proxy_config", + "translate_anthropic_messages_request", + "translate_openai_chat_response_to_anthropic", +] + +__version__ = "0.1.0" diff --git a/python/dynamo_agent_proxy/__main__.py b/python/dynamo_agent_proxy/__main__.py new file mode 100644 index 0000000..a1db4ca --- /dev/null +++ b/python/dynamo_agent_proxy/__main__.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from .proxy import main + + +if __name__ == "__main__": + main() diff --git a/python/dynamo_agent_proxy/proxy.py b/python/dynamo_agent_proxy/proxy.py new file mode 100644 index 0000000..b76aad0 --- /dev/null +++ b/python/dynamo_agent_proxy/proxy.py @@ -0,0 +1,917 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Generic OpenAI/Anthropic proxy that injects Dynamo ``nvext`` metadata.""" + +from __future__ import annotations + +import argparse +import gzip +import json +import os +import sys +import uuid +import zlib +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +from http.client import HTTPConnection, HTTPSConnection, HTTPResponse +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from typing import Any +from urllib.parse import urlsplit + +DEFAULT_BASE_URL = "http://127.0.0.1:8000/v1" +DEFAULT_API_KEY = "dynamo-local" +DEFAULT_LISTEN_HOST = "127.0.0.1" +DEFAULT_LISTEN_PORT = 18080 +DEFAULT_SESSION_TYPE_ID = "generic_agent" +DEFAULT_MODEL = "default" +STREAM_CHUNK_SIZE = 64 * 1024 + +HOP_BY_HOP_HEADERS = { + "connection", + "content-length", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailer", + "transfer-encoding", + "upgrade", +} + +REQUEST_BODY_METHODS = {"POST", "PUT", "PATCH"} +OPENAI_INJECTION_PATHS = { + "/chat/completions", + "/responses", + "/v1/chat/completions", + "/v1/responses", +} +ANTHROPIC_MESSAGES_PATHS = {"/messages", "/v1/messages"} + + +@dataclass(frozen=True) +class AgentAnnotation: + session_type_id: str + session_id: str + trajectory_id: str + parent_trajectory_id: str | None = None + agent_hints: dict[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class ProxyConfig: + upstream_scheme: str + upstream_host: str + upstream_port: int + upstream_prefix: str + api_key: str + model: str + annotation: AgentAnnotation + + +@dataclass(frozen=True) +class ListenConfig: + host: str + port: int + + +@dataclass(frozen=True) +class RuntimeConfig: + proxy: ProxyConfig + listen: ListenConfig + upstream_display: str + + +class UnsupportedContentEncoding(ValueError): + pass + + +def _env_value(env: Mapping[str, str], key: str) -> str | None: + value = env.get(key) + if value is None: + return None + stripped = value.strip() + return stripped or None + + +def _normalize_base_url(raw_base_url: str | None) -> str: + raw = (raw_base_url or DEFAULT_BASE_URL).strip().rstrip("/") + try: + parsed = urlsplit(raw) + except ValueError: + return raw + if parsed.scheme and parsed.netloc and parsed.path in {"", "/"}: + return raw + "/v1" + return raw + + +def _parse_json_object(raw: str | None) -> dict[str, Any]: + if not raw: + return {} + try: + parsed = json.loads(raw) + except json.JSONDecodeError: + return {} + return parsed if isinstance(parsed, dict) else {} + + +def _parse_hint(value: str) -> tuple[str, Any]: + if "=" not in value: + raise argparse.ArgumentTypeError("agent hints must use key=value") + key, raw_value = value.split("=", 1) + if not key: + raise argparse.ArgumentTypeError("agent hint key must not be empty") + try: + return key, json.loads(raw_value) + except json.JSONDecodeError: + return key, raw_value + + +def read_proxy_config( + argv: Sequence[str] | None = None, + env: Mapping[str, str] | None = None, +) -> RuntimeConfig: + env = env or os.environ + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--listen-host", default=_env_value(env, "DYNAMO_PROXY_LISTEN_HOST") or DEFAULT_LISTEN_HOST) + parser.add_argument( + "--listen-port", + type=int, + default=int(_env_value(env, "DYNAMO_PROXY_LISTEN_PORT") or DEFAULT_LISTEN_PORT), + ) + parser.add_argument( + "--upstream", + "--base-url", + dest="base_url", + default=_env_value(env, "DYNAMO_BASE_URL") or _env_value(env, "OPENAI_BASE_URL") or DEFAULT_BASE_URL, + ) + parser.add_argument("--api-key", default=_env_value(env, "DYNAMO_API_KEY") or DEFAULT_API_KEY) + parser.add_argument("--model", default=_env_value(env, "DYNAMO_MODEL") or DEFAULT_MODEL) + parser.add_argument( + "--session-type-id", + default=_env_value(env, "DYN_AGENT_SESSION_TYPE_ID") or DEFAULT_SESSION_TYPE_ID, + ) + parser.add_argument("--session-id", default=_env_value(env, "DYN_AGENT_SESSION_ID")) + parser.add_argument("--trajectory-id", default=_env_value(env, "DYN_AGENT_TRAJECTORY_ID")) + parser.add_argument( + "--parent-trajectory-id", + default=_env_value(env, "DYN_AGENT_PARENT_TRAJECTORY_ID"), + ) + parser.add_argument("--priority", type=int) + parser.add_argument("--osl", type=int) + parser.add_argument( + "--agent-hint", + action="append", + default=[], + type=_parse_hint, + help="Additional nvext.agent_hints entry as key=value; value may be JSON.", + ) + args = parser.parse_args(argv) + + agent_hints = _parse_json_object(_env_value(env, "DYN_AGENT_HINTS")) + if args.priority is not None: + agent_hints["priority"] = args.priority + if args.osl is not None: + agent_hints["osl"] = args.osl + for key, value in args.agent_hint: + agent_hints[key] = value + + normalized = _normalize_base_url(args.base_url) + upstream = urlsplit(normalized) + if upstream.scheme not in {"http", "https"}: + raise SystemExit("--upstream must start with http:// or https://") + if not upstream.hostname: + raise SystemExit("--upstream must include a host") + + session_id = args.session_id or f"proxy-{uuid.uuid4().hex}" + proxy = ProxyConfig( + upstream_scheme=upstream.scheme, + upstream_host=upstream.hostname, + upstream_port=upstream.port or (443 if upstream.scheme == "https" else 80), + upstream_prefix=upstream.path.rstrip("/"), + api_key=args.api_key, + model=args.model, + annotation=AgentAnnotation( + session_type_id=args.session_type_id, + session_id=session_id, + trajectory_id=args.trajectory_id or f"{session_id}:main", + parent_trajectory_id=args.parent_trajectory_id, + agent_hints=agent_hints, + ), + ) + return RuntimeConfig( + proxy=proxy, + listen=ListenConfig(host=args.listen_host, port=args.listen_port), + upstream_display=normalized, + ) + + +def merge_dynamo_metadata(payload: Any, annotation: AgentAnnotation) -> dict[str, Any]: + payload_record = dict(payload) if isinstance(payload, dict) else {} + existing_nvext = payload_record.get("nvext") if isinstance(payload_record.get("nvext"), dict) else {} + existing_agent_context = ( + existing_nvext.get("agent_context") + if isinstance(existing_nvext.get("agent_context"), dict) + else {} + ) + existing_agent_hints = ( + existing_nvext.get("agent_hints") + if isinstance(existing_nvext.get("agent_hints"), dict) + else {} + ) + + agent_context: dict[str, Any] = dict(existing_agent_context) + agent_context.update( + { + "session_type_id": annotation.session_type_id, + "session_id": annotation.session_id, + "trajectory_id": annotation.trajectory_id, + } + ) + if annotation.parent_trajectory_id: + agent_context["parent_trajectory_id"] = annotation.parent_trajectory_id + + nvext = dict(existing_nvext) + nvext["agent_context"] = agent_context + if annotation.agent_hints or existing_agent_hints: + agent_hints = dict(existing_agent_hints) + agent_hints.update(annotation.agent_hints) + nvext["agent_hints"] = agent_hints + + payload_record["nvext"] = nvext + return payload_record + + +def _is_json_content_type(value: str | None) -> bool: + if not value: + return False + media_type = value.split(";", 1)[0].strip().lower() + return media_type == "application/json" or media_type.endswith("+json") + + +def _decode_request_body(body: bytes, content_encoding: str | None) -> bytes: + encodings = [ + item.strip().lower() + for item in (content_encoding or "identity").split(",") + if item.strip() + ] + for encoding in reversed(encodings): + if encoding == "identity": + continue + if encoding in {"gzip", "x-gzip"}: + try: + body = gzip.decompress(body) + except (EOFError, OSError) as exc: + raise UnsupportedContentEncoding(f"invalid gzip request body: {exc}") from exc + continue + if encoding == "deflate": + try: + body = zlib.decompress(body) + except zlib.error as exc: + raise UnsupportedContentEncoding(f"invalid deflate request body: {exc}") from exc + continue + raise UnsupportedContentEncoding(f"unsupported request Content-Encoding: {encoding}") + return body + + +def annotate_json_request_body( + headers: Mapping[str, str], + body: bytes, + annotation: AgentAnnotation, +) -> tuple[bytes, bool]: + if not body or not _is_json_content_type(headers.get("content-type")): + return body, False + + decoded = _decode_request_body(body, headers.get("content-encoding")) + try: + payload = json.loads(decoded.decode("utf-8")) + except (UnicodeDecodeError, json.JSONDecodeError) as exc: + raise ValueError(f"invalid JSON request body: {exc}") from exc + + if not isinstance(payload, dict): + return body, False + + annotated = json.dumps( + merge_dynamo_metadata(payload, annotation), + separators=(",", ":"), + ensure_ascii=False, + ).encode("utf-8") + return annotated, True + + +def _normalized_path(request_path: str) -> str: + path = urlsplit(request_path).path.rstrip("/") + return path or "/" + + +def _make_upstream_path(prefix: str, request_path: str, forced_path: str | None = None) -> str: + split = urlsplit(request_path) + suffix = forced_path or split.path or "/" + if suffix == "/v1": + suffix = "/" + elif suffix.startswith("/v1/"): + suffix = suffix[3:] + if not suffix.startswith("/"): + suffix = f"/{suffix}" + base = prefix.rstrip("/") + path = f"{base}{suffix}" if base else suffix + if forced_path is None and split.query: + path = f"{path}?{split.query}" + return path + + +def _connection(config: ProxyConfig) -> HTTPConnection: + if config.upstream_scheme == "https": + return HTTPSConnection(config.upstream_host, config.upstream_port, timeout=300) + return HTTPConnection(config.upstream_host, config.upstream_port, timeout=300) + + +def _anthropic_block_to_text(block: Any) -> str: + if isinstance(block, str): + return block + if not isinstance(block, dict): + return "" + if block.get("type") == "text" and isinstance(block.get("text"), str): + return block["text"] + if block.get("type") == "tool_result": + content = block.get("content") + if isinstance(content, str): + return content + if isinstance(content, list): + return "\n".join(filter(None, (_anthropic_block_to_text(item) for item in content))) + return "" + + +def _anthropic_system_to_openai(system: Any) -> str | None: + if isinstance(system, str): + return system + if isinstance(system, list): + text = "\n".join(filter(None, (_anthropic_block_to_text(block) for block in system))) + return text or None + return None + + +def _anthropic_message_to_openai_messages(message: Mapping[str, Any]) -> list[dict[str, Any]]: + role = "assistant" if message.get("role") == "assistant" else "user" + content = message.get("content") + if isinstance(content, str): + return [{"role": role, "content": content}] + if not isinstance(content, list): + return [{"role": role, "content": ""}] + + messages: list[dict[str, Any]] = [] + text_parts: list[str] = [] + tool_calls: list[dict[str, Any]] = [] + + for block in content: + if not isinstance(block, dict): + continue + if block.get("type") == "tool_result": + if text_parts: + messages.append({"role": "user", "content": "\n".join(text_parts)}) + text_parts = [] + messages.append( + { + "role": "tool", + "tool_call_id": str(block.get("tool_use_id", "")), + "content": _anthropic_block_to_text(block), + } + ) + continue + if role == "assistant" and block.get("type") == "tool_use": + tool_calls.append( + { + "id": str(block.get("id") or f"toolu_{uuid.uuid4().hex}"), + "type": "function", + "function": { + "name": str(block.get("name") or "tool"), + "arguments": json.dumps(block.get("input") or {}), + }, + } + ) + continue + text = _anthropic_block_to_text(block) + if text: + text_parts.append(text) + + message_body: dict[str, Any] = {"role": role, "content": "\n".join(text_parts)} + if role == "assistant": + message_body["content"] = message_body["content"] or None + if tool_calls: + message_body["tool_calls"] = tool_calls + messages.append(message_body) + return messages + + +def _anthropic_tools_to_openai(tools: Any) -> Any | None: + if not isinstance(tools, list): + return None + converted = [] + for tool in tools: + if not isinstance(tool, dict): + continue + function: dict[str, Any] = { + "name": str(tool.get("name") or "tool"), + "parameters": tool.get("input_schema") if isinstance(tool.get("input_schema"), dict) else {"type": "object", "properties": {}}, + } + if isinstance(tool.get("description"), str): + function["description"] = tool["description"] + converted.append({"type": "function", "function": function}) + return converted + + +def _anthropic_tool_choice_to_openai(tool_choice: Any) -> Any | None: + if not isinstance(tool_choice, dict): + return None + choice_type = tool_choice.get("type") + if choice_type == "auto": + return "auto" + if choice_type == "any": + return "required" + if choice_type == "none": + return "none" + if choice_type == "tool" and isinstance(tool_choice.get("name"), str): + return {"type": "function", "function": {"name": tool_choice["name"]}} + return None + + +def translate_anthropic_messages_request(payload: Any, config: ProxyConfig) -> dict[str, Any]: + if not isinstance(payload, dict): + raise ValueError("Anthropic messages request body must be a JSON object") + + messages: list[dict[str, Any]] = [] + system = _anthropic_system_to_openai(payload.get("system")) + if system is not None: + messages.append({"role": "system", "content": system}) + if isinstance(payload.get("messages"), list): + for message in payload["messages"]: + if isinstance(message, dict): + messages.extend(_anthropic_message_to_openai_messages(message)) + + openai_request: dict[str, Any] = { + "model": payload.get("model") if isinstance(payload.get("model"), str) else config.model, + "messages": messages, + "stream": payload.get("stream") is True, + } + for anthropic_key, openai_key in ( + ("max_tokens", "max_tokens"), + ("temperature", "temperature"), + ("top_p", "top_p"), + ): + if isinstance(payload.get(anthropic_key), (int, float)): + openai_request[openai_key] = payload[anthropic_key] + if isinstance(payload.get("stop_sequences"), list): + openai_request["stop"] = payload["stop_sequences"] + tools = _anthropic_tools_to_openai(payload.get("tools")) + if tools is not None: + openai_request["tools"] = tools + tool_choice = _anthropic_tool_choice_to_openai(payload.get("tool_choice")) + if tool_choice is not None: + openai_request["tool_choice"] = tool_choice + + return merge_dynamo_metadata(openai_request, config.annotation) + + +def _map_finish_reason(finish_reason: Any) -> str | None: + if finish_reason == "length": + return "max_tokens" + if finish_reason == "tool_calls": + return "tool_use" + if finish_reason == "stop": + return "end_turn" + return None + + +def _text_content_from_openai_message(message: Mapping[str, Any]) -> str: + content = message.get("content") + if isinstance(content, str): + return content + if isinstance(content, list): + return "\n".join( + part.get("text", "") + for part in content + if isinstance(part, dict) and part.get("type") == "text" and isinstance(part.get("text"), str) + ) + return "" + + +def translate_openai_chat_response_to_anthropic(payload: Any, fallback_model: str) -> dict[str, Any]: + body = payload if isinstance(payload, dict) else {} + choices = body.get("choices") if isinstance(body.get("choices"), list) else [] + first_choice = choices[0] if choices and isinstance(choices[0], dict) else {} + message = first_choice.get("message") if isinstance(first_choice.get("message"), dict) else {} + content: list[dict[str, Any]] = [] + text = _text_content_from_openai_message(message) + if text: + content.append({"type": "text", "text": text}) + + tool_calls = message.get("tool_calls") if isinstance(message, dict) else None + if isinstance(tool_calls, list): + for tool_call in tool_calls: + if not isinstance(tool_call, dict) or not isinstance(tool_call.get("function"), dict): + continue + function = tool_call["function"] + tool_input: Any = {} + arguments = function.get("arguments") + if isinstance(arguments, str) and arguments.strip(): + try: + tool_input = json.loads(arguments) + except json.JSONDecodeError: + tool_input = {"raw_arguments": arguments} + content.append( + { + "type": "tool_use", + "id": str(tool_call.get("id") or f"toolu_{uuid.uuid4().hex}"), + "name": str(function.get("name") or "tool"), + "input": tool_input, + } + ) + + usage = body.get("usage") if isinstance(body.get("usage"), dict) else {} + return { + "id": body.get("id") if isinstance(body.get("id"), str) else f"msg_{uuid.uuid4().hex}", + "type": "message", + "role": "assistant", + "model": body.get("model") if isinstance(body.get("model"), str) else fallback_model, + "content": content, + "stop_reason": _map_finish_reason(first_choice.get("finish_reason")), + "stop_sequence": None, + "usage": { + "input_tokens": usage.get("prompt_tokens") if isinstance(usage.get("prompt_tokens"), int) else 0, + "output_tokens": usage.get("completion_tokens") if isinstance(usage.get("completion_tokens"), int) else 0, + }, + } + + +def _write_sse(handler: BaseHTTPRequestHandler, event: str, data: Any) -> None: + handler.wfile.write(f"event: {event}\n".encode("utf-8")) + handler.wfile.write( + b"data: " + json.dumps(data, separators=(",", ":")).encode("utf-8") + b"\n\n" + ) + handler.wfile.flush() + + +def _iter_sse_data(response: HTTPResponse): + data_lines: list[str] = [] + while line := response.readline(): + decoded = line.decode("utf-8", errors="replace").rstrip("\r\n") + if decoded == "": + if data_lines: + yield "\n".join(data_lines) + data_lines = [] + continue + if decoded.startswith("data:"): + data_lines.append(decoded[5:].lstrip()) + if data_lines: + yield "\n".join(data_lines) + + +def _send_anthropic_stream(handler: BaseHTTPRequestHandler, response: HTTPResponse, model: str) -> None: + handler.send_response(response.status, response.reason) + handler.send_header("Content-Type", "text/event-stream; charset=utf-8") + handler.send_header("Cache-Control", "no-cache") + handler.send_header("Connection", "close") + handler.end_headers() + + message_id = f"msg_{uuid.uuid4().hex}" + _write_sse( + handler, + "message_start", + { + "type": "message_start", + "message": { + "id": message_id, + "type": "message", + "role": "assistant", + "model": model, + "content": [], + "stop_reason": None, + "stop_sequence": None, + "usage": {"input_tokens": 0, "output_tokens": 0}, + }, + }, + ) + + text_block_open = False + next_block_index = 0 + stop_reason: str | None = None + output_tokens = 0 + tool_blocks: dict[int, int] = {} + + for data in _iter_sse_data(response): + if data == "[DONE]": + continue + try: + chunk = json.loads(data) + except json.JSONDecodeError: + continue + if isinstance(chunk.get("usage"), dict) and isinstance(chunk["usage"].get("completion_tokens"), int): + output_tokens = chunk["usage"]["completion_tokens"] + choices = chunk.get("choices") if isinstance(chunk.get("choices"), list) else [] + for choice in choices: + if not isinstance(choice, dict): + continue + if choice.get("finish_reason") is not None: + stop_reason = _map_finish_reason(choice.get("finish_reason")) + delta = choice.get("delta") if isinstance(choice.get("delta"), dict) else {} + if isinstance(delta.get("content"), str) and delta["content"]: + if not text_block_open: + _write_sse( + handler, + "content_block_start", + { + "type": "content_block_start", + "index": next_block_index, + "content_block": {"type": "text", "text": ""}, + }, + ) + text_block_open = True + next_block_index += 1 + _write_sse( + handler, + "content_block_delta", + { + "type": "content_block_delta", + "index": next_block_index - 1, + "delta": {"type": "text_delta", "text": delta["content"]}, + }, + ) + if isinstance(delta.get("tool_calls"), list): + if text_block_open: + _write_sse( + handler, + "content_block_stop", + {"type": "content_block_stop", "index": next_block_index - 1}, + ) + text_block_open = False + for tool_call in delta["tool_calls"]: + if not isinstance(tool_call, dict): + continue + index = tool_call.get("index") if isinstance(tool_call.get("index"), int) else 0 + block_index = tool_blocks.get(index) + function = tool_call.get("function") if isinstance(tool_call.get("function"), dict) else {} + if block_index is None: + block_index = next_block_index + tool_blocks[index] = block_index + next_block_index += 1 + _write_sse( + handler, + "content_block_start", + { + "type": "content_block_start", + "index": block_index, + "content_block": { + "type": "tool_use", + "id": str(tool_call.get("id") or f"toolu_{uuid.uuid4().hex}"), + "name": str(function.get("name") or "tool"), + "input": {}, + }, + }, + ) + arguments = function.get("arguments") + if isinstance(arguments, str) and arguments: + _write_sse( + handler, + "content_block_delta", + { + "type": "content_block_delta", + "index": block_index, + "delta": {"type": "input_json_delta", "partial_json": arguments}, + }, + ) + + if text_block_open: + _write_sse( + handler, + "content_block_stop", + {"type": "content_block_stop", "index": next_block_index - 1}, + ) + for block_index in tool_blocks.values(): + _write_sse(handler, "content_block_stop", {"type": "content_block_stop", "index": block_index}) + _write_sse( + handler, + "message_delta", + { + "type": "message_delta", + "delta": {"stop_reason": stop_reason or "end_turn", "stop_sequence": None}, + "usage": {"output_tokens": output_tokens}, + }, + ) + _write_sse(handler, "message_stop", {"type": "message_stop"}) + + +def make_handler(config: ProxyConfig) -> type[BaseHTTPRequestHandler]: + class DynamoAgentProxyHandler(BaseHTTPRequestHandler): + protocol_version = "HTTP/1.1" + + def do_GET(self) -> None: + self._proxy_openai() + + def do_HEAD(self) -> None: + self._proxy_openai() + + def do_POST(self) -> None: + if _normalized_path(self.path) in ANTHROPIC_MESSAGES_PATHS: + self._proxy_anthropic_messages() + else: + self._proxy_openai() + + def do_PUT(self) -> None: + self._proxy_openai() + + def do_PATCH(self) -> None: + self._proxy_openai() + + def do_DELETE(self) -> None: + self._proxy_openai() + + def _read_body(self) -> bytes: + if self.command not in REQUEST_BODY_METHODS: + return b"" + length_header = self.headers.get("content-length") + if not length_header: + return b"" + try: + length = int(length_header) + except ValueError as exc: + raise ValueError(f"invalid Content-Length: {length_header}") from exc + return self.rfile.read(length) + + def _outgoing_headers(self, annotated: bool, body_length: int) -> dict[str, str]: + outgoing: dict[str, str] = {} + for key, value in self.headers.items(): + lower = key.lower() + if lower in HOP_BY_HOP_HEADERS or lower == "host": + continue + if annotated and lower in {"content-encoding", "content-type"}: + continue + if lower in {"authorization", "x-api-key", "anthropic-version", "accept-encoding"}: + continue + outgoing[key] = value + + outgoing["Host"] = f"{config.upstream_host}:{config.upstream_port}" + outgoing["Authorization"] = f"Bearer {config.api_key}" + if not any(key.lower() == "x-request-id" for key in outgoing): + outgoing["x-request-id"] = str(uuid.uuid4()) + if self.command in REQUEST_BODY_METHODS: + outgoing["Content-Length"] = str(body_length) + if annotated: + outgoing["Content-Type"] = "application/json" + return outgoing + + def _proxy_openai(self) -> None: + try: + request_headers = {key.lower(): value for key, value in self.headers.items()} + body = self._read_body() + annotated = False + if self.command in REQUEST_BODY_METHODS and _normalized_path(self.path) in OPENAI_INJECTION_PATHS: + body, annotated = annotate_json_request_body( + request_headers, + body, + config.annotation, + ) + headers = self._outgoing_headers(annotated, len(body)) + except UnsupportedContentEncoding as exc: + self.send_error(415, str(exc)) + return + except ValueError as exc: + self.send_error(400, str(exc)) + return + + upstream_path = _make_upstream_path(config.upstream_prefix, self.path) + try: + conn = _connection(config) + conn.request( + self.command, + upstream_path, + body=body if self.command in REQUEST_BODY_METHODS else None, + headers=headers, + ) + response = conn.getresponse() + self._send_upstream_response(response) + except OSError as exc: + self.send_error(502, f"upstream request failed: {exc}") + finally: + try: + conn.close() + except UnboundLocalError: + pass + + def _proxy_anthropic_messages(self) -> None: + try: + body = self._read_body() + decoded = _decode_request_body(body, self.headers.get("content-encoding")) + payload = json.loads(decoded.decode("utf-8")) + translated = translate_anthropic_messages_request(payload, config) + outgoing_body = json.dumps(translated, separators=(",", ":"), ensure_ascii=False).encode("utf-8") + headers = self._outgoing_headers(True, len(outgoing_body)) + headers["Content-Type"] = "application/json" + except UnsupportedContentEncoding as exc: + self.send_error(415, str(exc)) + return + except (UnicodeDecodeError, json.JSONDecodeError, ValueError) as exc: + self.send_error(400, f"invalid Anthropic request body: {exc}") + return + + try: + conn = _connection(config) + conn.request( + "POST", + _make_upstream_path(config.upstream_prefix, self.path, forced_path="/chat/completions"), + body=outgoing_body, + headers=headers, + ) + response = conn.getresponse() + content_type = response.getheader("content-type", "") + if response.status >= 400: + self._send_upstream_response(response) + return + if translated.get("stream") is True and "text/event-stream" in content_type.lower(): + _send_anthropic_stream(self, response, str(translated.get("model") or config.model)) + self.close_connection = True + return + response_body = response.read() + try: + upstream_payload = json.loads(response_body.decode("utf-8")) + except (UnicodeDecodeError, json.JSONDecodeError) as exc: + self.send_error(502, f"invalid upstream JSON response: {exc}") + return + anthropic_body = json.dumps( + translate_openai_chat_response_to_anthropic( + upstream_payload, + str(translated.get("model") or config.model), + ), + separators=(",", ":"), + ensure_ascii=False, + ).encode("utf-8") + self.send_response(response.status, response.reason) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(anthropic_body))) + self.send_header("Connection", "close") + self.end_headers() + self.wfile.write(anthropic_body) + self.close_connection = True + except OSError as exc: + self.send_error(502, f"upstream request failed: {exc}") + finally: + try: + conn.close() + except UnboundLocalError: + pass + + def _send_upstream_response(self, response: HTTPResponse) -> None: + self.send_response(response.status, response.reason) + content_type = None + for key, value in response.getheaders(): + lower = key.lower() + if lower in HOP_BY_HOP_HEADERS: + continue + if lower == "content-type": + content_type = value + self.send_header(key, value) + self.send_header("Connection", "close") + self.end_headers() + + if self.command == "HEAD": + self.close_connection = True + return + + media_type = (content_type or "").split(";", 1)[0].strip().lower() + if media_type == "text/event-stream": + while line := response.readline(): + self.wfile.write(line) + self.wfile.flush() + else: + while chunk := response.read(STREAM_CHUNK_SIZE): + self.wfile.write(chunk) + self.wfile.flush() + self.close_connection = True + + return DynamoAgentProxyHandler + + +def serve(config: RuntimeConfig) -> None: + handler = make_handler(config.proxy) + server = ThreadingHTTPServer((config.listen.host, config.listen.port), handler) + print( + "dynamo-agent-proxy listening on " + f"http://{config.listen.host}:{config.listen.port} -> {config.upstream_display}; " + f"session_id={config.proxy.annotation.session_id}; " + f"trajectory_id={config.proxy.annotation.trajectory_id}", + file=sys.stderr, + flush=True, + ) + try: + server.serve_forever() + except KeyboardInterrupt: + pass + finally: + server.server_close() + + +def main(argv: Sequence[str] | None = None) -> None: + serve(read_proxy_config(argv)) + + +if __name__ == "__main__": + main() diff --git a/test/dynamo-provider.test.ts b/test/dynamo-provider.test.ts index e38e1f9..dc7598f 100644 --- a/test/dynamo-provider.test.ts +++ b/test/dynamo-provider.test.ts @@ -94,7 +94,8 @@ describe("pi-subagents trajectory bridge", () => { }); it("skips the bridge when PI_SUBAGENT_CHILD is not 1", () => { - expect(computeSubagentTrajectoryRewrite({ ...childEnv, PI_SUBAGENT_CHILD: undefined })).toBeNull(); + const { PI_SUBAGENT_CHILD: _omit, ...envWithoutChildFlag } = childEnv; + expect(computeSubagentTrajectoryRewrite(envWithoutChildFlag)).toBeNull(); }); it("does NOT override an explicit DYN_AGENT_PARENT_TRAJECTORY_ID (manual wins)", () => { @@ -104,12 +105,15 @@ describe("pi-subagents trajectory bridge", () => { }); it("skips when inherited DYN_AGENT_TRAJECTORY_ID is absent", () => { - expect(computeSubagentTrajectoryRewrite({ ...childEnv, DYN_AGENT_TRAJECTORY_ID: undefined })).toBeNull(); + const { DYN_AGENT_TRAJECTORY_ID: _omit, ...envWithoutTrajectory } = childEnv; + expect(computeSubagentTrajectoryRewrite(envWithoutTrajectory)).toBeNull(); }); it("skips when PI_SUBAGENT_RUN_ID or PI_SUBAGENT_CHILD_AGENT is missing", () => { - expect(computeSubagentTrajectoryRewrite({ ...childEnv, PI_SUBAGENT_RUN_ID: undefined })).toBeNull(); - expect(computeSubagentTrajectoryRewrite({ ...childEnv, PI_SUBAGENT_CHILD_AGENT: undefined })).toBeNull(); + const { PI_SUBAGENT_RUN_ID: _omitRunId, ...envWithoutRunId } = childEnv; + const { PI_SUBAGENT_CHILD_AGENT: _omitChildAgent, ...envWithoutChildAgent } = childEnv; + expect(computeSubagentTrajectoryRewrite(envWithoutRunId)).toBeNull(); + expect(computeSubagentTrajectoryRewrite(envWithoutChildAgent)).toBeNull(); }); it("readDynamoConfig surfaces the synthesized ids", () => { @@ -132,9 +136,9 @@ describe("pi-subagents trajectory bridge", () => { // passes { ...process.env, ...subagentEnv }. The grandchild then sees // its own synthesized id as inherited DYN_AGENT_TRAJECTORY_ID, so the // next rewrite treats THIS generation as the parent. + const { DYN_AGENT_PARENT_TRAJECTORY_ID: _omitParent, ...envWithoutParent } = env; const grandchildEnv = { - ...env, - DYN_AGENT_PARENT_TRAJECTORY_ID: undefined, + ...envWithoutParent, PI_SUBAGENT_CHILD_AGENT: "subworker", PI_SUBAGENT_CHILD_INDEX: "0", }; diff --git a/test/python/test_agent_proxy.py b/test/python/test_agent_proxy.py new file mode 100644 index 0000000..da7c7fd --- /dev/null +++ b/test/python/test_agent_proxy.py @@ -0,0 +1,328 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import http.client +import json +import threading +import unittest +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer + +from dynamo_agent_proxy.proxy import ( + AgentAnnotation, + ProxyConfig, + annotate_json_request_body, + make_handler, + merge_dynamo_metadata, + read_proxy_config, + translate_anthropic_messages_request, + translate_openai_chat_response_to_anthropic, +) + + +def annotation() -> AgentAnnotation: + return AgentAnnotation( + session_type_id="generic_agent", + session_id="session-1", + trajectory_id="session-1:main", + agent_hints={"priority": 7, "osl": 512}, + ) + + +def proxy_config(upstream_port: int = 8000) -> ProxyConfig: + return ProxyConfig( + upstream_scheme="http", + upstream_host="127.0.0.1", + upstream_port=upstream_port, + upstream_prefix="/v1", + api_key="test-key", + model="fallback-model", + annotation=annotation(), + ) + + +def start_server(server: ThreadingHTTPServer) -> None: + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + + +def close_server(server: ThreadingHTTPServer) -> None: + server.shutdown() + server.server_close() + + +class RecordingUpstreamHandler(BaseHTTPRequestHandler): + protocol_version = "HTTP/1.1" + + def do_POST(self) -> None: + length = int(self.headers.get("content-length", "0")) + body = self.rfile.read(length) + self.server.recorded.append( + { + "path": self.path, + "headers": dict(self.headers.items()), + "body": body, + } + ) + if self.path == "/v1/chat/completions": + response = json.dumps( + { + "id": "chatcmpl-1", + "model": "demo", + "choices": [ + {"finish_reason": "stop", "message": {"role": "assistant", "content": "ok"}} + ], + "usage": {"prompt_tokens": 3, "completion_tokens": 1}, + }, + separators=(",", ":"), + ).encode("utf-8") + else: + response = b'{"ok":true}' + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(response))) + self.end_headers() + self.wfile.write(response) + + def log_message(self, format: str, *args: object) -> None: + return + + +class StreamingUpstreamHandler(BaseHTTPRequestHandler): + protocol_version = "HTTP/1.1" + + def do_POST(self) -> None: + length = int(self.headers.get("content-length", "0")) + self.server.recorded_body = self.rfile.read(length) + chunks = [ + b'data: {"choices":[{"delta":{"content":"he"}}]}\n\n', + b'data: {"choices":[{"delta":{"content":"llo"},"finish_reason":"stop"}],"usage":{"completion_tokens":1}}\n\n', + b"data: [DONE]\n\n", + ] + body = b"".join(chunks) + self.send_response(200) + self.send_header("Content-Type", "text/event-stream") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + + def log_message(self, format: str, *args: object) -> None: + return + + +class AgentProxyTests(unittest.TestCase): + def test_reads_env_and_cli_overrides(self) -> None: + runtime = read_proxy_config( + ["--listen-port", "19090", "--session-id", "session-cli", "--agent-hint", "osl=1024"], + env={ + "DYNAMO_BASE_URL": "http://dynamo.test", + "DYNAMO_API_KEY": "dyn-key", + "DYN_AGENT_SESSION_TYPE_ID": "agent-kind", + "DYN_AGENT_HINTS": '{"priority":3}', + }, + ) + + self.assertEqual(runtime.listen.port, 19090) + self.assertEqual(runtime.proxy.upstream_prefix, "/v1") + self.assertEqual(runtime.proxy.upstream_host, "dynamo.test") + self.assertEqual(runtime.proxy.api_key, "dyn-key") + self.assertEqual(runtime.proxy.annotation.session_type_id, "agent-kind") + self.assertEqual(runtime.proxy.annotation.session_id, "session-cli") + self.assertEqual(runtime.proxy.annotation.trajectory_id, "session-cli:main") + self.assertEqual(runtime.proxy.annotation.agent_hints, {"priority": 3, "osl": 1024}) + + def test_merge_dynamo_metadata_preserves_request_values(self) -> None: + payload = merge_dynamo_metadata( + { + "model": "demo", + "nvext": { + "extra_fields": ["worker_id"], + "agent_context": {"trajectory_id": "client-traj", "custom": "kept"}, + "agent_hints": {"priority": 1, "custom_hint": True}, + }, + }, + annotation(), + ) + + self.assertEqual(payload["nvext"]["agent_context"]["session_id"], "session-1") + self.assertEqual(payload["nvext"]["agent_context"]["trajectory_id"], "session-1:main") + self.assertEqual(payload["nvext"]["agent_context"]["custom"], "kept") + self.assertEqual(payload["nvext"]["agent_hints"], {"priority": 7, "osl": 512, "custom_hint": True}) + + def test_annotates_json_object_request_body(self) -> None: + body, annotated = annotate_json_request_body( + {"content-type": "application/json"}, + json.dumps({"model": "demo", "input": "hello"}).encode("utf-8"), + annotation(), + ) + + self.assertTrue(annotated) + forwarded = json.loads(body.decode("utf-8")) + self.assertEqual(forwarded["nvext"]["agent_context"]["trajectory_id"], "session-1:main") + self.assertEqual(forwarded["nvext"]["agent_hints"], {"priority": 7, "osl": 512}) + + def test_translates_anthropic_messages_request(self) -> None: + translated = translate_anthropic_messages_request( + { + "model": "claude-style-model", + "system": "You are concise.", + "max_tokens": 64, + "messages": [{"role": "user", "content": [{"type": "text", "text": "hello"}]}], + "tools": [ + { + "name": "search", + "description": "Search docs", + "input_schema": {"type": "object", "properties": {"query": {"type": "string"}}}, + } + ], + "tool_choice": {"type": "tool", "name": "search"}, + }, + proxy_config(), + ) + + self.assertEqual(translated["model"], "claude-style-model") + self.assertEqual( + translated["messages"], + [ + {"role": "system", "content": "You are concise."}, + {"role": "user", "content": "hello"}, + ], + ) + self.assertEqual(translated["tool_choice"], {"type": "function", "function": {"name": "search"}}) + self.assertEqual(translated["nvext"]["agent_context"]["session_id"], "session-1") + + def test_translates_openai_response_to_anthropic(self) -> None: + translated = translate_openai_chat_response_to_anthropic( + { + "id": "chatcmpl-1", + "model": "demo", + "choices": [ + { + "finish_reason": "tool_calls", + "message": { + "content": "checking", + "tool_calls": [ + { + "id": "call-1", + "type": "function", + "function": {"name": "search", "arguments": '{"query":"dynamo"}'}, + } + ], + }, + } + ], + "usage": {"prompt_tokens": 11, "completion_tokens": 7}, + }, + "fallback", + ) + + self.assertEqual(translated["stop_reason"], "tool_use") + self.assertEqual(translated["usage"], {"input_tokens": 11, "output_tokens": 7}) + self.assertEqual( + translated["content"], + [ + {"type": "text", "text": "checking"}, + {"type": "tool_use", "id": "call-1", "name": "search", "input": {"query": "dynamo"}}, + ], + ) + + def test_openai_chat_and_responses_requests_are_forwarded_with_nvext(self) -> None: + upstream = ThreadingHTTPServer(("127.0.0.1", 0), RecordingUpstreamHandler) + upstream.recorded = [] + start_server(upstream) + proxy = ThreadingHTTPServer(("127.0.0.1", 0), make_handler(proxy_config(upstream.server_address[1]))) + start_server(proxy) + conn = http.client.HTTPConnection("127.0.0.1", proxy.server_address[1], timeout=5) + try: + for path in ["/v1/chat/completions", "/v1/responses"]: + conn.request( + "POST", + path, + body=json.dumps({"model": "demo", "input": "hello"}), + headers={"Content-Type": "application/json", "Authorization": "Bearer client-key"}, + ) + response = conn.getresponse() + self.assertEqual(response.status, 200) + response.read() + finally: + conn.close() + close_server(proxy) + close_server(upstream) + + self.assertEqual([item["path"] for item in upstream.recorded], ["/v1/chat/completions", "/v1/responses"]) + self.assertEqual([item["headers"]["Authorization"] for item in upstream.recorded], ["Bearer test-key", "Bearer test-key"]) + first_body = json.loads(upstream.recorded[0]["body"].decode("utf-8")) + second_body = json.loads(upstream.recorded[1]["body"].decode("utf-8")) + self.assertEqual(first_body["nvext"]["agent_context"]["trajectory_id"], "session-1:main") + self.assertEqual(second_body["nvext"]["agent_hints"], {"priority": 7, "osl": 512}) + + def test_anthropic_messages_request_forwards_chat_completion_to_dynamo(self) -> None: + upstream = ThreadingHTTPServer(("127.0.0.1", 0), RecordingUpstreamHandler) + upstream.recorded = [] + start_server(upstream) + proxy = ThreadingHTTPServer(("127.0.0.1", 0), make_handler(proxy_config(upstream.server_address[1]))) + start_server(proxy) + conn = http.client.HTTPConnection("127.0.0.1", proxy.server_address[1], timeout=5) + try: + conn.request( + "POST", + "/v1/messages", + body=json.dumps({"model": "demo", "max_tokens": 8, "messages": [{"role": "user", "content": "hello"}]}), + headers={"Content-Type": "application/json", "x-api-key": "client-key"}, + ) + response = conn.getresponse() + body = json.loads(response.read().decode("utf-8")) + self.assertEqual(response.status, 200) + finally: + conn.close() + close_server(proxy) + close_server(upstream) + + self.assertEqual(upstream.recorded[0]["path"], "/v1/chat/completions") + forwarded = json.loads(upstream.recorded[0]["body"].decode("utf-8")) + self.assertEqual(forwarded["nvext"]["agent_context"]["session_type_id"], "generic_agent") + self.assertEqual(body["type"], "message") + self.assertEqual(body["content"], [{"type": "text", "text": "ok"}]) + self.assertEqual(body["stop_reason"], "end_turn") + + def test_anthropic_streaming_response_is_translated_to_anthropic_sse(self) -> None: + upstream = ThreadingHTTPServer(("127.0.0.1", 0), StreamingUpstreamHandler) + start_server(upstream) + proxy = ThreadingHTTPServer(("127.0.0.1", 0), make_handler(proxy_config(upstream.server_address[1]))) + start_server(proxy) + conn = http.client.HTTPConnection("127.0.0.1", proxy.server_address[1], timeout=5) + try: + conn.request( + "POST", + "/v1/messages", + body=json.dumps( + { + "model": "demo", + "max_tokens": 8, + "stream": True, + "messages": [{"role": "user", "content": "hello"}], + } + ), + headers={"Content-Type": "application/json", "x-api-key": "client-key"}, + ) + response = conn.getresponse() + body = response.read().decode("utf-8") + self.assertEqual(response.status, 200) + finally: + conn.close() + close_server(proxy) + close_server(upstream) + + forwarded = json.loads(upstream.recorded_body.decode("utf-8")) + self.assertTrue(forwarded["stream"]) + self.assertEqual(forwarded["nvext"]["agent_context"]["trajectory_id"], "session-1:main") + self.assertIn("event: message_start", body) + self.assertIn('"type":"text_delta","text":"he"', body) + self.assertIn('"type":"text_delta","text":"llo"', body) + self.assertIn('"stop_reason":"end_turn"', body) + self.assertIn("event: message_stop", body) + + +if __name__ == "__main__": + unittest.main()