Skip to content

Commit 4621105

Browse files
authored
Merge pull request #627 from KKonstantinov/request-propagation-in-tools
Raw request propagation in tools - passed to callbacks via RequestHandlerExtra<ServerRequest, ServerNotification>
2 parents f733207 + a94a63a commit 4621105

File tree

9 files changed

+319
-21
lines changed

9 files changed

+319
-21
lines changed

src/server/index.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import {
1515
ListResourcesRequestSchema,
1616
ListToolsRequestSchema,
1717
SetLevelRequestSchema,
18-
ErrorCode,
18+
ErrorCode
1919
} from "../types.js";
2020
import { Transport } from "../shared/transport.js";
2121
import { InMemoryTransport } from "../inMemory.js";

src/server/mcp.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import {
1414
LoggingMessageNotificationSchema,
1515
Notification,
1616
TextContent,
17-
ElicitRequestSchema,
17+
ElicitRequestSchema
1818
} from "../types.js";
1919
import { ResourceTemplate } from "./mcp.js";
2020
import { completable } from "./completable.js";

src/server/sse.test.ts

Lines changed: 198 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,146 @@
11
import http from 'http';
22
import { jest } from '@jest/globals';
33
import { SSEServerTransport } from './sse.js';
4+
import { McpServer } from './mcp.js';
5+
import { createServer, type Server } from "node:http";
6+
import { AddressInfo } from "node:net";
7+
import { z } from 'zod';
8+
import { CallToolResult, JSONRPCMessage } from 'src/types.js';
49

510
const createMockResponse = () => {
611
const res = {
7-
writeHead: jest.fn<http.ServerResponse['writeHead']>(),
8-
write: jest.fn<http.ServerResponse['write']>().mockReturnValue(true),
9-
on: jest.fn<http.ServerResponse['on']>(),
12+
writeHead: jest.fn<http.ServerResponse['writeHead']>().mockReturnThis(),
13+
write: jest.fn<http.ServerResponse['write']>().mockReturnThis(),
14+
on: jest.fn<http.ServerResponse['on']>().mockReturnThis(),
15+
end: jest.fn<http.ServerResponse['end']>().mockReturnThis(),
1016
};
11-
res.writeHead.mockReturnThis();
12-
res.on.mockReturnThis();
1317

14-
return res as unknown as http.ServerResponse;
18+
return res as unknown as jest.Mocked<http.ServerResponse>;
1519
};
1620

21+
/**
22+
* Helper to create and start test HTTP server with MCP setup
23+
*/
24+
async function createTestServerWithSse(args: {
25+
mockRes: http.ServerResponse;
26+
}): Promise<{
27+
server: Server;
28+
transport: SSEServerTransport;
29+
mcpServer: McpServer;
30+
baseUrl: URL;
31+
sessionId: string
32+
serverPort: number;
33+
}> {
34+
const mcpServer = new McpServer(
35+
{ name: "test-server", version: "1.0.0" },
36+
{ capabilities: { logging: {} } }
37+
);
38+
39+
mcpServer.tool(
40+
"greet",
41+
"A simple greeting tool",
42+
{ name: z.string().describe("Name to greet") },
43+
async ({ name }): Promise<CallToolResult> => {
44+
return { content: [{ type: "text", text: `Hello, ${name}!` }] };
45+
}
46+
);
47+
48+
const endpoint = '/messages';
49+
50+
const transport = new SSEServerTransport(endpoint, args.mockRes);
51+
const sessionId = transport.sessionId;
52+
53+
await mcpServer.connect(transport);
54+
55+
const server = createServer(async (req, res) => {
56+
try {
57+
await transport.handlePostMessage(req, res);
58+
} catch (error) {
59+
console.error("Error handling request:", error);
60+
if (!res.headersSent) res.writeHead(500).end();
61+
}
62+
});
63+
64+
const baseUrl = await new Promise<URL>((resolve) => {
65+
server.listen(0, "127.0.0.1", () => {
66+
const addr = server.address() as AddressInfo;
67+
resolve(new URL(`http://127.0.0.1:${addr.port}`));
68+
});
69+
});
70+
71+
const port = (server.address() as AddressInfo).port;
72+
73+
return { server, transport, mcpServer, baseUrl, sessionId, serverPort: port };
74+
}
75+
76+
async function readAllSSEEvents(response: Response): Promise<string[]> {
77+
const reader = response.body?.getReader();
78+
if (!reader) throw new Error('No readable stream');
79+
80+
const events: string[] = [];
81+
const decoder = new TextDecoder();
82+
83+
try {
84+
while (true) {
85+
const { done, value } = await reader.read();
86+
if (done) break;
87+
88+
if (value) {
89+
events.push(decoder.decode(value));
90+
}
91+
}
92+
} finally {
93+
reader.releaseLock();
94+
}
95+
96+
return events;
97+
}
98+
99+
/**
100+
* Helper to send JSON-RPC request
101+
*/
102+
async function sendSsePostRequest(baseUrl: URL, message: JSONRPCMessage | JSONRPCMessage[], sessionId?: string, extraHeaders?: Record<string, string>): Promise<Response> {
103+
const headers: Record<string, string> = {
104+
"Content-Type": "application/json",
105+
Accept: "application/json, text/event-stream",
106+
...extraHeaders
107+
};
108+
109+
if (sessionId) {
110+
baseUrl.searchParams.set('sessionId', sessionId);
111+
}
112+
113+
return fetch(baseUrl, {
114+
method: "POST",
115+
headers,
116+
body: JSON.stringify(message),
117+
});
118+
}
119+
17120
describe('SSEServerTransport', () => {
121+
122+
async function initializeServer(baseUrl: URL): Promise<void> {
123+
const response = await sendSsePostRequest(baseUrl, {
124+
jsonrpc: "2.0",
125+
method: "initialize",
126+
params: {
127+
clientInfo: { name: "test-client", version: "1.0" },
128+
protocolVersion: "2025-03-26",
129+
capabilities: {
130+
},
131+
},
132+
133+
id: "init-1",
134+
} as JSONRPCMessage);
135+
136+
expect(response.status).toBe(202);
137+
138+
const text = await readAllSSEEvents(response);
139+
140+
expect(text).toHaveLength(1);
141+
expect(text[0]).toBe('Accepted');
142+
}
143+
18144
describe('start method', () => {
19145
it('should correctly append sessionId to a simple relative endpoint', async () => {
20146
const mockRes = createMockResponse();
@@ -105,5 +231,71 @@ describe('SSEServerTransport', () => {
105231
`event: endpoint\ndata: /?sessionId=${expectedSessionId}\n\n`
106232
);
107233
});
234+
235+
/**
236+
* Test: Tool With Request Info
237+
*/
238+
it("should pass request info to tool callback", async () => {
239+
const mockRes = createMockResponse();
240+
const { mcpServer, baseUrl, sessionId, serverPort } = await createTestServerWithSse({ mockRes });
241+
await initializeServer(baseUrl);
242+
243+
mcpServer.tool(
244+
"test-request-info",
245+
"A simple test tool with request info",
246+
{ name: z.string().describe("Name to greet") },
247+
async ({ name }, { requestInfo }): Promise<CallToolResult> => {
248+
return { content: [{ type: "text", text: `Hello, ${name}!` }, { type: "text", text: `${JSON.stringify(requestInfo)}` }] };
249+
}
250+
);
251+
252+
const toolCallMessage: JSONRPCMessage = {
253+
jsonrpc: "2.0",
254+
method: "tools/call",
255+
params: {
256+
name: "test-request-info",
257+
arguments: {
258+
name: "Test User",
259+
},
260+
},
261+
id: "call-1",
262+
};
263+
264+
const response = await sendSsePostRequest(baseUrl, toolCallMessage, sessionId);
265+
266+
expect(response.status).toBe(202);
267+
268+
expect(mockRes.write).toHaveBeenCalledWith(`event: endpoint\ndata: /messages?sessionId=${sessionId}\n\n`);
269+
270+
const expectedMessage = {
271+
result: {
272+
content: [
273+
{
274+
type: "text",
275+
text: "Hello, Test User!",
276+
},
277+
{
278+
type: "text",
279+
text: JSON.stringify({
280+
headers: {
281+
host: `127.0.0.1:${serverPort}`,
282+
connection: 'keep-alive',
283+
'content-type': 'application/json',
284+
accept: 'application/json, text/event-stream',
285+
'accept-language': '*',
286+
'sec-fetch-mode': 'cors',
287+
'user-agent': 'node',
288+
'accept-encoding': 'gzip, deflate',
289+
'content-length': '124'
290+
},
291+
})
292+
},
293+
],
294+
},
295+
jsonrpc: "2.0",
296+
id: "call-1",
297+
};
298+
expect(mockRes.write).toHaveBeenCalledWith(`event: message\ndata: ${JSON.stringify(expectedMessage)}\n\n`);
299+
});
108300
});
109301
});

src/server/sse.ts

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { randomUUID } from "node:crypto";
22
import { IncomingMessage, ServerResponse } from "node:http";
33
import { Transport } from "../shared/transport.js";
4-
import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
4+
import { JSONRPCMessage, JSONRPCMessageSchema, MessageExtraInfo, RequestInfo } from "../types.js";
55
import getRawBody from "raw-body";
66
import contentType from "content-type";
77
import { AuthInfo } from "./auth/types.js";
@@ -19,7 +19,7 @@ export class SSEServerTransport implements Transport {
1919
private _sessionId: string;
2020
onclose?: () => void;
2121
onerror?: (error: Error) => void;
22-
onmessage?: (message: JSONRPCMessage, extra?: { authInfo?: AuthInfo }) => void;
22+
onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void;
2323

2424
/**
2525
* Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL identified by `_endpoint`.
@@ -86,6 +86,7 @@ export class SSEServerTransport implements Transport {
8686
throw new Error(message);
8787
}
8888
const authInfo: AuthInfo | undefined = req.auth;
89+
const requestInfo: RequestInfo = { headers: req.headers };
8990

9091
let body: string | unknown;
9192
try {
@@ -105,7 +106,7 @@ export class SSEServerTransport implements Transport {
105106
}
106107

107108
try {
108-
await this.handleMessage(typeof body === 'string' ? JSON.parse(body) : body, { authInfo });
109+
await this.handleMessage(typeof body === 'string' ? JSON.parse(body) : body, { requestInfo, authInfo });
109110
} catch {
110111
res.writeHead(400).end(`Invalid message: ${body}`);
111112
return;
@@ -117,7 +118,7 @@ export class SSEServerTransport implements Transport {
117118
/**
118119
* Handle a client message, regardless of how it arrived. This can be used to inform the server of messages that arrive via a means different than HTTP POST.
119120
*/
120-
async handleMessage(message: unknown, extra?: { authInfo?: AuthInfo }): Promise<void> {
121+
async handleMessage(message: unknown, extra?: MessageExtraInfo): Promise<void> {
121122
let parsedMessage: JSONRPCMessage;
122123
try {
123124
parsedMessage = JSONRPCMessageSchema.parse(message);

src/server/streamableHttp.test.ts

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ function expectErrorResponse(data: unknown, expectedCode: number, expectedMessag
208208

209209
describe("StreamableHTTPServerTransport", () => {
210210
let server: Server;
211+
let mcpServer: McpServer;
211212
let transport: StreamableHTTPServerTransport;
212213
let baseUrl: URL;
213214
let sessionId: string;
@@ -216,6 +217,7 @@ describe("StreamableHTTPServerTransport", () => {
216217
const result = await createTestServer();
217218
server = result.server;
218219
transport = result.transport;
220+
mcpServer = result.mcpServer;
219221
baseUrl = result.baseUrl;
220222
});
221223

@@ -347,6 +349,69 @@ describe("StreamableHTTPServerTransport", () => {
347349
});
348350
});
349351

352+
/***
353+
* Test: Tool With Request Info
354+
*/
355+
it("should pass request info to tool callback", async () => {
356+
sessionId = await initializeServer();
357+
358+
mcpServer.tool(
359+
"test-request-info",
360+
"A simple test tool with request info",
361+
{ name: z.string().describe("Name to greet") },
362+
async ({ name }, { requestInfo }): Promise<CallToolResult> => {
363+
return { content: [{ type: "text", text: `Hello, ${name}!` }, { type: "text", text: `${JSON.stringify(requestInfo)}` }] };
364+
}
365+
);
366+
367+
const toolCallMessage: JSONRPCMessage = {
368+
jsonrpc: "2.0",
369+
method: "tools/call",
370+
params: {
371+
name: "test-request-info",
372+
arguments: {
373+
name: "Test User",
374+
},
375+
},
376+
id: "call-1",
377+
};
378+
379+
const response = await sendPostRequest(baseUrl, toolCallMessage, sessionId);
380+
expect(response.status).toBe(200);
381+
382+
const text = await readSSEEvent(response);
383+
const eventLines = text.split("\n");
384+
const dataLine = eventLines.find(line => line.startsWith("data:"));
385+
expect(dataLine).toBeDefined();
386+
387+
const eventData = JSON.parse(dataLine!.substring(5));
388+
389+
expect(eventData).toMatchObject({
390+
jsonrpc: "2.0",
391+
result: {
392+
content: [
393+
{ type: "text", text: "Hello, Test User!" },
394+
{ type: "text", text: expect.any(String) }
395+
],
396+
},
397+
id: "call-1",
398+
});
399+
400+
const requestInfo = JSON.parse(eventData.result.content[1].text);
401+
expect(requestInfo).toMatchObject({
402+
headers: {
403+
'content-type': 'application/json',
404+
accept: 'application/json, text/event-stream',
405+
connection: 'keep-alive',
406+
'mcp-session-id': sessionId,
407+
'accept-language': '*',
408+
'user-agent': expect.any(String),
409+
'accept-encoding': expect.any(String),
410+
'content-length': expect.any(String),
411+
},
412+
});
413+
});
414+
350415
it("should reject requests without a valid session ID", async () => {
351416
const response = await sendPostRequest(baseUrl, TEST_MESSAGES.toolsList);
352417

0 commit comments

Comments
 (0)