From d619a3bdc69da7283e478fa3af3cf9fff1458f92 Mon Sep 17 00:00:00 2001 From: Sambhav Kothari Date: Sat, 27 Jun 2026 20:33:24 +0100 Subject: [PATCH] mcp: bidirectional custom methods, CustomMethod type, and extension registry MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Builds on the custom JSON-RPC method support in #956 in three ways: 1. **Bidirectionality** — the server can now send requests to the client, and the client can register handlers for them, mirroring the existing client→server direction: - AddServerSendingCustomMethod / ServerCallCustomMethod - AddClientReceivingCustomMethod - CustomMethod.RegisterServerSending / .ServerCall - CustomMethod.RegisterClientReceiving 2. **CustomMethod[P, R, T]** — a phantom-type wrapper that captures the method name and parameter/result types once at package level, so call sites never repeat generic arguments or method-name strings: var Method = mcp.NewCustomMethod[*Params, *Result]("acme/method") result, err := Method.Call(ctx, cs, &Params{...}) 3. **Extension registry** — a mechanism for libraries to auto-wire custom methods into every Server/Client without requiring callers to do any manual setup: - RegisterExtension (global, typically called from init) - ServerOptions.Extensions / ClientOptions.Extensions (per-instance) Global extensions are applied first; per-instance extensions after (last writer wins on name collision). The example is restructured to demonstrate the pattern: a latinext sub-package is the "extension author" (defines types, registers via init(), exports a Translate() helper); main.go is the "consumer" (just imports latinext, no generics or method-name strings visible). Co-Authored-By: Claude Sonnet 4.6 --- .../server/custom-method/latinext/latin.go | 87 +++++++++++ examples/server/custom-method/main.go | 80 ++-------- mcp/client.go | 51 +++++- mcp/extension.go | 145 ++++++++++++++++++ mcp/server.go | 80 +++++++++- 5 files changed, 373 insertions(+), 70 deletions(-) create mode 100644 examples/server/custom-method/latinext/latin.go create mode 100644 mcp/extension.go diff --git a/examples/server/custom-method/latinext/latin.go b/examples/server/custom-method/latinext/latin.go new file mode 100644 index 00000000..b619a032 --- /dev/null +++ b/examples/server/custom-method/latinext/latin.go @@ -0,0 +1,87 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Package latinext is an example MCP extension that adds a "latin/translate" +// custom JSON-RPC method. It demonstrates the extension-author pattern: types, +// the CustomMethod variable, and the init() registration are all defined here +// so that importers get everything wired up automatically. +package latinext + +import ( + "context" + "fmt" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// TranslateParams are the parameters for the latin/translate method. +type TranslateParams struct { + mcp.ParamsBase + Text string `json:"text"` +} + +// TranslateResult is the result of the latin/translate method. +type TranslateResult struct { + mcp.ResultBase + Latin string `json:"latin"` +} + +// Method captures the method name and types once. Extension consumers call +// Translate() rather than this directly. +var Method = mcp.NewCustomMethod[*TranslateParams, *TranslateResult]("latin/translate") + +func init() { + mcp.RegisterExtension(mcp.Extension{ + Server: func(s *mcp.Server) error { + return Method.RegisterServerReceiving(s, DefaultHandler) + }, + Client: func(c *mcp.Client) error { + return Method.RegisterClientSending(c) + }, + }) +} + +// Translate calls the latin/translate method on the server via cs. +// This is the one-liner that extension consumers use — no generics, no method +// name strings. +func Translate(ctx context.Context, cs *mcp.ClientSession, text string) (*TranslateResult, error) { + return Method.Call(ctx, cs, &TranslateParams{Text: text}) +} + +// DefaultHandler is the reference server-side implementation. It can be +// overridden per-server by calling Method.RegisterServer(server, myHandler). +func DefaultHandler(_ context.Context, _ *mcp.ServerSession, params *TranslateParams) (*TranslateResult, error) { + key := strings.ToLower(strings.TrimSpace(params.Text)) + latin, ok := translations[key] + if !ok { + latin = fmt.Sprintf("[unknown: %q]", params.Text) + } + return &TranslateResult{Latin: latin}, nil +} + +var translations = map[string]string{ + "hello": "salve", + "goodbye": "vale", + "thank you": "gratias tibi ago", + "how are you": "quid agis", + "good morning": "bonum mane", + "good night": "bonam noctem", + "friend": "amicus", + "water": "aqua", + "love": "amor", + "war": "bellum", + "peace": "pax", + "truth": "veritas", + "light": "lux", + "time": "tempus", + "life": "vita", + "death": "mors", + "star": "stella", + "earth": "terra", + "sea": "mare", + "the die is cast": "alea iacta est", + "i came i saw i conquered": "veni vidi vici", + "seize the day": "carpe diem", +} diff --git a/examples/server/custom-method/main.go b/examples/server/custom-method/main.go index 6b2861fe..e9590ede 100644 --- a/examples/server/custom-method/main.go +++ b/examples/server/custom-method/main.go @@ -2,74 +2,31 @@ // Use of this source code is governed by the license // that can be found in the LICENSE file. -// The custom-method example demonstrates registering and calling a custom -// JSON-RPC method that is not part of the standard MCP spec. +// The custom-method example demonstrates the extension-author / extension-consumer +// split for custom JSON-RPC methods. // -// The server registers a "latin/translate" method that translates simple -// English phrases into Latin. A client connects over an in-memory transport, -// calls the custom method, and prints the result. +// The latinext sub-package is the "extension author": it defines the types, +// registers the method via init(), and exposes a domain-specific Translate() +// helper. This file is the "extension consumer": importing latinext is all +// that's needed to wire up both sides. package main import ( "context" "fmt" "log" - "strings" + "github.com/modelcontextprotocol/go-sdk/examples/server/custom-method/latinext" "github.com/modelcontextprotocol/go-sdk/mcp" ) -type TranslateParams struct { - mcp.ParamsBase - Text string `json:"text"` -} - -type TranslateResult struct { - mcp.ResultBase - Latin string `json:"latin"` -} - -var translations = map[string]string{ - "hello": "salve", - "goodbye": "vale", - "thank you": "gratias tibi ago", - "how are you": "quid agis", - "good morning": "bonum mane", - "good night": "bonam noctem", - "friend": "amicus", - "water": "aqua", - "love": "amor", - "war": "bellum", - "peace": "pax", - "truth": "veritas", - "light": "lux", - "time": "tempus", - "life": "vita", - "death": "mors", - "star": "stella", - "earth": "terra", - "sea": "mare", - "the die is cast": "alea iacta est", - "i came i saw i conquered": "veni vidi vici", - "seize the day": "carpe diem", -} - func main() { ctx := context.Background() + // NewServer and NewClient automatically apply all extensions registered via + // init() — including latinext's handler and sending registration. server := mcp.NewServer(&mcp.Implementation{Name: "latin-server", Version: "v1.0.0"}, nil) - - if err := mcp.AddReceivingCustomMethod(server, "latin/translate", - func(ctx context.Context, ss *mcp.ServerSession, params *TranslateParams) (*TranslateResult, error) { - key := strings.ToLower(strings.TrimSpace(params.Text)) - latin, ok := translations[key] - if !ok { - latin = fmt.Sprintf("[unknown: %q — try: %s]", params.Text, knownPhrases()) - } - return &TranslateResult{Latin: latin}, nil - }); err != nil { - log.Fatal(err) - } + client := mcp.NewClient(&mcp.Implementation{Name: "latin-client", Version: "v1.0.0"}, nil) ct, st := mcp.NewInMemoryTransports() @@ -79,32 +36,19 @@ func main() { } defer ss.Close() - client := mcp.NewClient(&mcp.Implementation{Name: "latin-client", Version: "v1.0.0"}, nil) - if err := mcp.AddSendingCustomMethod[*TranslateParams, *TranslateResult](client, "latin/translate"); err != nil { - log.Fatal(err) - } - cs, err := client.Connect(ctx, ct, nil) if err != nil { log.Fatal(err) } defer cs.Close() + // Call the custom method — no generics, no method-name strings. phrases := []string{"Hello", "Seize the day", "Peace", "Truth", "I came I saw I conquered"} for _, phrase := range phrases { - result, err := mcp.CallCustomMethod[*TranslateParams, *TranslateResult]( - ctx, cs, "latin/translate", &TranslateParams{Text: phrase}) + result, err := latinext.Translate(ctx, cs, phrase) if err != nil { log.Fatalf("translate %q: %v", phrase, err) } fmt.Printf("%-35s → %s\n", phrase, result.Latin) } } - -func knownPhrases() string { - phrases := make([]string, 0, len(translations)) - for k := range translations { - phrases = append(phrases, fmt.Sprintf("%q", k)) - } - return strings.Join(phrases, ", ") -} diff --git a/mcp/client.go b/mcp/client.go index 0a6c906f..9d633c84 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -39,6 +39,11 @@ type Client struct { // serverMethodInfos) plus any custom methods registered via // [AddSendingCustomMethod]. sendMethods map[string]methodInfo + // receiveMethods is the merged map of methods this client may receive from + // a server: it always contains the standard client methods (from + // clientMethodInfos) plus any custom methods registered via + // [AddClientReceivingCustomMethod]. + receiveMethods map[string]methodInfo } // NewClient creates a new [Client]. @@ -68,6 +73,9 @@ func NewClient(impl *Implementation, options *ClientOptions) *Client { sendMethods := make(map[string]methodInfo, len(serverMethodInfos)) maps.Copy(sendMethods, serverMethodInfos) + receiveMethods := make(map[string]methodInfo, len(clientMethodInfos)) + maps.Copy(receiveMethods, clientMethodInfos) + c := &Client{ impl: impl, opts: opts, @@ -75,10 +83,13 @@ func NewClient(impl *Implementation, options *ClientOptions) *Client { sendingMethodHandler_: defaultSendingMethodHandler, receivingMethodHandler_: defaultReceivingMethodHandler[*ClientSession], sendMethods: sendMethods, + receiveMethods: receiveMethods, } if opts.MultiRoundTrip == nil || !opts.MultiRoundTrip.Disabled { c.AddSendingMiddleware(clientMultiRoundTripMiddleware()) } + applyExtensionsToClient(c) + runExtensions(opts.Extensions, func(e Extension) func(*Client) error { return e.Client }, c) return c } @@ -204,6 +215,10 @@ type ClientOptions struct { // reset" guidance, letting a transient miss pass without tearing down an // otherwise live session. Has no effect unless KeepAlive is non-zero. KeepAliveFailureThreshold int + // Extensions are applied to the client during [NewClient], after any + // globally registered extensions (see [RegisterExtension]). Per-client + // extensions override global ones for the same method names. + Extensions []Extension } // toolContextKeyType is the context key type for passing tool definitions @@ -1170,7 +1185,9 @@ func (cs *ClientSession) sendingMethodInfos() map[string]methodInfo { } func (cs *ClientSession) receivingMethodInfos() map[string]methodInfo { - return clientMethodInfos + cs.client.mu.Lock() + defer cs.client.mu.Unlock() + return cs.client.receiveMethods } func (cs *ClientSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) { @@ -1674,3 +1691,35 @@ func CallCustomMethod[P paramsPtr[PT], R Result, PT any]( Params: params, }) } + +// AddClientReceivingCustomMethod registers a handler for a custom +// (non-standard) JSON-RPC method that the client may receive from a server. +// +// When a server sends a request with the given method name, the params will be +// unmarshaled into P, the handler will be called, and the returned R will be +// marshaled as the JSON-RPC result. +// +// P and R must implement [Params] and [Result] respectively, which is most +// easily done by embedding [ParamsBase] and [ResultBase]. +// +// AddClientReceivingCustomMethod returns an error if method is the name of a +// standard MCP method. Registering the same custom method twice replaces the +// previous handler. +func AddClientReceivingCustomMethod[P paramsPtr[T], R Result, T any]( + c *Client, + method string, + handler func(ctx context.Context, cs *ClientSession, params P) (R, error), +) error { + if _, ok := clientMethodInfos[method]; ok { + return fmt.Errorf("mcp: AddClientReceivingCustomMethod: %q shadows a standard MCP method", method) + } + + typed := typedClientMethodHandler[P, R](func(ctx context.Context, req *ClientRequest[P]) (R, error) { + return handler(ctx, req.Session, req.Params) + }) + + c.mu.Lock() + defer c.mu.Unlock() + c.receiveMethods[method] = newClientMethodInfo(typed, missingParamsOK) + return nil +} diff --git a/mcp/extension.go b/mcp/extension.go new file mode 100644 index 00000000..a92c0ce1 --- /dev/null +++ b/mcp/extension.go @@ -0,0 +1,145 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "fmt" + "sync" +) + +// CustomMethod captures the method name and parameter/result types for a +// custom JSON-RPC method. Extension authors define a package-level var using +// [NewCustomMethod]; consumers call the resulting methods without ever writing +// generic type parameters or method-name strings. +// +// For the client-to-server direction: +// +// var Method = mcp.NewCustomMethod[*MyParams, *MyResult]("acme/method") +// +// // Extension author wires server and client once: +// Method.RegisterServerReceiving(server, MyHandler) +// Method.RegisterClientSending(client) +// +// // Consumer calls with no generics visible: +// result, err := Method.Call(ctx, cs, &MyParams{...}) +// +// For the server-to-client direction use [CustomMethod.RegisterServerSending], +// [CustomMethod.RegisterClientReceiving], and [CustomMethod.ServerCall]. +// +// P, R, and T are phantom type parameters — they are not stored in the struct +// but thread through to the wrapped generic functions so call sites stay +// type-safe without repeating type arguments. +type CustomMethod[P paramsPtr[T], R Result, T any] struct { + name string +} + +// NewCustomMethod creates a [CustomMethod] that captures the method name and +// its parameter and result types. The name must not be the name of a standard +// MCP method. +func NewCustomMethod[P paramsPtr[T], R Result, T any](name string) *CustomMethod[P, R, T] { + return &CustomMethod[P, R, T]{name: name} +} + +// Name returns the JSON-RPC method name. +func (m *CustomMethod[P, R, T]) Name() string { return m.name } + +// RegisterServerReceiving registers handler on s to handle incoming requests +// for this method from clients. It wraps [AddReceivingCustomMethod]. +func (m *CustomMethod[P, R, T]) RegisterServerReceiving(s *Server, handler func(ctx context.Context, ss *ServerSession, params P) (R, error)) error { + return AddReceivingCustomMethod(s, m.name, handler) +} + +// RegisterServerSending registers this method on s so that the server may +// call clients via [CustomMethod.ServerCall]. It wraps +// [AddServerSendingCustomMethod]. +func (m *CustomMethod[P, R, T]) RegisterServerSending(s *Server) error { + return AddServerSendingCustomMethod[P, R](s, m.name) +} + +// RegisterClientSending registers this method on c so that the client may +// send it to a server via [CustomMethod.Call]. It wraps [AddSendingCustomMethod]. +func (m *CustomMethod[P, R, T]) RegisterClientSending(c *Client) error { + return AddSendingCustomMethod[P, R](c, m.name) +} + +// RegisterClientReceiving registers handler on c to handle incoming requests +// for this method from servers. It wraps [AddClientReceivingCustomMethod]. +func (m *CustomMethod[P, R, T]) RegisterClientReceiving(c *Client, handler func(ctx context.Context, cs *ClientSession, params P) (R, error)) error { + return AddClientReceivingCustomMethod(c, m.name, handler) +} + +// Call invokes this method on the server via cs. It wraps [CallCustomMethod]. +// The method must have been registered on the client via [CustomMethod.RegisterClientSending]. +func (m *CustomMethod[P, R, T]) Call(ctx context.Context, cs *ClientSession, params P) (R, error) { + return CallCustomMethod[P, R](ctx, cs, m.name, params) +} + +// ServerCall invokes this method on the client via ss. It wraps +// [ServerCallCustomMethod]. The method must have been registered on the server +// via [CustomMethod.RegisterServerSending]. +func (m *CustomMethod[P, R, T]) ServerCall(ctx context.Context, ss *ServerSession, params P) (R, error) { + return ServerCallCustomMethod[P, R](ctx, ss, m.name, params) +} + +// Extension describes a set of custom methods that can be auto-applied to +// every new [Server] and [Client] via [RegisterExtension]. +// +// Extension authors typically call [RegisterExtension] in an init function so +// that importing the extension package is sufficient to wire everything up. +// Either field may be nil if the extension only applies to one side. +// +// If applying an extension returns an error (e.g. the method name shadows a +// standard method), [NewServer] or [NewClient] will panic. +type Extension struct { + // Server, if non-nil, is called by [NewServer] to register the extension. + Server func(*Server) error + // Client, if non-nil, is called by [NewClient] to register the extension. + Client func(*Client) error +} + +var ( + extensionsMu sync.Mutex + extensions []Extension +) + +// RegisterExtension adds ext to the global extension registry. [NewServer] +// and [NewClient] apply all registered extensions in registration order, +// before any per-instance extensions set in [ServerOptions.Extensions] or +// [ClientOptions.Extensions]. +// +// RegisterExtension is safe for concurrent use and is typically called from +// init functions. For scoped registration that does not affect the whole +// process, use [ServerOptions.Extensions] / [ClientOptions.Extensions] instead. +func RegisterExtension(ext Extension) { + extensionsMu.Lock() + defer extensionsMu.Unlock() + extensions = append(extensions, ext) +} + +func applyExtensionsToServer(s *Server) { + applyExtensions(func(e Extension) func(*Server) error { return e.Server }, s) +} + +func applyExtensionsToClient(c *Client) { + applyExtensions(func(e Extension) func(*Client) error { return e.Client }, c) +} + +func applyExtensions[T any](get func(Extension) func(T) error, arg T) { + extensionsMu.Lock() + exts := append([]Extension(nil), extensions...) + extensionsMu.Unlock() + runExtensions(exts, get, arg) +} + +func runExtensions[T any](exts []Extension, get func(Extension) func(T) error, arg T) { + for _, ext := range exts { + if fn := get(ext); fn != nil { + if err := fn(arg); err != nil { + panic(fmt.Errorf("mcp: applying extension: %w", err)) + } + } + } +} diff --git a/mcp/server.go b/mcp/server.go index 21b5a722..9bfff00d 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -62,6 +62,11 @@ type Server struct { // serverMethodInfos) plus any custom methods registered via // [AddReceivingCustomMethod]. receiveMethods map[string]methodInfo + // sendMethods is the merged map of methods this server may send to a + // client: it always contains the standard client methods (from + // clientMethodInfos) plus any custom methods registered via + // [AddServerSendingCustomMethod]. + sendMethods map[string]methodInfo } // ServerOptions is used to configure behavior of the server. @@ -168,6 +173,10 @@ type ServerOptions struct { // GetSessionID is not consulted when [StreamableHTTPOptions.Stateless] is // true, since stateless servers do not maintain sessions. GetSessionID func() string + // Extensions are applied to the server during [NewServer], after any + // globally registered extensions (see [RegisterExtension]). Per-server + // extensions override global ones for the same method names. + Extensions []Extension } // NewServer creates a new MCP server. The resulting server has no features: @@ -211,6 +220,9 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { receiveMethods := make(map[string]methodInfo, len(serverMethodInfos)) maps.Copy(receiveMethods, serverMethodInfos) + sendMethods := make(map[string]methodInfo, len(clientMethodInfos)) + maps.Copy(sendMethods, clientMethodInfos) + s := &Server{ impl: impl, opts: opts, @@ -226,8 +238,11 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { resourceSubscriptions: make(map[string]map[*ServerSession]jsonrpc.ID), pendingNotifications: make(map[string]*time.Timer), receiveMethods: receiveMethods, + sendMethods: sendMethods, } s.AddReceivingMiddleware(serverMultiRoundTripMiddleware()) + applyExtensionsToServer(s) + runExtensions(opts.Extensions, func(e Extension) func(*Server) error { return e.Server }, s) return s } @@ -1768,7 +1783,12 @@ func initializeMethodInfo() methodInfo { return info } -func (ss *ServerSession) sendingMethodInfos() map[string]methodInfo { return clientMethodInfos } +func (ss *ServerSession) sendingMethodInfos() map[string]methodInfo { + s := ss.server + s.mu.Lock() + defer s.mu.Unlock() + return s.sendMethods +} func (s *Server) receivingMethodInfos() map[string]methodInfo { s.mu.Lock() @@ -2085,3 +2105,61 @@ func AddReceivingCustomMethod[P paramsPtr[T], R Result, T any]( s.receiveMethods[method] = newServerMethodInfo(typed, missingParamsOK) return nil } + +// AddServerSendingCustomMethod registers a custom (non-standard) JSON-RPC +// method that the server may send to clients. +// +// Registration is decoupled from invocation: extensions typically call this +// during setup, while the actual call site uses [ServerCallCustomMethod]. +// +// if err := mcp.AddServerSendingCustomMethod[*PingParams, *PingResult](server, "acme/ping"); err != nil { +// return err +// } +// // ... later, with a *ServerSession: +// result, err := mcp.ServerCallCustomMethod[*PingParams, *PingResult]( +// ctx, ss, "acme/ping", &PingParams{}) +// +// AddServerSendingCustomMethod returns an error if method is the name of a +// standard MCP method. Registering the same method twice replaces the +// previous registration. +func AddServerSendingCustomMethod[P paramsPtr[T], R Result, T any]( + s *Server, + method string, +) error { + if _, ok := clientMethodInfos[method]; ok { + return fmt.Errorf("mcp: AddServerSendingCustomMethod: %q shadows a standard MCP method", method) + } + + mi := methodInfo{ + newResult: func() Result { + return reflect.New(reflect.TypeFor[R]().Elem()).Interface().(R) + }, + } + + s.mu.Lock() + defer s.mu.Unlock() + s.sendMethods[method] = mi + return nil +} + +// ServerCallCustomMethod sends a custom (non-standard) JSON-RPC method to the +// client and decodes the response into R. +// +// The method must have been registered on the session's server via +// [AddServerSendingCustomMethod]. +func ServerCallCustomMethod[P paramsPtr[T], R Result, T any]( + ctx context.Context, + ss *ServerSession, + method string, + params P, +) (R, error) { + s := ss.server + s.mu.Lock() + _, ok := s.sendMethods[method] + s.mu.Unlock() + if !ok { + var zero R + return zero, fmt.Errorf("mcp: ServerCallCustomMethod: %q is not registered; call AddServerSendingCustomMethod first", method) + } + return handleSend[R](ctx, method, newServerRequest(ss, params)) +}