Skip to content

Commit 11e84f0

Browse files
Merge pull request #751 from modelcontextprotocol/jerome/fix/allow-async-onsession-callbacks
feat: support async callbacks for onsessioninitialized and onsessionclosed
2 parents 031dfc2 + 0ddf682 commit 11e84f0

File tree

2 files changed

+217
-7
lines changed

2 files changed

+217
-7
lines changed

src/server/streamableHttp.test.ts

Lines changed: 211 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ interface TestServerConfig {
2929
enableJsonResponse?: boolean;
3030
customRequestHandler?: (req: IncomingMessage, res: ServerResponse, parsedBody?: unknown) => Promise<void>;
3131
eventStore?: EventStore;
32-
onsessionclosed?: (sessionId: string) => void;
32+
onsessioninitialized?: (sessionId: string) => void | Promise<void>;
33+
onsessionclosed?: (sessionId: string) => void | Promise<void>;
3334
}
3435

3536
/**
@@ -59,6 +60,7 @@ async function createTestServer(config: TestServerConfig = { sessionIdGenerator:
5960
sessionIdGenerator: config.sessionIdGenerator,
6061
enableJsonResponse: config.enableJsonResponse ?? false,
6162
eventStore: config.eventStore,
63+
onsessioninitialized: config.onsessioninitialized,
6264
onsessionclosed: config.onsessionclosed
6365
});
6466

@@ -114,6 +116,7 @@ async function createTestAuthServer(config: TestServerConfig = { sessionIdGenera
114116
sessionIdGenerator: config.sessionIdGenerator,
115117
enableJsonResponse: config.enableJsonResponse ?? false,
116118
eventStore: config.eventStore,
119+
onsessioninitialized: config.onsessioninitialized,
117120
onsessionclosed: config.onsessionclosed
118121
});
119122

@@ -1666,6 +1669,213 @@ describe("StreamableHTTPServerTransport onsessionclosed callback", () => {
16661669
});
16671670
});
16681671

1672+
// Test async callbacks for onsessioninitialized and onsessionclosed
1673+
describe("StreamableHTTPServerTransport async callbacks", () => {
1674+
it("should support async onsessioninitialized callback", async () => {
1675+
const initializationOrder: string[] = [];
1676+
1677+
// Create server with async onsessioninitialized callback
1678+
const result = await createTestServer({
1679+
sessionIdGenerator: () => randomUUID(),
1680+
onsessioninitialized: async (sessionId: string) => {
1681+
initializationOrder.push('async-start');
1682+
// Simulate async operation
1683+
await new Promise(resolve => setTimeout(resolve, 10));
1684+
initializationOrder.push('async-end');
1685+
initializationOrder.push(sessionId);
1686+
},
1687+
});
1688+
1689+
const tempServer = result.server;
1690+
const tempUrl = result.baseUrl;
1691+
1692+
// Initialize to trigger the callback
1693+
const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize);
1694+
const tempSessionId = initResponse.headers.get("mcp-session-id");
1695+
1696+
// Give time for async callback to complete
1697+
await new Promise(resolve => setTimeout(resolve, 50));
1698+
1699+
expect(initializationOrder).toEqual(['async-start', 'async-end', tempSessionId]);
1700+
1701+
// Clean up
1702+
tempServer.close();
1703+
});
1704+
1705+
it("should support sync onsessioninitialized callback (backwards compatibility)", async () => {
1706+
const capturedSessionId: string[] = [];
1707+
1708+
// Create server with sync onsessioninitialized callback
1709+
const result = await createTestServer({
1710+
sessionIdGenerator: () => randomUUID(),
1711+
onsessioninitialized: (sessionId: string) => {
1712+
capturedSessionId.push(sessionId);
1713+
},
1714+
});
1715+
1716+
const tempServer = result.server;
1717+
const tempUrl = result.baseUrl;
1718+
1719+
// Initialize to trigger the callback
1720+
const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize);
1721+
const tempSessionId = initResponse.headers.get("mcp-session-id");
1722+
1723+
expect(capturedSessionId).toEqual([tempSessionId]);
1724+
1725+
// Clean up
1726+
tempServer.close();
1727+
});
1728+
1729+
it("should support async onsessionclosed callback", async () => {
1730+
const closureOrder: string[] = [];
1731+
1732+
// Create server with async onsessionclosed callback
1733+
const result = await createTestServer({
1734+
sessionIdGenerator: () => randomUUID(),
1735+
onsessionclosed: async (sessionId: string) => {
1736+
closureOrder.push('async-close-start');
1737+
// Simulate async operation
1738+
await new Promise(resolve => setTimeout(resolve, 10));
1739+
closureOrder.push('async-close-end');
1740+
closureOrder.push(sessionId);
1741+
},
1742+
});
1743+
1744+
const tempServer = result.server;
1745+
const tempUrl = result.baseUrl;
1746+
1747+
// Initialize to get a session ID
1748+
const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize);
1749+
const tempSessionId = initResponse.headers.get("mcp-session-id");
1750+
expect(tempSessionId).toBeDefined();
1751+
1752+
// DELETE the session
1753+
const deleteResponse = await fetch(tempUrl, {
1754+
method: "DELETE",
1755+
headers: {
1756+
"mcp-session-id": tempSessionId || "",
1757+
"mcp-protocol-version": "2025-03-26",
1758+
},
1759+
});
1760+
1761+
expect(deleteResponse.status).toBe(200);
1762+
1763+
// Give time for async callback to complete
1764+
await new Promise(resolve => setTimeout(resolve, 50));
1765+
1766+
expect(closureOrder).toEqual(['async-close-start', 'async-close-end', tempSessionId]);
1767+
1768+
// Clean up
1769+
tempServer.close();
1770+
});
1771+
1772+
it("should propagate errors from async onsessioninitialized callback", async () => {
1773+
const consoleErrorSpy = jest.spyOn(console, 'error').mockImplementation();
1774+
1775+
// Create server with async onsessioninitialized callback that throws
1776+
const result = await createTestServer({
1777+
sessionIdGenerator: () => randomUUID(),
1778+
onsessioninitialized: async (_sessionId: string) => {
1779+
throw new Error('Async initialization error');
1780+
},
1781+
});
1782+
1783+
const tempServer = result.server;
1784+
const tempUrl = result.baseUrl;
1785+
1786+
// Initialize should fail when callback throws
1787+
const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize);
1788+
expect(initResponse.status).toBe(400);
1789+
1790+
// Clean up
1791+
consoleErrorSpy.mockRestore();
1792+
tempServer.close();
1793+
});
1794+
1795+
it("should propagate errors from async onsessionclosed callback", async () => {
1796+
const consoleErrorSpy = jest.spyOn(console, 'error').mockImplementation();
1797+
1798+
// Create server with async onsessionclosed callback that throws
1799+
const result = await createTestServer({
1800+
sessionIdGenerator: () => randomUUID(),
1801+
onsessionclosed: async (_sessionId: string) => {
1802+
throw new Error('Async closure error');
1803+
},
1804+
});
1805+
1806+
const tempServer = result.server;
1807+
const tempUrl = result.baseUrl;
1808+
1809+
// Initialize to get a session ID
1810+
const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize);
1811+
const tempSessionId = initResponse.headers.get("mcp-session-id");
1812+
1813+
// DELETE should fail when callback throws
1814+
const deleteResponse = await fetch(tempUrl, {
1815+
method: "DELETE",
1816+
headers: {
1817+
"mcp-session-id": tempSessionId || "",
1818+
"mcp-protocol-version": "2025-03-26",
1819+
},
1820+
});
1821+
1822+
expect(deleteResponse.status).toBe(500);
1823+
1824+
// Clean up
1825+
consoleErrorSpy.mockRestore();
1826+
tempServer.close();
1827+
});
1828+
1829+
it("should handle both async callbacks together", async () => {
1830+
const events: string[] = [];
1831+
1832+
// Create server with both async callbacks
1833+
const result = await createTestServer({
1834+
sessionIdGenerator: () => randomUUID(),
1835+
onsessioninitialized: async (sessionId: string) => {
1836+
await new Promise(resolve => setTimeout(resolve, 5));
1837+
events.push(`initialized:${sessionId}`);
1838+
},
1839+
onsessionclosed: async (sessionId: string) => {
1840+
await new Promise(resolve => setTimeout(resolve, 5));
1841+
events.push(`closed:${sessionId}`);
1842+
},
1843+
});
1844+
1845+
const tempServer = result.server;
1846+
const tempUrl = result.baseUrl;
1847+
1848+
// Initialize to trigger first callback
1849+
const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize);
1850+
const tempSessionId = initResponse.headers.get("mcp-session-id");
1851+
1852+
// Wait for async callback
1853+
await new Promise(resolve => setTimeout(resolve, 20));
1854+
1855+
expect(events).toContain(`initialized:${tempSessionId}`);
1856+
1857+
// DELETE to trigger second callback
1858+
const deleteResponse = await fetch(tempUrl, {
1859+
method: "DELETE",
1860+
headers: {
1861+
"mcp-session-id": tempSessionId || "",
1862+
"mcp-protocol-version": "2025-03-26",
1863+
},
1864+
});
1865+
1866+
expect(deleteResponse.status).toBe(200);
1867+
1868+
// Wait for async callback
1869+
await new Promise(resolve => setTimeout(resolve, 20));
1870+
1871+
expect(events).toContain(`closed:${tempSessionId}`);
1872+
expect(events).toHaveLength(2);
1873+
1874+
// Clean up
1875+
tempServer.close();
1876+
});
1877+
});
1878+
16691879
// Test DNS rebinding protection
16701880
describe("StreamableHTTPServerTransport DNS rebinding protection", () => {
16711881
let server: Server;

src/server/streamableHttp.ts

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ export interface StreamableHTTPServerTransportOptions {
4747
* and need to keep track of them.
4848
* @param sessionId The generated session ID
4949
*/
50-
onsessioninitialized?: (sessionId: string) => void;
50+
onsessioninitialized?: (sessionId: string) => void | Promise<void>;
5151

5252
/**
5353
* A callback for session close events
@@ -59,7 +59,7 @@ export interface StreamableHTTPServerTransportOptions {
5959
* session open/running.
6060
* @param sessionId The session ID that was closed
6161
*/
62-
onsessionclosed?: (sessionId: string) => void;
62+
onsessionclosed?: (sessionId: string) => void | Promise<void>;
6363

6464
/**
6565
* If true, the server will return JSON responses instead of starting an SSE stream.
@@ -138,8 +138,8 @@ export class StreamableHTTPServerTransport implements Transport {
138138
private _enableJsonResponse: boolean = false;
139139
private _standaloneSseStreamId: string = '_GET_stream';
140140
private _eventStore?: EventStore;
141-
private _onsessioninitialized?: (sessionId: string) => void;
142-
private _onsessionclosed?: (sessionId: string) => void;
141+
private _onsessioninitialized?: (sessionId: string) => void | Promise<void>;
142+
private _onsessionclosed?: (sessionId: string) => void | Promise<void>;
143143
private _allowedHosts?: string[];
144144
private _allowedOrigins?: string[];
145145
private _enableDnsRebindingProtection: boolean;
@@ -460,7 +460,7 @@ export class StreamableHTTPServerTransport implements Transport {
460460
// If we have a session ID and an onsessioninitialized handler, call it immediately
461461
// This is needed in cases where the server needs to keep track of multiple sessions
462462
if (this.sessionId && this._onsessioninitialized) {
463-
this._onsessioninitialized(this.sessionId);
463+
await Promise.resolve(this._onsessioninitialized(this.sessionId));
464464
}
465465

466466
}
@@ -552,7 +552,7 @@ export class StreamableHTTPServerTransport implements Transport {
552552
if (!this.validateProtocolVersion(req, res)) {
553553
return;
554554
}
555-
this._onsessionclosed?.(this.sessionId!);
555+
await Promise.resolve(this._onsessionclosed?.(this.sessionId!));
556556
await this.close();
557557
res.writeHead(200).end();
558558
}

0 commit comments

Comments
 (0)