Skip to content

Commit 0b0ea5b

Browse files
authored
Merge pull request #721 from cliffhall/allow-custom-fetch
Allow custom fetch
2 parents bf4c5be + 9766edc commit 0b0ea5b

File tree

5 files changed

+94
-17
lines changed

5 files changed

+94
-17
lines changed

src/client/sse.test.ts

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,38 @@ describe("SSEClientTransport", () => {
262262
expect(lastServerRequest.headers.authorization).toBe(authToken);
263263
});
264264

265+
it("uses custom fetch implementation from options", async () => {
266+
const authToken = "Bearer custom-token";
267+
268+
const fetchWithAuth = jest.fn((url: string | URL, init?: RequestInit) => {
269+
const headers = new Headers(init?.headers);
270+
headers.set("Authorization", authToken);
271+
return fetch(url.toString(), { ...init, headers });
272+
});
273+
274+
transport = new SSEClientTransport(resourceBaseUrl, {
275+
fetch: fetchWithAuth,
276+
});
277+
278+
await transport.start();
279+
280+
expect(lastServerRequest.headers.authorization).toBe(authToken);
281+
282+
// Send a message to verify fetchWithAuth used for POST as well
283+
const message: JSONRPCMessage = {
284+
jsonrpc: "2.0",
285+
id: "1",
286+
method: "test",
287+
params: {},
288+
};
289+
290+
await transport.send(message);
291+
292+
expect(fetchWithAuth).toHaveBeenCalledTimes(2);
293+
expect(lastServerRequest.method).toBe("POST");
294+
expect(lastServerRequest.headers.authorization).toBe(authToken);
295+
});
296+
265297
it("passes custom headers to fetch requests", async () => {
266298
const customHeaders = {
267299
Authorization: "Bearer test-token",

src/client/sse.ts

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { EventSource, type ErrorEvent, type EventSourceInit } from "eventsource";
2-
import { Transport } from "../shared/transport.js";
2+
import { Transport, FetchLike } from "../shared/transport.js";
33
import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
44
import { auth, AuthResult, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from "./auth.js";
55

@@ -47,6 +47,11 @@ export type SSEClientTransportOptions = {
4747
* Customizes recurring POST requests to the server.
4848
*/
4949
requestInit?: RequestInit;
50+
51+
/**
52+
* Custom fetch implementation used for all network requests.
53+
*/
54+
fetch?: FetchLike;
5055
};
5156

5257
/**
@@ -62,6 +67,7 @@ export class SSEClientTransport implements Transport {
6267
private _eventSourceInit?: EventSourceInit;
6368
private _requestInit?: RequestInit;
6469
private _authProvider?: OAuthClientProvider;
70+
private _fetch?: FetchLike;
6571
private _protocolVersion?: string;
6672

6773
onclose?: () => void;
@@ -77,6 +83,7 @@ export class SSEClientTransport implements Transport {
7783
this._eventSourceInit = opts?.eventSourceInit;
7884
this._requestInit = opts?.requestInit;
7985
this._authProvider = opts?.authProvider;
86+
this._fetch = opts?.fetch;
8087
}
8188

8289
private async _authThenStart(): Promise<void> {
@@ -117,7 +124,7 @@ export class SSEClientTransport implements Transport {
117124
}
118125

119126
private _startOrAuth(): Promise<void> {
120-
const fetchImpl = (this?._eventSourceInit?.fetch || fetch) as typeof fetch
127+
const fetchImpl = (this?._eventSourceInit?.fetch ?? this._fetch ?? fetch) as typeof fetch
121128
return new Promise((resolve, reject) => {
122129
this._eventSource = new EventSource(
123130
this._url.href,
@@ -242,7 +249,7 @@ export class SSEClientTransport implements Transport {
242249
signal: this._abortController?.signal,
243250
};
244251

245-
const response = await fetch(this._endpoint, init);
252+
const response = await (this._fetch ?? fetch)(this._endpoint, init);
246253
if (!response.ok) {
247254
if (response.status === 401 && this._authProvider) {
248255

src/client/streamableHttp.test.ts

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { StreamableHTTPClientTransport, StreamableHTTPReconnectionOptions } from "./streamableHttp.js";
1+
import { StreamableHTTPClientTransport, StreamableHTTPReconnectionOptions, StartSSEOptions } from "./streamableHttp.js";
22
import { OAuthClientProvider, UnauthorizedError } from "./auth.js";
33
import { JSONRPCMessage } from "../types.js";
44

@@ -443,6 +443,35 @@ describe("StreamableHTTPClientTransport", () => {
443443
expect(errorSpy).toHaveBeenCalled();
444444
});
445445

446+
it("uses custom fetch implementation", async () => {
447+
const authToken = "Bearer custom-token";
448+
449+
const fetchWithAuth = jest.fn((url: string | URL, init?: RequestInit) => {
450+
const headers = new Headers(init?.headers);
451+
headers.set("Authorization", authToken);
452+
return (global.fetch as jest.Mock)(url, { ...init, headers });
453+
});
454+
455+
(global.fetch as jest.Mock)
456+
.mockResolvedValueOnce(
457+
new Response(null, { status: 200, headers: { "content-type": "text/event-stream" } })
458+
)
459+
.mockResolvedValueOnce(new Response(null, { status: 202 }));
460+
461+
transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { fetch: fetchWithAuth });
462+
463+
await transport.start();
464+
await (transport as unknown as { _startOrAuthSse: (opts: StartSSEOptions) => Promise<void> })._startOrAuthSse({});
465+
466+
await transport.send({ jsonrpc: "2.0", method: "test", params: {}, id: "1" } as JSONRPCMessage);
467+
468+
expect(fetchWithAuth).toHaveBeenCalled();
469+
for (const call of (global.fetch as jest.Mock).mock.calls) {
470+
const headers = call[1].headers as Headers;
471+
expect(headers.get("Authorization")).toBe(authToken);
472+
}
473+
});
474+
446475

447476
it("should always send specified custom headers", async () => {
448477
const requestInit = {
@@ -530,7 +559,7 @@ describe("StreamableHTTPClientTransport", () => {
530559
// Second retry - should double (2^1 * 100 = 200)
531560
expect(getDelay(1)).toBe(200);
532561

533-
// Third retry - should double again (2^2 * 100 = 400)
562+
// Third retry - should double again (2^2 * 100 = 400)
534563
expect(getDelay(2)).toBe(400);
535564

536565
// Fourth retry - should double again (2^3 * 100 = 800)

src/client/streamableHttp.ts

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { Transport } from "../shared/transport.js";
1+
import { Transport, FetchLike } from "../shared/transport.js";
22
import { isInitializedNotification, isJSONRPCRequest, isJSONRPCResponse, JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
33
import { auth, AuthResult, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from "./auth.js";
44
import { EventSourceParserStream } from "eventsource-parser/stream";
@@ -23,7 +23,7 @@ export class StreamableHTTPError extends Error {
2323
/**
2424
* Options for starting or authenticating an SSE connection
2525
*/
26-
interface StartSSEOptions {
26+
export interface StartSSEOptions {
2727
/**
2828
* The resumption token used to continue long-running requests that were interrupted.
2929
*
@@ -99,6 +99,11 @@ export type StreamableHTTPClientTransportOptions = {
9999
*/
100100
requestInit?: RequestInit;
101101

102+
/**
103+
* Custom fetch implementation used for all network requests.
104+
*/
105+
fetch?: FetchLike;
106+
102107
/**
103108
* Options to configure the reconnection behavior.
104109
*/
@@ -122,6 +127,7 @@ export class StreamableHTTPClientTransport implements Transport {
122127
private _resourceMetadataUrl?: URL;
123128
private _requestInit?: RequestInit;
124129
private _authProvider?: OAuthClientProvider;
130+
private _fetch?: FetchLike;
125131
private _sessionId?: string;
126132
private _reconnectionOptions: StreamableHTTPReconnectionOptions;
127133
private _protocolVersion?: string;
@@ -138,6 +144,7 @@ export class StreamableHTTPClientTransport implements Transport {
138144
this._resourceMetadataUrl = undefined;
139145
this._requestInit = opts?.requestInit;
140146
this._authProvider = opts?.authProvider;
147+
this._fetch = opts?.fetch;
141148
this._sessionId = opts?.sessionId;
142149
this._reconnectionOptions = opts?.reconnectionOptions ?? DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS;
143150
}
@@ -200,7 +207,7 @@ export class StreamableHTTPClientTransport implements Transport {
200207
headers.set("last-event-id", resumptionToken);
201208
}
202209

203-
const response = await fetch(this._url, {
210+
const response = await (this._fetch ?? fetch)(this._url, {
204211
method: "GET",
205212
headers,
206213
signal: this._abortController?.signal,
@@ -251,15 +258,15 @@ export class StreamableHTTPClientTransport implements Transport {
251258

252259
private _normalizeHeaders(headers: HeadersInit | undefined): Record<string, string> {
253260
if (!headers) return {};
254-
261+
255262
if (headers instanceof Headers) {
256263
return Object.fromEntries(headers.entries());
257264
}
258-
265+
259266
if (Array.isArray(headers)) {
260267
return Object.fromEntries(headers);
261268
}
262-
269+
263270
return { ...headers as Record<string, string> };
264271
}
265272

@@ -414,7 +421,7 @@ export class StreamableHTTPClientTransport implements Transport {
414421
signal: this._abortController?.signal,
415422
};
416423

417-
const response = await fetch(this._url, init);
424+
const response = await (this._fetch ?? fetch)(this._url, init);
418425

419426
// Handle session ID received during initialization
420427
const sessionId = response.headers.get("mcp-session-id");
@@ -520,7 +527,7 @@ export class StreamableHTTPClientTransport implements Transport {
520527
signal: this._abortController?.signal,
521528
};
522529

523-
const response = await fetch(this._url, init);
530+
const response = await (this._fetch ?? fetch)(this._url, init);
524531

525532
// We specifically handle 405 as a valid response according to the spec,
526533
// meaning the server does not support explicit session termination

src/shared/transport.ts

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import { JSONRPCMessage, MessageExtraInfo, RequestId } from "../types.js";
22

3+
export type FetchLike = (url: string | URL, init?: RequestInit) => Promise<Response>;
4+
35
/**
46
* Options for sending a JSON-RPC message.
57
*/
68
export type TransportSendOptions = {
7-
/**
9+
/**
810
* If present, `relatedRequestId` is used to indicate to the transport which incoming request to associate this outgoing message with.
911
*/
1012
relatedRequestId?: RequestId;
@@ -38,7 +40,7 @@ export interface Transport {
3840

3941
/**
4042
* Sends a JSON-RPC message (request or response).
41-
*
43+
*
4244
* If present, `relatedRequestId` is used to indicate to the transport which incoming request to associate this outgoing message with.
4345
*/
4446
send(message: JSONRPCMessage, options?: TransportSendOptions): Promise<void>;
@@ -64,9 +66,9 @@ export interface Transport {
6466

6567
/**
6668
* Callback for when a message (request or response) is received over the connection.
67-
*
69+
*
6870
* Includes the requestInfo and authInfo if the transport is authenticated.
69-
*
71+
*
7072
* The requestInfo can be used to get the original request information (headers, etc.)
7173
*/
7274
onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void;

0 commit comments

Comments
 (0)