diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 9e4d82c5..8a5bd63a 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -272,6 +272,11 @@ name = "test_streamable_http_json_response" required-features = ["server", "client", "transport-streamable-http-server", "reqwest"] path = "tests/test_streamable_http_json_response.rs" +[[test]] +name = "test_streamable_http_protocol_version" +required-features = ["server", "client", "transport-streamable-http-server", "reqwest"] +path = "tests/test_streamable_http_protocol_version.rs" + [[test]] name = "test_streamable_http_4xx_error_body" required-features = ["transport-streamable-http-client", "transport-streamable-http-client-reqwest"] diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index 5993c75b..cd2f5f1e 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -1,4 +1,6 @@ -use std::{collections::HashMap, convert::Infallible, fmt::Display, sync::Arc, time::Duration}; +use std::{ + borrow::Cow, collections::HashMap, convert::Infallible, fmt::Display, sync::Arc, time::Duration, +}; use bytes::Bytes; use futures::{StreamExt, future::BoxFuture}; @@ -14,8 +16,8 @@ use super::session::{ use crate::{ RoleServer, model::{ - ClientJsonRpcMessage, ClientNotification, ClientRequest, GetExtensions, InitializeRequest, - InitializedNotification, ProtocolVersion, + ClientJsonRpcMessage, ClientNotification, ClientRequest, ErrorData, GetExtensions, + InitializeRequest, InitializedNotification, JsonRpcError, ProtocolVersion, RequestId, }, serve_server, service::serve_directly, @@ -209,6 +211,54 @@ fn validate_protocol_version_header(headers: &http::HeaderMap) -> Result<(), Box Ok(()) } +fn invalid_request_jsonrpc_response( + id: Option, + message: impl Into>, +) -> BoxResponse { + let err = JsonRpcError::new(id, ErrorData::invalid_request(message, None)); + let body = serde_json::to_vec(&err).expect("serialize JsonRpcError"); + Response::builder() + .status(http::StatusCode::BAD_REQUEST) + .header(http::header::CONTENT_TYPE, JSON_MIME_TYPE) + .body(Full::new(Bytes::from(body)).boxed()) + .expect("valid response") +} + +#[expect( + clippy::result_large_err, + reason = "BoxResponse is intentionally large; matches other handlers in this file" +)] +/// Absent header is allowed; the first initialize round-trip may legitimately omit it. +fn validate_header_matches_init_body( + headers: &http::HeaderMap, + body_version: &str, + request_id: Option, +) -> Result<(), BoxResponse> { + let Some(header_value) = headers.get(HEADER_MCP_PROTOCOL_VERSION) else { + return Ok(()); + }; + let header_str = header_value.to_str().map_err(|_| { + invalid_request_jsonrpc_response( + request_id.clone(), + "Invalid Request: MCP-Protocol-Version header is not valid UTF-8", + ) + })?; + if header_str != body_version { + tracing::warn!( + header = header_str, + body = body_version, + "rejecting initialize: MCP-Protocol-Version header does not match params.protocolVersion" + ); + return Err(invalid_request_jsonrpc_response( + request_id, + format!( + "Invalid Request: MCP-Protocol-Version header ({header_str}) does not match initialize params.protocolVersion ({body_version})" + ), + )); + } + Ok(()) +} + fn forbidden_response(message: impl Into) -> BoxResponse { Response::builder() .status(http::StatusCode::FORBIDDEN) @@ -1095,9 +1145,15 @@ where None }; if let ClientJsonRpcMessage::Request(req) = &mut message { - if !matches!(req.request, ClientRequest::InitializeRequest(_)) { + let ClientRequest::InitializeRequest(init_req) = &req.request else { return Err(unexpected_message_response("initialize request")); - } + }; + // Reject mismatched MCP-Protocol-Version header before binding the session to anything. + validate_header_matches_init_body( + &part.headers, + init_req.params.protocol_version.as_str(), + Some(req.id.clone()), + )?; // inject request part to extensions req.request.extensions_mut().insert(part); } else { @@ -1163,13 +1219,24 @@ where Ok(response) } } else { - // Stateless mode: validate MCP-Protocol-Version on non-init requests - let is_init = matches!( - &message, - ClientJsonRpcMessage::Request(req) if matches!(req.request, ClientRequest::InitializeRequest(_)) - ); - if !is_init { - validate_protocol_version_header(&part.headers)?; + // Stateless mode: + // - on initialize: the header (if present) must match `params.protocolVersion` + // - on every other request: the header must name a known version. + match &message { + ClientJsonRpcMessage::Request(req) => { + if let ClientRequest::InitializeRequest(init_req) = &req.request { + validate_header_matches_init_body( + &part.headers, + init_req.params.protocol_version.as_str(), + Some(req.id.clone()), + )?; + } else { + validate_protocol_version_header(&part.headers)?; + } + } + _ => { + validate_protocol_version_header(&part.headers)?; + } } let service = self .get_service() diff --git a/crates/rmcp/tests/test_streamable_http_protocol_version.rs b/crates/rmcp/tests/test_streamable_http_protocol_version.rs new file mode 100644 index 00000000..3500266b --- /dev/null +++ b/crates/rmcp/tests/test_streamable_http_protocol_version.rs @@ -0,0 +1,149 @@ +#![cfg(not(feature = "local"))] +//! Regression tests for the `MCP-Protocol-Version` header / initialize body consistency check. +use rmcp::transport::streamable_http_server::{ + StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager, +}; +use tokio_util::sync::CancellationToken; + +mod common; +use common::calculator::Calculator; + +fn init_body(body_version: &str) -> String { + format!( + r#"{{"jsonrpc":"2.0","id":1,"method":"initialize","params":{{"protocolVersion":"{body_version}","capabilities":{{}},"clientInfo":{{"name":"test","version":"1.0"}}}}}}"# + ) +} + +async fn spawn_server( + config: StreamableHttpServerConfig, +) -> (reqwest::Client, String, CancellationToken) { + let ct = config.cancellation_token.clone(); + let service: StreamableHttpService = + StreamableHttpService::new(|| Ok(Calculator::new()), Default::default(), config); + + let router = axum::Router::new().nest_service("/mcp", service); + let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = tcp_listener.local_addr().unwrap(); + + tokio::spawn({ + let ct = ct.clone(); + async move { + let _ = axum::serve(tcp_listener, router) + .with_graceful_shutdown(async move { ct.cancelled_owned().await }) + .await; + } + }); + + let client = reqwest::Client::new(); + let base_url = format!("http://{addr}/mcp"); + (client, base_url, ct) +} + +fn stateless_json_config() -> StreamableHttpServerConfig { + StreamableHttpServerConfig::default() + .with_stateful_mode(false) + .with_json_response(true) + .with_sse_keep_alive(None) + .with_cancellation_token(CancellationToken::new()) +} + +fn stateful_config() -> StreamableHttpServerConfig { + StreamableHttpServerConfig::default() + .with_stateful_mode(true) + .with_sse_keep_alive(None) + .with_cancellation_token(CancellationToken::new()) +} + +async fn post_init( + client: &reqwest::Client, + url: &str, + header: Option<&str>, + body_version: &str, +) -> reqwest::Response { + let mut req = client + .post(url) + .header("Content-Type", "application/json") + .header("Accept", "application/json, text/event-stream") + .body(init_body(body_version)); + if let Some(h) = header { + req = req.header("MCP-Protocol-Version", h); + } + req.send().await.expect("send initialize request") +} + +#[tokio::test] +async fn stateless_init_rejects_when_header_older_than_body() -> anyhow::Result<()> { + let (client, url, ct) = spawn_server(stateless_json_config()).await; + + let response = post_init(&client, &url, Some("2025-03-26"), "2025-11-25").await; + assert_eq!(response.status(), 400); + + let body: serde_json::Value = response.json().await?; + assert_eq!(body["error"]["code"], -32600); + assert!( + body["error"]["message"] + .as_str() + .unwrap_or_default() + .contains("MCP-Protocol-Version"), + "expected error message to mention the header, got: {body}" + ); + + ct.cancel(); + Ok(()) +} + +#[tokio::test] +async fn stateless_init_rejects_when_header_newer_than_body() -> anyhow::Result<()> { + let (client, url, ct) = spawn_server(stateless_json_config()).await; + + let response = post_init(&client, &url, Some("2025-11-25"), "2025-03-26").await; + assert_eq!(response.status(), 400); + + let body: serde_json::Value = response.json().await?; + assert_eq!(body["error"]["code"], -32600); + + ct.cancel(); + Ok(()) +} + +#[tokio::test] +async fn stateless_init_accepts_when_header_matches_body() -> anyhow::Result<()> { + let (client, url, ct) = spawn_server(stateless_json_config()).await; + + let response = post_init(&client, &url, Some("2025-11-25"), "2025-11-25").await; + assert_eq!(response.status(), 200); + + let body: serde_json::Value = response.json().await?; + assert!( + body["result"].is_object(), + "expected an InitializeResult, got: {body}" + ); + + ct.cancel(); + Ok(()) +} + +#[tokio::test] +async fn stateless_init_accepts_when_header_absent() -> anyhow::Result<()> { + let (client, url, ct) = spawn_server(stateless_json_config()).await; + + let response = post_init(&client, &url, None, "2025-11-25").await; + assert_eq!(response.status(), 200); + + ct.cancel(); + Ok(()) +} + +#[tokio::test] +async fn stateful_init_rejects_when_header_mismatches_body() -> anyhow::Result<()> { + let (client, url, ct) = spawn_server(stateful_config()).await; + + let response = post_init(&client, &url, Some("2024-11-05"), "2025-11-25").await; + assert_eq!(response.status(), 400); + + let body: serde_json::Value = response.json().await?; + assert_eq!(body["error"]["code"], -32600); + + ct.cancel(); + Ok(()) +}