diff --git a/crates/token_proxy_core/src/proxy/http/mod.rs b/crates/token_proxy_core/src/proxy/http/mod.rs index 9f6f088..1063d47 100644 --- a/crates/token_proxy_core/src/proxy/http/mod.rs +++ b/crates/token_proxy_core/src/proxy/http/mod.rs @@ -175,11 +175,12 @@ fn parse_query_key(query: Option<&str>) -> Result, String> { #[derive(Clone, Default)] pub(crate) struct RequestAuth { pub(crate) openai_bearer: Option, - pub(crate) anthropic_api_key: Option, + pub(crate) anthropic_request_auth: Option, pub(crate) gemini_api_key: Option, pub(crate) authorization_fallback: Option, } +#[derive(Clone)] pub(crate) struct UpstreamAuthHeader { pub(crate) name: HeaderName, pub(crate) value: HeaderValue, @@ -188,6 +189,7 @@ pub(crate) struct UpstreamAuthHeader { pub(crate) fn resolve_request_auth( config: &ProxyConfig, headers: &HeaderMap, + path: &str, ) -> Result { let mut auth = RequestAuth::default(); // When local auth is enabled, request auth headers are reserved for local access and not used upstream. @@ -202,15 +204,8 @@ pub(crate) fn resolve_request_auth( ); } - // Anthropic uses `x-api-key`; allow explicit overrides as well. - if let Some(value) = headers - .get(X_API_KEY) - .or_else(|| headers.get(X_ANTHROPIC_API_KEY)) - { - let Ok(_) = value.to_str() else { - return Err("Upstream API key is invalid.".to_string()); - }; - auth.anthropic_api_key = Some(value.clone()); + if is_anthropic_path(path) { + auth.anthropic_request_auth = resolve_anthropic_request_auth(headers)?; } if let Some(value) = headers.get(AUTHORIZATION) { @@ -230,6 +225,44 @@ pub(crate) fn resolve_request_auth( Ok(auth) } +fn resolve_anthropic_request_auth( + headers: &HeaderMap, +) -> Result, String> { + if let Some(value) = headers.get(X_API_KEY) { + let Ok(_) = value.to_str() else { + return Err("Upstream API key is invalid.".to_string()); + }; + return Ok(Some(UpstreamAuthHeader { + name: HeaderName::from_static(X_API_KEY), + value: value.clone(), + })); + } + + if let Some(value) = headers.get(X_ANTHROPIC_API_KEY) { + let Ok(_) = value.to_str() else { + return Err("Upstream API key is invalid.".to_string()); + }; + return Ok(Some(UpstreamAuthHeader { + name: HeaderName::from_static(X_ANTHROPIC_API_KEY), + value: value.clone(), + })); + } + + let Some(value) = headers.get(AUTHORIZATION) else { + return Ok(None); + }; + let Ok(value_str) = value.to_str() else { + return Err("Upstream API key is invalid.".to_string()); + }; + if extract_bearer_token(value_str).is_none() { + return Err("Upstream API key is invalid.".to_string()); + } + Ok(Some(UpstreamAuthHeader { + name: AUTHORIZATION, + value: value.clone(), + })) +} + pub(crate) fn resolve_upstream_auth( provider: &str, upstream: &UpstreamRuntime, @@ -240,40 +273,66 @@ pub(crate) fn resolve_upstream_auth( upstream_id = %upstream.id, has_upstream_key = upstream.api_key.is_some(), has_openai_bearer = request_auth.openai_bearer.is_some(), - has_anthropic_key = request_auth.anthropic_api_key.is_some(), + has_anthropic_key = request_auth.anthropic_request_auth.is_some(), has_auth_fallback = request_auth.authorization_fallback.is_some(), "resolving upstream auth" ); match provider { "anthropic" => { - let value = match upstream.api_key.as_ref() { - Some(key) => { - tracing::debug!("using upstream.api_key for Anthropic"); - HeaderValue::from_str(key).map_err(|_| { - error_response( - StatusCode::UNAUTHORIZED, - "Upstream API key contains invalid characters.", - ) - })? - } - None => { - let Some(value) = request_auth.anthropic_api_key.clone().or_else(|| { - request_auth - .authorization_fallback - .as_ref() - .and_then(|value| value.to_str().ok()) - .and_then(extract_bearer_token) - .and_then(|value| HeaderValue::from_str(value).ok()) - }) else { - tracing::warn!("no API key for Anthropic"); - return Ok(None); + if let Some(key) = upstream.api_key.as_ref() { + tracing::debug!("using upstream.api_key for Anthropic"); + if let Some(request_header) = request_auth.anthropic_request_auth.as_ref() { + let value = if request_header.name == AUTHORIZATION { + bearer_header(key).ok_or_else(|| { + error_response( + StatusCode::UNAUTHORIZED, + "Upstream API key contains invalid characters.", + ) + })? + } else { + HeaderValue::from_str(key).map_err(|_| { + error_response( + StatusCode::UNAUTHORIZED, + "Upstream API key contains invalid characters.", + ) + })? }; - tracing::debug!("using request auth fallback for Anthropic"); - value + return Ok(Some(UpstreamAuthHeader { + name: request_header.name.clone(), + value, + })); } + + let value = HeaderValue::from_str(key).map_err(|_| { + error_response( + StatusCode::UNAUTHORIZED, + "Upstream API key contains invalid characters.", + ) + })?; + return Ok(Some(UpstreamAuthHeader { + name: HeaderName::from_static(X_API_KEY), + value, + })); + } + + if let Some(header) = request_auth.anthropic_request_auth.clone() { + tracing::debug!("using native anthropic request auth header"); + return Ok(Some(header)); + } + + let Some(value) = request_auth + .authorization_fallback + .as_ref() + .and_then(|value| value.to_str().ok()) + .and_then(extract_bearer_token) + .and_then(|value| HeaderValue::from_str(value).ok()) + else { + tracing::warn!("no API key for Anthropic"); + return Ok(None); }; + tracing::debug!("using request auth fallback for Anthropic"); Ok(Some(UpstreamAuthHeader { name: HeaderName::from_static(X_API_KEY), value, diff --git a/crates/token_proxy_core/src/proxy/http/tests.rs b/crates/token_proxy_core/src/proxy/http/tests.rs index c097ef9..60cb0ed 100644 --- a/crates/token_proxy_core/src/proxy/http/tests.rs +++ b/crates/token_proxy_core/src/proxy/http/tests.rs @@ -148,7 +148,30 @@ fn anthropic_upstream_auth_accepts_authorization_bearer_fallback() { HeaderValue::from_static("Bearer anthropic-request-key"), ); - let request_auth = resolve_request_auth(&config, &headers).expect("request auth"); + let request_auth = + resolve_request_auth(&config, &headers, "/v1/messages").expect("request auth"); + let auth = resolve_upstream_auth("anthropic", &upstream_without_key(), &request_auth) + .expect("upstream auth") + .expect("anthropic auth header"); + + assert_eq!(auth.name, AUTHORIZATION); + assert_eq!( + auth.value.to_str().ok(), + Some("Bearer anthropic-request-key") + ); +} + +#[test] +fn anthropic_upstream_auth_defaults_to_x_api_key_for_non_native_inbound_requests() { + let config = config_without_local(); + let mut headers = HeaderMap::new(); + headers.insert( + AUTHORIZATION, + HeaderValue::from_static("Bearer anthropic-request-key"), + ); + + let request_auth = + resolve_request_auth(&config, &headers, "/v1/chat/completions").expect("request auth"); let auth = resolve_upstream_auth("anthropic", &upstream_without_key(), &request_auth) .expect("upstream auth") .expect("anthropic auth header"); @@ -156,3 +179,45 @@ fn anthropic_upstream_auth_accepts_authorization_bearer_fallback() { assert_eq!(auth.name.as_str(), "x-api-key"); assert_eq!(auth.value.to_str().ok(), Some("anthropic-request-key")); } + +#[test] +fn anthropic_upstream_auth_reuses_authorization_header_name_with_upstream_key() { + let config = config_without_local(); + let mut headers = HeaderMap::new(); + headers.insert( + AUTHORIZATION, + HeaderValue::from_static("Bearer local-debug-key"), + ); + + let request_auth = + resolve_request_auth(&config, &headers, "/v1/messages").expect("request auth"); + let mut upstream = upstream_without_key(); + upstream.api_key = Some("upstream-anthropic-key".to_string()); + let auth = resolve_upstream_auth("anthropic", &upstream, &request_auth) + .expect("upstream auth") + .expect("anthropic auth header"); + + assert_eq!(auth.name, AUTHORIZATION); + assert_eq!( + auth.value.to_str().ok(), + Some("Bearer upstream-anthropic-key") + ); +} + +#[test] +fn anthropic_upstream_auth_reuses_x_api_key_header_name_with_upstream_key() { + let config = config_without_local(); + let mut headers = HeaderMap::new(); + headers.insert("x-api-key", HeaderValue::from_static("local-debug-key")); + + let request_auth = + resolve_request_auth(&config, &headers, "/v1/messages").expect("request auth"); + let mut upstream = upstream_without_key(); + upstream.api_key = Some("upstream-anthropic-key".to_string()); + let auth = resolve_upstream_auth("anthropic", &upstream, &request_auth) + .expect("upstream auth") + .expect("anthropic auth header"); + + assert_eq!(auth.name.as_str(), "x-api-key"); + assert_eq!(auth.value.to_str().ok(), Some("upstream-anthropic-key")); +} diff --git a/crates/token_proxy_core/src/proxy/server/prepared.rs b/crates/token_proxy_core/src/proxy/server/prepared.rs index bdf747e..14436f5 100644 --- a/crates/token_proxy_core/src/proxy/server/prepared.rs +++ b/crates/token_proxy_core/src/proxy/server/prepared.rs @@ -54,7 +54,7 @@ pub(super) fn resolve_request_auth_or_respond( provider: &str, request_start: Instant, ) -> Result { - match http::resolve_request_auth(config, headers) { + match http::resolve_request_auth(config, headers, path) { Ok(auth) => Ok(auth), Err(message) => { log_request_error( diff --git a/crates/token_proxy_core/src/proxy/server/tests.rs b/crates/token_proxy_core/src/proxy/server/tests.rs index 8131781..6900897 100644 --- a/crates/token_proxy_core/src/proxy/server/tests.rs +++ b/crates/token_proxy_core/src/proxy/server/tests.rs @@ -2706,6 +2706,40 @@ async fn send_anthropic_messages_request( (status, json) } +async fn send_anthropic_count_tokens_request( + state: ProxyStateHandle, + headers: HeaderMap, +) -> (StatusCode, Value) { + let response = proxy_request( + State(state), + Method::POST, + Uri::from_static("/v1/messages/count_tokens"), + headers, + Body::from( + json!({ + "model": "claude-sonnet-4-5", + "messages": [ + { + "role": "user", + "content": [ + { "type": "text", "text": "hi from claude" } + ] + } + ] + }) + .to_string(), + ), + ) + .await; + + let status = response.status(); + let body = to_bytes(response.into_body(), usize::MAX) + .await + .expect("proxy response bytes"); + let json = serde_json::from_slice(&body).expect("proxy response json"); + (status, json) +} + #[test] fn responses_request_uses_chat_compat_for_coding_plan_runtime_upstream() { run_async(async { @@ -3813,6 +3847,43 @@ fn gemini_count_tokens_route_dispatches_to_gemini() { assert_eq!(plan.response_transform, FormatTransform::None); } +#[test] +fn anthropic_count_tokens_preserves_authorization_header_name_for_upstream() { + run_async(async { + let upstream = spawn_mock_upstream(StatusCode::OK, json!({ "input_tokens": 12 })).await; + let config = config_with_runtime_upstreams(&[( + PROVIDER_ANTHROPIC, + 0, + "anthropic-auth-relay", + upstream.base_url.as_str(), + FORMATS_MESSAGES, + )]); + let data_dir = next_test_data_dir("anthropic_count_tokens_authorization_header"); + let state = build_test_state_handle(config, data_dir.clone()).await; + let mut headers = HeaderMap::new(); + headers.insert( + axum::http::header::AUTHORIZATION, + HeaderValue::from_static("Bearer local-debug-key"), + ); + headers.insert("anthropic-version", HeaderValue::from_static("2023-06-01")); + + let (status, body) = send_anthropic_count_tokens_request(state, headers).await; + let requests = upstream.requests(); + + upstream.abort(); + let _ = std::fs::remove_dir_all(&data_dir); + + assert_eq!(status, StatusCode::OK); + assert_eq!(body["input_tokens"].as_u64(), Some(12)); + assert_eq!(requests.len(), 1); + assert_eq!(requests[0].path, "/v1/messages/count_tokens"); + assert_eq!( + requests[0].authorization.as_deref(), + Some("Bearer test-key") + ); + }); +} + #[test] fn gemini_embed_route_dispatches_to_gemini() { let config = config_with_providers(&[(PROVIDER_GEMINI, FORMATS_GEMINI)]);