diff --git a/crates/lingua/src/providers/openai/adapter.rs b/crates/lingua/src/providers/openai/adapter.rs index 79fc9efc..2cfbc988 100644 --- a/crates/lingua/src/providers/openai/adapter.rs +++ b/crates/lingua/src/providers/openai/adapter.rs @@ -16,7 +16,7 @@ use crate::providers::openai::capabilities::{OpenAICapabilities, TargetProvider} use crate::providers::openai::generated::{ AllowedToolsFunction, ChatCompletionRequestMessage, ChatCompletionRequestMessageContent, ChatCompletionRequestMessageContentPart, ChatCompletionRequestMessageRole, - ChatCompletionResponseMessage, ChatCompletionToolChoiceOption, + ChatCompletionResponseMessage, ChatCompletionToolChoiceOption, CompletionUsage, CreateChatCompletionRequestClass, CreateResponseClass, File, FunctionObject, FunctionToolChoiceClass, FunctionToolChoiceType, InputItem, InputItemContent, InputItemRole, InputItemType, Instructions, PurpleType, ResponseFormatType, ToolElement, ToolType, @@ -74,6 +74,40 @@ const RESPONSES_KNOWN_KEYS: &[&str] = &[ // metadata, parallel_tool_calls ]; +mod streaming_types { + use serde::{Deserialize, Serialize}; + + use crate::serde_json::{Map, Value}; + + #[derive(Debug, Clone, Deserialize, Serialize)] + pub struct StreamResponse { + pub choices: Vec, + pub created: i64, + pub id: String, + pub model: String, + pub object: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub service_tier: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, + #[serde(flatten)] + pub extra: Map, + } + + #[derive(Debug, Clone, Deserialize, Serialize)] + pub struct StreamChoice { + pub delta: Value, + pub finish_reason: Option, + pub index: i64, + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + #[serde(flatten)] + pub extra: Map, + } +} + /// Adapter for OpenAI Chat Completions API. pub struct OpenAIAdapter; @@ -226,19 +260,13 @@ impl ProviderAdapter for OpenAIAdapter { } } - let usage = payload.get("usage").map(|u| UniversalUsage { - prompt_tokens: u.get("prompt_tokens").and_then(Value::as_i64), - completion_tokens: u.get("completion_tokens").and_then(Value::as_i64), - prompt_cached_tokens: u - .get("prompt_tokens_details") - .and_then(|d| d.get("cached_tokens")) - .and_then(Value::as_i64), - prompt_cache_creation_tokens: None, // OpenAI doesn't report cache creation tokens - completion_reasoning_tokens: u - .get("completion_tokens_details") - .and_then(|d| d.get("reasoning_tokens")) - .and_then(Value::as_i64), - }); + // Parse usage with typed struct, then convert to universal format + let usage = payload + .get("usage") + .map(|u| serde_json::from_value::(u.clone())) + .transpose() + .map_err(|e| TransformError::ToUniversalFailed(format!("invalid usage: {}", e)))? + .map(|u| UniversalUsage::from(&u)); Ok(UniversalResponse { model: payload @@ -340,118 +368,82 @@ impl ProviderAdapter for OpenAIAdapter { &self, payload: Value, ) -> Result, TransformError> { - // OpenAI is the canonical format, so this is mostly direct mapping - let choices = payload - .get("choices") - .and_then(Value::as_array) - .map(|arr| { - arr.iter() - .map(|c| { - let index = c.get("index").and_then(Value::as_u64).unwrap_or(0) as u32; - let delta = c.get("delta").cloned(); - let finish_reason = c - .get("finish_reason") - .and_then(Value::as_str) - .map(String::from); - UniversalStreamChoice { - index, - delta, - finish_reason, - } - }) - .collect::>() + // Parse into typed struct with proper nullable handling + let response: streaming_types::StreamResponse = serde_json::from_value(payload) + .map_err(|e| TransformError::ToUniversalFailed(e.to_string()))?; + + // Convert choices - delta and finish_reason pass through directly (both already correct types) + let choices: Vec = response + .choices + .into_iter() + .map(|c| UniversalStreamChoice { + index: c.index as u32, + delta: Some(c.delta), + finish_reason: c.finish_reason, }) - .unwrap_or_default(); + .collect(); - // Extract usage if present (usually only on final chunk) - let usage = payload.get("usage").map(|u| UniversalUsage { - prompt_tokens: u.get("prompt_tokens").and_then(Value::as_i64), - completion_tokens: u.get("completion_tokens").and_then(Value::as_i64), - prompt_cached_tokens: u - .get("prompt_tokens_details") - .and_then(|d| d.get("cached_tokens")) - .and_then(Value::as_i64), - prompt_cache_creation_tokens: None, - completion_reasoning_tokens: u - .get("completion_tokens_details") - .and_then(|d| d.get("reasoning_tokens")) - .and_then(Value::as_i64), - }); + // Parse usage on final chunk if needed for billing/metrics + let usage = response + .usage + .as_ref() + .and_then(|u| serde_json::from_value::(u.clone()).ok()) + .map(|u| UniversalUsage::from(&u)); Ok(Some(UniversalStreamChunk::new( - payload.get("id").and_then(Value::as_str).map(String::from), - payload - .get("model") - .and_then(Value::as_str) - .map(String::from), + Some(response.id), + Some(response.model), choices, - payload.get("created").and_then(Value::as_u64), + Some(response.created as u64), usage, ))) } fn stream_from_universal(&self, chunk: &UniversalStreamChunk) -> Result { - // Convert back to OpenAI streaming format + // Keep-alive: return minimal JSON (don't emit fake id/model/created) if chunk.is_keep_alive() { - // Return empty chunk for keep-alive return Ok(serde_json::json!({ "object": "chat.completion.chunk", "choices": [] })); } - let choices: Vec = chunk + // Build choices - delta and finish_reason pass through directly + let choices: Vec = chunk .choices .iter() - .map(|c| { - let mut choice = serde_json::json!({ - "index": c.index, - "delta": c.delta.clone().unwrap_or(Value::Object(Map::new())) - }); - if let Some(ref reason) = c.finish_reason { - choice - .as_object_mut() - .unwrap() - .insert("finish_reason".into(), Value::String(reason.clone())); - } else { - choice - .as_object_mut() - .unwrap() - .insert("finish_reason".into(), Value::Null); - } - choice + .map(|c| streaming_types::StreamChoice { + delta: c.delta.clone().unwrap_or(serde_json::json!({})), + finish_reason: c.finish_reason.clone(), + index: c.index as i64, + logprobs: None, + extra: Map::new(), }) .collect(); - let mut obj = serde_json::json!({ - "object": "chat.completion.chunk", - "choices": choices + // Build usage as Value if present + let usage = chunk.usage.as_ref().map(|u| { + serde_json::json!({ + "prompt_tokens": u.prompt_tokens.unwrap_or(0), + "completion_tokens": u.completion_tokens.unwrap_or(0), + "total_tokens": u.prompt_tokens.unwrap_or(0) + u.completion_tokens.unwrap_or(0) + }) }); - let obj_map = obj.as_object_mut().unwrap(); - if let Some(ref id) = chunk.id { - obj_map.insert("id".into(), Value::String(id.clone())); - } - if let Some(ref model) = chunk.model { - obj_map.insert("model".into(), Value::String(model.clone())); - } - if let Some(created) = chunk.created { - obj_map.insert("created".into(), Value::Number(created.into())); - } - if let Some(ref usage) = chunk.usage { - let prompt = usage.prompt_tokens.unwrap_or(0); - let completion = usage.completion_tokens.unwrap_or(0); - obj_map.insert( - "usage".into(), - serde_json::json!({ - "prompt_tokens": prompt, - "completion_tokens": completion, - "total_tokens": prompt + completion - }), - ); - } + let response = streaming_types::StreamResponse { + choices, + created: chunk.created.unwrap_or(0) as i64, + id: chunk.id.clone().unwrap_or_default(), + model: chunk.model.clone().unwrap_or_default(), + object: "chat.completion.chunk".to_string(), + service_tier: None, + system_fingerprint: None, + usage, + extra: Map::new(), + }; - Ok(obj) + serde_json::to_value(&response) + .map_err(|e| TransformError::SerializationFailed(e.to_string())) } } diff --git a/crates/lingua/src/providers/openai/convert.rs b/crates/lingua/src/providers/openai/convert.rs index 002421de..267d4fe6 100644 --- a/crates/lingua/src/providers/openai/convert.rs +++ b/crates/lingua/src/providers/openai/convert.rs @@ -4,8 +4,9 @@ use crate::serde_json; use crate::universal::convert::TryFromLLM; use crate::universal::defaults::{EMPTY_OBJECT_STR, REFUSAL_TEXT}; use crate::universal::{ - AssistantContent, AssistantContentPart, Message, TextContentPart, ToolContentPart, - ToolResultContentPart, UserContent, UserContentPart, + AssistantContent, AssistantContentPart, FinishReason as UniversalFinishReason, Message, + TextContentPart, ToolContentPart, ToolResultContentPart, UniversalUsage, UserContent, + UserContentPart, }; /// Convert OpenAI InputItem collection to universal Message collection @@ -1546,3 +1547,77 @@ impl TryFromLLM<&Message> for openai::ChatCompletionResponseMessage { } } } + +/// Convert OpenAI generated FinishReason to universal FinishReason +impl From<&openai::FinishReason> for UniversalFinishReason { + fn from(reason: &openai::FinishReason) -> Self { + match reason { + openai::FinishReason::Stop => UniversalFinishReason::Stop, + openai::FinishReason::Length => UniversalFinishReason::Length, + openai::FinishReason::ToolCalls => UniversalFinishReason::ToolCalls, + openai::FinishReason::ContentFilter => UniversalFinishReason::ContentFilter, + openai::FinishReason::FunctionCall => { + UniversalFinishReason::Other("function_call".to_string()) + } + } + } +} + +/// Convert universal FinishReason to OpenAI generated FinishReason +impl From<&UniversalFinishReason> for openai::FinishReason { + fn from(reason: &UniversalFinishReason) -> Self { + match reason { + UniversalFinishReason::Stop => openai::FinishReason::Stop, + UniversalFinishReason::Length => openai::FinishReason::Length, + UniversalFinishReason::ToolCalls => openai::FinishReason::ToolCalls, + UniversalFinishReason::ContentFilter => openai::FinishReason::ContentFilter, + UniversalFinishReason::Other(s) if s == "function_call" => { + openai::FinishReason::FunctionCall + } + // Default to Stop for other cases + UniversalFinishReason::Other(_) => openai::FinishReason::Stop, + } + } +} + +/// Convert FinishReason to its string representation (for streaming) +pub fn finish_reason_to_string(reason: &openai::FinishReason) -> String { + match reason { + openai::FinishReason::Stop => "stop".to_string(), + openai::FinishReason::Length => "length".to_string(), + openai::FinishReason::ToolCalls => "tool_calls".to_string(), + openai::FinishReason::ContentFilter => "content_filter".to_string(), + openai::FinishReason::FunctionCall => "function_call".to_string(), + } +} + +/// Parse a string to OpenAI FinishReason (for streaming) +pub fn string_to_finish_reason(s: &str) -> openai::FinishReason { + match s { + "stop" => openai::FinishReason::Stop, + "length" => openai::FinishReason::Length, + "tool_calls" => openai::FinishReason::ToolCalls, + "content_filter" => openai::FinishReason::ContentFilter, + "function_call" => openai::FinishReason::FunctionCall, + _ => openai::FinishReason::Stop, // Default fallback + } +} + +/// Convert OpenAI CompletionUsage to universal UniversalUsage +impl From<&openai::CompletionUsage> for UniversalUsage { + fn from(usage: &openai::CompletionUsage) -> Self { + UniversalUsage { + prompt_tokens: Some(usage.prompt_tokens), + completion_tokens: Some(usage.completion_tokens), + prompt_cached_tokens: usage + .prompt_tokens_details + .as_ref() + .and_then(|d| d.cached_tokens), + prompt_cache_creation_tokens: None, // OpenAI doesn't report cache creation tokens + completion_reasoning_tokens: usage + .completion_tokens_details + .as_ref() + .and_then(|d| d.reasoning_tokens), + } + } +}