Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
326 changes: 148 additions & 178 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -439,21 +439,19 @@ func (h *StreamableHTTPHandler) ephemeralConnectOpts(req *http.Request) (*epheme
}
req.Body.Close()
req.Body = io.NopCloser(bytes.NewBuffer(body))
msgs, _, err := readBatch(body)
msg, err := jsonrpc2.DecodeMessage(body)
if err == nil {
for _, msg := range msgs {
if r, ok := msg.(*jsonrpc.Request); ok {
switch r.Method {
case methodInitialize:
hasInitialize = true
case notificationInitialized:
hasInitialized = true
case methodSubscriptionsListen:
isSubscriptionsListen = true
}
if protocolVersion >= protocolVersion20260728 {
usesNewProtocol = true
}
if r, ok := msg.(*jsonrpc.Request); ok {
switch r.Method {
case methodInitialize:
hasInitialize = true
case notificationInitialized:
hasInitialized = true
case methodSubscriptionsListen:
isSubscriptionsListen = true
}
if protocolVersion >= protocolVersion20260728 {
usesNewProtocol = true
}
}
}
Expand Down Expand Up @@ -916,8 +914,6 @@ type stream struct {
// collected here until the stream is complete, at which point they are
// flushed as a single JSON response. Note that the non-nilness of this field
// is significant, as it signals the expected content type.
//
// Note: if we remove support for batching, this could just be a bool.
pendingJSONMessages []json.RawMessage

// w is the HTTP response writer for this stream. A non-nil w indicates
Expand All @@ -942,9 +938,6 @@ type stream struct {
// requests is the set of unanswered incoming requests for the stream.
//
// Requests are removed when their response has been received.
// In practice, there is only one request, but in the 2025-03-26 version of
// the spec and earlier there was a concept of batching, in which POST
// payloads could hold multiple requests or responses.
requests map[jsonrpc.ID]struct{}

// isListen reports whether this stream was opened by a
Expand Down Expand Up @@ -1381,10 +1374,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
http.Error(w, "POST requires a non-empty body", http.StatusBadRequest)
return
}
// TODO(#674): once we've documented the support matrix for 2025-03-26 and
// earlier, drop support for matching entirely; that will simplify this
// logic.
incoming, isBatch, err := readBatch(body)
incoming, err := jsonrpc2.DecodeMessage(body)
if err != nil {
http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest)
return
Expand All @@ -1395,160 +1385,144 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
protocolVersion = protocolVersion20250326
}

if isBatch && protocolVersion >= protocolVersion20250618 {
http.Error(w, fmt.Sprintf("JSON-RPC batching is not supported in %s and later (request version: %s)", protocolVersion20250618, protocolVersion), http.StatusBadRequest)
return
}

// TODO(rfindley): no tests fail if we reject batch JSON requests entirely.
// We need to test this with older protocol versions.
// if isBatch && c.jsonResponse {
// http.Error(w, "server does not support batch requests", http.StatusBadRequest)
// return
// }

calls := make(map[jsonrpc.ID]struct{})
tokenInfo := auth.TokenInfoFromContext(req.Context())
isInitialize := false
isSubscriptionsListen := false
var initializeProtocolVersion string
for _, msg := range incoming {
if jreq, ok := msg.(*jsonrpc.Request); ok {
// Preemptively check that this is a valid request, so that we can fail
// the HTTP request. If we didn't do this, a request with a bad method or
// missing ID could be silently swallowed.
// Use the server's receiving method infos (which include any custom
// methods registered via AddReceivingCustomMethod) when available;
// fall back to the standard methods otherwise, e.g. in tests that
// exercise streamableServerConn directly without a server.
methodInfos := serverMethodInfos
if c.server != nil {
methodInfos = c.server.receivingMethodInfos()
if jreq, ok := incoming.(*jsonrpc.Request); ok {
// Preemptively check that this is a valid request, so that we can fail
// the HTTP request. If we didn't do this, a request with a bad method or
// missing ID could be silently swallowed.
// Use the server's receiving method infos (which include any custom
// methods registered via AddReceivingCustomMethod) when available;
// fall back to the standard methods otherwise, e.g. in tests that
// exercise streamableServerConn directly without a server.
methodInfos := serverMethodInfos
if c.server != nil {
methodInfos = c.server.receivingMethodInfos()
}
if _, err := checkRequest(jreq, methodInfos); err != nil {
if protocolVersion >= protocolVersion20260728 && errors.Is(err, jsonrpc2.ErrNotHandled) && jreq.IsCall() {
writeJSONRPCError(w, http.StatusNotFound, jreq.ID, &jsonrpc.Error{
Code: jsonrpc.CodeMethodNotFound,
Message: err.Error(),
})
return
}
if _, err := checkRequest(jreq, methodInfos); err != nil {
if protocolVersion >= protocolVersion20260728 && errors.Is(err, jsonrpc2.ErrNotHandled) && jreq.IsCall() {
writeJSONRPCError(w, http.StatusNotFound, jreq.ID, &jsonrpc.Error{
Code: jsonrpc.CodeMethodNotFound,
Message: err.Error(),
})
return
}
http.Error(w, err.Error(), http.StatusBadRequest)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if jreq.Method == methodInitialize {
isInitialize = true
// Extract the protocol version from InitializeParams.
var params InitializeParams
if err := internaljson.Unmarshal(jreq.Params, &params); err == nil {
initializeProtocolVersion = params.ProtocolVersion
}
}
if jreq.Method == methodSubscriptionsListen {
isSubscriptionsListen = true
}
// SEP-2575: requests carrying `_meta.protocolVersion` require the
// Mcp-Protocol-Version HTTP header to be present and to match the
// per-request `_meta.protocolVersion` value.
// The new (>= 2026-07-28) protocol is supported on the HTTP transport
// only when [StreamableHTTPOptions.Stateless] is true.
//
// TODO: this validation can be moved within validateMcpHeaders.
var metaVersion string
if meta := extractRequestMeta(jreq.Params); meta != nil {
metaVersion, _ = meta[MetaKeyProtocolVersion].(string)
}
if protocolVersion >= protocolVersion20260728 || metaVersion != "" {
// Extract again the protcol version from the context to see what the client
// is advertising in the Mcp-Protocol-Version HTTP header.
headerVersion := protocolVersionFromContext(req.Context())
// server/discover is exempt from the stateful
// rejection as it should learn about the supported protocols from the
// DiscoverResult response.
if !c.stateless && jreq.Method != methodDiscover {
http.Error(w, fmt.Sprintf(
"Bad Request: protocol version %q is only supported on stateless HTTP servers (set StreamableHTTPOptions.Stateless = true)",
protocolVersion),
http.StatusBadRequest)
return
}
if jreq.Method == methodInitialize {
isInitialize = true
// Extract the protocol version from InitializeParams.
var params InitializeParams
if err := internaljson.Unmarshal(jreq.Params, &params); err == nil {
initializeProtocolVersion = params.ProtocolVersion
}
if headerVersion == "" {
writeJSONRPCError(w, http.StatusBadRequest, jreq.ID, &jsonrpc.Error{
Code: CodeHeaderMismatch,
Message: fmt.Sprintf(
"%s header is required for requests carrying %q",
protocolVersionHeader, MetaKeyProtocolVersion),
})
return
}
if jreq.Method == methodSubscriptionsListen {
isSubscriptionsListen = true
if metaVersion == "" {
writeJSONRPCError(w, http.StatusBadRequest, jreq.ID, &jsonrpc.Error{
Code: jsonrpc.CodeInvalidParams,
Message: fmt.Sprintf(
"missing or invalid _meta field %q",
MetaKeyProtocolVersion),
})
return
}
// SEP-2575: requests carrying `_meta.protocolVersion` require the
// Mcp-Protocol-Version HTTP header to be present and to match the
// per-request `_meta.protocolVersion` value.
// The new (>= 2026-07-28) protocol is supported on the HTTP transport
// only when [StreamableHTTPOptions.Stateless] is true.
//
// TODO: this validation can be moved within validateMcpHeaders.
var metaVersion string
if meta := extractRequestMeta(jreq.Params); meta != nil {
metaVersion, _ = meta[MetaKeyProtocolVersion].(string)
if headerVersion != metaVersion {
writeJSONRPCError(w, http.StatusBadRequest, jreq.ID, &jsonrpc.Error{
Code: CodeHeaderMismatch,
Message: fmt.Sprintf(
"%s header %q does not match request %s %q",
protocolVersionHeader, headerVersion,
MetaKeyProtocolVersion, metaVersion),
})
return
}
if protocolVersion >= protocolVersion20260728 || metaVersion != "" {
// Extract again the protcol version from the context to see what the client
// is advertising in the Mcp-Protocol-Version HTTP header.
headerVersion := protocolVersionFromContext(req.Context())
// server/discover is exempt from the stateful
// rejection as it should learn about the supported protocols from the
// DiscoverResult response.
if !c.stateless && jreq.Method != methodDiscover {
http.Error(w, fmt.Sprintf(
"Bad Request: protocol version %q is only supported on stateless HTTP servers (set StreamableHTTPOptions.Stateless = true)",
protocolVersion),
http.StatusBadRequest)
return
}
if headerVersion == "" {
writeJSONRPCError(w, http.StatusBadRequest, jreq.ID, &jsonrpc.Error{
Code: CodeHeaderMismatch,
Message: fmt.Sprintf(
"%s header is required for requests carrying %q",
protocolVersionHeader, MetaKeyProtocolVersion),
})
return
}
if metaVersion == "" {
writeJSONRPCError(w, http.StatusBadRequest, jreq.ID, &jsonrpc.Error{
Code: jsonrpc.CodeInvalidParams,
Message: fmt.Sprintf(
"missing or invalid _meta field %q",
MetaKeyProtocolVersion),
})
}
// Include metadata for all requests (including notifications).
jreq.Extra = &RequestExtra{
TokenInfo: tokenInfo,
Header: req.Header,
}
if jreq.IsCall() {
calls[jreq.ID] = struct{}{}
// See the doc for CloseSSEStream: allow the request handler to
// explicitly close the ongoing stream.
jreq.Extra.(*RequestExtra).CloseSSEStream = func(args CloseSSEStreamArgs) {
// This mechanism was designed to trigger client reconnection with
// Last-Event-ID for server-initiated disconnect scenarios. It is
// deprecated in protocol version 2026-07-28.
if protocolVersion >= protocolVersion20260728 {
return
}
if headerVersion != metaVersion {
writeJSONRPCError(w, http.StatusBadRequest, jreq.ID, &jsonrpc.Error{
Code: CodeHeaderMismatch,
Message: fmt.Sprintf(
"%s header %q does not match request %s %q",
protocolVersionHeader, headerVersion,
MetaKeyProtocolVersion, metaVersion),
})
return
c.mu.Lock()
streamID, ok := c.requestStreams[jreq.ID]
var stream *stream
if ok {
stream = c.streams[streamID]
}
}
// Include metadata for all requests (including notifications).
jreq.Extra = &RequestExtra{
TokenInfo: tokenInfo,
Header: req.Header,
}
if jreq.IsCall() {
calls[jreq.ID] = struct{}{}
// See the doc for CloseSSEStream: allow the request handler to
// explicitly close the ongoing stream.
jreq.Extra.(*RequestExtra).CloseSSEStream = func(args CloseSSEStreamArgs) {
// This mechanism was designed to trigger client reconnection with
// Last-Event-ID for server-initiated disconnect scenarios. It is
// deprecated in protocol version 2026-07-28.
if protocolVersion >= protocolVersion20260728 {
return
}
c.mu.Lock()
streamID, ok := c.requestStreams[jreq.ID]
var stream *stream
if ok {
stream = c.streams[streamID]
}
c.mu.Unlock()

if stream != nil {
stream.close(args.RetryAfter)
}
c.mu.Unlock()

if stream != nil {
stream.close(args.RetryAfter)
}
}
}
}

// Validate MCP standard headers (Mcp-Method, Mcp-Name, Mcp-Param-*)
if !isBatch && len(incoming) == 1 {
if err := validateMcpHeaders(req.Header, incoming[0], c.toolLookup); err != nil {
resp := &jsonrpc.Response{
Error: jsonrpc2.NewError(CodeHeaderMismatch, err.Error()),
}
if jreq, ok := incoming[0].(*jsonrpc.Request); ok {
resp.ID = jreq.ID
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
if data, err := jsonrpc2.EncodeMessage(resp); err == nil {
w.Write(data)
}
return
if err := validateMcpHeaders(req.Header, incoming, c.toolLookup); err != nil {
resp := &jsonrpc.Response{
Error: jsonrpc2.NewError(CodeHeaderMismatch, err.Error()),
}
if jreq, ok := incoming.(*jsonrpc.Request); ok {
resp.ID = jreq.ID
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
if data, err := jsonrpc2.EncodeMessage(resp); err == nil {
w.Write(data)
}
return
}

// The prime and close events were added in protocol version 2025-11-25 (SEP-1699).
Expand All @@ -1567,15 +1541,13 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
//
// [§2.1.4]: https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#sending-messages-to-the-server
if len(calls) == 0 {
for _, msg := range incoming {
select {
case c.incoming <- msg:
case <-c.done:
// The session is closing. Since we haven't yet written any data to the
// response, we can signal to the client that the session is gone.
http.Error(w, "session is closing", http.StatusNotFound)
return
}
select {
case c.incoming <- incoming:
case <-c.done:
// The session is closing. Since we haven't yet written any data to the
// response, we can signal to the client that the session is gone.
http.Error(w, "session is closing", http.StatusNotFound)
return
}
w.WriteHeader(http.StatusAccepted)
return
Expand Down Expand Up @@ -1671,19 +1643,17 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
}

// Publish incoming messages.
for _, msg := range incoming {
select {
case c.incoming <- msg:
// Note: don't select on req.Context().Done() here, since we've already
// received the requests and may have already published a response message
// or notification. The client could resume the stream.
//
// In fact, this send could be in a separate goroutine.
case <-c.done:
// Session closed: we don't know if any data has been written, so it's
// too late to write a status code here.
return
}
select {
case c.incoming <- incoming:
// Note: don't select on req.Context().Done() here, since we've already
// received the requests and may have already published a response message
// or notification. The client could resume the stream.
//
// In fact, this send could be in a separate goroutine.
case <-c.done:
// Session closed: we don't know if any data has been written, so it's
// too late to write a status code here.
return
}

c.hangResponse(req.Context(), done)
Expand Down
Loading
Loading