Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 40 additions & 54 deletions mcp/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,72 +119,58 @@ type InputResponse interface{ isInputResponse() }
// input-required result.
type InputResponseMap map[string]InputResponse

func (m InputResponseMap) MarshalJSON() ([]byte, error) {
type wire struct {
Method string `json:"method"`
Result InputResponse `json:"result,omitempty"`
}
typeToMethod := func(v InputResponse) (string, error) {
switch v.(type) {
case *ElicitResult:
return methodElicit, nil
case *CreateMessageResult, *CreateMessageWithToolsResult:
return methodCreateMessage, nil
case *ListRootsResult:
return methodListRoots, nil
default:
return "", fmt.Errorf("unsupported type: %T", v)
}
}
converted := map[string]*wire{}
for k, v := range m {
method, err := typeToMethod(v)
if err != nil {
return nil, err
}
converted[k] = &wire{Method: method, Result: v}
}
return json.Marshal(converted)
}

func (m *InputResponseMap) UnmarshalJSON(data []byte) error {
type raw struct {
Method string `json:"method"`
Result json.RawMessage `json:"result"`
}
var rawMap map[string]*raw
var rawMap map[string]json.RawMessage
if err := json.Unmarshal(data, &rawMap); err != nil {
return err
}
result := make(InputResponseMap, len(rawMap))
for k, raw := range rawMap {
switch raw.Method {
case methodElicit:
var p ElicitResult
if err := json.Unmarshal(raw.Result, &p); err != nil {
return err
}
result[k] = &p
case methodCreateMessage:
var p CreateMessageWithToolsResult
if err := json.Unmarshal(raw.Result, &p); err != nil {
return err
}
result[k] = &p
case methodListRoots:
var p ListRootsResult
if err := json.Unmarshal(raw.Result, &p); err != nil {
return err
}
result[k] = &p
default:
return fmt.Errorf("unsupported InputResponse method: %q", raw.Method)
v, err := unmarshalInputResponse(raw)
if err != nil {
return fmt.Errorf("inputResponses[%q]: %w", k, err)
}
result[k] = v
}
*m = result
return nil
}

// unmarshalInputResponse determines the concrete InputResponse type from the
// JSON structure by searching for a discriminating key in a raw message.
func unmarshalInputResponse(data json.RawMessage) (InputResponse, error) {
var probe struct {
Action json.RawMessage `json:"action"`
Role json.RawMessage `json:"role"`
Roots json.RawMessage `json:"roots"`
}
if err := json.Unmarshal(data, &probe); err != nil {
return nil, err
}
switch {
case probe.Roots != nil:
var p ListRootsResult
if err := json.Unmarshal(data, &p); err != nil {
return nil, err
}
return &p, nil
case probe.Action != nil:
var p ElicitResult
if err := json.Unmarshal(data, &p); err != nil {
return nil, err
}
return &p, nil
case probe.Role != nil:
var p CreateMessageWithToolsResult
if err := json.Unmarshal(data, &p); err != nil {
return nil, err
}
return &p, nil
default:
return nil, fmt.Errorf(`cannot determine InputResponse type: expected "action", "role", or "roots" key`)
}
}

// Optional annotations for the client. The client can use annotations to inform
// how objects are used or displayed.
type Annotations struct {
Expand Down
37 changes: 37 additions & 0 deletions mcp/protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1216,6 +1216,43 @@ func TestInputRequestMapJSON(t *testing.T) {
})
}

func TestInputResponseMapJSON(t *testing.T) {
tests := []struct {
name string
value InputResponse
check func(t *testing.T, got InputResponse)
}{
{
name: "elicit",
value: &ElicitResult{Action: "accept", Content: map[string]any{"ok": true}},
},
{
name: "sampling",
value: &CreateMessageWithToolsResult{Role: "assistant", Model: "test-model", Content: []Content{&TextContent{Text: "hello"}}},
},
{
name: "list-roots",
value: &ListRootsResult{Roots: []*Root{{URI: "file:///tmp", Name: "tmp"}}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
key := tt.name
data, err := json.Marshal(InputResponseMap{key: tt.value})
if err != nil {
t.Fatal(err)
}
var got InputResponseMap
if err := json.Unmarshal(data, &got); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(tt.value, got[key], ctrCmpOpts...); diff != "" {
t.Errorf("mismatch (-want, +got):\n%s", diff)
}
})
}
}

func TestContentUnmarshal(t *testing.T) {
// Verify that types with a Content field round-trip properly.
roundtrip := func(in, out any) {
Expand Down
5 changes: 1 addition & 4 deletions mcp/testdata/conformance/server/mrtr.txtar
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,7 @@ confirmThenGreet
"arguments": {},
"requestState": "step=1",
"inputResponses": {
"who": {
"method": "elicitation/create",
"result": { "action": "accept", "content": { "name": "MCP Go" } }
}
"who": { "action": "accept", "content": { "name": "MCP Go" } }
}
}
}
Expand Down
Loading