Skip to content

Commit 07bb04a

Browse files
dian-lun-linmihaic
authored andcommitted
Add cancel (#409)
GitOrigin-RevId: 98418f02d3916800f722fb97b1755086973895ad
1 parent ab91cd9 commit 07bb04a

File tree

13 files changed

+290
-42
lines changed

13 files changed

+290
-42
lines changed

include/svs/index/flat/flat.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,6 @@ class FlatIndex {
212212
/// Return the logical number of dimensions of the indexed vectors.
213213
size_t dimensions() const { return data_.dimensions(); }
214214

215-
///
216215
/// @anchor flat_class_search_mutating
217216
/// @brief Fill the result with the ``num_neighbors`` nearest neighbors for each query.
218217
///
@@ -225,6 +224,9 @@ class FlatIndex {
225224
/// @param result The result data structure to populate.
226225
/// Row `i` in the result corresponds to the neighbors for the `i`th query.
227226
/// Neighbors within each row are ordered from nearest to furthest.
227+
/// @param cancel A predicate called during the search to determine if the search should be cancelled.
228+
// Return ``true`` if the search should be cancelled. This functor must implement ``bool operator()()``.
229+
// Note: This predicate should be thread-safe as it can be called concurrently by different threads during the search.
228230
/// @param predicate A predicate functor that can be used to exclude certain dataset
229231
/// elements from consideration. This functor must implement
230232
/// ``bool operator()(size_t)`` where the ``size_t`` argument is an index in
@@ -260,6 +262,7 @@ class FlatIndex {
260262
QueryResultView<size_t> result,
261263
const data::ConstSimpleDataView<QueryType>& queries,
262264
const search_parameters_type& search_parameters,
265+
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>()),
263266
Pred predicate = lib::Returns(lib::Const<true>())
264267
) {
265268
const size_t data_max_size = data_.size();
@@ -276,12 +279,18 @@ class FlatIndex {
276279

277280
size_t start = 0;
278281
while (start < data_.size()) {
282+
// Check if request to cancel the search
283+
if (cancel()) {
284+
scratch.cleanup();
285+
return;
286+
}
279287
size_t stop = std::min(data_max_size, start + data_batch_size);
280288
search_subset(
281289
queries,
282290
threads::UnitRange(start, stop),
283291
scratch,
284292
search_parameters,
293+
cancel,
285294
predicate
286295
);
287296
start = stop;
@@ -311,6 +320,7 @@ class FlatIndex {
311320
const threads::UnitRange<size_t>& data_indices,
312321
sorter_type& scratch,
313322
const search_parameters_type& search_parameters,
323+
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>()),
314324
Pred predicate = lib::Returns(lib::Const<true>())
315325
) {
316326
// Process all queries.
@@ -331,6 +341,7 @@ class FlatIndex {
331341
threads::UnitRange(query_indices),
332342
scratch,
333343
distances,
344+
cancel,
334345
predicate
335346
);
336347
}
@@ -352,6 +363,7 @@ class FlatIndex {
352363
const threads::UnitRange<size_t>& query_indices,
353364
sorter_type& scratch,
354365
distance::BroadcastDistance<DistFull>& distance_functors,
366+
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>()),
355367
Pred predicate = lib::Returns(lib::Const<true>())
356368
) {
357369
assert(distance_functors.size() >= query_indices.size());
@@ -365,6 +377,11 @@ class FlatIndex {
365377
}
366378

367379
for (auto data_index : data_indices) {
380+
// Check if request to cancel the search
381+
if (cancel()) {
382+
return;
383+
}
384+
368385
// Skip this index if it doesn't pass the predicate.
369386
if (!predicate(data_index)) {
370387
continue;

include/svs/index/index.h

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,20 +41,24 @@ void search_batch_into_with(
4141
Index& index,
4242
svs::QueryResultView<I> result,
4343
const Queries& queries,
44-
const search_parameters_t<Index>& search_parameters
44+
const search_parameters_t<Index>& search_parameters,
45+
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>())
4546
) {
4647
// Assert pre-conditions.
4748
assert(result.n_queries() == queries.size());
48-
index.search(result, queries, search_parameters);
49+
index.search(result, queries, search_parameters, cancel);
4950
}
5051

5152
// Apply default search parameters
5253
template <typename Index, std::integral I, data::ImmutableMemoryDataset Queries>
5354
void search_batch_into(
54-
Index& index, svs::QueryResultView<I> result, const Queries& queries
55+
Index& index,
56+
svs::QueryResultView<I> result,
57+
const Queries& queries,
58+
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>())
5559
) {
5660
svs::index::search_batch_into_with(
57-
index, result, queries, index.get_search_parameters()
61+
index, result, queries, index.get_search_parameters(), cancel
5862
);
5963
}
6064

@@ -64,19 +68,26 @@ svs::QueryResult<size_t> search_batch_with(
6468
Index& index,
6569
const Queries& queries,
6670
size_t num_neighbors,
67-
const search_parameters_t<Index>& search_parameters
71+
const search_parameters_t<Index>& search_parameters,
72+
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>())
6873
) {
6974
auto result = svs::QueryResult<size_t>{queries.size(), num_neighbors};
70-
svs::index::search_batch_into_with(index, result.view(), queries, search_parameters);
75+
svs::index::search_batch_into_with(
76+
index, result.view(), queries, search_parameters, cancel
77+
);
7178
return result;
7279
}
7380

7481
// Obtain default search parameters.
7582
template <typename Index, data::ImmutableMemoryDataset Queries>
76-
svs::QueryResult<size_t>
77-
search_batch(Index& index, const Queries& queries, size_t num_neighbors) {
83+
svs::QueryResult<size_t> search_batch(
84+
Index& index,
85+
const Queries& queries,
86+
size_t num_neighbors,
87+
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>())
88+
) {
7889
return svs::index::search_batch_with(
79-
index, queries, num_neighbors, index.get_search_parameters()
90+
index, queries, num_neighbors, index.get_search_parameters(), cancel
8091
);
8192
}
8293
} // namespace svs::index

include/svs/index/inverted/memory_based.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,8 @@ template <typename Index, typename Cluster> class InvertedIndex {
376376
void search(
377377
QueryResultView<Idx> results,
378378
const Queries& queries,
379-
const search_parameters_type& search_parameters
379+
const search_parameters_type& search_parameters,
380+
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>())
380381
) {
381382
threads::run(
382383
threadpool_,
@@ -396,9 +397,11 @@ template <typename Index, typename Cluster> class InvertedIndex {
396397

397398
for (auto i : is) {
398399
buffer.clear();
400+
399401
auto&& query = queries.get_datum(i);
400402
// Primary Index Search
401-
index_.search(query, scratch);
403+
index_.search(query, scratch, cancel);
404+
402405
auto& d = scratch.scratch;
403406
auto compare = distance::comparator(d);
404407

@@ -411,6 +414,11 @@ template <typename Index, typename Cluster> class InvertedIndex {
411414
);
412415

413416
for (size_t j = 0, jmax = scratch_buffer.size(); j < jmax; ++j) {
417+
// Check if request to cancel the search
418+
// TODO: this cancel may also be inside on_leaves()
419+
if (cancel()) {
420+
return;
421+
}
414422
auto candidate = scratch_buffer[j];
415423
if (!compare(candidate.distance(), cutoff_distance)) {
416424
break;

include/svs/index/vamana/dynamic_index.h

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ class MutableVamanaIndex {
402402
///
403403
size_t dimensions() const { return data_.dimensions(); }
404404

405-
auto greedy_search_closure(GreedySearchPrefetchParameters prefetch_parameters) const {
405+
auto greedy_search_closure(GreedySearchPrefetchParameters prefetch_parameters, const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>())) const {
406406
return [&, prefetch_parameters](
407407
const auto& query, auto& accessor, auto& distance, auto& buffer
408408
) {
@@ -416,7 +416,8 @@ class MutableVamanaIndex {
416416
buffer,
417417
entry_point_,
418418
ValidBuilder{status_},
419-
prefetch_parameters
419+
prefetch_parameters,
420+
cancel
420421
);
421422
// Take a pass over the search buffer to remove any deleted elements that
422423
// might remain.
@@ -426,7 +427,10 @@ class MutableVamanaIndex {
426427

427428
template <typename I, data::ImmutableMemoryDataset Queries>
428429
void search(
429-
QueryResultView<I> results, const Queries& queries, const search_parameters_type& sp
430+
QueryResultView<I> results,
431+
const Queries& queries,
432+
const search_parameters_type& sp,
433+
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>())
430434
) {
431435
threads::run(
432436
threadpool_,
@@ -452,11 +456,17 @@ class MutableVamanaIndex {
452456
queries,
453457
results,
454458
threads::UnitRange{is},
455-
greedy_search_closure(prefetch_parameters)
459+
greedy_search_closure(prefetch_parameters, cancel),
460+
cancel
456461
);
457462
}
458463
);
459464

465+
// Check if request to cancel the search
466+
if (cancel()) {
467+
return;
468+
}
469+
460470
// After the search procedure, the indices in `results` are internal.
461471
// Perform one more pass to convert these to external ids.
462472
translate_to_external(results.indices());

include/svs/index/vamana/extensions.h

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -411,15 +411,19 @@ struct VamanaSingleSearchType {
411411
SearchBuffer& search_buffer,
412412
Scratch& scratch,
413413
const Query& query,
414-
const Search& search
414+
const Search& search,
415+
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>())
415416
) const {
416-
svs::svs_invoke(*this, data, search_buffer, scratch, query, search);
417+
svs::svs_invoke(*this, data, search_buffer, scratch, query, search, cancel);
417418
}
418419
};
419420

420421
/// Customization point object for processing single queries.
421422
inline constexpr VamanaSingleSearchType single_search{};
422423

424+
/// In this function, cancel does not need to be called since search will call cancel
425+
/// However, lvq requires cancel to be called to skip reranking
426+
/// Here we have cancel to make the interface consistent with lvq
423427
template <
424428
typename Data,
425429
typename SearchBuffer,
@@ -432,8 +436,13 @@ SVS_FORCE_INLINE void svs_invoke(
432436
SearchBuffer& search_buffer,
433437
Distance& distance,
434438
const Query& query,
435-
const Search& search
439+
const Search& search,
440+
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>())
436441
) {
442+
// Check if request to cancel the search
443+
if (cancel()) {
444+
return;
445+
}
437446
// Perform graph search.
438447
auto accessor = data::GetDatumAccessor();
439448
search(query, accessor, distance, search_buffer);
@@ -482,10 +491,19 @@ struct VamanaPerThreadBatchSearchType {
482491
const Queries& queries,
483492
QueryResultView<I>& result,
484493
threads::UnitRange<size_t> thread_indices,
485-
const Search& search
494+
const Search& search,
495+
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>())
486496
) const {
487497
svs::svs_invoke(
488-
*this, data, search_buffer, scratch, queries, result, thread_indices, search
498+
*this,
499+
data,
500+
search_buffer,
501+
scratch,
502+
queries,
503+
result,
504+
thread_indices,
505+
search,
506+
cancel
489507
);
490508
}
491509
};
@@ -509,13 +527,18 @@ void svs_invoke(
509527
const Queries& queries,
510528
QueryResultView<I>& result,
511529
threads::UnitRange<size_t> thread_indices,
512-
const Search& search
530+
const Search& search,
531+
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>())
513532
) {
514533
// Fallback implementation
515534
size_t num_neighbors = result.n_neighbors();
516535
for (auto i : thread_indices) {
536+
// Check if request to cancel the search
537+
if (cancel()) {
538+
return;
539+
}
517540
// Perform search - results will be queued in the search buffer.
518-
single_search(dataset, search_buffer, distance, queries.get_datum(i), search);
541+
single_search(dataset, search_buffer, distance, queries.get_datum(i), search, cancel);
519542

520543
// Copy back results.
521544
for (size_t j = 0; j < num_neighbors; ++j) {

include/svs/index/vamana/greedy_search.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ void greedy_search(
7878
const Ep& entry_points,
7979
const Builder& builder,
8080
Tracker& search_tracker,
81-
GreedySearchPrefetchParameters prefetch_parameters = {}
81+
GreedySearchPrefetchParameters prefetch_parameters = {},
82+
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>())
8283
) {
8384
using I = typename Graph::index_type;
8485

@@ -102,6 +103,10 @@ void greedy_search(
102103
// Main search routine.
103104
search_buffer.sort();
104105
while (!search_buffer.done()) {
106+
// Check if request to cancel the search
107+
if (cancel()) {
108+
return;
109+
}
105110
// Get the next unvisited vertex.
106111
const auto& node = search_buffer.next();
107112
auto node_id = node.id();
@@ -169,7 +174,8 @@ void greedy_search(
169174
Buffer& search_buffer,
170175
const Ep& entry_points,
171176
const Builder& builder = NeighborBuilder(),
172-
GreedySearchPrefetchParameters prefetch_parameters = {}
177+
GreedySearchPrefetchParameters prefetch_parameters = {},
178+
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>())
173179
) {
174180
auto null_tracker = NullTracker{};
175181
greedy_search(
@@ -182,7 +188,8 @@ void greedy_search(
182188
entry_points,
183189
builder,
184190
null_tracker,
185-
prefetch_parameters
191+
prefetch_parameters,
192+
cancel
186193
);
187194
}
188195
} // namespace svs::index::vamana

0 commit comments

Comments
 (0)