diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5824bd39..ba1eddb6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -85,6 +85,7 @@ jobs: cargo semver-checks \ --package rmcp \ --baseline-rev ${{ github.event.pull_request.base.sha }} \ + --release-type minor \ --only-explicit-features \ --features default @@ -97,6 +98,7 @@ jobs: cargo semver-checks \ --package rmcp \ --baseline-rev ${{ github.event.pull_request.base.sha }} \ + --release-type minor \ --only-explicit-features \ --features "$FEATURES" diff --git a/crates/rmcp/src/service/server.rs b/crates/rmcp/src/service/server.rs index 530e508e..c185696e 100644 --- a/crates/rmcp/src/service/server.rs +++ b/crates/rmcp/src/service/server.rs @@ -69,6 +69,10 @@ pub enum ServerInitializeError { #[error("initialize failed: {0}")] InitializeFailed(ErrorData), + #[deprecated( + since = "1.8.0", + note = "Negotiation now falls back to the server-configured version. This variant is never constructed and will be removed in a future major release." + )] #[error("unsupported protocol version: {0}")] UnsupportedProtocolVersion(ProtocolVersion), @@ -155,6 +159,23 @@ where } } +/// Echoes the client-requested version if known; otherwise returns `server_fallback`. +fn negotiate_protocol_version( + client_requested: &ProtocolVersion, + server_fallback: ProtocolVersion, +) -> ProtocolVersion { + if ProtocolVersion::KNOWN_VERSIONS.contains(client_requested) { + client_requested.clone() + } else { + tracing::warn!( + client_requested = %client_requested, + server_fallback = %server_fallback, + "client requested unsupported protocol version; falling back to server default" + ); + server_fallback + } +} + async fn serve_server_with_ct_inner( service: S, transport: T, @@ -227,16 +248,10 @@ where return Err(ServerInitializeError::InitializeFailed(e)); } }; - let peer_protocol_version = peer_info.params.protocol_version.clone(); - let protocol_version = match peer_protocol_version - .partial_cmp(&init_response.protocol_version) - .ok_or(ServerInitializeError::UnsupportedProtocolVersion( - peer_protocol_version, - ))? { - std::cmp::Ordering::Less => peer_info.params.protocol_version.clone(), - _ => init_response.protocol_version, - }; - init_response.protocol_version = protocol_version; + init_response.protocol_version = negotiate_protocol_version( + &peer_info.params.protocol_version, + init_response.protocol_version, + ); transport .send(ServerJsonRpcMessage::response( ServerResult::InitializeResult(init_response), diff --git a/crates/rmcp/tests/test_server_initialization.rs b/crates/rmcp/tests/test_server_initialization.rs index 8cf5c2c4..e2e04896 100644 --- a/crates/rmcp/tests/test_server_initialization.rs +++ b/crates/rmcp/tests/test_server_initialization.rs @@ -4,8 +4,11 @@ mod common; use common::handlers::TestServer; use rmcp::{ - ServiceExt, - model::{ClientJsonRpcMessage, ServerJsonRpcMessage, ServerResult}, + ServerHandler, ServiceExt, + model::{ + ClientJsonRpcMessage, ProtocolVersion, ServerCapabilities, ServerInfo, + ServerJsonRpcMessage, ServerResult, + }, transport::{IntoTransport, Transport}, }; @@ -220,6 +223,82 @@ async fn server_init_buffers_request_before_initialized() { result.unwrap().cancel().await.unwrap(); } +fn init_request_with_version(v: &str) -> ClientJsonRpcMessage { + msg(&format!( + r#"{{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": {{ + "protocolVersion": "{v}", + "capabilities": {{}}, + "clientInfo": {{ "name": "test-client", "version": "0.0.1" }} + }} + }}"# + )) +} + +async fn negotiate_version(handler: H, client_version: &str) -> ProtocolVersion +where + H: ServerHandler + 'static, +{ + let (server_transport, client_transport) = tokio::io::duplex(4096); + let _server = tokio::spawn(async move { handler.serve(server_transport).await }); + let mut client = IntoTransport::::into_transport(client_transport); + + client + .send(init_request_with_version(client_version)) + .await + .unwrap(); + let response = client.receive().await.unwrap(); + let ServerJsonRpcMessage::Response(r) = response else { + panic!("expected initialize response, got {response:?}"); + }; + let ServerResult::InitializeResult(init) = r.result else { + panic!("expected InitializeResult"); + }; + init.protocol_version +} + +#[tokio::test] +async fn server_echoes_client_protocol_version_when_known_old() { + let negotiated = negotiate_version(TestServer::new(), "2024-11-05").await; + assert_eq!(negotiated, ProtocolVersion::V_2024_11_05); +} + +#[tokio::test] +async fn server_echoes_client_protocol_version_when_latest() { + let negotiated = negotiate_version(TestServer::new(), "2025-11-25").await; + assert_eq!(negotiated, ProtocolVersion::LATEST); +} + +#[tokio::test] +async fn server_falls_back_when_client_protocol_version_unknown() { + let negotiated = negotiate_version(TestServer::new(), "2099-99-99").await; + assert_eq!(negotiated, ProtocolVersion::LATEST); +} + +struct PinnedServer; + +impl ServerHandler for PinnedServer { + fn get_info(&self) -> ServerInfo { + ServerInfo::new(ServerCapabilities::builder().build()) + .with_protocol_version(ProtocolVersion::V_2025_06_18) + } +} + +#[tokio::test] +async fn server_pinned_version_does_not_override_known_client_request() { + let negotiated = negotiate_version(PinnedServer, "2025-11-25").await; + assert_eq!(negotiated, ProtocolVersion::LATEST); +} + +#[tokio::test] +async fn server_pinned_version_used_as_fallback_for_unknown_client_request() { + let negotiated = negotiate_version(PinnedServer, "2099-99-99").await; + assert_eq!(negotiated, ProtocolVersion::V_2025_06_18); +} + // Server buffers multiple requests before initialized and processes them in order. #[tokio::test] async fn server_init_buffers_multiple_requests_before_initialized() {