From 17210c875d0e047bf8d5b251ea0ebe704d59e20e Mon Sep 17 00:00:00 2001 From: Marcelo Paternostro Date: Wed, 2 Jul 2025 20:36:52 -0400 Subject: [PATCH] feature(auth): DelegatedAuthClientProvider An optional provider that can be passed to the SSE and StreamableHttp client transports in order to completely delegate the authentication to an external system. --- src/client/auth.ts | 36 +++++ src/client/sse.test.ts | 204 +++++++++++++++++++++++++++- src/client/sse.ts | 102 +++++++++++--- src/client/streamableHttp.test.ts | 212 +++++++++++++++++++++++++++++- src/client/streamableHttp.ts | 74 +++++++++-- 5 files changed, 598 insertions(+), 30 deletions(-) diff --git a/src/client/auth.ts b/src/client/auth.ts index 2b69a5d8..ca7e39f9 100644 --- a/src/client/auth.ts +++ b/src/client/auth.ts @@ -103,6 +103,42 @@ export interface OAuthClientProvider { validateResourceURL?(serverUrl: string | URL, resource?: string): Promise; } +/** + * A provider that delegates authentication to an external system. + * + * This interface allows for custom authentication mechanisms that are + * either already implemented on a specific platform or handled outside the + * standard OAuth flow, such as API keys, custom tokens, or integration with external + * authentication services. + */ +export interface DelegatedAuthClientProvider { + /** + * Returns authentication headers to be included in requests. + * + * These headers will be added to all HTTP requests made by the transport. + * Common examples include Authorization headers, API keys, or custom + * authentication tokens. + * + * @returns Headers to include in requests, or undefined if no authentication is available + */ + headers(): HeadersInit | undefined | Promise; + + /** + * Performs authentication when a 401 Unauthorized response is received. + * + * This method is called when the server responds with a 401 status code, + * indicating that the current authentication is invalid or expired. + * The implementation should attempt to refresh or re-establish authentication. + * + * @param context Authentication context providing server and resource information + * @param context.serverUrl The URL of the MCP server being authenticated against + * @param context.resourceMetadataUrl Optional URL for resource metadata, if available + * @returns Promise that resolves to true if authentication was successful, + * false if authentication failed + */ + authorize(context: { serverUrl: string | URL; resourceMetadataUrl?: URL }): boolean | Promise; +} + export type AuthResult = "AUTHORIZED" | "REDIRECT"; export class UnauthorizedError extends Error { diff --git a/src/client/sse.test.ts b/src/client/sse.test.ts index 3e3abe68..b5ee1393 100644 --- a/src/client/sse.test.ts +++ b/src/client/sse.test.ts @@ -2,7 +2,7 @@ import { createServer, type IncomingMessage, type Server } from "http"; import { AddressInfo } from "net"; import { JSONRPCMessage } from "../types.js"; import { SSEClientTransport } from "./sse.js"; -import { OAuthClientProvider, UnauthorizedError } from "./auth.js"; +import { DelegatedAuthClientProvider, OAuthClientProvider, UnauthorizedError } from "./auth.js"; import { OAuthTokens } from "../shared/auth.js"; describe("SSEClientTransport", () => { @@ -935,4 +935,206 @@ describe("SSEClientTransport", () => { expect(mockAuthProvider.redirectToAuthorization).toHaveBeenCalled(); }); }); + + describe("delegated authentication", () => { + let mockDelegatedAuthProvider: jest.Mocked; + + beforeEach(() => { + mockDelegatedAuthProvider = { + headers: jest.fn(), + authorize: jest.fn(), + }; + }); + + it("includes delegated auth headers in requests", async () => { + mockDelegatedAuthProvider.headers.mockResolvedValue({ + "Authorization": "Bearer delegated-token", + "X-API-Key": "api-key-123" + }); + + transport = new SSEClientTransport(resourceBaseUrl, { + delegatedAuthProvider: mockDelegatedAuthProvider, + }); + + await transport.start(); + + expect(lastServerRequest.headers.authorization).toBe("Bearer delegated-token"); + expect(lastServerRequest.headers["x-api-key"]).toBe("api-key-123"); + }); + + it("takes precedence over OAuth provider", async () => { + const mockOAuthProvider = { + get redirectUrl() { return "http://localhost/callback"; }, + get clientMetadata() { return { redirect_uris: ["http://localhost/callback"] }; }, + clientInformation: jest.fn(() => ({ client_id: "oauth-client", client_secret: "oauth-secret" })), + tokens: jest.fn(() => Promise.resolve({ access_token: "oauth-token", token_type: "Bearer" })), + saveTokens: jest.fn(), + redirectToAuthorization: jest.fn(), + saveCodeVerifier: jest.fn(), + codeVerifier: jest.fn(), + }; + + mockDelegatedAuthProvider.headers.mockResolvedValue({ + "Authorization": "Bearer delegated-token" + }); + + transport = new SSEClientTransport(resourceBaseUrl, { + authProvider: mockOAuthProvider, + delegatedAuthProvider: mockDelegatedAuthProvider, + }); + + await transport.start(); + + expect(lastServerRequest.headers.authorization).toBe("Bearer delegated-token"); + expect(mockOAuthProvider.tokens).not.toHaveBeenCalled(); + }); + + it("handles 401 during SSE connection with successful reauth", async () => { + mockDelegatedAuthProvider.headers.mockResolvedValueOnce(undefined); + mockDelegatedAuthProvider.authorize.mockResolvedValue(true); + mockDelegatedAuthProvider.headers.mockResolvedValueOnce({ + "Authorization": "Bearer new-delegated-token" + }); + + // Create server that returns 401 on first attempt, 200 on second + resourceServer.close(); + + let attemptCount = 0; + resourceServer = createServer((req, res) => { + lastServerRequest = req; + attemptCount++; + + if (attemptCount === 1) { + res.writeHead(401).end(); + return; + } + + res.writeHead(200, { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache, no-transform", + Connection: "keep-alive", + }); + res.write("event: endpoint\n"); + res.write(`data: ${resourceBaseUrl.href}\n\n`); + }); + + await new Promise((resolve) => { + resourceServer.listen(0, "127.0.0.1", () => { + const addr = resourceServer.address() as AddressInfo; + resourceBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + + transport = new SSEClientTransport(resourceBaseUrl, { + delegatedAuthProvider: mockDelegatedAuthProvider, + }); + + await transport.start(); + + expect(mockDelegatedAuthProvider.authorize).toHaveBeenCalledTimes(1); + expect(mockDelegatedAuthProvider.authorize).toHaveBeenCalledWith({ + serverUrl: resourceBaseUrl, + resourceMetadataUrl: undefined + }); + expect(attemptCount).toBe(2); + }); + + it("throws UnauthorizedError when reauth fails", async () => { + mockDelegatedAuthProvider.headers.mockResolvedValue(undefined); + mockDelegatedAuthProvider.authorize.mockResolvedValue(false); + + // Create server that always returns 401 + resourceServer.close(); + + resourceServer = createServer((req, res) => { + res.writeHead(401).end(); + }); + + await new Promise((resolve) => { + resourceServer.listen(0, "127.0.0.1", () => { + const addr = resourceServer.address() as AddressInfo; + resourceBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + + transport = new SSEClientTransport(resourceBaseUrl, { + delegatedAuthProvider: mockDelegatedAuthProvider, + }); + + await expect(transport.start()).rejects.toThrow(UnauthorizedError); + expect(mockDelegatedAuthProvider.authorize).toHaveBeenCalledTimes(1); + expect(mockDelegatedAuthProvider.authorize).toHaveBeenCalledWith({ + serverUrl: resourceBaseUrl, + resourceMetadataUrl: undefined + }); + }); + + it("handles 401 during POST request with successful reauth", async () => { + mockDelegatedAuthProvider.headers.mockResolvedValue({ + "Authorization": "Bearer delegated-token" + }); + mockDelegatedAuthProvider.authorize.mockResolvedValue(true); + + // Create server that accepts SSE but returns 401 on first POST, 200 on second + resourceServer.close(); + + let postAttempts = 0; + resourceServer = createServer((req, res) => { + lastServerRequest = req; + + switch (req.method) { + case "GET": + res.writeHead(200, { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache, no-transform", + Connection: "keep-alive", + }); + res.write("event: endpoint\n"); + res.write(`data: ${resourceBaseUrl.href}\n\n`); + break; + + case "POST": + postAttempts++; + if (postAttempts === 1) { + res.writeHead(401).end(); + } else { + res.writeHead(200).end(); + } + break; + } + }); + + await new Promise((resolve) => { + resourceServer.listen(0, "127.0.0.1", () => { + const addr = resourceServer.address() as AddressInfo; + resourceBaseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + + transport = new SSEClientTransport(resourceBaseUrl, { + delegatedAuthProvider: mockDelegatedAuthProvider, + }); + + await transport.start(); + + const message: JSONRPCMessage = { + jsonrpc: "2.0", + id: "1", + method: "test", + params: {}, + }; + + await transport.send(message); + + expect(mockDelegatedAuthProvider.authorize).toHaveBeenCalledTimes(1); + expect(mockDelegatedAuthProvider.authorize).toHaveBeenCalledWith({ + serverUrl: resourceBaseUrl, + resourceMetadataUrl: undefined + }); + expect(postAttempts).toBe(2); + }); + }); }); diff --git a/src/client/sse.ts b/src/client/sse.ts index 568a5159..5d6a4c33 100644 --- a/src/client/sse.ts +++ b/src/client/sse.ts @@ -1,7 +1,7 @@ import { EventSource, type ErrorEvent, type EventSourceInit } from "eventsource"; import { Transport, FetchLike } from "../shared/transport.js"; import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js"; -import { auth, AuthResult, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from "./auth.js"; +import { auth, AuthResult, DelegatedAuthClientProvider, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from "./auth.js"; export class SseError extends Error { constructor( @@ -33,6 +33,19 @@ export type SSEClientTransportOptions = { */ authProvider?: OAuthClientProvider; + /** + * A delegated authentication provider that handles authentication externally. + * + * When a `delegatedAuthProvider` is specified: + * 1. Authentication headers are obtained via `headers()` and added to requests. + * 2. On 401 responses, `authorize()` is called to perform authentication. + * 3. If `authorize()` returns `true`, the request is retried. + * 4. If `authorize()` returns `false`, an `UnauthorizedError` is thrown. + * + * This provider takes precedence over `authProvider` when both are specified. + */ + delegatedAuthProvider?: DelegatedAuthClientProvider; + /** * Customizes the initial SSE request to the server (the request that begins the stream). * @@ -67,6 +80,7 @@ export class SSEClientTransport implements Transport { private _eventSourceInit?: EventSourceInit; private _requestInit?: RequestInit; private _authProvider?: OAuthClientProvider; + private _delegatedAuthProvider?: DelegatedAuthClientProvider; private _fetch?: FetchLike; private _protocolVersion?: string; @@ -83,6 +97,7 @@ export class SSEClientTransport implements Transport { this._eventSourceInit = opts?.eventSourceInit; this._requestInit = opts?.requestInit; this._authProvider = opts?.authProvider; + this._delegatedAuthProvider = opts?.delegatedAuthProvider; this._fetch = opts?.fetch; } @@ -107,13 +122,25 @@ export class SSEClientTransport implements Transport { } private async _commonHeaders(): Promise { - const headers: HeadersInit = {}; - if (this._authProvider) { + const headers = { + ...this._requestInit?.headers, + } as HeadersInit & Record; + + if (this._delegatedAuthProvider) { + const delegatedHeaders = await this._delegatedAuthProvider.headers(); + if (delegatedHeaders) { + const normalizedHeaders = this._normalizeHeaders(delegatedHeaders); + for (const [key, value] of Object.entries(normalizedHeaders)) { + headers[key] = value; + } + } + } else if (this._authProvider) { const tokens = await this._authProvider.tokens(); if (tokens) { headers["Authorization"] = `Bearer ${tokens.access_token}`; } } + if (this._protocolVersion) { headers["mcp-protocol-version"] = this._protocolVersion; } @@ -123,6 +150,20 @@ export class SSEClientTransport implements Transport { ); } + private _normalizeHeaders(headers: HeadersInit | undefined): Record { + if (!headers) return {}; + + if (headers instanceof Headers) { + return Object.fromEntries(headers.entries()); + } + + if (Array.isArray(headers)) { + return Object.fromEntries(headers); + } + + return { ...headers as Record }; + } + private _startOrAuth(): Promise { const fetchImpl = (this?._eventSourceInit?.fetch ?? this._fetch ?? fetch) as typeof fetch return new Promise((resolve, reject) => { @@ -148,11 +189,30 @@ export class SSEClientTransport implements Transport { ); this._abortController = new AbortController(); - this._eventSource.onerror = (event) => { - if (event.code === 401 && this._authProvider) { + this._eventSource.onerror = async (event) => { + if (event.code === 401) { + if (this._delegatedAuthProvider) { + try { + const authorized = await this._delegatedAuthProvider.authorize({ + serverUrl: this._url, + resourceMetadataUrl: this._resourceMetadataUrl + }); + if (authorized) { + this._startOrAuth().then(resolve, reject); + return; + } + reject(new UnauthorizedError("Delegated authentication failed")); + return; + } catch (error) { + reject(error); + return; + } + } - this._authThenStart().then(resolve, reject); - return; + if (this._authProvider) { + this._authThenStart().then(resolve, reject); + return; + } } const error = new SseError(event.code, event.message, event); @@ -248,17 +308,29 @@ export class SSEClientTransport implements Transport { const response = await (this._fetch ?? fetch)(this._endpoint, init); if (!response.ok) { - if (response.status === 401 && this._authProvider) { + if (response.status === 401) { + if (this._delegatedAuthProvider) { + const authorized = await this._delegatedAuthProvider.authorize({ + serverUrl: this._url, + resourceMetadataUrl: this._resourceMetadataUrl + }); + if (authorized) { + return this.send(message); + } + throw new UnauthorizedError("Delegated authentication failed"); + } - this._resourceMetadataUrl = extractResourceMetadataUrl(response); + if (this._authProvider) { + this._resourceMetadataUrl = extractResourceMetadataUrl(response); - const result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl }); - if (result !== "AUTHORIZED") { - throw new UnauthorizedError(); - } + const result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl }); + if (result !== "AUTHORIZED") { + throw new UnauthorizedError(); + } - // Purposely _not_ awaited, so we don't call onerror twice - return this.send(message); + // Purposely _not_ awaited, so we don't call onerror twice + return this.send(message); + } } const text = await response.text().catch(() => null); diff --git a/src/client/streamableHttp.test.ts b/src/client/streamableHttp.test.ts index dcd76528..e6b6b4f0 100644 --- a/src/client/streamableHttp.test.ts +++ b/src/client/streamableHttp.test.ts @@ -1,5 +1,5 @@ import { StreamableHTTPClientTransport, StreamableHTTPReconnectionOptions, StartSSEOptions } from "./streamableHttp.js"; -import { OAuthClientProvider, UnauthorizedError } from "./auth.js"; +import { DelegatedAuthClientProvider, OAuthClientProvider, UnauthorizedError } from "./auth.js"; import { JSONRPCMessage } from "../types.js"; @@ -592,4 +592,214 @@ describe("StreamableHTTPClientTransport", () => { await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1); }); + + describe("delegated authentication", () => { + let mockDelegatedAuthProvider: jest.Mocked; + + beforeEach(() => { + mockDelegatedAuthProvider = { + headers: jest.fn(), + authorize: jest.fn(), + }; + }); + + it("includes delegated auth headers in requests", async () => { + mockDelegatedAuthProvider.headers.mockResolvedValue({ + "Authorization": "Bearer delegated-token", + "X-API-Key": "api-key-123" + }); + + transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { + delegatedAuthProvider: mockDelegatedAuthProvider, + }); + + (global.fetch as jest.Mock).mockResolvedValueOnce({ + ok: true, + status: 202, + headers: new Headers(), + }); + + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: "test-id" + }; + + await transport.send(message); + + const call = (global.fetch as jest.Mock).mock.calls[0]; + const headers = call[1].headers as Headers; + expect(headers.get("authorization")).toBe("Bearer delegated-token"); + expect(headers.get("x-api-key")).toBe("api-key-123"); + }); + + it("takes precedence over OAuth provider", async () => { + mockDelegatedAuthProvider.headers.mockResolvedValue({ + "Authorization": "Bearer delegated-token" + }); + + transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { + authProvider: mockAuthProvider, + delegatedAuthProvider: mockDelegatedAuthProvider, + }); + + (global.fetch as jest.Mock).mockResolvedValueOnce({ + ok: true, + status: 202, + headers: new Headers(), + }); + + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: "test-id" + }; + + await transport.send(message); + + const call = (global.fetch as jest.Mock).mock.calls[0]; + const headers = call[1].headers as Headers; + expect(headers.get("authorization")).toBe("Bearer delegated-token"); + expect(mockAuthProvider.tokens).not.toHaveBeenCalled(); + }); + + it("handles 401 during SSE start with successful reauth", async () => { + mockDelegatedAuthProvider.headers.mockResolvedValue({ + "Authorization": "Bearer delegated-token" + }); + mockDelegatedAuthProvider.authorize.mockResolvedValue(true); + + transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { + delegatedAuthProvider: mockDelegatedAuthProvider, + }); + + // Test the internal SSE start method directly + const startMethod = transport["_startOrAuthSse"]; + + (global.fetch as jest.Mock) + .mockResolvedValueOnce({ + ok: false, + status: 401, + statusText: "Unauthorized", + headers: new Headers() + }) + .mockResolvedValueOnce({ + ok: false, + status: 405, + statusText: "Method Not Allowed", + headers: new Headers() + }); + + await startMethod.call(transport, { resumptionToken: undefined }); + + expect(mockDelegatedAuthProvider.authorize).toHaveBeenCalledTimes(1); + expect(mockDelegatedAuthProvider.authorize).toHaveBeenCalledWith({ + serverUrl: new URL("http://localhost:1234/mcp"), + resourceMetadataUrl: undefined + }); + expect(global.fetch).toHaveBeenCalledTimes(2); + }); + + it("throws UnauthorizedError when reauth fails during SSE start", async () => { + mockDelegatedAuthProvider.headers.mockResolvedValue({ + "Authorization": "Bearer delegated-token" + }); + mockDelegatedAuthProvider.authorize.mockResolvedValue(false); + + transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { + delegatedAuthProvider: mockDelegatedAuthProvider, + }); + + // Test the internal SSE start method directly + const startMethod = transport["_startOrAuthSse"]; + + (global.fetch as jest.Mock).mockResolvedValueOnce({ + ok: false, + status: 401, + statusText: "Unauthorized", + headers: new Headers() + }); + + await expect(startMethod.call(transport, { resumptionToken: undefined })).rejects.toThrow(UnauthorizedError); + expect(mockDelegatedAuthProvider.authorize).toHaveBeenCalledTimes(1); + expect(mockDelegatedAuthProvider.authorize).toHaveBeenCalledWith({ + serverUrl: new URL("http://localhost:1234/mcp"), + resourceMetadataUrl: undefined + }); + }); + + it("handles 401 during POST request with successful reauth", async () => { + mockDelegatedAuthProvider.headers.mockResolvedValue({ + "Authorization": "Bearer delegated-token" + }); + mockDelegatedAuthProvider.authorize.mockResolvedValue(true); + + transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { + delegatedAuthProvider: mockDelegatedAuthProvider, + }); + + (global.fetch as jest.Mock) + .mockResolvedValueOnce({ + ok: false, + status: 401, + statusText: "Unauthorized", + headers: new Headers() + }) + .mockResolvedValueOnce({ + ok: true, + status: 202, + headers: new Headers(), + }); + + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: "test-id" + }; + + await transport.send(message); + + expect(mockDelegatedAuthProvider.authorize).toHaveBeenCalledTimes(1); + expect(mockDelegatedAuthProvider.authorize).toHaveBeenCalledWith({ + serverUrl: new URL("http://localhost:1234/mcp"), + resourceMetadataUrl: undefined + }); + expect(global.fetch).toHaveBeenCalledTimes(2); + }); + + it("throws UnauthorizedError when reauth fails during POST request", async () => { + mockDelegatedAuthProvider.headers.mockResolvedValue({ + "Authorization": "Bearer delegated-token" + }); + mockDelegatedAuthProvider.authorize.mockResolvedValue(false); + + transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { + delegatedAuthProvider: mockDelegatedAuthProvider, + }); + + (global.fetch as jest.Mock).mockResolvedValueOnce({ + ok: false, + status: 401, + statusText: "Unauthorized", + headers: new Headers() + }); + + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: "test-id" + }; + + await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + expect(mockDelegatedAuthProvider.authorize).toHaveBeenCalledTimes(1); + expect(mockDelegatedAuthProvider.authorize).toHaveBeenCalledWith({ + serverUrl: new URL("http://localhost:1234/mcp"), + resourceMetadataUrl: undefined + }); + }); + }); }); diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts index b81f1a5d..9bf306da 100644 --- a/src/client/streamableHttp.ts +++ b/src/client/streamableHttp.ts @@ -1,6 +1,6 @@ import { Transport, FetchLike } from "../shared/transport.js"; import { isInitializedNotification, isJSONRPCRequest, isJSONRPCResponse, JSONRPCMessage, JSONRPCMessageSchema } from "../types.js"; -import { auth, AuthResult, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from "./auth.js"; +import { auth, AuthResult, DelegatedAuthClientProvider, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from "./auth.js"; import { EventSourceParserStream } from "eventsource-parser/stream"; // Default reconnection options for StreamableHTTP connections @@ -94,6 +94,19 @@ export type StreamableHTTPClientTransportOptions = { */ authProvider?: OAuthClientProvider; + /** + * A delegated authentication provider that handles authentication externally. + * + * When a `delegatedAuthProvider` is specified: + * 1. Authentication headers are obtained via `tokens()` and added to requests. + * 2. On 401 responses, `authorize()` is called to perform authentication. + * 3. If `authorize()` returns `true`, the request is retried. + * 4. If `authorize()` returns `false`, an `UnauthorizedError` is thrown. + * + * This provider takes precedence over `authProvider` when both are specified. + */ + delegatedAuthProvider?: DelegatedAuthClientProvider; + /** * Customizes HTTP requests to the server. */ @@ -127,6 +140,7 @@ export class StreamableHTTPClientTransport implements Transport { private _resourceMetadataUrl?: URL; private _requestInit?: RequestInit; private _authProvider?: OAuthClientProvider; + private _delegatedAuthProvider?: DelegatedAuthClientProvider; private _fetch?: FetchLike; private _sessionId?: string; private _reconnectionOptions: StreamableHTTPReconnectionOptions; @@ -144,6 +158,7 @@ export class StreamableHTTPClientTransport implements Transport { this._resourceMetadataUrl = undefined; this._requestInit = opts?.requestInit; this._authProvider = opts?.authProvider; + this._delegatedAuthProvider = opts?.delegatedAuthProvider; this._fetch = opts?.fetch; this._sessionId = opts?.sessionId; this._reconnectionOptions = opts?.reconnectionOptions ?? DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS; @@ -171,7 +186,15 @@ export class StreamableHTTPClientTransport implements Transport { private async _commonHeaders(): Promise { const headers: HeadersInit & Record = {}; - if (this._authProvider) { + + if (this._delegatedAuthProvider) { + const delegatedHeaders = await this._delegatedAuthProvider.headers(); + if (delegatedHeaders) { + for (const [key, value] of Object.entries(this._normalizeHeaders(delegatedHeaders))) { + headers[key] = value; + } + } + } else if (this._authProvider) { const tokens = await this._authProvider.tokens(); if (tokens) { headers["Authorization"] = `Bearer ${tokens.access_token}`; @@ -214,9 +237,22 @@ const response = await (this._fetch ?? fetch)(this._url, { }); if (!response.ok) { - if (response.status === 401 && this._authProvider) { - // Need to authenticate - return await this._authThenStart(); + if (response.status === 401) { + if (this._delegatedAuthProvider) { + const authorized = await this._delegatedAuthProvider.authorize({ + serverUrl: this._url, + resourceMetadataUrl: this._resourceMetadataUrl + }); + if (authorized) { + return await this._startOrAuthSse(options); + } + throw new UnauthorizedError("Delegated authentication failed"); + } + + if (this._authProvider) { + // Need to authenticate + return await this._authThenStart(); + } } // 405 indicates that the server does not offer an SSE stream at GET endpoint @@ -430,17 +466,29 @@ const response = await (this._fetch ?? fetch)(this._url, init); } if (!response.ok) { - if (response.status === 401 && this._authProvider) { + if (response.status === 401) { + if (this._delegatedAuthProvider) { + const authorized = await this._delegatedAuthProvider.authorize({ + serverUrl: this._url, + resourceMetadataUrl: this._resourceMetadataUrl + }); + if (authorized) { + return this.send(message); + } + throw new UnauthorizedError("Delegated authentication failed"); + } - this._resourceMetadataUrl = extractResourceMetadataUrl(response); + if (this._authProvider) { + this._resourceMetadataUrl = extractResourceMetadataUrl(response); - const result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl }); - if (result !== "AUTHORIZED") { - throw new UnauthorizedError(); - } + const result = await auth(this._authProvider, { serverUrl: this._url, resourceMetadataUrl: this._resourceMetadataUrl }); + if (result !== "AUTHORIZED") { + throw new UnauthorizedError(); + } - // Purposely _not_ awaited, so we don't call onerror twice - return this.send(message); + // Purposely _not_ awaited, so we don't call onerror twice + return this.send(message); + } } const text = await response.text().catch(() => null);