From 3dcb2b2f71827bc021f0db5279b68d89db32b49c Mon Sep 17 00:00:00 2001 From: Anton Chen Date: Mon, 20 Apr 2026 20:23:16 +0800 Subject: [PATCH 1/2] Clean up search execution and add optional query CPU pool --- next-plaid-api/src/handlers/encode.rs | 106 ++- next-plaid-api/src/handlers/rerank.rs | 10 +- next-plaid-api/src/handlers/search.rs | 1001 +++++++++++++++++-------- next-plaid-api/src/main.rs | 319 +++++--- next-plaid-api/src/state.rs | 198 ++++- 5 files changed, 1185 insertions(+), 449 deletions(-) diff --git a/next-plaid-api/src/handlers/encode.rs b/next-plaid-api/src/handlers/encode.rs index aa2c6735..3b1fa76b 100644 --- a/next-plaid-api/src/handlers/encode.rs +++ b/next-plaid-api/src/handlers/encode.rs @@ -22,6 +22,8 @@ use crate::error::{ApiError, ApiResult}; use crate::models::InputType; use crate::models::{EncodeRequest, EncodeResponse}; use crate::state::AppState; +#[cfg(feature = "model")] +use crate::state::{EncodeLane, EncodePoolKind}; // --- Batch Configuration --- @@ -82,28 +84,66 @@ struct EncodeWorkerPool { sender: mpsc::Sender, } -/// Global encode worker pool (singleton). #[cfg(feature = "model")] -static ENCODE_WORKER_POOL: OnceLock>> = OnceLock::new(); +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct EncodePoolCacheKey { + pool_kind: EncodePoolKind, + path: std::path::PathBuf, + use_cuda: bool, + use_int8: bool, + parallel_sessions: Option, + batch_size: Option, + threads: Option, + query_length: Option, + document_length: Option, + pool_size: usize, +} -/// Get or create the global encode worker pool. -/// Spawns multiple workers, each owning its own Colbert model instance. #[cfg(feature = "model")] -fn get_or_create_encode_pool(state: Arc) -> ApiResult> { - let pool_lock: &std::sync::Mutex> = - ENCODE_WORKER_POOL.get_or_init(|| std::sync::Mutex::new(None)); - - let mut pool_opt = pool_lock.lock().unwrap(); - - if let Some(pool) = pool_opt.as_ref() { - return Ok(pool.sender.clone()); +impl EncodePoolCacheKey { + fn new( + pool_kind: EncodePoolKind, + model_config: &crate::state::ModelConfig, + pool_size: usize, + ) -> Self { + Self { + pool_kind, + path: model_config.path.clone(), + use_cuda: model_config.use_cuda, + use_int8: model_config.use_int8, + parallel_sessions: model_config.parallel_sessions, + batch_size: model_config.batch_size, + threads: model_config.threads, + query_length: model_config.query_length, + document_length: model_config.document_length, + pool_size, + } } +} - // Get model pool configuration +/// Global encode worker pools keyed by model configuration and physical lane. +#[cfg(feature = "model")] +static ENCODE_WORKER_POOLS: OnceLock< + std::sync::Mutex>, +> = OnceLock::new(); + +/// Get or create the encode worker pool for the selected physical lane. +#[cfg(feature = "model")] +fn get_or_create_encode_pool( + state: Arc, + pool_kind: EncodePoolKind, +) -> ApiResult> { + let pool_lock: &std::sync::Mutex> = + ENCODE_WORKER_POOLS.get_or_init(|| std::sync::Mutex::new(HashMap::new())); let model_pool = state - .model_pool - .as_ref() + .model_pool_for_kind(pool_kind) .ok_or_else(|| ApiError::ModelNotLoaded)?; + let cache_key = + EncodePoolCacheKey::new(pool_kind, &model_pool.model_config, model_pool.pool_size); + let mut pools = pool_lock.lock().unwrap(); + if let Some(pool) = pools.get(&cache_key) { + return Ok(pool.sender.clone()); + } let pool_size = model_pool.pool_size; let model_config = model_pool.model_config.clone(); @@ -115,23 +155,33 @@ fn get_or_create_encode_pool(state: Arc) -> ApiResult { - tracing::info!(worker_id = worker_id, "encode.worker.started"); + tracing::info!( + worker_id = worker_id, + pool_kind = pool_kind_copy.as_str(), + "encode.worker.started" + ); m } Err(e) => { tracing::error!( worker_id = worker_id, + pool_kind = pool_kind_copy.as_str(), error = %e, "encode.worker.start.failed" ); @@ -147,7 +197,7 @@ fn get_or_create_encode_pool(state: Arc) -> ApiResult, +) -> ApiResult>> { + let encode_lane = match &input_type { + InputType::Query => EncodeLane::Query, + InputType::Document => EncodeLane::Ingest, + }; + + encode_texts_internal_with_lane(state, texts, input_type, encode_lane, pool_factor).await +} + +/// Internal function to encode texts on an explicit logical lane. +#[cfg(feature = "model")] +pub(crate) async fn encode_texts_internal_with_lane( + state: Arc, + texts: &[String], + input_type: InputType, + encode_lane: EncodeLane, + pool_factor: Option, ) -> ApiResult>> { if !state.has_model() { return Err(ApiError::ModelNotLoaded); } + let pool_kind = state.encode_pool_kind_for_lane(encode_lane); // Create oneshot channel for receiving results let (response_tx, response_rx) = oneshot::channel(); @@ -532,7 +600,7 @@ pub async fn encode_texts_internal( }; // Get or create the worker pool - let sender = get_or_create_encode_pool(state)?; + let sender = get_or_create_encode_pool(state, pool_kind)?; // Send to worker pool sender.try_send(batch_item).map_err(|e| match e { diff --git a/next-plaid-api/src/handlers/rerank.rs b/next-plaid-api/src/handlers/rerank.rs index 327d29a8..1773bfae 100644 --- a/next-plaid-api/src/handlers/rerank.rs +++ b/next-plaid-api/src/handlers/rerank.rs @@ -15,6 +15,11 @@ use crate::state::AppState; use crate::tracing_middleware::TraceId; use crate::PrettyJson; +#[cfg(feature = "model")] +fn rerank_document_encode_lane() -> crate::state::EncodeLane { + crate::state::EncodeLane::Query +} + /// Convert a Vec> to an ndarray::Array2. fn to_ndarray(embeddings: &[Vec]) -> ApiResult> { if embeddings.is_empty() { @@ -218,7 +223,7 @@ pub async fn rerank_with_encoding( trace_id: Option>, Json(request): Json, ) -> ApiResult> { - use crate::handlers::encode::encode_texts_internal; + use crate::handlers::encode::{encode_texts_internal, encode_texts_internal_with_lane}; use crate::models::InputType; let trace_id = trace_id.map(|t| t.0).unwrap_or_default(); @@ -256,10 +261,11 @@ pub async fn rerank_with_encoding( .ok_or_else(|| ApiError::Internal("Failed to encode query".to_string()))?; // Encode documents - let doc_embeddings = encode_texts_internal( + let doc_embeddings = encode_texts_internal_with_lane( state, &request.documents, InputType::Document, + rerank_document_encode_lane(), request.pool_factor, ) .await?; diff --git a/next-plaid-api/src/handlers/search.rs b/next-plaid-api/src/handlers/search.rs index 64443916..ea994360 100644 --- a/next-plaid-api/src/handlers/search.rs +++ b/next-plaid-api/src/handlers/search.rs @@ -2,14 +2,16 @@ //! //! Handles search operations on indices. -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; +use std::time::Instant as StdInstant; use axum::{ extract::{Path, State}, Extension, Json, }; use ndarray::Array2; +use tokio::task; use next_plaid::{filtering, text_search, SearchParameters}; @@ -17,25 +19,77 @@ use crate::error::{ApiError, ApiResult}; use crate::handlers::encode::encode_texts_internal; use crate::models::{ ErrorResponse, FilteredSearchRequest, FilteredSearchWithEncodingRequest, InputType, - QueryEmbeddings, QueryResultResponse, SearchRequest, SearchResponse, SearchWithEncodingRequest, + QueryEmbeddings, QueryResultResponse, SearchParamsRequest, SearchRequest, SearchResponse, + SearchWithEncodingRequest, }; use crate::state::AppState; use crate::tracing_middleware::TraceId; use crate::PrettyJson; -// Fusion algorithms are in next_plaid::text_search::{fuse_rrf, fuse_relative_score} +#[derive(Debug, Clone, Copy)] +enum FusionMode { + Rrf, + RelativeScore, +} + +#[derive(Debug, Clone, Copy)] +struct PreparedSearchConfig { + top_k: usize, + fetch_k: usize, + n_ivf_probe: usize, + n_full_scores: usize, + centroid_score_threshold: Option, + alpha: f32, + fusion_mode: FusionMode, +} + +#[derive(Debug)] +struct PreparedSearchRequest { + semantic_queries: Option>>, + text_queries: Vec, + subset: Option>, + filter_condition: Option, + filter_parameters: Vec, + config: PreparedSearchConfig, +} + +#[derive(Debug)] +struct PreparedSearchRequestInput { + semantic_queries: Option>>, + params: SearchParamsRequest, + subset: Option>, + text_query: Option>, + alpha: Option, + fusion: Option, + filter_condition: Option, + filter_parameters: Option>, + default_top_k: usize, +} + +#[derive(Debug)] +struct SearchExecutionMetrics { + mode: &'static str, + num_queries: usize, + top_k: usize, + total_results: usize, +} + +#[derive(Debug)] +struct SearchExecutionOutput { + response: SearchResponse, + metrics: SearchExecutionMetrics, +} /// Convert query embeddings from JSON or base64 format to ndarray. fn to_ndarray(query: &QueryEmbeddings) -> ApiResult> { - // Prefer base64 if provided (more efficient) if let (Some(b64), Some(shape)) = (&query.embeddings_b64, &query.shape) { let floats = crate::models::decode_b64_embeddings(b64, *shape).map_err(ApiError::BadRequest)?; - return Array2::from_shape_vec((shape[0], shape[1]), floats) - .map_err(|e| ApiError::BadRequest(format!("Failed to create query array: {}", e))); + return Array2::from_shape_vec((shape[0], shape[1]), floats).map_err(|error| { + ApiError::BadRequest(format!("Failed to create query array: {}", error)) + }); } - // Fall back to JSON array format let embeddings = query.embeddings.as_ref().ok_or_else(|| { ApiError::BadRequest( "Must provide either 'embeddings' or 'embeddings_b64' + 'shape'".to_string(), @@ -54,12 +108,11 @@ fn to_ndarray(query: &QueryEmbeddings) -> ApiResult> { )); } - // Verify all rows have the same dimension - for (i, row) in embeddings.iter().enumerate() { + for (index, row) in embeddings.iter().enumerate() { if row.len() != cols { return Err(ApiError::BadRequest(format!( "Inconsistent query embedding dimension at row {}: expected {}, got {}", - i, + index, cols, row.len() ))); @@ -68,74 +121,32 @@ fn to_ndarray(query: &QueryEmbeddings) -> ApiResult> { let flat: Vec = embeddings.iter().flatten().copied().collect(); Array2::from_shape_vec((rows, cols), flat) - .map_err(|e| ApiError::BadRequest(format!("Failed to create query array: {}", e))) + .map_err(|error| ApiError::BadRequest(format!("Failed to create query array: {}", error))) } -/// Fetch metadata for a list of document IDs. -/// Returns a Vec of Option in the same order as document_ids. -/// If metadata doesn't exist for an index or a specific document, returns None for that entry. -/// -/// # Errors -/// Returns an error if the metadata database exists but fails to query. -/// If no metadata database exists, returns Ok with None for all entries (not an error). -fn fetch_metadata_for_docs( - path_str: &str, - document_ids: &[i64], -) -> ApiResult>> { - if !filtering::exists(path_str) { - // No metadata database - return None for all (this is not an error) - return Ok(vec![None; document_ids.len()]); +fn decode_semantic_queries( + queries: Option>, +) -> ApiResult>>> { + match queries { + Some(queries) if !queries.is_empty() => queries + .iter() + .map(to_ndarray) + .collect::>>() + .map(Some), + _ => Ok(None), } - - // Fetch metadata for the document IDs - let metadata_list = filtering::get(path_str, None, &[], Some(document_ids)).map_err(|e| { - tracing::error!("Failed to fetch metadata from database: {}", e); - ApiError::Internal(format!("Failed to fetch metadata: {}", e)) - })?; - - // Build a map from _subset_ to metadata for quick lookup - let meta_map: HashMap = metadata_list - .into_iter() - .filter_map(|m| m.get("_subset_").and_then(|v| v.as_i64()).map(|id| (id, m))) - .collect(); - - // Map document_ids to their metadata (or None if not found) - Ok(document_ids - .iter() - .map(|doc_id| meta_map.get(doc_id).cloned()) - .collect()) } -/// Search an index with query embeddings. -#[utoipa::path( - post, - path = "/indices/{name}/search", - tag = "search", - params( - ("name" = String, Path, description = "Index name") - ), - request_body = SearchRequest, - responses( - (status = 200, description = "Search results", body = SearchResponse), - (status = 400, description = "Invalid request", body = ErrorResponse), - (status = 404, description = "Index not found", body = ErrorResponse) - ) -)] -pub async fn search( - State(state): State>, - Path(name): Path, - trace_id: Option>, - Json(req): Json, -) -> ApiResult> { - let trace_id = trace_id.map(|t| t.0).unwrap_or_default(); - let start = std::time::Instant::now(); - - let has_queries = req.queries.as_ref().map(|q| !q.is_empty()).unwrap_or(false); - let has_text_query = req - .text_query - .as_ref() - .map(|q| !q.is_empty()) - .unwrap_or(false); +fn build_search_request_config( + semantic_query_count: usize, + has_queries: bool, + text_queries: &[String], + params: &SearchParamsRequest, + alpha: Option, + fusion: Option<&str>, + default_top_k: usize, +) -> ApiResult { + let has_text_query = !text_queries.is_empty(); if !has_queries && !has_text_query { return Err(ApiError::BadRequest( @@ -144,273 +155,559 @@ pub async fn search( )); } - let alpha = req.alpha.unwrap_or(0.75); - if !(0.0..=1.0).contains(&alpha) { + let alpha_value = alpha.unwrap_or(0.75); + if !(0.0..=1.0).contains(&alpha_value) { return Err(ApiError::BadRequest( "alpha must be between 0.0 and 1.0".to_string(), )); } - let fusion_mode = req.fusion.as_deref().unwrap_or("rrf"); - if fusion_mode != "rrf" && fusion_mode != "relative_score" { + let fusion_mode = match fusion.unwrap_or("rrf") { + "rrf" => FusionMode::Rrf, + "relative_score" => FusionMode::RelativeScore, + _ => { + return Err(ApiError::BadRequest( + "fusion must be 'rrf' or 'relative_score'".to_string(), + )); + } + }; + + if has_queries && has_text_query && semantic_query_count != text_queries.len() { + return Err(ApiError::BadRequest(format!( + "queries length ({}) must match text_query length ({}) in hybrid mode", + semantic_query_count, + text_queries.len() + ))); + } + + let top_k = params.top_k.unwrap_or(default_top_k); + if top_k == 0 { + return Err(ApiError::BadRequest( + "top_k must be greater than 0".to_string(), + )); + } + + let n_ivf_probe = params.n_ivf_probe.unwrap_or(8); + if n_ivf_probe == 0 { return Err(ApiError::BadRequest( - "fusion must be 'rrf' or 'relative_score'".to_string(), + "n_ivf_probe must be greater than 0".to_string(), )); } - // Hybrid mode: text_query is a single string, so queries must have exactly 1 element - if has_queries && has_text_query { - let queries_len = req.queries.as_ref().unwrap().len(); - if queries_len != 1 { - return Err(ApiError::BadRequest(format!( - "Hybrid search requires exactly 1 query embedding (got {}). \ - text_query is a single string and can only fuse with one semantic query.", - queries_len - ))); - } + let n_full_scores = params.n_full_scores.unwrap_or(4096); + if n_full_scores == 0 { + return Err(ApiError::BadRequest( + "n_full_scores must be greater than 0".to_string(), + )); + } + + let fetch_k = if has_queries && has_text_query { + top_k.checked_mul(3).ok_or_else(|| { + ApiError::BadRequest("hybrid search fetch_k overflowed top_k * 3".to_string()) + })? + } else { + top_k + }; + + if has_queries && has_text_query && n_full_scores < fetch_k { + return Err(ApiError::BadRequest(format!( + "hybrid search requires n_full_scores ({}) to be greater than or equal to fetch_k ({})", + n_full_scores, fetch_k + ))); } - let top_k = req.params.top_k.unwrap_or(state.config.default_top_k); - let path_str = state.index_path(&name).to_string_lossy().to_string(); + Ok(PreparedSearchConfig { + top_k, + fetch_k, + n_ivf_probe, + n_full_scores, + centroid_score_threshold: params.centroid_score_threshold.unwrap_or_default(), + alpha: alpha_value, + fusion_mode, + }) +} + +fn build_prepared_search_request( + input: PreparedSearchRequestInput, +) -> ApiResult { + let text_queries = input.text_query.unwrap_or_default(); + let has_queries = input + .semantic_queries + .as_ref() + .map(|queries| !queries.is_empty()) + .unwrap_or(false); + + let config = build_search_request_config( + input + .semantic_queries + .as_ref() + .map(|queries| queries.len()) + .unwrap_or(0), + has_queries, + &text_queries, + &input.params, + input.alpha, + input.fusion.as_deref(), + input.default_top_k, + )?; + + Ok(PreparedSearchRequest { + semantic_queries: input.semantic_queries, + text_queries, + subset: input.subset, + filter_condition: input.filter_condition, + filter_parameters: input.filter_parameters.unwrap_or_default(), + config, + }) +} + +fn build_search_params(config: PreparedSearchConfig, top_k: usize) -> ApiResult { + if config.n_full_scores < top_k { + return Err(ApiError::BadRequest(format!( + "n_full_scores ({}) must be greater than or equal to top_k ({})", + config.n_full_scores, top_k + ))); + } - // Resolve filter condition to subset - let mut subset = req.subset.clone(); - if let Some(ref condition) = req.filter_condition { - if !filtering::exists(&path_str) { - return Err(ApiError::MetadataNotFound(name.clone())); + Ok(SearchParameters { + top_k, + n_ivf_probe: config.n_ivf_probe, + n_full_scores: config.n_full_scores, + batch_size: 2000, + centroid_score_threshold: config.centroid_score_threshold, + ..Default::default() + }) +} + +fn validate_query_dimensions(queries: &[Array2], expected_dim: usize) -> ApiResult<()> { + for query in queries { + if query.ncols() != expected_dim { + return Err(ApiError::DimensionMismatch { + expected: expected_dim, + actual: query.ncols(), + }); } - let filter_params = req.filter_parameters.as_deref().unwrap_or(&[]); - let filtered_ids = filtering::where_condition(&path_str, condition, filter_params) - .map_err(|e| ApiError::BadRequest(format!("Invalid filter condition: {}", e)))?; - subset = Some(filtered_ids); } - // --- Pure semantic search (preserves batch query support) --- - if has_queries && !has_text_query { - let queries_vec = req.queries.as_ref().unwrap(); - let queries: Vec> = queries_vec - .iter() - .map(to_ndarray) - .collect::>>()?; - - let idx = state.get_index_for_read(&name)?; - let expected_dim = idx.embedding_dim(); - for query in queries.iter() { - if query.ncols() != expected_dim { - return Err(ApiError::DimensionMismatch { - expected: expected_dim, - actual: query.ncols(), - }); + Ok(()) +} + +fn fetch_metadata_for_docs( + path_str: &str, + document_ids: &[i64], +) -> ApiResult>> { + if !filtering::exists(path_str) { + return Ok(vec![None; document_ids.len()]); + } + + let metadata_list = + filtering::get(path_str, None, &[], Some(document_ids)).map_err(|error| { + tracing::error!("Failed to fetch metadata from database: {}", error); + ApiError::Internal(format!("Failed to fetch metadata: {}", error)) + })?; + + let metadata_map: HashMap = metadata_list + .into_iter() + .filter_map(|metadata| { + metadata + .get("_subset_") + .and_then(|value| value.as_i64()) + .map(|doc_id| (doc_id, metadata)) + }) + .collect(); + + Ok(document_ids + .iter() + .map(|document_id| metadata_map.get(document_id).cloned()) + .collect()) +} + +fn is_filter_request_error(message: &str) -> bool { + message.contains("SQL comments are not allowed") + || message.contains("Semicolons are not allowed") + || message.contains("is not allowed in conditions") + || message.contains("Unterminated quoted identifier") + || message.contains("Unexpected character") + || message.contains("Expected ") + || message.contains("Unexpected token") + || message.contains("Unknown column") + || message.contains("REGEXP requires a pattern parameter") + || message.contains("Invalid regex pattern") + || message.contains("Invalid parameter count") + || message.contains("Invalid parameter name") + || message.contains("syntax error") +} + +fn map_filter_error(index_name: &str, error: next_plaid::Error) -> ApiError { + let message = error.to_string(); + if message.contains("No metadata database found") { + ApiError::MetadataNotFound(index_name.to_string()) + } else if is_filter_request_error(&message) { + ApiError::BadRequest(format!("Invalid filter condition: {}", message)) + } else { + ApiError::Internal(format!("Filter resolution failed: {}", message)) + } +} + +fn is_keyword_request_error(message: &str) -> bool { + message.contains("syntax error") + || message.contains("malformed MATCH expression") + || message.contains("unterminated string") + || message.contains("Failed to prepare FTS5 query") + || message.contains("FTS5 query failed") +} + +fn map_keyword_search_error(index_name: &str, error: next_plaid::Error) -> ApiError { + let message = error.to_string(); + if message.contains("No metadata database found") || message.contains("FTS5 index not found") { + ApiError::MetadataNotFound(index_name.to_string()) + } else if is_keyword_request_error(&message) { + ApiError::BadRequest(format!("Invalid keyword query: {}", message)) + } else { + ApiError::Internal(format!("Keyword search failed: {}", message)) + } +} + +fn intersect_subset_with_filtered_ids(subset: Vec, filtered_ids: Vec) -> Vec { + let allowed_ids: HashSet = filtered_ids.into_iter().collect(); + let mut seen_ids: HashSet = HashSet::new(); + subset + .into_iter() + .filter(|document_id| allowed_ids.contains(document_id) && seen_ids.insert(*document_id)) + .collect() +} + +fn resolve_filtered_subset( + index_name: &str, + path_str: &str, + subset: Option>, + filter_condition: Option<&str>, + filter_parameters: &[serde_json::Value], +) -> ApiResult>> { + match filter_condition { + Some(condition) => { + if !filtering::exists(path_str) { + return Err(ApiError::MetadataNotFound(index_name.to_string())); + } + let filtered_ids = filtering::where_condition(path_str, condition, filter_parameters) + .map_err(|error| map_filter_error(index_name, error))?; + match subset { + Some(subset_ids) => Ok(Some(intersect_subset_with_filtered_ids( + subset_ids, + filtered_ids, + ))), + None => Ok(Some(filtered_ids)), } } + None => Ok(subset), + } +} - let params = SearchParameters { - top_k, - n_ivf_probe: req.params.n_ivf_probe.unwrap_or(8), - n_full_scores: req.params.n_full_scores.unwrap_or(4096), - batch_size: 2000, - centroid_score_threshold: req.params.centroid_score_threshold.unwrap_or_default(), - ..Default::default() - }; +async fn execute_search_prepared( + state: Arc, + name: String, + prepared: PreparedSearchRequest, +) -> ApiResult { + task::spawn_blocking(move || execute_search_prepared_blocking(&state, &name, prepared)) + .await + .map_err(|error| ApiError::Internal(format!("Search task failed: {}", error)))? +} - let index = &**idx; - let raw_results: Vec<(usize, Vec, Vec)> = if queries.len() == 1 { - let r = index.search(&queries[0], ¶ms, subset.as_deref())?; - vec![(r.query_id, r.passage_ids, r.scores)] - } else { - let batch = index.search_batch(&queries, ¶ms, true, subset.as_deref())?; - batch - .into_iter() - .map(|r| (r.query_id, r.passage_ids, r.scores)) - .collect() - }; +fn execute_search_prepared_blocking( + state: &AppState, + name: &str, + prepared: PreparedSearchRequest, +) -> ApiResult { + let path_str = state.index_path(name).to_string_lossy().to_string(); + let subset = resolve_filtered_subset( + name, + &path_str, + prepared.subset, + prepared.filter_condition.as_deref(), + &prepared.filter_parameters, + )?; + + let has_queries = prepared + .semantic_queries + .as_ref() + .map(|queries| !queries.is_empty()) + .unwrap_or(false); + let has_text_query = !prepared.text_queries.is_empty(); + + if has_queries && !has_text_query { + return execute_semantic_search( + state, + name, + &path_str, + prepared + .semantic_queries + .expect("validated semantic queries"), + subset, + prepared.config, + ); + } + + execute_keyword_or_hybrid_search( + state, + name, + &path_str, + prepared.semantic_queries, + prepared.text_queries, + subset, + prepared.config, + ) +} - let total_results: usize = raw_results.iter().map(|(_, ids, _)| ids.len()).sum(); - let results: Vec = raw_results +fn execute_semantic_search( + state: &AppState, + name: &str, + path_str: &str, + semantic_queries: Vec>, + subset: Option>, + config: PreparedSearchConfig, +) -> ApiResult { + let index = state.get_index_for_read(name)?; + validate_query_dimensions(&semantic_queries, index.embedding_dim())?; + let params = build_search_params(config, config.top_k)?; + + let raw_results: Vec<(usize, Vec, Vec)> = if semantic_queries.len() == 1 { + let result = index.search(&semantic_queries[0], ¶ms, subset.as_deref())?; + vec![(result.query_id, result.passage_ids, result.scores)] + } else { + index + .search_batch(&semantic_queries, ¶ms, true, subset.as_deref())? .into_iter() - .map(|(query_id, document_ids, scores)| { - let metadata = fetch_metadata_for_docs(&path_str, &document_ids)?; - Ok(QueryResultResponse { - query_id, - document_ids, - scores, - metadata, - }) - }) - .collect::>>()?; + .map(|result| (result.query_id, result.passage_ids, result.scores)) + .collect() + }; - let total_ms = start.elapsed().as_millis() as u64; - tracing::info!( - trace_id = %trace_id, - index = %name, - mode = "semantic", - num_queries = queries.len(), - top_k = top_k, - total_results = total_results, - total_ms = total_ms, - "search.complete" - ); - if total_ms > 1000 { - tracing::warn!(trace_id = %trace_id, index = %name, total_ms = total_ms, "search.slow"); - } + let total_results: usize = raw_results.iter().map(|(_, ids, _)| ids.len()).sum(); + let results = raw_results + .into_iter() + .map(|(query_id, document_ids, scores)| { + let metadata = fetch_metadata_for_docs(path_str, &document_ids)?; + Ok(QueryResultResponse { + query_id, + document_ids, + scores, + metadata, + }) + }) + .collect::>>()?; - return Ok(PrettyJson(SearchResponse { - num_queries: queries.len(), + Ok(SearchExecutionOutput { + response: SearchResponse { + num_queries: semantic_queries.len(), results, - })); - } + }, + metrics: SearchExecutionMetrics { + mode: "semantic", + num_queries: semantic_queries.len(), + top_k: config.top_k, + total_results, + }, + }) +} - // --- Keyword or hybrid search (supports batch) --- - let empty_text: Vec = vec![]; - let text_queries = req.text_query.as_ref().unwrap_or(&empty_text); - let embedding_queries = req.queries.as_ref(); +fn execute_keyword_or_hybrid_search( + state: &AppState, + name: &str, + path_str: &str, + semantic_queries: Option>>, + text_queries: Vec, + subset: Option>, + config: PreparedSearchConfig, +) -> ApiResult { + let has_queries = semantic_queries + .as_ref() + .map(|queries| !queries.is_empty()) + .unwrap_or(false); + let has_text_query = !text_queries.is_empty(); - // Validate: in hybrid mode, queries and text_query must have the same length - if has_queries && has_text_query { - let n_emb = embedding_queries.unwrap().len(); - let n_txt = text_queries.len(); - if n_emb != n_txt { - return Err(ApiError::BadRequest(format!( - "queries length ({}) must match text_query length ({}) in hybrid mode", - n_emb, n_txt - ))); - } + if has_text_query && !filtering::exists(path_str) { + return Err(ApiError::MetadataNotFound(name.to_string())); } - let num_queries = if has_text_query { - text_queries.len() + let semantic_index = if let Some(queries) = semantic_queries.as_ref() { + let index = state.get_index_for_read(name)?; + validate_query_dimensions(queries, index.embedding_dim())?; + Some(index) + } else { + None + }; + let semantic_params = if semantic_queries.is_some() { + Some(build_search_params(config, config.fetch_k)?) } else { - embedding_queries.map(|q| q.len()).unwrap_or(0) + None }; - let fetch_k = if has_queries && has_text_query { - top_k * 3 + let num_queries = if has_text_query { + text_queries.len() } else { - top_k + semantic_queries + .as_ref() + .map(|queries| queries.len()) + .unwrap_or(0) }; - // Process each query - let mut all_results: Vec = Vec::with_capacity(num_queries); - - #[allow(clippy::needless_range_loop)] - for i in 0..num_queries { - // Semantic component for this query - let semantic: Option<(Vec, Vec)> = if has_queries { - let query = to_ndarray(&embedding_queries.unwrap()[i])?; - let idx = state.get_index_for_read(&name)?; - let expected_dim = idx.embedding_dim(); - if query.ncols() != expected_dim { - return Err(ApiError::DimensionMismatch { - expected: expected_dim, - actual: query.ncols(), - }); - } - let params = SearchParameters { - top_k: fetch_k, - n_ivf_probe: req.params.n_ivf_probe.unwrap_or(8), - n_full_scores: req.params.n_full_scores.unwrap_or(4096), - batch_size: 2000, - centroid_score_threshold: req.params.centroid_score_threshold.unwrap_or_default(), - ..Default::default() - }; - let r = idx.search(&query, ¶ms, subset.as_deref())?; - Some((r.passage_ids, r.scores)) + let mut results = Vec::with_capacity(num_queries); + for query_id in 0..num_queries { + let semantic = if let Some(queries) = semantic_queries.as_ref() { + let result = semantic_index + .as_ref() + .expect("semantic index present") + .search( + &queries[query_id], + semantic_params.as_ref().expect("semantic params present"), + subset.as_deref(), + )?; + Some((result.passage_ids, result.scores)) } else { None }; - // Keyword component for this query - let keyword: Option<(Vec, Vec)> = if has_text_query { - let tq = &text_queries[i]; - let result = if let Some(ref sub) = subset { - text_search::search_filtered(&path_str, tq, fetch_k, sub) + let keyword = if has_text_query { + let result = if let Some(ref subset_ids) = subset { + text_search::search_filtered( + path_str, + &text_queries[query_id], + config.fetch_k, + subset_ids, + ) } else { - text_search::search(&path_str, tq, fetch_k) + text_search::search(path_str, &text_queries[query_id], config.fetch_k) }; - match result { - Ok(r) => Some((r.passage_ids, r.scores)), - Err(e) => { - tracing::warn!(trace_id = %trace_id, index = %name, error = %e, "search.keyword.failed"); - None - } - } + Some( + result + .map(|value| (value.passage_ids, value.scores)) + .map_err(|error| map_keyword_search_error(name, error))?, + ) } else { None }; - // Fuse let (document_ids, scores) = match (semantic, keyword) { - (Some((sem_ids, sem_scores)), Some((kw_ids, kw_scores))) => match fusion_mode { - "relative_score" => text_search::fuse_relative_score( - &sem_ids, - &sem_scores, - &kw_ids, - &kw_scores, - alpha, - top_k, - ), - _ => text_search::fuse_rrf(&sem_ids, &kw_ids, alpha, top_k), - }, - (Some((ids, scores)), None) => { - let mut r: Vec<(i64, f32)> = ids.into_iter().zip(scores).collect(); - r.truncate(top_k); - ( - r.iter().map(|x| x.0).collect(), - r.iter().map(|x| x.1).collect(), - ) + (Some((semantic_ids, semantic_scores)), Some((keyword_ids, keyword_scores))) => { + match config.fusion_mode { + FusionMode::RelativeScore => text_search::fuse_relative_score( + &semantic_ids, + &semantic_scores, + &keyword_ids, + &keyword_scores, + config.alpha, + config.top_k, + ), + FusionMode::Rrf => text_search::fuse_rrf( + &semantic_ids, + &keyword_ids, + config.alpha, + config.top_k, + ), + } } - (None, Some((ids, scores))) => { - let mut r: Vec<(i64, f32)> = ids.into_iter().zip(scores).collect(); - r.truncate(top_k); + (Some((ids, scores)), None) | (None, Some((ids, scores))) => { + let mut ranked: Vec<(i64, f32)> = ids.into_iter().zip(scores).collect(); + ranked.truncate(config.top_k); ( - r.iter().map(|x| x.0).collect(), - r.iter().map(|x| x.1).collect(), + ranked.iter().map(|(document_id, _)| *document_id).collect(), + ranked.iter().map(|(_, score)| *score).collect(), ) } - (None, None) => (vec![], vec![]), + (None, None) => (Vec::new(), Vec::new()), }; - let metadata = fetch_metadata_for_docs(&path_str, &document_ids)?; - all_results.push(QueryResultResponse { - query_id: i, + let metadata = fetch_metadata_for_docs(path_str, &document_ids)?; + results.push(QueryResultResponse { + query_id, document_ids, scores, metadata, }); } - let total_results: usize = all_results.iter().map(|r| r.document_ids.len()).sum(); - let total_ms = start.elapsed().as_millis() as u64; - + let total_results: usize = results.iter().map(|result| result.document_ids.len()).sum(); let mode = if has_queries && has_text_query { "hybrid" } else { "keyword" }; + Ok(SearchExecutionOutput { + response: SearchResponse { + num_queries, + results, + }, + metrics: SearchExecutionMetrics { + mode, + num_queries, + top_k: config.top_k, + total_results, + }, + }) +} + +/// Search an index with query embeddings. +#[utoipa::path( + post, + path = "/indices/{name}/search", + tag = "search", + params( + ("name" = String, Path, description = "Index name") + ), + request_body = SearchRequest, + responses( + (status = 200, description = "Search results", body = SearchResponse), + (status = 400, description = "Invalid request", body = ErrorResponse), + (status = 404, description = "Index not found", body = ErrorResponse) + ) +)] +pub async fn search( + State(state): State>, + Path(name): Path, + trace_id: Option>, + Json(req): Json, +) -> ApiResult> { + let trace_id_value = trace_id.map(|value| value.0).unwrap_or_default(); + let start = StdInstant::now(); + + let semantic_queries = decode_semantic_queries(req.queries)?; + let prepared = build_prepared_search_request(PreparedSearchRequestInput { + semantic_queries, + params: req.params, + subset: req.subset, + text_query: req.text_query, + alpha: req.alpha, + fusion: req.fusion, + filter_condition: req.filter_condition, + filter_parameters: req.filter_parameters, + default_top_k: state.config.default_top_k, + })?; + let output = execute_search_prepared(state.clone(), name.clone(), prepared).await?; + + let total_ms = start.elapsed().as_millis() as u64; tracing::info!( - trace_id = %trace_id, + trace_id = %trace_id_value, index = %name, - mode = mode, - num_queries = num_queries, - top_k = top_k, - total_results = total_results, + mode = output.metrics.mode, + num_queries = output.metrics.num_queries, + top_k = output.metrics.top_k, + total_results = output.metrics.total_results, total_ms = total_ms, "search.complete" ); if total_ms > 1000 { - tracing::warn!(trace_id = %trace_id, index = %name, total_ms = total_ms, "search.slow"); + tracing::warn!( + trace_id = %trace_id_value, + index = %name, + total_ms = total_ms, + "search.slow" + ); } - Ok(PrettyJson(SearchResponse { - num_queries, - results: all_results, - })) + Ok(PrettyJson(output.response)) } /// Search with a pre-filtered subset from metadata query. -/// -/// This is a convenience endpoint that combines metadata filtering and search. #[utoipa::path( post, path = "/indices/{name}/search/filtered", @@ -435,8 +732,7 @@ pub async fn search_filtered( return Err(ApiError::BadRequest("No queries provided".to_string())); } - // Convert to unified SearchRequest with filter_condition - let search_req = SearchRequest { + let search_request = SearchRequest { queries: Some(req.queries), params: req.params, subset: None, @@ -447,13 +743,10 @@ pub async fn search_filtered( filter_parameters: Some(req.filter_parameters), }; - search(State(state), Path(name), trace_id, Json(search_req)).await + search(State(state), Path(name), trace_id, Json(search_request)).await } /// Search an index using text queries (requires model to be loaded). -/// -/// This endpoint encodes the text queries using the loaded model and then performs a search. -/// Requires the server to be started with `--model `. #[utoipa::path( post, path = "/indices/{name}/search_with_encoding", @@ -474,34 +767,24 @@ pub async fn search_with_encoding( trace_id: Option>, Json(req): Json, ) -> ApiResult> { - let trace_id_val = trace_id.as_ref().map(|t| t.0.clone()).unwrap_or_default(); - let start = std::time::Instant::now(); + let trace_id_value = trace_id + .as_ref() + .map(|value| value.0.clone()) + .unwrap_or_default(); + let start = StdInstant::now(); if req.queries.is_empty() { return Err(ApiError::BadRequest("No queries provided".to_string())); } let num_queries = req.queries.len(); - - // Encode the text queries (async, uses batch queue) - let encode_start = std::time::Instant::now(); + let encode_start = StdInstant::now(); let query_embeddings = encode_texts_internal(state.clone(), &req.queries, InputType::Query, None).await?; let encode_ms = encode_start.elapsed().as_millis() as u64; - // Convert to QueryEmbeddings format - let queries: Vec = query_embeddings - .into_iter() - .map(|arr| QueryEmbeddings { - embeddings: Some(arr.rows().into_iter().map(|r| r.to_vec()).collect()), - embeddings_b64: None, - shape: None, - }) - .collect(); - - // Create a standard SearchRequest (pass through hybrid fields) - let search_req = SearchRequest { - queries: Some(queries), + let prepared = build_prepared_search_request(PreparedSearchRequestInput { + semantic_queries: Some(query_embeddings), params: req.params, subset: req.subset, text_query: req.text_query, @@ -509,15 +792,15 @@ pub async fn search_with_encoding( fusion: req.fusion, filter_condition: None, filter_parameters: None, - }; - - // Delegate to the standard search - let result = search(State(state), Path(name.clone()), trace_id, Json(search_req)).await; + default_top_k: state.config.default_top_k, + })?; + let response = execute_search_prepared(state.clone(), name.clone(), prepared) + .await? + .response; let total_ms = start.elapsed().as_millis() as u64; - tracing::info!( - trace_id = %trace_id_val, + trace_id = %trace_id_value, index = %name, num_queries = num_queries, encode_ms = encode_ms, @@ -525,13 +808,10 @@ pub async fn search_with_encoding( "search.with_encoding.complete" ); - result + Ok(PrettyJson(response)) } /// Search with text queries and a metadata filter (requires model to be loaded). -/// -/// This endpoint encodes the text queries using the loaded model and performs a filtered search. -/// Requires the server to be started with `--model `. #[utoipa::path( post, path = "/indices/{name}/search/filtered_with_encoding", @@ -552,34 +832,24 @@ pub async fn search_filtered_with_encoding( trace_id: Option>, Json(req): Json, ) -> ApiResult> { - let trace_id_val = trace_id.as_ref().map(|t| t.0.clone()).unwrap_or_default(); - let start = std::time::Instant::now(); + let trace_id_value = trace_id + .as_ref() + .map(|value| value.0.clone()) + .unwrap_or_default(); + let start = StdInstant::now(); if req.queries.is_empty() { return Err(ApiError::BadRequest("No queries provided".to_string())); } let num_queries = req.queries.len(); - - // Encode the text queries (async, uses batch queue) - let encode_start = std::time::Instant::now(); + let encode_start = StdInstant::now(); let query_embeddings = encode_texts_internal(state.clone(), &req.queries, InputType::Query, None).await?; let encode_ms = encode_start.elapsed().as_millis() as u64; - // Convert to QueryEmbeddings format - let queries: Vec = query_embeddings - .into_iter() - .map(|arr| QueryEmbeddings { - embeddings: Some(arr.rows().into_iter().map(|r| r.to_vec()).collect()), - embeddings_b64: None, - shape: None, - }) - .collect(); - - // Create a unified SearchRequest with filter (pass through hybrid fields) - let search_req = SearchRequest { - queries: Some(queries), + let prepared = build_prepared_search_request(PreparedSearchRequestInput { + semantic_queries: Some(query_embeddings), params: req.params, subset: None, text_query: req.text_query, @@ -587,15 +857,15 @@ pub async fn search_filtered_with_encoding( fusion: req.fusion, filter_condition: Some(req.filter_condition.clone()), filter_parameters: Some(req.filter_parameters), - }; - - // Delegate to the unified search handler - let result = search(State(state), Path(name.clone()), trace_id, Json(search_req)).await; + default_top_k: state.config.default_top_k, + })?; + let response = execute_search_prepared(state.clone(), name.clone(), prepared) + .await? + .response; let total_ms = start.elapsed().as_millis() as u64; - tracing::info!( - trace_id = %trace_id_val, + trace_id = %trace_id_value, index = %name, num_queries = num_queries, filter = %req.filter_condition, @@ -604,5 +874,86 @@ pub async fn search_filtered_with_encoding( "search.filtered_with_encoding.complete" ); - result + Ok(PrettyJson(response)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn intersect_subset_with_filtered_ids_preserves_subset_order() { + let actual = intersect_subset_with_filtered_ids(vec![8, 3, 8, 2, 9], vec![2, 3, 8]); + + assert_eq!(actual, vec![8, 3, 2]); + } + + #[test] + fn build_search_request_config_rejects_zero_top_k() { + let error = build_search_request_config( + 1, + true, + &[], + &SearchParamsRequest { + top_k: Some(0), + n_ivf_probe: None, + n_full_scores: None, + centroid_score_threshold: None, + }, + None, + None, + 10, + ) + .expect_err("top_k=0 should be rejected"); + + match error { + ApiError::BadRequest(message) => { + assert!(message.contains("top_k must be greater than 0")); + } + other => panic!("unexpected error: {}", other), + } + } + + #[test] + fn build_search_request_config_rejects_hybrid_when_n_full_scores_below_fetch_k() { + let error = build_search_request_config( + 1, + true, + &["keyword".to_string()], + &SearchParamsRequest { + top_k: Some(10), + n_ivf_probe: Some(8), + n_full_scores: Some(20), + centroid_score_threshold: None, + }, + Some(0.75), + Some("rrf"), + 10, + ) + .expect_err("hybrid search should reject n_full_scores below fetch_k"); + + match error { + ApiError::BadRequest(message) => { + assert!(message.contains("hybrid search requires n_full_scores")); + } + other => panic!("unexpected error: {}", other), + } + } + + #[test] + fn map_keyword_search_error_marks_malformed_match_as_bad_request() { + let error = map_keyword_search_error( + "test", + next_plaid::Error::Filtering( + "FTS5 query failed: malformed MATCH expression: [\"]".to_string(), + ), + ); + + match error { + ApiError::BadRequest(message) => { + assert!(message.contains("Invalid keyword query")); + } + other => panic!("unexpected error: {}", other), + } + } } diff --git a/next-plaid-api/src/main.rs b/next-plaid-api/src/main.rs index d15f1afc..209a0573 100644 --- a/next-plaid-api/src/main.rs +++ b/next-plaid-api/src/main.rs @@ -502,6 +502,106 @@ fn build_router(state: Arc) -> Router { .merge(api_router) } +#[cfg(feature = "model")] +fn execution_provider_name(use_cuda: bool) -> &'static str { + if use_cuda { + "cuda" + } else { + "cpu" + } +} + +#[cfg(feature = "model")] +fn build_model_instance( + model_config: &state::ModelConfig, +) -> Result { + let execution_provider = if model_config.use_cuda { + next_plaid_onnx::ExecutionProvider::Cuda + } else { + next_plaid_onnx::ExecutionProvider::Cpu + }; + + let mut builder = next_plaid_onnx::Colbert::builder(&model_config.path) + .with_execution_provider(execution_provider) + .with_quantized(model_config.use_int8); + + if let Some(value) = model_config.parallel_sessions { + builder = builder.with_parallel(value); + } + if let Some(value) = model_config.batch_size { + builder = builder.with_batch_size(value); + } + if let Some(value) = model_config.threads { + builder = builder.with_threads(value); + } + if let Some(value) = model_config.query_length { + builder = builder.with_query_length(value); + } + if let Some(value) = model_config.document_length { + builder = builder.with_document_length(value); + } + + builder.build().map_err(|error| error.to_string()) +} + +#[cfg(feature = "model")] +fn log_loaded_model( + model_path: &std::path::Path, + model: &next_plaid_onnx::Colbert, + use_cuda: bool, + use_int8: bool, +) { + let config = model.config(); + tracing::info!( + model_path = %model_path.display(), + model_name = ?config.model_name(), + execution_provider = execution_provider_name(use_cuda), + quantized = use_int8, + embedding_dim = model.embedding_dim(), + batch_size = model.batch_size(), + num_sessions = model.num_sessions(), + query_length = config.query_length, + document_length = config.document_length, + query_expansion = config.do_query_expansion, + "model.load.complete" + ); +} + +#[cfg(feature = "model")] +fn build_model_pool( + model: next_plaid_onnx::Colbert, + model_config: state::ModelConfig, + pool_size: usize, +) -> state::ModelPool { + let loaded_model_config = model.config(); + let cached_info = state::CachedModelInfo { + name: loaded_model_config + .model_name() + .map(|value| value.to_string()), + path: model_config.path.to_string_lossy().to_string(), + quantized: model_config.use_int8, + embedding_dim: model.embedding_dim(), + batch_size: model.batch_size(), + num_sessions: model.num_sessions(), + query_prefix: loaded_model_config.query_prefix.clone(), + document_prefix: loaded_model_config.document_prefix.clone(), + query_length: loaded_model_config.query_length, + document_length: loaded_model_config.document_length, + do_query_expansion: loaded_model_config.do_query_expansion, + uses_token_type_ids: loaded_model_config.uses_token_type_ids, + mask_token_id: loaded_model_config.mask_token_id, + pad_token_id: loaded_model_config.pad_token_id, + }; + + drop(model); + + state::ModelPool { + pool_size, + model_config, + cached_info, + } +} + #[tokio::main] async fn main() { // Initialize tracing @@ -521,6 +621,7 @@ async fn main() { let mut index_dir = PathBuf::from("./indices"); let mut model_path: Option = None; let mut _use_cuda = false; + let mut _query_on_cpu = false; let mut _use_int8 = false; let mut _parallel_sessions: Option = None; let mut _batch_size: Option = None; @@ -575,6 +676,10 @@ async fn main() { _use_cuda = true; i += 1; } + "--query-on-cpu" => { + _query_on_cpu = true; + i += 1; + } "--int8" => { _use_int8 = true; i += 1; @@ -663,6 +768,8 @@ Options: -d, --index-dir Directory for storing indices (default: ./indices) -m, --model Path to ONNX model directory for encoding (optional) --cuda Use CUDA for model inference (requires --model) + --query-on-cpu CUDA route only: run query/rerank encoding on CPU while + keeping ingest encoding on CUDA --int8 Use INT8 quantized model for faster inference (requires --model) --parallel Number of parallel ONNX sessions (default: 1) More sessions = more parallelism but also more memory. @@ -691,6 +798,8 @@ Examples: next-plaid-api -p 3000 -d /data/indices # Custom port and directory next-plaid-api --model ./models/colbert # Enable text encoding next-plaid-api --model ./models/colbert --cuda # Enable encoding with CUDA + next-plaid-api --model ./models/colbert --cuda --query-on-cpu + # Query/rerank on CPU, ingest on CUDA next-plaid-api --model ./models/colbert --int8 # Enable encoding with INT8 quantization next-plaid-api --model ./models/colbert --parallel 16 # 16 parallel sessions for high throughput next-plaid-api --model ./models/colbert --parallel 8 --batch-size 4 # Fine-tuned parallel config @@ -708,6 +817,15 @@ Examples: } } + if _query_on_cpu && !_use_cuda { + eprintln!("Error: --query-on-cpu requires --cuda"); + std::process::exit(1); + } + if _query_on_cpu && model_path.is_none() { + eprintln!("Error: --query-on-cpu requires --model "); + std::process::exit(1); + } + // Create config let config = ApiConfig { index_dir, @@ -719,69 +837,6 @@ Examples: "server.starting" ); - // Load model if specified - #[cfg(feature = "model")] - let model = if let Some(ref model_path) = model_path { - let execution_provider = if _use_cuda { - next_plaid_onnx::ExecutionProvider::Cuda - } else { - next_plaid_onnx::ExecutionProvider::Cpu - }; - - let mut builder = next_plaid_onnx::Colbert::builder(model_path) - .with_execution_provider(execution_provider) - .with_quantized(_use_int8); - - // Apply optional model configuration - if let Some(parallel) = _parallel_sessions { - builder = builder.with_parallel(parallel); - } - if let Some(batch_size) = _batch_size { - builder = builder.with_batch_size(batch_size); - } - if let Some(threads) = _threads { - builder = builder.with_threads(threads); - } - if let Some(query_length) = _query_length { - builder = builder.with_query_length(query_length); - } - if let Some(document_length) = _document_length { - builder = builder.with_document_length(document_length); - } - - match builder.build() { - Ok(model) => { - let cfg = model.config(); - tracing::info!( - model_path = %model_path.display(), - model_name = ?cfg.model_name(), - execution_provider = if _use_cuda { "cuda" } else { "cpu" }, - quantized = _use_int8, - embedding_dim = model.embedding_dim(), - batch_size = model.batch_size(), - num_sessions = model.num_sessions(), - query_length = cfg.query_length, - document_length = cfg.document_length, - query_expansion = cfg.do_query_expansion, - "model.load.complete" - ); - Some(model) - } - Err(e) => { - tracing::error!( - model_path = %model_path.display(), - error = %e, - "model.load.failed" - ); - eprintln!("Error: Failed to load model from {:?}: {}", model_path, e); - std::process::exit(1); - } - } - } else { - tracing::debug!("model.disabled"); - None - }; - // Create state #[cfg(feature = "model")] let state = { @@ -789,36 +844,11 @@ Examples: path: path.to_string_lossy().to_string(), quantized: _use_int8, }); + let pool_size = _model_pool_size.unwrap_or(1); - // Create model pool if model was loaded successfully - let model_pool = model.map(|m| { - let model_cfg = m.config(); - let pool_size = _model_pool_size.unwrap_or(1); - - // Create cached model info for lock-free health endpoint access - let cached_info = state::CachedModelInfo { - name: model_cfg.model_name().map(|s| s.to_string()), - path: model_path - .as_ref() - .map(|p| p.to_string_lossy().to_string()) - .unwrap_or_default(), - quantized: _use_int8, - embedding_dim: m.embedding_dim(), - batch_size: m.batch_size(), - num_sessions: m.num_sessions(), - query_prefix: model_cfg.query_prefix.clone(), - document_prefix: model_cfg.document_prefix.clone(), - query_length: model_cfg.query_length, - document_length: model_cfg.document_length, - do_query_expansion: model_cfg.do_query_expansion, - uses_token_type_ids: model_cfg.uses_token_type_ids, - mask_token_id: model_cfg.mask_token_id, - pad_token_id: model_cfg.pad_token_id, - }; - - // Create model config for workers to build their own instances - let model_config = state::ModelConfig { - path: model_path.clone().unwrap(), + if let Some(ref loaded_model_path) = model_path { + let ingest_model_config = state::ModelConfig { + path: loaded_model_path.clone(), use_cuda: _use_cuda, use_int8: _use_int8, parallel_sessions: _parallel_sessions, @@ -827,18 +857,105 @@ Examples: query_length: _query_length, document_length: _document_length, }; + if _query_on_cpu { + let query_model_config = state::ModelConfig { + path: loaded_model_path.clone(), + use_cuda: false, + use_int8: _use_int8, + parallel_sessions: _parallel_sessions, + batch_size: _batch_size, + threads: _threads, + query_length: _query_length, + document_length: _document_length, + }; + let ingest_model = match build_model_instance(&state::ModelConfig { + use_cuda: true, + ..ingest_model_config.clone() + }) { + Ok(model) => model, + Err(error) => { + tracing::error!( + model_path = %loaded_model_path.display(), + lane = "ingest", + execution_provider = execution_provider_name(true), + error = %error, + "model.load.failed" + ); + eprintln!( + "Error: Failed to load ingest model from {:?}: {}", + loaded_model_path, error + ); + std::process::exit(1); + } + }; + log_loaded_model(loaded_model_path, &ingest_model, true, _use_int8); + + let query_model = match build_model_instance(&query_model_config) { + Ok(model) => model, + Err(error) => { + tracing::error!( + model_path = %loaded_model_path.display(), + lane = "query", + execution_provider = execution_provider_name(false), + error = %error, + "model.load.failed" + ); + eprintln!( + "Error: Failed to load query model from {:?}: {}", + loaded_model_path, error + ); + std::process::exit(1); + } + }; + log_loaded_model(loaded_model_path, &query_model, false, _use_int8); + + let ingest_pool = build_model_pool( + ingest_model, + state::ModelConfig { + use_cuda: true, + ..ingest_model_config.clone() + }, + pool_size, + ); + let query_pool = build_model_pool(query_model, query_model_config, pool_size); + + Arc::new(AppState::with_dual_model_pools( + config, + ingest_pool, + query_pool, + model_info, + )) + } else { + let model = match build_model_instance(&ingest_model_config) { + Ok(model) => model, + Err(error) => { + tracing::error!( + model_path = %loaded_model_path.display(), + execution_provider = execution_provider_name(_use_cuda), + error = %error, + "model.load.failed" + ); + eprintln!( + "Error: Failed to load model from {:?}: {}", + loaded_model_path, error + ); + std::process::exit(1); + } + }; + log_loaded_model(loaded_model_path, &model, _use_cuda, _use_int8); - // Drop the initial model - workers will create their own - drop(m); + let model_pool = build_model_pool(model, ingest_model_config, pool_size); - state::ModelPool { - pool_size, - model_config, - cached_info, + Arc::new(AppState::with_model_pool( + config, + Some(model_pool), + model_info, + )) } - }); - - Arc::new(AppState::with_model_pool(config, model_pool, model_info)) + } else { + tracing::debug!("model.disabled"); + Arc::new(AppState::with_model_pool(config, None, model_info)) + } }; #[cfg(not(feature = "model"))] diff --git a/next-plaid-api/src/state.rs b/next-plaid-api/src/state.rs index 960ee9dd..4e076f93 100644 --- a/next-plaid-api/src/state.rs +++ b/next-plaid-api/src/state.rs @@ -125,6 +125,32 @@ pub struct ModelPool { pub cached_info: CachedModelInfo, } +#[cfg(feature = "model")] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum EncodeLane { + Query, + Ingest, +} + +#[cfg(feature = "model")] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum EncodePoolKind { + Unified, + Query, + Ingest, +} + +#[cfg(feature = "model")] +impl EncodePoolKind { + pub fn as_str(self) -> &'static str { + match self { + Self::Unified => "unified", + Self::Query => "query", + Self::Ingest => "ingest", + } + } +} + /// Cached model information that doesn't require locking. /// This information is immutable after model initialization. #[cfg(feature = "model")] @@ -174,12 +200,129 @@ pub struct AppState { /// Optional model pool for concurrent encoding #[cfg(feature = "model")] pub model_pool: Option, + /// Optional dedicated pool for query/rerank encoding. + #[cfg(feature = "model")] + pub query_model_pool: Option, + /// Whether query/rerank should use the dedicated query pool. + #[cfg(feature = "model")] + pub query_on_cpu: bool, /// Model configuration info (path, quantization status) - for logging #[cfg(feature = "model")] #[allow(dead_code)] pub model_info: Option, } +#[cfg(all(test, feature = "model"))] +mod tests { + use tempfile::TempDir; + + use super::{ + ApiConfig, AppState, CachedModelInfo, EncodeLane, EncodePoolKind, ModelConfig, ModelPool, + }; + + fn dummy_model_pool(path: &str, use_cuda: bool, pool_size: usize) -> ModelPool { + ModelPool { + pool_size, + model_config: ModelConfig { + path: std::path::PathBuf::from(path), + use_cuda, + use_int8: false, + parallel_sessions: Some(1), + batch_size: Some(4), + threads: None, + query_length: Some(48), + document_length: Some(300), + }, + cached_info: CachedModelInfo { + name: Some("dummy".to_string()), + path: path.to_string(), + quantized: false, + embedding_dim: 48, + batch_size: 4, + num_sessions: 1, + query_prefix: "[Q] ".to_string(), + document_prefix: "[D] ".to_string(), + query_length: 48, + document_length: 300, + do_query_expansion: false, + uses_token_type_ids: false, + mask_token_id: 50284, + pad_token_id: 50284, + }, + } + } + + #[test] + fn query_on_cpu_routes_logical_lanes_to_distinct_pools() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let config = ApiConfig { + index_dir: temp_dir.path().to_path_buf(), + default_top_k: 10, + }; + + let state = AppState::with_dual_model_pools( + config, + dummy_model_pool("/models/ingest", true, 2), + dummy_model_pool("/models/query", false, 1), + None, + ); + + assert_eq!( + state.encode_pool_kind_for_lane(EncodeLane::Query), + EncodePoolKind::Query + ); + assert_eq!( + state.encode_pool_kind_for_lane(EncodeLane::Ingest), + EncodePoolKind::Ingest + ); + assert!( + !state + .model_pool_for_kind(EncodePoolKind::Query) + .expect("query pool must exist") + .model_config + .use_cuda + ); + assert!( + state + .model_pool_for_kind(EncodePoolKind::Ingest) + .expect("ingest pool must exist") + .model_config + .use_cuda + ); + } + + #[test] + fn single_model_pool_keeps_unified_encode_lane() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let config = ApiConfig { + index_dir: temp_dir.path().to_path_buf(), + default_top_k: 10, + }; + + let state = AppState::with_model_pool( + config, + Some(dummy_model_pool("/models/unified", true, 2)), + None, + ); + + assert_eq!( + state.encode_pool_kind_for_lane(EncodeLane::Query), + EncodePoolKind::Unified + ); + assert_eq!( + state.encode_pool_kind_for_lane(EncodeLane::Ingest), + EncodePoolKind::Unified + ); + assert!( + state + .model_pool_for_kind(EncodePoolKind::Unified) + .expect("unified pool must exist") + .model_config + .use_cuda + ); + } +} + impl AppState { /// Create a new application state (without model feature). #[cfg(not(feature = "model"))] @@ -213,6 +356,31 @@ impl AppState { indices: RwLock::new(HashMap::new()), index_configs: RwLock::new(HashMap::new()), model_pool, + query_model_pool: None, + query_on_cpu: false, + model_info, + } + } + + /// Create a new application state with dedicated ingest and query pools. + #[cfg(feature = "model")] + pub fn with_dual_model_pools( + config: ApiConfig, + ingest_model_pool: ModelPool, + query_model_pool: ModelPool, + model_info: Option, + ) -> Self { + if !config.index_dir.exists() { + std::fs::create_dir_all(&config.index_dir).ok(); + } + + Self { + config, + indices: RwLock::new(HashMap::new()), + index_configs: RwLock::new(HashMap::new()), + model_pool: Some(ingest_model_pool), + query_model_pool: Some(query_model_pool), + query_on_cpu: true, model_info, } } @@ -275,13 +443,39 @@ impl AppState { /// Check if model pool is available. #[cfg(feature = "model")] pub fn has_model(&self) -> bool { - self.model_pool.is_some() + self.model_pool.is_some() || self.query_model_pool.is_some() } /// Get cached model info if model pool is available. #[cfg(feature = "model")] pub fn cached_model_info(&self) -> Option<&CachedModelInfo> { - self.model_pool.as_ref().map(|p| &p.cached_info) + self.model_pool + .as_ref() + .or(self.query_model_pool.as_ref()) + .map(|pool| &pool.cached_info) + } + + /// Resolve the physical worker pool for a logical encoding lane. + #[cfg(feature = "model")] + pub fn encode_pool_kind_for_lane(&self, lane: EncodeLane) -> EncodePoolKind { + if self.query_on_cpu { + match lane { + EncodeLane::Query => EncodePoolKind::Query, + EncodeLane::Ingest => EncodePoolKind::Ingest, + } + } else { + EncodePoolKind::Unified + } + } + + /// Get the model pool for a physical worker pool. + #[cfg(feature = "model")] + pub fn model_pool_for_kind(&self, pool_kind: EncodePoolKind) -> Option<&ModelPool> { + match pool_kind { + EncodePoolKind::Unified => self.model_pool.as_ref(), + EncodePoolKind::Query => self.query_model_pool.as_ref().or(self.model_pool.as_ref()), + EncodePoolKind::Ingest => self.model_pool.as_ref(), + } } /// Get the path for an index by name. From 7dc824f92651ac30ad1984f86389ed90c87aa82d Mon Sep 17 00:00:00 2001 From: Anton Chen Date: Mon, 20 Apr 2026 19:28:49 +0800 Subject: [PATCH 2/2] Fix qml parser clippy warning --- colgrep/src/parser/qml.rs | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/colgrep/src/parser/qml.rs b/colgrep/src/parser/qml.rs index 053aa3b9..7fdb781b 100644 --- a/colgrep/src/parser/qml.rs +++ b/colgrep/src/parser/qml.rs @@ -432,17 +432,15 @@ fn extract_object_variables(node: Node, bytes: &[u8]) -> Vec { push_unique(&mut variables, name); } } - "ui_binding" => { - if field_text(child, "name", bytes).as_deref() == Some("id") { - if let Some(value) = child.child_by_field_name("value").and_then(|value| { - value - .utf8_text(bytes) - .ok() - .map(|text| text.trim().to_string()) - }) { - if is_simple_identifier(&value) { - push_unique(&mut variables, value); - } + "ui_binding" if field_text(child, "name", bytes).as_deref() == Some("id") => { + if let Some(value) = child.child_by_field_name("value").and_then(|value| { + value + .utf8_text(bytes) + .ok() + .map(|text| text.trim().to_string()) + }) { + if is_simple_identifier(&value) { + push_unique(&mut variables, value); } } }