From 6c0c0b4b5cd4fffeb029fe5a988b74b87da2698e Mon Sep 17 00:00:00 2001 From: Tinco Andringa Date: Tue, 12 Aug 2025 19:20:59 +0200 Subject: [PATCH 1/3] unify StreamError and OpenAIError --- async-openai/src/client.rs | 128 +++++++++++++-------- async-openai/src/error.rs | 29 ++++- async-openai/src/types/assistant_stream.rs | 8 +- 3 files changed, 111 insertions(+), 54 deletions(-) diff --git a/async-openai/src/client.rs b/async-openai/src/client.rs index fe2ed232..9ee92ef8 100644 --- a/async-openai/src/client.rs +++ b/async-openai/src/client.rs @@ -2,13 +2,13 @@ use std::pin::Pin; use bytes::Bytes; use futures::{stream::StreamExt, Stream}; -use reqwest::multipart::Form; -use reqwest_eventsource::{Event, EventSource, RequestBuilderExt}; +use reqwest::{multipart::Form, Response}; +use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt}; use serde::{de::DeserializeOwned, Serialize}; use crate::{ config::{Config, OpenAIConfig}, - error::{map_deserialization_error, ApiError, OpenAIError, WrappedError}, + error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError}, file::Files, image::Images, moderation::Moderations, @@ -335,52 +335,34 @@ impl Client { .map_err(backoff::Error::Permanent)?; let status = response.status(); - let bytes = response - .bytes() - .await - .map_err(OpenAIError::Reqwest) - .map_err(backoff::Error::Permanent)?; - if status.is_server_error() { - // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them. - let message: String = String::from_utf8_lossy(&bytes).into_owned(); - tracing::warn!("Server error: {status} - {message}"); - return Err(backoff::Error::Transient { - err: OpenAIError::ApiError(ApiError { - message, - r#type: None, - param: None, - code: None, - }), - retry_after: None, - }); - } - - // Deserialize response body from either error object or actual response object - if !status.is_success() { - let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref()) - .map_err(|e| map_deserialization_error(e, bytes.as_ref())) - .map_err(backoff::Error::Permanent)?; - - if status.as_u16() == 429 - // API returns 429 also when: - // "You exceeded your current quota, please check your plan and billing details." - && wrapped_error.error.r#type != Some("insufficient_quota".to_string()) - { - // Rate limited retry... - tracing::warn!("Rate limited: {}", wrapped_error.error.message); - return Err(backoff::Error::Transient { - err: OpenAIError::ApiError(wrapped_error.error), - retry_after: None, - }); - } else { - return Err(backoff::Error::Permanent(OpenAIError::ApiError( - wrapped_error.error, - ))); - } + match read_response(response).await { + Ok(bytes) => Ok(bytes), + Err(e) => { + match e { + OpenAIError::ApiError(api_error) => { + if status.is_server_error() { + Err(backoff::Error::Transient { + err: OpenAIError::ApiError(api_error), + retry_after: None, + }) + } else { + if status.as_u16() == 429 && api_error.r#type != Some("insufficient_quota".to_string()) { + // Rate limited retry... + tracing::warn!("Rate limited: {}", api_error.message); + Err(backoff::Error::Transient { + err: OpenAIError::ApiError(api_error), + retry_after: None, + }) + } else { + Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error))) + } + } + } + _ => Err(backoff::Error::Permanent(e)), + } + }, } - - Ok(bytes) }) .await } @@ -471,6 +453,54 @@ impl Client { } } +async fn read_response(response: Response) -> Result { + let status = response.status(); + let bytes = response + .bytes() + .await + .map_err(OpenAIError::Reqwest)?; + + if status.is_server_error() { + // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them. + let message: String = String::from_utf8_lossy(&bytes).into_owned(); + tracing::warn!("Server error: {status} - {message}"); + return Err( + OpenAIError::ApiError(ApiError { + message, + r#type: None, + param: None, + code: None, + })); + } + + // Deserialize response body from either error object or actual response object + if !status.is_success() { + let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref()) + .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?; + + return Err(OpenAIError::ApiError(wrapped_error.error)); + } + + Ok(bytes) +} + +async fn map_stream_error(value: EventSourceError) -> OpenAIError { + match value { + EventSourceError::Parser(e) => OpenAIError::StreamError(StreamError::Parser(e.to_string())), + EventSourceError::InvalidContentType(e, response) => OpenAIError::StreamError(StreamError::InvalidContentType(e, response)), + EventSourceError::InvalidLastEventId(e) => OpenAIError::StreamError(StreamError::InvalidLastEventId(e)), + EventSourceError::StreamEnded => OpenAIError::StreamError(StreamError::StreamEnded), + EventSourceError::Utf8(e) => OpenAIError::StreamError(StreamError::Utf8(e)), + EventSourceError::Transport(error) => OpenAIError::Reqwest(error), + EventSourceError::InvalidStatusCode(_status_code, response) => { + read_response(response) + .await + .expect_err("Unreachable because read_response returns err when status_code is invalid") + } + } +} + + /// Request which responds with SSE. /// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format) pub(crate) async fn stream( @@ -485,7 +515,7 @@ where while let Some(ev) = event_source.next().await { match ev { Err(e) => { - if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) { + if let Err(_e) = tx.send(Err(map_stream_error(e).await)) { // rx dropped break; } @@ -530,7 +560,7 @@ where while let Some(ev) = event_source.next().await { match ev { Err(e) => { - if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) { + if let Err(_e) = tx.send(Err(map_stream_error(e).await)) { // rx dropped break; } diff --git a/async-openai/src/error.rs b/async-openai/src/error.rs index a1139c9f..35ab388d 100644 --- a/async-openai/src/error.rs +++ b/async-openai/src/error.rs @@ -1,6 +1,11 @@ //! Errors originating from API calls, parsing responses, and reading-or-writing to the file system. +use std::string::FromUtf8Error; + +use reqwest::{header::HeaderValue, Response}; use serde::{Deserialize, Serialize}; +use reqwest_eventsource::Error as EventSourceError; + #[derive(Debug, thiserror::Error)] pub enum OpenAIError { /// Underlying error from reqwest library after an API call was made @@ -20,13 +25,35 @@ pub enum OpenAIError { FileReadError(String), /// Error on SSE streaming #[error("stream failed: {0}")] - StreamError(String), + StreamError(StreamError), /// Error from client side validation /// or when builder fails to build request before making API call #[error("invalid args: {0}")] InvalidArgument(String), } +#[derive(Debug, thiserror::Error)] +pub enum StreamError { + /// Source stream is not valid UTF8 + #[error(transparent)] + Utf8(FromUtf8Error), + /// Source stream is not a valid EventStream + #[error("Source stream is not a valid event stream: {0}")] + Parser(String), + /// The `Content-Type` returned by the server is invalid + #[error("Invalid content type for event stream: {0:?}")] + InvalidContentType(HeaderValue, Response), + /// The `Last-Event-ID` cannot be formed into a Header to be submitted to the server + #[error("Invalid `Last-Event-ID` for event stream: {0}")] + InvalidLastEventId(String), + /// The server sent an unrecognized event type + #[error("Unrecognized event type: {0}")] + UnrecognizedEventType(String), + /// The stream ended + #[error("Stream ended")] + StreamEnded, +} + /// OpenAI API returns error object on failure #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ApiError { diff --git a/async-openai/src/types/assistant_stream.rs b/async-openai/src/types/assistant_stream.rs index 755a322d..8bf3b690 100644 --- a/async-openai/src/types/assistant_stream.rs +++ b/async-openai/src/types/assistant_stream.rs @@ -3,7 +3,7 @@ use std::pin::Pin; use futures::Stream; use serde::Deserialize; -use crate::error::{map_deserialization_error, ApiError, OpenAIError}; +use crate::error::{map_deserialization_error, ApiError, OpenAIError, StreamError}; use super::{ MessageDeltaObject, MessageObject, RunObject, RunStepDeltaObject, RunStepObject, ThreadObject, @@ -207,9 +207,9 @@ impl TryFrom for AssistantStreamEvent { .map(AssistantStreamEvent::ErrorEvent), "done" => Ok(AssistantStreamEvent::Done(value.data)), - _ => Err(OpenAIError::StreamError( - "Unrecognized event: {value:?#}".into(), - )), + _ => Err(OpenAIError::StreamError(StreamError::UnrecognizedEventType( + value.event, + ))), } } } From de767d0b26e5e6d5d37a75320c1f0ce2935a95a0 Mon Sep 17 00:00:00 2001 From: Tinco Andringa Date: Tue, 12 Aug 2025 19:28:53 +0200 Subject: [PATCH 2/3] format --- async-openai/src/client.rs | 39 +++++++++++----------- async-openai/src/types/assistant_stream.rs | 6 ++-- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/async-openai/src/client.rs b/async-openai/src/client.rs index 9ee92ef8..93a76e19 100644 --- a/async-openai/src/client.rs +++ b/async-openai/src/client.rs @@ -347,7 +347,9 @@ impl Client { retry_after: None, }) } else { - if status.as_u16() == 429 && api_error.r#type != Some("insufficient_quota".to_string()) { + if status.as_u16() == 429 + && api_error.r#type != Some("insufficient_quota".to_string()) + { // Rate limited retry... tracing::warn!("Rate limited: {}", api_error.message); Err(backoff::Error::Transient { @@ -361,7 +363,7 @@ impl Client { } _ => Err(backoff::Error::Permanent(e)), } - }, + } } }) .await @@ -455,22 +457,18 @@ impl Client { async fn read_response(response: Response) -> Result { let status = response.status(); - let bytes = response - .bytes() - .await - .map_err(OpenAIError::Reqwest)?; + let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?; if status.is_server_error() { // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them. let message: String = String::from_utf8_lossy(&bytes).into_owned(); tracing::warn!("Server error: {status} - {message}"); - return Err( - OpenAIError::ApiError(ApiError { - message, - r#type: None, - param: None, - code: None, - })); + return Err(OpenAIError::ApiError(ApiError { + message, + r#type: None, + param: None, + code: None, + })); } // Deserialize response body from either error object or actual response object @@ -487,20 +485,23 @@ async fn read_response(response: Response) -> Result { async fn map_stream_error(value: EventSourceError) -> OpenAIError { match value { EventSourceError::Parser(e) => OpenAIError::StreamError(StreamError::Parser(e.to_string())), - EventSourceError::InvalidContentType(e, response) => OpenAIError::StreamError(StreamError::InvalidContentType(e, response)), - EventSourceError::InvalidLastEventId(e) => OpenAIError::StreamError(StreamError::InvalidLastEventId(e)), + EventSourceError::InvalidContentType(e, response) => { + OpenAIError::StreamError(StreamError::InvalidContentType(e, response)) + } + EventSourceError::InvalidLastEventId(e) => { + OpenAIError::StreamError(StreamError::InvalidLastEventId(e)) + } EventSourceError::StreamEnded => OpenAIError::StreamError(StreamError::StreamEnded), EventSourceError::Utf8(e) => OpenAIError::StreamError(StreamError::Utf8(e)), EventSourceError::Transport(error) => OpenAIError::Reqwest(error), EventSourceError::InvalidStatusCode(_status_code, response) => { - read_response(response) - .await - .expect_err("Unreachable because read_response returns err when status_code is invalid") + read_response(response).await.expect_err( + "Unreachable because read_response returns err when status_code is invalid", + ) } } } - /// Request which responds with SSE. /// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format) pub(crate) async fn stream( diff --git a/async-openai/src/types/assistant_stream.rs b/async-openai/src/types/assistant_stream.rs index 8bf3b690..d7bec396 100644 --- a/async-openai/src/types/assistant_stream.rs +++ b/async-openai/src/types/assistant_stream.rs @@ -207,9 +207,9 @@ impl TryFrom for AssistantStreamEvent { .map(AssistantStreamEvent::ErrorEvent), "done" => Ok(AssistantStreamEvent::Done(value.data)), - _ => Err(OpenAIError::StreamError(StreamError::UnrecognizedEventType( - value.event, - ))), + _ => Err(OpenAIError::StreamError( + StreamError::UnrecognizedEventType(value.event), + )), } } } From 5b85afed73b97bcaa47f7ac1106194c82758e14f Mon Sep 17 00:00:00 2001 From: Tinco Andringa Date: Tue, 12 Aug 2025 19:31:08 +0200 Subject: [PATCH 3/3] clippy --- async-openai/src/client.rs | 22 ++++++++++------------ async-openai/src/error.rs | 1 - 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/async-openai/src/client.rs b/async-openai/src/client.rs index 93a76e19..877d9c26 100644 --- a/async-openai/src/client.rs +++ b/async-openai/src/client.rs @@ -346,19 +346,17 @@ impl Client { err: OpenAIError::ApiError(api_error), retry_after: None, }) + } else if status.as_u16() == 429 + && api_error.r#type != Some("insufficient_quota".to_string()) + { + // Rate limited retry... + tracing::warn!("Rate limited: {}", api_error.message); + Err(backoff::Error::Transient { + err: OpenAIError::ApiError(api_error), + retry_after: None, + }) } else { - if status.as_u16() == 429 - && api_error.r#type != Some("insufficient_quota".to_string()) - { - // Rate limited retry... - tracing::warn!("Rate limited: {}", api_error.message); - Err(backoff::Error::Transient { - err: OpenAIError::ApiError(api_error), - retry_after: None, - }) - } else { - Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error))) - } + Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error))) } } _ => Err(backoff::Error::Permanent(e)), diff --git a/async-openai/src/error.rs b/async-openai/src/error.rs index 35ab388d..46b2fc1b 100644 --- a/async-openai/src/error.rs +++ b/async-openai/src/error.rs @@ -4,7 +4,6 @@ use std::string::FromUtf8Error; use reqwest::{header::HeaderValue, Response}; use serde::{Deserialize, Serialize}; -use reqwest_eventsource::Error as EventSourceError; #[derive(Debug, thiserror::Error)] pub enum OpenAIError {