diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 793c59bc7..8623079e1 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -22,6 +22,7 @@ CONNECTION_CLOSED, INTERNAL_ERROR, INVALID_PARAMS, + INVALID_REQUEST, REQUEST_TIMEOUT, ErrorData, JSONRPCError, @@ -551,6 +552,23 @@ async def _dispatch_request( on_request: OnRequest, sender_ctx: contextvars.Context | None, ) -> None: + # Key coerced so a stringified `notifications/cancelled` id still correlates. + key = coerce_request_id(req.id) + if key in self._in_flight: + # Duplicate in-flight id. The spec requires request ids to be unique + # within a session while a request is outstanding; a blind overwrite + # here would silently retarget `notifications/cancelled` onto the newer + # request and orphan the older one (see #3060). Reject the duplicate + # instead - ids may still be reused once the earlier request completes. + # Mirrors `direct_dispatcher`'s guard for caller-supplied ids. + logger.warning("duplicate in-flight request id %r; rejecting with INVALID_REQUEST", req.id) + self._spawn( + self._write_error, + req.id, + ErrorData(code=INVALID_REQUEST, message=f"request id {req.id!r} is already in flight"), + sender_ctx=sender_ctx, + ) + return progress_token = progress_token_from_params(req.params) try: transport_ctx = self._transport_builder(metadata) @@ -572,10 +590,7 @@ async def _dispatch_request( _progress_token=progress_token, ) scope = anyio.CancelScope() - # TODO(maxisbey): duplicate ids blind-overwrite (v1/TS parity); revisit - # rejecting with INVALID_REQUEST. Key coerced so a stringified - # `notifications/cancelled` id still correlates. - self._in_flight[coerce_request_id(req.id)] = _InFlight(scope=scope, dctx=dctx) + self._in_flight[key] = _InFlight(scope=scope, dctx=dctx) if req.method in self._inline_methods: # Spawn so `sender_ctx` applies, but park the read loop until the # handler returns - that's the inline ordering guarantee. @@ -699,12 +714,10 @@ async def _handle_request( result = await on_request(dctx, req.method, req.params) finally: # Close the back-channel and drop from `_in_flight`; no checkpoint - # since handler return, so a peer cancel can't interleave. - # Identity guard: don't evict a duplicate id's newer entry. + # since handler return, so a peer cancel can't interleave. Duplicate + # ids are rejected at registration, so this entry is always ours. dctx.close() - key = coerce_request_id(req.id) - if (entry := self._in_flight.get(key)) is not None and entry.dctx is dctx: - del self._in_flight[key] + self._in_flight.pop(coerce_request_id(req.id), None) # A write interrupted by cancellation may still have delivered # (a memory-stream send can hand its item to the receiver and # still raise), so a started answer write counts as sent below: @@ -744,7 +757,7 @@ async def _handle_request( await self._write_error(req.id, ErrorData(code=0, message=str(e))) if self._raise_handler_exceptions: raise - # No `_in_flight` pop here: the inner finally covers every path, and a late pop could evict a reused id. + # No `_in_flight` pop here: the inner finally covers every path. def _allocate_id(self) -> int: self._next_id += 1 diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index e91fc2de2..1b9aee39b 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -14,6 +14,7 @@ CONNECTION_CLOSED, INTERNAL_ERROR, INVALID_PARAMS, + INVALID_REQUEST, REQUEST_TIMEOUT, CallToolRequest, CallToolRequestParams, @@ -2191,27 +2192,69 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> @pytest.mark.anyio -async def test_completed_handler_does_not_evict_reused_request_id_from_in_flight(): - """A second request reusing an id while the first handler is parked in its response write - keeps its own `_in_flight` entry (a post-write pop would evict it and break peer-cancellation).""" +async def test_duplicate_in_flight_request_id_is_rejected_with_invalid_request(): + """A second inbound request that reuses an id still in flight is rejected with INVALID_REQUEST + rather than blindly overwriting the first's `_in_flight` entry (#3060). The duplicate never + reaches the handler; the original request is untouched and still completes.""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) - # buffer=0: the first handler's response write parks until the test receives. - s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) - calls = 0 - second_started = anyio.Event() - second_exited = anyio.Event() + handled: list[str] = [] + started = anyio.Event() + release = anyio.Event() async def on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: - nonlocal calls - calls += 1 - if calls == 1: - return {"first": True} - second_started.set() + handled.append(method) + started.set() + await release.wait() + return {"method": method} + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError # no notifications are sent in this test + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, on_request, on_notify) + with anyio.fail_after(5): + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="first"))) + await started.wait() + # Duplicate id while the first request is still outstanding. + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="second"))) + rejection = await s2c_recv.receive() + assert isinstance(rejection, SessionMessage) + assert isinstance(rejection.message, JSONRPCError) + assert rejection.message.id == 7 + assert rejection.message.error.code == INVALID_REQUEST + # The original request is untouched and still completes normally. + release.set() + resp = await s2c_recv.receive() + assert isinstance(resp, SessionMessage) + assert isinstance(resp.message, JSONRPCResponse) + assert resp.message.result == {"method": "first"} + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + assert handled == ["first"] # the rejected duplicate never reached the handler + + +@pytest.mark.anyio +async def test_duplicate_id_rejection_leaves_original_request_cancellable(): + """Rejecting the duplicate keeps `_in_flight` pointing at the original request, so a later + `notifications/cancelled` still targets it - the duplicate can no longer steal cancellation + away from the older, still-running request (#3060).""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + started = anyio.Event() + exited = anyio.Event() + + async def on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + started.set() try: await anyio.sleep_forever() finally: - second_exited.set() + exited.set() raise NotImplementedError async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: @@ -2221,18 +2264,15 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> async with anyio.create_task_group() as tg: await tg.start(server.run, on_request, on_notify) with anyio.fail_after(5): - await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="a"))) - # First handler is now parked in `_write_result`; reuse its id. - await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="b"))) - await second_started.wait() - resp1 = await s2c_recv.receive() - assert isinstance(resp1, SessionMessage) - assert isinstance(resp1.message, JSONRPCResponse) - assert resp1.message.result == {"first": True} - # Let the first handler task run to completion past the write. - await anyio.wait_all_tasks_blocked() - assert 7 in server._in_flight # pyright: ignore[reportPrivateUsage] - # The surviving entry must still be cancellable. + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="slow"))) + await started.wait() + # Duplicate id is rejected; it must not become the cancellation target. + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="dup"))) + rejection = await s2c_recv.receive() + assert isinstance(rejection, SessionMessage) + assert isinstance(rejection.message, JSONRPCError) + assert rejection.message.error.code == INVALID_REQUEST + # Cancelling id 7 must reach the original, still-parked request. await c2s_send.send( SessionMessage( message=JSONRPCNotification( @@ -2240,11 +2280,12 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> ) ) ) - resp2 = await s2c_recv.receive() - assert isinstance(resp2, SessionMessage) - assert isinstance(resp2.message, JSONRPCError) - assert resp2.message.error == ErrorData(code=0, message="Request cancelled") - assert second_exited.is_set() + cancelled = await s2c_recv.receive() + assert isinstance(cancelled, SessionMessage) + assert isinstance(cancelled.message, JSONRPCError) + assert cancelled.message.id == 7 + assert cancelled.message.error == ErrorData(code=0, message="Request cancelled") + assert exited.is_set() tg.cancel_scope.cancel() finally: for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): @@ -2252,62 +2293,33 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> @pytest.mark.anyio -async def test_duplicate_request_id_completion_of_first_handler_keeps_second_cancellable(): - """A duplicate inbound id overwrites `_in_flight` (parity with v1/TS); the identity-guarded pop - keeps the first handler's completion from evicting the second's entry and breaking its cancellation.""" +async def test_request_id_is_reusable_after_the_earlier_request_completes(): + """Sequential reuse of an id after the earlier request has completed is still accepted - + deployed clients that send a constant id depend on it; only *in-flight* duplicates are rejected.""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) - first_started = anyio.Event() - release_first = anyio.Event() - second_started = anyio.Event() - second_exited = anyio.Event() + calls = 0 async def on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: - if method == "first": - first_started.set() - await release_first.wait() - return {"first": True} - second_started.set() - try: - await anyio.sleep_forever() - finally: - second_exited.set() - raise NotImplementedError + nonlocal calls + calls += 1 + return {"call": calls} async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: - pass # the cancelled notification is teed here; nothing to observe + raise NotImplementedError # no notifications are sent in this test try: async with anyio.create_task_group() as tg: await tg.start(server.run, on_request, on_notify) with anyio.fail_after(5): - await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="first"))) - await first_started.wait() - # Duplicate id: the table entry now belongs to the second request. - await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="second"))) - await second_started.wait() - release_first.set() - resp1 = await s2c_recv.receive() - assert isinstance(resp1, SessionMessage) - assert isinstance(resp1.message, JSONRPCResponse) - assert resp1.message.result == {"first": True} - # Let the first handler task run past its pop entirely. - await anyio.wait_all_tasks_blocked() - assert 7 in server._in_flight # pyright: ignore[reportPrivateUsage] - # The surviving entry must still be cancellable by the peer. - await c2s_send.send( - SessionMessage( - message=JSONRPCNotification( - jsonrpc="2.0", method="notifications/cancelled", params={"requestId": 7} - ) - ) - ) - resp2 = await s2c_recv.receive() - assert isinstance(resp2, SessionMessage) - assert isinstance(resp2.message, JSONRPCError) - assert resp2.message.error == ErrorData(code=0, message="Request cancelled") - assert second_exited.is_set() + for expected in (1, 2): + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="t"))) + resp = await s2c_recv.receive() + assert isinstance(resp, SessionMessage) + assert isinstance(resp.message, JSONRPCResponse) + assert resp.message.result == {"call": expected} + assert 7 not in server._in_flight # pyright: ignore[reportPrivateUsage] tg.cancel_scope.cancel() finally: for s in (c2s_send, c2s_recv, s2c_send, s2c_recv):