Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions colgrep/src/parser/qml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -432,17 +432,15 @@ fn extract_object_variables(node: Node, bytes: &[u8]) -> Vec<String> {
push_unique(&mut variables, name);
}
}
"ui_binding" => {
if field_text(child, "name", bytes).as_deref() == Some("id") {
if let Some(value) = child.child_by_field_name("value").and_then(|value| {
value
.utf8_text(bytes)
.ok()
.map(|text| text.trim().to_string())
}) {
if is_simple_identifier(&value) {
push_unique(&mut variables, value);
}
"ui_binding" if field_text(child, "name", bytes).as_deref() == Some("id") => {
if let Some(value) = child.child_by_field_name("value").and_then(|value| {
value
.utf8_text(bytes)
.ok()
.map(|text| text.trim().to_string())
}) {
if is_simple_identifier(&value) {
push_unique(&mut variables, value);
}
}
}
Expand Down
106 changes: 87 additions & 19 deletions next-plaid-api/src/handlers/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ use crate::error::{ApiError, ApiResult};
use crate::models::InputType;
use crate::models::{EncodeRequest, EncodeResponse};
use crate::state::AppState;
#[cfg(feature = "model")]
use crate::state::{EncodeLane, EncodePoolKind};

// --- Batch Configuration ---

Expand Down Expand Up @@ -82,28 +84,66 @@ struct EncodeWorkerPool {
sender: mpsc::Sender<EncodeBatchItem>,
}

/// Global encode worker pool (singleton).
#[cfg(feature = "model")]
static ENCODE_WORKER_POOL: OnceLock<std::sync::Mutex<Option<EncodeWorkerPool>>> = OnceLock::new();
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct EncodePoolCacheKey {
pool_kind: EncodePoolKind,
path: std::path::PathBuf,
use_cuda: bool,
use_int8: bool,
parallel_sessions: Option<usize>,
batch_size: Option<usize>,
threads: Option<usize>,
query_length: Option<usize>,
document_length: Option<usize>,
pool_size: usize,
}

/// Get or create the global encode worker pool.
/// Spawns multiple workers, each owning its own Colbert model instance.
#[cfg(feature = "model")]
fn get_or_create_encode_pool(state: Arc<AppState>) -> ApiResult<mpsc::Sender<EncodeBatchItem>> {
let pool_lock: &std::sync::Mutex<Option<EncodeWorkerPool>> =
ENCODE_WORKER_POOL.get_or_init(|| std::sync::Mutex::new(None));

let mut pool_opt = pool_lock.lock().unwrap();

if let Some(pool) = pool_opt.as_ref() {
return Ok(pool.sender.clone());
impl EncodePoolCacheKey {
fn new(
pool_kind: EncodePoolKind,
model_config: &crate::state::ModelConfig,
pool_size: usize,
) -> Self {
Self {
pool_kind,
path: model_config.path.clone(),
use_cuda: model_config.use_cuda,
use_int8: model_config.use_int8,
parallel_sessions: model_config.parallel_sessions,
batch_size: model_config.batch_size,
threads: model_config.threads,
query_length: model_config.query_length,
document_length: model_config.document_length,
pool_size,
}
}
}

// Get model pool configuration
/// Global encode worker pools keyed by model configuration and physical lane.
#[cfg(feature = "model")]
static ENCODE_WORKER_POOLS: OnceLock<
std::sync::Mutex<HashMap<EncodePoolCacheKey, EncodeWorkerPool>>,
> = OnceLock::new();

/// Get or create the encode worker pool for the selected physical lane.
#[cfg(feature = "model")]
fn get_or_create_encode_pool(
state: Arc<AppState>,
pool_kind: EncodePoolKind,
) -> ApiResult<mpsc::Sender<EncodeBatchItem>> {
let pool_lock: &std::sync::Mutex<HashMap<EncodePoolCacheKey, EncodeWorkerPool>> =
ENCODE_WORKER_POOLS.get_or_init(|| std::sync::Mutex::new(HashMap::new()));
let model_pool = state
.model_pool
.as_ref()
.model_pool_for_kind(pool_kind)
.ok_or_else(|| ApiError::ModelNotLoaded)?;
let cache_key =
EncodePoolCacheKey::new(pool_kind, &model_pool.model_config, model_pool.pool_size);
let mut pools = pool_lock.lock().unwrap();
if let Some(pool) = pools.get(&cache_key) {
return Ok(pool.sender.clone());
}

let pool_size = model_pool.pool_size;
let model_config = model_pool.model_config.clone();
Expand All @@ -115,23 +155,33 @@ fn get_or_create_encode_pool(state: Arc<AppState>) -> ApiResult<mpsc::Sender<Enc
let shared_receiver = Arc::new(tokio::sync::Mutex::new(receiver));

// Spawn N workers, each building and owning its own model
tracing::info!(pool_size = pool_size, "encode.worker.pool.starting");
tracing::info!(
pool_kind = pool_kind.as_str(),
pool_size = pool_size,
"encode.worker.pool.starting"
);

for worker_id in 0..pool_size {
let receiver_clone = Arc::clone(&shared_receiver);
let config_clone = model_config.clone();
let pool_kind_copy = pool_kind;

// Spawn worker in a blocking task since model building is CPU-intensive
tokio::spawn(async move {
// Build model for this worker
let model = match build_model_from_config(&config_clone) {
Ok(m) => {
tracing::info!(worker_id = worker_id, "encode.worker.started");
tracing::info!(
worker_id = worker_id,
pool_kind = pool_kind_copy.as_str(),
"encode.worker.started"
);
m
}
Err(e) => {
tracing::error!(
worker_id = worker_id,
pool_kind = pool_kind_copy.as_str(),
error = %e,
"encode.worker.start.failed"
);
Expand All @@ -147,7 +197,7 @@ fn get_or_create_encode_pool(state: Arc<AppState>) -> ApiResult<mpsc::Sender<Enc
let pool = EncodeWorkerPool {
sender: sender.clone(),
};
*pool_opt = Some(pool);
pools.insert(cache_key, pool);

Ok(sender)
}
Expand Down Expand Up @@ -515,10 +565,28 @@ pub async fn encode_texts_internal(
texts: &[String],
input_type: InputType,
pool_factor: Option<usize>,
) -> ApiResult<Vec<ndarray::Array2<f32>>> {
let encode_lane = match &input_type {
InputType::Query => EncodeLane::Query,
InputType::Document => EncodeLane::Ingest,
};

encode_texts_internal_with_lane(state, texts, input_type, encode_lane, pool_factor).await
}

/// Internal function to encode texts on an explicit logical lane.
#[cfg(feature = "model")]
pub(crate) async fn encode_texts_internal_with_lane(
state: Arc<AppState>,
texts: &[String],
input_type: InputType,
encode_lane: EncodeLane,
pool_factor: Option<usize>,
) -> ApiResult<Vec<ndarray::Array2<f32>>> {
if !state.has_model() {
return Err(ApiError::ModelNotLoaded);
}
let pool_kind = state.encode_pool_kind_for_lane(encode_lane);

// Create oneshot channel for receiving results
let (response_tx, response_rx) = oneshot::channel();
Expand All @@ -532,7 +600,7 @@ pub async fn encode_texts_internal(
};

// Get or create the worker pool
let sender = get_or_create_encode_pool(state)?;
let sender = get_or_create_encode_pool(state, pool_kind)?;

// Send to worker pool
sender.try_send(batch_item).map_err(|e| match e {
Expand Down
10 changes: 8 additions & 2 deletions next-plaid-api/src/handlers/rerank.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ use crate::state::AppState;
use crate::tracing_middleware::TraceId;
use crate::PrettyJson;

#[cfg(feature = "model")]
fn rerank_document_encode_lane() -> crate::state::EncodeLane {
crate::state::EncodeLane::Query
}

/// Convert a Vec<Vec<f32>> to an ndarray::Array2<f32>.
fn to_ndarray(embeddings: &[Vec<f32>]) -> ApiResult<Array2<f32>> {
if embeddings.is_empty() {
Expand Down Expand Up @@ -218,7 +223,7 @@ pub async fn rerank_with_encoding(
trace_id: Option<Extension<TraceId>>,
Json(request): Json<crate::models::RerankWithEncodingRequest>,
) -> ApiResult<PrettyJson<RerankResponse>> {
use crate::handlers::encode::encode_texts_internal;
use crate::handlers::encode::{encode_texts_internal, encode_texts_internal_with_lane};
use crate::models::InputType;

let trace_id = trace_id.map(|t| t.0).unwrap_or_default();
Expand Down Expand Up @@ -256,10 +261,11 @@ pub async fn rerank_with_encoding(
.ok_or_else(|| ApiError::Internal("Failed to encode query".to_string()))?;

// Encode documents
let doc_embeddings = encode_texts_internal(
let doc_embeddings = encode_texts_internal_with_lane(
state,
&request.documents,
InputType::Document,
rerank_document_encode_lane(),
request.pool_factor,
)
.await?;
Expand Down
Loading
Loading