Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions crates/rmcp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ name = "test_streamable_http_json_response"
required-features = ["server", "client", "transport-streamable-http-server", "reqwest"]
path = "tests/test_streamable_http_json_response.rs"

[[test]]
name = "test_streamable_http_protocol_version"
required-features = ["server", "client", "transport-streamable-http-server", "reqwest"]
path = "tests/test_streamable_http_protocol_version.rs"

[[test]]
name = "test_streamable_http_4xx_error_body"
required-features = ["transport-streamable-http-client", "transport-streamable-http-client-reqwest"]
Expand Down
91 changes: 79 additions & 12 deletions crates/rmcp/src/transport/streamable_http_server/tower.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use std::{collections::HashMap, convert::Infallible, fmt::Display, sync::Arc, time::Duration};
use std::{
borrow::Cow, collections::HashMap, convert::Infallible, fmt::Display, sync::Arc, time::Duration,
};

use bytes::Bytes;
use futures::{StreamExt, future::BoxFuture};
Expand All @@ -14,8 +16,8 @@ use super::session::{
use crate::{
RoleServer,
model::{
ClientJsonRpcMessage, ClientNotification, ClientRequest, GetExtensions, InitializeRequest,
InitializedNotification, ProtocolVersion,
ClientJsonRpcMessage, ClientNotification, ClientRequest, ErrorData, GetExtensions,
InitializeRequest, InitializedNotification, JsonRpcError, ProtocolVersion, RequestId,
},
serve_server,
service::serve_directly,
Expand Down Expand Up @@ -209,6 +211,54 @@ fn validate_protocol_version_header(headers: &http::HeaderMap) -> Result<(), Box
Ok(())
}

fn invalid_request_jsonrpc_response(
id: Option<RequestId>,
message: impl Into<Cow<'static, str>>,
) -> BoxResponse {
let err = JsonRpcError::new(id, ErrorData::invalid_request(message, None));
let body = serde_json::to_vec(&err).expect("serialize JsonRpcError");
Response::builder()
.status(http::StatusCode::BAD_REQUEST)
.header(http::header::CONTENT_TYPE, JSON_MIME_TYPE)
.body(Full::new(Bytes::from(body)).boxed())
.expect("valid response")
}

#[expect(
clippy::result_large_err,
reason = "BoxResponse is intentionally large; matches other handlers in this file"
)]
/// Absent header is allowed; the first initialize round-trip may legitimately omit it.
fn validate_header_matches_init_body(
headers: &http::HeaderMap,
body_version: &str,
request_id: Option<RequestId>,
) -> Result<(), BoxResponse> {
let Some(header_value) = headers.get(HEADER_MCP_PROTOCOL_VERSION) else {
return Ok(());
};
let header_str = header_value.to_str().map_err(|_| {
invalid_request_jsonrpc_response(
request_id.clone(),
"Invalid Request: MCP-Protocol-Version header is not valid UTF-8",
)
})?;
if header_str != body_version {
tracing::warn!(
header = header_str,
body = body_version,
"rejecting initialize: MCP-Protocol-Version header does not match params.protocolVersion"
);
return Err(invalid_request_jsonrpc_response(
request_id,
format!(
"Invalid Request: MCP-Protocol-Version header ({header_str}) does not match initialize params.protocolVersion ({body_version})"
),
));
}
Ok(())
}

fn forbidden_response(message: impl Into<String>) -> BoxResponse {
Response::builder()
.status(http::StatusCode::FORBIDDEN)
Expand Down Expand Up @@ -1095,9 +1145,15 @@ where
None
};
if let ClientJsonRpcMessage::Request(req) = &mut message {
if !matches!(req.request, ClientRequest::InitializeRequest(_)) {
let ClientRequest::InitializeRequest(init_req) = &req.request else {
return Err(unexpected_message_response("initialize request"));
}
};
// Reject mismatched MCP-Protocol-Version header before binding the session to anything.
validate_header_matches_init_body(
&part.headers,
init_req.params.protocol_version.as_str(),
Some(req.id.clone()),
)?;
// inject request part to extensions
req.request.extensions_mut().insert(part);
} else {
Expand Down Expand Up @@ -1163,13 +1219,24 @@ where
Ok(response)
}
} else {
// Stateless mode: validate MCP-Protocol-Version on non-init requests
let is_init = matches!(
&message,
ClientJsonRpcMessage::Request(req) if matches!(req.request, ClientRequest::InitializeRequest(_))
);
if !is_init {
validate_protocol_version_header(&part.headers)?;
// Stateless mode:
// - on initialize: the header (if present) must match `params.protocolVersion`
// - on every other request: the header must name a known version.
match &message {
ClientJsonRpcMessage::Request(req) => {
if let ClientRequest::InitializeRequest(init_req) = &req.request {
validate_header_matches_init_body(
&part.headers,
init_req.params.protocol_version.as_str(),
Some(req.id.clone()),
)?;
} else {
validate_protocol_version_header(&part.headers)?;
}
}
_ => {
validate_protocol_version_header(&part.headers)?;
}
}
let service = self
.get_service()
Expand Down
149 changes: 149 additions & 0 deletions crates/rmcp/tests/test_streamable_http_protocol_version.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#![cfg(not(feature = "local"))]
//! Regression tests for the `MCP-Protocol-Version` header / initialize body consistency check.
use rmcp::transport::streamable_http_server::{
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
};
use tokio_util::sync::CancellationToken;

mod common;
use common::calculator::Calculator;

fn init_body(body_version: &str) -> String {
format!(
r#"{{"jsonrpc":"2.0","id":1,"method":"initialize","params":{{"protocolVersion":"{body_version}","capabilities":{{}},"clientInfo":{{"name":"test","version":"1.0"}}}}}}"#
)
}

async fn spawn_server(
config: StreamableHttpServerConfig,
) -> (reqwest::Client, String, CancellationToken) {
let ct = config.cancellation_token.clone();
let service: StreamableHttpService<Calculator, LocalSessionManager> =
StreamableHttpService::new(|| Ok(Calculator::new()), Default::default(), config);

let router = axum::Router::new().nest_service("/mcp", service);
let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = tcp_listener.local_addr().unwrap();

tokio::spawn({
let ct = ct.clone();
async move {
let _ = axum::serve(tcp_listener, router)
.with_graceful_shutdown(async move { ct.cancelled_owned().await })
.await;
}
});

let client = reqwest::Client::new();
let base_url = format!("http://{addr}/mcp");
(client, base_url, ct)
}

fn stateless_json_config() -> StreamableHttpServerConfig {
StreamableHttpServerConfig::default()
.with_stateful_mode(false)
.with_json_response(true)
.with_sse_keep_alive(None)
.with_cancellation_token(CancellationToken::new())
}

fn stateful_config() -> StreamableHttpServerConfig {
StreamableHttpServerConfig::default()
.with_stateful_mode(true)
.with_sse_keep_alive(None)
.with_cancellation_token(CancellationToken::new())
}

async fn post_init(
client: &reqwest::Client,
url: &str,
header: Option<&str>,
body_version: &str,
) -> reqwest::Response {
let mut req = client
.post(url)
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.body(init_body(body_version));
if let Some(h) = header {
req = req.header("MCP-Protocol-Version", h);
}
req.send().await.expect("send initialize request")
}

#[tokio::test]
async fn stateless_init_rejects_when_header_older_than_body() -> anyhow::Result<()> {
let (client, url, ct) = spawn_server(stateless_json_config()).await;

let response = post_init(&client, &url, Some("2025-03-26"), "2025-11-25").await;
assert_eq!(response.status(), 400);

let body: serde_json::Value = response.json().await?;
assert_eq!(body["error"]["code"], -32600);
assert!(
body["error"]["message"]
.as_str()
.unwrap_or_default()
.contains("MCP-Protocol-Version"),
"expected error message to mention the header, got: {body}"
);

ct.cancel();
Ok(())
}

#[tokio::test]
async fn stateless_init_rejects_when_header_newer_than_body() -> anyhow::Result<()> {
let (client, url, ct) = spawn_server(stateless_json_config()).await;

let response = post_init(&client, &url, Some("2025-11-25"), "2025-03-26").await;
assert_eq!(response.status(), 400);

let body: serde_json::Value = response.json().await?;
assert_eq!(body["error"]["code"], -32600);

ct.cancel();
Ok(())
}

#[tokio::test]
async fn stateless_init_accepts_when_header_matches_body() -> anyhow::Result<()> {
let (client, url, ct) = spawn_server(stateless_json_config()).await;

let response = post_init(&client, &url, Some("2025-11-25"), "2025-11-25").await;
assert_eq!(response.status(), 200);

let body: serde_json::Value = response.json().await?;
assert!(
body["result"].is_object(),
"expected an InitializeResult, got: {body}"
);

ct.cancel();
Ok(())
}

#[tokio::test]
async fn stateless_init_accepts_when_header_absent() -> anyhow::Result<()> {
let (client, url, ct) = spawn_server(stateless_json_config()).await;

let response = post_init(&client, &url, None, "2025-11-25").await;
assert_eq!(response.status(), 200);

ct.cancel();
Ok(())
}

#[tokio::test]
async fn stateful_init_rejects_when_header_mismatches_body() -> anyhow::Result<()> {
let (client, url, ct) = spawn_server(stateful_config()).await;

let response = post_init(&client, &url, Some("2024-11-05"), "2025-11-25").await;
assert_eq!(response.status(), 400);

let body: serde_json::Value = response.json().await?;
assert_eq!(body["error"]["code"], -32600);

ct.cancel();
Ok(())
}