Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 93 additions & 34 deletions crates/token_proxy_core/src/proxy/http/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,12 @@ fn parse_query_key(query: Option<&str>) -> Result<Option<String>, String> {
#[derive(Clone, Default)]
pub(crate) struct RequestAuth {
pub(crate) openai_bearer: Option<HeaderValue>,
pub(crate) anthropic_api_key: Option<HeaderValue>,
pub(crate) anthropic_request_auth: Option<UpstreamAuthHeader>,
pub(crate) gemini_api_key: Option<String>,
pub(crate) authorization_fallback: Option<HeaderValue>,
}

#[derive(Clone)]
pub(crate) struct UpstreamAuthHeader {
pub(crate) name: HeaderName,
pub(crate) value: HeaderValue,
Expand All @@ -188,6 +189,7 @@ pub(crate) struct UpstreamAuthHeader {
pub(crate) fn resolve_request_auth(
config: &ProxyConfig,
headers: &HeaderMap,
path: &str,
) -> Result<RequestAuth, String> {
let mut auth = RequestAuth::default();
// When local auth is enabled, request auth headers are reserved for local access and not used upstream.
Expand All @@ -202,15 +204,8 @@ pub(crate) fn resolve_request_auth(
);
}

// Anthropic uses `x-api-key`; allow explicit overrides as well.
if let Some(value) = headers
.get(X_API_KEY)
.or_else(|| headers.get(X_ANTHROPIC_API_KEY))
{
let Ok(_) = value.to_str() else {
return Err("Upstream API key is invalid.".to_string());
};
auth.anthropic_api_key = Some(value.clone());
if is_anthropic_path(path) {
auth.anthropic_request_auth = resolve_anthropic_request_auth(headers)?;
}

if let Some(value) = headers.get(AUTHORIZATION) {
Expand All @@ -230,6 +225,44 @@ pub(crate) fn resolve_request_auth(
Ok(auth)
}

fn resolve_anthropic_request_auth(
headers: &HeaderMap,
) -> Result<Option<UpstreamAuthHeader>, String> {
if let Some(value) = headers.get(X_API_KEY) {
let Ok(_) = value.to_str() else {
return Err("Upstream API key is invalid.".to_string());
};
return Ok(Some(UpstreamAuthHeader {
name: HeaderName::from_static(X_API_KEY),
value: value.clone(),
}));
}

if let Some(value) = headers.get(X_ANTHROPIC_API_KEY) {
let Ok(_) = value.to_str() else {
return Err("Upstream API key is invalid.".to_string());
};
return Ok(Some(UpstreamAuthHeader {
name: HeaderName::from_static(X_ANTHROPIC_API_KEY),
value: value.clone(),
}));
}

let Some(value) = headers.get(AUTHORIZATION) else {
return Ok(None);
};
let Ok(value_str) = value.to_str() else {
return Err("Upstream API key is invalid.".to_string());
};
if extract_bearer_token(value_str).is_none() {
return Err("Upstream API key is invalid.".to_string());
}
Ok(Some(UpstreamAuthHeader {
name: AUTHORIZATION,
value: value.clone(),
}))
}

pub(crate) fn resolve_upstream_auth(
provider: &str,
upstream: &UpstreamRuntime,
Expand All @@ -240,40 +273,66 @@ pub(crate) fn resolve_upstream_auth(
upstream_id = %upstream.id,
has_upstream_key = upstream.api_key.is_some(),
has_openai_bearer = request_auth.openai_bearer.is_some(),
has_anthropic_key = request_auth.anthropic_api_key.is_some(),
has_anthropic_key = request_auth.anthropic_request_auth.is_some(),
has_auth_fallback = request_auth.authorization_fallback.is_some(),
"resolving upstream auth"
);

match provider {
"anthropic" => {
let value = match upstream.api_key.as_ref() {
Some(key) => {
tracing::debug!("using upstream.api_key for Anthropic");
HeaderValue::from_str(key).map_err(|_| {
error_response(
StatusCode::UNAUTHORIZED,
"Upstream API key contains invalid characters.",
)
})?
}
None => {
let Some(value) = request_auth.anthropic_api_key.clone().or_else(|| {
request_auth
.authorization_fallback
.as_ref()
.and_then(|value| value.to_str().ok())
.and_then(extract_bearer_token)
.and_then(|value| HeaderValue::from_str(value).ok())
}) else {
tracing::warn!("no API key for Anthropic");
return Ok(None);
if let Some(key) = upstream.api_key.as_ref() {
tracing::debug!("using upstream.api_key for Anthropic");
if let Some(request_header) = request_auth.anthropic_request_auth.as_ref() {
let value = if request_header.name == AUTHORIZATION {
bearer_header(key).ok_or_else(|| {
error_response(
StatusCode::UNAUTHORIZED,
"Upstream API key contains invalid characters.",
)
})?
} else {
HeaderValue::from_str(key).map_err(|_| {
error_response(
StatusCode::UNAUTHORIZED,
"Upstream API key contains invalid characters.",
)
})?
};
tracing::debug!("using request auth fallback for Anthropic");
value
return Ok(Some(UpstreamAuthHeader {
name: request_header.name.clone(),
value,
}));
}

let value = HeaderValue::from_str(key).map_err(|_| {
error_response(
StatusCode::UNAUTHORIZED,
"Upstream API key contains invalid characters.",
)
})?;
return Ok(Some(UpstreamAuthHeader {
name: HeaderName::from_static(X_API_KEY),
value,
}));
}

if let Some(header) = request_auth.anthropic_request_auth.clone() {
tracing::debug!("using native anthropic request auth header");
return Ok(Some(header));
}

let Some(value) = request_auth
.authorization_fallback
.as_ref()
.and_then(|value| value.to_str().ok())
.and_then(extract_bearer_token)
.and_then(|value| HeaderValue::from_str(value).ok())
else {
tracing::warn!("no API key for Anthropic");
return Ok(None);
};

tracing::debug!("using request auth fallback for Anthropic");
Ok(Some(UpstreamAuthHeader {
name: HeaderName::from_static(X_API_KEY),
value,
Expand Down
67 changes: 66 additions & 1 deletion crates/token_proxy_core/src/proxy/http/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,76 @@ fn anthropic_upstream_auth_accepts_authorization_bearer_fallback() {
HeaderValue::from_static("Bearer anthropic-request-key"),
);

let request_auth = resolve_request_auth(&config, &headers).expect("request auth");
let request_auth =
resolve_request_auth(&config, &headers, "/v1/messages").expect("request auth");
let auth = resolve_upstream_auth("anthropic", &upstream_without_key(), &request_auth)
.expect("upstream auth")
.expect("anthropic auth header");

assert_eq!(auth.name, AUTHORIZATION);
assert_eq!(
auth.value.to_str().ok(),
Some("Bearer anthropic-request-key")
);
}

#[test]
fn anthropic_upstream_auth_defaults_to_x_api_key_for_non_native_inbound_requests() {
let config = config_without_local();
let mut headers = HeaderMap::new();
headers.insert(
AUTHORIZATION,
HeaderValue::from_static("Bearer anthropic-request-key"),
);

let request_auth =
resolve_request_auth(&config, &headers, "/v1/chat/completions").expect("request auth");
let auth = resolve_upstream_auth("anthropic", &upstream_without_key(), &request_auth)
.expect("upstream auth")
.expect("anthropic auth header");

assert_eq!(auth.name.as_str(), "x-api-key");
assert_eq!(auth.value.to_str().ok(), Some("anthropic-request-key"));
}

#[test]
fn anthropic_upstream_auth_reuses_authorization_header_name_with_upstream_key() {
let config = config_without_local();
let mut headers = HeaderMap::new();
headers.insert(
AUTHORIZATION,
HeaderValue::from_static("Bearer local-debug-key"),
);

let request_auth =
resolve_request_auth(&config, &headers, "/v1/messages").expect("request auth");
let mut upstream = upstream_without_key();
upstream.api_key = Some("upstream-anthropic-key".to_string());
let auth = resolve_upstream_auth("anthropic", &upstream, &request_auth)
.expect("upstream auth")
.expect("anthropic auth header");

assert_eq!(auth.name, AUTHORIZATION);
assert_eq!(
auth.value.to_str().ok(),
Some("Bearer upstream-anthropic-key")
);
}

#[test]
fn anthropic_upstream_auth_reuses_x_api_key_header_name_with_upstream_key() {
let config = config_without_local();
let mut headers = HeaderMap::new();
headers.insert("x-api-key", HeaderValue::from_static("local-debug-key"));

let request_auth =
resolve_request_auth(&config, &headers, "/v1/messages").expect("request auth");
let mut upstream = upstream_without_key();
upstream.api_key = Some("upstream-anthropic-key".to_string());
let auth = resolve_upstream_auth("anthropic", &upstream, &request_auth)
.expect("upstream auth")
.expect("anthropic auth header");

assert_eq!(auth.name.as_str(), "x-api-key");
assert_eq!(auth.value.to_str().ok(), Some("upstream-anthropic-key"));
}
2 changes: 1 addition & 1 deletion crates/token_proxy_core/src/proxy/server/prepared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ pub(super) fn resolve_request_auth_or_respond(
provider: &str,
request_start: Instant,
) -> Result<http::RequestAuth, Response> {
match http::resolve_request_auth(config, headers) {
match http::resolve_request_auth(config, headers, path) {
Ok(auth) => Ok(auth),
Err(message) => {
log_request_error(
Expand Down
71 changes: 71 additions & 0 deletions crates/token_proxy_core/src/proxy/server/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2706,6 +2706,40 @@ async fn send_anthropic_messages_request(
(status, json)
}

async fn send_anthropic_count_tokens_request(
state: ProxyStateHandle,
headers: HeaderMap,
) -> (StatusCode, Value) {
let response = proxy_request(
State(state),
Method::POST,
Uri::from_static("/v1/messages/count_tokens"),
headers,
Body::from(
json!({
"model": "claude-sonnet-4-5",
"messages": [
{
"role": "user",
"content": [
{ "type": "text", "text": "hi from claude" }
]
}
]
})
.to_string(),
),
)
.await;

let status = response.status();
let body = to_bytes(response.into_body(), usize::MAX)
.await
.expect("proxy response bytes");
let json = serde_json::from_slice(&body).expect("proxy response json");
(status, json)
}

#[test]
fn responses_request_uses_chat_compat_for_coding_plan_runtime_upstream() {
run_async(async {
Expand Down Expand Up @@ -3813,6 +3847,43 @@ fn gemini_count_tokens_route_dispatches_to_gemini() {
assert_eq!(plan.response_transform, FormatTransform::None);
}

#[test]
fn anthropic_count_tokens_preserves_authorization_header_name_for_upstream() {
run_async(async {
let upstream = spawn_mock_upstream(StatusCode::OK, json!({ "input_tokens": 12 })).await;
let config = config_with_runtime_upstreams(&[(
PROVIDER_ANTHROPIC,
0,
"anthropic-auth-relay",
upstream.base_url.as_str(),
FORMATS_MESSAGES,
)]);
let data_dir = next_test_data_dir("anthropic_count_tokens_authorization_header");
let state = build_test_state_handle(config, data_dir.clone()).await;
let mut headers = HeaderMap::new();
headers.insert(
axum::http::header::AUTHORIZATION,
HeaderValue::from_static("Bearer local-debug-key"),
);
headers.insert("anthropic-version", HeaderValue::from_static("2023-06-01"));

let (status, body) = send_anthropic_count_tokens_request(state, headers).await;
let requests = upstream.requests();

upstream.abort();
let _ = std::fs::remove_dir_all(&data_dir);

assert_eq!(status, StatusCode::OK);
assert_eq!(body["input_tokens"].as_u64(), Some(12));
assert_eq!(requests.len(), 1);
assert_eq!(requests[0].path, "/v1/messages/count_tokens");
assert_eq!(
requests[0].authorization.as_deref(),
Some("Bearer test-key")
);
});
}

#[test]
fn gemini_embed_route_dispatches_to_gemini() {
let config = config_with_providers(&[(PROVIDER_GEMINI, FORMATS_GEMINI)]);
Expand Down
Loading