@@ -212,7 +212,6 @@ class FlatIndex {
212212 // / Return the logical number of dimensions of the indexed vectors.
213213 size_t dimensions () const { return data_.dimensions (); }
214214
215- // /
216215 // / @anchor flat_class_search_mutating
217216 // / @brief Fill the result with the ``num_neighbors`` nearest neighbors for each query.
218217 // /
@@ -225,6 +224,9 @@ class FlatIndex {
225224 // / @param result The result data structure to populate.
226225 // / Row `i` in the result corresponds to the neighbors for the `i`th query.
227226 // / Neighbors within each row are ordered from nearest to furthest.
227+ // / @param cancel A predicate called during the search to determine if the search should be cancelled.
228+ // Return ``true`` if the search should be cancelled. This functor must implement ``bool operator()()``.
229+ // Note: This predicate should be thread-safe as it can be called concurrently by different threads during the search.
228230 // / @param predicate A predicate functor that can be used to exclude certain dataset
229231 // / elements from consideration. This functor must implement
230232 // / ``bool operator()(size_t)`` where the ``size_t`` argument is an index in
@@ -260,6 +262,7 @@ class FlatIndex {
260262 QueryResultView<size_t > result,
261263 const data::ConstSimpleDataView<QueryType>& queries,
262264 const search_parameters_type& search_parameters,
265+ const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false >()),
263266 Pred predicate = lib::Returns(lib::Const<true >())
264267 ) {
265268 const size_t data_max_size = data_.size ();
@@ -276,12 +279,18 @@ class FlatIndex {
276279
277280 size_t start = 0 ;
278281 while (start < data_.size ()) {
282+ // Check if request to cancel the search
283+ if (cancel ()) {
284+ scratch.cleanup ();
285+ return ;
286+ }
279287 size_t stop = std::min (data_max_size, start + data_batch_size);
280288 search_subset (
281289 queries,
282290 threads::UnitRange (start, stop),
283291 scratch,
284292 search_parameters,
293+ cancel,
285294 predicate
286295 );
287296 start = stop;
@@ -311,6 +320,7 @@ class FlatIndex {
311320 const threads::UnitRange<size_t >& data_indices,
312321 sorter_type& scratch,
313322 const search_parameters_type& search_parameters,
323+ const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false >()),
314324 Pred predicate = lib::Returns(lib::Const<true >())
315325 ) {
316326 // Process all queries.
@@ -331,6 +341,7 @@ class FlatIndex {
331341 threads::UnitRange (query_indices),
332342 scratch,
333343 distances,
344+ cancel,
334345 predicate
335346 );
336347 }
@@ -352,6 +363,7 @@ class FlatIndex {
352363 const threads::UnitRange<size_t >& query_indices,
353364 sorter_type& scratch,
354365 distance::BroadcastDistance<DistFull>& distance_functors,
366+ const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false >()),
355367 Pred predicate = lib::Returns(lib::Const<true >())
356368 ) {
357369 assert (distance_functors.size () >= query_indices.size ());
@@ -365,6 +377,11 @@ class FlatIndex {
365377 }
366378
367379 for (auto data_index : data_indices) {
380+ // Check if request to cancel the search
381+ if (cancel ()) {
382+ return ;
383+ }
384+
368385 // Skip this index if it doesn't pass the predicate.
369386 if (!predicate (data_index)) {
370387 continue ;
0 commit comments