diff --git a/mcp/protocol.go b/mcp/protocol.go index f077c6db..c1877417 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -119,72 +119,58 @@ type InputResponse interface{ isInputResponse() } // input-required result. type InputResponseMap map[string]InputResponse -func (m InputResponseMap) MarshalJSON() ([]byte, error) { - type wire struct { - Method string `json:"method"` - Result InputResponse `json:"result,omitempty"` - } - typeToMethod := func(v InputResponse) (string, error) { - switch v.(type) { - case *ElicitResult: - return methodElicit, nil - case *CreateMessageResult, *CreateMessageWithToolsResult: - return methodCreateMessage, nil - case *ListRootsResult: - return methodListRoots, nil - default: - return "", fmt.Errorf("unsupported type: %T", v) - } - } - converted := map[string]*wire{} - for k, v := range m { - method, err := typeToMethod(v) - if err != nil { - return nil, err - } - converted[k] = &wire{Method: method, Result: v} - } - return json.Marshal(converted) -} - func (m *InputResponseMap) UnmarshalJSON(data []byte) error { - type raw struct { - Method string `json:"method"` - Result json.RawMessage `json:"result"` - } - var rawMap map[string]*raw + var rawMap map[string]json.RawMessage if err := json.Unmarshal(data, &rawMap); err != nil { return err } result := make(InputResponseMap, len(rawMap)) for k, raw := range rawMap { - switch raw.Method { - case methodElicit: - var p ElicitResult - if err := json.Unmarshal(raw.Result, &p); err != nil { - return err - } - result[k] = &p - case methodCreateMessage: - var p CreateMessageWithToolsResult - if err := json.Unmarshal(raw.Result, &p); err != nil { - return err - } - result[k] = &p - case methodListRoots: - var p ListRootsResult - if err := json.Unmarshal(raw.Result, &p); err != nil { - return err - } - result[k] = &p - default: - return fmt.Errorf("unsupported InputResponse method: %q", raw.Method) + v, err := unmarshalInputResponse(raw) + if err != nil { + return fmt.Errorf("inputResponses[%q]: %w", k, err) } + result[k] = v } *m = result return nil } +// unmarshalInputResponse determines the concrete InputResponse type from the +// JSON structure by searching for a discriminating key in a raw message. +func unmarshalInputResponse(data json.RawMessage) (InputResponse, error) { + var probe struct { + Action json.RawMessage `json:"action"` + Role json.RawMessage `json:"role"` + Roots json.RawMessage `json:"roots"` + } + if err := json.Unmarshal(data, &probe); err != nil { + return nil, err + } + switch { + case probe.Roots != nil: + var p ListRootsResult + if err := json.Unmarshal(data, &p); err != nil { + return nil, err + } + return &p, nil + case probe.Action != nil: + var p ElicitResult + if err := json.Unmarshal(data, &p); err != nil { + return nil, err + } + return &p, nil + case probe.Role != nil: + var p CreateMessageWithToolsResult + if err := json.Unmarshal(data, &p); err != nil { + return nil, err + } + return &p, nil + default: + return nil, fmt.Errorf(`cannot determine InputResponse type: expected "action", "role", or "roots" key`) + } +} + // Optional annotations for the client. The client can use annotations to inform // how objects are used or displayed. type Annotations struct { diff --git a/mcp/protocol_test.go b/mcp/protocol_test.go index edf3623d..38fd48f4 100644 --- a/mcp/protocol_test.go +++ b/mcp/protocol_test.go @@ -1216,6 +1216,43 @@ func TestInputRequestMapJSON(t *testing.T) { }) } +func TestInputResponseMapJSON(t *testing.T) { + tests := []struct { + name string + value InputResponse + check func(t *testing.T, got InputResponse) + }{ + { + name: "elicit", + value: &ElicitResult{Action: "accept", Content: map[string]any{"ok": true}}, + }, + { + name: "sampling", + value: &CreateMessageWithToolsResult{Role: "assistant", Model: "test-model", Content: []Content{&TextContent{Text: "hello"}}}, + }, + { + name: "list-roots", + value: &ListRootsResult{Roots: []*Root{{URI: "file:///tmp", Name: "tmp"}}}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + key := tt.name + data, err := json.Marshal(InputResponseMap{key: tt.value}) + if err != nil { + t.Fatal(err) + } + var got InputResponseMap + if err := json.Unmarshal(data, &got); err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(tt.value, got[key], ctrCmpOpts...); diff != "" { + t.Errorf("mismatch (-want, +got):\n%s", diff) + } + }) + } +} + func TestContentUnmarshal(t *testing.T) { // Verify that types with a Content field round-trip properly. roundtrip := func(in, out any) { diff --git a/mcp/testdata/conformance/server/mrtr.txtar b/mcp/testdata/conformance/server/mrtr.txtar index e0dc18ce..7ee6dc50 100644 --- a/mcp/testdata/conformance/server/mrtr.txtar +++ b/mcp/testdata/conformance/server/mrtr.txtar @@ -36,10 +36,7 @@ confirmThenGreet "arguments": {}, "requestState": "step=1", "inputResponses": { - "who": { - "method": "elicitation/create", - "result": { "action": "accept", "content": { "name": "MCP Go" } } - } + "who": { "action": "accept", "content": { "name": "MCP Go" } } } } }