Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 116 additions & 43 deletions next-plaid-onnx/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ fn get_cuda_device_id() -> i32 {
fn configured_cuda_execution_provider() -> CUDAExecutionProvider {
CUDAExecutionProvider::default()
.with_device_id(get_cuda_device_id())
.with_tf32(true)
.with_tf32(false)
}

/// Check if CPU-only mode is forced via environment variable.
Expand Down Expand Up @@ -664,6 +664,13 @@ pub struct PreparedDocumentBatch {
original_lengths: Vec<usize>,
is_query: bool,
filter_skiplist: bool,
/// Position of each document in the original input slice passed to
/// `tokenize_documents_in_batches`. Used to restore input order in
/// `encode_prepared_document_batches` after GPU dynamic batching
/// reorders documents by length. For batches produced outside of
/// `tokenize_documents_in_batches`, this is empty and no reordering
/// is applied.
original_input_indices: Vec<usize>,
}

struct TokenizedDocument {
Expand Down Expand Up @@ -1070,12 +1077,15 @@ impl Colbert {
let batch_docs = self.batch_size.max(1);
let mut batches = Vec::new();

let mut tokenized_iter = tokenized.into_iter();
while let Some(first) = tokenized_iter.next() {
let mut tokenized_iter = tokenized.into_iter().enumerate();
while let Some((first_idx, first)) = tokenized_iter.next() {
let mut piece_encodings = Vec::with_capacity(batch_docs);
let mut piece_indices = Vec::with_capacity(batch_docs);
piece_encodings.push(first);
for encoding in tokenized_iter.by_ref().take(batch_docs - 1) {
piece_indices.push(first_idx);
for (idx, encoding) in tokenized_iter.by_ref().take(batch_docs - 1) {
piece_encodings.push(encoding);
piece_indices.push(idx);
}

batches.push(prepare_batch_from_tokenized_documents(
Expand All @@ -1084,6 +1094,7 @@ impl Colbert {
piece_encodings,
false,
true,
piece_indices,
)?);
}

Expand All @@ -1094,43 +1105,54 @@ impl Colbert {
// length and bucketed into fixed shapes (quantized to 32-token steps).
// This lets the GPU reuse execution plans across batches with the same
// shape, reducing kernel launch overhead and minimizing padding waste.
// We carry the original input index alongside each tokenized doc so
// `encode_prepared_document_batches` can restore the caller-visible
// input order in the returned embeddings.
let prepared_lengths: Vec<usize> = tokenized
.iter()
.map(|doc| doc.ids.len().min(truncate_limit) + 1)
.collect();
let mut items: Vec<(usize, TokenizedDocument)> =
prepared_lengths.into_iter().zip(tokenized).collect();
items.sort_by_key(|(prepared_len, _)| *prepared_len);
let mut items: Vec<(usize, usize, TokenizedDocument)> = prepared_lengths
.into_iter()
.zip(tokenized)
.enumerate()
.map(|(idx, (len, doc))| (len, idx, doc))
.collect();
items.sort_by_key(|(prepared_len, _, _)| *prepared_len);

let shapes =
build_fixed_dynamic_shapes(self.batch_size.max(1), self.config.document_length);
let mut buckets: Vec<Vec<TokenizedDocument>> =
let mut buckets: Vec<Vec<(usize, TokenizedDocument)>> =
(0..shapes.len()).map(|_| Vec::new()).collect();

for (prepared_len, encoding) in items {
for (prepared_len, orig_idx, encoding) in items {
let bucket_idx = shapes
.iter()
.position(|shape| prepared_len <= shape.planned_len)
.unwrap_or(shapes.len().saturating_sub(1));
buckets[bucket_idx].push(encoding);
buckets[bucket_idx].push((orig_idx, encoding));
}

let mut batches = Vec::new();
for (shape, bucket_docs) in shapes.iter().zip(buckets) {
let docs_per_batch = shape.docs.max(1);
let mut bucket_iter = bucket_docs.into_iter();
while let Some(first) = bucket_iter.next() {
while let Some((first_idx, first)) = bucket_iter.next() {
let mut piece_encodings = Vec::with_capacity(docs_per_batch);
let mut piece_indices = Vec::with_capacity(docs_per_batch);
piece_encodings.push(first);
for encoding in bucket_iter.by_ref().take(docs_per_batch - 1) {
piece_indices.push(first_idx);
for (idx, encoding) in bucket_iter.by_ref().take(docs_per_batch - 1) {
piece_encodings.push(encoding);
piece_indices.push(idx);
}
batches.push(prepare_batch_from_tokenized_documents(
&self.tokenizer,
&self.config,
piece_encodings,
false,
true,
piece_indices,
)?);
}
}
Expand All @@ -1156,46 +1178,89 @@ impl Colbert {
return Ok(Vec::new());
}

if self.sessions.len() <= 1 || prepared_batches.len() == 1 {
// Collect the original-input position for every document across all
// batches in the order they appear here. When `tokenize_documents_in_batches`
// sorts documents by length (GPU dynamic batching path) the embeddings
// come out in a permuted order; we restore the caller's input order
// before returning so downstream consumers (which index embeddings by
// input position) get correct (doc, embedding) pairs.
let mut combined_indices: Vec<usize> =
Vec::with_capacity(prepared_batches.iter().map(|b| b.batch_size).sum());
let mut has_reordering = false;
for batch in &prepared_batches {
if !batch.original_input_indices.is_empty() {
combined_indices.extend_from_slice(&batch.original_input_indices);
has_reordering = true;
}
}

let encoded: Vec<Array2<f32>> = if self.sessions.len() <= 1 || prepared_batches.len() == 1 {
let mut all_embeddings = Vec::new();
for prepared_batch in prepared_batches {
all_embeddings.extend(self.encode_prepared_documents(prepared_batch)?);
}
return Ok(all_embeddings);
}
all_embeddings
} else {
let results: Vec<Result<Vec<Array2<f32>>>> = std::thread::scope(|scope| {
let mut handles = Vec::with_capacity(prepared_batches.len());

let results: Vec<Result<Vec<Array2<f32>>>> = std::thread::scope(|scope| {
let mut handles = Vec::with_capacity(prepared_batches.len());

for (i, prepared_batch) in prepared_batches.into_iter().enumerate() {
let session_idx = i % self.sessions.len();
let session_mutex = &self.sessions[session_idx];
let config = &self.config;
let skiplist_ids = &self.skiplist_ids;

handles.push(scope.spawn(move || {
let mut session = session_mutex.lock().unwrap();
encode_prepared_batch_with_session(
&mut session,
config,
skiplist_ids,
prepared_batch,
)
}));
}
for (i, prepared_batch) in prepared_batches.into_iter().enumerate() {
let session_idx = i % self.sessions.len();
let session_mutex = &self.sessions[session_idx];
let config = &self.config;
let skiplist_ids = &self.skiplist_ids;

handles
.into_iter()
.map(|handle| handle.join().unwrap())
.collect()
});
handles.push(scope.spawn(move || {
let mut session = session_mutex.lock().unwrap();
encode_prepared_batch_with_session(
&mut session,
config,
skiplist_ids,
prepared_batch,
)
}));
}

let mut all_embeddings = Vec::new();
for result in results {
all_embeddings.extend(result?);
handles
.into_iter()
.map(|handle| handle.join().unwrap())
.collect()
});

let mut all_embeddings = Vec::new();
for result in results {
all_embeddings.extend(result?);
}
all_embeddings
};

if !has_reordering || combined_indices.len() != encoded.len() {
return Ok(encoded);
}

Ok(all_embeddings)
// Restore input order: encoded[i] belongs at output position combined_indices[i].
let n = encoded.len();
let mut reordered: Vec<Option<Array2<f32>>> = (0..n).map(|_| None).collect();
for (encoded_pos, embedding) in encoded.into_iter().enumerate() {
let target = combined_indices[encoded_pos];
if target >= n {
anyhow::bail!(
"original_input_indices points to out-of-range slot ({} >= {})",
target,
n
);
}
reordered[target] = Some(embedding);
}
reordered
.into_iter()
.enumerate()
.map(|(i, opt)| {
opt.ok_or_else(|| {
anyhow::anyhow!("original_input_indices missing slot {} in output", i)
})
})
.collect()
}

/// Stream document embeddings chunk-by-chunk.
Expand Down Expand Up @@ -1680,6 +1745,7 @@ fn prepare_batch_for_session(
original_lengths: Vec::new(),
is_query,
filter_skiplist,
original_input_indices: Vec::new(),
});
}

Expand All @@ -1701,6 +1767,7 @@ fn prepare_batch_from_tokenized_documents(
batch_docs: Vec<TokenizedDocument>,
is_query: bool,
filter_skiplist: bool,
original_input_indices: Vec<usize>,
) -> Result<PreparedDocumentBatch> {
let (prefix_str, prefix_token_id_opt, max_length) = if is_query {
(
Expand Down Expand Up @@ -1814,6 +1881,7 @@ fn prepare_batch_from_tokenized_documents(
original_lengths,
is_query,
filter_skiplist,
original_input_indices,
})
}

Expand Down Expand Up @@ -1957,6 +2025,10 @@ fn prepare_batch_from_tokenizer_encodings(
original_lengths,
is_query,
filter_skiplist,
// No reordering happens in this code path — callers that need to
// restore an original input order should populate this themselves
// before calling `encode_prepared_document_batches`.
original_input_indices: Vec::new(),
})
}

Expand All @@ -1976,6 +2048,7 @@ fn encode_prepared_batch_with_session(
original_lengths,
is_query,
filter_skiplist,
original_input_indices: _,
} = prepared;

if batch_size == 0 {
Expand Down
Loading