From ea95b139f9752e73f7cf979e035431ed23b902f0 Mon Sep 17 00:00:00 2001 From: jamie8johnson Date: Fri, 3 Jul 2026 16:23:37 -0500 Subject: [PATCH] feat(rust): add serialize/deserialize to brute force index --- rust/cuvs/src/brute_force.rs | 154 ++++++++++++++++++++++++++++++++++- 1 file changed, 153 insertions(+), 1 deletion(-) diff --git a/rust/cuvs/src/brute_force.rs b/rust/cuvs/src/brute_force.rs index 318346cc4d..1d563040c5 100644 --- a/rust/cuvs/src/brute_force.rs +++ b/rust/cuvs/src/brute_force.rs @@ -11,12 +11,14 @@ //! [`dlpack`](crate::dlpack) module for the tensor model and `examples/cagra.rs` //! for the same build/search workflow. +use std::ffi::CString; use std::io::{Write, stderr}; use std::marker::PhantomData; +use std::path::Path; use crate::distance_type::DistanceType; use crate::dlpack::{AsDlTensor, AsDlTensorMut}; -use crate::error::{Result, check_cuvs}; +use crate::error::{Error, Result, check_cuvs}; use crate::resources::Resources; /// Brute Force KNN Index @@ -28,6 +30,17 @@ pub struct Index<'d> { _dataset: PhantomData<&'d ()>, } +/// Convert a filesystem path into a `CString` suitable for the cuVS C API, +/// returning `Error::InvalidArgument` instead of panicking for paths that are +/// not valid UTF-8 or that contain an interior NUL byte. +fn path_to_cstring(path: &Path) -> Result { + let path_str = path + .to_str() + .ok_or_else(|| Error::InvalidArgument(format!("path is not valid UTF-8: {path:?}")))?; + CString::new(path_str) + .map_err(|e| Error::InvalidArgument(format!("path contains an interior NUL byte: {e}"))) +} + impl<'d> Index<'d> { /// Builds a brute-force index over `dataset` for exact k-NN search. /// @@ -103,6 +116,43 @@ impl<'d> Index<'d> { )) } } + + /// Save the Brute Force index to file. + /// + /// The serialization format can be subject to change, therefore loading an + /// index saved with a previous version of cuVS is not guaranteed to work. + /// + /// # Arguments + /// + /// * `res` - Resources to use + /// * `filename` - The file path for saving the index + pub fn serialize>(&self, res: &Resources, filename: P) -> Result<()> { + let c_filename = path_to_cstring(filename.as_ref())?; + unsafe { check_cuvs(ffi::cuvsBruteForceSerialize(res.0, c_filename.as_ptr(), self.inner)) } + } + + /// Load a Brute Force index from file. + /// + /// The serialization format can be subject to change, therefore loading an + /// index saved with a previous version of cuVS is not guaranteed to work. + /// + /// A deserialized index owns its data internally and borrows nothing from + /// Rust, hence the `'static` dataset lifetime. + /// + /// # Arguments + /// + /// * `res` - Resources to use + /// * `filename` - The path of the file that stores the index + pub fn deserialize>(res: &Resources, filename: P) -> Result> { + let c_filename = path_to_cstring(filename.as_ref())?; + // Create the Index handle first so that any error path below still runs + // its `Drop` and releases the C-side index allocation. + let index: Index<'static> = Index::new()?; + unsafe { + check_cuvs(ffi::cuvsBruteForceDeserialize(res.0, c_filename.as_ptr(), index.inner))?; + } + Ok(index) + } } impl Drop for Index<'_> { @@ -187,4 +237,106 @@ mod tests { fn test_l2() { test_bfknn(DistanceType::L2Expanded); } + + const N_DATAPOINTS: usize = 16; + const N_FEATURES: usize = 8; + + /// Search the first `n_queries` rows of `dataset` against `index` and assert + /// each query finds itself as the top-1 neighbor. + fn search_and_verify_self_neighbors( + res: &Resources, + index: &Index, + dataset: &ndarray::Array2, + n_queries: usize, + k: usize, + ) { + let queries = dataset.slice(s![0..n_queries, ..]).to_owned(); + let queries = DeviceTensor::from_host(res, &queries).unwrap(); + + let mut neighbors_host = ndarray::Array::::zeros((n_queries, k)); + let mut neighbors = DeviceTensor::::zeros(res, &[n_queries, k]).unwrap(); + + let mut distances_host = ndarray::Array::::zeros((n_queries, k)); + let mut distances = DeviceTensor::::zeros(res, &[n_queries, k]).unwrap(); + + index.search(res, &queries, &mut neighbors, &mut distances).expect("search failed"); + + distances.copy_to_host(res, &mut distances_host).unwrap(); + neighbors.copy_to_host(res, &mut neighbors_host).unwrap(); + res.sync_stream().unwrap(); + + for i in 0..n_queries { + assert_eq!( + neighbors_host[[i, 0]], + i as i64, + "query {i} should be its own nearest neighbor" + ); + } + } + + #[test] + fn test_brute_force_serialize_deserialize() { + let res = Resources::new().unwrap(); + + // Keep `device_dataset` alive for the whole test: the C++ index stores a + // non-owning view into it, so it must not be dropped while `index` lives. + let dataset = ndarray::Array::::random( + (N_DATAPOINTS, N_FEATURES), + Uniform::new(0., 1.0).unwrap(), + ); + let device_dataset = DeviceTensor::from_host(&res, &dataset).unwrap(); + let index = Index::build(&res, DistanceType::L2Expanded, None, &device_dataset) + .expect("failed to build brute force index"); + res.sync_stream().unwrap(); + + let unique = format!( + "test_brute_force_index_{}_{}.bin", + std::process::id(), + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("system clock before unix epoch") + .as_nanos() + ); + let filepath = std::env::temp_dir().join(unique); + index.serialize(&res, &filepath).expect("failed to serialize brute force index"); + + assert!(filepath.exists(), "serialized index file should exist"); + assert!( + std::fs::metadata(&filepath).unwrap().len() > 0, + "serialized index file should not be empty" + ); + + let loaded_index = + Index::deserialize(&res, &filepath).expect("failed to deserialize brute force index"); + + // The deserialized index should still find each query as its own + // nearest neighbor. + search_and_verify_self_neighbors(&res, &loaded_index, &dataset, 4, 4); + + let _ = std::fs::remove_file(&filepath); + } + + /// Passing a filename containing an interior NUL byte must surface as an + /// `InvalidArgument` error rather than panicking inside the serializer. + #[test] + fn test_brute_force_serialize_rejects_interior_nul() { + let res = Resources::new().unwrap(); + + let dataset = ndarray::Array::::random( + (N_DATAPOINTS, N_FEATURES), + Uniform::new(0., 1.0).unwrap(), + ); + let device_dataset = DeviceTensor::from_host(&res, &dataset).unwrap(); + let index = Index::build(&res, DistanceType::L2Expanded, None, &device_dataset) + .expect("failed to build brute force index"); + res.sync_stream().unwrap(); + + // `PathBuf::from` on Unix preserves arbitrary bytes, so we can embed a + // NUL byte in the path and confirm the helper rejects it. + let bad_path = std::path::PathBuf::from("/tmp/has\0nul.bin"); + let err = index + .serialize(&res, &bad_path) + .expect_err("serialize should reject paths with interior NUL"); + assert!(matches!(err, Error::InvalidArgument(_)), "expected InvalidArgument, got {err:?}"); + } }