Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,4 @@ rust-executor/CUSTOM_DENO_SNAPSHOT.bin
rust-executor/test_data

.npmrc
.worktrees/
10 changes: 6 additions & 4 deletions core/src/ai/AIClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -302,17 +302,19 @@ export class AIClient {
endThreshold?: number;
endWindow?: number;
timeBeforeSpeech?: number;
}
},
language?: string
): Promise<string> {
const { aiOpenTranscriptionStream } = unwrapApolloResult(await this.#apolloClient.mutate({
mutation: gql`
mutation AiOpenTranscriptionStream($modelId: String!, $params: VoiceActivityParamsInput) {
aiOpenTranscriptionStream(modelId: $modelId, params: $params)
mutation AiOpenTranscriptionStream($modelId: String!, $params: VoiceActivityParamsInput, $language: String) {
aiOpenTranscriptionStream(modelId: $modelId, params: $params, language: $language)
}
`,
variables: {
modelId,
params
params,
language
}
}));

Expand Down
3 changes: 2 additions & 1 deletion core/src/ai/AIResolver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,8 @@ export default class AIResolver {
@Mutation(() => String)
aiOpenTranscriptionStream(
@Arg("modelId") modelId: string,
@Arg("params", () => VoiceActivityParamsInput, { nullable: true }) params?: VoiceActivityParamsInput
@Arg("params", () => VoiceActivityParamsInput, { nullable: true }) params?: VoiceActivityParamsInput,
@Arg("language", { nullable: true }) language?: string
): string {
return "streamId"
}
Expand Down
77 changes: 54 additions & 23 deletions rust-executor/src/ai_service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ pub struct AIService {
llm_channel: Arc<Mutex<HashMap<String, mpsc::UnboundedSender<LLMTaskRequest>>>>,
transcription_streams: Arc<Mutex<HashMap<String, TranscriptionSession>>>,
cleanup_task_shutdown: Arc<std::sync::Mutex<Option<oneshot::Sender<()>>>>,
/// Shared Whisper models - ONE model per size, shared across ALL streams using that size
/// Key = WhisperSource (Tiny/Small/Medium/Large), Value = Arc<Whisper>
/// Shared Whisper models - ONE model per size+language combination, shared across ALL streams
/// Key = "{size:?}_{lang_key}" (e.g., "Small_auto", "Small_de"), Value = Arc<Whisper>
/// Cloning Arc is cheap (just increments ref count), model weights stay in memory once
/// Saves 500MB-1.5GB per stream!
shared_whisper_models: Arc<Mutex<HashMap<String, Arc<Whisper>>>>,
Expand Down Expand Up @@ -1113,42 +1113,73 @@ impl AIService {
&self,
model_id: String,
params: Option<VoiceActivityParams>,
language: Option<String>,
) -> Result<String> {
let model_size = Self::get_whisper_model_size(model_id.clone())?;

// MEMORY OPTIMIZATION: Load each Whisper model size ONCE and share across all streams using that size
// Arc cloning is cheap (just increments ref count), saves 500MB-1.5GB per stream!
// Language is set at build time, so include it in the cache key
let whisper_model = {
let mut shared_models = self.shared_whisper_models.lock().await;
let model_key = format!("{:?}", model_size); // Use Debug format as key (e.g., "Small", "Medium")

if !shared_models.contains_key(&model_key) {
// Parse language upfront and use canonical form for cache key
let whisper_lang = if let Some(ref lang) = language {
Some(
lang.to_lowercase()
.parse::<WhisperLanguage>()
.map_err(|_| anyhow::anyhow!("Unsupported whisper language: {}", lang))?,
)
} else {
None
};
let lang_key = whisper_lang
.as_ref()
.map(|l| l.to_string())
.unwrap_or_else(|| "auto".to_string());
let model_key = format!("{:?}_{}", model_size, lang_key);

// First check: acquire lock briefly to see if model already exists
let existing = {
let shared_models = self.shared_whisper_models.lock().await;
shared_models.get(&model_key).cloned()
}; // Lock dropped here

if let Some(model) = existing {
log::info!(
"Loading shared Whisper model {} ({:?}) (ONE model per size, ~500MB-1.5GB)...",
"Reusing existing shared Whisper model {:?} (language: {}) for new stream",
model_size,
lang_key
);
model
} else {
// Model not found — build it outside the lock to avoid serializing callers
log::info!(
"Loading shared Whisper model {} ({:?}, language: {}) (ONE model per size, ~500MB-1.5GB)...",
model_id,
model_size
model_size,
lang_key
);

let model = WhisperBuilder::default()
let mut builder = WhisperBuilder::default()
.with_source(model_size)
.with_device(Self::new_candle_device())
.build()
.await?;
.with_device(Self::new_candle_device());

builder = builder.with_language(whisper_lang);

let model = builder.build().await?;

log::info!(
"Shared Whisper model {:?} loaded! All streams using this size will reuse this model.",
model_size
);
shared_models.insert(model_key.clone(), Arc::new(model));
} else {
log::info!(
"Reusing existing shared Whisper model {:?} for new stream",
model_size
"Shared Whisper model {:?} (language: {}) loaded! All streams using this config will reuse this model.",
model_size,
lang_key
);
}

// Clone the Arc - this is CHEAP! Just increments a reference count
shared_models.get(&model_key).unwrap().clone()
// Re-acquire lock and insert; use entry to handle race where another caller built it first
let mut shared_models = self.shared_whisper_models.lock().await;
let model_arc = shared_models
.entry(model_key)
.or_insert_with(|| Arc::new(model));
model_arc.clone()
}
};

log::info!("Opening transcription stream with model {:?}", model_size);
Expand Down
3 changes: 2 additions & 1 deletion rust-executor/src/graphql/mutation_resolvers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2614,11 +2614,12 @@ impl Mutation {
context: &RequestContext,
model_id: String,
params: Option<VoiceActivityParamsInput>,
language: Option<String>,
) -> FieldResult<String> {
check_capability(&context.capabilities, &AI_TRANSCRIBE_CAPABILITY)?;
Ok(AIService::global_instance()
.await?
.open_transcription_stream(model_id, params.map(|p| p.into()))
.open_transcription_stream(model_id, params.map(|p| p.into()), language)
.await?)
}

Expand Down