|
1 | 1 | import http from 'http';
|
2 | 2 | import { jest } from '@jest/globals';
|
3 | 3 | import { SSEServerTransport } from './sse.js';
|
| 4 | +import { McpServer } from './mcp.js'; |
| 5 | +import { createServer, type Server } from "node:http"; |
| 6 | +import { AddressInfo } from "node:net"; |
| 7 | +import { z } from 'zod'; |
| 8 | +import { CallToolResult, JSONRPCMessage } from 'src/types.js'; |
4 | 9 |
|
5 | 10 | const createMockResponse = () => {
|
6 | 11 | const res = {
|
7 |
| - writeHead: jest.fn<http.ServerResponse['writeHead']>(), |
8 |
| - write: jest.fn<http.ServerResponse['write']>().mockReturnValue(true), |
9 |
| - on: jest.fn<http.ServerResponse['on']>(), |
| 12 | + writeHead: jest.fn<http.ServerResponse['writeHead']>().mockReturnThis(), |
| 13 | + write: jest.fn<http.ServerResponse['write']>().mockReturnThis(), |
| 14 | + on: jest.fn<http.ServerResponse['on']>().mockReturnThis(), |
| 15 | + end: jest.fn<http.ServerResponse['end']>().mockReturnThis(), |
10 | 16 | };
|
11 |
| - res.writeHead.mockReturnThis(); |
12 |
| - res.on.mockReturnThis(); |
13 | 17 |
|
14 |
| - return res as unknown as http.ServerResponse; |
| 18 | + return res as unknown as jest.Mocked<http.ServerResponse>; |
15 | 19 | };
|
16 | 20 |
|
| 21 | +/** |
| 22 | + * Helper to create and start test HTTP server with MCP setup |
| 23 | + */ |
| 24 | +async function createTestServerWithSse(args: { |
| 25 | + mockRes: http.ServerResponse; |
| 26 | +}): Promise<{ |
| 27 | + server: Server; |
| 28 | + transport: SSEServerTransport; |
| 29 | + mcpServer: McpServer; |
| 30 | + baseUrl: URL; |
| 31 | + sessionId: string |
| 32 | + serverPort: number; |
| 33 | +}> { |
| 34 | + const mcpServer = new McpServer( |
| 35 | + { name: "test-server", version: "1.0.0" }, |
| 36 | + { capabilities: { logging: {} } } |
| 37 | + ); |
| 38 | + |
| 39 | + mcpServer.tool( |
| 40 | + "greet", |
| 41 | + "A simple greeting tool", |
| 42 | + { name: z.string().describe("Name to greet") }, |
| 43 | + async ({ name }): Promise<CallToolResult> => { |
| 44 | + return { content: [{ type: "text", text: `Hello, ${name}!` }] }; |
| 45 | + } |
| 46 | + ); |
| 47 | + |
| 48 | + const endpoint = '/messages'; |
| 49 | + |
| 50 | + const transport = new SSEServerTransport(endpoint, args.mockRes); |
| 51 | + const sessionId = transport.sessionId; |
| 52 | + |
| 53 | + await mcpServer.connect(transport); |
| 54 | + |
| 55 | + const server = createServer(async (req, res) => { |
| 56 | + try { |
| 57 | + await transport.handlePostMessage(req, res); |
| 58 | + } catch (error) { |
| 59 | + console.error("Error handling request:", error); |
| 60 | + if (!res.headersSent) res.writeHead(500).end(); |
| 61 | + } |
| 62 | + }); |
| 63 | + |
| 64 | + const baseUrl = await new Promise<URL>((resolve) => { |
| 65 | + server.listen(0, "127.0.0.1", () => { |
| 66 | + const addr = server.address() as AddressInfo; |
| 67 | + resolve(new URL(`http://127.0.0.1:${addr.port}`)); |
| 68 | + }); |
| 69 | + }); |
| 70 | + |
| 71 | + const port = (server.address() as AddressInfo).port; |
| 72 | + |
| 73 | + return { server, transport, mcpServer, baseUrl, sessionId, serverPort: port }; |
| 74 | +} |
| 75 | + |
| 76 | +async function readAllSSEEvents(response: Response): Promise<string[]> { |
| 77 | + const reader = response.body?.getReader(); |
| 78 | + if (!reader) throw new Error('No readable stream'); |
| 79 | + |
| 80 | + const events: string[] = []; |
| 81 | + const decoder = new TextDecoder(); |
| 82 | + |
| 83 | + try { |
| 84 | + while (true) { |
| 85 | + const { done, value } = await reader.read(); |
| 86 | + if (done) break; |
| 87 | + |
| 88 | + if (value) { |
| 89 | + events.push(decoder.decode(value)); |
| 90 | + } |
| 91 | + } |
| 92 | + } finally { |
| 93 | + reader.releaseLock(); |
| 94 | + } |
| 95 | + |
| 96 | + return events; |
| 97 | +} |
| 98 | + |
| 99 | +/** |
| 100 | + * Helper to send JSON-RPC request |
| 101 | + */ |
| 102 | +async function sendSsePostRequest(baseUrl: URL, message: JSONRPCMessage | JSONRPCMessage[], sessionId?: string, extraHeaders?: Record<string, string>): Promise<Response> { |
| 103 | + const headers: Record<string, string> = { |
| 104 | + "Content-Type": "application/json", |
| 105 | + Accept: "application/json, text/event-stream", |
| 106 | + ...extraHeaders |
| 107 | + }; |
| 108 | + |
| 109 | + if (sessionId) { |
| 110 | + baseUrl.searchParams.set('sessionId', sessionId); |
| 111 | + } |
| 112 | + |
| 113 | + return fetch(baseUrl, { |
| 114 | + method: "POST", |
| 115 | + headers, |
| 116 | + body: JSON.stringify(message), |
| 117 | + }); |
| 118 | +} |
| 119 | + |
17 | 120 | describe('SSEServerTransport', () => {
|
| 121 | + |
| 122 | + async function initializeServer(baseUrl: URL): Promise<void> { |
| 123 | + const response = await sendSsePostRequest(baseUrl, { |
| 124 | + jsonrpc: "2.0", |
| 125 | + method: "initialize", |
| 126 | + params: { |
| 127 | + clientInfo: { name: "test-client", version: "1.0" }, |
| 128 | + protocolVersion: "2025-03-26", |
| 129 | + capabilities: { |
| 130 | + }, |
| 131 | + }, |
| 132 | + |
| 133 | + id: "init-1", |
| 134 | + } as JSONRPCMessage); |
| 135 | + |
| 136 | + expect(response.status).toBe(202); |
| 137 | + |
| 138 | + const text = await readAllSSEEvents(response); |
| 139 | + |
| 140 | + expect(text).toHaveLength(1); |
| 141 | + expect(text[0]).toBe('Accepted'); |
| 142 | + } |
| 143 | + |
18 | 144 | describe('start method', () => {
|
19 | 145 | it('should correctly append sessionId to a simple relative endpoint', async () => {
|
20 | 146 | const mockRes = createMockResponse();
|
@@ -105,5 +231,71 @@ describe('SSEServerTransport', () => {
|
105 | 231 | `event: endpoint\ndata: /?sessionId=${expectedSessionId}\n\n`
|
106 | 232 | );
|
107 | 233 | });
|
| 234 | + |
| 235 | + /** |
| 236 | + * Test: Tool With Request Info |
| 237 | + */ |
| 238 | + it("should pass request info to tool callback", async () => { |
| 239 | + const mockRes = createMockResponse(); |
| 240 | + const { mcpServer, baseUrl, sessionId, serverPort } = await createTestServerWithSse({ mockRes }); |
| 241 | + await initializeServer(baseUrl); |
| 242 | + |
| 243 | + mcpServer.tool( |
| 244 | + "test-request-info", |
| 245 | + "A simple test tool with request info", |
| 246 | + { name: z.string().describe("Name to greet") }, |
| 247 | + async ({ name }, { requestInfo }): Promise<CallToolResult> => { |
| 248 | + return { content: [{ type: "text", text: `Hello, ${name}!` }, { type: "text", text: `${JSON.stringify(requestInfo)}` }] }; |
| 249 | + } |
| 250 | + ); |
| 251 | + |
| 252 | + const toolCallMessage: JSONRPCMessage = { |
| 253 | + jsonrpc: "2.0", |
| 254 | + method: "tools/call", |
| 255 | + params: { |
| 256 | + name: "test-request-info", |
| 257 | + arguments: { |
| 258 | + name: "Test User", |
| 259 | + }, |
| 260 | + }, |
| 261 | + id: "call-1", |
| 262 | + }; |
| 263 | + |
| 264 | + const response = await sendSsePostRequest(baseUrl, toolCallMessage, sessionId); |
| 265 | + |
| 266 | + expect(response.status).toBe(202); |
| 267 | + |
| 268 | + expect(mockRes.write).toHaveBeenCalledWith(`event: endpoint\ndata: /messages?sessionId=${sessionId}\n\n`); |
| 269 | + |
| 270 | + const expectedMessage = { |
| 271 | + result: { |
| 272 | + content: [ |
| 273 | + { |
| 274 | + type: "text", |
| 275 | + text: "Hello, Test User!", |
| 276 | + }, |
| 277 | + { |
| 278 | + type: "text", |
| 279 | + text: JSON.stringify({ |
| 280 | + headers: { |
| 281 | + host: `127.0.0.1:${serverPort}`, |
| 282 | + connection: 'keep-alive', |
| 283 | + 'content-type': 'application/json', |
| 284 | + accept: 'application/json, text/event-stream', |
| 285 | + 'accept-language': '*', |
| 286 | + 'sec-fetch-mode': 'cors', |
| 287 | + 'user-agent': 'node', |
| 288 | + 'accept-encoding': 'gzip, deflate', |
| 289 | + 'content-length': '124' |
| 290 | + }, |
| 291 | + }) |
| 292 | + }, |
| 293 | + ], |
| 294 | + }, |
| 295 | + jsonrpc: "2.0", |
| 296 | + id: "call-1", |
| 297 | + }; |
| 298 | + expect(mockRes.write).toHaveBeenCalledWith(`event: message\ndata: ${JSON.stringify(expectedMessage)}\n\n`); |
| 299 | + }); |
108 | 300 | });
|
109 | 301 | });
|
0 commit comments