diff --git a/src/client/auth.test.ts b/src/client/auth.test.ts index ce0cc708..b0ea8d1e 100644 --- a/src/client/auth.test.ts +++ b/src/client/auth.test.ts @@ -347,6 +347,35 @@ describe("OAuth Authorization", () => { const [url] = calls[0]; expect(url.toString()).toBe("https://custom.example.com/metadata"); }); + + it("supports overriding the fetch function used for requests", async () => { + const validMetadata = { + resource: "https://resource.example.com", + authorization_servers: ["https://auth.example.com"], + }; + + const customFetch = jest.fn().mockResolvedValue({ + ok: true, + status: 200, + json: async () => validMetadata, + }); + + const metadata = await discoverOAuthProtectedResourceMetadata( + "https://resource.example.com", + undefined, + customFetch + ); + + expect(metadata).toEqual(validMetadata); + expect(customFetch).toHaveBeenCalledTimes(1); + expect(mockFetch).not.toHaveBeenCalled(); + + const [url, options] = customFetch.mock.calls[0]; + expect(url.toString()).toBe("https://resource.example.com/.well-known/oauth-protected-resource"); + expect(options.headers).toEqual({ + "MCP-Protocol-Version": LATEST_PROTOCOL_VERSION + }); + }); }); describe("discoverOAuthMetadata", () => { @@ -619,6 +648,39 @@ describe("OAuth Authorization", () => { discoverOAuthMetadata("https://auth.example.com") ).rejects.toThrow(); }); + + it("supports overriding the fetch function used for requests", async () => { + const validMetadata = { + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + registration_endpoint: "https://auth.example.com/register", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }; + + const customFetch = jest.fn().mockResolvedValue({ + ok: true, + status: 200, + json: async () => validMetadata, + }); + + const metadata = await discoverOAuthMetadata( + "https://auth.example.com", + {}, + customFetch + ); + + expect(metadata).toEqual(validMetadata); + expect(customFetch).toHaveBeenCalledTimes(1); + expect(mockFetch).not.toHaveBeenCalled(); + + const [url, options] = customFetch.mock.calls[0]; + expect(url.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server"); + expect(options.headers).toEqual({ + "MCP-Protocol-Version": LATEST_PROTOCOL_VERSION + }); + }); }); describe("startAuthorization", () => { @@ -917,6 +979,46 @@ describe("OAuth Authorization", () => { }) ).rejects.toThrow("Token exchange failed"); }); + + it("supports overriding the fetch function used for requests", async () => { + const customFetch = jest.fn().mockResolvedValue({ + ok: true, + status: 200, + json: async () => validTokens, + }); + + const tokens = await exchangeAuthorization("https://auth.example.com", { + clientInformation: validClientInfo, + authorizationCode: "code123", + codeVerifier: "verifier123", + redirectUri: "http://localhost:3000/callback", + resource: new URL("https://api.example.com/mcp-server"), + fetchFn: customFetch, + }); + + expect(tokens).toEqual(validTokens); + expect(customFetch).toHaveBeenCalledTimes(1); + expect(mockFetch).not.toHaveBeenCalled(); + + const [url, options] = customFetch.mock.calls[0]; + expect(url.toString()).toBe("https://auth.example.com/token"); + expect(options).toEqual( + expect.objectContaining({ + method: "POST", + headers: expect.any(Headers), + body: expect.any(URLSearchParams), + }) + ); + + const body = options.body as URLSearchParams; + expect(body.get("grant_type")).toBe("authorization_code"); + expect(body.get("code")).toBe("code123"); + expect(body.get("code_verifier")).toBe("verifier123"); + expect(body.get("client_id")).toBe("client123"); + expect(body.get("client_secret")).toBe("secret123"); + expect(body.get("redirect_uri")).toBe("http://localhost:3000/callback"); + expect(body.get("resource")).toBe("https://api.example.com/mcp-server"); + }); }); describe("refreshAuthorization", () => { @@ -1824,6 +1926,68 @@ describe("OAuth Authorization", () => { // Second call should be to AS metadata with the path from authorization server expect(calls[1][0].toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server/oauth"); }); + + it("supports overriding the fetch function used for requests", async () => { + const customFetch = jest.fn(); + + // Mock PRM discovery + customFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + resource: "https://resource.example.com", + authorization_servers: ["https://auth.example.com"], + }), + }); + + // Mock AS metadata discovery + customFetch.mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + registration_endpoint: "https://auth.example.com/register", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }), + }); + + const mockProvider: OAuthClientProvider = { + get redirectUrl() { return "http://localhost:3000/callback"; }, + get clientMetadata() { + return { + client_name: "Test Client", + redirect_uris: ["http://localhost:3000/callback"], + }; + }, + clientInformation: jest.fn().mockResolvedValue({ + client_id: "client123", + client_secret: "secret123", + }), + tokens: jest.fn().mockResolvedValue(undefined), + saveTokens: jest.fn(), + redirectToAuthorization: jest.fn(), + saveCodeVerifier: jest.fn(), + codeVerifier: jest.fn().mockResolvedValue("verifier123"), + }; + + const result = await auth(mockProvider, { + serverUrl: "https://resource.example.com", + fetchFn: customFetch, + }); + + expect(result).toBe("REDIRECT"); + expect(customFetch).toHaveBeenCalledTimes(2); + expect(mockFetch).not.toHaveBeenCalled(); + + // Verify custom fetch was called for PRM discovery + expect(customFetch.mock.calls[0][0].toString()).toBe("https://resource.example.com/.well-known/oauth-protected-resource"); + + // Verify custom fetch was called for AS metadata discovery + expect(customFetch.mock.calls[1][0].toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server"); + }); }); describe("exchangeAuthorization with multiple client authentication methods", () => { diff --git a/src/client/auth.ts b/src/client/auth.ts index 4a8bbe2d..b5a3a6a4 100644 --- a/src/client/auth.ts +++ b/src/client/auth.ts @@ -19,6 +19,7 @@ import { ServerError, UnauthorizedClientError } from "../server/auth/errors.js"; +import { FetchLike } from "../shared/transport.js"; /** * Implements an end-to-end OAuth client to be used with one MCP server. @@ -281,8 +282,9 @@ export async function auth( serverUrl: string | URL; authorizationCode?: string; scope?: string; - resourceMetadataUrl?: URL }): Promise { - + resourceMetadataUrl?: URL; + fetchFn?: FetchLike; +}): Promise { try { return await authInternal(provider, options); } catch (error) { @@ -305,18 +307,21 @@ async function authInternal( { serverUrl, authorizationCode, scope, - resourceMetadataUrl + resourceMetadataUrl, + fetchFn, }: { serverUrl: string | URL; authorizationCode?: string; scope?: string; - resourceMetadataUrl?: URL - }): Promise { + resourceMetadataUrl?: URL; + fetchFn?: FetchLike; + }, +): Promise { let resourceMetadata: OAuthProtectedResourceMetadata | undefined; let authorizationServerUrl = serverUrl; try { - resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl, { resourceMetadataUrl }); + resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl, { resourceMetadataUrl }, fetchFn); if (resourceMetadata.authorization_servers && resourceMetadata.authorization_servers.length > 0) { authorizationServerUrl = resourceMetadata.authorization_servers[0]; } @@ -328,7 +333,7 @@ async function authInternal( const metadata = await discoverOAuthMetadata(serverUrl, { authorizationServerUrl - }); + }, fetchFn); // Handle client registration if needed let clientInformation = await Promise.resolve(provider.clientInformation()); @@ -361,6 +366,7 @@ async function authInternal( redirectUri: provider.redirectUrl, resource, addClientAuthentication: provider.addClientAuthentication, + fetchFn: fetchFn, }); await provider.saveTokens(tokens); @@ -469,10 +475,12 @@ export function extractResourceMetadataUrl(res: Response): URL | undefined { export async function discoverOAuthProtectedResourceMetadata( serverUrl: string | URL, opts?: { protocolVersion?: string, resourceMetadataUrl?: string | URL }, + fetchFn: FetchLike = fetch, ): Promise { const response = await discoverMetadataWithFallback( serverUrl, 'oauth-protected-resource', + fetchFn, { protocolVersion: opts?.protocolVersion, metadataUrl: opts?.resourceMetadataUrl, @@ -497,14 +505,15 @@ export async function discoverOAuthProtectedResourceMetadata( async function fetchWithCorsRetry( url: URL, headers?: Record, + fetchFn: FetchLike = fetch, ): Promise { try { - return await fetch(url, { headers }); + return await fetchFn(url, { headers }); } catch (error) { if (error instanceof TypeError) { if (headers) { // CORS errors come back as TypeError, retry without headers - return fetchWithCorsRetry(url) + return fetchWithCorsRetry(url, undefined, fetchFn) } else { // We're getting CORS errors on retry too, return undefined return undefined @@ -532,11 +541,12 @@ function buildWellKnownPath(wellKnownPrefix: string, pathname: string): string { async function tryMetadataDiscovery( url: URL, protocolVersion: string, + fetchFn: FetchLike = fetch, ): Promise { const headers = { "MCP-Protocol-Version": protocolVersion }; - return await fetchWithCorsRetry(url, headers); + return await fetchWithCorsRetry(url, headers, fetchFn); } /** @@ -552,6 +562,7 @@ function shouldAttemptFallback(response: Response | undefined, pathname: string) async function discoverMetadataWithFallback( serverUrl: string | URL, wellKnownType: 'oauth-authorization-server' | 'oauth-protected-resource', + fetchFn: FetchLike, opts?: { protocolVersion?: string; metadataUrl?: string | URL, metadataServerUrl?: string | URL }, ): Promise { const issuer = new URL(serverUrl); @@ -567,12 +578,12 @@ async function discoverMetadataWithFallback( url.search = issuer.search; } - let response = await tryMetadataDiscovery(url, protocolVersion); + let response = await tryMetadataDiscovery(url, protocolVersion, fetchFn); // If path-aware discovery fails with 404 and we're not already at root, try fallback to root discovery if (!opts?.metadataUrl && shouldAttemptFallback(response, issuer.pathname)) { const rootUrl = new URL(`/.well-known/${wellKnownType}`, issuer); - response = await tryMetadataDiscovery(rootUrl, protocolVersion); + response = await tryMetadataDiscovery(rootUrl, protocolVersion, fetchFn); } return response; @@ -593,6 +604,7 @@ export async function discoverOAuthMetadata( authorizationServerUrl?: string | URL, protocolVersion?: string, } = {}, + fetchFn: FetchLike = fetch, ): Promise { if (typeof issuer === 'string') { issuer = new URL(issuer); @@ -608,6 +620,7 @@ export async function discoverOAuthMetadata( const response = await discoverMetadataWithFallback( authorizationServerUrl, 'oauth-authorization-server', + fetchFn, { protocolVersion, metadataServerUrl: authorizationServerUrl, @@ -730,7 +743,8 @@ export async function exchangeAuthorization( codeVerifier, redirectUri, resource, - addClientAuthentication + addClientAuthentication, + fetchFn, }: { metadata?: OAuthMetadata; clientInformation: OAuthClientInformation; @@ -739,6 +753,7 @@ export async function exchangeAuthorization( redirectUri: string | URL; resource?: URL; addClientAuthentication?: OAuthClientProvider["addClientAuthentication"]; + fetchFn?: FetchLike; }, ): Promise { const grantType = "authorization_code"; @@ -781,7 +796,7 @@ export async function exchangeAuthorization( params.set("resource", resource.href); } - const response = await fetch(tokenUrl, { + const response = await (fetchFn ?? fetch)(tokenUrl, { method: "POST", headers, body: params, @@ -814,12 +829,14 @@ export async function refreshAuthorization( refreshToken, resource, addClientAuthentication, + fetchFn, }: { metadata?: OAuthMetadata; clientInformation: OAuthClientInformation; refreshToken: string; resource?: URL; addClientAuthentication?: OAuthClientProvider["addClientAuthentication"]; + fetchFn?: FetchLike; } ): Promise { const grantType = "refresh_token"; @@ -863,7 +880,7 @@ export async function refreshAuthorization( params.set("resource", resource.href); } - const response = await fetch(tokenUrl, { + const response = await (fetchFn ?? fetch)(tokenUrl, { method: "POST", headers, body: params, @@ -883,9 +900,11 @@ export async function registerClient( { metadata, clientMetadata, + fetchFn, }: { metadata?: OAuthMetadata; clientMetadata: OAuthClientMetadata; + fetchFn?: FetchLike; }, ): Promise { let registrationUrl: URL; @@ -900,7 +919,7 @@ export async function registerClient( registrationUrl = new URL("/register", authorizationServerUrl); } - const response = await fetch(registrationUrl, { + const response = await (fetchFn ?? fetch)(registrationUrl, { method: "POST", headers: { "Content-Type": "application/json", diff --git a/src/client/sse.test.ts b/src/client/sse.test.ts index 2cc4a1dd..24bfe094 100644 --- a/src/client/sse.test.ts +++ b/src/client/sse.test.ts @@ -1,4 +1,4 @@ -import { createServer, type IncomingMessage, type Server } from "http"; +import { createServer, ServerResponse, type IncomingMessage, type Server } from "http"; import { AddressInfo } from "net"; import { JSONRPCMessage } from "../types.js"; import { SSEClientTransport } from "./sse.js"; @@ -1108,4 +1108,337 @@ describe("SSEClientTransport", () => { expect(mockAuthProvider.invalidateCredentials).toHaveBeenCalledWith('tokens'); }); }); + + describe("custom fetch in auth code paths", () => { + let customFetch: jest.MockedFunction; + let globalFetchSpy: jest.SpyInstance; + let mockAuthProvider: jest.Mocked; + let resourceServerHandler: jest.Mock & { + req: IncomingMessage; + }], void>; + + /** + * Helper function to create a mock auth provider with configurable behavior + */ + const createMockAuthProvider = (config: { + hasTokens?: boolean; + tokensExpired?: boolean; + hasRefreshToken?: boolean; + clientRegistered?: boolean; + authorizationCode?: string; + } = {}): jest.Mocked => { + const tokens = config.hasTokens ? { + access_token: config.tokensExpired ? "expired-token" : "valid-token", + token_type: "Bearer" as const, + ...(config.hasRefreshToken && { refresh_token: "refresh-token" }) + } : undefined; + + const clientInfo = config.clientRegistered ? { + client_id: "test-client-id", + client_secret: "test-client-secret" + } : undefined; + + return { + get redirectUrl() { return "http://localhost/callback"; }, + get clientMetadata() { + return { + redirect_uris: ["http://localhost/callback"], + client_name: "Test Client" + }; + }, + clientInformation: jest.fn().mockResolvedValue(clientInfo), + tokens: jest.fn().mockResolvedValue(tokens), + saveTokens: jest.fn(), + redirectToAuthorization: jest.fn(), + saveCodeVerifier: jest.fn(), + codeVerifier: jest.fn().mockResolvedValue("test-verifier"), + invalidateCredentials: jest.fn(), + }; + }; + + const createCustomFetchMockAuthServer = async () => { + authServer = createServer((req, res) => { + if (req.url === "/.well-known/oauth-authorization-server") { + res.writeHead(200, { "Content-Type": "application/json" }); + res.end(JSON.stringify({ + issuer: `http://127.0.0.1:${(authServer.address() as AddressInfo).port}`, + authorization_endpoint: `http://127.0.0.1:${(authServer.address() as AddressInfo).port}/authorize`, + token_endpoint: `http://127.0.0.1:${(authServer.address() as AddressInfo).port}/token`, + registration_endpoint: `http://127.0.0.1:${(authServer.address() as AddressInfo).port}/register`, + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + })); + return; + } + + if (req.url === "/token" && req.method === "POST") { + // Handle token exchange request + let body = ""; + req.on("data", chunk => { body += chunk; }); + req.on("end", () => { + const params = new URLSearchParams(body); + if (params.get("grant_type") === "authorization_code" && + params.get("code") === "test-auth-code" && + params.get("client_id") === "test-client-id") { + res.writeHead(200, { "Content-Type": "application/json" }); + res.end(JSON.stringify({ + access_token: "new-access-token", + token_type: "Bearer", + expires_in: 3600, + refresh_token: "new-refresh-token" + })); + } else { + res.writeHead(400).end(); + } + }); + return; + } + + res.writeHead(404).end(); + }); + + // Start auth server on random port + await new Promise(resolve => { + authServer.listen(0, "127.0.0.1", () => { + const addr = authServer.address() as AddressInfo; + authBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + }; + + const createCustomFetchMockResourceServer = async () => { + // Set up resource server that provides OAuth metadata + resourceServer = createServer((req, res) => { + lastServerRequest = req; + + if (req.url === "/.well-known/oauth-protected-resource") { + res.writeHead(200, { "Content-Type": "application/json" }); + res.end(JSON.stringify({ + resource: resourceBaseUrl.href, + authorization_servers: [authBaseUrl.href], + })); + return; + } + + resourceServerHandler(req, res); + }); + + // Start resource server on random port + await new Promise(resolve => { + resourceServer.listen(0, "127.0.0.1", () => { + const addr = resourceServer.address() as AddressInfo; + resourceBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + }; + + beforeEach(async () => { + // Close existing servers to set up custom auth flow servers + resourceServer.close(); + authServer.close(); + + const originalFetch = fetch; + + // Create custom fetch spy that delegates to real fetch + customFetch = jest.fn((url, init) => { + return originalFetch(url.toString(), init); + }); + + // Spy on global fetch to detect unauthorized usage + globalFetchSpy = jest.spyOn(global, 'fetch'); + + // Create mock auth provider with default configuration + mockAuthProvider = createMockAuthProvider({ + hasTokens: false, + clientRegistered: true + }); + + // Set up auth server that handles OAuth discovery and token requests + await createCustomFetchMockAuthServer(); + + // Set up resource server + resourceServerHandler = jest.fn((_req: IncomingMessage, res: ServerResponse & { + req: IncomingMessage; + }) => { + res.writeHead(404).end(); + }); + await createCustomFetchMockResourceServer(); + }); + + afterEach(() => { + globalFetchSpy.mockRestore(); + }); + + it("uses custom fetch during auth flow on SSE connection 401 - no global fetch fallback", async () => { + // Set up resource server that returns 401 on SSE connection and provides OAuth metadata + resourceServerHandler.mockImplementation((req, res) => { + if (req.url === "/") { + // Return 401 to trigger auth flow + res.writeHead(401, { + "WWW-Authenticate": `Bearer realm="mcp", resource_metadata="${resourceBaseUrl.href}.well-known/oauth-protected-resource"` + }); + res.end(); + return; + } + + res.writeHead(404).end(); + }); + + // Create transport with custom fetch and auth provider + transport = new SSEClientTransport(resourceBaseUrl, { + authProvider: mockAuthProvider, + fetch: customFetch, + }); + + // Attempt to start - should trigger auth flow and eventually fail with UnauthorizedError + await expect(transport.start()).rejects.toThrow(UnauthorizedError); + + // Verify custom fetch was used + expect(customFetch).toHaveBeenCalled(); + + // Verify specific OAuth endpoints were called with custom fetch + const customFetchCalls = customFetch.mock.calls; + const callUrls = customFetchCalls.map(([url]) => url.toString()); + + // Should have called resource metadata discovery + expect(callUrls.some(url => url.includes('/.well-known/oauth-protected-resource'))).toBe(true); + + // Should have called OAuth authorization server metadata discovery + expect(callUrls.some(url => url.includes('/.well-known/oauth-authorization-server'))).toBe(true); + + // Verify auth provider was called to redirect to authorization + expect(mockAuthProvider.redirectToAuthorization).toHaveBeenCalled(); + + // Global fetch should never have been called + expect(globalFetchSpy).not.toHaveBeenCalled(); + }); + + it("uses custom fetch during auth flow on POST request 401 - no global fetch fallback", async () => { + // Set up resource server that accepts SSE connection but returns 401 on POST + resourceServerHandler.mockImplementation((req, res) => { + switch (req.method) { + case "GET": + if (req.url === "/") { + // Accept SSE connection + res.writeHead(200, { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache, no-transform", + Connection: "keep-alive", + }); + res.write("event: endpoint\n"); + res.write(`data: ${resourceBaseUrl.href}\n\n`); + return; + } + break; + + case "POST": + if (req.url === "/") { + // Return 401 to trigger auth retry + res.writeHead(401, { + "WWW-Authenticate": `Bearer realm="mcp", resource_metadata="${resourceBaseUrl.href}.well-known/oauth-protected-resource"` + }); + res.end(); + return; + } + break; + } + + res.writeHead(404).end(); + }); + + // Create transport with custom fetch and auth provider + transport = new SSEClientTransport(resourceBaseUrl, { + authProvider: mockAuthProvider, + fetch: customFetch, + }); + + // Start the transport (should succeed) + await transport.start(); + + // Send a message that should trigger 401 and auth retry + const message: JSONRPCMessage = { + jsonrpc: "2.0", + id: "1", + method: "test", + params: {}, + }; + + // Attempt to send message - should trigger auth flow and eventually fail + await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + + // Verify custom fetch was used + expect(customFetch).toHaveBeenCalled(); + + // Verify specific OAuth endpoints were called with custom fetch + const customFetchCalls = customFetch.mock.calls; + const callUrls = customFetchCalls.map(([url]) => url.toString()); + + // Should have called resource metadata discovery + expect(callUrls.some(url => url.includes('/.well-known/oauth-protected-resource'))).toBe(true); + + // Should have called OAuth authorization server metadata discovery + expect(callUrls.some(url => url.includes('/.well-known/oauth-authorization-server'))).toBe(true); + + // Should have attempted the POST request that triggered the 401 + const postCalls = customFetchCalls.filter(([url, options]) => + url.toString() === resourceBaseUrl.href && options?.method === "POST" + ); + expect(postCalls.length).toBeGreaterThan(0); + + // Verify auth provider was called to redirect to authorization + expect(mockAuthProvider.redirectToAuthorization).toHaveBeenCalled(); + + // Global fetch should never have been called + expect(globalFetchSpy).not.toHaveBeenCalled(); + }); + + it("uses custom fetch in finishAuth method - no global fetch fallback", async () => { + // Create mock auth provider that expects to save tokens + const authProviderWithCode = createMockAuthProvider({ + clientRegistered: true, + authorizationCode: "test-auth-code" + }); + + // Create transport with custom fetch and auth provider + transport = new SSEClientTransport(resourceBaseUrl, { + authProvider: authProviderWithCode, + fetch: customFetch, + }); + + // Call finishAuth with authorization code + await transport.finishAuth("test-auth-code"); + + // Verify custom fetch was used + expect(customFetch).toHaveBeenCalled(); + + // Verify specific OAuth endpoints were called with custom fetch + const customFetchCalls = customFetch.mock.calls; + const callUrls = customFetchCalls.map(([url]) => url.toString()); + + // Should have called resource metadata discovery + expect(callUrls.some(url => url.includes('/.well-known/oauth-protected-resource'))).toBe(true); + + // Should have called OAuth authorization server metadata discovery + expect(callUrls.some(url => url.includes('/.well-known/oauth-authorization-server'))).toBe(true); + + // Should have called token endpoint for authorization code exchange + const tokenCalls = customFetchCalls.filter(([url, options]) => + url.toString().includes('/token') && options?.method === "POST" + ); + expect(tokenCalls.length).toBeGreaterThan(0); + + // Verify tokens were saved + expect(authProviderWithCode.saveTokens).toHaveBeenCalledWith({ + access_token: "new-access-token", + token_type: "Bearer", + expires_in: 3600, + refresh_token: "new-refresh-token" + }); + + // Global fetch should never have been called + expect(globalFetchSpy).not.toHaveBeenCalled(); + }); + }); }); diff --git a/src/client/sse.ts b/src/client/sse.ts index 568a5159..e1c86ccd 100644 --- a/src/client/sse.ts +++ b/src/client/sse.ts @@ -93,7 +93,7 @@ export class SSEClientTransport implements Transport { let result: AuthResult; try { - result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl }); + result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch }); } catch (error) { this.onerror?.(error as Error); throw error; @@ -218,7 +218,7 @@ export class SSEClientTransport implements Transport { throw new UnauthorizedError("No auth provider"); } - const result = await auth(this._authProvider, { serverUrl: this._url, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl }); + const result = await auth(this._authProvider, { serverUrl: this._url, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch }); if (result !== "AUTHORIZED") { throw new UnauthorizedError("Failed to authorize"); } @@ -246,13 +246,13 @@ export class SSEClientTransport implements Transport { signal: this._abortController?.signal, }; -const response = await (this._fetch ?? fetch)(this._endpoint, init); + const response = await (this._fetch ?? fetch)(this._endpoint, init); if (!response.ok) { if (response.status === 401 && this._authProvider) { this._resourceMetadataUrl = extractResourceMetadataUrl(response); - const result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl }); + const result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch }); if (result !== "AUTHORIZED") { throw new UnauthorizedError(); } diff --git a/src/client/streamableHttp.test.ts b/src/client/streamableHttp.test.ts index c54cf289..88fd4801 100644 --- a/src/client/streamableHttp.test.ts +++ b/src/client/streamableHttp.test.ts @@ -1,4 +1,4 @@ -import { StreamableHTTPClientTransport, StreamableHTTPReconnectionOptions, StartSSEOptions } from "./streamableHttp.js"; +import { StartSSEOptions, StreamableHTTPClientTransport, StreamableHTTPReconnectionOptions } from "./streamableHttp.js"; import { OAuthClientProvider, UnauthorizedError } from "./auth.js"; import { JSONRPCMessage, JSONRPCRequest } from "../types.js"; import { InvalidClientError, InvalidGrantError, UnauthorizedClientError } from "../server/auth/errors.js"; @@ -445,36 +445,31 @@ describe("StreamableHTTPClientTransport", () => { expect(errorSpy).toHaveBeenCalled(); }); - it("uses custom fetch implementation", async () => { - const authToken = "Bearer custom-token"; - - const fetchWithAuth = jest.fn((url: string | URL, init?: RequestInit) => { - const headers = new Headers(init?.headers); - headers.set("Authorization", authToken); - return (global.fetch as jest.Mock)(url, { ...init, headers }); - }); - - (global.fetch as jest.Mock) + it("uses custom fetch implementation if provided", async () => { + // Create custom fetch + const customFetch = jest.fn() .mockResolvedValueOnce( new Response(null, { status: 200, headers: { "content-type": "text/event-stream" } }) ) .mockResolvedValueOnce(new Response(null, { status: 202 })); - transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { fetch: fetchWithAuth }); + // Create transport instance + transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { + fetch: customFetch + }); await transport.start(); await (transport as unknown as { _startOrAuthSse: (opts: StartSSEOptions) => Promise })._startOrAuthSse({}); await transport.send({ jsonrpc: "2.0", method: "test", params: {}, id: "1" } as JSONRPCMessage); - expect(fetchWithAuth).toHaveBeenCalled(); - for (const call of (global.fetch as jest.Mock).mock.calls) { - const headers = call[1].headers as Headers; - expect(headers.get("Authorization")).toBe(authToken); - } + // Verify custom fetch was used + expect(customFetch).toHaveBeenCalled(); + + // Global fetch should never have been called + expect(global.fetch).not.toHaveBeenCalled(); }); - it("should always send specified custom headers", async () => { const requestInit = { headers: { @@ -855,4 +850,149 @@ describe("StreamableHTTPClientTransport", () => { await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); expect(mockAuthProvider.invalidateCredentials).toHaveBeenCalledWith('tokens'); }); + + describe("custom fetch in auth code paths", () => { + it("uses custom fetch during auth flow on 401 - no global fetch fallback", async () => { + const unauthedResponse = { + ok: false, + status: 401, + statusText: "Unauthorized", + headers: new Headers() + }; + + // Create custom fetch + const customFetch = jest.fn() + // Initial connection + .mockResolvedValueOnce(unauthedResponse) + // Resource discovery + .mockResolvedValueOnce(unauthedResponse) + // OAuth metadata discovery + .mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + issuer: "http://localhost:1234", + authorization_endpoint: "http://localhost:1234/authorize", + token_endpoint: "http://localhost:1234/token", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }), + }) + // Token refresh fails with InvalidClientError + .mockResolvedValueOnce(Response.json( + new InvalidClientError("Client authentication failed").toResponseObject(), + { status: 400 } + )) + // Fallback should fail to complete the flow + .mockResolvedValue({ + ok: false, + status: 404 + }); + + // Create transport instance + transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { + authProvider: mockAuthProvider, + fetch: customFetch + }); + + // Attempt to start - should trigger auth flow and eventually fail with UnauthorizedError + await transport.start(); + await expect((transport as unknown as { _startOrAuthSse: (opts: StartSSEOptions) => Promise })._startOrAuthSse({})).rejects.toThrow(UnauthorizedError); + + // Verify custom fetch was used + expect(customFetch).toHaveBeenCalled(); + + // Verify specific OAuth endpoints were called with custom fetch + const customFetchCalls = customFetch.mock.calls; + const callUrls = customFetchCalls.map(([url]) => url.toString()); + + // Should have called resource metadata discovery + expect(callUrls.some(url => url.includes('/.well-known/oauth-protected-resource'))).toBe(true); + + // Should have called OAuth authorization server metadata discovery + expect(callUrls.some(url => url.includes('/.well-known/oauth-authorization-server'))).toBe(true); + + // Verify auth provider was called to redirect to authorization + expect(mockAuthProvider.redirectToAuthorization).toHaveBeenCalled(); + + // Global fetch should never have been called + expect(global.fetch).not.toHaveBeenCalled(); + }); + + it("uses custom fetch in finishAuth method - no global fetch fallback", async () => { + // Create custom fetch + const customFetch = jest.fn() + // Protected resource metadata discovery + .mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + authorization_servers: ["http://localhost:1234"], + resource: "http://localhost:1234/mcp" + }), + }) + // OAuth metadata discovery + .mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + issuer: "http://localhost:1234", + authorization_endpoint: "http://localhost:1234/authorize", + token_endpoint: "http://localhost:1234/token", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }), + }) + // Code exchange + .mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + access_token: "new-access-token", + refresh_token: "new-refresh-token", + token_type: "Bearer", + expires_in: 3600, + }), + }); + + // Create transport instance + transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { + authProvider: mockAuthProvider, + fetch: customFetch + }); + + // Call finishAuth with authorization code + await transport.finishAuth("test-auth-code"); + + // Verify custom fetch was used + expect(customFetch).toHaveBeenCalled(); + + // Verify specific OAuth endpoints were called with custom fetch + const customFetchCalls = customFetch.mock.calls; + const callUrls = customFetchCalls.map(([url]) => url.toString()); + + // Should have called resource metadata discovery + expect(callUrls.some(url => url.includes('/.well-known/oauth-protected-resource'))).toBe(true); + + // Should have called OAuth authorization server metadata discovery + expect(callUrls.some(url => url.includes('/.well-known/oauth-authorization-server'))).toBe(true); + + // Should have called token endpoint for authorization code exchange + const tokenCalls = customFetchCalls.filter(([url, options]) => + url.toString().includes('/token') && options?.method === "POST" + ); + expect(tokenCalls.length).toBeGreaterThan(0); + + // Verify tokens were saved + expect(mockAuthProvider.saveTokens).toHaveBeenCalledWith({ + access_token: "new-access-token", + token_type: "Bearer", + expires_in: 3600, + refresh_token: "new-refresh-token" + }); + + // Global fetch should never have been called + expect(global.fetch).not.toHaveBeenCalled(); + }); + }); }); diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts index b0894fce..77a15c92 100644 --- a/src/client/streamableHttp.ts +++ b/src/client/streamableHttp.ts @@ -156,7 +156,7 @@ export class StreamableHTTPClientTransport implements Transport { let result: AuthResult; try { - result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl }); + result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch }); } catch (error) { this.onerror?.(error as Error); throw error; @@ -392,7 +392,7 @@ const response = await (this._fetch ?? fetch)(this._url, { throw new UnauthorizedError("No auth provider"); } - const result = await auth(this._authProvider, { serverUrl: this._url, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl }); + const result = await auth(this._authProvider, { serverUrl: this._url, authorizationCode, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch }); if (result !== "AUTHORIZED") { throw new UnauthorizedError("Failed to authorize"); } @@ -440,7 +440,7 @@ const response = await (this._fetch ?? fetch)(this._url, init); this._resourceMetadataUrl = extractResourceMetadataUrl(response); - const result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl }); + const result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl, fetchFn: this._fetch }); if (result !== "AUTHORIZED") { throw new UnauthorizedError(); } diff --git a/src/server/auth/providers/proxyProvider.ts b/src/server/auth/providers/proxyProvider.ts index de74862b..c66a8707 100644 --- a/src/server/auth/providers/proxyProvider.ts +++ b/src/server/auth/providers/proxyProvider.ts @@ -10,6 +10,7 @@ import { import { AuthInfo } from "../types.js"; import { AuthorizationParams, OAuthServerProvider } from "../provider.js"; import { ServerError } from "../errors.js"; +import { FetchLike } from "../../../shared/transport.js"; export type ProxyEndpoints = { authorizationUrl: string; @@ -34,6 +35,10 @@ export type ProxyOptions = { */ getClient: (clientId: string) => Promise; + /** + * Custom fetch implementation used for all network requests. + */ + fetch?: FetchLike; }; /** @@ -43,6 +48,7 @@ export class ProxyOAuthServerProvider implements OAuthServerProvider { protected readonly _endpoints: ProxyEndpoints; protected readonly _verifyAccessToken: (token: string) => Promise; protected readonly _getClient: (clientId: string) => Promise; + protected readonly _fetch?: FetchLike; skipLocalPkceValidation = true; @@ -55,6 +61,7 @@ export class ProxyOAuthServerProvider implements OAuthServerProvider { this._endpoints = options.endpoints; this._verifyAccessToken = options.verifyAccessToken; this._getClient = options.getClient; + this._fetch = options.fetch; if (options.endpoints?.revocationUrl) { this.revokeToken = async ( client: OAuthClientInformationFull, @@ -76,7 +83,7 @@ export class ProxyOAuthServerProvider implements OAuthServerProvider { params.set("token_type_hint", request.token_type_hint); } - const response = await fetch(revocationUrl, { + const response = await (this._fetch ?? fetch)(revocationUrl, { method: "POST", headers: { "Content-Type": "application/x-www-form-urlencoded", @@ -97,7 +104,7 @@ export class ProxyOAuthServerProvider implements OAuthServerProvider { getClient: this._getClient, ...(registrationUrl && { registerClient: async (client: OAuthClientInformationFull) => { - const response = await fetch(registrationUrl, { + const response = await (this._fetch ?? fetch)(registrationUrl, { method: "POST", headers: { "Content-Type": "application/json", @@ -178,7 +185,7 @@ export class ProxyOAuthServerProvider implements OAuthServerProvider { params.append("resource", resource.href); } - const response = await fetch(this._endpoints.tokenUrl, { + const response = await (this._fetch ?? fetch)(this._endpoints.tokenUrl, { method: "POST", headers: { "Content-Type": "application/x-www-form-urlencoded", @@ -220,7 +227,7 @@ export class ProxyOAuthServerProvider implements OAuthServerProvider { params.set("resource", resource.href); } - const response = await fetch(this._endpoints.tokenUrl, { + const response = await (this._fetch ?? fetch)(this._endpoints.tokenUrl, { method: "POST", headers: { "Content-Type": "application/x-www-form-urlencoded",