Skip to content

Commit 9009320

Browse files
committed
feat: refactor text generation to use LanguageModel abstraction
- Replace direct use of model string and AIProvider in `GenerateTextRequest` with a `LanguageModel` struct. - Update OpenAI provider and tests to construct and use `LanguageModel` instances. - Simplify `GenerateTextRequestBuilder` by removing separate provider field; model now encapsulates provider. - Adjust internal logic to clone provider from the model. - Improves type safety and paves the way for richer model metadata.
1 parent 6c810fe commit 9009320

2 files changed

Lines changed: 21 additions & 25 deletions

File tree

crates/umem_ai/src/providers/openai.rs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ impl OpenAIProvider {
7070
let normalized_user_messages = self.normalize_user_messages(&request.messages);
7171

7272
serde_json::json!({
73-
"model": request.model,
73+
"model": request.model.model_name,
7474
"instructions":system,
7575
"input": [serde_json::json!({
7676
"type": "message",
@@ -408,19 +408,20 @@ impl Default for OpenAIProviderBuilder {
408408

409409
#[cfg(test)]
410410
mod tests {
411+
411412
use super::*;
412413
use crate::{
413414
response_generators::{
414415
generate_object::{generate_object, GenerateObjectRequestBuilder},
415416
generate_text, GenerateTextRequestBuilder,
416417
},
417-
LLMProvider, LLM,
418+
AIProvider, LanguageModel,
418419
};
419420
use std::sync::Arc;
420421

421422
#[tokio::test(flavor = "multi_thread")]
422423
async fn test_generate_object() {
423-
let provider = Arc::new(LLMProvider::from(
424+
let provider = Arc::new(AIProvider::from(
424425
OpenAIProviderBuilder::new()
425426
.api_key("")
426427
.base_url("https://openrouter.ai/api/v1")
@@ -434,13 +435,13 @@ mod tests {
434435
traditions: String,
435436
}
436437

437-
let llm = Arc::new(LLM {
438+
let model = Arc::new(LanguageModel {
438439
provider,
439440
model_name: "allenai/olmo-3.1-32b-think:free".to_string(),
440441
});
441442

442443
let request = GenerateObjectRequestBuilder::<Holiday>::new()
443-
.model(llm)
444+
.model(model)
444445
.system("You are a helpful assistant.".to_string())
445446
.prompt("Invent a new holiday and describe its traditions.".to_string())
446447
.max_output_tokens(2000)
@@ -455,19 +456,23 @@ mod tests {
455456

456457
#[tokio::test(flavor = "multi_thread")]
457458
async fn test_generate_text() {
458-
let provider = Arc::new(LLMProvider::from(
459+
let provider = Arc::new(AIProvider::from(
459460
OpenAIProviderBuilder::new()
460461
.api_key("")
461462
.base_url("https://openrouter.ai/api/v1")
462463
.build()
463464
.unwrap(),
464465
));
465466

467+
let model = Arc::new(LanguageModel {
468+
provider,
469+
model_name: "arcee-ai/trinity-mini:free".to_string(),
470+
});
471+
466472
let request = GenerateTextRequestBuilder::new()
467-
.model("arcee-ai/trinity-mini:free".to_string())
473+
.model(model)
468474
.system("You are a helpful assistant.".to_string())
469475
.prompt("Invent a new holiday and describe its traditions.".to_string())
470-
.provider(Arc::clone(&provider))
471476
.max_output_tokens(10000)
472477
.temperature(0.7)
473478
.build()

crates/umem_ai/src/response_generators/generate_text.rs

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::response_generators::messages::Message;
22
use crate::utils;
33
use crate::utils::is_retryable_error;
4-
use crate::AIProvider;
4+
use crate::LanguageModel;
55
use crate::ResponseGeneratorError;
66
use backon::ExponentialBuilder;
77
use backon::Retryable;
@@ -19,8 +19,10 @@ pub async fn generate_text(
1919
let total_delay = per_request_timeout.mul_f32(max_retries as f32 / 2.0);
2020

2121
let generation = || {
22-
let provider = Arc::clone(&request.provider);
22+
let model = Arc::clone(&request.model);
23+
let provider = Arc::clone(&model.provider);
2324
let request = request.clone();
25+
2426
async move {
2527
tokio::time::timeout(per_request_timeout, provider.do_generate_text(request))
2628
.await
@@ -50,8 +52,7 @@ pub struct GenerateTextResponse {
5052

5153
#[derive(Clone)]
5254
pub struct GenerateTextRequest {
53-
pub model: String,
54-
pub provider: Arc<AIProvider>,
55+
pub model: Arc<LanguageModel>,
5556
pub messages: Vec<Message>,
5657
pub max_output_tokens: Option<usize>,
5758
pub temperature: Option<f32>,
@@ -80,8 +81,7 @@ pub enum GenerateTextRequestBuilderError {
8081
}
8182

8283
pub struct GenerateTextRequestBuilder {
83-
pub model: Option<String>,
84-
pub provider: Option<Arc<AIProvider>>,
84+
pub model: Option<Arc<LanguageModel>>,
8585
pub system: Option<String>,
8686
pub prompt: Option<String>,
8787
pub messages: Vec<Message>,
@@ -100,7 +100,6 @@ impl GenerateTextRequestBuilder {
100100
pub fn new() -> Self {
101101
GenerateTextRequestBuilder {
102102
model: None,
103-
provider: None,
104103
system: None,
105104
prompt: None,
106105
max_output_tokens: None,
@@ -116,13 +115,8 @@ impl GenerateTextRequestBuilder {
116115
}
117116
}
118117

119-
pub fn model(mut self, model: impl Into<String>) -> Self {
120-
self.model = Some(model.into());
121-
self
122-
}
123-
124-
pub fn provider(mut self, provider: Arc<AIProvider>) -> Self {
125-
self.provider = Some(provider);
118+
pub fn model(mut self, model: Arc<LanguageModel>) -> Self {
119+
self.model = Some(model);
126120
self
127121
}
128122

@@ -214,9 +208,6 @@ impl GenerateTextRequestBuilder {
214208
.model
215209
.ok_or(GenerateTextRequestBuilderError::MissingModel)?,
216210
messages: self.messages,
217-
provider: self
218-
.provider
219-
.ok_or(GenerateTextRequestBuilderError::MissingProvider)?,
220211
max_output_tokens: self.max_output_tokens,
221212
top_p: self.top_p,
222213
top_k: self.top_k,

0 commit comments

Comments
 (0)