diff --git a/crates/braintrust-llm-router/examples/custom_auth.rs b/crates/braintrust-llm-router/examples/custom_auth.rs index 904088d3..9c87ec4a 100644 --- a/crates/braintrust-llm-router/examples/custom_auth.rs +++ b/crates/braintrust-llm-router/examples/custom_auth.rs @@ -15,7 +15,8 @@ use anyhow::Result; use braintrust_llm_router::{ - serde_json::json, AuthConfig, OpenAIConfig, OpenAIProvider, ProviderFormat, Router, + serde_json::json, AuthConfig, ClientHeaders, OpenAIConfig, OpenAIProvider, ProviderFormat, + Router, }; use bytes::Bytes; use serde_json::Value; @@ -127,7 +128,14 @@ async fn main() -> Result<()> { println!(" Sending authenticated request to GPT-4..."); let body = Bytes::from(serde_json::to_vec(&payload)?); - let bytes = router.complete(body, model, ProviderFormat::OpenAI).await?; + let bytes = router + .complete( + body, + model, + ProviderFormat::OpenAI, + &ClientHeaders::default(), + ) + .await?; let response: Value = serde_json::from_slice(&bytes)?; if let Some(text) = extract_assistant_text(&response) { diff --git a/crates/braintrust-llm-router/examples/multi_provider.rs b/crates/braintrust-llm-router/examples/multi_provider.rs index 98f7a0c1..22ace6bf 100644 --- a/crates/braintrust-llm-router/examples/multi_provider.rs +++ b/crates/braintrust-llm-router/examples/multi_provider.rs @@ -10,8 +10,8 @@ use anyhow::Result; use braintrust_llm_router::{ - serde_json::json, AnthropicConfig, AnthropicProvider, AuthConfig, OpenAIConfig, OpenAIProvider, - ProviderFormat, Router, + serde_json::json, AnthropicConfig, AnthropicProvider, AuthConfig, ClientHeaders, OpenAIConfig, + OpenAIProvider, ProviderFormat, Router, }; use bytes::Bytes; use serde_json::Value; @@ -77,7 +77,15 @@ async fn main() -> Result<()> { }); let body = Bytes::from(serde_json::to_vec(&payload)?); - match router.complete(body, model, ProviderFormat::OpenAI).await { + match router + .complete( + body, + model, + ProviderFormat::OpenAI, + &ClientHeaders::default(), + ) + .await + { Ok(bytes) => { if let Ok(response) = serde_json::from_slice::(&bytes) { if let Some(text) = extract_assistant_text(&response) { @@ -100,7 +108,15 @@ async fn main() -> Result<()> { }); let body = Bytes::from(serde_json::to_vec(&payload)?); - match router.complete(body, model, ProviderFormat::OpenAI).await { + match router + .complete( + body, + model, + ProviderFormat::OpenAI, + &ClientHeaders::default(), + ) + .await + { Ok(bytes) => { if let Ok(response) = serde_json::from_slice::(&bytes) { if let Some(text) = extract_assistant_text(&response) { diff --git a/crates/braintrust-llm-router/examples/simple.rs b/crates/braintrust-llm-router/examples/simple.rs index 826aea60..33578973 100644 --- a/crates/braintrust-llm-router/examples/simple.rs +++ b/crates/braintrust-llm-router/examples/simple.rs @@ -7,7 +7,7 @@ use anyhow::Result; use braintrust_llm_router::{ - serde_json::json, OpenAIConfig, OpenAIProvider, ProviderFormat, Router, + serde_json::json, ClientHeaders, OpenAIConfig, OpenAIProvider, ProviderFormat, Router, }; use bytes::Bytes; use serde_json::Value; @@ -53,7 +53,14 @@ async fn main() -> Result<()> { // Convert payload to bytes and send request let body = Bytes::from(serde_json::to_vec(&payload)?); - let bytes = router.complete(body, model, ProviderFormat::OpenAI).await?; + let bytes = router + .complete( + body, + model, + ProviderFormat::OpenAI, + &ClientHeaders::default(), + ) + .await?; let response: Value = serde_json::from_slice(&bytes)?; println!("📝 Response:\n"); diff --git a/crates/braintrust-llm-router/examples/streaming.rs b/crates/braintrust-llm-router/examples/streaming.rs index 1462fe9a..88659761 100644 --- a/crates/braintrust-llm-router/examples/streaming.rs +++ b/crates/braintrust-llm-router/examples/streaming.rs @@ -7,8 +7,8 @@ use anyhow::Result; use braintrust_llm_router::{ - serde_json::json, AnthropicConfig, AnthropicProvider, AuthConfig, ProviderFormat, - ResponseStream, Router, + serde_json::json, AnthropicConfig, AnthropicProvider, AuthConfig, ClientHeaders, + ProviderFormat, ResponseStream, Router, }; use bytes::Bytes; use futures::StreamExt; @@ -54,7 +54,12 @@ async fn main() -> Result<()> { let body = Bytes::from(serde_json::to_vec(&payload)?); let mut stream = router - .complete_stream(body, model, ProviderFormat::OpenAI) + .complete_stream( + body, + model, + ProviderFormat::OpenAI, + &ClientHeaders::default(), + ) .await?; // Process the stream @@ -128,7 +133,12 @@ async fn main() -> Result<()> { }); let body = Bytes::from(serde_json::to_vec(&payload)?); let stream = router - .complete_stream(body, model, ProviderFormat::OpenAI) + .complete_stream( + body, + model, + ProviderFormat::OpenAI, + &ClientHeaders::default(), + ) .await?; streams.push((model.to_string(), stream)); } diff --git a/crates/braintrust-llm-router/src/lib.rs b/crates/braintrust-llm-router/src/lib.rs index 0a040b8a..bab93b85 100644 --- a/crates/braintrust-llm-router/src/lib.rs +++ b/crates/braintrust-llm-router/src/lib.rs @@ -25,9 +25,9 @@ pub use lingua::ProviderFormat; pub use lingua::{FinishReason, UniversalStreamChoice, UniversalStreamChunk}; pub use providers::{ is_openai_compatible, openai_compatible_endpoint, AnthropicConfig, AnthropicProvider, - AzureConfig, AzureProvider, BedrockConfig, BedrockProvider, GoogleConfig, GoogleProvider, - MistralConfig, MistralProvider, OpenAICompatibleEndpoint, OpenAIConfig, OpenAIProvider, - OpenAIResponsesProvider, Provider, VertexConfig, VertexProvider, + AzureConfig, AzureProvider, BedrockConfig, BedrockProvider, ClientHeaders, GoogleConfig, + GoogleProvider, MistralConfig, MistralProvider, OpenAICompatibleEndpoint, OpenAIConfig, + OpenAIProvider, OpenAIResponsesProvider, Provider, VertexConfig, VertexProvider, }; pub use retry::{RetryPolicy, RetryStrategy}; pub use router::{create_provider, extract_request_hints, RequestHints, Router, RouterBuilder}; diff --git a/crates/braintrust-llm-router/src/providers/anthropic.rs b/crates/braintrust-llm-router/src/providers/anthropic.rs index 7109332e..901d6958 100644 --- a/crates/braintrust-llm-router/src/providers/anthropic.rs +++ b/crates/braintrust-llm-router/src/providers/anthropic.rs @@ -2,21 +2,25 @@ use std::time::Duration; use async_trait::async_trait; use bytes::Bytes; -use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE}; +use reqwest::header::{HeaderMap, HeaderValue}; use reqwest::{Client, StatusCode, Url}; use crate::auth::AuthConfig; use crate::catalog::ModelSpec; use crate::client::{default_client, ClientSettings}; use crate::error::{Error, Result, UpstreamHttpError}; +use crate::providers::ClientHeaders; use crate::streaming::{single_bytes_stream, sse_stream, RawResponseStream}; use lingua::ProviderFormat; +const ANTHROPIC_VERSION: &str = "anthropic-version"; +const ANTHROPIC_BETA: &str = "anthropic-beta"; +const STRUCTURED_OUTPUTS_BETA: &str = "structured-outputs-2025-11-13"; + #[derive(Debug, Clone)] pub struct AnthropicConfig { pub endpoint: Url, pub version: String, - pub beta: Option, pub timeout: Option, } @@ -26,7 +30,6 @@ impl Default for AnthropicConfig { endpoint: Url::parse("https://api.anthropic.com/v1/") .expect("valid Anthropic endpoint"), version: "2023-06-01".to_string(), - beta: None, timeout: None, } } @@ -56,7 +59,6 @@ impl AnthropicProvider { /// /// Extracts Anthropic-specific options from metadata: /// - `version`: Anthropic API version (defaults to "2023-06-01") - /// - `beta`: Beta feature flag pub fn from_config( endpoint: Option<&Url>, timeout: Option, @@ -74,9 +76,6 @@ impl AnthropicProvider { if let Some(version) = metadata.get("version").and_then(Value::as_str) { config.version = version.to_string(); } - if let Some(beta) = metadata.get("beta").and_then(Value::as_str) { - config.beta = Some(beta.to_string()); - } Self::new(config) } @@ -88,18 +87,23 @@ impl AnthropicProvider { .expect("join messages path") } - fn apply_headers(&self, headers: &mut HeaderMap) { - headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + fn build_headers(&self, client_headers: &ClientHeaders) -> HeaderMap { + let mut headers = client_headers.to_json_headers(); + headers.insert( - "anthropic-version", + ANTHROPIC_VERSION, HeaderValue::from_str(&self.config.version).expect("version header"), ); - if let Some(beta) = &self.config.beta { + + // Respect caller override: only set default if missing. + if !headers.contains_key(ANTHROPIC_BETA) { headers.insert( - "anthropic-beta", - HeaderValue::from_str(beta).unwrap_or_else(|_| HeaderValue::from_static("")), + ANTHROPIC_BETA, + HeaderValue::from_static(STRUCTURED_OUTPUTS_BETA), ); } + + headers } } @@ -118,6 +122,7 @@ impl crate::providers::Provider for AnthropicProvider { payload: Bytes, auth: &AuthConfig, _spec: &ModelSpec, + client_headers: &ClientHeaders, ) -> Result { let url = self.messages_url(); @@ -129,8 +134,7 @@ impl crate::providers::Provider for AnthropicProvider { "sending request to Anthropic" ); - let mut headers = HeaderMap::new(); - self.apply_headers(&mut headers); + let mut headers = self.build_headers(client_headers); auth.apply_headers(&mut headers)?; let response = self @@ -176,9 +180,10 @@ impl crate::providers::Provider for AnthropicProvider { payload: Bytes, auth: &AuthConfig, spec: &ModelSpec, + client_headers: &ClientHeaders, ) -> Result { if !spec.supports_streaming { - let response = self.complete(payload, auth, spec).await?; + let response = self.complete(payload, auth, spec, client_headers).await?; return Ok(single_bytes_stream(response)); } @@ -194,8 +199,7 @@ impl crate::providers::Provider for AnthropicProvider { "sending streaming request to Anthropic" ); - let mut headers = HeaderMap::new(); - self.apply_headers(&mut headers); + let mut headers = self.build_headers(client_headers); auth.apply_headers(&mut headers)?; let response = self @@ -243,8 +247,7 @@ impl crate::providers::Provider for AnthropicProvider { .endpoint .join("models") .expect("join models path"); - let mut headers = HeaderMap::new(); - self.apply_headers(&mut headers); + let mut headers = self.build_headers(&ClientHeaders::default()); auth.apply_headers(&mut headers)?; let response = self.client.get(url).headers(headers).send().await?; diff --git a/crates/braintrust-llm-router/src/providers/azure.rs b/crates/braintrust-llm-router/src/providers/azure.rs index dd435619..ecb3cacd 100644 --- a/crates/braintrust-llm-router/src/providers/azure.rs +++ b/crates/braintrust-llm-router/src/providers/azure.rs @@ -3,13 +3,14 @@ use std::time::Duration; use async_trait::async_trait; use bytes::Bytes; use lingua::serde_json::Value; -use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE}; +use reqwest::header::HeaderMap; use reqwest::{Client, StatusCode, Url}; use crate::auth::AuthConfig; use crate::catalog::ModelSpec; use crate::client::{default_client, ClientSettings}; use crate::error::{Error, Result, UpstreamHttpError}; +use crate::providers::ClientHeaders; use crate::streaming::{single_bytes_stream, sse_stream, RawResponseStream}; use lingua::ProviderFormat; @@ -143,7 +144,13 @@ impl crate::providers::Provider for AzureProvider { ProviderFormat::OpenAI } - async fn complete(&self, payload: Bytes, auth: &AuthConfig, spec: &ModelSpec) -> Result { + async fn complete( + &self, + payload: Bytes, + auth: &AuthConfig, + spec: &ModelSpec, + client_headers: &ClientHeaders, + ) -> Result { let url = self.chat_url(&spec.model)?; #[cfg(feature = "tracing")] @@ -154,8 +161,7 @@ impl crate::providers::Provider for AzureProvider { "sending request to Azure" ); - let mut headers = HeaderMap::new(); - headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + let mut headers = self.build_headers(client_headers); auth.apply_headers(&mut headers)?; let response = self @@ -201,9 +207,10 @@ impl crate::providers::Provider for AzureProvider { payload: Bytes, auth: &AuthConfig, spec: &ModelSpec, + client_headers: &ClientHeaders, ) -> Result { if !spec.supports_streaming { - let response = self.complete(payload, auth, spec).await?; + let response = self.complete(payload, auth, spec, client_headers).await?; return Ok(single_bytes_stream(response)); } @@ -219,8 +226,7 @@ impl crate::providers::Provider for AzureProvider { "sending streaming request to Azure" ); - let mut headers = HeaderMap::new(); - headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + let mut headers = self.build_headers(client_headers); auth.apply_headers(&mut headers)?; let response = self diff --git a/crates/braintrust-llm-router/src/providers/bedrock.rs b/crates/braintrust-llm-router/src/providers/bedrock.rs index 1dc06a42..862b9386 100644 --- a/crates/braintrust-llm-router/src/providers/bedrock.rs +++ b/crates/braintrust-llm-router/src/providers/bedrock.rs @@ -17,6 +17,7 @@ use crate::auth::AuthConfig; use crate::catalog::ModelSpec; use crate::client::{default_client, ClientSettings}; use crate::error::{Error, Result, UpstreamHttpError}; +use crate::providers::ClientHeaders; use crate::streaming::{bedrock_event_stream, single_bytes_stream, RawResponseStream}; use lingua::ProviderFormat; @@ -179,6 +180,12 @@ impl BedrockProvider { } Ok(headers) } + + fn build_headers(&self, url: &Url, payload: &[u8], auth: &AuthConfig) -> Result { + let mut headers = self.sign_request(url, payload, auth)?; + headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + Ok(headers) + } } #[async_trait] @@ -191,7 +198,13 @@ impl crate::providers::Provider for BedrockProvider { ProviderFormat::Converse } - async fn complete(&self, payload: Bytes, auth: &AuthConfig, spec: &ModelSpec) -> Result { + async fn complete( + &self, + payload: Bytes, + auth: &AuthConfig, + spec: &ModelSpec, + _client_headers: &ClientHeaders, + ) -> Result { let url = self.invoke_url(&spec.model, false)?; #[cfg(feature = "tracing")] @@ -202,8 +215,7 @@ impl crate::providers::Provider for BedrockProvider { "sending request to Bedrock" ); - let mut headers = self.sign_request(&url, &payload, auth)?; - headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + let headers = self.build_headers(&url, payload.as_ref(), auth)?; let response = self .client @@ -248,9 +260,10 @@ impl crate::providers::Provider for BedrockProvider { payload: Bytes, auth: &AuthConfig, spec: &ModelSpec, + client_headers: &ClientHeaders, ) -> Result { if !spec.supports_streaming { - let response = self.complete(payload, auth, spec).await?; + let response = self.complete(payload, auth, spec, client_headers).await?; return Ok(single_bytes_stream(response)); } @@ -266,8 +279,7 @@ impl crate::providers::Provider for BedrockProvider { "sending streaming request to Bedrock" ); - let mut headers = self.sign_request(&url, &payload, auth)?; - headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + let headers = self.build_headers(&url, payload.as_ref(), auth)?; let response = self .client diff --git a/crates/braintrust-llm-router/src/providers/google.rs b/crates/braintrust-llm-router/src/providers/google.rs index 650c6081..c57f673a 100644 --- a/crates/braintrust-llm-router/src/providers/google.rs +++ b/crates/braintrust-llm-router/src/providers/google.rs @@ -2,13 +2,14 @@ use std::time::Duration; use async_trait::async_trait; use bytes::Bytes; -use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE}; +use reqwest::header::HeaderMap; use reqwest::{Client, StatusCode, Url}; use crate::auth::AuthConfig; use crate::catalog::ModelSpec; use crate::client::{default_client, ClientSettings}; use crate::error::{Error, Result, UpstreamHttpError}; +use crate::providers::ClientHeaders; use crate::streaming::{single_bytes_stream, sse_stream, RawResponseStream}; use lingua::ProviderFormat; @@ -93,7 +94,13 @@ impl crate::providers::Provider for GoogleProvider { ProviderFormat::Google } - async fn complete(&self, payload: Bytes, auth: &AuthConfig, spec: &ModelSpec) -> Result { + async fn complete( + &self, + payload: Bytes, + auth: &AuthConfig, + spec: &ModelSpec, + client_headers: &ClientHeaders, + ) -> Result { let url = self.generate_url(&spec.model, false)?; #[cfg(feature = "tracing")] @@ -104,8 +111,7 @@ impl crate::providers::Provider for GoogleProvider { "sending request to Google" ); - let mut headers = HeaderMap::new(); - headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + let mut headers = self.build_headers(client_headers); auth.apply_headers(&mut headers)?; let response = self @@ -151,9 +157,10 @@ impl crate::providers::Provider for GoogleProvider { payload: Bytes, auth: &AuthConfig, spec: &ModelSpec, + client_headers: &ClientHeaders, ) -> Result { if !spec.supports_streaming { - let response = self.complete(payload, auth, spec).await?; + let response = self.complete(payload, auth, spec, client_headers).await?; return Ok(single_bytes_stream(response)); } @@ -169,8 +176,7 @@ impl crate::providers::Provider for GoogleProvider { "sending streaming request to Google" ); - let mut headers = HeaderMap::new(); - headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + let mut headers = self.build_headers(client_headers); auth.apply_headers(&mut headers)?; let response = self diff --git a/crates/braintrust-llm-router/src/providers/mistral.rs b/crates/braintrust-llm-router/src/providers/mistral.rs index 6ef74495..6296281c 100644 --- a/crates/braintrust-llm-router/src/providers/mistral.rs +++ b/crates/braintrust-llm-router/src/providers/mistral.rs @@ -2,13 +2,14 @@ use std::time::Duration; use async_trait::async_trait; use bytes::Bytes; -use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE}; +use reqwest::header::HeaderMap; use reqwest::{Client, StatusCode, Url}; use crate::auth::AuthConfig; use crate::catalog::ModelSpec; use crate::client::{default_client, ClientSettings}; use crate::error::{Error, Result, UpstreamHttpError}; +use crate::providers::ClientHeaders; use crate::streaming::{single_bytes_stream, sse_stream, RawResponseStream}; use lingua::ProviderFormat; @@ -99,6 +100,7 @@ impl crate::providers::Provider for MistralProvider { payload: Bytes, auth: &AuthConfig, _spec: &ModelSpec, + client_headers: &ClientHeaders, ) -> Result { let url = self.chat_url()?; @@ -110,8 +112,7 @@ impl crate::providers::Provider for MistralProvider { "sending request to Mistral" ); - let mut headers = HeaderMap::new(); - headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + let mut headers = self.build_headers(client_headers); auth.apply_headers(&mut headers)?; let response = self @@ -157,9 +158,10 @@ impl crate::providers::Provider for MistralProvider { payload: Bytes, auth: &AuthConfig, spec: &ModelSpec, + client_headers: &ClientHeaders, ) -> Result { if !spec.supports_streaming { - let response = self.complete(payload, auth, spec).await?; + let response = self.complete(payload, auth, spec, client_headers).await?; return Ok(single_bytes_stream(response)); } @@ -175,8 +177,7 @@ impl crate::providers::Provider for MistralProvider { "sending streaming request to Mistral" ); - let mut headers = HeaderMap::new(); - headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + let mut headers = self.build_headers(client_headers); auth.apply_headers(&mut headers)?; let response = self diff --git a/crates/braintrust-llm-router/src/providers/mod.rs b/crates/braintrust-llm-router/src/providers/mod.rs index a52e5136..bb5b33e0 100644 --- a/crates/braintrust-llm-router/src/providers/mod.rs +++ b/crates/braintrust-llm-router/src/providers/mod.rs @@ -21,6 +21,8 @@ pub use vertex::{VertexConfig, VertexProvider}; use async_trait::async_trait; use bytes::Bytes; +use reqwest::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_TYPE}; +use std::collections::HashMap; use std::sync::Arc; use crate::auth::AuthConfig; @@ -29,6 +31,89 @@ use crate::error::Result; use crate::streaming::RawResponseStream; use lingua::ProviderFormat; +/// Header prefixes blocked from forwarding to upstream LLM providers. +pub const BLOCKED_HEADER_PREFIXES: &[&str] = &["x-amzn", "x-bt", "sec-"]; + +/// Exact header names blocked from forwarding to upstream LLM providers +/// from https://github.com/braintrustdata/braintrust-proxy/blob/e992f51734c71e689ea0090f9e0a6759c9a593a4/packages/proxy/src/proxy.ts#L247 +pub const BLOCKED_HEADERS: &[&str] = &[ + "authorization", + "api-key", + "x-api-key", + "x-auth-token", + "content-length", + "origin", + "priority", + "referer", + "user-agent", + "cache-control", + // Avoid forwarding client Accept-Encoding: upstream may return compressed bytes we + // don't decode (e.g., unsupported encodings), which breaks JSON parsing in the router. + "accept-encoding", +]; + +#[derive(Debug, Clone, Default)] +pub struct ClientHeaders { + inner: HashMap, +} + +impl ClientHeaders { + pub fn new() -> Self { + Self::default() + } + + fn is_blocked(name: &str) -> bool { + let name = name.to_lowercase(); + BLOCKED_HEADER_PREFIXES + .iter() + .any(|prefix| name.starts_with(prefix)) + || BLOCKED_HEADERS.iter().any(|&blocked| name == blocked) + } + + pub fn insert_if_allowed(&mut self, name: impl Into, value: impl Into) { + let name = name.into(); + if !Self::is_blocked(&name) { + self.inner.insert(name.to_lowercase(), value.into()); + } + } + + pub fn apply(&self, headers: &mut HeaderMap) { + for (name, value) in &self.inner { + if name == "host" { + // Don't forward client Host; reqwest sets it from the upstream URL. + continue; + } + if let (Ok(header_name), Ok(header_value)) = ( + HeaderName::from_bytes(name.as_bytes()), + HeaderValue::from_str(value), + ) { + headers.insert(header_name, header_value); + } + } + } + + pub(crate) fn to_json_headers(&self) -> HeaderMap { + let mut headers = HeaderMap::new(); + self.apply(&mut headers); + headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + headers + } +} + +// NOTE: This FromIterator impl exists to support collecting forwarded client +// headers from `(String, String)` pairs at crate boundaries. We use pairs instead +// of `http::HeaderMap` because different workspace crates depend on different +// major versions of the `http` crate, making `HeaderMap` incompatible. +impl FromIterator<(String, String)> for ClientHeaders { + fn from_iter>(iter: T) -> Self { + let mut client_headers = ClientHeaders::new(); + for (name, value) in iter { + client_headers.insert_if_allowed(name, value); + } + client_headers + } +} + /// Provider trait for LLM API backends. /// /// Implementations should be `Send + Sync` to allow concurrent access. @@ -55,7 +140,14 @@ pub trait Provider: Send + Sync { /// * `payload` - Pre-transformed bytes payload ready to send to the provider /// * `auth` - Authentication configuration /// * `spec` - Model specification - async fn complete(&self, payload: Bytes, auth: &AuthConfig, spec: &ModelSpec) -> Result; + /// * `client_headers` - Client headers to forward to the upstream provider + async fn complete( + &self, + payload: Bytes, + auth: &AuthConfig, + spec: &ModelSpec, + client_headers: &ClientHeaders, + ) -> Result; /// Execute a streaming completion request. /// @@ -67,15 +159,25 @@ pub trait Provider: Send + Sync { /// * `payload` - Pre-transformed bytes payload ready to send to the provider /// * `auth` - Authentication configuration /// * `spec` - Model specification + /// * `client_headers` - Client headers to forward to the upstream provider async fn complete_stream( &self, payload: Bytes, auth: &AuthConfig, spec: &ModelSpec, + client_headers: &ClientHeaders, ) -> Result; /// Check if the provider is reachable. async fn health_check(&self, auth: &AuthConfig) -> Result<()>; + + /// Build HTTP headers for a request. + /// + /// Default implementation returns JSON content-type headers. + /// Override for provider-specific headers (e.g., OpenAI-Organization). + fn build_headers(&self, client_headers: &ClientHeaders) -> HeaderMap { + client_headers.to_json_headers() + } } impl dyn Provider { diff --git a/crates/braintrust-llm-router/src/providers/openai.rs b/crates/braintrust-llm-router/src/providers/openai.rs index 1690421c..de72f3f9 100644 --- a/crates/braintrust-llm-router/src/providers/openai.rs +++ b/crates/braintrust-llm-router/src/providers/openai.rs @@ -3,13 +3,14 @@ use std::time::Duration; use async_trait::async_trait; use bytes::Bytes; use lingua::serde_json::Value; -use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE}; +use reqwest::header::{HeaderMap, HeaderValue}; use reqwest::{Client, StatusCode, Url}; use crate::auth::AuthConfig; use crate::catalog::ModelSpec; use crate::client::{default_client, ClientSettings}; use crate::error::{Error, Result, UpstreamHttpError}; +use crate::providers::ClientHeaders; use crate::streaming::{single_bytes_stream, sse_stream, RawResponseStream}; use lingua::ProviderFormat; @@ -148,6 +149,12 @@ impl OpenAIProvider { ); } } + + fn build_headers(&self, client_headers: &ClientHeaders) -> HeaderMap { + let mut headers = client_headers.to_json_headers(); + self.apply_headers(&mut headers); + headers + } } #[async_trait] @@ -160,7 +167,13 @@ impl crate::providers::Provider for OpenAIProvider { ProviderFormat::OpenAI } - async fn complete(&self, payload: Bytes, auth: &AuthConfig, spec: &ModelSpec) -> Result { + async fn complete( + &self, + payload: Bytes, + auth: &AuthConfig, + spec: &ModelSpec, + client_headers: &ClientHeaders, + ) -> Result { let url = self.chat_url(Some(&spec.model))?; #[cfg(feature = "tracing")] @@ -171,9 +184,7 @@ impl crate::providers::Provider for OpenAIProvider { "sending request to OpenAI" ); - let mut headers = HeaderMap::new(); - headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); - self.apply_headers(&mut headers); + let mut headers = self.build_headers(client_headers); auth.apply_headers(&mut headers)?; let response = self @@ -219,9 +230,10 @@ impl crate::providers::Provider for OpenAIProvider { payload: Bytes, auth: &AuthConfig, spec: &ModelSpec, + client_headers: &ClientHeaders, ) -> Result { if !spec.supports_streaming { - let response = self.complete(payload, auth, spec).await?; + let response = self.complete(payload, auth, spec, client_headers).await?; return Ok(single_bytes_stream(response)); } @@ -237,9 +249,7 @@ impl crate::providers::Provider for OpenAIProvider { "sending streaming request to OpenAI" ); - let mut headers = HeaderMap::new(); - headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); - self.apply_headers(&mut headers); + let mut headers = self.build_headers(client_headers); auth.apply_headers(&mut headers)?; let response = self @@ -283,8 +293,7 @@ impl crate::providers::Provider for OpenAIProvider { async fn health_check(&self, auth: &AuthConfig) -> Result<()> { let url = self.chat_url(None)?; - let mut headers = HeaderMap::new(); - self.apply_headers(&mut headers); + let mut headers = self.build_headers(&ClientHeaders::default()); auth.apply_headers(&mut headers)?; let response = self.client.get(url).headers(headers).send().await?; diff --git a/crates/braintrust-llm-router/src/providers/openai_responses.rs b/crates/braintrust-llm-router/src/providers/openai_responses.rs index 9f52216b..17363088 100644 --- a/crates/braintrust-llm-router/src/providers/openai_responses.rs +++ b/crates/braintrust-llm-router/src/providers/openai_responses.rs @@ -2,13 +2,14 @@ use std::time::Duration; use async_trait::async_trait; use bytes::Bytes; -use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE}; +use reqwest::header::{HeaderMap, HeaderValue}; use reqwest::{Client, StatusCode}; use crate::auth::AuthConfig; use crate::catalog::ModelSpec; use crate::client::{default_client, ClientSettings}; use crate::error::{Error, Result, UpstreamHttpError}; +use crate::providers::ClientHeaders; use crate::streaming::{single_bytes_stream, RawResponseStream}; use lingua::ProviderFormat; @@ -67,6 +68,12 @@ impl OpenAIResponsesProvider { ); } } + + fn build_headers(&self, client_headers: &ClientHeaders) -> HeaderMap { + let mut headers = client_headers.to_json_headers(); + self.apply_headers(&mut headers); + headers + } } #[async_trait] @@ -84,6 +91,7 @@ impl crate::providers::Provider for OpenAIResponsesProvider { payload: Bytes, auth: &AuthConfig, _spec: &ModelSpec, + client_headers: &ClientHeaders, ) -> Result { let url = self.responses_url()?; @@ -95,9 +103,7 @@ impl crate::providers::Provider for OpenAIResponsesProvider { "sending request to OpenAI Responses API" ); - let mut headers = HeaderMap::new(); - headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); - self.apply_headers(&mut headers); + let mut headers = self.build_headers(client_headers); auth.apply_headers(&mut headers)?; let response = self @@ -143,16 +149,16 @@ impl crate::providers::Provider for OpenAIResponsesProvider { payload: Bytes, auth: &AuthConfig, spec: &ModelSpec, + client_headers: &ClientHeaders, ) -> Result { // Responses API doesn't support streaming, return single-bytes stream - let response = self.complete(payload, auth, spec).await?; + let response = self.complete(payload, auth, spec, client_headers).await?; Ok(single_bytes_stream(response)) } async fn health_check(&self, auth: &AuthConfig) -> Result<()> { let url = self.responses_url()?; - let mut headers = HeaderMap::new(); - self.apply_headers(&mut headers); + let mut headers = self.build_headers(&ClientHeaders::default()); auth.apply_headers(&mut headers)?; let response = self.client.get(url).headers(headers).send().await?; diff --git a/crates/braintrust-llm-router/src/providers/vertex.rs b/crates/braintrust-llm-router/src/providers/vertex.rs index f0a79dfb..e16e0f38 100644 --- a/crates/braintrust-llm-router/src/providers/vertex.rs +++ b/crates/braintrust-llm-router/src/providers/vertex.rs @@ -3,13 +3,14 @@ use std::time::Duration; use async_trait::async_trait; use bytes::Bytes; use lingua::serde_json::Value; -use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE}; +use reqwest::header::HeaderMap; use reqwest::{Client, StatusCode, Url}; use crate::auth::AuthConfig; use crate::catalog::ModelSpec; use crate::client::{default_client, ClientSettings}; use crate::error::{Error, Result, UpstreamHttpError}; +use crate::providers::ClientHeaders; use crate::streaming::{single_bytes_stream, sse_stream, RawResponseStream}; use lingua::ProviderFormat; @@ -149,7 +150,13 @@ impl crate::providers::Provider for VertexProvider { ProviderFormat::Google } - async fn complete(&self, payload: Bytes, auth: &AuthConfig, spec: &ModelSpec) -> Result { + async fn complete( + &self, + payload: Bytes, + auth: &AuthConfig, + spec: &ModelSpec, + client_headers: &ClientHeaders, + ) -> Result { let mode = self.determine_mode(&spec.model); let url = self.endpoint_for_mode(&mode, false)?; @@ -161,8 +168,7 @@ impl crate::providers::Provider for VertexProvider { "sending request to Vertex" ); - let mut headers = HeaderMap::new(); - headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + let mut headers = self.build_headers(client_headers); auth.apply_headers(&mut headers)?; let response = self @@ -208,9 +214,10 @@ impl crate::providers::Provider for VertexProvider { payload: Bytes, auth: &AuthConfig, spec: &ModelSpec, + client_headers: &ClientHeaders, ) -> Result { if !spec.supports_streaming { - let response = self.complete(payload, auth, spec).await?; + let response = self.complete(payload, auth, spec, client_headers).await?; return Ok(single_bytes_stream(response)); } @@ -226,8 +233,7 @@ impl crate::providers::Provider for VertexProvider { "sending streaming request to Vertex" ); - let mut headers = HeaderMap::new(); - headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + let mut headers = self.build_headers(client_headers); auth.apply_headers(&mut headers)?; // Router should have already added stream options to payload diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index 3e138fc6..f3a4b9e5 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -13,7 +13,7 @@ use crate::catalog::{ default_catalog, load_catalog_from_disk, ModelCatalog, ModelResolver, ModelSpec, }; use crate::error::{Error, Result}; -use crate::providers::Provider; +use crate::providers::{ClientHeaders, Provider}; use crate::retry::{RetryPolicy, RetryStrategy}; use crate::streaming::{transform_stream, ResponseStream}; use lingua::serde_json::Value; @@ -129,6 +129,7 @@ impl Router { /// * `body` - Raw request body bytes in any supported format (OpenAI, Anthropic, Google, etc.) /// * `model` - The model name for routing (e.g., "gpt-4", "claude-3-opus") /// * `output_format` - The output format, or None to auto-detect from body + /// * `client_headers` - Client headers to forward to the upstream provider /// /// The body will be automatically transformed to the target provider's format if needed. /// The response will be converted to the requested output format. @@ -136,7 +137,7 @@ impl Router { feature = "tracing", tracing::instrument( name = "bt.router.complete", - skip(self, body), + skip(self, body, client_headers), fields(llm.model = %model) ) )] @@ -145,6 +146,7 @@ impl Router { body: Bytes, model: &str, output_format: ProviderFormat, + client_headers: &ClientHeaders, ) -> Result { let (provider, auth, spec, strategy) = self.resolve_provider(model)?; let payload = match lingua::transform_request(body.clone(), provider.format(), Some(model)) @@ -156,7 +158,14 @@ impl Router { }; let response_bytes = self - .execute_with_retry(provider.clone(), auth, spec, payload, strategy) + .execute_with_retry( + provider.clone(), + auth, + spec, + payload, + strategy, + client_headers, + ) .await?; let result = lingua::transform_response(response_bytes.clone(), output_format) @@ -177,6 +186,7 @@ impl Router { /// * `body` - Raw request body bytes in any supported format (OpenAI, Anthropic, Google, etc.) /// * `model` - The model name for routing (e.g., "gpt-4", "claude-3-opus") /// * `output_format` - The output format, or None to auto-detect from body + /// * `client_headers` - Client headers to forward to the upstream provider /// /// The body will be automatically transformed to the target provider's format if needed. /// Stream chunks will be transformed to the requested output format. @@ -184,7 +194,7 @@ impl Router { feature = "tracing", tracing::instrument( name = "bt.router.complete_stream", - skip(self, body), + skip(self, body, client_headers), fields(llm.model = %model) ) )] @@ -193,6 +203,7 @@ impl Router { body: Bytes, model: &str, output_format: ProviderFormat, + client_headers: &ClientHeaders, ) -> Result { let (provider, auth, spec, _) = self.resolve_provider(model)?; let payload = match lingua::transform_request(body.clone(), provider.format(), Some(model)) @@ -203,7 +214,9 @@ impl Router { Err(e) => return Err(Error::Lingua(e.to_string())), }; - let raw_stream = provider.complete_stream(payload, auth, &spec).await?; + let raw_stream = provider + .complete_stream(payload, auth, &spec, client_headers) + .await?; Ok(transform_stream(raw_stream, output_format)) } @@ -236,6 +249,7 @@ impl Router { spec: Arc, payload: Bytes, mut strategy: RetryStrategy, + client_headers: &ClientHeaders, ) -> Result { #[cfg(feature = "tracing")] let mut attempt = 0u32; @@ -253,13 +267,19 @@ impl Router { llm.provider = %provider.id(), attempt = attempt, ); - async { provider.complete(payload.clone(), auth, &spec).await } - .instrument(span) - .await + async { + provider + .complete(payload.clone(), auth, &spec, client_headers) + .await + } + .instrument(span) + .await }; #[cfg(not(feature = "tracing"))] - let result = provider.complete(payload.clone(), auth, &spec).await; + let result = provider + .complete(payload.clone(), auth, &spec, client_headers) + .await; match result { Ok(response) => return Ok(response), diff --git a/crates/braintrust-llm-router/tests/client_headers.rs b/crates/braintrust-llm-router/tests/client_headers.rs new file mode 100644 index 00000000..9a084a59 --- /dev/null +++ b/crates/braintrust-llm-router/tests/client_headers.rs @@ -0,0 +1,37 @@ +use braintrust_llm_router::ClientHeaders; +use http::HeaderMap; + +fn apply_headers(cases: &[(&str, &str, bool)]) -> HeaderMap { + let header_pairs = cases + .iter() + .map(|(name, value, _)| (name.to_string(), value.to_string())) + .collect::>(); + let client_headers: ClientHeaders = header_pairs.into_iter().collect(); + let mut headers = HeaderMap::new(); + client_headers.apply(&mut headers); + headers +} + +#[test] +fn client_headers_filter_and_host_behavior() { + let cases = [ + ("x-amzn-trace-id", "1", false), + ("x-bt-project-id", "1", false), + ("sec-fetch-mode", "cors", false), + ("content-length", "123", false), + ("origin", "https://example.com", false), + ("priority", "u=1", false), + ("referer", "https://example.com", false), + ("user-agent", "test", false), + ("cache-control", "no-cache", false), + ("host", "api.example.com", false), + ("anthropic-beta", "tools-2024-05-16", true), + ("accept", "application/json", true), + ("x-custom-header", "1", true), + ]; + + let headers = apply_headers(&cases); + for (name, _value, expected) in cases { + assert_eq!(headers.contains_key(name), expected, "header {name}"); + } +} diff --git a/crates/braintrust-llm-router/tests/router.rs b/crates/braintrust-llm-router/tests/router.rs index fa48b66f..605ec3df 100644 --- a/crates/braintrust-llm-router/tests/router.rs +++ b/crates/braintrust-llm-router/tests/router.rs @@ -5,8 +5,8 @@ use std::time::Duration; use async_trait::async_trait; use braintrust_llm_router::{ serde_json::{json, Value}, - AuthConfig, Error, ModelCatalog, ModelFlavor, ModelSpec, Provider, ProviderFormat, - RawResponseStream, RetryPolicy, RouterBuilder, + AuthConfig, ClientHeaders, Error, ModelCatalog, ModelFlavor, ModelSpec, Provider, + ProviderFormat, RawResponseStream, RetryPolicy, RouterBuilder, }; use bytes::Bytes; @@ -33,6 +33,7 @@ impl Provider for StubProvider { payload: Bytes, _auth: &AuthConfig, _spec: &ModelSpec, + _client_headers: &ClientHeaders, ) -> braintrust_llm_router::Result { // Parse the incoming payload to extract model name let value: Value = @@ -71,6 +72,7 @@ impl Provider for StubProvider { _payload: Bytes, _auth: &AuthConfig, _spec: &ModelSpec, + _client_headers: &ClientHeaders, ) -> braintrust_llm_router::Result { Ok(Box::pin(tokio_stream::empty())) } @@ -129,7 +131,12 @@ async fn router_routes_to_stub_provider() { })); let bytes = router - .complete(body, model, ProviderFormat::OpenAI) + .complete( + body, + model, + ProviderFormat::OpenAI, + &ClientHeaders::default(), + ) .await .expect("complete"); // Parse bytes to Value using braintrust_llm_router's serde_json @@ -179,7 +186,12 @@ async fn router_requires_auth_for_provider() { })); let err = router - .complete(body, model, ProviderFormat::OpenAI) + .complete( + body, + model, + ProviderFormat::OpenAI, + &ClientHeaders::default(), + ) .await .expect_err("missing auth"); assert!(matches!(err, Error::NoAuth(alias) if alias == "stub")); @@ -220,7 +232,12 @@ async fn router_reports_missing_provider() { })); let err = router - .complete(body, model, ProviderFormat::OpenAI) + .complete( + body, + model, + ProviderFormat::OpenAI, + &ClientHeaders::default(), + ) .await .expect_err("missing provider"); assert!(matches!(err, Error::NoProvider(ProviderFormat::OpenAI))); @@ -247,7 +264,7 @@ async fn router_propagates_validation_errors() { "messages": [] })); let err = router - .complete(body, "", ProviderFormat::OpenAI) + .complete(body, "", ProviderFormat::OpenAI, &ClientHeaders::default()) .await .expect_err("validation"); // Empty model is treated as unknown model, not invalid request @@ -274,6 +291,7 @@ impl Provider for FailingProvider { _payload: Bytes, _auth: &AuthConfig, _spec: &ModelSpec, + _client_headers: &ClientHeaders, ) -> braintrust_llm_router::Result { self.attempts.fetch_add(1, Ordering::SeqCst); Err(Error::Timeout) @@ -284,6 +302,7 @@ impl Provider for FailingProvider { _payload: Bytes, _auth: &AuthConfig, _spec: &ModelSpec, + _client_headers: &ClientHeaders, ) -> braintrust_llm_router::Result { Err(Error::Timeout) } @@ -353,7 +372,12 @@ async fn router_retries_and_propagates_terminal_error() { })); let err = router - .complete(body, model, ProviderFormat::OpenAI) + .complete( + body, + model, + ProviderFormat::OpenAI, + &ClientHeaders::default(), + ) .await .expect_err("terminal error"); assert!(matches!(err, Error::Timeout));