diff --git a/crates/coverage-report/src/report.rs b/crates/coverage-report/src/report.rs index 5e938809..209ceb07 100644 --- a/crates/coverage-report/src/report.rs +++ b/crates/coverage-report/src/report.rs @@ -23,17 +23,28 @@ pub fn format_cell(pair_result: &PairResult) -> String { format!("{} {}/{}", emoji, pair_result.passed, total) } +/// Returns (table_markdown, stats, failures, limitations, missing_fixtures) pub fn generate_table( results: &HashMap<(usize, usize), PairResult>, adapters: &[Box], title: &str, -) -> (String, TableStats, Vec<(String, String, String)>) { +) -> ( + String, + TableStats, + Vec<(String, String, String)>, + Vec<(String, String, String)>, + Vec<(String, String, String)>, +) { let mut table = String::new(); let mut stats = TableStats { passed: 0, failed: 0, + limitations: 0, + missing_fixtures: 0, }; let mut all_failures: Vec<(String, String, String)> = Vec::new(); + let mut all_limitations: Vec<(String, String, String)> = Vec::new(); + let mut all_missing_fixtures: Vec<(String, String, String)> = Vec::new(); table.push_str(&format!("### {}\n\n", title)); table.push_str("| Source ↓ / Target → |"); @@ -57,6 +68,8 @@ pub fn generate_table( stats.passed += pair_result.passed; stats.failed += pair_result.failed; + stats.limitations += pair_result.limitations; + stats.missing_fixtures += pair_result.missing_fixtures; for (test_case, error) in &pair_result.failures { all_failures.push(( @@ -65,12 +78,28 @@ pub fn generate_table( error.clone(), )); } + + for (test_case, error) in &pair_result.limitation_details { + all_limitations.push(( + format!("{} → {}", source.display_name(), target.display_name()), + test_case.clone(), + error.clone(), + )); + } + + for (test_case, error) in &pair_result.missing_fixture_details { + all_missing_fixtures.push(( + format!("{} → {}", source.display_name(), target.display_name()), + test_case.clone(), + error.clone(), + )); + } } } table.push('\n'); } - (table, stats, all_failures) + (table, stats, all_failures, all_limitations, all_missing_fixtures) } // ============================================================================ @@ -259,25 +288,30 @@ pub fn generate_report( report.push_str("## Cross-Provider Transformation Coverage\n\n"); - let (req_table, req_stats, req_failures) = + let (req_table, req_stats, req_failures, req_limitations, req_missing) = generate_table(request_results, adapters, "Request Transformations"); report.push_str(&req_table); report.push('\n'); - let (resp_table, resp_stats, resp_failures) = + let (resp_table, resp_stats, resp_failures, resp_limitations, resp_missing) = generate_table(response_results, adapters, "Response Transformations"); report.push_str(&resp_table); report.push('\n'); - let (stream_table, stream_stats, stream_failures) = generate_table( - streaming_results, - adapters, - "Streaming Response Transformations", - ); + let (stream_table, stream_stats, stream_failures, stream_limitations, stream_missing) = + generate_table( + streaming_results, + adapters, + "Streaming Response Transformations", + ); report.push_str(&stream_table); let total_passed = req_stats.passed + resp_stats.passed + stream_stats.passed; let total_failed = req_stats.failed + resp_stats.failed + stream_stats.failed; + let total_limitations = + req_stats.limitations + resp_stats.limitations + stream_stats.limitations; + let total_missing = + req_stats.missing_fixtures + resp_stats.missing_fixtures + stream_stats.missing_fixtures; let total = total_passed + total_failed; let pass_percentage = if total > 0 { @@ -288,8 +322,8 @@ pub fn generate_report( report.push_str("\n### Summary\n\n"); report.push_str(&format!( - "**{}/{} ({:.1}%)** - {} failed\n", - total_passed, total, pass_percentage, total_failed + "**{}/{} ({:.1}%)** - {} failed, {} limitations, {} missing fixtures\n", + total_passed, total, pass_percentage, total_failed, total_limitations, total_missing )); let req_total = req_stats.passed + req_stats.failed; @@ -297,16 +331,28 @@ pub fn generate_report( let stream_total = stream_stats.passed + stream_stats.failed; report.push_str(&format!( - "\n**Requests:** {}/{} passed, {} failed\n", - req_stats.passed, req_total, req_stats.failed + "\n**Requests:** {}/{} passed, {} failed, {} limitations, {} missing\n", + req_stats.passed, + req_total, + req_stats.failed, + req_stats.limitations, + req_stats.missing_fixtures )); report.push_str(&format!( - "**Responses:** {}/{} passed, {} failed\n", - resp_stats.passed, resp_total, resp_stats.failed + "**Responses:** {}/{} passed, {} failed, {} limitations, {} missing\n", + resp_stats.passed, + resp_total, + resp_stats.failed, + resp_stats.limitations, + resp_stats.missing_fixtures )); report.push_str(&format!( - "**Streaming:** {}/{} passed, {} failed\n", - stream_stats.passed, stream_total, stream_stats.failed + "**Streaming:** {}/{} passed, {} failed, {} limitations, {} missing\n", + stream_stats.passed, + stream_total, + stream_stats.failed, + stream_stats.limitations, + stream_stats.missing_fixtures )); // Organize issues by source provider → request/response/streaming → target @@ -509,6 +555,114 @@ pub fn generate_report( } } + // Add provider limitations section + let all_limitations: Vec<_> = req_limitations + .into_iter() + .chain(resp_limitations) + .chain(stream_limitations) + .collect(); + + if !all_limitations.is_empty() { + report.push_str("\n### Provider Limitations\n\n"); + report.push_str("These are provider-specific features that cannot be transformed:\n\n"); + + // Group by source provider + let mut by_source: HashMap> = HashMap::new(); + for (direction, test_case, error) in all_limitations { + let source = direction + .split(" → ") + .next() + .unwrap_or(&direction) + .to_string(); + by_source + .entry(source) + .or_default() + .push((direction, test_case, error)); + } + + let mut sources: Vec<_> = by_source.into_iter().collect(); + sources.sort_by(|a, b| b.1.len().cmp(&a.1.len())); + + for (source, limitations) in sources { + report.push_str("
\n"); + report.push_str(&format!( + "⚠️ {} ({} limitations)\n\n", + source, + limitations.len() + )); + + // Group by target + let mut by_target: HashMap> = HashMap::new(); + for (direction, test_case, error) in limitations { + let target = direction + .split(" → ") + .nth(1) + .unwrap_or("Unknown") + .to_string(); + by_target + .entry(target) + .or_default() + .push((test_case, error)); + } + + let mut targets: Vec<_> = by_target.into_iter().collect(); + targets.sort_by(|a, b| b.1.len().cmp(&a.1.len())); + + for (target, target_limitations) in targets { + report.push_str(&format!("**→ {}:**\n", target)); + for (test_case, error) in target_limitations { + report.push_str(&format!(" - `{}` - {}\n", test_case, error)); + } + report.push('\n'); + } + + report.push_str("
\n\n"); + } + } + + // Add missing fixtures section (collapsed by default) + let all_missing: Vec<_> = req_missing + .into_iter() + .chain(resp_missing) + .chain(stream_missing) + .collect(); + + if !all_missing.is_empty() { + report.push_str("\n### Missing Test Fixtures\n\n"); + report.push_str("
\n"); + report.push_str(&format!( + "📁 {} missing fixtures (expand to see details)\n\n", + all_missing.len() + )); + + // Group by source provider + let mut by_source: HashMap> = HashMap::new(); + for (direction, test_case, error) in all_missing { + let source = direction + .split(" → ") + .next() + .unwrap_or(&direction) + .to_string(); + by_source + .entry(source) + .or_default() + .push((direction, test_case, error)); + } + + let mut sources: Vec<_> = by_source.into_iter().collect(); + sources.sort_by(|a, b| b.1.len().cmp(&a.1.len())); + + for (source, missing) in sources { + report.push_str(&format!("**{}** ({} missing):\n", source, missing.len())); + for (_, test_case, _) in missing { + report.push_str(&format!(" - `{}`\n", test_case)); + } + report.push('\n'); + } + + report.push_str("
\n"); + } + // Add roundtrip section report.push('\n'); report.push_str(&generate_roundtrip_section(roundtrip_results, adapters)); diff --git a/crates/coverage-report/src/runner.rs b/crates/coverage-report/src/runner.rs index 01dddba8..18c76ca7 100644 --- a/crates/coverage-report/src/runner.rs +++ b/crates/coverage-report/src/runner.rs @@ -18,6 +18,30 @@ use crate::types::{PairResult, TransformResult, ValidationLevel}; type PairResults = HashMap<(usize, usize), PairResult>; type AllResults = (PairResults, PairResults, PairResults); +// Patterns that indicate provider limitations (real gaps, not bugs) +const LIMITATION_PATTERNS: &[&str] = &[ + "Provider limitation", + "has no OpenAI equivalent", + "has no Anthropic equivalent", + "has no Bedrock equivalent", + "has no Google equivalent", + "Unsupported", +]; + +// Patterns that indicate missing test fixtures (test coverage gaps) +const MISSING_FIXTURE_PATTERNS: &[&str] = &["Source payload not found"]; + +/// Classify an error into failure, limitation, or missing fixture. +fn classify_error(error: &str) -> ValidationLevel { + if MISSING_FIXTURE_PATTERNS.iter().any(|p| error.contains(p)) { + ValidationLevel::MissingFixture + } else if LIMITATION_PATTERNS.iter().any(|p| error.contains(p)) { + ValidationLevel::Limitation + } else { + ValidationLevel::Fail + } +} + // Validation uses request_to_universal/response_to_universal from the adapter trait. // These methods return Result with detailed error info when validation fails. @@ -30,10 +54,11 @@ pub fn test_request_transformation( let payload = match load_payload(test_case, source_adapter.directory_name(), filename) { Some(p) => p, None => { + let error = format!("Source payload not found: {}", filename); return TransformResult { - level: ValidationLevel::Fail, - error: Some(format!("Source payload not found: {}", filename)), - } + level: ValidationLevel::MissingFixture, + error: Some(error), + }; } }; @@ -77,10 +102,11 @@ pub fn test_request_transformation( }, } } - Err(e) => TransformResult { - level: ValidationLevel::Fail, - error: Some(format!("{}", e)), - }, + Err(e) => { + let error = format!("{}", e); + let level = classify_error(&error); + TransformResult { level, error: Some(error) } + } } } @@ -265,6 +291,22 @@ pub fn run_all_tests(adapters: &[Box]) -> AllResults { .push((format!("{} (request)", test_case), error)); } } + ValidationLevel::Limitation => { + pair_result.limitations += 1; + if let Some(error) = result.error { + pair_result + .limitation_details + .push((format!("{} (request)", test_case), error)); + } + } + ValidationLevel::MissingFixture => { + pair_result.missing_fixtures += 1; + if let Some(error) = result.error { + pair_result + .missing_fixture_details + .push((format!("{} (request)", test_case), error)); + } + } } // Test followup request if exists @@ -285,6 +327,22 @@ pub fn run_all_tests(adapters: &[Box]) -> AllResults { .push((format!("{} (followup)", test_case), error)); } } + ValidationLevel::Limitation => { + pair_result.limitations += 1; + if let Some(error) = followup_result.error { + pair_result + .limitation_details + .push((format!("{} (followup)", test_case), error)); + } + } + ValidationLevel::MissingFixture => { + pair_result.missing_fixtures += 1; + if let Some(error) = followup_result.error { + pair_result + .missing_fixture_details + .push((format!("{} (followup)", test_case), error)); + } + } } } @@ -308,6 +366,22 @@ pub fn run_all_tests(adapters: &[Box]) -> AllResults { .push((format!("{} (response)", test_case), error)); } } + ValidationLevel::Limitation => { + resp_pair_result.limitations += 1; + if let Some(error) = response_result.error { + resp_pair_result + .limitation_details + .push((format!("{} (response)", test_case), error)); + } + } + ValidationLevel::MissingFixture => { + resp_pair_result.missing_fixtures += 1; + if let Some(error) = response_result.error { + resp_pair_result + .missing_fixture_details + .push((format!("{} (response)", test_case), error)); + } + } } } @@ -337,6 +411,22 @@ pub fn run_all_tests(adapters: &[Box]) -> AllResults { .push((format!("{} (streaming)", test_case), error)); } } + ValidationLevel::Limitation => { + stream_pair_result.limitations += 1; + if let Some(error) = streaming_result.error { + stream_pair_result + .limitation_details + .push((format!("{} (streaming)", test_case), error)); + } + } + ValidationLevel::MissingFixture => { + stream_pair_result.missing_fixtures += 1; + if let Some(error) = streaming_result.error { + stream_pair_result + .missing_fixture_details + .push((format!("{} (streaming)", test_case), error)); + } + } } } @@ -362,6 +452,22 @@ pub fn run_all_tests(adapters: &[Box]) -> AllResults { .push((format!("{} (followup-streaming)", test_case), error)); } } + ValidationLevel::Limitation => { + stream_pair_result.limitations += 1; + if let Some(error) = followup_streaming_result.error { + stream_pair_result + .limitation_details + .push((format!("{} (followup-streaming)", test_case), error)); + } + } + ValidationLevel::MissingFixture => { + stream_pair_result.missing_fixtures += 1; + if let Some(error) = followup_streaming_result.error { + stream_pair_result + .missing_fixture_details + .push((format!("{} (followup-streaming)", test_case), error)); + } + } } } } @@ -612,7 +718,9 @@ pub fn run_roundtrip_tests(adapters: &[Box]) -> RoundtripRe if let Some(result) = test_request_roundtrip(test_case, adapter, "request.json") { match result.level { ValidationLevel::Pass => provider_result.request_passed += 1, - ValidationLevel::Fail => { + ValidationLevel::Fail + | ValidationLevel::Limitation + | ValidationLevel::MissingFixture => { provider_result.request_failed += 1; provider_result .request_failures @@ -627,7 +735,9 @@ pub fn run_roundtrip_tests(adapters: &[Box]) -> RoundtripRe { match result.level { ValidationLevel::Pass => provider_result.request_passed += 1, - ValidationLevel::Fail => { + ValidationLevel::Fail + | ValidationLevel::Limitation + | ValidationLevel::MissingFixture => { provider_result.request_failed += 1; provider_result .request_failures @@ -640,7 +750,9 @@ pub fn run_roundtrip_tests(adapters: &[Box]) -> RoundtripRe if let Some(result) = test_response_roundtrip(test_case, adapter, "response.json") { match result.level { ValidationLevel::Pass => provider_result.response_passed += 1, - ValidationLevel::Fail => { + ValidationLevel::Fail + | ValidationLevel::Limitation + | ValidationLevel::MissingFixture => { provider_result.response_failed += 1; provider_result .response_failures diff --git a/crates/coverage-report/src/types.rs b/crates/coverage-report/src/types.rs index 12230c93..b7d49ae4 100644 --- a/crates/coverage-report/src/types.rs +++ b/crates/coverage-report/src/types.rs @@ -6,6 +6,10 @@ Type definitions for coverage-report. pub enum ValidationLevel { Pass, Fail, + /// Provider limitation - feature that can't be transformed (e.g., "has no OpenAI equivalent") + Limitation, + /// Missing test fixture - "Source payload not found" + MissingFixture, } #[derive(Debug)] @@ -18,12 +22,18 @@ pub struct TransformResult { pub struct PairResult { pub passed: usize, pub failed: usize, + pub limitations: usize, + pub missing_fixtures: usize, pub failures: Vec<(String, String)>, + pub limitation_details: Vec<(String, String)>, + pub missing_fixture_details: Vec<(String, String)>, } pub struct TableStats { pub passed: usize, pub failed: usize, + pub limitations: usize, + pub missing_fixtures: usize, } // ============================================================================ diff --git a/crates/lingua/src/processing/transform.rs b/crates/lingua/src/processing/transform.rs index aeceefd0..3f03038b 100644 --- a/crates/lingua/src/processing/transform.rs +++ b/crates/lingua/src/processing/transform.rs @@ -55,6 +55,9 @@ pub enum TransformError { #[error("Streaming not implemented: {0}")] StreamingNotImplemented(String), + + #[error("Provider limitation: {0}")] + ProviderLimitation(String), } /// Result of a transformation operation. diff --git a/crates/lingua/src/providers/anthropic/adapter.rs b/crates/lingua/src/providers/anthropic/adapter.rs index 265f9a3f..8384b3bf 100644 --- a/crates/lingua/src/providers/anthropic/adapter.rs +++ b/crates/lingua/src/providers/anthropic/adapter.rs @@ -8,41 +8,27 @@ Anthropic's Messages API has some unique requirements: use crate::capabilities::ProviderFormat; use crate::processing::adapters::{ - collect_extras, insert_opt_bool, insert_opt_f64, insert_opt_i64, insert_opt_value, - ProviderAdapter, + insert_opt_bool, insert_opt_f64, insert_opt_i64, insert_opt_value, ProviderAdapter, }; use crate::processing::transform::TransformError; -use crate::providers::anthropic::generated::{ContentBlock, CreateMessageParams, InputMessage}; +use crate::providers::anthropic::generated::{ContentBlock, InputMessage}; +use crate::providers::anthropic::params::AnthropicParams; use crate::providers::anthropic::try_parse_anthropic; use crate::serde_json::{self, Map, Value}; use crate::universal::convert::TryFromLLM; use crate::universal::message::{Message, UserContent}; use crate::universal::transform::extract_system_messages; +use std::convert::TryInto; +use crate::universal::tools::{find_builtin_tool, is_anthropic_custom_format, openai_to_anthropic_tools}; use crate::universal::{ FinishReason, UniversalParams, UniversalRequest, UniversalResponse, UniversalStreamChoice, UniversalStreamChunk, UniversalUsage, PLACEHOLDER_ID, PLACEHOLDER_MODEL, }; +use std::collections::HashMap; /// Default max_tokens for Anthropic requests (matches legacy proxy behavior). pub const DEFAULT_MAX_TOKENS: i64 = 4096; -/// Known request fields for Anthropic Messages API. -/// Fields not in this list go into `extras`. -const ANTHROPIC_KNOWN_KEYS: &[&str] = &[ - "model", - "messages", - "system", - "max_tokens", - "temperature", - "top_p", - "top_k", - "stop_sequences", - "stream", - "metadata", - "tools", - "tool_choice", -]; - /// Adapter for Anthropic Messages API. pub struct AnthropicAdapter; @@ -64,37 +50,72 @@ impl ProviderAdapter for AnthropicAdapter { } fn request_to_universal(&self, payload: Value) -> Result { - let extras = collect_extras(&payload, ANTHROPIC_KNOWN_KEYS); - let stop = payload.get("stop_sequences").cloned(); - - let request: CreateMessageParams = serde_json::from_value(payload) + // Parse into typed params - extras are automatically captured via #[serde(flatten)] + let typed_params: AnthropicParams = serde_json::from_value(payload.clone()) .map_err(|e| TransformError::ToUniversalFailed(e.to_string()))?; - let messages = as TryFromLLM>>::try_from(request.messages) + // Extract messages array directly to avoid deserializing problematic Tool enum + // (CreateMessageParams has tools: Option>, but Tool enum's #[serde(tag = "type")] + // fails on custom tools that lack a type field. We only need messages here anyway.) + let input_messages: Vec = payload + .get("messages") + .and_then(|v| serde_json::from_value(v.clone()).ok()) + .ok_or_else(|| TransformError::ToUniversalFailed("missing messages field".to_string()))?; + + let messages = as TryFromLLM>>::try_from(input_messages) .map_err(|e| TransformError::ToUniversalFailed(e.to_string()))?; let params = UniversalParams { - temperature: request.temperature, - top_p: request.top_p, - top_k: request.top_k, - max_tokens: Some(request.max_tokens), - stop, - tools: request.tools.and_then(|t| serde_json::to_value(t).ok()), - tool_choice: request + temperature: typed_params.temperature, + top_p: typed_params.top_p, + top_k: typed_params.top_k, + max_tokens: typed_params.max_tokens, + stop: typed_params + .stop_sequences + .as_ref() + .and_then(|v| (ProviderFormat::Anthropic, v).try_into().ok()), + tools: typed_params.tools, + tool_choice: typed_params .tool_choice - .and_then(|t| serde_json::to_value(t).ok()), - response_format: None, // Anthropic doesn't use response_format - seed: None, // Anthropic doesn't support seed + .as_ref() + .and_then(|v| (ProviderFormat::Anthropic, v).try_into().ok()), + response_format: None, // Anthropic doesn't use response_format + seed: None, // Anthropic doesn't support seed presence_penalty: None, // Anthropic doesn't support these frequency_penalty: None, - stream: request.stream, + stream: typed_params.stream, + // Extract parallel_tool_calls from Anthropic's disable_parallel_tool_use in tool_choice + parallel_tool_calls: typed_params + .tool_choice + .as_ref() + .and_then(|tc| tc.get("disable_parallel_tool_use")) + .and_then(Value::as_bool) + .map(|disabled| !disabled), // disable_parallel_tool_use: true → parallel_tool_calls: false + reasoning: typed_params + .thinking + .as_ref() + .and_then(|v| (ProviderFormat::Anthropic, v).try_into().ok()), + metadata: typed_params.metadata, + store: None, // Anthropic doesn't support store + service_tier: typed_params.service_tier, + logprobs: None, // Anthropic doesn't support logprobs + top_logprobs: None, // Anthropic doesn't support top_logprobs }; + // Use extras captured automatically via #[serde(flatten)] + let mut provider_extras = HashMap::new(); + if !typed_params.extras.is_empty() { + provider_extras.insert( + ProviderFormat::Anthropic, + typed_params.extras.into_iter().collect(), + ); + } + Ok(UniversalRequest { - model: Some(request.model), + model: typed_params.model, messages, params, - extras, + provider_extras, }) } @@ -144,20 +165,66 @@ impl ProviderAdapter for AnthropicAdapter { insert_opt_i64(&mut obj, "top_k", req.params.top_k); // Anthropic uses stop_sequences instead of stop - if let Some(stop) = &req.params.stop { - obj.insert("stop_sequences".into(), stop.clone()); + if let Some(sequences) = req.params.stop.as_ref().and_then(|s| s.to_sequences_array()) { + obj.insert( + "stop_sequences".into(), + Value::Array(sequences.into_iter().map(Value::String).collect()), + ); } - insert_opt_value(&mut obj, "tools", req.params.tools.clone()); - insert_opt_value(&mut obj, "tool_choice", req.params.tool_choice.clone()); + // Convert tools from OpenAI format to Anthropic format if needed + if let Some(tools) = &req.params.tools { + if is_anthropic_custom_format(tools) || find_builtin_tool(tools).is_some() { + // Already in Anthropic format (custom tools or built-ins) - pass through + insert_opt_value(&mut obj, "tools", Some(tools.clone())); + } else { + // Convert from OpenAI format + insert_opt_value(&mut obj, "tools", openai_to_anthropic_tools(tools)); + } + } + + // Convert tool_choice from canonical ToolChoiceConfig to Anthropic format + // parallel_tool_calls is passed explicitly for disable_parallel_tool_use + if let Some(tool_choice_val) = req + .params + .tool_choice + .as_ref() + .and_then(|tc| tc.to_provider(ProviderFormat::Anthropic, req.params.parallel_tool_calls).ok()) + .flatten() + { + obj.insert("tool_choice".into(), tool_choice_val); + } insert_opt_bool(&mut obj, "stream", req.params.stream); - // Merge extras - only include Anthropic-known fields - // This filters out OpenAI-specific fields like stream_options that would cause - // Anthropic to reject the request with "extra inputs are not permitted" - for (k, v) in &req.extras { - if ANTHROPIC_KNOWN_KEYS.contains(&k.as_str()) { - obj.insert(k.clone(), v.clone()); + // Add reasoning as thinking if present (convert ReasoningConfig to Anthropic thinking format) + // max_tokens is passed explicitly for effort→budget conversion + if let Some(thinking_val) = req + .params + .reasoning + .as_ref() + .and_then(|r| r.to_provider(ProviderFormat::Anthropic, req.params.max_tokens).ok()) + .flatten() + { + obj.insert("thinking".into(), thinking_val); + } + + // Add metadata from canonical params + if let Some(metadata) = req.params.metadata.as_ref() { + obj.insert("metadata".into(), metadata.clone()); + } + + // Add service_tier from canonical params + if let Some(ref service_tier) = req.params.service_tier { + obj.insert("service_tier".into(), Value::String(service_tier.clone())); + } + + // Merge back provider-specific extras (only for Anthropic) + if let Some(extras) = req.provider_extras.get(&ProviderFormat::Anthropic) { + for (k, v) in extras { + // Don't overwrite canonical fields we already handled + if !obj.contains_key(k) { + obj.insert(k.clone(), v.clone()); + } } } @@ -559,7 +626,7 @@ mod tests { model: Some("claude-3-5-sonnet-20241022".to_string()), messages: vec![], params: UniversalParams::default(), - extras: Map::new(), + provider_extras: HashMap::new(), }; assert!(req.params.max_tokens.is_none()); @@ -577,7 +644,7 @@ mod tests { max_tokens: Some(8192), ..Default::default() }, - extras: Map::new(), + provider_extras: HashMap::new(), }; adapter.apply_defaults(&mut req); diff --git a/crates/lingua/src/providers/anthropic/convert.rs b/crates/lingua/src/providers/anthropic/convert.rs index 5de0b99d..1b4131e4 100644 --- a/crates/lingua/src/providers/anthropic/convert.rs +++ b/crates/lingua/src/providers/anthropic/convert.rs @@ -2,8 +2,9 @@ use crate::error::ConvertError; use crate::providers::anthropic::generated; use crate::serde_json; use crate::universal::{ - convert::TryFromLLM, AssistantContent, AssistantContentPart, Message, TextContentPart, - ToolCallArguments, ToolContentPart, ToolResultContentPart, UserContent, UserContentPart, + convert::TryFromLLM, message::ProviderOptions, AssistantContent, AssistantContentPart, Message, + TextContentPart, ToolCallArguments, ToolContentPart, ToolResultContentPart, UserContent, + UserContentPart, }; impl TryFromLLM for Message { @@ -139,6 +140,74 @@ impl TryFromLLM for Message { // Skip tool results for now - should be handled properly continue; } + generated::InputContentBlockType::Document => { + // Map document to File with provider_options for title/context + if let Some(source) = &block.source { + let mut opts = serde_json::Map::new(); + // Store document-specific fields in provider_options + opts.insert( + "anthropic_type".into(), + serde_json::Value::String("document".to_string()), + ); + if let Some(title) = &block.title { + opts.insert( + "title".into(), + serde_json::Value::String(title.clone()), + ); + } + if let Some(context) = &block.context { + opts.insert( + "context".into(), + serde_json::Value::String(context.clone()), + ); + } + + // Extract data and media_type from source + match source { + generated::Source::SourceSource(s) => { + let media_type = s.media_type.as_ref().map(|mt| { + match mt { + generated::FluffyMediaType::ImageJpeg => { + "image/jpeg".to_string() + } + generated::FluffyMediaType::ImagePng => { + "image/png".to_string() + } + generated::FluffyMediaType::ImageGif => { + "image/gif".to_string() + } + generated::FluffyMediaType::ImageWebp => { + "image/webp".to_string() + } + generated::FluffyMediaType::ApplicationPdf => { + "application/pdf".to_string() + } + generated::FluffyMediaType::TextPlain => { + "text/plain".to_string() + } + } + }); + content_parts.push(UserContentPart::File { + data: s + .data + .clone() + .map(serde_json::Value::String) + .unwrap_or(serde_json::Value::Null), + filename: None, + media_type: media_type + .unwrap_or_else(|| "text/plain".to_string()), + provider_options: Some(ProviderOptions { + options: opts, + }), + }); + } + _ => { + // Skip other source types + continue; + } + } + } + } _ => { // Skip other types for now continue; @@ -180,10 +249,19 @@ impl TryFromLLM for Message { match block.input_content_block_type { generated::InputContentBlockType::Text => { if let Some(text) = block.text { + // Preserve citations in provider_options for roundtrip + let provider_options = block.citations.as_ref().map(|citations| { + let mut opts = serde_json::Map::new(); + if let Ok(v) = serde_json::to_value(citations) { + opts.insert("citations".into(), v); + } + ProviderOptions { options: opts } + }); + content_parts.push(AssistantContentPart::Text( TextContentPart { text, - provider_options: None, + provider_options, }, )); } @@ -192,7 +270,8 @@ impl TryFromLLM for Message { if let Some(thinking) = block.thinking { content_parts.push(AssistantContentPart::Reasoning { text: thinking, - encrypted_content: None, + // Preserve the signature in encrypted_content for roundtrip + encrypted_content: block.signature.clone(), }); } } @@ -218,6 +297,51 @@ impl TryFromLLM for Message { }); } } + generated::InputContentBlockType::ServerToolUse => { + // Server-executed tool use (web search, etc.) + if let (Some(id), Some(name)) = (&block.id, &block.name) { + let input = if let Some(input_map) = &block.input { + serde_json::to_value(input_map) + .unwrap_or(serde_json::Value::Null) + } else { + serde_json::Value::Null + }; + + content_parts.push(AssistantContentPart::ToolCall { + tool_call_id: id.clone(), + tool_name: name.clone(), + arguments: serde_json::to_string(&input) + .unwrap_or_else(|_| "{}".to_string()) + .into(), + provider_options: None, + provider_executed: Some(true), // Mark as server-executed + }); + } + } + generated::InputContentBlockType::WebSearchToolResult => { + // Web search tool result - convert to ToolResult with marker + if let Some(id) = &block.tool_use_id { + let mut output = serde_json::Map::new(); + output.insert( + "anthropic_type".into(), + serde_json::Value::String( + "web_search_tool_result".to_string(), + ), + ); + if let Some(content) = &block.content { + if let Ok(v) = serde_json::to_value(content) { + output.insert("content".into(), v); + } + } + + content_parts.push(AssistantContentPart::ToolResult { + tool_call_id: id.clone(), + tool_name: "web_search".to_string(), // Server-executed web search tool + output: serde_json::Value::Object(output), + provider_options: None, + }); + } + } _ => { // Skip other types for now continue; @@ -344,7 +468,82 @@ impl TryFromLLM for generated::InputMessage { None } } - _ => None, // Skip other parts for now + UserContentPart::File { + data, + media_type, + provider_options, + .. + } => { + // Check if this was originally a Document block + let is_document = provider_options + .as_ref() + .and_then(|opts| opts.options.get("anthropic_type")) + .and_then(|v| v.as_str()) + == Some("document"); + + if is_document { + // Restore as Document block + let title = provider_options + .as_ref() + .and_then(|opts| opts.options.get("title")) + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + let context = provider_options + .as_ref() + .and_then(|opts| opts.options.get("context")) + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + let anthropic_media_type = match media_type.as_str() { + "image/jpeg" => Some(generated::FluffyMediaType::ImageJpeg), + "image/png" => Some(generated::FluffyMediaType::ImagePng), + "image/gif" => Some(generated::FluffyMediaType::ImageGif), + "image/webp" => Some(generated::FluffyMediaType::ImageWebp), + "application/pdf" => { + Some(generated::FluffyMediaType::ApplicationPdf) + } + "text/plain" => Some(generated::FluffyMediaType::TextPlain), + _ => Some(generated::FluffyMediaType::TextPlain), + }; + + let data_str = match data { + serde_json::Value::String(s) => Some(s), + _ => None, + }; + + Some(generated::InputContentBlock { + cache_control: None, + citations: None, + text: None, + input_content_block_type: + generated::InputContentBlockType::Document, + source: Some(generated::Source::SourceSource( + generated::SourceSource { + data: data_str, + media_type: anthropic_media_type, + source_type: generated::FluffyType::Text, + url: None, + content: None, + }, + )), + context, + title, + content: None, + signature: None, + thinking: None, + data: None, + id: None, + input: None, + name: None, + is_error: None, + tool_use_id: None, + }) + } else { + // Regular file - skip for now + None + } + } }) .collect(); generated::MessageContent::InputContentBlockArray(blocks) @@ -382,9 +581,15 @@ impl TryFromLLM for generated::InputMessage { .into_iter() .filter_map(|part| match part { AssistantContentPart::Text(text_part) => { + // Restore citations from provider_options + let citations = text_part.provider_options + .as_ref() + .and_then(|opts| opts.options.get("citations")) + .and_then(|v| serde_json::from_value::(v.clone()).ok()); + Some(generated::InputContentBlock { cache_control: None, - citations: None, + citations, text: Some(text_part.text), input_content_block_type: generated::InputContentBlockType::Text, @@ -402,7 +607,10 @@ impl TryFromLLM for generated::InputMessage { tool_use_id: None, }) } - AssistantContentPart::Reasoning { text, .. } => { + AssistantContentPart::Reasoning { + text, + encrypted_content, + } => { Some(generated::InputContentBlock { cache_control: None, citations: None, @@ -413,7 +621,8 @@ impl TryFromLLM for generated::InputMessage { context: None, title: None, content: None, - signature: None, + // Restore signature from encrypted_content + signature: encrypted_content, thinking: Some(text), data: None, id: None, @@ -427,6 +636,7 @@ impl TryFromLLM for generated::InputMessage { tool_call_id, tool_name, arguments, + provider_executed, .. } => { // Convert ToolCallArguments to serde_json::Map @@ -435,12 +645,18 @@ impl TryFromLLM for generated::InputMessage { ToolCallArguments::Invalid(_) => None, }; + // Use ServerToolUse for provider-executed tools + let block_type = if provider_executed == Some(true) { + generated::InputContentBlockType::ServerToolUse + } else { + generated::InputContentBlockType::ToolUse + }; + Some(generated::InputContentBlock { cache_control: None, citations: None, text: None, - input_content_block_type: - generated::InputContentBlockType::ToolUse, + input_content_block_type: block_type, source: None, context: None, title: None, @@ -455,6 +671,46 @@ impl TryFromLLM for generated::InputMessage { tool_use_id: None, }) } + AssistantContentPart::ToolResult { + tool_call_id, + output, + .. + } => { + // Check if this was a web_search_tool_result + let is_web_search_result = output.as_object() + .and_then(|obj| obj.get("anthropic_type")) + .and_then(|v| v.as_str()) + == Some("web_search_tool_result"); + + if is_web_search_result { + // Restore WebSearchToolResult block + let content = output.as_object() + .and_then(|obj| obj.get("content")) + .and_then(|v| serde_json::from_value::(v.clone()).ok()); + + Some(generated::InputContentBlock { + cache_control: None, + citations: None, + text: None, + input_content_block_type: + generated::InputContentBlockType::WebSearchToolResult, + source: None, + context: None, + title: None, + content, + signature: None, + thinking: None, + data: None, + id: None, + input: None, + name: None, + is_error: None, + tool_use_id: Some(tool_call_id.clone()), + }) + } else { + None // Skip other tool results in assistant messages + } + } _ => None, // Skip other types for now }) .collect(), @@ -530,9 +786,17 @@ impl TryFromLLM> for Vec { match block.content_block_type { generated::ContentBlockType::Text => { if let Some(text) = block.text { + // Preserve citations in provider_options for roundtrip + let provider_options = block.citations.as_ref().map(|citations| { + let mut opts = serde_json::Map::new(); + if let Ok(v) = serde_json::to_value(citations) { + opts.insert("citations".into(), v); + } + ProviderOptions { options: opts } + }); content_parts.push(AssistantContentPart::Text(TextContentPart { text, - provider_options: None, + provider_options, })); } } @@ -540,7 +804,8 @@ impl TryFromLLM> for Vec { if let Some(thinking) = block.thinking { content_parts.push(AssistantContentPart::Reasoning { text: thinking, - encrypted_content: None, + // Preserve signature in encrypted_content for roundtrip + encrypted_content: block.signature.clone(), }); } } @@ -564,8 +829,51 @@ impl TryFromLLM> for Vec { }); } } + generated::ContentBlockType::ServerToolUse => { + // Server-executed tool (similar to ToolUse but provider_executed=true) + if let (Some(id), Some(name)) = (block.id, block.name) { + let input = if let Some(input_map) = block.input { + serde_json::to_value(input_map).unwrap_or(serde_json::Value::Null) + } else { + serde_json::Value::Null + }; + + content_parts.push(AssistantContentPart::ToolCall { + tool_call_id: id, + tool_name: name, + arguments: serde_json::to_string(&input) + .unwrap_or_else(|_| "{}".to_string()) + .into(), + provider_options: None, + provider_executed: Some(true), // Mark as server-executed + }); + } + } + generated::ContentBlockType::WebSearchToolResult => { + // Web search tool result - convert to ToolResult with full data + if let Some(id) = block.tool_use_id { + // Store the entire block data for roundtrip + let mut output = serde_json::Map::new(); + output.insert( + "anthropic_type".into(), + serde_json::Value::String("web_search_tool_result".to_string()), + ); + if let Some(content) = &block.content { + if let Ok(v) = serde_json::to_value(content) { + output.insert("content".into(), v); + } + } + + content_parts.push(AssistantContentPart::ToolResult { + tool_call_id: id, + tool_name: "web_search".to_string(), + output: serde_json::Value::Object(output), + provider_options: None, + }); + } + } _ => { - // Skip other types for now + // Skip other types (RedactedThinking, etc.) continue; } } @@ -614,8 +922,19 @@ impl TryFromLLM> for Vec { for part in parts { match part { AssistantContentPart::Text(text_part) => { + // Restore citations from provider_options if present + let citations = text_part + .provider_options + .as_ref() + .and_then(|opts| opts.options.get("citations")) + .and_then(|v| { + serde_json::from_value::>( + v.clone(), + ) + .ok() + }); content_blocks.push(generated::ContentBlock { - citations: None, + citations, text: Some(text_part.text), content_block_type: generated::ContentBlockType::Text, signature: None, @@ -628,12 +947,16 @@ impl TryFromLLM> for Vec { tool_use_id: None, }); } - AssistantContentPart::Reasoning { text, .. } => { + AssistantContentPart::Reasoning { + text, + encrypted_content, + } => { content_blocks.push(generated::ContentBlock { citations: None, text: None, content_block_type: generated::ContentBlockType::Thinking, - signature: None, + // Restore signature from encrypted_content + signature: encrypted_content, thinking: Some(text), data: None, id: None, @@ -647,6 +970,7 @@ impl TryFromLLM> for Vec { tool_call_id, tool_name, arguments, + provider_executed, .. } => { // Convert ToolCallArguments to serde_json::Map for response generation @@ -655,10 +979,17 @@ impl TryFromLLM> for Vec { ToolCallArguments::Invalid(_) => None, }; + // Use ServerToolUse if provider_executed is true + let block_type = if provider_executed == Some(true) { + generated::ContentBlockType::ServerToolUse + } else { + generated::ContentBlockType::ToolUse + }; + content_blocks.push(generated::ContentBlock { citations: None, text: None, - content_block_type: generated::ContentBlockType::ToolUse, + content_block_type: block_type, signature: None, thinking: None, data: None, @@ -669,6 +1000,45 @@ impl TryFromLLM> for Vec { tool_use_id: None, }); } + AssistantContentPart::ToolResult { + tool_call_id, + output, + .. + } => { + // Check if this is a web_search_tool_result + let is_web_search_result = output + .get("anthropic_type") + .and_then(|v| v.as_str()) + == Some("web_search_tool_result"); + + if is_web_search_result { + // Restore as WebSearchToolResult + let content = output + .get("content") + .and_then(|v| { + serde_json::from_value::( + v.clone(), + ) + .ok() + }); + + content_blocks.push(generated::ContentBlock { + citations: None, + text: None, + content_block_type: + generated::ContentBlockType::WebSearchToolResult, + signature: None, + thinking: None, + data: None, + id: None, + input: None, + name: None, + content, + tool_use_id: Some(tool_call_id.clone()), + }); + } + // Skip other tool results - they shouldn't appear in response content + } _ => { // Skip other types for now continue; diff --git a/crates/lingua/src/providers/anthropic/mod.rs b/crates/lingua/src/providers/anthropic/mod.rs index df2fe6b1..b33bb334 100644 --- a/crates/lingua/src/providers/anthropic/mod.rs +++ b/crates/lingua/src/providers/anthropic/mod.rs @@ -9,6 +9,7 @@ pub mod adapter; pub mod convert; pub mod detect; pub mod generated; +pub mod params; #[cfg(test)] pub mod test_anthropic; diff --git a/crates/lingua/src/providers/anthropic/params.rs b/crates/lingua/src/providers/anthropic/params.rs new file mode 100644 index 00000000..16292be0 --- /dev/null +++ b/crates/lingua/src/providers/anthropic/params.rs @@ -0,0 +1,149 @@ +/*! +Typed parameter structs for Anthropic Messages API. + +These structs use `#[serde(flatten)]` to automatically capture unknown fields, +eliminating the need for explicit KNOWN_KEYS arrays. +*/ + +use crate::serde_json::Value; +use serde::{Deserialize, Serialize}; +use std::collections::BTreeMap; + +/// Anthropic Messages API request parameters. +/// +/// All known fields are explicitly typed. Unknown fields automatically +/// go into `extras` via `#[serde(flatten)]`. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct AnthropicParams { + // === Core fields === + pub model: Option, + pub messages: Option, + + // === System prompt (can be string or array with cache_control) === + pub system: Option, + + // === Required output control === + pub max_tokens: Option, + + // === Sampling parameters === + pub temperature: Option, + pub top_p: Option, + pub top_k: Option, + pub stop_sequences: Option, + + // === Streaming === + pub stream: Option, + + // === Tools and function calling === + pub tools: Option, + pub tool_choice: Option, + + // === Extended thinking === + pub thinking: Option, + + // === Metadata and identification === + pub metadata: Option, + pub service_tier: Option, + + /// Unknown fields - automatically captured by serde flatten. + /// These are provider-specific fields not in the canonical set. + #[serde(flatten)] + pub extras: BTreeMap, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::serde_json; + use crate::serde_json::json; + + #[test] + fn test_anthropic_params_known_fields() { + let json = json!({ + "model": "claude-sonnet-4-20250514", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 1024, + "temperature": 0.7, + "top_k": 40 + }); + + let params: AnthropicParams = serde_json::from_value(json).unwrap(); + assert_eq!(params.model, Some("claude-sonnet-4-20250514".to_string())); + assert_eq!(params.max_tokens, Some(1024)); + assert_eq!(params.temperature, Some(0.7)); + assert_eq!(params.top_k, Some(40)); + assert!(params.extras.is_empty()); + } + + #[test] + fn test_anthropic_params_with_thinking() { + let json = json!({ + "model": "claude-sonnet-4-20250514", + "messages": [], + "max_tokens": 16000, + "thinking": { + "type": "enabled", + "budget_tokens": 10000 + } + }); + + let params: AnthropicParams = serde_json::from_value(json).unwrap(); + assert!(params.thinking.is_some()); + let thinking = params.thinking.unwrap(); + assert_eq!(thinking.get("type"), Some(&json!("enabled"))); + assert_eq!(thinking.get("budget_tokens"), Some(&json!(10000))); + } + + #[test] + fn test_anthropic_params_with_system_cache_control() { + let json = json!({ + "model": "claude-sonnet-4-20250514", + "messages": [], + "max_tokens": 1024, + "system": [ + { + "type": "text", + "text": "Be helpful.", + "cache_control": {"type": "ephemeral", "ttl": "5m"} + } + ] + }); + + let params: AnthropicParams = serde_json::from_value(json).unwrap(); + assert!(params.system.is_some()); + assert!(params.extras.is_empty()); + } + + #[test] + fn test_anthropic_params_unknown_fields_go_to_extras() { + let json = json!({ + "model": "claude-sonnet-4-20250514", + "messages": [], + "max_tokens": 1024, + "some_future_param": "value" + }); + + let params: AnthropicParams = serde_json::from_value(json).unwrap(); + assert_eq!(params.extras.len(), 1); + assert_eq!( + params.extras.get("some_future_param"), + Some(&Value::String("value".to_string())) + ); + } + + #[test] + fn test_anthropic_roundtrip_preserves_extras() { + let json = json!({ + "model": "claude-sonnet-4-20250514", + "messages": [], + "max_tokens": 1024, + "custom_field": {"nested": "data"} + }); + + let params: AnthropicParams = serde_json::from_value(json.clone()).unwrap(); + let back: Value = serde_json::to_value(¶ms).unwrap(); + + // Custom field should be preserved + assert_eq!(back.get("custom_field"), json.get("custom_field")); + } +} diff --git a/crates/lingua/src/providers/bedrock/adapter.rs b/crates/lingua/src/providers/bedrock/adapter.rs index e5b29217..61bb85f0 100644 --- a/crates/lingua/src/providers/bedrock/adapter.rs +++ b/crates/lingua/src/providers/bedrock/adapter.rs @@ -9,7 +9,7 @@ Bedrock's Converse API has some unique characteristics: */ use crate::capabilities::ProviderFormat; -use crate::processing::adapters::{collect_extras, ProviderAdapter}; +use crate::processing::adapters::ProviderAdapter; use crate::processing::transform::TransformError; use crate::providers::bedrock::request::{ BedrockInferenceConfiguration, BedrockMessage, ConverseRequest, @@ -18,25 +18,38 @@ use crate::providers::bedrock::try_parse_bedrock; use crate::serde_json::{self, Map, Value}; use crate::universal::convert::TryFromLLM; use crate::universal::message::Message; +use std::convert::TryInto; use crate::universal::{ FinishReason, UniversalParams, UniversalRequest, UniversalResponse, UniversalStreamChoice, UniversalStreamChunk, UniversalUsage, }; +use std::collections::HashMap; -/// Known request fields for Bedrock Converse API. -/// Fields not in this list go into `extras`. -const BEDROCK_KNOWN_KEYS: &[&str] = &[ +/// Fields that are mapped to UniversalParams canonical fields for Bedrock Converse API. +/// These are excluded from extras since they're handled explicitly. +const BEDROCK_CANONICAL_KEYS: &[&str] = &[ "modelId", "messages", "system", "inferenceConfig", "toolConfig", - "guardrailConfig", - "additionalModelRequestFields", - "additionalModelResponseFieldPaths", - "promptVariables", ]; +/// Collect fields not in BEDROCK_CANONICAL_KEYS as extras for passthrough. +fn collect_bedrock_extras(payload: &Value) -> Map { + let obj = match payload.as_object() { + Some(o) => o, + None => return Map::new(), + }; + let mut extras = Map::new(); + for (k, v) in obj { + if !BEDROCK_CANONICAL_KEYS.contains(&k.as_str()) { + extras.insert(k.clone(), v.clone()); + } + } + extras +} + /// Adapter for Amazon Bedrock Converse API. pub struct BedrockAdapter; @@ -58,9 +71,7 @@ impl ProviderAdapter for BedrockAdapter { } fn request_to_universal(&self, payload: Value) -> Result { - let extras = collect_extras(&payload, BEDROCK_KNOWN_KEYS); - - let request: ConverseRequest = serde_json::from_value(payload) + let request: ConverseRequest = serde_json::from_value(payload.clone()) .map_err(|e| TransformError::ToUniversalFailed(e.to_string()))?; let messages = @@ -77,7 +88,9 @@ impl ProviderAdapter for BedrockAdapter { config .stop_sequences .as_ref() - .and_then(|s| serde_json::to_value(s).ok()), + .and_then(|s| serde_json::to_value(s).ok()) + .as_ref() + .and_then(|v| (ProviderFormat::Converse, v).try_into().ok()), ) } else { (None, None, None, None) @@ -98,13 +111,29 @@ impl ProviderAdapter for BedrockAdapter { presence_penalty: None, frequency_penalty: None, stream: None, // Bedrock uses separate endpoint for streaming + // New canonical fields - Bedrock doesn't support these + parallel_tool_calls: None, + reasoning: None, + metadata: None, + store: None, + service_tier: None, + logprobs: None, + top_logprobs: None, }; + // Collect unknown fields as extras for this provider + let extras = collect_bedrock_extras(&payload); + + let mut provider_extras = HashMap::new(); + if !extras.is_empty() { + provider_extras.insert(ProviderFormat::Converse, extras); + } + Ok(UniversalRequest { model: Some(request.model_id), messages, params, - extras, + provider_extras, }) } @@ -138,19 +167,7 @@ impl ProviderAdapter for BedrockAdapter { temperature: req.params.temperature, top_p: req.params.top_p, max_tokens: req.params.max_tokens.map(|t| t as i32), - stop_sequences: req.params.stop.as_ref().and_then(|v| { - if let Value::Array(arr) = v { - Some( - arr.iter() - .filter_map(|s| s.as_str().map(String::from)) - .collect(), - ) - } else if let Value::String(s) = v { - Some(vec![s.clone()]) - } else { - None - } - }), + stop_sequences: req.params.stop.as_ref().and_then(|s| s.to_sequences_array()), }; obj.insert( @@ -165,9 +182,14 @@ impl ProviderAdapter for BedrockAdapter { obj.insert("toolConfig".into(), tools.clone()); } - // Merge extras - for (k, v) in &req.extras { - obj.insert(k.clone(), v.clone()); + // Merge back provider-specific extras (only for Bedrock/Converse) + if let Some(extras) = req.provider_extras.get(&ProviderFormat::Converse) { + for (k, v) in extras { + // Don't overwrite canonical fields we already handled + if !obj.contains_key(k) { + obj.insert(k.clone(), v.clone()); + } + } } Ok(Value::Object(obj)) diff --git a/crates/lingua/src/providers/google/adapter.rs b/crates/lingua/src/providers/google/adapter.rs index 2dcca98c..c9e89e4f 100644 --- a/crates/lingua/src/providers/google/adapter.rs +++ b/crates/lingua/src/providers/google/adapter.rs @@ -9,7 +9,7 @@ Google's API has some unique characteristics: */ use crate::capabilities::ProviderFormat; -use crate::processing::adapters::{collect_extras, ProviderAdapter}; +use crate::processing::adapters::ProviderAdapter; use crate::processing::transform::TransformError; use crate::providers::google::detect::{ try_parse_google, GoogleContent, GoogleGenerateContentRequest, GoogleGenerationConfig, @@ -17,11 +17,13 @@ use crate::providers::google::detect::{ use crate::serde_json::{self, Map, Value}; use crate::universal::convert::TryFromLLM; use crate::universal::message::Message; +use std::convert::TryInto; use crate::universal::{ extract_system_messages, flatten_consecutive_messages, FinishReason, UniversalParams, UniversalRequest, UniversalResponse, UniversalStreamChoice, UniversalStreamChunk, UniversalUsage, UserContent, }; +use std::collections::HashMap; /// Known request fields for Google GenerateContent API. /// Fields not in this list go into `extras`. @@ -35,6 +37,21 @@ const GOOGLE_KNOWN_KEYS: &[&str] = &[ "model", ]; +/// Collect fields not in GOOGLE_KNOWN_KEYS as extras for passthrough. +fn collect_google_extras(payload: &Value) -> Map { + let obj = match payload.as_object() { + Some(o) => o, + None => return Map::new(), + }; + let mut extras = Map::new(); + for (k, v) in obj { + if !GOOGLE_KNOWN_KEYS.contains(&k.as_str()) { + extras.insert(k.clone(), v.clone()); + } + } + extras +} + /// Adapter for Google AI GenerateContent API. pub struct GoogleAdapter; @@ -56,13 +73,12 @@ impl ProviderAdapter for GoogleAdapter { } fn request_to_universal(&self, payload: Value) -> Result { - let extras = collect_extras(&payload, GOOGLE_KNOWN_KEYS); let model = payload .get("model") .and_then(Value::as_str) .map(String::from); - let request: GoogleGenerateContentRequest = serde_json::from_value(payload) + let request: GoogleGenerateContentRequest = serde_json::from_value(payload.clone()) .map_err(|e| TransformError::ToUniversalFailed(e.to_string()))?; let messages = as TryFromLLM>>::try_from(request.contents) @@ -79,7 +95,9 @@ impl ProviderAdapter for GoogleAdapter { config .stop_sequences .as_ref() - .and_then(|s| serde_json::to_value(s).ok()), + .and_then(|s| serde_json::to_value(s).ok()) + .as_ref() + .and_then(|v| (ProviderFormat::Google, v).try_into().ok()), ) } else { (None, None, None, None, None) @@ -98,13 +116,29 @@ impl ProviderAdapter for GoogleAdapter { presence_penalty: None, frequency_penalty: None, stream: None, // Google uses endpoint-based streaming + // New canonical fields - Google doesn't support most of these + parallel_tool_calls: None, + reasoning: None, + metadata: None, + store: None, + service_tier: None, + logprobs: None, + top_logprobs: None, }; + // Collect unknown fields as extras for this provider + let extras = collect_google_extras(&payload); + + let mut provider_extras = HashMap::new(); + if !extras.is_empty() { + provider_extras.insert(ProviderFormat::Google, extras); + } + Ok(UniversalRequest { model, messages, params, - extras, + provider_extras, }) } @@ -176,19 +210,7 @@ impl ProviderAdapter for GoogleAdapter { top_p: req.params.top_p, top_k: req.params.top_k.map(|k| k as i32), max_output_tokens: req.params.max_tokens.map(|t| t as i32), - stop_sequences: req.params.stop.as_ref().and_then(|v| { - if let Value::Array(arr) = v { - Some( - arr.iter() - .filter_map(|s| s.as_str().map(String::from)) - .collect(), - ) - } else if let Value::String(s) = v { - Some(vec![s.clone()]) - } else { - None - } - }), + stop_sequences: req.params.stop.as_ref().and_then(|s| s.to_sequences_array()), }; obj.insert( @@ -203,12 +225,13 @@ impl ProviderAdapter for GoogleAdapter { obj.insert("tools".into(), tools.clone()); } - // Merge extras - only include Google-known fields - // This filters out OpenAI-specific fields like stream_options that would cause - // Google to reject the request with "Unknown name: stream_options" - for (k, v) in &req.extras { - if GOOGLE_KNOWN_KEYS.contains(&k.as_str()) { - obj.insert(k.clone(), v.clone()); + // Merge back provider-specific extras (only for Google) + if let Some(extras) = req.provider_extras.get(&ProviderFormat::Google) { + for (k, v) in extras { + // Don't overwrite canonical fields we already handled + if !obj.contains_key(k) { + obj.insert(k.clone(), v.clone()); + } } } diff --git a/crates/lingua/src/providers/openai/adapter.rs b/crates/lingua/src/providers/openai/adapter.rs index 79fc9efc..2f9c9073 100644 --- a/crates/lingua/src/providers/openai/adapter.rs +++ b/crates/lingua/src/providers/openai/adapter.rs @@ -1,15 +1,16 @@ /*! -OpenAI provider adapters for chat completions and responses API. +OpenAI Chat Completions API adapter. -This module provides two adapters: -- `OpenAIAdapter` for the standard Chat Completions API -- `ResponsesAdapter` for the Responses API (used by reasoning models like o1) +This module provides the `OpenAIAdapter` for the standard Chat Completions API, +along with target-specific transformation utilities for providers like Azure, +Vertex, and Mistral. */ use crate::capabilities::ProviderFormat; +use std::collections::HashMap; + use crate::processing::adapters::{ - collect_extras, insert_opt_bool, insert_opt_f64, insert_opt_i64, insert_opt_value, - ProviderAdapter, + insert_opt_bool, insert_opt_f64, insert_opt_i64, insert_opt_value, ProviderAdapter, }; use crate::processing::transform::TransformError; use crate::providers::openai::capabilities::{OpenAICapabilities, TargetProvider}; @@ -17,63 +18,22 @@ use crate::providers::openai::generated::{ AllowedToolsFunction, ChatCompletionRequestMessage, ChatCompletionRequestMessageContent, ChatCompletionRequestMessageContentPart, ChatCompletionRequestMessageRole, ChatCompletionResponseMessage, ChatCompletionToolChoiceOption, - CreateChatCompletionRequestClass, CreateResponseClass, File, FunctionObject, - FunctionToolChoiceClass, FunctionToolChoiceType, InputItem, InputItemContent, InputItemRole, - InputItemType, Instructions, PurpleType, ResponseFormatType, ToolElement, ToolType, -}; -use crate::providers::openai::{ - try_parse_openai, try_parse_responses, universal_to_responses_input, + CreateChatCompletionRequestClass, File, FunctionObject, FunctionToolChoiceClass, + FunctionToolChoiceType, PurpleType, ResponseFormatType, ToolElement, ToolType, }; +use crate::providers::openai::params::OpenAIChatParams; +use crate::providers::openai::try_parse_openai; use crate::serde_json::{self, Map, Value}; use crate::universal::convert::TryFromLLM; -use crate::universal::message::{AssistantContent, Message, UserContent}; +use crate::universal::message::Message; +use std::convert::TryInto; +use crate::universal::tools::{anthropic_to_openai_tools, find_builtin_tool, is_openai_format}; use crate::universal::{ FinishReason, UniversalParams, UniversalRequest, UniversalResponse, UniversalStreamChoice, UniversalStreamChunk, UniversalUsage, PLACEHOLDER_ID, PLACEHOLDER_MODEL, }; use crate::util::media::parse_base64_data_url; -/// Known request fields for OpenAI Chat Completions API. -/// These are fields extracted into UniversalRequest/UniversalParams. -/// Fields not in this list go into `extras` for passthrough. -const OPENAI_KNOWN_KEYS: &[&str] = &[ - "model", - "messages", - "temperature", - "top_p", - "max_tokens", - "max_completion_tokens", - "stop", - "tools", - "tool_choice", - "response_format", - "seed", - "presence_penalty", - "frequency_penalty", - "stream", - // OpenAI-specific fields (not in UniversalParams) go to extras: - // stream_options, n, logprobs, top_logprobs, logit_bias, - // user, store, metadata, parallel_tool_calls, service_tier -]; - -/// Known request fields for OpenAI Responses API. -/// These are fields extracted into UniversalRequest/UniversalParams. -/// Fields not in this list go into `extras` for passthrough. -const RESPONSES_KNOWN_KEYS: &[&str] = &[ - "model", - "input", - "temperature", - "top_p", - "max_output_tokens", - "tools", - "tool_choice", - "stream", - // Responses-specific fields (not in UniversalParams) go to extras: - // instructions, stop, response_format, seed, presence_penalty, - // frequency_penalty, reasoning, truncation, user, store, - // metadata, parallel_tool_calls -]; - /// Adapter for OpenAI Chat Completions API. pub struct OpenAIAdapter; @@ -95,37 +55,91 @@ impl ProviderAdapter for OpenAIAdapter { } fn request_to_universal(&self, payload: Value) -> Result { - let extras = collect_extras(&payload, OPENAI_KNOWN_KEYS); + // Parse into typed params - extras are automatically captured via #[serde(flatten)] + let typed_params: OpenAIChatParams = serde_json::from_value(payload.clone()) + .map_err(|e| TransformError::ToUniversalFailed(e.to_string()))?; + + // Parse again for strongly-typed message conversion let request: CreateChatCompletionRequestClass = serde_json::from_value(payload) .map_err(|e| TransformError::ToUniversalFailed(e.to_string()))?; let messages = as TryFromLLM>>::try_from(request.messages) .map_err(|e| TransformError::ToUniversalFailed(e.to_string()))?; + // Build canonical params from typed fields let params = UniversalParams { - temperature: request.temperature, - top_p: request.top_p, + temperature: typed_params.temperature, + top_p: typed_params.top_p, top_k: None, // OpenAI doesn't support top_k - max_tokens: request.max_tokens.or(request.max_completion_tokens), - stop: request.stop.and_then(|s| serde_json::to_value(s).ok()), - tools: request.tools.and_then(|t| serde_json::to_value(t).ok()), - tool_choice: request + max_tokens: typed_params.max_tokens.or(typed_params.max_completion_tokens), + stop: typed_params + .stop + .as_ref() + .and_then(|v| (ProviderFormat::OpenAI, v).try_into().ok()), + tools: typed_params.tools, + tool_choice: typed_params .tool_choice - .and_then(|t| serde_json::to_value(t).ok()), - response_format: request + .as_ref() + .and_then(|v| (ProviderFormat::OpenAI, v).try_into().ok()), + response_format: typed_params .response_format - .and_then(|r| serde_json::to_value(r).ok()), - seed: request.seed, - presence_penalty: request.presence_penalty, - frequency_penalty: request.frequency_penalty, - stream: request.stream, + .as_ref() + .and_then(|v| (ProviderFormat::OpenAI, v).try_into().ok()), + seed: typed_params.seed, + presence_penalty: typed_params.presence_penalty, + frequency_penalty: typed_params.frequency_penalty, + stream: typed_params.stream, + // New canonical fields + parallel_tool_calls: typed_params.parallel_tool_calls, + reasoning: typed_params.reasoning_effort.as_ref().and_then(|e| { + let json_value = serde_json::json!(e); + (ProviderFormat::OpenAI, &json_value).try_into().ok() + }), + metadata: typed_params.metadata, + store: typed_params.store, + service_tier: typed_params.service_tier, + logprobs: typed_params.logprobs, + top_logprobs: typed_params.top_logprobs, }; + // Collect provider-specific extras for round-trip preservation + // This includes both unknown fields (from serde flatten) and known OpenAI fields + // that aren't part of UniversalParams + let mut extras_map: Map = typed_params.extras.into_iter().collect(); + + // Add OpenAI-specific known fields that aren't in UniversalParams + if let Some(user) = typed_params.user { + extras_map.insert("user".into(), Value::String(user)); + } + if let Some(n) = typed_params.n { + extras_map.insert("n".into(), Value::Number(n.into())); + } + if let Some(logit_bias) = typed_params.logit_bias { + extras_map.insert("logit_bias".into(), logit_bias); + } + if let Some(stream_options) = typed_params.stream_options { + extras_map.insert("stream_options".into(), stream_options); + } + if let Some(prediction) = typed_params.prediction { + extras_map.insert("prediction".into(), prediction); + } + if let Some(safety_identifier) = typed_params.safety_identifier { + extras_map.insert("safety_identifier".into(), Value::String(safety_identifier)); + } + if let Some(prompt_cache_key) = typed_params.prompt_cache_key { + extras_map.insert("prompt_cache_key".into(), Value::String(prompt_cache_key)); + } + + let mut provider_extras = HashMap::new(); + if !extras_map.is_empty() { + provider_extras.insert(ProviderFormat::OpenAI, extras_map); + } + Ok(UniversalRequest { - model: Some(request.model), + model: typed_params.model, messages, params, - extras, + provider_extras, }) } @@ -153,19 +167,91 @@ impl ProviderAdapter for OpenAIAdapter { insert_opt_f64(&mut obj, "temperature", req.params.temperature); insert_opt_f64(&mut obj, "top_p", req.params.top_p); insert_opt_i64(&mut obj, "max_completion_tokens", req.params.max_tokens); - insert_opt_value(&mut obj, "stop", req.params.stop.clone()); - insert_opt_value(&mut obj, "tools", req.params.tools.clone()); - insert_opt_value(&mut obj, "tool_choice", req.params.tool_choice.clone()); + insert_opt_value( + &mut obj, + "stop", + req.params + .stop + .as_ref() + .and_then(|s| s.to_provider(ProviderFormat::OpenAI).ok()) + .flatten(), + ); + // Convert tools from Anthropic format to OpenAI format if needed + if let Some(tools) = &req.params.tools { + // Check for Anthropic built-in tools that have no OpenAI equivalent + if let Some(builtin_tool) = find_builtin_tool(tools) { + return Err(TransformError::ProviderLimitation(format!( + "Anthropic built-in tool '{}' has no OpenAI equivalent", + builtin_tool + ))); + } + + if is_openai_format(tools) { + insert_opt_value(&mut obj, "tools", Some(tools.clone())); + } else { + // Convert from Anthropic format + insert_opt_value(&mut obj, "tools", anthropic_to_openai_tools(tools)); + } + } + // OpenAI doesn't use parallel_tool_calls in tool_choice conversion, pass None + insert_opt_value( + &mut obj, + "tool_choice", + req.params + .tool_choice + .as_ref() + .and_then(|tc| tc.to_provider(ProviderFormat::OpenAI, None).ok()) + .flatten(), + ); insert_opt_value( &mut obj, "response_format", - req.params.response_format.clone(), + req.params + .response_format + .as_ref() + .and_then(|rf| rf.to_provider(ProviderFormat::OpenAI).ok()) + .flatten(), ); insert_opt_i64(&mut obj, "seed", req.params.seed); insert_opt_f64(&mut obj, "presence_penalty", req.params.presence_penalty); insert_opt_f64(&mut obj, "frequency_penalty", req.params.frequency_penalty); + insert_opt_bool(&mut obj, "logprobs", req.params.logprobs); + insert_opt_i64(&mut obj, "top_logprobs", req.params.top_logprobs); insert_opt_bool(&mut obj, "stream", req.params.stream); + // Add parallel_tool_calls from canonical params + if let Some(parallel) = req.params.parallel_tool_calls { + obj.insert("parallel_tool_calls".into(), Value::Bool(parallel)); + } + + // Add reasoning_effort from canonical params (convert ReasoningConfig back to string) + // to_provider returns Value::String for OpenAI format + // max_tokens is passed explicitly for budget→effort conversion + if let Some(effort_value) = req + .params + .reasoning + .as_ref() + .and_then(|r| r.to_provider(ProviderFormat::OpenAI, req.params.max_tokens).ok()) + .flatten() + { + obj.insert("reasoning_effort".into(), effort_value); + } + + // Add metadata from canonical params + if let Some(metadata) = req.params.metadata.as_ref() { + obj.insert("metadata".into(), metadata.clone()); + } + + // Add store from canonical params + if let Some(store) = req.params.store { + obj.insert("store".into(), Value::Bool(store)); + } + + // Add service_tier from canonical params + if let Some(ref service_tier) = req.params.service_tier { + obj.insert("service_tier".into(), Value::String(service_tier.clone())); + } + // If streaming, ensure stream_options.include_usage is set for usage reporting if req.params.stream == Some(true) { let stream_options = obj @@ -176,9 +262,11 @@ impl ProviderAdapter for OpenAIAdapter { } } - // Merge extras (provider-specific fields) - for (k, v) in &req.extras { - obj.insert(k.clone(), v.clone()); + // Merge back provider-specific extras (only for OpenAI) + if let Some(extras) = req.provider_extras.get(&ProviderFormat::OpenAI) { + for (k, v) in extras { + obj.insert(k.clone(), v.clone()); + } } Ok(Value::Object(obj)) @@ -455,660 +543,6 @@ impl ProviderAdapter for OpenAIAdapter { } } -/// Adapter for OpenAI Responses API (used by reasoning models like o1). -pub struct ResponsesAdapter; - -impl ProviderAdapter for ResponsesAdapter { - fn format(&self) -> ProviderFormat { - ProviderFormat::Responses - } - - fn directory_name(&self) -> &'static str { - "responses" - } - - fn display_name(&self) -> &'static str { - "Responses" - } - - fn detect_request(&self, payload: &Value) -> bool { - try_parse_responses(payload).is_ok() - } - - fn request_to_universal(&self, payload: Value) -> Result { - let extras = collect_extras(&payload, RESPONSES_KNOWN_KEYS); - let request: CreateResponseClass = serde_json::from_value(payload) - .map_err(|e| TransformError::ToUniversalFailed(e.to_string()))?; - - // Extract input items from the request - let input_items: Vec = match request.input { - Some(Instructions::InputItemArray(items)) => items, - Some(Instructions::String(s)) => { - // Single string input - create a user message InputItem - vec![InputItem { - input_item_type: Some(InputItemType::Message), - role: Some(InputItemRole::User), - content: Some(InputItemContent::String(s)), - ..Default::default() - }] - } - None => vec![], - }; - - let messages = as TryFromLLM>>::try_from(input_items) - .map_err(|e| TransformError::ToUniversalFailed(e.to_string()))?; - - let params = UniversalParams { - temperature: request.temperature, - top_p: request.top_p, - top_k: None, - max_tokens: request.max_output_tokens, - stop: None, // Responses API doesn't use stop - tools: request.tools.and_then(|t| serde_json::to_value(t).ok()), - tool_choice: request - .tool_choice - .and_then(|t| serde_json::to_value(t).ok()), - response_format: None, // Different structure in Responses API - seed: None, // Responses API uses different randomness control - presence_penalty: None, // Responses API doesn't support penalties - frequency_penalty: None, - stream: request.stream, - }; - - Ok(UniversalRequest { - model: request.model, // Already Option in CreateResponseClass - messages, - params, - extras, - }) - } - - fn request_from_universal(&self, req: &UniversalRequest) -> Result { - let model = req.model.as_ref().ok_or(TransformError::ValidationFailed { - target: ProviderFormat::Responses, - reason: "missing model".to_string(), - })?; - - // Use existing conversion with 1:N Tool message expansion - let input_items = universal_to_responses_input(&req.messages) - .map_err(|e| TransformError::FromUniversalFailed(e.to_string()))?; - - let mut obj = Map::new(); - obj.insert("model".into(), Value::String(model.clone())); - obj.insert( - "input".into(), - serde_json::to_value(input_items) - .map_err(|e| TransformError::SerializationFailed(e.to_string()))?, - ); - - // Note: temperature is intentionally NOT included for Responses API - // as reasoning models (o1, o3) don't support it - insert_opt_f64(&mut obj, "top_p", req.params.top_p); - insert_opt_i64(&mut obj, "max_output_tokens", req.params.max_tokens); - insert_opt_f64(&mut obj, "presence_penalty", req.params.presence_penalty); - insert_opt_f64(&mut obj, "frequency_penalty", req.params.frequency_penalty); - insert_opt_bool(&mut obj, "stream", req.params.stream); - - // Transform tools from OpenAI Chat format to Responses API format - // {type: "function", function: {name, description, parameters}} - // → {type: "function", name, description, parameters, strict: false} - // Tools can come from params.tools or extras.tools depending on how the request was built - let tools_value = req - .params - .tools - .as_ref() - .or_else(|| req.extras.get("tools")); - if let Some(Value::Array(tools)) = tools_value { - let response_tools: Vec = tools - .iter() - .filter_map(|tool| { - if tool.get("type").and_then(Value::as_str) == Some("function") { - let func = tool.get("function")?; - Some(serde_json::json!({ - "type": "function", - "name": func.get("name")?, - "description": func.get("description"), - "parameters": func.get("parameters").cloned().unwrap_or(serde_json::json!({})), - "strict": false - })) - } else { - None - } - }) - .collect(); - if !response_tools.is_empty() { - obj.insert("tools".into(), Value::Array(response_tools)); - } - } - - // Transform tool_choice from OpenAI Chat format to Responses API format - // {function: {name: "foo"}} → {type: "function", name: "foo"} - // tool_choice can come from params or extras depending on how the request was built - let tool_choice_value = req - .params - .tool_choice - .as_ref() - .or_else(|| req.extras.get("tool_choice")); - if let Some(tool_choice) = tool_choice_value { - let converted = match tool_choice { - Value::String(s) if s == "none" || s == "auto" || s == "required" => { - Value::String(s.clone()) - } - Value::Object(obj_tc) if obj_tc.contains_key("function") => { - if let Some(func) = obj_tc.get("function") { - if let Some(name) = func.get("name").and_then(Value::as_str) { - serde_json::json!({ "type": "function", "name": name }) - } else { - Value::String("auto".into()) - } - } else { - Value::String("auto".into()) - } - } - _ => Value::String("auto".into()), - }; - obj.insert("tool_choice".into(), converted); - } - - // Transform response_format to nested text.format structure for Responses API - if let Some(response_format) = req.extras.get("response_format") { - let text_format = match response_format.get("type").and_then(Value::as_str) { - Some("text") | Some("json_object") => { - Some(serde_json::json!({ "format": response_format })) - } - Some("json_schema") => response_format.get("json_schema").map(|json_schema| { - serde_json::json!({ - "format": { - "type": "json_schema", - "schema": json_schema.get("schema").cloned().unwrap_or(serde_json::json!({})), - "name": json_schema.get("name"), - "description": json_schema.get("description"), - "strict": json_schema.get("strict") - } - }) - }), - _ => None, - }; - if let Some(tf) = text_format { - obj.insert("text".into(), tf); - } - } - - // Transform reasoning_effort to nested reasoning.effort structure - if let Some(effort) = req.extras.get("reasoning_effort") { - obj.insert( - "reasoning".into(), - serde_json::json!({ "effort": effort.clone() }), - ); - } - - // Pass through parallel_tool_calls - if let Some(Value::Bool(parallel)) = req.extras.get("parallel_tool_calls") { - obj.insert("parallel_tool_calls".into(), Value::Bool(*parallel)); - } - - // Merge remaining extras (except those we handled specially) - for (k, v) in &req.extras { - if !matches!( - k.as_str(), - "tools" - | "tool_choice" - | "response_format" - | "reasoning_effort" - | "parallel_tool_calls" - ) { - obj.insert(k.clone(), v.clone()); - } - } - - Ok(Value::Object(obj)) - } - - fn apply_defaults(&self, _req: &mut UniversalRequest) { - // Responses API doesn't require any specific defaults - } - - fn detect_response(&self, payload: &Value) -> bool { - // Responses API response has output[] array and object="response" - payload.get("output").and_then(Value::as_array).is_some() - && payload - .get("object") - .and_then(Value::as_str) - .is_some_and(|o| o == "response") - } - - fn response_to_universal(&self, payload: Value) -> Result { - let output = payload - .get("output") - .and_then(Value::as_array) - .ok_or_else(|| TransformError::ToUniversalFailed("missing output".to_string()))?; - - // Convert output items to messages - // Responses API has multiple output types: message, function_call, reasoning, etc. - let mut messages: Vec = Vec::new(); - let mut tool_calls: Vec = Vec::new(); - - for item in output { - let item_type = item.get("type").and_then(Value::as_str); - - match item_type { - Some("message") => { - // Message type - extract text content - if let Some(content) = item.get("content") { - if let Some(content_arr) = content.as_array() { - let text: String = content_arr - .iter() - .filter_map(|c| { - if c.get("type").and_then(Value::as_str) == Some("output_text") - { - c.get("text").and_then(Value::as_str).map(String::from) - } else { - None - } - }) - .collect::>() - .join(""); - if !text.is_empty() { - messages.push(Message::Assistant { - content: AssistantContent::String(text), - id: None, - }); - } - } - } - } - Some("function_call") => { - // Function call - collect for later conversion to tool calls - tool_calls.push(item.clone()); - } - _ => { - // Skip reasoning and other types for now - } - } - } - - // If we have tool calls but no messages, create an assistant message with tool calls - if !tool_calls.is_empty() && messages.is_empty() { - // Convert function_call items to tool call format - use crate::universal::message::{AssistantContentPart, ToolCallArguments}; - let parts: Vec = tool_calls - .iter() - .filter_map(|tc| { - let name = tc.get("name").and_then(Value::as_str)?; - let call_id = tc.get("call_id").and_then(Value::as_str)?; - let arguments = tc.get("arguments").and_then(Value::as_str)?; - - // Try to parse arguments as JSON, fall back to invalid string - let args = serde_json::from_str::>(arguments) - .map(ToolCallArguments::Valid) - .unwrap_or_else(|_| ToolCallArguments::Invalid(arguments.to_string())); - - Some(AssistantContentPart::ToolCall { - tool_call_id: call_id.to_string(), - tool_name: name.to_string(), - arguments: args, - provider_options: None, - provider_executed: None, - }) - }) - .collect(); - - if !parts.is_empty() { - messages.push(Message::Assistant { - content: AssistantContent::Array(parts), - id: None, - }); - } - } - - // If still no messages, try output_text field as fallback - if messages.is_empty() { - if let Some(text) = payload.get("output_text").and_then(Value::as_str) { - if !text.is_empty() { - messages.push(Message::Assistant { - content: AssistantContent::String(text.to_string()), - id: None, - }); - } - } - } - - // Map status to finish_reason - let finish_reason = payload - .get("status") - .and_then(Value::as_str) - .map(|s| s.parse().unwrap()); - - let usage = payload.get("usage").map(|u| UniversalUsage { - prompt_tokens: u.get("input_tokens").and_then(Value::as_i64), - completion_tokens: u.get("output_tokens").and_then(Value::as_i64), - prompt_cached_tokens: u - .get("input_tokens_details") - .and_then(|d| d.get("cached_tokens")) - .and_then(Value::as_i64), - prompt_cache_creation_tokens: None, - completion_reasoning_tokens: u - .get("output_tokens_details") - .and_then(|d| d.get("reasoning_tokens")) - .and_then(Value::as_i64), - }); - - Ok(UniversalResponse { - model: payload - .get("model") - .and_then(Value::as_str) - .map(String::from), - messages, - usage, - finish_reason, - }) - } - - fn response_from_universal(&self, resp: &UniversalResponse) -> Result { - // Build Responses API response format - let output: Vec = resp - .messages - .iter() - .map(|msg| { - let text = match msg { - Message::Assistant { content, .. } => match content { - AssistantContent::String(s) => s.clone(), - AssistantContent::Array(_) => String::new(), // TODO: extract text from parts - }, - Message::User { content } => match content { - UserContent::String(s) => s.clone(), - UserContent::Array(_) => String::new(), - }, - _ => String::new(), - }; - - serde_json::json!({ - "type": "message", - "role": "assistant", - "content": [{ - "type": "output_text", - "text": text - }] - }) - }) - .collect(); - - let status = self - .map_finish_reason(resp.finish_reason.as_ref()) - .unwrap_or_else(|| "completed".to_string()); - - // Build response with all required fields for TheResponseObject - let mut obj = serde_json::json!({ - "id": format!("resp_{}", PLACEHOLDER_ID), - "object": "response", - "model": resp.model.as_deref().unwrap_or(PLACEHOLDER_MODEL), - "output": output, - "status": status, - "created_at": 0.0, - "tool_choice": "none", - "tools": [], - "parallel_tool_calls": false - }); - - if let Some(usage) = &resp.usage { - let input = usage.prompt_tokens.unwrap_or(0); - let output = usage.completion_tokens.unwrap_or(0); - obj.as_object_mut().unwrap().insert( - "usage".into(), - serde_json::json!({ - "input_tokens": input, - "output_tokens": output, - "total_tokens": input + output, - "input_tokens_details": { - "cached_tokens": usage.prompt_cached_tokens.unwrap_or(0) - }, - "output_tokens_details": { - "reasoning_tokens": usage.completion_reasoning_tokens.unwrap_or(0) - } - }), - ); - } - - Ok(obj) - } - - fn map_finish_reason(&self, reason: Option<&FinishReason>) -> Option { - reason.map(|r| match r { - FinishReason::Stop => "completed".to_string(), - FinishReason::Length => "incomplete".to_string(), - FinishReason::ToolCalls => "completed".to_string(), // Tool calls also complete - FinishReason::ContentFilter => "incomplete".to_string(), - FinishReason::Other(s) => s.clone(), - }) - } - - // ========================================================================= - // Streaming response handling - // ========================================================================= - - fn detect_stream_response(&self, payload: &Value) -> bool { - // Responses API streaming has type field starting with "response." - payload - .get("type") - .and_then(Value::as_str) - .is_some_and(|t| t.starts_with("response.")) - } - - fn stream_to_universal( - &self, - payload: Value, - ) -> Result, TransformError> { - let event_type = payload - .get("type") - .and_then(Value::as_str) - .ok_or_else(|| TransformError::ToUniversalFailed("missing type field".to_string()))?; - - match event_type { - "response.output_text.delta" => { - // Text delta - extract from delta field - let text = payload.get("delta").and_then(Value::as_str).unwrap_or(""); - let output_index = payload - .get("output_index") - .and_then(Value::as_u64) - .unwrap_or(0) as u32; - - Ok(Some(UniversalStreamChunk::new( - None, - None, - vec![UniversalStreamChoice { - index: output_index, - delta: Some(serde_json::json!({ - "role": "assistant", - "content": text - })), - finish_reason: None, - }], - None, - None, - ))) - } - - "response.completed" => { - // Final event with usage - let response = payload.get("response"); - let usage = response - .and_then(|r| r.get("usage")) - .map(|u| UniversalUsage { - prompt_tokens: u.get("input_tokens").and_then(Value::as_i64), - completion_tokens: u.get("output_tokens").and_then(Value::as_i64), - prompt_cached_tokens: u - .get("input_tokens_details") - .and_then(|d| d.get("cached_tokens")) - .and_then(Value::as_i64), - prompt_cache_creation_tokens: None, - completion_reasoning_tokens: u - .get("output_tokens_details") - .and_then(|d| d.get("reasoning_tokens")) - .and_then(Value::as_i64), - }); - - let model = response - .and_then(|r| r.get("model")) - .and_then(Value::as_str) - .map(String::from); - - let id = response - .and_then(|r| r.get("id")) - .and_then(Value::as_str) - .map(String::from); - - Ok(Some(UniversalStreamChunk::new( - id, - model, - vec![UniversalStreamChoice { - index: 0, - delta: Some(serde_json::json!({})), - finish_reason: Some("stop".to_string()), - }], - None, - usage, - ))) - } - - "response.incomplete" => { - // Incomplete response - typically due to length - let response = payload.get("response"); - let usage = response - .and_then(|r| r.get("usage")) - .map(|u| UniversalUsage { - prompt_tokens: u.get("input_tokens").and_then(Value::as_i64), - completion_tokens: u.get("output_tokens").and_then(Value::as_i64), - prompt_cached_tokens: u - .get("input_tokens_details") - .and_then(|d| d.get("cached_tokens")) - .and_then(Value::as_i64), - prompt_cache_creation_tokens: None, - completion_reasoning_tokens: u - .get("output_tokens_details") - .and_then(|d| d.get("reasoning_tokens")) - .and_then(Value::as_i64), - }); - - Ok(Some(UniversalStreamChunk::new( - None, - None, - vec![UniversalStreamChoice { - index: 0, - delta: Some(serde_json::json!({})), - finish_reason: Some("length".to_string()), - }], - None, - usage, - ))) - } - - "response.created" | "response.in_progress" => { - // Initial metadata events - extract model/id - let response = payload.get("response"); - let model = response - .and_then(|r| r.get("model")) - .and_then(Value::as_str) - .map(String::from); - let id = response - .and_then(|r| r.get("id")) - .and_then(Value::as_str) - .map(String::from); - - Ok(Some(UniversalStreamChunk::new( - id, - model, - vec![UniversalStreamChoice { - index: 0, - delta: Some(serde_json::json!({"role": "assistant", "content": ""})), - finish_reason: None, - }], - None, - None, - ))) - } - - // All other events are metadata/keep-alive - _ => Ok(Some(UniversalStreamChunk::keep_alive())), - } - } - - fn stream_from_universal(&self, chunk: &UniversalStreamChunk) -> Result { - if chunk.is_keep_alive() { - // Return a generic in_progress event - return Ok(serde_json::json!({ - "type": "response.in_progress", - "sequence_number": 0 - })); - } - - // Check for finish chunk - let has_finish = chunk - .choices - .first() - .and_then(|c| c.finish_reason.as_ref()) - .is_some(); - - if has_finish { - let finish_reason = chunk.choices.first().and_then(|c| c.finish_reason.as_ref()); - let status = match finish_reason.map(|r| r.as_str()) { - Some("stop") => "completed", - Some("length") => "incomplete", - _ => "completed", - }; - - let id = chunk - .id - .clone() - .unwrap_or_else(|| format!("resp_{}", PLACEHOLDER_ID)); - let mut response = serde_json::json!({ - "id": id, - "object": "response", - "model": chunk.model.as_deref().unwrap_or(PLACEHOLDER_MODEL), - "status": status, - "output": [] - }); - - if let Some(usage) = &chunk.usage { - response.as_object_mut().unwrap().insert( - "usage".into(), - serde_json::json!({ - "input_tokens": usage.prompt_tokens.unwrap_or(0), - "output_tokens": usage.completion_tokens.unwrap_or(0), - "total_tokens": usage.prompt_tokens.unwrap_or(0) + usage.completion_tokens.unwrap_or(0) - }), - ); - } - - return Ok(serde_json::json!({ - "type": if status == "completed" { "response.completed" } else { "response.incomplete" }, - "response": response - })); - } - - // Check for content delta - if let Some(choice) = chunk.choices.first() { - if let Some(delta) = &choice.delta { - if let Some(content) = delta.get("content").and_then(Value::as_str) { - return Ok(serde_json::json!({ - "type": "response.output_text.delta", - "output_index": choice.index, - "content_index": 0, - "delta": content - })); - } - } - } - - // Fallback - return output_text.delta with empty content - Ok(serde_json::json!({ - "type": "response.output_text.delta", - "output_index": 0, - "content_index": 0, - "delta": "" - })) - } -} - // ============================================================================= // OpenAI Target-Specific Transformations // ============================================================================= @@ -1418,8 +852,12 @@ mod tests { }); let universal = adapter.request_to_universal(payload).unwrap(); - assert!(universal.extras.contains_key("user")); - assert!(universal.extras.contains_key("custom_field")); + let openai_extras = universal + .provider_extras + .get(&ProviderFormat::OpenAI) + .expect("should have OpenAI extras"); + assert!(openai_extras.contains_key("user")); + assert!(openai_extras.contains_key("custom_field")); let reconstructed = adapter.request_from_universal(&universal).unwrap(); assert_eq!(reconstructed.get("user").unwrap(), "test-user-123"); @@ -1428,14 +866,4 @@ mod tests { "should_be_preserved" ); } - - #[test] - fn test_responses_detect_request() { - let adapter = ResponsesAdapter; - let payload = json!({ - "model": "o1", - "input": [{"role": "user", "content": "Hello"}] - }); - assert!(adapter.detect_request(&payload)); - } } diff --git a/crates/lingua/src/providers/openai/mod.rs b/crates/lingua/src/providers/openai/mod.rs index 9ba962b6..7a502934 100644 --- a/crates/lingua/src/providers/openai/mod.rs +++ b/crates/lingua/src/providers/openai/mod.rs @@ -11,9 +11,12 @@ pub mod capabilities; pub mod convert; pub mod detect; pub mod generated; +pub mod params; +pub mod responses_adapter; // Re-export adapters and transformations -pub use adapter::{apply_target_transforms, OpenAIAdapter, OpenAITransformError, ResponsesAdapter}; +pub use adapter::{apply_target_transforms, OpenAIAdapter, OpenAITransformError}; +pub use responses_adapter::ResponsesAdapter; #[cfg(test)] pub mod test_responses; diff --git a/crates/lingua/src/providers/openai/params.rs b/crates/lingua/src/providers/openai/params.rs new file mode 100644 index 00000000..e22b3c46 --- /dev/null +++ b/crates/lingua/src/providers/openai/params.rs @@ -0,0 +1,194 @@ +/*! +Typed parameter structs for OpenAI APIs. + +These structs use `#[serde(flatten)]` to automatically capture unknown fields, +eliminating the need for explicit KNOWN_KEYS arrays. +*/ + +use crate::serde_json::Value; +use serde::{Deserialize, Serialize}; +use std::collections::BTreeMap; + +/// OpenAI Chat Completions API request parameters. +/// +/// All known fields are explicitly typed. Unknown fields automatically +/// go into `extras` via `#[serde(flatten)]`. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct OpenAIChatParams { + // === Core fields === + pub model: Option, + pub messages: Option, + + // === Sampling parameters === + pub temperature: Option, + pub top_p: Option, + pub seed: Option, + pub presence_penalty: Option, + pub frequency_penalty: Option, + + // === Output control === + pub max_tokens: Option, + pub max_completion_tokens: Option, + pub stop: Option, + pub n: Option, + pub logprobs: Option, + pub top_logprobs: Option, + pub logit_bias: Option, + + // === Tools and function calling === + pub tools: Option, + pub tool_choice: Option, + pub parallel_tool_calls: Option, + + // === Response format === + pub response_format: Option, + + // === Streaming === + pub stream: Option, + pub stream_options: Option, + + // === Reasoning (o-series models) === + pub reasoning_effort: Option, + + // === Metadata and identification === + pub metadata: Option, + pub store: Option, + pub service_tier: Option, + pub user: Option, + pub safety_identifier: Option, + pub prompt_cache_key: Option, + + // === Prediction === + pub prediction: Option, + + /// Unknown fields - automatically captured by serde flatten. + /// These are provider-specific fields not in the canonical set. + #[serde(flatten)] + pub extras: BTreeMap, +} + +/// OpenAI Responses API request parameters. +/// +/// The Responses API has different field names and structure than Chat Completions. +/// All known fields are explicitly typed. Unknown fields automatically +/// go into `extras` via `#[serde(flatten)]`. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct OpenAIResponsesParams { + // === Core fields === + pub model: Option, + pub input: Option, + pub instructions: Option, + + // === Sampling parameters === + pub temperature: Option, + pub top_p: Option, + + // === Output control === + pub max_output_tokens: Option, + + // === Tools and function calling === + pub tools: Option, + pub tool_choice: Option, + pub parallel_tool_calls: Option, + + // === Text/Response format (nested structure) === + pub text: Option, + + // === Streaming === + pub stream: Option, + + // === Reasoning configuration (nested structure) === + pub reasoning: Option, + + // === Context management === + pub truncation: Option, + + // === Metadata and identification === + pub metadata: Option, + pub store: Option, + pub service_tier: Option, + pub user: Option, + pub safety_identifier: Option, + pub prompt_cache_key: Option, + + /// Unknown fields - automatically captured by serde flatten. + #[serde(flatten)] + pub extras: BTreeMap, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::serde_json; + use crate::serde_json::json; + + #[test] + fn test_chat_params_known_fields() { + let json = json!({ + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Hello"}], + "temperature": 0.7, + "max_tokens": 100 + }); + + let params: OpenAIChatParams = serde_json::from_value(json).unwrap(); + assert_eq!(params.model, Some("gpt-4o".to_string())); + assert_eq!(params.temperature, Some(0.7)); + assert_eq!(params.max_tokens, Some(100)); + assert!(params.extras.is_empty()); + } + + #[test] + fn test_chat_params_unknown_fields_go_to_extras() { + let json = json!({ + "model": "gpt-4o", + "messages": [], + "some_future_param": "value", + "another_unknown": 42 + }); + + let params: OpenAIChatParams = serde_json::from_value(json).unwrap(); + assert_eq!(params.model, Some("gpt-4o".to_string())); + assert_eq!(params.extras.len(), 2); + assert_eq!( + params.extras.get("some_future_param"), + Some(&Value::String("value".to_string())) + ); + assert_eq!( + params.extras.get("another_unknown"), + Some(&Value::Number(42.into())) + ); + } + + #[test] + fn test_responses_params_known_fields() { + let json = json!({ + "model": "gpt-5-nano", + "input": [{"role": "user", "content": "Hello"}], + "instructions": "Be helpful", + "max_output_tokens": 500, + "reasoning": {"effort": "medium"} + }); + + let params: OpenAIResponsesParams = serde_json::from_value(json).unwrap(); + assert_eq!(params.model, Some("gpt-5-nano".to_string())); + assert_eq!(params.instructions, Some("Be helpful".to_string())); + assert_eq!(params.max_output_tokens, Some(500)); + assert!(params.extras.is_empty()); + } + + #[test] + fn test_roundtrip_preserves_extras() { + let json = json!({ + "model": "gpt-4o", + "messages": [], + "custom_field": {"nested": "data"} + }); + + let params: OpenAIChatParams = serde_json::from_value(json.clone()).unwrap(); + let back: Value = serde_json::to_value(¶ms).unwrap(); + + // Custom field should be preserved + assert_eq!(back.get("custom_field"), json.get("custom_field")); + } +} diff --git a/crates/lingua/src/providers/openai/responses_adapter.rs b/crates/lingua/src/providers/openai/responses_adapter.rs new file mode 100644 index 00000000..43068cbf --- /dev/null +++ b/crates/lingua/src/providers/openai/responses_adapter.rs @@ -0,0 +1,740 @@ +/*! +OpenAI Responses API adapter. + +This module provides the `ResponsesAdapter` for the Responses API, +which is used by reasoning models like o1 and o3. +*/ + +use crate::capabilities::ProviderFormat; +use std::collections::HashMap; + +use crate::processing::adapters::{ + insert_opt_bool, insert_opt_f64, insert_opt_i64, ProviderAdapter, +}; +use crate::processing::transform::TransformError; +use crate::providers::openai::generated::{ + InputItem, InputItemContent, InputItemRole, InputItemType, Instructions, +}; +use crate::providers::openai::params::OpenAIResponsesParams; +use crate::providers::openai::{try_parse_responses, universal_to_responses_input}; +use crate::serde_json::{self, Map, Value}; +use crate::universal::convert::TryFromLLM; +use crate::universal::message::{AssistantContent, Message, UserContent}; +use std::convert::TryInto; +use crate::universal::tools::is_responses_tool_format; +use crate::universal::{ + FinishReason, UniversalParams, UniversalRequest, UniversalResponse, UniversalStreamChoice, + UniversalStreamChunk, UniversalUsage, PLACEHOLDER_ID, PLACEHOLDER_MODEL, +}; + +/// Adapter for OpenAI Responses API (used by reasoning models like o1). +pub struct ResponsesAdapter; + +impl ProviderAdapter for ResponsesAdapter { + fn format(&self) -> ProviderFormat { + ProviderFormat::Responses + } + + fn directory_name(&self) -> &'static str { + "responses" + } + + fn display_name(&self) -> &'static str { + "Responses" + } + + fn detect_request(&self, payload: &Value) -> bool { + try_parse_responses(payload).is_ok() + } + + fn request_to_universal(&self, payload: Value) -> Result { + // Parse into typed params - extras are automatically captured via #[serde(flatten)] + let typed_params: OpenAIResponsesParams = serde_json::from_value(payload.clone()) + .map_err(|e| TransformError::ToUniversalFailed(e.to_string()))?; + + // Parse again for strongly-typed message conversion + let request: crate::providers::openai::generated::CreateResponseClass = + serde_json::from_value(payload) + .map_err(|e| TransformError::ToUniversalFailed(e.to_string()))?; + + // Extract input items from the request + let input_items: Vec = match request.input { + Some(Instructions::InputItemArray(items)) => items, + Some(Instructions::String(s)) => { + // Single string input - create a user message InputItem + vec![InputItem { + input_item_type: Some(InputItemType::Message), + role: Some(InputItemRole::User), + content: Some(InputItemContent::String(s)), + ..Default::default() + }] + } + None => vec![], + }; + + let messages = as TryFromLLM>>::try_from(input_items) + .map_err(|e| TransformError::ToUniversalFailed(e.to_string()))?; + + // Extract response_format from nested text.format structure and convert to typed config + let response_format = typed_params + .text + .as_ref() + .and_then(|t| t.get("format")) + .and_then(|v| (ProviderFormat::Responses, v).try_into().ok()); + + let params = UniversalParams { + temperature: typed_params.temperature, + top_p: typed_params.top_p, + top_k: None, + max_tokens: typed_params.max_output_tokens, + stop: None, // Responses API doesn't use stop + tools: typed_params.tools, + tool_choice: typed_params + .tool_choice + .as_ref() + .and_then(|v| (ProviderFormat::Responses, v).try_into().ok()), + response_format, + seed: None, // Responses API uses different randomness control + presence_penalty: None, // Responses API doesn't support penalties + frequency_penalty: None, + stream: typed_params.stream, + // New canonical fields + parallel_tool_calls: typed_params.parallel_tool_calls, + reasoning: typed_params + .reasoning + .as_ref() + .and_then(|v| (ProviderFormat::Responses, v).try_into().ok()), + metadata: typed_params.metadata, + store: typed_params.store, + service_tier: typed_params.service_tier, + logprobs: None, // Responses API doesn't support logprobs + top_logprobs: None, // Responses API doesn't support top_logprobs + }; + + // Collect provider-specific extras for round-trip preservation + // This includes both unknown fields (from serde flatten) and known Responses API fields + // that aren't part of UniversalParams + let mut extras_map: Map = typed_params.extras.into_iter().collect(); + + // Add Responses API specific known fields that aren't in UniversalParams + if let Some(instructions) = typed_params.instructions { + extras_map.insert("instructions".into(), Value::String(instructions)); + } + if let Some(text) = typed_params.text { + extras_map.insert("text".into(), text); + } + if let Some(truncation) = typed_params.truncation { + extras_map.insert("truncation".into(), truncation); + } + if let Some(user) = typed_params.user { + extras_map.insert("user".into(), Value::String(user)); + } + if let Some(safety_identifier) = typed_params.safety_identifier { + extras_map.insert("safety_identifier".into(), Value::String(safety_identifier)); + } + if let Some(prompt_cache_key) = typed_params.prompt_cache_key { + extras_map.insert("prompt_cache_key".into(), Value::String(prompt_cache_key)); + } + + let mut provider_extras = HashMap::new(); + if !extras_map.is_empty() { + provider_extras.insert(ProviderFormat::Responses, extras_map); + } + + Ok(UniversalRequest { + model: typed_params.model, + messages, + params, + provider_extras, + }) + } + + fn request_from_universal(&self, req: &UniversalRequest) -> Result { + let model = req.model.as_ref().ok_or(TransformError::ValidationFailed { + target: ProviderFormat::Responses, + reason: "missing model".to_string(), + })?; + + // Use existing conversion with 1:N Tool message expansion + let input_items = universal_to_responses_input(&req.messages) + .map_err(|e| TransformError::FromUniversalFailed(e.to_string()))?; + + let mut obj = Map::new(); + obj.insert("model".into(), Value::String(model.clone())); + obj.insert( + "input".into(), + serde_json::to_value(input_items) + .map_err(|e| TransformError::SerializationFailed(e.to_string()))?, + ); + + // Note: temperature is intentionally NOT included for Responses API + // as reasoning models (o1, o3) don't support it + insert_opt_f64(&mut obj, "top_p", req.params.top_p); + insert_opt_i64(&mut obj, "max_output_tokens", req.params.max_tokens); + insert_opt_f64(&mut obj, "presence_penalty", req.params.presence_penalty); + insert_opt_f64(&mut obj, "frequency_penalty", req.params.frequency_penalty); + insert_opt_bool(&mut obj, "stream", req.params.stream); + + // Get provider-specific extras for Responses API + let responses_extras = req.provider_extras.get(&ProviderFormat::Responses); + + // Transform tools - but if already in Responses format, pass through unchanged + if let Some(tools) = req.params.tools.as_ref() { + if is_responses_tool_format(tools) { + // Already in Responses format - pass through + obj.insert("tools".into(), tools.clone()); + } else if let Value::Array(tools_arr) = tools { + // Convert from OpenAI Chat format to Responses API format + // {type: "function", function: {name, description, parameters}} + // → {type: "function", name, description, parameters, strict: false} + let response_tools: Vec = tools_arr + .iter() + .filter_map(|tool| { + if tool.get("type").and_then(Value::as_str) == Some("function") { + let func = tool.get("function")?; + Some(serde_json::json!({ + "type": "function", + "name": func.get("name")?, + "description": func.get("description"), + "parameters": func.get("parameters").cloned().unwrap_or(serde_json::json!({})), + "strict": false + })) + } else { + None + } + }) + .collect(); + if !response_tools.is_empty() { + obj.insert("tools".into(), Value::Array(response_tools)); + } + } + } + + // Convert tool_choice from canonical ToolChoiceConfig to Responses API format + // Responses API doesn't use parallel_tool_calls in tool_choice, pass None + if let Some(tool_choice_val) = req + .params + .tool_choice + .as_ref() + .and_then(|tc| tc.to_provider(ProviderFormat::Responses, None).ok()) + .flatten() + { + obj.insert("tool_choice".into(), tool_choice_val); + } + + // Convert response_format from canonical ResponseFormatConfig to Responses API text format + if let Some(text_val) = req + .params + .response_format + .as_ref() + .and_then(|rf| rf.to_provider(ProviderFormat::Responses).ok()) + .flatten() + { + obj.insert("text".into(), text_val); + } + + // Add reasoning from canonical params - convert ReasoningConfig to Responses API format + // max_tokens is passed explicitly for budget→effort conversion + if let Some(reasoning_val) = req + .params + .reasoning + .as_ref() + .and_then(|r| r.to_provider(ProviderFormat::Responses, req.params.max_tokens).ok()) + .flatten() + { + obj.insert("reasoning".into(), reasoning_val); + } + + // Add parallel_tool_calls from canonical params + if let Some(parallel) = req.params.parallel_tool_calls { + obj.insert("parallel_tool_calls".into(), Value::Bool(parallel)); + } + + // Add metadata from canonical params + if let Some(metadata) = req.params.metadata.as_ref() { + obj.insert("metadata".into(), metadata.clone()); + } + + // Add store from canonical params + if let Some(store) = req.params.store { + obj.insert("store".into(), Value::Bool(store)); + } + + // Add service_tier from canonical params + if let Some(ref service_tier) = req.params.service_tier { + obj.insert("service_tier".into(), Value::String(service_tier.clone())); + } + + // Merge back provider-specific extras (only for Responses API) + if let Some(extras) = responses_extras { + for (k, v) in extras { + // Don't overwrite canonical fields we already handled + if !obj.contains_key(k) { + obj.insert(k.clone(), v.clone()); + } + } + } + + Ok(Value::Object(obj)) + } + + fn apply_defaults(&self, _req: &mut UniversalRequest) { + // Responses API doesn't require any specific defaults + } + + fn detect_response(&self, payload: &Value) -> bool { + // Responses API response has output[] array and object="response" + payload.get("output").and_then(Value::as_array).is_some() + && payload + .get("object") + .and_then(Value::as_str) + .is_some_and(|o| o == "response") + } + + fn response_to_universal(&self, payload: Value) -> Result { + let output = payload + .get("output") + .and_then(Value::as_array) + .ok_or_else(|| TransformError::ToUniversalFailed("missing output".to_string()))?; + + // Convert output items to messages + // Responses API has multiple output types: message, function_call, reasoning, etc. + let mut messages: Vec = Vec::new(); + let mut tool_calls: Vec = Vec::new(); + + for item in output { + let item_type = item.get("type").and_then(Value::as_str); + + match item_type { + Some("message") => { + // Message type - extract text content + if let Some(content) = item.get("content") { + if let Some(content_arr) = content.as_array() { + let text: String = content_arr + .iter() + .filter_map(|c| { + if c.get("type").and_then(Value::as_str) == Some("output_text") + { + c.get("text").and_then(Value::as_str).map(String::from) + } else { + None + } + }) + .collect::>() + .join(""); + if !text.is_empty() { + messages.push(Message::Assistant { + content: AssistantContent::String(text), + id: None, + }); + } + } + } + } + Some("function_call") => { + // Function call - collect for later conversion to tool calls + tool_calls.push(item.clone()); + } + _ => { + // Skip reasoning and other types for now + } + } + } + + // If we have tool calls but no messages, create an assistant message with tool calls + if !tool_calls.is_empty() && messages.is_empty() { + // Convert function_call items to tool call format + use crate::universal::message::{AssistantContentPart, ToolCallArguments}; + let parts: Vec = tool_calls + .iter() + .filter_map(|tc| { + let name = tc.get("name").and_then(Value::as_str)?; + let call_id = tc.get("call_id").and_then(Value::as_str)?; + let arguments = tc.get("arguments").and_then(Value::as_str)?; + + // Try to parse arguments as JSON, fall back to invalid string + let args = serde_json::from_str::>(arguments) + .map(ToolCallArguments::Valid) + .unwrap_or_else(|_| ToolCallArguments::Invalid(arguments.to_string())); + + Some(AssistantContentPart::ToolCall { + tool_call_id: call_id.to_string(), + tool_name: name.to_string(), + arguments: args, + provider_options: None, + provider_executed: None, + }) + }) + .collect(); + + if !parts.is_empty() { + messages.push(Message::Assistant { + content: AssistantContent::Array(parts), + id: None, + }); + } + } + + // If still no messages, try output_text field as fallback + if messages.is_empty() { + if let Some(text) = payload.get("output_text").and_then(Value::as_str) { + if !text.is_empty() { + messages.push(Message::Assistant { + content: AssistantContent::String(text.to_string()), + id: None, + }); + } + } + } + + // Map status to finish_reason + let finish_reason = payload + .get("status") + .and_then(Value::as_str) + .map(|s| s.parse().unwrap()); + + let usage = payload.get("usage").map(|u| UniversalUsage { + prompt_tokens: u.get("input_tokens").and_then(Value::as_i64), + completion_tokens: u.get("output_tokens").and_then(Value::as_i64), + prompt_cached_tokens: u + .get("input_tokens_details") + .and_then(|d| d.get("cached_tokens")) + .and_then(Value::as_i64), + prompt_cache_creation_tokens: None, + completion_reasoning_tokens: u + .get("output_tokens_details") + .and_then(|d| d.get("reasoning_tokens")) + .and_then(Value::as_i64), + }); + + Ok(UniversalResponse { + model: payload + .get("model") + .and_then(Value::as_str) + .map(String::from), + messages, + usage, + finish_reason, + }) + } + + fn response_from_universal(&self, resp: &UniversalResponse) -> Result { + // Build Responses API response format + let output: Vec = resp + .messages + .iter() + .map(|msg| { + let text = match msg { + Message::Assistant { content, .. } => match content { + AssistantContent::String(s) => s.clone(), + AssistantContent::Array(_) => String::new(), // TODO: extract text from parts + }, + Message::User { content } => match content { + UserContent::String(s) => s.clone(), + UserContent::Array(_) => String::new(), + }, + _ => String::new(), + }; + + serde_json::json!({ + "type": "message", + "role": "assistant", + "content": [{ + "type": "output_text", + "text": text + }] + }) + }) + .collect(); + + let status = self + .map_finish_reason(resp.finish_reason.as_ref()) + .unwrap_or_else(|| "completed".to_string()); + + // Build response with all required fields for TheResponseObject + let mut obj = serde_json::json!({ + "id": format!("resp_{}", PLACEHOLDER_ID), + "object": "response", + "model": resp.model.as_deref().unwrap_or(PLACEHOLDER_MODEL), + "output": output, + "status": status, + "created_at": 0.0, + "tool_choice": "none", + "tools": [], + "parallel_tool_calls": false + }); + + if let Some(usage) = &resp.usage { + let input = usage.prompt_tokens.unwrap_or(0); + let output = usage.completion_tokens.unwrap_or(0); + obj.as_object_mut().unwrap().insert( + "usage".into(), + serde_json::json!({ + "input_tokens": input, + "output_tokens": output, + "total_tokens": input + output, + "input_tokens_details": { + "cached_tokens": usage.prompt_cached_tokens.unwrap_or(0) + }, + "output_tokens_details": { + "reasoning_tokens": usage.completion_reasoning_tokens.unwrap_or(0) + } + }), + ); + } + + Ok(obj) + } + + fn map_finish_reason(&self, reason: Option<&FinishReason>) -> Option { + reason.map(|r| match r { + FinishReason::Stop => "completed".to_string(), + FinishReason::Length => "incomplete".to_string(), + FinishReason::ToolCalls => "completed".to_string(), // Tool calls also complete + FinishReason::ContentFilter => "incomplete".to_string(), + FinishReason::Other(s) => s.clone(), + }) + } + + // ========================================================================= + // Streaming response handling + // ========================================================================= + + fn detect_stream_response(&self, payload: &Value) -> bool { + // Responses API streaming has type field starting with "response." + payload + .get("type") + .and_then(Value::as_str) + .is_some_and(|t| t.starts_with("response.")) + } + + fn stream_to_universal( + &self, + payload: Value, + ) -> Result, TransformError> { + let event_type = payload + .get("type") + .and_then(Value::as_str) + .ok_or_else(|| TransformError::ToUniversalFailed("missing type field".to_string()))?; + + match event_type { + "response.output_text.delta" => { + // Text delta - extract from delta field + let text = payload.get("delta").and_then(Value::as_str).unwrap_or(""); + let output_index = payload + .get("output_index") + .and_then(Value::as_u64) + .unwrap_or(0) as u32; + + Ok(Some(UniversalStreamChunk::new( + None, + None, + vec![UniversalStreamChoice { + index: output_index, + delta: Some(serde_json::json!({ + "role": "assistant", + "content": text + })), + finish_reason: None, + }], + None, + None, + ))) + } + + "response.completed" => { + // Final event with usage + let response = payload.get("response"); + let usage = response + .and_then(|r| r.get("usage")) + .map(|u| UniversalUsage { + prompt_tokens: u.get("input_tokens").and_then(Value::as_i64), + completion_tokens: u.get("output_tokens").and_then(Value::as_i64), + prompt_cached_tokens: u + .get("input_tokens_details") + .and_then(|d| d.get("cached_tokens")) + .and_then(Value::as_i64), + prompt_cache_creation_tokens: None, + completion_reasoning_tokens: u + .get("output_tokens_details") + .and_then(|d| d.get("reasoning_tokens")) + .and_then(Value::as_i64), + }); + + let model = response + .and_then(|r| r.get("model")) + .and_then(Value::as_str) + .map(String::from); + + let id = response + .and_then(|r| r.get("id")) + .and_then(Value::as_str) + .map(String::from); + + Ok(Some(UniversalStreamChunk::new( + id, + model, + vec![UniversalStreamChoice { + index: 0, + delta: Some(serde_json::json!({})), + finish_reason: Some("stop".to_string()), + }], + None, + usage, + ))) + } + + "response.incomplete" => { + // Incomplete response - typically due to length + let response = payload.get("response"); + let usage = response + .and_then(|r| r.get("usage")) + .map(|u| UniversalUsage { + prompt_tokens: u.get("input_tokens").and_then(Value::as_i64), + completion_tokens: u.get("output_tokens").and_then(Value::as_i64), + prompt_cached_tokens: u + .get("input_tokens_details") + .and_then(|d| d.get("cached_tokens")) + .and_then(Value::as_i64), + prompt_cache_creation_tokens: None, + completion_reasoning_tokens: u + .get("output_tokens_details") + .and_then(|d| d.get("reasoning_tokens")) + .and_then(Value::as_i64), + }); + + Ok(Some(UniversalStreamChunk::new( + None, + None, + vec![UniversalStreamChoice { + index: 0, + delta: Some(serde_json::json!({})), + finish_reason: Some("length".to_string()), + }], + None, + usage, + ))) + } + + "response.created" | "response.in_progress" => { + // Initial metadata events - extract model/id + let response = payload.get("response"); + let model = response + .and_then(|r| r.get("model")) + .and_then(Value::as_str) + .map(String::from); + let id = response + .and_then(|r| r.get("id")) + .and_then(Value::as_str) + .map(String::from); + + Ok(Some(UniversalStreamChunk::new( + id, + model, + vec![UniversalStreamChoice { + index: 0, + delta: Some(serde_json::json!({"role": "assistant", "content": ""})), + finish_reason: None, + }], + None, + None, + ))) + } + + // All other events are metadata/keep-alive + _ => Ok(Some(UniversalStreamChunk::keep_alive())), + } + } + + fn stream_from_universal(&self, chunk: &UniversalStreamChunk) -> Result { + if chunk.is_keep_alive() { + // Return a generic in_progress event + return Ok(serde_json::json!({ + "type": "response.in_progress", + "sequence_number": 0 + })); + } + + // Check for finish chunk + let has_finish = chunk + .choices + .first() + .and_then(|c| c.finish_reason.as_ref()) + .is_some(); + + if has_finish { + let finish_reason = chunk.choices.first().and_then(|c| c.finish_reason.as_ref()); + let status = match finish_reason.map(|r| r.as_str()) { + Some("stop") => "completed", + Some("length") => "incomplete", + _ => "completed", + }; + + let id = chunk + .id + .clone() + .unwrap_or_else(|| format!("resp_{}", PLACEHOLDER_ID)); + let mut response = serde_json::json!({ + "id": id, + "object": "response", + "model": chunk.model.as_deref().unwrap_or(PLACEHOLDER_MODEL), + "status": status, + "output": [] + }); + + if let Some(usage) = &chunk.usage { + response.as_object_mut().unwrap().insert( + "usage".into(), + serde_json::json!({ + "input_tokens": usage.prompt_tokens.unwrap_or(0), + "output_tokens": usage.completion_tokens.unwrap_or(0), + "total_tokens": usage.prompt_tokens.unwrap_or(0) + usage.completion_tokens.unwrap_or(0) + }), + ); + } + + return Ok(serde_json::json!({ + "type": if status == "completed" { "response.completed" } else { "response.incomplete" }, + "response": response + })); + } + + // Check for content delta + if let Some(choice) = chunk.choices.first() { + if let Some(delta) = &choice.delta { + if let Some(content) = delta.get("content").and_then(Value::as_str) { + return Ok(serde_json::json!({ + "type": "response.output_text.delta", + "output_index": choice.index, + "content_index": 0, + "delta": content + })); + } + } + } + + // Fallback - return output_text.delta with empty content + Ok(serde_json::json!({ + "type": "response.output_text.delta", + "output_index": 0, + "content_index": 0, + "delta": "" + })) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::serde_json::json; + + #[test] + fn test_responses_detect_request() { + let adapter = ResponsesAdapter; + let payload = json!({ + "model": "o1", + "input": [{"role": "user", "content": "Hello"}] + }); + assert!(adapter.detect_request(&payload)); + } +} diff --git a/crates/lingua/src/universal/mod.rs b/crates/lingua/src/universal/mod.rs index eacc0c6e..533177ae 100644 --- a/crates/lingua/src/universal/mod.rs +++ b/crates/lingua/src/universal/mod.rs @@ -12,15 +12,23 @@ This module provides a 1:1 Rust implementation of the AI SDK ModelMessage format pub mod convert; pub mod defaults; pub mod message; +pub mod reasoning; pub mod request; pub mod response; +pub mod response_format; +pub mod stop; pub mod stream; +pub mod tool_choice; +pub mod tools; pub mod transform; // Re-export main types for convenience pub use defaults::*; pub use message::*; -pub use request::{UniversalParams, UniversalRequest}; +pub use request::{ + JsonSchemaConfig, ReasoningConfig, ReasoningEffort, ResponseFormatConfig, ResponseFormatType, + StopConfig, SummaryMode, ToolChoiceConfig, ToolChoiceMode, UniversalParams, UniversalRequest, +}; pub use response::{FinishReason, UniversalResponse, UniversalUsage}; pub use stream::{UniversalStreamChoice, UniversalStreamChunk}; pub use transform::{extract_system_messages, flatten_consecutive_messages}; diff --git a/crates/lingua/src/universal/reasoning.rs b/crates/lingua/src/universal/reasoning.rs new file mode 100644 index 00000000..68c39333 --- /dev/null +++ b/crates/lingua/src/universal/reasoning.rs @@ -0,0 +1,436 @@ +/*! +Reasoning conversion utilities for cross-provider semantic translation. + +This module provides heuristics for converting between different providers' +reasoning/thinking configurations: +- OpenAI Chat: `reasoning_effort` (low/medium/high) +- OpenAI Responses: `reasoning` object with `effort` and `summary` fields +- Anthropic: `thinking.budget_tokens` +- Google: `thinkingConfig.thinkingBudget` + +## Design + +The conversion uses documented, deterministic heuristics: +- `effort_to_budget`: Converts effort level to token budget using multipliers +- `budget_to_effort`: Converts token budget to effort level using thresholds + +These heuristics match the existing proxy behavior for consistency. + +## Usage + +```ignore +use std::convert::TryInto; +use crate::capabilities::ProviderFormat; +use crate::universal::request::ReasoningConfig; + +// FROM: Parse provider-specific value to universal config +let config: ReasoningConfig = (ProviderFormat::Anthropic, &raw_json).try_into()?; + +// TO: Convert universal config to provider-specific value +// Note: max_tokens is passed explicitly for effort→budget conversion +let output = config.to_provider(ProviderFormat::Anthropic, Some(4096))?; +``` +*/ + +use std::convert::TryFrom; + +use crate::capabilities::ProviderFormat; +use crate::processing::transform::TransformError; +use crate::serde_json::{json, Map, Value}; +use crate::universal::request::{ReasoningConfig, ReasoningEffort}; +#[cfg(test)] +use crate::universal::request::SummaryMode; + +// ============================================================================= +// Heuristic Constants +// ============================================================================= + +/// Multiplier for "low" effort (25% of max_tokens) +pub const EFFORT_LOW_MULTIPLIER: f64 = 0.25; + +/// Multiplier for "medium" effort (50% of max_tokens) +pub const EFFORT_MEDIUM_MULTIPLIER: f64 = 0.50; + +/// Multiplier for "high" effort (75% of max_tokens) +pub const EFFORT_HIGH_MULTIPLIER: f64 = 0.75; + +/// Threshold below which budget is considered "low" effort +pub const EFFORT_LOW_THRESHOLD: f64 = 0.35; + +/// Threshold above which budget is considered "high" effort +pub const EFFORT_HIGH_THRESHOLD: f64 = 0.65; + +/// Minimum thinking budget for Anthropic +pub const MIN_THINKING_BUDGET: i64 = 1024; + +/// Default max_tokens to use when not specified +pub const DEFAULT_MAX_TOKENS: i64 = 4096; + +/// Required temperature for Anthropic when thinking is enabled +pub const ANTHROPIC_THINKING_TEMPERATURE: f64 = 1.0; + +// ============================================================================= +// Effort ↔ Budget Conversion +// ============================================================================= + +/// Convert effort level to token budget. +/// +/// Uses multipliers applied to max_tokens: +/// - low: 25% of max_tokens +/// - medium: 50% of max_tokens +/// - high: 75% of max_tokens +/// +/// Result is clamped to minimum of 1024 tokens (Anthropic requirement). +pub fn effort_to_budget(effort: ReasoningEffort, max_tokens: Option) -> i64 { + let max = max_tokens.unwrap_or(DEFAULT_MAX_TOKENS); + let multiplier = match effort { + ReasoningEffort::Low => EFFORT_LOW_MULTIPLIER, + ReasoningEffort::Medium => EFFORT_MEDIUM_MULTIPLIER, + ReasoningEffort::High => EFFORT_HIGH_MULTIPLIER, + }; + let budget = (max as f64 * multiplier).floor() as i64; + budget.max(MIN_THINKING_BUDGET) +} + +/// Convert token budget to effort level. +/// +/// Uses ratio of budget/max_tokens with thresholds: +/// - ratio < 0.35: low +/// - 0.35 <= ratio < 0.65: medium +/// - ratio >= 0.65: high +pub fn budget_to_effort(budget: i64, max_tokens: Option) -> ReasoningEffort { + let max = max_tokens.unwrap_or(DEFAULT_MAX_TOKENS); + let ratio = budget as f64 / max as f64; + + if ratio < EFFORT_LOW_THRESHOLD { + ReasoningEffort::Low + } else if ratio < EFFORT_HIGH_THRESHOLD { + ReasoningEffort::Medium + } else { + ReasoningEffort::High + } +} + +// ============================================================================= +// TryFrom Implementation for FROM Conversions +// ============================================================================= + +impl<'a> TryFrom<(ProviderFormat, &'a Value)> for ReasoningConfig { + type Error = TransformError; + + fn try_from((provider, value): (ProviderFormat, &'a Value)) -> Result { + match provider { + ProviderFormat::OpenAI => { + // For OpenAI Chat, value is expected to be the reasoning_effort string + if let Some(effort_str) = value.as_str() { + Ok(from_openai_chat_reasoning_effort(effort_str, value.clone())) + } else { + Ok(Self::default()) + } + } + ProviderFormat::Responses => Ok(from_openai_responses(value)), + ProviderFormat::Anthropic => Ok(from_anthropic(value)), + ProviderFormat::Google => Ok(from_google(value)), + _ => Ok(Self::default()), + } + } +} + +// ============================================================================= +// to_provider Method for TO Conversions +// ============================================================================= + +impl ReasoningConfig { + /// Convert this config to a provider-specific value. + /// + /// # Arguments + /// * `provider` - Target provider format + /// * `max_tokens` - Max tokens for effort→budget conversion (for Anthropic/Google) + /// + /// # Returns + /// `Ok(Some(value))` if conversion succeeded + /// `Ok(None)` if reasoning is not enabled or no value should be set + /// `Err(_)` if conversion failed + pub fn to_provider( + &self, + provider: ProviderFormat, + max_tokens: Option, + ) -> Result, TransformError> { + match provider { + ProviderFormat::OpenAI => Ok(to_openai_chat(self, max_tokens).map(Value::String)), + ProviderFormat::Responses => Ok(to_openai_responses(self, max_tokens)), + ProviderFormat::Anthropic => Ok(to_anthropic(self, max_tokens)), + ProviderFormat::Google => Ok(to_google(self, max_tokens)), + _ => Ok(None), + } + } +} + +// ============================================================================= +// Private Helper Functions - FROM Provider Formats +// ============================================================================= + +/// Parse OpenAI Chat `reasoning_effort` string into ReasoningConfig. +fn from_openai_chat_reasoning_effort(reasoning_effort: &str, raw_value: Value) -> ReasoningConfig { + ReasoningConfig { + enabled: Some(true), + effort: reasoning_effort.parse().ok(), + budget_tokens: None, // OpenAI Chat doesn't have budget + summary: None, + raw: Some(raw_value), + } +} + +/// Parse OpenAI Responses API `reasoning` object into ReasoningConfig. +fn from_openai_responses(reasoning: &Value) -> ReasoningConfig { + let effort = reasoning + .get("effort") + .and_then(Value::as_str) + .and_then(|s| s.parse().ok()); + + let summary = reasoning + .get("summary") + .and_then(Value::as_str) + .and_then(|s| s.parse().ok()); + + ReasoningConfig { + enabled: Some(true), + effort, + budget_tokens: None, + summary, + raw: Some(reasoning.clone()), + } +} + +/// Parse Anthropic `thinking` object into ReasoningConfig. +fn from_anthropic(thinking: &Value) -> ReasoningConfig { + let enabled = thinking + .get("type") + .and_then(Value::as_str) + .map(|t| t == "enabled"); + + let budget_tokens = thinking.get("budget_tokens").and_then(Value::as_i64); + + ReasoningConfig { + enabled, + effort: None, // Anthropic doesn't have effort level + budget_tokens, + summary: None, + raw: Some(thinking.clone()), + } +} + +/// Parse Google `thinkingConfig` object into ReasoningConfig. +fn from_google(config: &Value) -> ReasoningConfig { + let enabled = config + .get("includeThoughts") + .and_then(Value::as_bool) + .or_else(|| { + // If thinkingBudget > 0, thinking is enabled + config + .get("thinkingBudget") + .and_then(Value::as_i64) + .map(|b| b > 0) + }); + + let budget_tokens = config.get("thinkingBudget").and_then(Value::as_i64); + + ReasoningConfig { + enabled, + effort: None, // Google doesn't have effort level + budget_tokens, + summary: None, + raw: Some(config.clone()), + } +} + +// ============================================================================= +// Private Helper Functions - TO Provider Formats +// ============================================================================= + +/// Convert ReasoningConfig to OpenAI Chat `reasoning_effort` string. +fn to_openai_chat(config: &ReasoningConfig, max_tokens: Option) -> Option { + // If we have effort, use it directly + if let Some(effort) = config.effort { + return Some(effort.to_string()); + } + + // If we have budget_tokens, derive effort from it + if let Some(budget) = config.budget_tokens { + let effort = budget_to_effort(budget, max_tokens); + return Some(effort.to_string()); + } + + // If just enabled with no specifics, default to medium + if config.enabled == Some(true) { + return Some("medium".to_string()); + } + + None +} + +/// Convert ReasoningConfig to OpenAI Responses API `reasoning` object. +fn to_openai_responses(config: &ReasoningConfig, max_tokens: Option) -> Option { + if config.enabled != Some(true) { + return None; + } + + let mut obj = Map::new(); + + // Effort + let effort = config + .effort + .map(|e| e.to_string()) + .or_else(|| { + config + .budget_tokens + .map(|b| budget_to_effort(b, max_tokens).to_string()) + }) + .unwrap_or_else(|| "medium".to_string()); + + obj.insert("effort".into(), Value::String(effort)); + + // Summary + if let Some(summary) = config.summary { + obj.insert("summary".into(), Value::String(summary.to_string())); + } + + Some(Value::Object(obj)) +} + +/// Convert ReasoningConfig to Anthropic `thinking` object. +fn to_anthropic(config: &ReasoningConfig, max_tokens: Option) -> Option { + if config.enabled != Some(true) { + return None; + } + + // Calculate budget_tokens + let budget = config.budget_tokens.unwrap_or_else(|| { + config + .effort + .map(|e| effort_to_budget(e, max_tokens)) + .unwrap_or(MIN_THINKING_BUDGET) + }); + + Some(json!({ + "type": "enabled", + "budget_tokens": budget + })) +} + +/// Convert ReasoningConfig to Google `thinkingConfig` object. +fn to_google(config: &ReasoningConfig, max_tokens: Option) -> Option { + if config.enabled != Some(true) { + return None; + } + + // Calculate thinkingBudget + let budget = config.budget_tokens.unwrap_or_else(|| { + config + .effort + .map(|e| effort_to_budget(e, max_tokens)) + .unwrap_or(MIN_THINKING_BUDGET) + }); + + Some(json!({ + "includeThoughts": true, + "thinkingBudget": budget + })) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::convert::TryInto; + + #[test] + fn test_effort_to_budget() { + // With default max_tokens (4096) + assert_eq!(effort_to_budget(ReasoningEffort::Low, None), 1024); // 4096 * 0.25 = 1024 + assert_eq!(effort_to_budget(ReasoningEffort::Medium, None), 2048); // 4096 * 0.50 = 2048 + assert_eq!(effort_to_budget(ReasoningEffort::High, None), 3072); // 4096 * 0.75 = 3072 + + // With custom max_tokens + assert_eq!(effort_to_budget(ReasoningEffort::Medium, Some(8192)), 4096); + + // Minimum budget enforced + assert_eq!(effort_to_budget(ReasoningEffort::Low, Some(1000)), 1024); // Would be 250, clamped to 1024 + } + + #[test] + fn test_budget_to_effort() { + // With default max_tokens (4096) + assert_eq!(budget_to_effort(500, None), ReasoningEffort::Low); // 500/4096 = 0.12 < 0.35 + assert_eq!(budget_to_effort(2000, None), ReasoningEffort::Medium); // 2000/4096 = 0.49 + assert_eq!(budget_to_effort(3000, None), ReasoningEffort::High); // 3000/4096 = 0.73 >= 0.65 + + // With custom max_tokens + assert_eq!(budget_to_effort(4096, Some(8192)), ReasoningEffort::Medium); // 4096/8192 = 0.5 + } + + #[test] + fn test_roundtrip_effort() { + // effort → budget → effort should preserve the original level + for effort in [ + ReasoningEffort::Low, + ReasoningEffort::Medium, + ReasoningEffort::High, + ] { + let budget = effort_to_budget(effort, Some(4096)); + let back = budget_to_effort(budget, Some(4096)); + assert_eq!(effort, back, "Roundtrip failed for {:?}", effort); + } + } + + #[test] + fn test_from_anthropic() { + let value = json!({ + "type": "enabled", + "budget_tokens": 2048 + }); + let config: ReasoningConfig = (ProviderFormat::Anthropic, &value).try_into().unwrap(); + assert_eq!(config.enabled, Some(true)); + assert_eq!(config.budget_tokens, Some(2048)); + } + + #[test] + fn test_to_anthropic_thinking() { + let config = ReasoningConfig { + enabled: Some(true), + effort: Some(ReasoningEffort::Medium), + budget_tokens: None, + summary: None, + raw: None, + }; + + let thinking = config.to_provider(ProviderFormat::Anthropic, Some(4096)).unwrap().unwrap(); + assert_eq!(thinking.get("type").unwrap(), "enabled"); + assert_eq!(thinking.get("budget_tokens").unwrap(), 2048); + } + + #[test] + fn test_to_openai_chat_reasoning() { + let config = ReasoningConfig { + enabled: Some(true), + effort: None, + budget_tokens: Some(2048), + summary: None, + raw: None, + }; + + let effort = config.to_provider(ProviderFormat::OpenAI, Some(4096)).unwrap().unwrap(); + assert_eq!(effort.as_str().unwrap(), "medium"); // 2048/4096 = 0.5 → medium + } + + #[test] + fn test_from_openai_responses() { + let value = json!({ + "effort": "high", + "summary": "detailed" + }); + let config: ReasoningConfig = (ProviderFormat::Responses, &value).try_into().unwrap(); + assert_eq!(config.enabled, Some(true)); + assert_eq!(config.effort, Some(ReasoningEffort::High)); + assert_eq!(config.summary, Some(SummaryMode::Detailed)); + } +} diff --git a/crates/lingua/src/universal/request.rs b/crates/lingua/src/universal/request.rs index bd0b7ece..732b3b57 100644 --- a/crates/lingua/src/universal/request.rs +++ b/crates/lingua/src/universal/request.rs @@ -6,23 +6,56 @@ converted to/from any provider format. ## Design principles -1. **Round-trip preservation**: Any field not mapped to a canonical field goes - into `extras` and is restored when converting back to the source format. +1. **Round-trip preservation**: Provider-specific fields are stored in + `provider_extras` keyed by `ProviderFormat`, and restored when converting + back to the same provider format. 2. **Canonical naming**: Uses consistent field names (e.g., `max_tokens`, `top_p`) regardless of what individual providers call them. -3. **Minimal typing for complex fields**: Fields like `tools`, `tool_choice`, and - `response_format` are kept as `Value` since they vary significantly across providers. +3. **Typed configs with lossless round-trip**: Complex fields like `tool_choice`, + `response_format`, and `stop` use typed structs with a `raw` field for lossless + preservation. Only `tools` and `metadata` remain as `Value`. + +4. **Provider isolation**: Provider-specific extras are scoped by `ProviderFormat` + to prevent cross-provider contamination (e.g., OpenAI extras don't bleed into + Anthropic requests). */ +use std::collections::HashMap; +use std::fmt; +use std::str::FromStr; + +use crate::capabilities::ProviderFormat; use crate::serde_json::{Map, Value}; use crate::universal::message::Message; +// ============================================================================= +// Error Types +// ============================================================================= + +/// Error type for enum parsing from strings. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ParseEnumError { + /// The type name that failed to parse. + pub type_name: &'static str, + /// The invalid input value. + pub value: String, +} + +impl fmt::Display for ParseEnumError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "invalid {} value: '{}'", self.type_name, self.value) + } +} + +impl std::error::Error for ParseEnumError {} + /// Universal request envelope for LLM API calls. /// /// This type captures the common structure across all provider request formats. -/// Provider-specific fields that don't map to canonical params go into `extras`. +/// Provider-specific fields are stored in `provider_extras`, keyed by the source +/// provider format to prevent cross-provider contamination. #[derive(Debug, Clone)] pub struct UniversalRequest { /// Model identifier (may be None for providers that use endpoint-based model selection) @@ -31,18 +64,27 @@ pub struct UniversalRequest { /// Conversation messages in universal format pub messages: Vec, - /// Common request parameters + /// Common request parameters (canonical fields only) pub params: UniversalParams, - /// Provider-specific fields not captured in params - pub extras: Map, + /// Provider-specific fields, keyed by the source ProviderFormat. + /// + /// When transforming back to the same provider, these extras are merged back. + /// When transforming to a different provider, they are ignored (no cross-pollination). + /// + /// Example: OpenAI Chat extras stay in `provider_extras[ProviderFormat::OpenAI]` + /// and are only merged back when converting to OpenAI Chat, not to Anthropic. + pub provider_extras: HashMap>, } /// Common request parameters across providers. /// /// Uses canonical names - adapters handle mapping to provider-specific names. +/// This struct contains ONLY canonical fields - no extras or provider-specific baggage. #[derive(Debug, Clone, Default)] pub struct UniversalParams { + // === Sampling parameters === + /// Sampling temperature (0.0 to 2.0 typically) pub temperature: Option, @@ -52,30 +94,427 @@ pub struct UniversalParams { /// Top-k sampling (not supported by all providers) pub top_k: Option, + /// Random seed for deterministic generation + pub seed: Option, + + /// Presence penalty (-2.0 to 2.0) + pub presence_penalty: Option, + + /// Frequency penalty (-2.0 to 2.0) + pub frequency_penalty: Option, + + // === Output control === + /// Maximum tokens to generate pub max_tokens: Option, - /// Stop sequences (kept as Value due to union type in OpenAI) - pub stop: Option, + /// Stop sequences configuration. + /// + /// This is a typed struct that supports both: + /// - **Lossless round-trip**: Original provider value stored in `raw` + /// - **Cross-provider conversion**: Normalized `sequences` array enables semantic translation + pub stop: Option, + + /// Whether to return log probabilities (OpenAI-specific but canonical) + pub logprobs: Option, + + /// Number of top logprobs to return (0-20) + pub top_logprobs: Option, + + // === Tools and function calling === /// Tool definitions (schema varies by provider) pub tools: Option, - /// Tool selection strategy (varies by provider) - pub tool_choice: Option, + /// Tool selection strategy configuration. + /// + /// This is a typed struct that supports both: + /// - **Lossless round-trip**: Original provider value stored in `raw` + /// - **Cross-provider conversion**: Canonical fields enable semantic translation + pub tool_choice: Option, - /// Output format specification (varies by provider) - pub response_format: Option, + /// Whether tools can be called in parallel + pub parallel_tool_calls: Option, - /// Random seed for deterministic generation - pub seed: Option, + // === Response format === - /// Presence penalty (-2.0 to 2.0) - pub presence_penalty: Option, + /// Response format configuration. + /// + /// This is a typed struct that supports both: + /// - **Lossless round-trip**: Original provider value stored in `raw` + /// - **Cross-provider conversion**: Canonical fields enable semantic translation + pub response_format: Option, - /// Frequency penalty (-2.0 to 2.0) - pub frequency_penalty: Option, + // === Reasoning / Extended thinking === + + /// Reasoning configuration for extended thinking / chain-of-thought. + /// + /// This is a typed struct that supports both: + /// - **Lossless round-trip**: Original provider value stored in `raw` + /// - **Cross-provider conversion**: Canonical fields enable semantic translation + pub reasoning: Option, + + // === Metadata and identification === + + /// Request metadata (user tracking, experiment tags, etc.) + pub metadata: Option, + + /// Whether to store completion for training/evals (OpenAI-specific but canonical) + pub store: Option, + + /// Service tier preference + pub service_tier: Option, + + // === Streaming === /// Whether to stream the response pub stream: Option, } + +// ============================================================================= +// Reasoning Configuration +// ============================================================================= + +/// Configuration for extended thinking / reasoning capabilities. +/// +/// Supports two usage modes: +/// 1. **Same-provider round-trip**: Use `raw` field for exact preservation +/// 2. **Cross-provider conversion**: Use canonical fields (`effort`, `budget_tokens`) +/// +/// When converting TO a provider: +/// - If `provider_extras` has the target provider's reasoning field → use it (lossless) +/// - Otherwise, derive from canonical fields using heuristics +#[derive(Debug, Clone, Default)] +pub struct ReasoningConfig { + /// Whether reasoning/thinking is enabled. + pub enabled: Option, + + /// Effort level (portable across providers that support effort-based control). + /// Maps to OpenAI's `reasoning_effort` and can be converted to Anthropic's `budget_tokens`. + pub effort: Option, + + /// Token budget for thinking (Anthropic's native format). + /// Can be derived from `effort` using heuristics when not explicitly set. + pub budget_tokens: Option, + + /// Summary mode for reasoning output. + /// Maps to OpenAI Responses API's `reasoning.summary` field. + pub summary: Option, + + /// Original provider-specific value for lossless round-trip. + /// This stores the exact JSON that was received, enabling perfect reconstruction. + pub raw: Option, +} + +/// Reasoning effort level (portable across providers). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ReasoningEffort { + Low, + Medium, + High, +} + +impl ReasoningEffort { + /// Returns the string representation. + pub fn as_str(&self) -> &'static str { + match self { + Self::Low => "low", + Self::Medium => "medium", + Self::High => "high", + } + } +} + +impl FromStr for ReasoningEffort { + type Err = ParseEnumError; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "low" => Ok(Self::Low), + "medium" => Ok(Self::Medium), + "high" => Ok(Self::High), + _ => Err(ParseEnumError { + type_name: "ReasoningEffort", + value: s.to_string(), + }), + } + } +} + +impl fmt::Display for ReasoningEffort { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +impl AsRef for ReasoningEffort { + fn as_ref(&self) -> &str { + self.as_str() + } +} + +/// Summary mode for reasoning output. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SummaryMode { + /// No summary included in response. + None, + /// Provider decides whether to include summary. + Auto, + /// Detailed summary included in response. + Detailed, +} + +impl SummaryMode { + /// Returns the string representation. + pub fn as_str(&self) -> &'static str { + match self { + Self::None => "none", + Self::Auto => "auto", + Self::Detailed => "detailed", + } + } +} + +impl FromStr for SummaryMode { + type Err = ParseEnumError; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "none" => Ok(Self::None), + "auto" => Ok(Self::Auto), + "detailed" => Ok(Self::Detailed), + _ => Err(ParseEnumError { + type_name: "SummaryMode", + value: s.to_string(), + }), + } + } +} + +impl fmt::Display for SummaryMode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +impl AsRef for SummaryMode { + fn as_ref(&self) -> &str { + self.as_str() + } +} + +// ============================================================================= +// Tool Choice Configuration +// ============================================================================= + +/// Tool selection strategy configuration. +/// +/// Supports two usage modes: +/// 1. **Same-provider round-trip**: Use `raw` field for exact preservation +/// 2. **Cross-provider conversion**: Use canonical fields (`mode`, `tool_name`) +/// +/// Provider mapping: +/// - OpenAI Chat: `"auto"` | `"none"` | `"required"` | `{ type: "function", function: { name } }` +/// - OpenAI Responses: `"auto"` | `{ type: "function", name }` +/// - Anthropic: `{ type: "auto" | "any" | "none" | "tool", name?, disable_parallel_tool_use? }` +#[derive(Debug, Clone, Default)] +pub struct ToolChoiceConfig { + /// Selection mode - the semantic intent of the tool choice + pub mode: Option, + + /// Specific tool name (when mode = Tool) + pub tool_name: Option, + + /// Whether to disable parallel tool calls. + /// Maps to Anthropic's `disable_parallel_tool_use` field. + /// For OpenAI, this is handled via the separate `parallel_tool_calls` param. + pub disable_parallel: Option, + + /// Original provider-specific value for lossless round-trip. + /// This stores the exact JSON that was received, enabling perfect reconstruction. + pub raw: Option, +} + +/// Tool selection mode (portable across providers). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ToolChoiceMode { + /// Provider decides whether to use tools + Auto, + /// No tools allowed + None, + /// Must use a tool (OpenAI "required" / Anthropic "any") + Required, + /// Specific tool required (use `tool_name` field) + Tool, +} + +impl ToolChoiceMode { + /// Returns the string representation (OpenAI format). + pub fn as_str(&self) -> &'static str { + match self { + Self::Auto => "auto", + Self::None => "none", + Self::Required => "required", + Self::Tool => "function", + } + } + + /// Convert to Anthropic format string. + pub fn as_anthropic_str(&self) -> &'static str { + match self { + Self::Auto => "auto", + Self::None => "none", + Self::Required => "any", + Self::Tool => "tool", + } + } +} + +impl FromStr for ToolChoiceMode { + type Err = ParseEnumError; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "auto" => Ok(Self::Auto), + "none" => Ok(Self::None), + "required" | "any" => Ok(Self::Required), + "tool" | "function" => Ok(Self::Tool), + _ => Err(ParseEnumError { + type_name: "ToolChoiceMode", + value: s.to_string(), + }), + } + } +} + +impl fmt::Display for ToolChoiceMode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +impl AsRef for ToolChoiceMode { + fn as_ref(&self) -> &str { + self.as_str() + } +} + +// ============================================================================= +// Response Format Configuration +// ============================================================================= + +/// Response format configuration for structured output. +/// +/// Supports two usage modes: +/// 1. **Same-provider round-trip**: Use `raw` field for exact preservation +/// 2. **Cross-provider conversion**: Use canonical fields (`format_type`, `json_schema`) +/// +/// Provider mapping: +/// - OpenAI Chat: `{ type: "text" | "json_object" | "json_schema", json_schema? }` +/// - OpenAI Responses: nested under `text.format` +/// - Google: `response_mime_type` + `response_schema` +/// - Anthropic: Not supported +#[derive(Debug, Clone, Default)] +pub struct ResponseFormatConfig { + /// Output format type + pub format_type: Option, + + /// JSON schema configuration (when format_type = JsonSchema) + pub json_schema: Option, + + /// Original provider-specific value for lossless round-trip. + /// This stores the exact JSON that was received, enabling perfect reconstruction. + pub raw: Option, +} + +/// Response format type (portable across providers). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ResponseFormatType { + /// Plain text output (default) + Text, + /// JSON object output (unstructured) + JsonObject, + /// JSON output conforming to a schema + JsonSchema, +} + +impl ResponseFormatType { + /// Returns the string representation. + pub fn as_str(&self) -> &'static str { + match self { + Self::Text => "text", + Self::JsonObject => "json_object", + Self::JsonSchema => "json_schema", + } + } +} + +impl FromStr for ResponseFormatType { + type Err = ParseEnumError; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "text" => Ok(Self::Text), + "json_object" => Ok(Self::JsonObject), + "json_schema" => Ok(Self::JsonSchema), + _ => Err(ParseEnumError { + type_name: "ResponseFormatType", + value: s.to_string(), + }), + } + } +} + +impl fmt::Display for ResponseFormatType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +impl AsRef for ResponseFormatType { + fn as_ref(&self) -> &str { + self.as_str() + } +} + +/// JSON schema configuration for structured output. +#[derive(Debug, Clone)] +pub struct JsonSchemaConfig { + /// Schema name (required by OpenAI) + pub name: String, + + /// The JSON schema definition + pub schema: Value, + + /// Whether to enable strict schema validation + pub strict: Option, + + /// Human-readable description of the schema + pub description: Option, +} + +// ============================================================================= +// Stop Configuration +// ============================================================================= + +/// Stop sequences configuration. +/// +/// Supports two usage modes: +/// 1. **Same-provider round-trip**: Use `raw` field for exact preservation +/// 2. **Cross-provider conversion**: Use normalized `sequences` array +/// +/// Provider mapping: +/// - OpenAI: `stop: string | string[]` (allows single string or array) +/// - Anthropic: `stop_sequences: string[]` +/// - Google: `generationConfig.stop_sequences: string[]` +/// - Bedrock: `inferenceConfig.stopSequences: string[]` +#[derive(Debug, Clone, Default)] +pub struct StopConfig { + /// Normalized stop sequences (always an array). + /// Single string inputs are converted to single-element arrays. + pub sequences: Vec, + + /// Original provider-specific value for lossless round-trip. + /// This stores the exact JSON that was received, enabling perfect reconstruction. + pub raw: Option, +} diff --git a/crates/lingua/src/universal/response_format.rs b/crates/lingua/src/universal/response_format.rs new file mode 100644 index 00000000..3adfd996 --- /dev/null +++ b/crates/lingua/src/universal/response_format.rs @@ -0,0 +1,338 @@ +/*! +Response format conversion utilities for cross-provider semantic translation. + +This module provides bidirectional conversion between different providers' +response format configurations: +- OpenAI Chat: `{ type: "text" | "json_object" | "json_schema", json_schema? }` +- OpenAI Responses: nested under `text.format` with flattened schema +- Google: `response_mime_type` + `response_schema` +- Anthropic: Not supported + +## Design + +The conversion preserves the original value in `raw` for lossless same-provider round-trips, +while extracting canonical fields (`format_type`, `json_schema`) for cross-provider conversion. + +## Usage + +```ignore +use std::convert::TryInto; +use crate::capabilities::ProviderFormat; +use crate::universal::request::ResponseFormatConfig; + +// FROM: Parse provider-specific value to universal config +let config: ResponseFormatConfig = (ProviderFormat::OpenAI, &raw_json).try_into()?; + +// TO: Convert universal config to provider-specific value +let output = config.to_provider(ProviderFormat::OpenAI)?; +``` +*/ + +use std::convert::TryFrom; + +use crate::capabilities::ProviderFormat; +use crate::processing::transform::TransformError; +use crate::serde_json::{json, Map, Value}; +use crate::universal::request::{JsonSchemaConfig, ResponseFormatConfig, ResponseFormatType}; + +// ============================================================================= +// TryFrom Implementation for FROM Conversions +// ============================================================================= + +impl<'a> TryFrom<(ProviderFormat, &'a Value)> for ResponseFormatConfig { + type Error = TransformError; + + fn try_from((provider, value): (ProviderFormat, &'a Value)) -> Result { + match provider { + ProviderFormat::OpenAI => Ok(from_openai_chat(value)), + ProviderFormat::Responses => Ok(from_openai_responses(value)), + _ => Ok(Self::default()), + } + } +} + +// ============================================================================= +// to_provider Method for TO Conversions +// ============================================================================= + +impl ResponseFormatConfig { + /// Convert this config to a provider-specific value. + /// + /// # Arguments + /// * `provider` - Target provider format + /// + /// # Returns + /// `Ok(Some(value))` if conversion succeeded + /// `Ok(None)` if no value should be set + /// `Err(_)` if conversion failed + pub fn to_provider( + &self, + provider: ProviderFormat, + ) -> Result, TransformError> { + match provider { + ProviderFormat::OpenAI => Ok(to_openai_chat(self)), + ProviderFormat::Responses => Ok(to_openai_responses_text(self)), + _ => Ok(None), + } + } +} + +// ============================================================================= +// Private Helper Functions - FROM Provider Formats +// ============================================================================= + +/// Parse OpenAI Chat `response_format` into ResponseFormatConfig. +/// +/// Handles: +/// - `{ type: "text" }` +/// - `{ type: "json_object" }` +/// - `{ type: "json_schema", json_schema: { name, schema, strict?, description? } }` +fn from_openai_chat(value: &Value) -> ResponseFormatConfig { + let format_type = value + .get("type") + .and_then(Value::as_str) + .and_then(|s| s.parse().ok()); + + let json_schema = if format_type == Some(ResponseFormatType::JsonSchema) { + value.get("json_schema").and_then(|js| { + let name = js.get("name").and_then(Value::as_str)?; + let schema = js.get("schema").cloned()?; + Some(JsonSchemaConfig { + name: name.to_string(), + schema, + strict: js.get("strict").and_then(Value::as_bool), + description: js + .get("description") + .and_then(Value::as_str) + .map(String::from), + }) + }) + } else { + None + }; + + ResponseFormatConfig { + format_type, + json_schema, + raw: Some(value.clone()), + } +} + +/// Parse OpenAI Responses API `text.format` into ResponseFormatConfig. +/// +/// Handles the flattened structure: +/// - `{ type: "json_schema", name, schema, strict?, description? }` +fn from_openai_responses(value: &Value) -> ResponseFormatConfig { + let format_type = value + .get("type") + .and_then(Value::as_str) + .and_then(|s| s.parse().ok()); + + let json_schema = if format_type == Some(ResponseFormatType::JsonSchema) { + value + .get("name") + .and_then(Value::as_str) + .and_then(|name| { + value.get("schema").cloned().map(|schema| JsonSchemaConfig { + name: name.to_string(), + schema, + strict: value.get("strict").and_then(Value::as_bool), + description: value + .get("description") + .and_then(Value::as_str) + .map(String::from), + }) + }) + } else { + None + }; + + ResponseFormatConfig { + format_type, + json_schema, + raw: Some(value.clone()), + } +} + +// ============================================================================= +// Private Helper Functions - TO Provider Formats +// ============================================================================= + +/// Convert ResponseFormatConfig to OpenAI Chat `response_format` value. +/// +/// Output format: +/// - `{ type: "text" }` +/// - `{ type: "json_object" }` +/// - `{ type: "json_schema", json_schema: { name, schema, strict?, description? } }` +fn to_openai_chat(config: &ResponseFormatConfig) -> Option { + // If we have raw and it looks like OpenAI Chat format, prefer it + if let Some(ref raw) = config.raw { + if raw.as_object().map_or(false, |o| { + o.contains_key("type") + && (o.contains_key("json_schema") || !o.contains_key("schema")) + }) { + return Some(raw.clone()); + } + } + + let format_type = config.format_type?; + + match format_type { + ResponseFormatType::Text => Some(json!({ "type": "text" })), + ResponseFormatType::JsonObject => Some(json!({ "type": "json_object" })), + ResponseFormatType::JsonSchema => { + let js = config.json_schema.as_ref()?; + let mut json_schema = Map::new(); + json_schema.insert("name".into(), Value::String(js.name.clone())); + json_schema.insert("schema".into(), js.schema.clone()); + if let Some(strict) = js.strict { + json_schema.insert("strict".into(), Value::Bool(strict)); + } + if let Some(ref desc) = js.description { + json_schema.insert("description".into(), Value::String(desc.clone())); + } + Some(json!({ + "type": "json_schema", + "json_schema": json_schema + })) + } + } +} + +/// Convert ResponseFormatConfig to OpenAI Responses API `text` object. +/// +/// Output format (flattened, wrapped in text object): +/// - `{ format: { type: "text" } }` +/// - `{ format: { type: "json_schema", name, schema, strict?, description? } }` +/// +/// Returns the full `text` object, not just the format. +fn to_openai_responses_text(config: &ResponseFormatConfig) -> Option { + // If we have raw from Responses API format, reconstruct text wrapper + if let Some(ref raw) = config.raw { + // Check if it's Responses format (flat, has name at top level) + if raw.as_object().map_or(false, |o| { + o.contains_key("type") && (o.contains_key("name") || o.contains_key("schema")) + }) { + return Some(json!({ "format": raw })); + } + } + + let format_type = config.format_type?; + + let format_obj = match format_type { + ResponseFormatType::Text => json!({ "type": "text" }), + ResponseFormatType::JsonObject => json!({ "type": "json_object" }), + ResponseFormatType::JsonSchema => { + let js = config.json_schema.as_ref()?; + let mut obj = Map::new(); + obj.insert("type".into(), Value::String("json_schema".into())); + obj.insert("name".into(), Value::String(js.name.clone())); + obj.insert("schema".into(), js.schema.clone()); + if let Some(strict) = js.strict { + obj.insert("strict".into(), Value::Bool(strict)); + } + if let Some(ref desc) = js.description { + obj.insert("description".into(), Value::String(desc.clone())); + } + Value::Object(obj) + } + }; + + Some(json!({ "format": format_obj })) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::convert::TryInto; + + #[test] + fn test_from_openai_chat_text() { + let value = json!({ "type": "text" }); + let config: ResponseFormatConfig = (ProviderFormat::OpenAI, &value).try_into().unwrap(); + assert_eq!(config.format_type, Some(ResponseFormatType::Text)); + assert!(config.json_schema.is_none()); + } + + #[test] + fn test_from_openai_chat_json_schema() { + let value = json!({ + "type": "json_schema", + "json_schema": { + "name": "person_info", + "schema": { + "type": "object", + "properties": { + "name": { "type": "string" } + } + }, + "strict": true + } + }); + let config: ResponseFormatConfig = (ProviderFormat::OpenAI, &value).try_into().unwrap(); + assert_eq!(config.format_type, Some(ResponseFormatType::JsonSchema)); + let js = config.json_schema.unwrap(); + assert_eq!(js.name, "person_info"); + assert_eq!(js.strict, Some(true)); + } + + #[test] + fn test_to_openai_chat_json_schema() { + let config = ResponseFormatConfig { + format_type: Some(ResponseFormatType::JsonSchema), + json_schema: Some(JsonSchemaConfig { + name: "test_schema".into(), + schema: json!({ "type": "object" }), + strict: Some(true), + description: None, + }), + raw: None, + }; + let value = config.to_provider(ProviderFormat::OpenAI).unwrap().unwrap(); + assert_eq!(value.get("type").unwrap(), "json_schema"); + assert!(value.get("json_schema").is_some()); + assert_eq!( + value + .get("json_schema") + .unwrap() + .get("name") + .unwrap() + .as_str() + .unwrap(), + "test_schema" + ); + } + + #[test] + fn test_roundtrip_openai_chat() { + let original = json!({ + "type": "json_schema", + "json_schema": { + "name": "test", + "schema": { "type": "object" }, + "strict": true + } + }); + let config: ResponseFormatConfig = (ProviderFormat::OpenAI, &original).try_into().unwrap(); + let back = config.to_provider(ProviderFormat::OpenAI).unwrap().unwrap(); + assert_eq!(original, back); + } + + #[test] + fn test_to_responses_text_format() { + let config = ResponseFormatConfig { + format_type: Some(ResponseFormatType::JsonSchema), + json_schema: Some(JsonSchemaConfig { + name: "test".into(), + schema: json!({ "type": "object" }), + strict: Some(true), + description: None, + }), + raw: None, + }; + let value = config.to_provider(ProviderFormat::Responses).unwrap().unwrap(); + let format = value.get("format").unwrap(); + assert_eq!(format.get("type").unwrap(), "json_schema"); + assert_eq!(format.get("name").unwrap(), "test"); + } +} diff --git a/crates/lingua/src/universal/stop.rs b/crates/lingua/src/universal/stop.rs new file mode 100644 index 00000000..df75f579 --- /dev/null +++ b/crates/lingua/src/universal/stop.rs @@ -0,0 +1,237 @@ +/*! +Stop sequences conversion utilities for cross-provider semantic translation. + +This module provides bidirectional conversion between different providers' +stop sequences configurations: +- OpenAI: `stop: string | string[]` (allows single string or array) +- Anthropic: `stop_sequences: string[]` +- Google: `generationConfig.stop_sequences: string[]` +- Bedrock: `inferenceConfig.stopSequences: string[]` + +## Design + +The conversion normalizes all inputs to a `Vec` for cross-provider +compatibility while preserving the original value in `raw` for lossless +same-provider round-trips. + +## Usage + +```ignore +use std::convert::TryInto; +use crate::capabilities::ProviderFormat; +use crate::universal::request::StopConfig; + +// FROM: Parse provider-specific value to universal config +let config: StopConfig = (ProviderFormat::OpenAI, &raw_json).try_into()?; + +// TO: Convert universal config to provider-specific value +let output = config.to_provider(ProviderFormat::OpenAI)?; +``` +*/ + +use std::convert::TryFrom; + +use crate::capabilities::ProviderFormat; +use crate::processing::transform::TransformError; +use crate::serde_json::{json, Value}; +use crate::universal::request::StopConfig; + +// ============================================================================= +// TryFrom Implementation for FROM Conversions +// ============================================================================= + +impl<'a> TryFrom<(ProviderFormat, &'a Value)> for StopConfig { + type Error = TransformError; + + fn try_from((_provider, value): (ProviderFormat, &'a Value)) -> Result { + // All providers use the same parsing logic - normalize to sequences array + Ok(from_value(value)) + } +} + +// ============================================================================= +// to_provider Method for TO Conversions +// ============================================================================= + +impl StopConfig { + /// Convert this config to a provider-specific value. + /// + /// # Arguments + /// * `provider` - Target provider format + /// + /// # Returns + /// `Ok(Some(value))` if conversion succeeded + /// `Ok(None)` if sequences are empty + /// `Err(_)` if conversion failed + pub fn to_provider( + &self, + provider: ProviderFormat, + ) -> Result, TransformError> { + match provider { + ProviderFormat::OpenAI | ProviderFormat::Responses => Ok(to_openai(self)), + ProviderFormat::Anthropic | ProviderFormat::Google | ProviderFormat::Converse => { + Ok(to_array(self).map(|arr| Value::Array(arr.into_iter().map(Value::String).collect()))) + } + _ => Ok(None), + } + } + + /// Get the sequences as a Vec for providers that need arrays. + /// + /// This is a convenience method for providers that need the raw array. + pub fn to_sequences_array(&self) -> Option> { + to_array(self) + } +} + +// ============================================================================= +// Private Helper Functions - FROM Provider Formats +// ============================================================================= + +/// Parse stop sequences from any provider format. +/// +/// Handles: +/// - `"single_string"` → `["single_string"]` +/// - `["arr", "of", "strings"]` → `["arr", "of", "strings"]` +/// - Other types → empty sequences with raw preserved +fn from_value(value: &Value) -> StopConfig { + let sequences = match value { + Value::String(s) => vec![s.clone()], + Value::Array(arr) => arr + .iter() + .filter_map(Value::as_str) + .map(String::from) + .collect(), + _ => vec![], + }; + + StopConfig { + sequences, + raw: Some(value.clone()), + } +} + +// ============================================================================= +// Private Helper Functions - TO Provider Formats +// ============================================================================= + +/// Convert StopConfig to OpenAI format. +/// +/// For lossless round-trip: returns the original `raw` value if available. +/// For cross-provider conversion: returns the appropriate format based on sequence count. +/// +/// Output format: +/// - `"single_string"` (if single sequence and raw was a string) +/// - `["arr", "of", "strings"]` (if multiple sequences or raw was array) +fn to_openai(config: &StopConfig) -> Option { + // For lossless round-trip, prefer raw + if let Some(ref raw) = config.raw { + return Some(raw.clone()); + } + + // Cross-provider conversion from sequences + if config.sequences.is_empty() { + return None; + } + + // OpenAI accepts either string or array, prefer array for consistency + Some(Value::Array( + config.sequences.iter().map(|s| json!(s)).collect(), + )) +} + +/// Convert StopConfig to array format for providers that only accept arrays. +/// +/// Used by: Anthropic, Google, Bedrock +/// +/// For lossless round-trip: extracts array from raw if it was already an array. +/// For cross-provider conversion: returns sequences as array. +fn to_array(config: &StopConfig) -> Option> { + if config.sequences.is_empty() { + // If we have raw, try to extract from it + if let Some(ref raw) = config.raw { + return match raw { + Value::Array(arr) => { + let sequences: Vec = arr + .iter() + .filter_map(Value::as_str) + .map(String::from) + .collect(); + if sequences.is_empty() { + None + } else { + Some(sequences) + } + } + Value::String(s) => Some(vec![s.clone()]), + _ => None, + }; + } + return None; + } + + Some(config.sequences.clone()) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::convert::TryInto; + + #[test] + fn test_from_string() { + let value = json!("stop_word"); + let config: StopConfig = (ProviderFormat::OpenAI, &value).try_into().unwrap(); + assert_eq!(config.sequences, vec!["stop_word"]); + assert!(config.raw.is_some()); + } + + #[test] + fn test_from_array() { + let value = json!(["stop1", "stop2", "stop3"]); + let config: StopConfig = (ProviderFormat::OpenAI, &value).try_into().unwrap(); + assert_eq!(config.sequences, vec!["stop1", "stop2", "stop3"]); + } + + #[test] + fn test_to_openai_roundtrip_string() { + let original = json!("single"); + let config: StopConfig = (ProviderFormat::OpenAI, &original).try_into().unwrap(); + let back = config.to_provider(ProviderFormat::OpenAI).unwrap().unwrap(); + assert_eq!(original, back); + } + + #[test] + fn test_to_openai_roundtrip_array() { + let original = json!(["a", "b"]); + let config: StopConfig = (ProviderFormat::OpenAI, &original).try_into().unwrap(); + let back = config.to_provider(ProviderFormat::OpenAI).unwrap().unwrap(); + assert_eq!(original, back); + } + + #[test] + fn test_to_sequences_array() { + let config = StopConfig { + sequences: vec!["stop1".into(), "stop2".into()], + raw: None, + }; + let arr = config.to_sequences_array().unwrap(); + assert_eq!(arr, vec!["stop1", "stop2"]); + } + + #[test] + fn test_cross_provider_string_to_array() { + // OpenAI string input → Anthropic array output + let openai_value = json!("single_stop"); + let config: StopConfig = (ProviderFormat::OpenAI, &openai_value).try_into().unwrap(); + let anthropic_arr = config.to_sequences_array().unwrap(); + assert_eq!(anthropic_arr, vec!["single_stop"]); + } + + #[test] + fn test_empty_sequences() { + let config = StopConfig::default(); + assert!(config.to_provider(ProviderFormat::OpenAI).unwrap().is_none()); + assert!(config.to_sequences_array().is_none()); + } +} diff --git a/crates/lingua/src/universal/tool_choice.rs b/crates/lingua/src/universal/tool_choice.rs new file mode 100644 index 00000000..b5fd0e67 --- /dev/null +++ b/crates/lingua/src/universal/tool_choice.rs @@ -0,0 +1,443 @@ +/*! +Tool choice conversion utilities for cross-provider semantic translation. + +This module provides bidirectional conversion between different providers' +tool choice configurations: +- OpenAI Chat: `"auto"` | `"none"` | `"required"` | `{ type: "function", function: { name } }` +- OpenAI Responses: `"auto"` | `{ type: "function", name }` +- Anthropic: `{ type: "auto" | "any" | "none" | "tool", name?, disable_parallel_tool_use? }` + +## Design + +The conversion preserves the original value in `raw` for lossless same-provider round-trips, +while extracting canonical fields (`mode`, `tool_name`) for cross-provider conversion. + +## Usage + +```ignore +use std::convert::TryInto; +use crate::capabilities::ProviderFormat; +use crate::universal::request::ToolChoiceConfig; + +// FROM: Parse provider-specific value to universal config +let config: ToolChoiceConfig = (ProviderFormat::Anthropic, &raw_json).try_into()?; + +// TO: Convert universal config to provider-specific value +// Note: parallel_tool_calls is passed explicitly for Anthropic's disable_parallel_tool_use +let output = config.to_provider(ProviderFormat::Anthropic, Some(false))?; +``` +*/ + +use std::convert::TryFrom; + +use crate::capabilities::ProviderFormat; +use crate::processing::transform::TransformError; +use crate::serde_json::{json, Map, Value}; +use crate::universal::request::{ToolChoiceConfig, ToolChoiceMode}; + +// ============================================================================= +// TryFrom Implementation for FROM Conversions +// ============================================================================= + +impl<'a> TryFrom<(ProviderFormat, &'a Value)> for ToolChoiceConfig { + type Error = TransformError; + + fn try_from((provider, value): (ProviderFormat, &'a Value)) -> Result { + match provider { + ProviderFormat::OpenAI => Ok(from_openai_chat(value)), + ProviderFormat::Responses => Ok(from_openai_responses(value)), + ProviderFormat::Anthropic => Ok(from_anthropic(value)), + _ => Ok(Self::default()), + } + } +} + +// ============================================================================= +// to_provider Method for TO Conversions +// ============================================================================= + +impl ToolChoiceConfig { + /// Convert this config to a provider-specific value. + /// + /// # Arguments + /// * `provider` - Target provider format + /// * `parallel_tool_calls` - Whether parallel tool calls are enabled (for Anthropic's disable_parallel_tool_use) + /// + /// # Returns + /// `Ok(Some(value))` if conversion succeeded + /// `Ok(None)` if no value should be set (e.g., mode is None) + /// `Err(_)` if conversion failed + pub fn to_provider( + &self, + provider: ProviderFormat, + parallel_tool_calls: Option, + ) -> Result, TransformError> { + match provider { + ProviderFormat::OpenAI => Ok(to_openai_chat(self)), + ProviderFormat::Responses => Ok(to_openai_responses(self)), + ProviderFormat::Anthropic => Ok(to_anthropic(self, parallel_tool_calls)), + _ => Ok(None), + } + } +} + +// ============================================================================= +// Private Helper Functions - FROM Provider Formats +// ============================================================================= + +/// Parse OpenAI Chat `tool_choice` into ToolChoiceConfig. +/// +/// Handles: +/// - String: `"auto"`, `"none"`, `"required"` +/// - Object: `{ type: "function", function: { name: "..." } }` +fn from_openai_chat(value: &Value) -> ToolChoiceConfig { + match value { + Value::String(s) => { + let mode = s.parse().ok(); + ToolChoiceConfig { + mode, + tool_name: None, + disable_parallel: None, + raw: Some(value.clone()), + } + } + Value::Object(obj) => { + // Check for { type: "function", function: { name: "..." } } + let tool_name = obj + .get("function") + .and_then(|f| f.get("name")) + .and_then(Value::as_str) + .map(String::from); + + ToolChoiceConfig { + mode: Some(ToolChoiceMode::Tool), + tool_name, + disable_parallel: None, + raw: Some(value.clone()), + } + } + _ => ToolChoiceConfig { + mode: None, + tool_name: None, + disable_parallel: None, + raw: Some(value.clone()), + }, + } +} + +/// Parse OpenAI Responses API `tool_choice` into ToolChoiceConfig. +/// +/// Handles: +/// - String: `"auto"`, `"none"`, `"required"` +/// - Object: `{ type: "function", name: "..." }` (flatter than Chat) +fn from_openai_responses(value: &Value) -> ToolChoiceConfig { + match value { + Value::String(s) => { + let mode = s.parse().ok(); + ToolChoiceConfig { + mode, + tool_name: None, + disable_parallel: None, + raw: Some(value.clone()), + } + } + Value::Object(obj) => { + // Responses API uses flatter structure: { type: "function", name: "..." } + let tool_name = obj.get("name").and_then(Value::as_str).map(String::from); + + let mode = obj + .get("type") + .and_then(Value::as_str) + .and_then(|s| s.parse().ok()) + .or(Some(ToolChoiceMode::Tool)); + + ToolChoiceConfig { + mode, + tool_name, + disable_parallel: None, + raw: Some(value.clone()), + } + } + _ => ToolChoiceConfig { + mode: None, + tool_name: None, + disable_parallel: None, + raw: Some(value.clone()), + }, + } +} + +/// Parse Anthropic `tool_choice` into ToolChoiceConfig. +/// +/// Handles: +/// - `{ type: "auto" }` +/// - `{ type: "any" }` +/// - `{ type: "none" }` +/// - `{ type: "tool", name: "..." }` +/// - `{ ..., disable_parallel_tool_use: true }` +fn from_anthropic(value: &Value) -> ToolChoiceConfig { + let obj = match value.as_object() { + Some(o) => o, + None => { + return ToolChoiceConfig { + raw: Some(value.clone()), + ..Default::default() + } + } + }; + + let mode = obj + .get("type") + .and_then(Value::as_str) + .and_then(|s| s.parse().ok()); + + let tool_name = obj.get("name").and_then(Value::as_str).map(String::from); + + let disable_parallel = obj + .get("disable_parallel_tool_use") + .and_then(Value::as_bool); + + ToolChoiceConfig { + mode, + tool_name, + disable_parallel, + raw: Some(value.clone()), + } +} + +// ============================================================================= +// Private Helper Functions - TO Provider Formats +// ============================================================================= + +/// Convert ToolChoiceConfig to OpenAI Chat `tool_choice` value. +/// +/// Output format: +/// - `"auto"`, `"none"`, `"required"` for simple modes +/// - `{ type: "function", function: { name: "..." } }` for specific tool +fn to_openai_chat(config: &ToolChoiceConfig) -> Option { + // If we have raw and it came from OpenAI Chat format, prefer it + if let Some(ref raw) = config.raw { + // Check if it looks like OpenAI Chat format (has nested function object or is string) + if raw.is_string() + || raw + .as_object() + .map_or(false, |o| o.contains_key("function")) + { + return Some(raw.clone()); + } + } + + let mode = config.mode?; + + match mode { + ToolChoiceMode::Auto => Some(Value::String("auto".into())), + ToolChoiceMode::None => Some(Value::String("none".into())), + ToolChoiceMode::Required => Some(Value::String("required".into())), + ToolChoiceMode::Tool => { + let name = config.tool_name.as_ref()?; + Some(json!({ + "type": "function", + "function": { + "name": name + } + })) + } + } +} + +/// Convert ToolChoiceConfig to OpenAI Responses API `tool_choice` value. +/// +/// Output format: +/// - `"auto"`, `"none"`, `"required"` for simple modes +/// - `{ type: "function", name: "..." }` for specific tool (flatter than Chat) +fn to_openai_responses(config: &ToolChoiceConfig) -> Option { + // If we have raw and it looks like Responses format, prefer it + if let Some(ref raw) = config.raw { + // Responses format has flat structure (name at top level, not nested in function) + if raw.is_string() + || raw.as_object().map_or(false, |o| { + o.contains_key("name") && !o.contains_key("function") + }) + { + return Some(raw.clone()); + } + } + + let mode = config.mode?; + + match mode { + ToolChoiceMode::Auto => Some(Value::String("auto".into())), + ToolChoiceMode::None => Some(Value::String("none".into())), + ToolChoiceMode::Required => Some(Value::String("required".into())), + ToolChoiceMode::Tool => { + let name = config.tool_name.as_ref()?; + Some(json!({ + "type": "function", + "name": name + })) + } + } +} + +/// Convert ToolChoiceConfig to Anthropic `tool_choice` value. +/// +/// Output format: +/// - `{ type: "auto" }`, `{ type: "any" }`, `{ type: "none" }` +/// - `{ type: "tool", name: "..." }` +/// - Includes `disable_parallel_tool_use` if set +fn to_anthropic(config: &ToolChoiceConfig, parallel_tool_calls: Option) -> Option { + // If we have raw and it looks like Anthropic format, prefer it + // But we may need to merge in disable_parallel_tool_use from params + if let Some(ref raw) = config.raw { + if raw.as_object().map_or(false, |o| { + o.get("type") + .and_then(Value::as_str) + .map_or(false, |t| ["auto", "any", "none", "tool"].contains(&t)) + }) { + // If parallel_tool_calls is explicitly false, we need to add disable_parallel_tool_use + if parallel_tool_calls == Some(false) { + let mut obj = raw.as_object().cloned().unwrap_or_default(); + obj.insert("disable_parallel_tool_use".into(), Value::Bool(true)); + return Some(Value::Object(obj)); + } + return Some(raw.clone()); + } + } + + let mode = config.mode?; + + let mut obj = Map::new(); + obj.insert( + "type".into(), + Value::String(mode.as_anthropic_str().into()), + ); + + if mode == ToolChoiceMode::Tool { + if let Some(ref name) = config.tool_name { + obj.insert("name".into(), Value::String(name.clone())); + } + } + + // Map parallel_tool_calls: false → disable_parallel_tool_use: true + if parallel_tool_calls == Some(false) || config.disable_parallel == Some(true) { + obj.insert("disable_parallel_tool_use".into(), Value::Bool(true)); + } + + Some(Value::Object(obj)) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::convert::TryInto; + + #[test] + fn test_from_openai_chat_string() { + let value = json!("auto"); + let config: ToolChoiceConfig = (ProviderFormat::OpenAI, &value).try_into().unwrap(); + assert_eq!(config.mode, Some(ToolChoiceMode::Auto)); + assert_eq!(config.tool_name, None); + } + + #[test] + fn test_from_openai_chat_function() { + let value = json!({ + "type": "function", + "function": { "name": "get_weather" } + }); + let config: ToolChoiceConfig = (ProviderFormat::OpenAI, &value).try_into().unwrap(); + assert_eq!(config.mode, Some(ToolChoiceMode::Tool)); + assert_eq!(config.tool_name, Some("get_weather".into())); + } + + #[test] + fn test_from_anthropic_tool() { + let value = json!({ + "type": "tool", + "name": "get_weather" + }); + let config: ToolChoiceConfig = (ProviderFormat::Anthropic, &value).try_into().unwrap(); + assert_eq!(config.mode, Some(ToolChoiceMode::Tool)); + assert_eq!(config.tool_name, Some("get_weather".into())); + } + + #[test] + fn test_from_anthropic_with_disable_parallel() { + let value = json!({ + "type": "auto", + "disable_parallel_tool_use": true + }); + let config: ToolChoiceConfig = (ProviderFormat::Anthropic, &value).try_into().unwrap(); + assert_eq!(config.mode, Some(ToolChoiceMode::Auto)); + assert_eq!(config.disable_parallel, Some(true)); + } + + #[test] + fn test_to_openai_chat_auto() { + let config = ToolChoiceConfig { + mode: Some(ToolChoiceMode::Auto), + ..Default::default() + }; + let value = config.to_provider(ProviderFormat::OpenAI, None).unwrap().unwrap(); + assert_eq!(value, json!("auto")); + } + + #[test] + fn test_to_openai_chat_function() { + let config = ToolChoiceConfig { + mode: Some(ToolChoiceMode::Tool), + tool_name: Some("get_weather".into()), + ..Default::default() + }; + let value = config.to_provider(ProviderFormat::OpenAI, None).unwrap().unwrap(); + assert_eq!( + value, + json!({ + "type": "function", + "function": { "name": "get_weather" } + }) + ); + } + + #[test] + fn test_to_anthropic_any() { + let config = ToolChoiceConfig { + mode: Some(ToolChoiceMode::Required), + ..Default::default() + }; + let value = config.to_provider(ProviderFormat::Anthropic, None).unwrap().unwrap(); + assert_eq!(value.get("type").unwrap(), "any"); + } + + #[test] + fn test_to_anthropic_with_parallel_disabled() { + let config = ToolChoiceConfig { + mode: Some(ToolChoiceMode::Auto), + ..Default::default() + }; + // parallel_tool_calls: false → disable_parallel_tool_use: true + let value = config.to_provider(ProviderFormat::Anthropic, Some(false)).unwrap().unwrap(); + assert_eq!(value.get("type").unwrap(), "auto"); + assert_eq!(value.get("disable_parallel_tool_use").unwrap(), true); + } + + #[test] + fn test_roundtrip_openai_chat() { + let original = json!({ + "type": "function", + "function": { "name": "get_weather" } + }); + let config: ToolChoiceConfig = (ProviderFormat::OpenAI, &original).try_into().unwrap(); + let back = config.to_provider(ProviderFormat::OpenAI, None).unwrap().unwrap(); + assert_eq!(original, back); + } + + #[test] + fn test_cross_provider_openai_to_anthropic() { + // OpenAI required → Anthropic any + let openai_value = json!("required"); + let config: ToolChoiceConfig = (ProviderFormat::OpenAI, &openai_value).try_into().unwrap(); + let anthropic_value = config.to_provider(ProviderFormat::Anthropic, None).unwrap().unwrap(); + assert_eq!(anthropic_value.get("type").unwrap(), "any"); + } +} diff --git a/crates/lingua/src/universal/tools.rs b/crates/lingua/src/universal/tools.rs new file mode 100644 index 00000000..bddfb34e --- /dev/null +++ b/crates/lingua/src/universal/tools.rs @@ -0,0 +1,375 @@ +/*! +Tool format conversion utilities for cross-provider semantic translation. + +This module provides bidirectional conversion between different providers' +tool formats: +- Anthropic: `{"name": "...", "description": "...", "input_schema": {...}}` +- OpenAI: `{"type": "function", "function": {"name": "...", "description": "...", "parameters": {...}}}` + +## Design + +Tools are a complex case because different providers have fundamentally different +structures. Unlike simple fields like `stop` or `tool_choice`, tools require +structural transformation rather than just field renaming. + +Anthropic built-in tools (bash, text_editor, web_search) have a "type" field +at the root level, but custom tools do not. OpenAI always requires "type": "function" +with the tool definition nested under "function". +*/ + +use crate::serde_json::{json, Value}; + +// ============================================================================= +// Format Detection +// ============================================================================= + +/// Check if tools are in OpenAI format (have "type": "function" wrapper). +pub fn is_openai_format(tools: &Value) -> bool { + tools + .as_array() + .and_then(|arr| arr.first()) + .and_then(|t| t.get("type")) + .and_then(Value::as_str) + .is_some_and(|t| t == "function") +} + +/// Check if tools are in Anthropic custom tool format (have "name" at root, no "type"). +pub fn is_anthropic_custom_format(tools: &Value) -> bool { + tools + .as_array() + .and_then(|arr| arr.first()) + .map(|t| t.get("name").is_some() && t.get("type").is_none()) + .unwrap_or(false) +} + +/// Check if tools are in OpenAI Responses API format. +/// +/// Responses API tools have `type` at top level (function, code_interpreter, web_search_preview, etc.) +/// WITHOUT a nested `function` object: +/// - Function tools: `{type: "function", name: "...", description: "...", parameters: {...}, strict: ...}` +/// - Built-in tools: `{type: "code_interpreter", ...}` or `{type: "web_search_preview"}` +/// +/// Contrast with OpenAI Chat format which nests function definition under `function`: +/// `{type: "function", function: {name: "...", description: "...", parameters: {...}}}` +/// +/// This returns true for ANY Responses API tool format (function or built-in). +pub fn is_responses_tool_format(tools: &Value) -> bool { + tools + .as_array() + .and_then(|arr| arr.first()) + .is_some_and(|tool| { + // Must have a "type" field (distinguishes from Anthropic custom format) + let has_type = tool.get("type").and_then(Value::as_str).is_some(); + // Must NOT have a nested "function" field (distinguishes from OpenAI Chat format) + let has_function_wrapper = tool.get("function").is_some(); + has_type && !has_function_wrapper + }) +} + +/// Check if tools contain Anthropic built-in tools (bash, text_editor, web_search). +/// +/// Returns the name of the first built-in tool found, or None if no built-in tools. +pub fn find_builtin_tool(tools: &Value) -> Option { + let arr = tools.as_array()?; + for tool in arr { + if let Some(tool_type) = tool.get("type").and_then(Value::as_str) { + if tool_type.starts_with("bash_") + || tool_type.starts_with("text_editor_") + || tool_type.starts_with("web_search_") + { + return Some(tool_type.to_string()); + } + } + } + None +} + +// ============================================================================= +// Anthropic → OpenAI Conversion +// ============================================================================= + +/// Convert Anthropic tool format to OpenAI format. +/// +/// Handles: +/// - Custom tools: `{"name", "description", "input_schema"}` → `{"type": "function", "function": {...}}` +/// - Built-in tools: Skipped (they have no OpenAI equivalent) +/// +/// # Returns +/// +/// - `Some(Value::Array)` with converted tools +/// - `None` if no convertible tools found +pub fn anthropic_to_openai_tools(anthropic_tools: &Value) -> Option { + let arr = anthropic_tools.as_array()?; + let converted: Vec = arr + .iter() + .filter_map(|tool| { + // Check for built-in tools (have "type" field like "bash_20250124") + // Skip them as they can't be converted to OpenAI format + if tool.get("type").and_then(Value::as_str).is_some() { + return None; + } + + let name = tool.get("name")?.as_str()?; + let description = tool.get("description").and_then(Value::as_str); + let input_schema = tool.get("input_schema").cloned().unwrap_or(json!({})); + + Some(json!({ + "type": "function", + "function": { + "name": name, + "description": description, + "parameters": input_schema + } + })) + }) + .collect(); + + if converted.is_empty() { + None + } else { + Some(Value::Array(converted)) + } +} + +// ============================================================================= +// OpenAI → Anthropic Conversion +// ============================================================================= + +/// Convert OpenAI tool format to Anthropic format. +/// +/// Handles: +/// - Function tools: `{"type": "function", "function": {...}}` → `{"name", "description", "input_schema"}` +/// - Other tool types: Skipped +/// +/// # Returns +/// +/// - `Some(Value::Array)` with converted tools +/// - `None` if no convertible tools found +pub fn openai_to_anthropic_tools(openai_tools: &Value) -> Option { + let arr = openai_tools.as_array()?; + let converted: Vec = arr + .iter() + .filter_map(|tool| { + // Only convert function tools + if tool.get("type").and_then(Value::as_str) != Some("function") { + return None; + } + + let func = tool.get("function")?; + let name = func.get("name")?.as_str()?; + let description = func.get("description").and_then(Value::as_str); + let parameters = func.get("parameters").cloned().unwrap_or(json!({})); + + Some(json!({ + "name": name, + "description": description, + "input_schema": parameters + })) + }) + .collect(); + + if converted.is_empty() { + None + } else { + Some(Value::Array(converted)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_anthropic_to_openai_custom_tool() { + let anthropic = json!([{ + "name": "get_weather", + "description": "Get the weather for a location", + "input_schema": { + "type": "object", + "properties": { + "location": {"type": "string"} + } + } + }]); + + let openai = anthropic_to_openai_tools(&anthropic).unwrap(); + let tool = openai.as_array().unwrap().first().unwrap(); + + assert_eq!(tool.get("type").unwrap(), "function"); + assert_eq!(tool["function"]["name"], "get_weather"); + assert_eq!( + tool["function"]["description"], + "Get the weather for a location" + ); + assert!(tool["function"]["parameters"]["properties"]["location"].is_object()); + } + + #[test] + fn test_anthropic_to_openai_skips_builtin() { + let anthropic = json!([ + { + "type": "bash_20250124", + "name": "bash" + }, + { + "name": "get_weather", + "description": "Get weather", + "input_schema": {} + } + ]); + + let openai = anthropic_to_openai_tools(&anthropic).unwrap(); + let arr = openai.as_array().unwrap(); + + // Should only have the custom tool, not the built-in + assert_eq!(arr.len(), 1); + assert_eq!(arr[0]["function"]["name"], "get_weather"); + } + + #[test] + fn test_anthropic_to_openai_all_builtin_returns_none() { + let anthropic = json!([{ + "type": "bash_20250124", + "name": "bash" + }]); + + let result = anthropic_to_openai_tools(&anthropic); + assert!(result.is_none()); + } + + #[test] + fn test_openai_to_anthropic() { + let openai = json!([{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather", + "parameters": { + "type": "object", + "properties": {} + } + } + }]); + + let anthropic = openai_to_anthropic_tools(&openai).unwrap(); + let tool = anthropic.as_array().unwrap().first().unwrap(); + + assert_eq!(tool.get("name").unwrap(), "get_weather"); + assert_eq!(tool.get("description").unwrap(), "Get the weather"); + assert!(tool.get("input_schema").is_some()); + // Should NOT have "type" field (that's for built-in tools only) + assert!(tool.get("type").is_none()); + } + + #[test] + fn test_is_openai_format() { + let openai = json!([{"type": "function", "function": {"name": "test"}}]); + let anthropic = json!([{"name": "test", "description": "...", "input_schema": {}}]); + + assert!(is_openai_format(&openai)); + assert!(!is_openai_format(&anthropic)); + } + + #[test] + fn test_is_anthropic_custom_format() { + let anthropic_custom = json!([{"name": "test", "description": "...", "input_schema": {}}]); + let anthropic_builtin = json!([{"type": "bash_20250124", "name": "bash"}]); + let openai = json!([{"type": "function", "function": {"name": "test"}}]); + + assert!(is_anthropic_custom_format(&anthropic_custom)); + assert!(!is_anthropic_custom_format(&anthropic_builtin)); // has "type" + assert!(!is_anthropic_custom_format(&openai)); + } + + #[test] + fn test_find_builtin_tool() { + let with_bash = json!([{"type": "bash_20250124", "name": "bash"}]); + let with_text_editor = json!([{"type": "text_editor_20250429", "name": "editor"}]); + let custom_only = json!([{"name": "test", "input_schema": {}}]); + + assert_eq!(find_builtin_tool(&with_bash), Some("bash_20250124".into())); + assert_eq!( + find_builtin_tool(&with_text_editor), + Some("text_editor_20250429".into()) + ); + assert_eq!(find_builtin_tool(&custom_only), None); + } + + #[test] + fn test_roundtrip_openai_anthropic_openai() { + let original = json!([{ + "type": "function", + "function": { + "name": "get_data", + "description": "Fetches data", + "parameters": {"type": "object"} + } + }]); + + let anthropic = openai_to_anthropic_tools(&original).unwrap(); + let back = anthropic_to_openai_tools(&anthropic).unwrap(); + + let orig_tool = original.as_array().unwrap().first().unwrap(); + let back_tool = back.as_array().unwrap().first().unwrap(); + + assert_eq!(orig_tool["type"], back_tool["type"]); + assert_eq!( + orig_tool["function"]["name"], + back_tool["function"]["name"] + ); + assert_eq!( + orig_tool["function"]["description"], + back_tool["function"]["description"] + ); + } + + #[test] + fn test_is_responses_tool_format() { + // Responses API function tool: name at top level, no "function" wrapper + let responses_function = json!([{ + "type": "function", + "name": "get_weather", + "description": "Get the weather", + "parameters": {"type": "object"}, + "strict": false + }]); + + // Responses API built-in tools (code_interpreter, web_search_preview) + let responses_code_interpreter = json!([{ + "type": "code_interpreter", + "container": {"type": "auto"} + }]); + + let responses_web_search = json!([{ + "type": "web_search_preview" + }]); + + // OpenAI Chat format: nested under "function" + let chat_format = json!([{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather", + "parameters": {"type": "object"} + } + }]); + + // Anthropic custom format: no "type" field + let anthropic_custom = json!([{ + "name": "get_weather", + "description": "Get the weather", + "input_schema": {"type": "object"} + }]); + + // All Responses API formats should match + assert!(is_responses_tool_format(&responses_function)); + assert!(is_responses_tool_format(&responses_code_interpreter)); + assert!(is_responses_tool_format(&responses_web_search)); + + // OpenAI Chat format should NOT match (has "function" wrapper) + assert!(!is_responses_tool_format(&chat_format)); + + // Anthropic custom format should NOT match (no "type" field) + assert!(!is_responses_tool_format(&anthropic_custom)); + } +}