diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index fa9964dd2..90781fade 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -56,6 +56,8 @@ pub struct StreamContext { http_protocol: Option, sse_buffer: Option, sse_chunk_processor: Option, + /// Accumulates upstream non-streaming response chunks until end of stream. + non_streaming_response_buffer: Vec, } impl StreamContext { @@ -87,6 +89,7 @@ impl StreamContext { http_protocol: None, sse_buffer: None, sse_chunk_processor: None, + non_streaming_response_buffer: Vec::new(), } } @@ -816,6 +819,31 @@ impl StreamContext { } } } + + fn flush_streaming_response_tail(&mut self) -> Option> { + let provider_id = self.get_provider_id(); + let has_buffered_sse = self + .sse_chunk_processor + .as_ref() + .is_some_and(|processor| processor.has_buffered_data()); + + if has_buffered_sse { + match self.handle_streaming_response(&[], provider_id) { + Ok(bytes) if !bytes.is_empty() => return Some(bytes), + Ok(_) => {} + Err(_) => return None, + } + } + + self.sse_buffer.as_mut().and_then(|buffer| { + let bytes = buffer.to_bytes(); + if bytes.is_empty() { + None + } else { + Some(bytes) + } + }) + } } // HttpContext is the trait that allows the Rust code to interact with HTTP objects. @@ -1174,6 +1202,21 @@ impl HttpContext for StreamContext { let current_time = get_current_time().unwrap(); if end_of_stream && body_size == 0 { + if self.streaming_response { + if let Some(serialized_body) = self.flush_streaming_response_tail() { + self.set_http_response_body(0, 0, &serialized_body); + } + } else if !self.non_streaming_response_buffer.is_empty() { + let body = std::mem::take(&mut self.non_streaming_response_buffer); + let provider_id = self.get_provider_id(); + match self.handle_non_streaming_response(&body, provider_id) { + Ok(serialized_body) => { + self.set_http_response_body(0, 0, &serialized_body); + } + Err(action) => return action, + } + } + debug!( "request_id={}: response body complete, total_bytes={}", self.request_identifier(), @@ -1248,7 +1291,15 @@ impl HttpContext for StreamContext { Err(action) => return action, } } else { - match self.handle_non_streaming_response(&body, provider_id) { + self.non_streaming_response_buffer.extend_from_slice(&body); + if !end_of_stream { + // Hold chunks until the full JSON body arrives. + self.set_http_response_body(0, body_size, &[]); + return Action::Continue; + } + + let complete_body = std::mem::take(&mut self.non_streaming_response_buffer); + match self.handle_non_streaming_response(&complete_body, provider_id) { Ok(serialized_body) => { self.set_http_response_body(0, body_size, &serialized_body); } diff --git a/crates/prompt_gateway/src/http_context.rs b/crates/prompt_gateway/src/http_context.rs index e3d00b3ff..fdbacacde 100644 --- a/crates/prompt_gateway/src/http_context.rs +++ b/crates/prompt_gateway/src/http_context.rs @@ -7,8 +7,7 @@ use common::{ ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_ROUTING_HEADER, ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, HEALTHZ_PATH, MODEL_SERVER_NAME, MODEL_SERVER_REQUEST_TIMEOUT_MS, REQUEST_ID_HEADER, TOOL_ROLE, - TRACE_PARENT_HEADER, USER_ROLE, X_ARCH_API_RESPONSE, X_ARCH_FC_MODEL_RESPONSE, - X_ARCH_STATE_HEADER, X_ARCH_TOOL_CALL, + TRACE_PARENT_HEADER, USER_ROLE, X_ARCH_STATE_HEADER, }, errors::ServerError, http::{CallArgs, Client}, @@ -17,7 +16,6 @@ use common::{ use http::StatusCode; use log::{debug, info, warn}; use proxy_wasm::{traits::HttpContext, types::Action}; -use serde_json::Value; use std::{ collections::HashMap, time::{Duration, SystemTime, UNIX_EPOCH}, @@ -291,6 +289,10 @@ impl HttpContext for StreamContext { } if end_of_stream && body_size == 0 { + if !self.streaming_response && !self.non_streaming_response_buffer.is_empty() { + let body = std::mem::take(&mut self.non_streaming_response_buffer); + self.process_non_streaming_response_body(&body, 0); + } return Action::Continue; } @@ -326,15 +328,15 @@ impl HttpContext for StreamContext { } }; - let body_utf8 = match String::from_utf8(body) { - Ok(body_utf8) => body_utf8, - Err(e) => { - info!("could not convert to utf8: {}", e); - return Action::Continue; - } - }; - if self.streaming_response { + let body_utf8 = match String::from_utf8(body) { + Ok(body_utf8) => body_utf8, + Err(e) => { + info!("could not convert to utf8: {}", e); + return Action::Continue; + } + }; + debug!("streaming response"); if self.tool_calls.is_some() && !self.tool_calls.as_ref().unwrap().is_empty() { @@ -359,70 +361,15 @@ impl HttpContext for StreamContext { self.set_http_response_body(0, body_size, response_str.as_bytes()); self.tool_calls = None; } - } else if let Some(tool_calls) = self.tool_calls.as_ref() { - if !tool_calls.is_empty() { - if self.arch_state.is_none() { - self.arch_state = Some(Vec::new()); - } - - let mut data = match serde_json::from_str(&body_utf8) { - Ok(data) => data, - Err(e) => { - warn!( - "could not deserialize response, sending data as it is: {}", - e - ); - return Action::Continue; - } - }; - // use serde::Value to manipulate the json object and ensure that we don't lose any data - if let Value::Object(ref mut map) = data { - // serialize arch state and add to metadata - let metadata = map - .entry("metadata") - .or_insert(Value::Object(serde_json::Map::new())); - if metadata == &Value::Null { - *metadata = Value::Object(serde_json::Map::new()); - } - - let tool_call_message = self.generate_tool_call_message(); - let tool_call_message_str = serde_json::to_string(&tool_call_message).unwrap(); - metadata.as_object_mut().unwrap().insert( - X_ARCH_TOOL_CALL.to_string(), - serde_json::Value::String(tool_call_message_str), - ); - - let api_response_message = self.generate_api_response_message(); - let api_response_message_str = - serde_json::to_string(&api_response_message).unwrap(); - metadata.as_object_mut().unwrap().insert( - X_ARCH_API_RESPONSE.to_string(), - serde_json::Value::String(api_response_message_str), - ); - - let fc_messages = vec![tool_call_message, api_response_message]; - - let fc_messages_str = serde_json::to_string(&fc_messages).unwrap(); - let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]); - let arch_state_str = serde_json::to_string(&arch_state).unwrap(); - metadata.as_object_mut().unwrap().insert( - X_ARCH_STATE_HEADER.to_string(), - serde_json::Value::String(arch_state_str), - ); - - if let Some(arch_fc_response) = self.arch_fc_response.as_ref() { - metadata.as_object_mut().unwrap().insert( - X_ARCH_FC_MODEL_RESPONSE.to_string(), - serde_json::Value::String( - serde_json::to_string(arch_fc_response).unwrap(), - ), - ); - } - let data_serialized = serde_json::to_string(&data).unwrap(); - info!("plano <= developer: {}", data_serialized); - self.set_http_response_body(0, body_size, data_serialized.as_bytes()); - }; + } else { + self.non_streaming_response_buffer.extend_from_slice(&body); + if !end_of_stream { + self.set_http_response_body(0, body_size, &[]); + return Action::Continue; } + + let complete_body = std::mem::take(&mut self.non_streaming_response_buffer); + self.process_non_streaming_response_body(&complete_body, body_size); } debug!("recv [S={}] end_stream={}", self.context_id, end_of_stream); diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 8ff44d522..e2754ebd8 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -9,7 +9,7 @@ use common::consts::{ API_REQUEST_TIMEOUT_MS, ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, DEFAULT_TARGET_REQUEST_TIMEOUT_MS, MESSAGES_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE, - X_ARCH_FC_MODEL_RESPONSE, + X_ARCH_API_RESPONSE, X_ARCH_FC_MODEL_RESPONSE, X_ARCH_STATE_HEADER, X_ARCH_TOOL_CALL, }; use common::errors::ServerError; use common::http::{CallArgs, Client}; @@ -18,6 +18,7 @@ use derivative::Derivative; use http::StatusCode; use log::{debug, info, warn}; use proxy_wasm::traits::*; +use serde_json::Value; use std::cell::RefCell; use std::collections::HashMap; use std::rc::Rc; @@ -66,6 +67,8 @@ pub struct StreamContext { pub traceparent: Option, pub _tracing: Rc>, pub arch_fc_response: Option, + /// Accumulates upstream non-streaming response chunks until end of stream. + pub non_streaming_response_buffer: Vec, } impl StreamContext { @@ -100,6 +103,7 @@ impl StreamContext { start_upstream_llm_request_time: 0, time_to_first_token: None, arch_fc_response: None, + non_streaming_response_buffer: Vec::new(), } } @@ -803,6 +807,80 @@ impl StreamContext { self.set_http_request_body(0, self.request_body_size, json_resp.as_bytes()); self.resume_http_request(); } + + pub fn process_non_streaming_response_body(&mut self, body: &[u8], body_size: usize) { + let body_utf8 = match String::from_utf8(body.to_vec()) { + Ok(body_utf8) => body_utf8, + Err(e) => { + info!("could not convert to utf8: {}", e); + return; + } + }; + + if let Some(tool_calls) = self.tool_calls.as_ref() { + if !tool_calls.is_empty() { + if self.arch_state.is_none() { + self.arch_state = Some(Vec::new()); + } + + let mut data = match serde_json::from_str(&body_utf8) { + Ok(data) => data, + Err(e) => { + warn!( + "could not deserialize response, sending data as it is: {}", + e + ); + return; + } + }; + if let Value::Object(ref mut map) = data { + let metadata = map + .entry("metadata") + .or_insert(Value::Object(serde_json::Map::new())); + if metadata == &Value::Null { + *metadata = Value::Object(serde_json::Map::new()); + } + + let tool_call_message = self.generate_tool_call_message(); + let tool_call_message_str = serde_json::to_string(&tool_call_message).unwrap(); + metadata.as_object_mut().unwrap().insert( + X_ARCH_TOOL_CALL.to_string(), + serde_json::Value::String(tool_call_message_str), + ); + + let api_response_message = self.generate_api_response_message(); + let api_response_message_str = + serde_json::to_string(&api_response_message).unwrap(); + metadata.as_object_mut().unwrap().insert( + X_ARCH_API_RESPONSE.to_string(), + serde_json::Value::String(api_response_message_str), + ); + + let fc_messages = vec![tool_call_message, api_response_message]; + + let fc_messages_str = serde_json::to_string(&fc_messages).unwrap(); + let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]); + let arch_state_str = serde_json::to_string(&arch_state).unwrap(); + metadata.as_object_mut().unwrap().insert( + X_ARCH_STATE_HEADER.to_string(), + serde_json::Value::String(arch_state_str), + ); + + if let Some(arch_fc_response) = self.arch_fc_response.as_ref() { + metadata.as_object_mut().unwrap().insert( + X_ARCH_FC_MODEL_RESPONSE.to_string(), + serde_json::Value::String( + serde_json::to_string(arch_fc_response).unwrap(), + ), + ); + } + let data_serialized = serde_json::to_string(&data).unwrap(); + info!("plano <= developer: {}", data_serialized); + self.set_http_response_body(0, body_size, data_serialized.as_bytes()); + } + } + } + } } fn check_intent_matched(model_server_response: &ChatCompletionsResponse) -> bool { diff --git a/demos/shared/test_runner/run_demo_tests.sh b/demos/shared/test_runner/run_demo_tests.sh index 098408142..44a433279 100644 --- a/demos/shared/test_runner/run_demo_tests.sh +++ b/demos/shared/test_runner/run_demo_tests.sh @@ -19,7 +19,8 @@ run_hurl_with_retries() { local max_attempts=1 local attempt=1 - if [ "$demo_name" = "llm_routing/preference_based_routing" ]; then + if [ "$demo_name" = "llm_routing/preference_based_routing" ] \ + || [ "$demo_name" = "advanced/currency_exchange" ]; then max_attempts=3 fi