diff --git a/next-plaid-onnx/src/lib.rs b/next-plaid-onnx/src/lib.rs index 86d5f0b..359c554 100644 --- a/next-plaid-onnx/src/lib.rs +++ b/next-plaid-onnx/src/lib.rs @@ -228,7 +228,7 @@ fn get_cuda_device_id() -> i32 { fn configured_cuda_execution_provider() -> CUDAExecutionProvider { CUDAExecutionProvider::default() .with_device_id(get_cuda_device_id()) - .with_tf32(true) + .with_tf32(false) } /// Check if CPU-only mode is forced via environment variable. @@ -664,6 +664,13 @@ pub struct PreparedDocumentBatch { original_lengths: Vec, is_query: bool, filter_skiplist: bool, + /// Position of each document in the original input slice passed to + /// `tokenize_documents_in_batches`. Used to restore input order in + /// `encode_prepared_document_batches` after GPU dynamic batching + /// reorders documents by length. For batches produced outside of + /// `tokenize_documents_in_batches`, this is empty and no reordering + /// is applied. + original_input_indices: Vec, } struct TokenizedDocument { @@ -1070,12 +1077,15 @@ impl Colbert { let batch_docs = self.batch_size.max(1); let mut batches = Vec::new(); - let mut tokenized_iter = tokenized.into_iter(); - while let Some(first) = tokenized_iter.next() { + let mut tokenized_iter = tokenized.into_iter().enumerate(); + while let Some((first_idx, first)) = tokenized_iter.next() { let mut piece_encodings = Vec::with_capacity(batch_docs); + let mut piece_indices = Vec::with_capacity(batch_docs); piece_encodings.push(first); - for encoding in tokenized_iter.by_ref().take(batch_docs - 1) { + piece_indices.push(first_idx); + for (idx, encoding) in tokenized_iter.by_ref().take(batch_docs - 1) { piece_encodings.push(encoding); + piece_indices.push(idx); } batches.push(prepare_batch_from_tokenized_documents( @@ -1084,6 +1094,7 @@ impl Colbert { piece_encodings, false, true, + piece_indices, )?); } @@ -1094,36 +1105,46 @@ impl Colbert { // length and bucketed into fixed shapes (quantized to 32-token steps). // This lets the GPU reuse execution plans across batches with the same // shape, reducing kernel launch overhead and minimizing padding waste. + // We carry the original input index alongside each tokenized doc so + // `encode_prepared_document_batches` can restore the caller-visible + // input order in the returned embeddings. let prepared_lengths: Vec = tokenized .iter() .map(|doc| doc.ids.len().min(truncate_limit) + 1) .collect(); - let mut items: Vec<(usize, TokenizedDocument)> = - prepared_lengths.into_iter().zip(tokenized).collect(); - items.sort_by_key(|(prepared_len, _)| *prepared_len); + let mut items: Vec<(usize, usize, TokenizedDocument)> = prepared_lengths + .into_iter() + .zip(tokenized) + .enumerate() + .map(|(idx, (len, doc))| (len, idx, doc)) + .collect(); + items.sort_by_key(|(prepared_len, _, _)| *prepared_len); let shapes = build_fixed_dynamic_shapes(self.batch_size.max(1), self.config.document_length); - let mut buckets: Vec> = + let mut buckets: Vec> = (0..shapes.len()).map(|_| Vec::new()).collect(); - for (prepared_len, encoding) in items { + for (prepared_len, orig_idx, encoding) in items { let bucket_idx = shapes .iter() .position(|shape| prepared_len <= shape.planned_len) .unwrap_or(shapes.len().saturating_sub(1)); - buckets[bucket_idx].push(encoding); + buckets[bucket_idx].push((orig_idx, encoding)); } let mut batches = Vec::new(); for (shape, bucket_docs) in shapes.iter().zip(buckets) { let docs_per_batch = shape.docs.max(1); let mut bucket_iter = bucket_docs.into_iter(); - while let Some(first) = bucket_iter.next() { + while let Some((first_idx, first)) = bucket_iter.next() { let mut piece_encodings = Vec::with_capacity(docs_per_batch); + let mut piece_indices = Vec::with_capacity(docs_per_batch); piece_encodings.push(first); - for encoding in bucket_iter.by_ref().take(docs_per_batch - 1) { + piece_indices.push(first_idx); + for (idx, encoding) in bucket_iter.by_ref().take(docs_per_batch - 1) { piece_encodings.push(encoding); + piece_indices.push(idx); } batches.push(prepare_batch_from_tokenized_documents( &self.tokenizer, @@ -1131,6 +1152,7 @@ impl Colbert { piece_encodings, false, true, + piece_indices, )?); } } @@ -1156,46 +1178,89 @@ impl Colbert { return Ok(Vec::new()); } - if self.sessions.len() <= 1 || prepared_batches.len() == 1 { + // Collect the original-input position for every document across all + // batches in the order they appear here. When `tokenize_documents_in_batches` + // sorts documents by length (GPU dynamic batching path) the embeddings + // come out in a permuted order; we restore the caller's input order + // before returning so downstream consumers (which index embeddings by + // input position) get correct (doc, embedding) pairs. + let mut combined_indices: Vec = + Vec::with_capacity(prepared_batches.iter().map(|b| b.batch_size).sum()); + let mut has_reordering = false; + for batch in &prepared_batches { + if !batch.original_input_indices.is_empty() { + combined_indices.extend_from_slice(&batch.original_input_indices); + has_reordering = true; + } + } + + let encoded: Vec> = if self.sessions.len() <= 1 || prepared_batches.len() == 1 { let mut all_embeddings = Vec::new(); for prepared_batch in prepared_batches { all_embeddings.extend(self.encode_prepared_documents(prepared_batch)?); } - return Ok(all_embeddings); - } + all_embeddings + } else { + let results: Vec>>> = std::thread::scope(|scope| { + let mut handles = Vec::with_capacity(prepared_batches.len()); - let results: Vec>>> = std::thread::scope(|scope| { - let mut handles = Vec::with_capacity(prepared_batches.len()); - - for (i, prepared_batch) in prepared_batches.into_iter().enumerate() { - let session_idx = i % self.sessions.len(); - let session_mutex = &self.sessions[session_idx]; - let config = &self.config; - let skiplist_ids = &self.skiplist_ids; - - handles.push(scope.spawn(move || { - let mut session = session_mutex.lock().unwrap(); - encode_prepared_batch_with_session( - &mut session, - config, - skiplist_ids, - prepared_batch, - ) - })); - } + for (i, prepared_batch) in prepared_batches.into_iter().enumerate() { + let session_idx = i % self.sessions.len(); + let session_mutex = &self.sessions[session_idx]; + let config = &self.config; + let skiplist_ids = &self.skiplist_ids; - handles - .into_iter() - .map(|handle| handle.join().unwrap()) - .collect() - }); + handles.push(scope.spawn(move || { + let mut session = session_mutex.lock().unwrap(); + encode_prepared_batch_with_session( + &mut session, + config, + skiplist_ids, + prepared_batch, + ) + })); + } - let mut all_embeddings = Vec::new(); - for result in results { - all_embeddings.extend(result?); + handles + .into_iter() + .map(|handle| handle.join().unwrap()) + .collect() + }); + + let mut all_embeddings = Vec::new(); + for result in results { + all_embeddings.extend(result?); + } + all_embeddings + }; + + if !has_reordering || combined_indices.len() != encoded.len() { + return Ok(encoded); } - Ok(all_embeddings) + // Restore input order: encoded[i] belongs at output position combined_indices[i]. + let n = encoded.len(); + let mut reordered: Vec>> = (0..n).map(|_| None).collect(); + for (encoded_pos, embedding) in encoded.into_iter().enumerate() { + let target = combined_indices[encoded_pos]; + if target >= n { + anyhow::bail!( + "original_input_indices points to out-of-range slot ({} >= {})", + target, + n + ); + } + reordered[target] = Some(embedding); + } + reordered + .into_iter() + .enumerate() + .map(|(i, opt)| { + opt.ok_or_else(|| { + anyhow::anyhow!("original_input_indices missing slot {} in output", i) + }) + }) + .collect() } /// Stream document embeddings chunk-by-chunk. @@ -1680,6 +1745,7 @@ fn prepare_batch_for_session( original_lengths: Vec::new(), is_query, filter_skiplist, + original_input_indices: Vec::new(), }); } @@ -1701,6 +1767,7 @@ fn prepare_batch_from_tokenized_documents( batch_docs: Vec, is_query: bool, filter_skiplist: bool, + original_input_indices: Vec, ) -> Result { let (prefix_str, prefix_token_id_opt, max_length) = if is_query { ( @@ -1814,6 +1881,7 @@ fn prepare_batch_from_tokenized_documents( original_lengths, is_query, filter_skiplist, + original_input_indices, }) } @@ -1957,6 +2025,10 @@ fn prepare_batch_from_tokenizer_encodings( original_lengths, is_query, filter_skiplist, + // No reordering happens in this code path — callers that need to + // restore an original input order should populate this themselves + // before calling `encode_prepared_document_batches`. + original_input_indices: Vec::new(), }) } @@ -1976,6 +2048,7 @@ fn encode_prepared_batch_with_session( original_lengths, is_query, filter_skiplist, + original_input_indices: _, } = prepared; if batch_size == 0 {