diff --git a/.changeset/fix-request-handler-result-types.md b/.changeset/fix-request-handler-result-types.md new file mode 100644 index 0000000000..2ec9fc8b52 --- /dev/null +++ b/.changeset/fix-request-handler-result-types.md @@ -0,0 +1,5 @@ +--- +'@modelcontextprotocol/sdk': patch +--- + +Tighten `setRequestHandler` return types for standard MCP methods so fixed-shape request handlers reject invalid result objects at compile time. diff --git a/src/client/index.ts b/src/client/index.ts index 03a6b40b5c..3a2bb095f6 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -1,4 +1,4 @@ -import { mergeCapabilities, Protocol, type ProtocolOptions, type RequestOptions } from '../shared/protocol.js'; +import { mergeCapabilities, Protocol, type ProtocolOptions, type RequestHandlerResult, type RequestOptions } from '../shared/protocol.js'; import type { Transport } from '../shared/transport.js'; import { @@ -331,7 +331,7 @@ export class Client< handler: ( request: SchemaOutput, extra: RequestHandlerExtra - ) => ClientResult | ResultT | Promise + ) => RequestHandlerResult | Promise> ): void { const shape = getObjectShape(requestSchema); const methodSchema = shape?.method; diff --git a/src/examples/client/simpleTaskInteractiveClient.ts b/src/examples/client/simpleTaskInteractiveClient.ts index 06ed0ead10..cc7df2f12a 100644 --- a/src/examples/client/simpleTaskInteractiveClient.ts +++ b/src/examples/client/simpleTaskInteractiveClient.ts @@ -17,6 +17,7 @@ import { CreateMessageRequestSchema, CreateMessageRequest, CreateMessageResult, + type ElicitResult, ErrorCode, McpError } from '../../types.js'; @@ -40,11 +41,7 @@ function getTextContent(result: { content: Array<{ type: string; text?: string } return textContent?.text ?? '(no text)'; } -async function elicitationCallback(params: { - mode?: string; - message: string; - requestedSchema?: object; -}): Promise<{ action: string; content?: Record }> { +async function elicitationCallback(params: { mode?: string; message: string; requestedSchema?: object }): Promise { console.log(`\n[Elicitation] Server asks: ${params.message}`); // Simple terminal prompt for y/n diff --git a/src/server/index.ts b/src/server/index.ts index 531a559dd5..80f3fb4368 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -1,4 +1,11 @@ -import { mergeCapabilities, Protocol, type NotificationOptions, type ProtocolOptions, type RequestOptions } from '../shared/protocol.js'; +import { + mergeCapabilities, + Protocol, + type NotificationOptions, + type ProtocolOptions, + type RequestHandlerResult, + type RequestOptions +} from '../shared/protocol.js'; import { type ClientCapabilities, type CreateMessageRequest, @@ -220,7 +227,7 @@ export class Server< handler: ( request: SchemaOutput, extra: RequestHandlerExtra - ) => ServerResult | ResultT | Promise + ) => RequestHandlerResult | Promise> ): void { const shape = getObjectShape(requestSchema); const methodSchema = shape?.method; diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 2637be65bc..4ac264cca7 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -29,6 +29,7 @@ import { RELATED_TASK_META_KEY, RequestId, Result, + ResultInputTypeMap, ServerCapabilities, RequestMeta, MessageExtraInfo, @@ -301,6 +302,38 @@ export type RequestHandlerExtra void; }; +type StrictRequestHandlerResultMap = Pick< + ResultInputTypeMap, + | 'ping' + | 'initialize' + | 'completion/complete' + | 'logging/setLevel' + | 'sampling/createMessage' + | 'prompts/get' + | 'prompts/list' + | 'resources/list' + | 'resources/templates/list' + | 'resources/read' + | 'resources/subscribe' + | 'resources/unsubscribe' + | 'tools/call' + | 'tools/list' + | 'elicitation/create' + | 'roots/list' + | 'tasks/get' + | 'tasks/list' + | 'tasks/cancel' +>; + +export type RequestHandlerResult = + SchemaOutput extends { + method: infer M; + } + ? M extends keyof StrictRequestHandlerResultMap + ? StrictRequestHandlerResultMap[M] + : Fallback + : Fallback; + /** * Information about a request's timeout state */ @@ -390,10 +423,9 @@ export abstract class Protocol { @@ -486,12 +518,11 @@ export abstract class Protocol { try { const { tasks, nextCursor } = await this._taskStore!.listTasks(request.params?.cursor, extra.sessionId); - // @ts-expect-error SendResultT cannot contain ListTasksResult, but we include it in our derived types everywhere else return { tasks, nextCursor, _meta: {} - } as SendResultT; + }; } catch (error) { throw new McpError( ErrorCode.InvalidParams, @@ -532,7 +563,7 @@ export abstract class Protocol, extra: RequestHandlerExtra - ) => SendResultT | Promise + ) => RequestHandlerResult | Promise> ): void { const method = getMethodLiteral(requestSchema); this.assertRequestHandlerCapability(method); this._requestHandlers.set(method, (request, extra) => { const parsed = parseWithCompat(requestSchema, request) as SchemaOutput; - return Promise.resolve(handler(parsed, extra)); + return Promise.resolve(handler(parsed, extra)) as Promise; }); } diff --git a/src/types.ts b/src/types.ts index 835eac89f8..76914559c5 100644 --- a/src/types.ts +++ b/src/types.ts @@ -2361,6 +2361,21 @@ type Flatten = T extends Primitive : T; type Infer = Flatten>; +type InputInfer = Flatten>; +type CallToolResultInput = InputInfer; +type CallToolHandlerResult = + | (Omit & + ( + | { + content: NonNullable; + structuredContent?: CallToolResultInput['structuredContent']; + } + | { + content?: CallToolResultInput['content']; + structuredContent: NonNullable; + } + )) + | CreateTaskResult; /** * Headers that are compatible with both Node.js and the browser. @@ -2625,3 +2640,59 @@ export type ClientResult = Infer; export type ServerRequest = Infer; export type ServerNotification = Infer; export type ServerResult = Infer; + +/* Protocol type maps */ +type MethodToTypeMap = { + [T in U as T extends { method: infer M extends string } ? M : never]: T; +}; +export type RequestMethod = ClientRequest['method'] | ServerRequest['method']; +export type NotificationMethod = ClientNotification['method'] | ServerNotification['method']; +export type RequestTypeMap = MethodToTypeMap; +export type NotificationTypeMap = MethodToTypeMap; +export type ResultTypeMap = { + ping: EmptyResult; + initialize: InitializeResult; + 'completion/complete': CompleteResult; + 'logging/setLevel': EmptyResult; + 'prompts/get': GetPromptResult; + 'prompts/list': ListPromptsResult; + 'resources/list': ListResourcesResult; + 'resources/templates/list': ListResourceTemplatesResult; + 'resources/read': ReadResourceResult; + 'resources/subscribe': EmptyResult; + 'resources/unsubscribe': EmptyResult; + 'tools/call': CallToolResult | CreateTaskResult; + 'tools/list': ListToolsResult; + 'sampling/createMessage': CreateMessageResult | CreateMessageResultWithTools | CreateTaskResult; + 'elicitation/create': ElicitResult | CreateTaskResult; + 'roots/list': ListRootsResult; + 'tasks/get': GetTaskResult; + 'tasks/result': Result; + 'tasks/list': ListTasksResult; + 'tasks/cancel': CancelTaskResult; +}; +export type ResultInputTypeMap = { + ping: InputInfer; + initialize: InputInfer; + 'completion/complete': InputInfer; + 'logging/setLevel': InputInfer; + 'prompts/get': InputInfer; + 'prompts/list': InputInfer; + 'resources/list': InputInfer; + 'resources/templates/list': InputInfer; + 'resources/read': InputInfer; + 'resources/subscribe': InputInfer; + 'resources/unsubscribe': InputInfer; + 'tools/call': CallToolHandlerResult; + 'tools/list': InputInfer; + 'sampling/createMessage': + | InputInfer + | InputInfer + | InputInfer; + 'elicitation/create': ElicitResult | CreateTaskResult; + 'roots/list': InputInfer; + 'tasks/get': InputInfer; + 'tasks/result': InputInfer; + 'tasks/list': InputInfer; + 'tasks/cancel': InputInfer; +}; diff --git a/test/client/index.test.ts b/test/client/index.test.ts index f5c6a348d1..a717b4f0e4 100644 --- a/test/client/index.test.ts +++ b/test/client/index.test.ts @@ -35,6 +35,85 @@ import * as z3 from 'zod/v3'; import * as z4 from 'zod/v4'; describe('Zod v4', () => { + test('should reject invalid standard request handler results at typecheck time', () => { + const client = new Client( + { + name: 'TypecheckClient', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {}, + roots: {}, + sampling: {} + } + } + ); + + client.setRequestHandler( + ListRootsRequestSchema, + // @ts-expect-error roots/list results must include roots + async () => ({}) + ); + client.setRequestHandler( + CreateMessageRequestSchema, + // @ts-expect-error sampling/createMessage results must include model, role, and content + async () => ({}) + ); + client.setRequestHandler( + ElicitRequestSchema, + // @ts-expect-error elicitation/create results must include action + async () => ({}) + ); + client.setRequestHandler( + ElicitRequestSchema, + // @ts-expect-error elicitation/create action must be a protocol action + async () => ({ action: 'approve' }) + ); + client.setRequestHandler( + ElicitRequestSchema, + // @ts-expect-error elicitation/create content values must match the requested schema value types + async () => ({ action: 'accept', content: { nested: {} } }) + ); + + client.setRequestHandler(CreateMessageRequestSchema, async () => ({ + model: 'test-model', + role: 'assistant', + content: { + type: 'text', + text: 'Test response' + } + })); + client.setRequestHandler(CreateMessageRequestSchema, async () => ({ + model: 'test-model', + role: 'assistant', + stopReason: 'toolUse', + content: [{ type: 'tool_use', id: 'call_1', name: 'test_tool', input: { arg: 'value' } }] + })); + client.setRequestHandler(CreateMessageRequestSchema, async () => ({ + task: { + taskId: 'sampling-task', + status: 'working', + ttl: null, + createdAt: '2026-05-16T00:00:00.000Z', + lastUpdatedAt: '2026-05-16T00:00:00.000Z' + } + })); + client.setRequestHandler(ElicitRequestSchema, async () => ({ + action: 'accept', + content: { confirmed: true } + })); + client.setRequestHandler(ElicitRequestSchema, async () => ({ + task: { + taskId: 'elicitation-task', + status: 'working', + ttl: null, + createdAt: '2026-05-16T00:00:00.000Z', + lastUpdatedAt: '2026-05-16T00:00:00.000Z' + } + })); + }); + /*** * Test: Type Checking * Test that custom request/notification/result schemas can be used with the Client class. @@ -724,7 +803,7 @@ test('should only allow setRequestHandler for declared capabilities', () => { // This should throw because roots listing is not a declared capability expect(() => { - client.setRequestHandler(ListRootsRequestSchema, () => ({})); + client.setRequestHandler(ListRootsRequestSchema, () => ({ roots: [] })); }).toThrow('Client does not support roots capability'); }); @@ -2707,7 +2786,7 @@ describe('Task-based execution', () => { client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { const result = { - action: 'accept', + action: 'accept' as const, content: { username: 'list-user' } }; @@ -2800,7 +2879,7 @@ describe('Task-based execution', () => { client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { const result = { - action: 'accept', + action: 'accept' as const, content: { username: 'list-user' } }; @@ -2892,7 +2971,7 @@ describe('Task-based execution', () => { client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { const result = { - action: 'accept', + action: 'accept' as const, content: { username: 'result-user' } }; @@ -2983,7 +3062,7 @@ describe('Task-based execution', () => { client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { const result = { - action: 'accept', + action: 'accept' as const, content: { username: 'list-user' } }; diff --git a/test/server/elicitation.test.ts b/test/server/elicitation.test.ts index c6f297b462..d4496ce5da 100644 --- a/test/server/elicitation.test.ts +++ b/test/server/elicitation.test.ts @@ -9,7 +9,7 @@ import { Client } from '../../src/client/index.js'; import { InMemoryTransport } from '../../src/inMemory.js'; -import { ElicitRequestFormParams, ElicitRequestSchema } from '../../src/types.js'; +import { ElicitRequestFormParams, ElicitRequestSchema, type ElicitResult } from '../../src/types.js'; import { AjvJsonSchemaValidator } from '../../src/validation/ajv-provider.js'; import { CfWorkerJsonSchemaValidator } from '../../src/validation/cfworker-provider.js'; import { Server } from '../../src/server/index.js'; @@ -338,7 +338,7 @@ function testElicitationFlow(validatorProvider: typeof ajvProvider | typeof cfWo test(`${validatorName}: should handle multiple sequential elicitation requests`, async () => { let requestCount = 0; - client.setRequestHandler(ElicitRequestSchema, request => { + client.setRequestHandler(ElicitRequestSchema, (request): ElicitResult => { requestCount++; if (request.params.message.includes('name')) { return { action: 'accept', content: { name: 'Alice' } }; diff --git a/test/server/index.test.ts b/test/server/index.test.ts index e434e57fc0..c7dd9667da 100644 --- a/test/server/index.test.ts +++ b/test/server/index.test.ts @@ -115,6 +115,59 @@ describe('Zod v3', () => { }); describe('Zod v4', () => { + test('should reject invalid standard request handler results at typecheck time', () => { + const server = new Server( + { + name: 'TypecheckServer', + version: '1.0.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {} + } + } + ); + + server.setRequestHandler( + ListToolsRequestSchema, + // @ts-expect-error tools/list results must include tools + async () => ({}) + ); + server.setRequestHandler( + CallToolRequestSchema, + // @ts-expect-error tools/call results must include content, structuredContent, or a task + async () => ({}) + ); + server.setRequestHandler( + ListResourcesRequestSchema, + // @ts-expect-error resources/list results must include resources + async () => ({}) + ); + server.setRequestHandler( + ListPromptsRequestSchema, + // @ts-expect-error prompts/list results must include prompts + async () => ({}) + ); + + server.setRequestHandler(CallToolRequestSchema, async () => ({ + content: [{ type: 'text', text: 'ok' }] + })); + server.setRequestHandler(CallToolRequestSchema, async () => ({ + structuredContent: { ok: true } + })); + server.setRequestHandler(CallToolRequestSchema, async () => ({ + task: { + taskId: 'tool-task', + status: 'working', + ttl: null, + createdAt: '2026-05-16T00:00:00.000Z', + lastUpdatedAt: '2026-05-16T00:00:00.000Z' + } + })); + }); + test('should typecheck', () => { const GetWeatherRequestSchema = RequestSchema.extend({ method: z4.literal('weather/get'), @@ -452,13 +505,18 @@ test('should respect client elicitation capabilities', async () => { } ); - client.setRequestHandler(ElicitRequestSchema, params => ({ - action: 'accept', - content: { - username: params.params.message.includes('username') ? 'test-user' : undefined, + client.setRequestHandler(ElicitRequestSchema, params => { + const content: Record = { confirmed: true + }; + if (params.params.message.includes('username')) { + content.username = 'test-user'; } - })); + return { + action: 'accept', + content + }; + }); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -536,13 +594,18 @@ test('should use elicitInput with mode: "form" by default for backwards compatib } ); - client.setRequestHandler(ElicitRequestSchema, params => ({ - action: 'accept', - content: { - username: params.params.message.includes('username') ? 'test-user' : undefined, + client.setRequestHandler(ElicitRequestSchema, params => { + const content: Record = { confirmed: true + }; + if (params.params.message.includes('username')) { + content.username = 'test-user'; } - })); + return { + action: 'accept', + content + }; + }); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); @@ -2596,7 +2659,7 @@ describe('Task-based execution', () => { client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { const result = { - action: 'accept', + action: 'accept' as const, content: { username: 'server-test-user', confirmed: true } }; @@ -2677,7 +2740,7 @@ describe('Task-based execution', () => { client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { const result = { - action: 'accept', + action: 'accept' as const, content: { username: 'list-user' } }; @@ -2756,7 +2819,7 @@ describe('Task-based execution', () => { client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { const result = { - action: 'accept', + action: 'accept' as const, content: { username: 'result-user', confirmed: true } }; @@ -2837,7 +2900,7 @@ describe('Task-based execution', () => { client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { const result = { - action: 'accept', + action: 'accept' as const, content: { username: 'list-user' } }; @@ -3180,7 +3243,7 @@ test('should respect client task capabilities', async () => { client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { const result = { - action: 'accept', + action: 'accept' as const, content: { username: 'test-user' } }; diff --git a/test/shared/protocol.test.ts b/test/shared/protocol.test.ts index 733146f29b..bee411624a 100644 --- a/test/shared/protocol.test.ts +++ b/test/shared/protocol.test.ts @@ -1263,7 +1263,9 @@ describe('Task-based execution', () => { ttl: 60000, pollInterval: 1000 }); - return { result: 'success' }; + return { + content: [{ type: 'text' as const, text: 'success' }] + }; }); transport.onmessage?.({ @@ -2008,7 +2010,7 @@ describe('Task-based execution', () => { }); return { - content: [{ type: 'text', text: 'done' }] + content: [{ type: 'text' as const, text: 'done' }] }; });