Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/fix-request-handler-result-types.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@modelcontextprotocol/sdk': patch
---

Tighten `setRequestHandler` return types for standard MCP methods so fixed-shape request handlers reject invalid result objects at compile time.
4 changes: 2 additions & 2 deletions src/client/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { mergeCapabilities, Protocol, type ProtocolOptions, type RequestOptions } from '../shared/protocol.js';
import { mergeCapabilities, Protocol, type ProtocolOptions, type RequestHandlerResult, type RequestOptions } from '../shared/protocol.js';
import type { Transport } from '../shared/transport.js';

import {
Expand Down Expand Up @@ -331,7 +331,7 @@ export class Client<
handler: (
request: SchemaOutput<T>,
extra: RequestHandlerExtra<ClientRequest | RequestT, ClientNotification | NotificationT>
) => ClientResult | ResultT | Promise<ClientResult | ResultT>
) => RequestHandlerResult<T, ClientResult | ResultT> | Promise<RequestHandlerResult<T, ClientResult | ResultT>>
): void {
const shape = getObjectShape(requestSchema);
const methodSchema = shape?.method;
Expand Down
7 changes: 2 additions & 5 deletions src/examples/client/simpleTaskInteractiveClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import {
CreateMessageRequestSchema,
CreateMessageRequest,
CreateMessageResult,
type ElicitResult,
ErrorCode,
McpError
} from '../../types.js';
Expand All @@ -40,11 +41,7 @@ function getTextContent(result: { content: Array<{ type: string; text?: string }
return textContent?.text ?? '(no text)';
}

async function elicitationCallback(params: {
mode?: string;
message: string;
requestedSchema?: object;
}): Promise<{ action: string; content?: Record<string, unknown> }> {
async function elicitationCallback(params: { mode?: string; message: string; requestedSchema?: object }): Promise<ElicitResult> {
console.log(`\n[Elicitation] Server asks: ${params.message}`);

// Simple terminal prompt for y/n
Expand Down
11 changes: 9 additions & 2 deletions src/server/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
import { mergeCapabilities, Protocol, type NotificationOptions, type ProtocolOptions, type RequestOptions } from '../shared/protocol.js';
import {
mergeCapabilities,
Protocol,
type NotificationOptions,
type ProtocolOptions,
type RequestHandlerResult,
type RequestOptions
} from '../shared/protocol.js';
import {
type ClientCapabilities,
type CreateMessageRequest,
Expand Down Expand Up @@ -220,7 +227,7 @@ export class Server<
handler: (
request: SchemaOutput<T>,
extra: RequestHandlerExtra<ServerRequest | RequestT, ServerNotification | NotificationT>
) => ServerResult | ResultT | Promise<ServerResult | ResultT>
) => RequestHandlerResult<T, ServerResult | ResultT> | Promise<RequestHandlerResult<T, ServerResult | ResultT>>
): void {
const shape = getObjectShape(requestSchema);
const methodSchema = shape?.method;
Expand Down
45 changes: 38 additions & 7 deletions src/shared/protocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import {
RELATED_TASK_META_KEY,
RequestId,
Result,
ResultInputTypeMap,
ServerCapabilities,
RequestMeta,
MessageExtraInfo,
Expand Down Expand Up @@ -301,6 +302,38 @@ export type RequestHandlerExtra<SendRequestT extends Request, SendNotificationT
closeStandaloneSSEStream?: () => void;
};

type StrictRequestHandlerResultMap = Pick<
ResultInputTypeMap,
| 'ping'
| 'initialize'
| 'completion/complete'
| 'logging/setLevel'
| 'sampling/createMessage'
| 'prompts/get'
| 'prompts/list'
| 'resources/list'
| 'resources/templates/list'
| 'resources/read'
| 'resources/subscribe'
| 'resources/unsubscribe'
| 'tools/call'
| 'tools/list'
| 'elicitation/create'
| 'roots/list'
| 'tasks/get'
| 'tasks/list'
| 'tasks/cancel'
>;

export type RequestHandlerResult<T extends AnyObjectSchema, Fallback extends Result> =
SchemaOutput<T> extends {
method: infer M;
}
? M extends keyof StrictRequestHandlerResultMap
? StrictRequestHandlerResultMap[M]
: Fallback
: Fallback;

/**
* Information about a request's timeout state
*/
Expand Down Expand Up @@ -390,10 +423,9 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e

// Per spec: tasks/get responses SHALL NOT include related-task metadata
// as the taskId parameter is the source of truth
// @ts-expect-error SendResultT cannot contain GetTaskResult, but we include it in our derived types everywhere else
return {
...task
} as SendResultT;
};
});

this.setRequestHandler(GetTaskPayloadRequestSchema, async (request, extra) => {
Expand Down Expand Up @@ -486,12 +518,11 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
this.setRequestHandler(ListTasksRequestSchema, async (request, extra) => {
try {
const { tasks, nextCursor } = await this._taskStore!.listTasks(request.params?.cursor, extra.sessionId);
// @ts-expect-error SendResultT cannot contain ListTasksResult, but we include it in our derived types everywhere else
return {
tasks,
nextCursor,
_meta: {}
} as SendResultT;
};
} catch (error) {
throw new McpError(
ErrorCode.InvalidParams,
Expand Down Expand Up @@ -532,7 +563,7 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
return {
_meta: {},
...cancelledTask
} as unknown as SendResultT;
};
} catch (error) {
// Re-throw McpError as-is
if (error instanceof McpError) {
Expand Down Expand Up @@ -1420,14 +1451,14 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
handler: (
request: SchemaOutput<T>,
extra: RequestHandlerExtra<SendRequestT, SendNotificationT>
) => SendResultT | Promise<SendResultT>
) => RequestHandlerResult<T, SendResultT> | Promise<RequestHandlerResult<T, SendResultT>>
): void {
const method = getMethodLiteral(requestSchema);
this.assertRequestHandlerCapability(method);

this._requestHandlers.set(method, (request, extra) => {
const parsed = parseWithCompat(requestSchema, request) as SchemaOutput<T>;
return Promise.resolve(handler(parsed, extra));
return Promise.resolve(handler(parsed, extra)) as Promise<SendResultT>;
});
}

Expand Down
71 changes: 71 additions & 0 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2361,6 +2361,21 @@ type Flatten<T> = T extends Primitive
: T;

type Infer<Schema extends z.ZodTypeAny> = Flatten<z.infer<Schema>>;
type InputInfer<Schema extends z.ZodTypeAny> = Flatten<z.input<Schema>>;
type CallToolResultInput = InputInfer<typeof CallToolResultSchema>;
type CallToolHandlerResult =
| (Omit<CallToolResultInput, 'content' | 'structuredContent'> &
(
| {
content: NonNullable<CallToolResultInput['content']>;
structuredContent?: CallToolResultInput['structuredContent'];
}
| {
content?: CallToolResultInput['content'];
structuredContent: NonNullable<CallToolResultInput['structuredContent']>;
}
))
| CreateTaskResult;

/**
* Headers that are compatible with both Node.js and the browser.
Expand Down Expand Up @@ -2625,3 +2640,59 @@ export type ClientResult = Infer<typeof ClientResultSchema>;
export type ServerRequest = Infer<typeof ServerRequestSchema>;
export type ServerNotification = Infer<typeof ServerNotificationSchema>;
export type ServerResult = Infer<typeof ServerResultSchema>;

/* Protocol type maps */
type MethodToTypeMap<U> = {
[T in U as T extends { method: infer M extends string } ? M : never]: T;
};
export type RequestMethod = ClientRequest['method'] | ServerRequest['method'];
export type NotificationMethod = ClientNotification['method'] | ServerNotification['method'];
export type RequestTypeMap = MethodToTypeMap<ClientRequest | ServerRequest>;
export type NotificationTypeMap = MethodToTypeMap<ClientNotification | ServerNotification>;
export type ResultTypeMap = {
ping: EmptyResult;
initialize: InitializeResult;
'completion/complete': CompleteResult;
'logging/setLevel': EmptyResult;
'prompts/get': GetPromptResult;
'prompts/list': ListPromptsResult;
'resources/list': ListResourcesResult;
'resources/templates/list': ListResourceTemplatesResult;
'resources/read': ReadResourceResult;
'resources/subscribe': EmptyResult;
'resources/unsubscribe': EmptyResult;
'tools/call': CallToolResult | CreateTaskResult;
'tools/list': ListToolsResult;
'sampling/createMessage': CreateMessageResult | CreateMessageResultWithTools | CreateTaskResult;
'elicitation/create': ElicitResult | CreateTaskResult;
'roots/list': ListRootsResult;
'tasks/get': GetTaskResult;
'tasks/result': Result;
'tasks/list': ListTasksResult;
'tasks/cancel': CancelTaskResult;
};
export type ResultInputTypeMap = {
ping: InputInfer<typeof EmptyResultSchema>;
initialize: InputInfer<typeof InitializeResultSchema>;
'completion/complete': InputInfer<typeof CompleteResultSchema>;
'logging/setLevel': InputInfer<typeof EmptyResultSchema>;
'prompts/get': InputInfer<typeof GetPromptResultSchema>;
'prompts/list': InputInfer<typeof ListPromptsResultSchema>;
'resources/list': InputInfer<typeof ListResourcesResultSchema>;
'resources/templates/list': InputInfer<typeof ListResourceTemplatesResultSchema>;
'resources/read': InputInfer<typeof ReadResourceResultSchema>;
'resources/subscribe': InputInfer<typeof EmptyResultSchema>;
'resources/unsubscribe': InputInfer<typeof EmptyResultSchema>;
'tools/call': CallToolHandlerResult;
'tools/list': InputInfer<typeof ListToolsResultSchema>;
'sampling/createMessage':
| InputInfer<typeof CreateMessageResultSchema>
| InputInfer<typeof CreateMessageResultWithToolsSchema>
| InputInfer<typeof CreateTaskResultSchema>;
'elicitation/create': ElicitResult | CreateTaskResult;
'roots/list': InputInfer<typeof ListRootsResultSchema>;
'tasks/get': InputInfer<typeof GetTaskResultSchema>;
'tasks/result': InputInfer<typeof ResultSchema>;
'tasks/list': InputInfer<typeof ListTasksResultSchema>;
'tasks/cancel': InputInfer<typeof CancelTaskResultSchema>;
};
89 changes: 84 additions & 5 deletions test/client/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,85 @@ import * as z3 from 'zod/v3';
import * as z4 from 'zod/v4';

describe('Zod v4', () => {
test('should reject invalid standard request handler results at typecheck time', () => {
const client = new Client(
{
name: 'TypecheckClient',
version: '1.0.0'
},
{
capabilities: {
elicitation: {},
roots: {},
sampling: {}
}
}
);

client.setRequestHandler(
ListRootsRequestSchema,
// @ts-expect-error roots/list results must include roots
async () => ({})
);
client.setRequestHandler(
CreateMessageRequestSchema,
// @ts-expect-error sampling/createMessage results must include model, role, and content
async () => ({})
);
client.setRequestHandler(
ElicitRequestSchema,
// @ts-expect-error elicitation/create results must include action
async () => ({})
);
client.setRequestHandler(
ElicitRequestSchema,
// @ts-expect-error elicitation/create action must be a protocol action
async () => ({ action: 'approve' })
);
client.setRequestHandler(
ElicitRequestSchema,
// @ts-expect-error elicitation/create content values must match the requested schema value types
async () => ({ action: 'accept', content: { nested: {} } })
);

client.setRequestHandler(CreateMessageRequestSchema, async () => ({
model: 'test-model',
role: 'assistant',
content: {
type: 'text',
text: 'Test response'
}
}));
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
model: 'test-model',
role: 'assistant',
stopReason: 'toolUse',
content: [{ type: 'tool_use', id: 'call_1', name: 'test_tool', input: { arg: 'value' } }]
}));
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
task: {
taskId: 'sampling-task',
status: 'working',
ttl: null,
createdAt: '2026-05-16T00:00:00.000Z',
lastUpdatedAt: '2026-05-16T00:00:00.000Z'
}
}));
client.setRequestHandler(ElicitRequestSchema, async () => ({
action: 'accept',
content: { confirmed: true }
}));
client.setRequestHandler(ElicitRequestSchema, async () => ({
task: {
taskId: 'elicitation-task',
status: 'working',
ttl: null,
createdAt: '2026-05-16T00:00:00.000Z',
lastUpdatedAt: '2026-05-16T00:00:00.000Z'
}
}));
});

/***
* Test: Type Checking
* Test that custom request/notification/result schemas can be used with the Client class.
Expand Down Expand Up @@ -724,7 +803,7 @@ test('should only allow setRequestHandler for declared capabilities', () => {

// This should throw because roots listing is not a declared capability
expect(() => {
client.setRequestHandler(ListRootsRequestSchema, () => ({}));
client.setRequestHandler(ListRootsRequestSchema, () => ({ roots: [] }));
}).toThrow('Client does not support roots capability');
});

Expand Down Expand Up @@ -2707,7 +2786,7 @@ describe('Task-based execution', () => {

client.setRequestHandler(ElicitRequestSchema, async (request, extra) => {
const result = {
action: 'accept',
action: 'accept' as const,
content: { username: 'list-user' }
};

Expand Down Expand Up @@ -2800,7 +2879,7 @@ describe('Task-based execution', () => {

client.setRequestHandler(ElicitRequestSchema, async (request, extra) => {
const result = {
action: 'accept',
action: 'accept' as const,
content: { username: 'list-user' }
};

Expand Down Expand Up @@ -2892,7 +2971,7 @@ describe('Task-based execution', () => {

client.setRequestHandler(ElicitRequestSchema, async (request, extra) => {
const result = {
action: 'accept',
action: 'accept' as const,
content: { username: 'result-user' }
};

Expand Down Expand Up @@ -2983,7 +3062,7 @@ describe('Task-based execution', () => {

client.setRequestHandler(ElicitRequestSchema, async (request, extra) => {
const result = {
action: 'accept',
action: 'accept' as const,
content: { username: 'list-user' }
};

Expand Down
4 changes: 2 additions & 2 deletions test/server/elicitation.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import { Client } from '../../src/client/index.js';
import { InMemoryTransport } from '../../src/inMemory.js';
import { ElicitRequestFormParams, ElicitRequestSchema } from '../../src/types.js';
import { ElicitRequestFormParams, ElicitRequestSchema, type ElicitResult } from '../../src/types.js';
import { AjvJsonSchemaValidator } from '../../src/validation/ajv-provider.js';
import { CfWorkerJsonSchemaValidator } from '../../src/validation/cfworker-provider.js';
import { Server } from '../../src/server/index.js';
Expand Down Expand Up @@ -338,7 +338,7 @@ function testElicitationFlow(validatorProvider: typeof ajvProvider | typeof cfWo

test(`${validatorName}: should handle multiple sequential elicitation requests`, async () => {
let requestCount = 0;
client.setRequestHandler(ElicitRequestSchema, request => {
client.setRequestHandler(ElicitRequestSchema, (request): ElicitResult => {
requestCount++;
if (request.params.message.includes('name')) {
return { action: 'accept', content: { name: 'Alice' } };
Expand Down
Loading
Loading