diff --git a/src/lib/node-oauth-client-provider.ts b/src/lib/node-oauth-client-provider.ts index f070a88..0836182 100644 --- a/src/lib/node-oauth-client-provider.ts +++ b/src/lib/node-oauth-client-provider.ts @@ -24,6 +24,7 @@ export class NodeOAuthClientProvider implements OAuthClientProvider { private softwareVersion: string private staticOAuthClientMetadata: StaticOAuthClientMetadata private staticOAuthClientInfo: StaticOAuthClientInformationFull + private useOidcConfig: boolean /** * Creates a new NodeOAuthClientProvider @@ -38,6 +39,7 @@ export class NodeOAuthClientProvider implements OAuthClientProvider { this.softwareVersion = options.softwareVersion || MCP_REMOTE_VERSION this.staticOAuthClientMetadata = options.staticOAuthClientMetadata this.staticOAuthClientInfo = options.staticOAuthClientInfo + this.useOidcConfig = !!options.useOidcConfig } get redirectUrl(): string { @@ -193,4 +195,13 @@ export class NodeOAuthClientProvider implements OAuthClientProvider { if (DEBUG) await debugLog(this.serverUrlHash, 'Code verifier found:', !!verifier) return verifier } + + /** + * Gets the PKCE code verifier + * @returns The code verifier + */ + async useOidcProviderConfiguration(): Promise { + if (DEBUG) await debugLog(this.serverUrlHash, 'Use OpenID Configuration:', !!this.useOidcConfig) + return !!this.useOidcConfig; + } } diff --git a/src/lib/types.ts b/src/lib/types.ts index 74b3a96..800a934 100644 --- a/src/lib/types.ts +++ b/src/lib/types.ts @@ -27,6 +27,8 @@ export interface OAuthProviderOptions { staticOAuthClientMetadata?: StaticOAuthClientMetadata /** Static OAuth client information to use instead of OAuth registration */ staticOAuthClientInfo?: StaticOAuthClientInformationFull + /** Static OAuth client information to use instead of OAuth registration */ + useOidcConfig?: boolean } /** diff --git a/src/lib/utils.ts b/src/lib/utils.ts index ba544aa..fbda2a2 100644 --- a/src/lib/utils.ts +++ b/src/lib/utils.ts @@ -546,6 +546,7 @@ export async function parseCommandLineArgs(args: string[], usage: string) { const serverUrl = args[0] const specifiedPort = args[1] ? parseInt(args[1]) : undefined const allowHttp = args.includes('--allow-http') + const useOidcConfig = args.includes('--use-oidc-config') // Check for debug flag const debug = args.includes('--debug') @@ -669,7 +670,7 @@ export async function parseCommandLineArgs(args: string[], usage: string) { }) } - return { serverUrl, callbackPort, headers, transportStrategy, host, debug, staticOAuthClientMetadata, staticOAuthClientInfo } + return { serverUrl, callbackPort, headers, transportStrategy, host, debug, staticOAuthClientMetadata, staticOAuthClientInfo, useOidcConfig } } /** diff --git a/src/proxy.ts b/src/proxy.ts index 3847d35..0dfb7bc 100644 --- a/src/proxy.ts +++ b/src/proxy.ts @@ -36,6 +36,7 @@ async function runProxy( host: string, staticOAuthClientMetadata: StaticOAuthClientMetadata, staticOAuthClientInfo: StaticOAuthClientInformationFull, + useOidcConfig: boolean = false, ) { // Set up event emitter for auth flow const events = new EventEmitter() @@ -54,6 +55,7 @@ async function runProxy( clientName: 'MCP CLI Proxy', staticOAuthClientMetadata, staticOAuthClientInfo, + useOidcConfig, }) // Create the STDIO transport for local connections @@ -143,8 +145,8 @@ to the CA certificate file. If using claude_desktop_config.json, this might look // Parse command-line arguments and run the proxy parseCommandLineArgs(process.argv.slice(2), 'Usage: npx tsx proxy.ts [callback-port] [--debug]') - .then(({ serverUrl, callbackPort, headers, transportStrategy, host, debug, staticOAuthClientMetadata, staticOAuthClientInfo }) => { - return runProxy(serverUrl, callbackPort, headers, transportStrategy, host, staticOAuthClientMetadata, staticOAuthClientInfo) + .then(({ serverUrl, callbackPort, headers, transportStrategy, host, debug, staticOAuthClientMetadata, staticOAuthClientInfo, useOidcConfig }) => { + return runProxy(serverUrl, callbackPort, headers, transportStrategy, host, staticOAuthClientMetadata, staticOAuthClientInfo, useOidcConfig) }) .catch((error) => { log('Fatal error:', error)