diff --git a/Cargo.lock b/Cargo.lock index 89af681..48eda10 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2085,9 +2085,9 @@ checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" [[package]] name = "is-ai-agent" -version = "0.2.1" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8745cc12e6796e1b20733dcc27e6062d51a22083eb00eaf0ed0a0880a60d302" +checksum = "9848cf556ca7b76f8e28f386a76deab397c60ef815f124f54f5c59016b1586f0" [[package]] name = "is-docker" diff --git a/crates/clickhousectl/Cargo.toml b/crates/clickhousectl/Cargo.toml index a7eb750..6040931 100644 --- a/crates/clickhousectl/Cargo.toml +++ b/crates/clickhousectl/Cargo.toml @@ -42,7 +42,7 @@ url = "2.5.8" tabled = "0.20.0" clickhouse-cloud-api = { version = "0.3.1", path = "../clickhouse-cloud-api" } uuid = { version = "1.23.0", features = ["v4"] } -is-ai-agent = "0.2.1" +is-ai-agent = "0.4.0" base64 = "0.22.1" bollard = "0.21.0" crossterm = "0.29.0" diff --git a/crates/clickhousectl/src/cloud/auth.rs b/crates/clickhousectl/src/cloud/auth.rs index 49a536c..9199c1b 100644 --- a/crates/clickhousectl/src/cloud/auth.rs +++ b/crates/clickhousectl/src/cloud/auth.rs @@ -148,8 +148,7 @@ pub async fn device_auth_login(api_url: &str) -> Result, url_override: Option<&str>, ) -> Result { - let http = reqwest::Client::builder() - .user_agent(crate::user_agent::user_agent()) + let http = crate::http::client_builder() .build() .map_err(|e| CloudError::new(format!("Failed to create HTTP client: {}", e)))?; diff --git a/crates/clickhousectl/src/http.rs b/crates/clickhousectl/src/http.rs new file mode 100644 index 0000000..97b85d4 --- /dev/null +++ b/crates/clickhousectl/src/http.rs @@ -0,0 +1,102 @@ +//! Canonical construction of outbound HTTP clients. +//! +//! Every `reqwest::Client` the CLI builds — Cloud API, OAuth, the updater, the +//! version manager — goes through [`client_builder`], so they uniformly carry +//! the `User-Agent` (built in `crate::user_agent`) and the agent +//! session/trace correlation headers, and any future builder picks these up +//! for free. + +use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; + +/// `agent-session-id`: the calling agent's stable per-session/conversation id. +const AGENT_SESSION_ID: HeaderName = HeaderName::from_static("agent-session-id"); +/// `traceparent`: the raw W3C Trace Context value the agent published. +const TRACEPARENT: HeaderName = HeaderName::from_static("traceparent"); + +/// Default headers that correlate every outbound request with the calling AI +/// agent's session/trace, so backend telemetry can group a single agent run's +/// calls. Empty when not running under a detected agent (or the agent exposes +/// neither id). +pub fn agent_headers() -> HeaderMap { + match is_ai_agent::detect() { + Some(agent) => { + agent_headers_from(agent.session_id.as_deref(), agent.traceparent.as_deref()) + } + None => HeaderMap::new(), + } +} + +/// Build the header map from the raw id values. Split out from [`agent_headers`] +/// so it can be unit-tested without constructing an `is_ai_agent::Agent` (which +/// is `#[non_exhaustive]` and not constructible outside its crate). +fn agent_headers_from(session_id: Option<&str>, traceparent: Option<&str>) -> HeaderMap { + let mut headers = HeaderMap::new(); + // Session ids / traceparent are opaque vendor strings; an invalid header + // value is dropped rather than panicking. + if let Some(value) = session_id.and_then(|s| HeaderValue::from_str(s).ok()) { + headers.insert(AGENT_SESSION_ID, value); + } + if let Some(value) = traceparent.and_then(|s| HeaderValue::from_str(s).ok()) { + headers.insert(TRACEPARENT, value); + } + headers +} + +/// Canonical `reqwest::ClientBuilder` for all outbound HTTP: pre-applies the +/// `User-Agent` and the agent session/trace default headers. Callers chain any +/// extra config (`.timeout(..)`, etc.) and `.build()`. +pub fn client_builder() -> reqwest::ClientBuilder { + reqwest::Client::builder() + .user_agent(crate::user_agent::user_agent()) + .default_headers(agent_headers()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn both_ids_present_yields_both_headers() { + let headers = agent_headers_from(Some("sess-123"), Some("00-abc-def-01")); + assert_eq!(headers.get("agent-session-id").unwrap(), "sess-123"); + assert_eq!(headers.get("traceparent").unwrap(), "00-abc-def-01"); + assert_eq!(headers.len(), 2); + } + + #[test] + fn no_ids_yields_empty_map() { + let headers = agent_headers_from(None, None); + assert!(headers.is_empty()); + } + + #[test] + fn session_only_yields_single_header() { + let headers = agent_headers_from(Some("sess-123"), None); + assert_eq!(headers.get("agent-session-id").unwrap(), "sess-123"); + assert!(headers.get("traceparent").is_none()); + assert_eq!(headers.len(), 1); + } + + #[test] + fn traceparent_only_yields_single_header() { + let headers = agent_headers_from(None, Some("00-abc-def-01")); + assert_eq!(headers.get("traceparent").unwrap(), "00-abc-def-01"); + assert!(headers.get("agent-session-id").is_none()); + assert_eq!(headers.len(), 1); + } + + #[test] + fn invalid_header_value_is_dropped_without_panic() { + // A newline is not a legal header value; the entry is skipped, the + // valid one still lands. + let headers = agent_headers_from(Some("bad\nvalue"), Some("00-ok-01")); + assert!(headers.get("agent-session-id").is_none()); + assert_eq!(headers.get("traceparent").unwrap(), "00-ok-01"); + } + + #[test] + fn client_builder_builds() { + // Smoke: the canonical builder produces a usable client. + let _client = client_builder().build().unwrap(); + } +} diff --git a/crates/clickhousectl/src/main.rs b/crates/clickhousectl/src/main.rs index b8d3eb4..b2ff006 100644 --- a/crates/clickhousectl/src/main.rs +++ b/crates/clickhousectl/src/main.rs @@ -2,6 +2,7 @@ mod cli; mod cloud; mod dotenv; mod error; +mod http; mod init; mod local; mod paths; diff --git a/crates/clickhousectl/src/update.rs b/crates/clickhousectl/src/update.rs index 50d2e15..95dd9cc 100644 --- a/crates/clickhousectl/src/update.rs +++ b/crates/clickhousectl/src/update.rs @@ -59,8 +59,7 @@ fn is_newer(current: &str, latest: &str) -> bool { /// Fetch the latest release info from GitHub with configurable timeout. async fn fetch_latest_release(timeout: std::time::Duration) -> Result { let url = format!("https://api.github.com/repos/{}/releases/latest", GITHUB_REPO); - let client = reqwest::Client::builder() - .user_agent(crate::user_agent::user_agent()) + let client = crate::http::client_builder() .timeout(timeout) .build()?; @@ -149,8 +148,7 @@ pub async fn perform_update() -> Result<()> { let display = latest.strip_prefix('v').unwrap_or(latest); println!("Downloading clickhousectl v{}...", display); - let client = reqwest::Client::builder() - .user_agent(crate::user_agent::user_agent()) + let client = crate::http::client_builder() .timeout(std::time::Duration::from_secs(300)) .build()?; diff --git a/crates/clickhousectl/src/version_manager/download.rs b/crates/clickhousectl/src/version_manager/download.rs index 5e2f604..77eaf80 100644 --- a/crates/clickhousectl/src/version_manager/download.rs +++ b/crates/clickhousectl/src/version_manager/download.rs @@ -17,8 +17,7 @@ pub async fn download_from_source( /// Downloads a file from a URL to the specified path, with progress bar pub async fn download_url(url: &str, dest_path: &Path) -> Result<()> { - let client = reqwest::Client::builder() - .user_agent(crate::user_agent::user_agent()) + let client = crate::http::client_builder() .build()?; let response = client diff --git a/crates/clickhousectl/src/version_manager/list.rs b/crates/clickhousectl/src/version_manager/list.rs index 46e1328..6c2012b 100644 --- a/crates/clickhousectl/src/version_manager/list.rs +++ b/crates/clickhousectl/src/version_manager/list.rs @@ -72,8 +72,7 @@ pub struct VersionEntry { /// Fetches available versions from GitHub releases pub async fn list_available_versions() -> Result> { let url = "https://api.github.com/repos/ClickHouse/ClickHouse/releases?per_page=100"; - let client = reqwest::Client::builder() - .user_agent(crate::user_agent::user_agent()) + let client = crate::http::client_builder() .build()?; let response = client @@ -114,8 +113,7 @@ pub async fn list_available_versions_from_builds() -> Result> { use crate::version_manager::platform::{Platform, builds_probe_url}; let platform = Platform::detect()?; - let client = reqwest::Client::builder() - .user_agent(crate::user_agent::user_agent()) + let client = crate::http::client_builder() .build() .map_err(|e| Error::Download(e.to_string()))?; diff --git a/crates/clickhousectl/src/version_manager/resolve.rs b/crates/clickhousectl/src/version_manager/resolve.rs index bd9e15e..06d8770 100644 --- a/crates/clickhousectl/src/version_manager/resolve.rs +++ b/crates/clickhousectl/src/version_manager/resolve.rs @@ -110,8 +110,7 @@ async fn resolve_channel(channel: Channel, platform: &Platform) -> Result Result { // Probe builds.clickhouse.com for all possible minors in this major (1..12) let mut highest_available: Option = None; - let client = reqwest::Client::builder() - .user_agent(crate::user_agent::user_agent()) + let client = crate::http::client_builder() .build() .map_err(|e| Error::Download(e.to_string()))?; @@ -188,8 +187,7 @@ async fn find_exact_channel(version: &str) -> Result { "https://api.github.com/repos/ClickHouse/ClickHouse/git/matching-refs/tags/v{}-", version ); - let client = reqwest::Client::builder() - .user_agent(crate::user_agent::user_agent()) + let client = crate::http::client_builder() .build()?; let response = client @@ -249,8 +247,7 @@ fn fallback_source(version: &str, channel: Channel, platform: &Platform) -> Reso /// Probe builds.clickhouse.com with a HEAD request to check if a version exists async fn probe_builds(version_path: &str, platform: &Platform) -> bool { let url = builds_probe_url(version_path, platform); - let client = match reqwest::Client::builder() - .user_agent(crate::user_agent::user_agent()) + let client = match crate::http::client_builder() .build() { Ok(c) => c, @@ -277,8 +274,7 @@ async fn find_version_by_refs(prefix: &str) -> Result { "https://api.github.com/repos/ClickHouse/ClickHouse/git/matching-refs/tags/v{}.", prefix ); - let client = reqwest::Client::builder() - .user_agent(crate::user_agent::user_agent()) + let client = crate::http::client_builder() .build()?; let response = client diff --git a/crates/clickhousectl/tests/cli_request_shape_test.rs b/crates/clickhousectl/tests/cli_request_shape_test.rs index eae3890..13b4aba 100644 --- a/crates/clickhousectl/tests/cli_request_shape_test.rs +++ b/crates/clickhousectl/tests/cli_request_shape_test.rs @@ -1667,3 +1667,69 @@ async fn shell_env_overrides_dotenv_creds_in_request() { "shell env vars must override .env values on the wire" ); } + +// ── Issue #267: agent session/trace headers land on outbound requests ──────── +// +// When invoked under a detected AI agent that publishes a session id / +// traceparent to its subprocesses (Claude Code uses CLAUDE_CODE_SESSION_ID; +// TRACEPARENT is the W3C standard var), `clickhousectl` forwards them as the +// `agent-session-id` and `traceparent` request headers via the default headers +// on the shared HTTP client (`crate::http::client_builder`). This proves they +// reach the wire through the client the Cloud library actually uses. + +#[tokio::test] +async fn agent_session_and_trace_headers_are_forwarded() { + let mock = MockServer::start().await; + + let stub_orgs = serde_json::json!({ + "result": [], + "status": 200, + "requestId": "stub-org-list", + }); + Mock::given(method("GET")) + .and(path("/v1/organizations")) + .respond_with(ResponseTemplate::new(200).set_body_json(stub_orgs)) + .mount(&mock) + .await; + + let url = mock.uri(); + let traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01"; + let output = Command::new(clickhousectl_binary()) + .args(["cloud", "--url", &url, "--json", "org", "list"]) + .env("CLICKHOUSE_CLOUD_API_KEY", "fake-key-for-tests") + .env("CLICKHOUSE_CLOUD_API_SECRET", "fake-secret-for-tests") + // Mark this invocation as Claude Code and expose the session/trace ids. + .env("AGENT", "claude-code") + .env("CLAUDE_CODE_SESSION_ID", "sess-test-267") + .env("TRACEPARENT", traceparent) + .output() + .expect("failed to spawn clickhousectl"); + + assert_success(&output); + + let requests = mock + .received_requests() + .await + .expect("mock requests log unavailable"); + let req = requests + .iter() + .find(|r| r.method == wiremock::http::Method::GET) + .expect("no GET request recorded"); + + assert_eq!( + req.headers + .get("agent-session-id") + .expect("agent-session-id header missing") + .to_str() + .unwrap(), + "sess-test-267", + ); + assert_eq!( + req.headers + .get("traceparent") + .expect("traceparent header missing") + .to_str() + .unwrap(), + traceparent, + ); +}