Skip to content

Commit

Permalink
Merge pull request #157 from modelcontextprotocol/jerome/fix/no-mcp-h…
Browse files Browse the repository at this point in the history
…eader-in-metadata-fetching

* Check token expiry in requireAuth middleware
* Gracefully handle cors issues in metadata fetching (experienced due to MCP-Protocol-Version header)
* Don't set client secret expiry for public clients
  • Loading branch information
jerome3o-anthropic authored Feb 25, 2025
2 parents 5c07636 + c521710 commit 42b1738
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 8 deletions.
58 changes: 58 additions & 0 deletions src/client/auth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,64 @@ describe("OAuth Authorization", () => {
});
});

it("returns metadata when first fetch fails but second without MCP header succeeds", async () => {
// Set up a counter to control behavior
let callCount = 0;

// Mock implementation that changes behavior based on call count
mockFetch.mockImplementation((_url, _options) => {
callCount++;

if (callCount === 1) {
// First call with MCP header - fail with TypeError (simulating CORS error)
// We need to use TypeError specifically because that's what the implementation checks for
return Promise.reject(new TypeError("Network error"));
} else {
// Second call without header - succeed
return Promise.resolve({
ok: true,
status: 200,
json: async () => validMetadata
});
}
});

// Should succeed with the second call
const metadata = await discoverOAuthMetadata("https://auth.example.com");
expect(metadata).toEqual(validMetadata);

// Verify both calls were made
expect(mockFetch).toHaveBeenCalledTimes(2);

// Verify first call had MCP header
expect(mockFetch.mock.calls[0][1]?.headers).toHaveProperty("MCP-Protocol-Version");
});

it("throws an error when all fetch attempts fail", async () => {
// Set up a counter to control behavior
let callCount = 0;

// Mock implementation that changes behavior based on call count
mockFetch.mockImplementation((_url, _options) => {
callCount++;

if (callCount === 1) {
// First call - fail with TypeError
return Promise.reject(new TypeError("First failure"));
} else {
// Second call - fail with different error
return Promise.reject(new Error("Second failure"));
}
});

// Should fail with the second error
await expect(discoverOAuthMetadata("https://auth.example.com"))
.rejects.toThrow("Second failure");

// Verify both calls were made
expect(mockFetch).toHaveBeenCalledTimes(2);
});

it("returns undefined when discovery endpoint returns 404", async () => {
mockFetch.mockResolvedValueOnce({
ok: false,
Expand Down
18 changes: 14 additions & 4 deletions src/client/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,21 @@ export async function discoverOAuthMetadata(
opts?: { protocolVersion?: string },
): Promise<OAuthMetadata | undefined> {
const url = new URL("/.well-known/oauth-authorization-server", serverUrl);
const response = await fetch(url, {
headers: {
"MCP-Protocol-Version": opts?.protocolVersion ?? LATEST_PROTOCOL_VERSION
let response: Response;
try {
response = await fetch(url, {
headers: {
"MCP-Protocol-Version": opts?.protocolVersion ?? LATEST_PROTOCOL_VERSION
}
});
} catch (error) {
// CORS errors come back as TypeError
if (error instanceof TypeError) {
response = await fetch(url);
} else {
throw error;
}
});
}

if (response.status === 404) {
return undefined;
Expand Down
31 changes: 31 additions & 0 deletions src/server/auth/handlers/register.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,37 @@ describe('Client Registration Handler', () => {

expect(response.status).toBe(201);
expect(response.body.client_secret).toBeUndefined();
expect(response.body.client_secret_expires_at).toBeUndefined();
});

it('sets client_secret_expires_at for public clients only', async () => {
// Test for public client (token_endpoint_auth_method not 'none')
const publicClientMetadata: OAuthClientMetadata = {
redirect_uris: ['https://example.com/callback'],
token_endpoint_auth_method: 'client_secret_basic'
};

const publicResponse = await supertest(app)
.post('/register')
.send(publicClientMetadata);

expect(publicResponse.status).toBe(201);
expect(publicResponse.body.client_secret).toBeDefined();
expect(publicResponse.body.client_secret_expires_at).toBeDefined();

// Test for non-public client (token_endpoint_auth_method is 'none')
const nonPublicClientMetadata: OAuthClientMetadata = {
redirect_uris: ['https://example.com/callback'],
token_endpoint_auth_method: 'none'
};

const nonPublicResponse = await supertest(app)
.post('/register')
.send(nonPublicClientMetadata);

expect(nonPublicResponse.status).toBe(201);
expect(nonPublicResponse.body.client_secret).toBeUndefined();
expect(nonPublicResponse.body.client_secret_expires_at).toBeUndefined();
});

it('sets expiry based on clientSecretExpirySeconds', async () => {
Expand Down
14 changes: 10 additions & 4 deletions src/server/auth/handlers/register.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,26 @@ export function clientRegistrationHandler({
}

const clientMetadata = parseResult.data;
const isPublicClient = clientMetadata.token_endpoint_auth_method === 'none'

// Generate client credentials
const clientId = crypto.randomUUID();
const clientSecret = clientMetadata.token_endpoint_auth_method !== 'none'
? crypto.randomBytes(32).toString('hex')
: undefined;
const clientSecret = isPublicClient
? undefined
: crypto.randomBytes(32).toString('hex');
const clientIdIssuedAt = Math.floor(Date.now() / 1000);

// Calculate client secret expiry time
const clientsDoExpire = clientSecretExpirySeconds > 0
const secretExpiryTime = clientsDoExpire ? clientIdIssuedAt + clientSecretExpirySeconds : 0
const clientSecretExpiresAt = isPublicClient ? undefined : secretExpiryTime

let clientInfo: OAuthClientInformationFull = {
...clientMetadata,
client_id: clientId,
client_secret: clientSecret,
client_id_issued_at: clientIdIssuedAt,
client_secret_expires_at: clientSecretExpirySeconds > 0 ? clientIdIssuedAt + clientSecretExpirySeconds : 0
client_secret_expires_at: clientSecretExpiresAt,
};

clientInfo = await clientsStore.registerClient!(clientInfo);
Expand Down
51 changes: 51 additions & 0 deletions src/server/auth/middleware/bearerAuth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,57 @@ describe("requireBearerAuth middleware", () => {
expect(mockResponse.status).not.toHaveBeenCalled();
expect(mockResponse.json).not.toHaveBeenCalled();
});

it("should reject expired tokens", async () => {
const expiredAuthInfo: AuthInfo = {
token: "expired-token",
clientId: "client-123",
scopes: ["read", "write"],
expiresAt: Math.floor(Date.now() / 1000) - 100, // Token expired 100 seconds ago
};
mockVerifyAccessToken.mockResolvedValue(expiredAuthInfo);

mockRequest.headers = {
authorization: "Bearer expired-token",
};

const middleware = requireBearerAuth({ provider: mockProvider });
await middleware(mockRequest as Request, mockResponse as Response, nextFunction);

expect(mockVerifyAccessToken).toHaveBeenCalledWith("expired-token");
expect(mockResponse.status).toHaveBeenCalledWith(401);
expect(mockResponse.set).toHaveBeenCalledWith(
"WWW-Authenticate",
expect.stringContaining('Bearer error="invalid_token"')
);
expect(mockResponse.json).toHaveBeenCalledWith(
expect.objectContaining({ error: "invalid_token", error_description: "Token has expired" })
);
expect(nextFunction).not.toHaveBeenCalled();
});

it("should accept non-expired tokens", async () => {
const nonExpiredAuthInfo: AuthInfo = {
token: "valid-token",
clientId: "client-123",
scopes: ["read", "write"],
expiresAt: Math.floor(Date.now() / 1000) + 3600, // Token expires in an hour
};
mockVerifyAccessToken.mockResolvedValue(nonExpiredAuthInfo);

mockRequest.headers = {
authorization: "Bearer valid-token",
};

const middleware = requireBearerAuth({ provider: mockProvider });
await middleware(mockRequest as Request, mockResponse as Response, nextFunction);

expect(mockVerifyAccessToken).toHaveBeenCalledWith("valid-token");
expect(mockRequest.auth).toEqual(nonExpiredAuthInfo);
expect(nextFunction).toHaveBeenCalled();
expect(mockResponse.status).not.toHaveBeenCalled();
expect(mockResponse.json).not.toHaveBeenCalled();
});

it("should require specific scopes when configured", async () => {
const authInfo: AuthInfo = {
Expand Down
5 changes: 5 additions & 0 deletions src/server/auth/middleware/bearerAuth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ export function requireBearerAuth({ provider, requiredScopes = [] }: BearerAuthM
}
}

// Check if the token is expired
if (!!authInfo.expiresAt && authInfo.expiresAt < Date.now() / 1000) {
throw new InvalidTokenError("Token has expired");
}

req.auth = authInfo;
next();
} catch (error) {
Expand Down

0 comments on commit 42b1738

Please sign in to comment.