-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a HNSW collector that exits early when nearest neighbor queue saturates #14094
base: main
Are you sure you want to change the base?
Changes from 40 commits
3b30c07
0b24e79
70b6144
93fb470
c5aa473
7fc49c5
aed6fd5
b7eb24f
d143bbb
51df9ee
e55f967
e3f8db3
ec1e686
a71e936
09b0229
74132f1
8d00ae8
370f513
fed77c9
88d22df
e86ebdc
e69730f
5b001ee
20a481f
1dbaa1a
c6dbf7e
55fdea2
460efd9
3d2e46b
0f3f047
eef4f97
acf5866
620e985
ca0f05d
f116141
45b2031
66bd51d
0b47585
bb57ca1
8f846b8
695a4eb
c899f29
36b9931
a84032e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.lucene.search; | ||
|
||
/** | ||
* {@link KnnCollector} that exposes methods to hook into specific parts of the HNSW algorithm. | ||
* | ||
* @lucene.experimental | ||
*/ | ||
public abstract class HnswKnnCollector extends KnnCollector.Decorator { | ||
|
||
public HnswKnnCollector(KnnCollector collector) { | ||
super(collector); | ||
} | ||
|
||
/** Triggers exploration of the next HNSW candidate graph node. */ | ||
public void nextCandidate() {} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.lucene.search; | ||
|
||
/** | ||
* A {@link HnswKnnCollector} that early exits when nearest neighbor queue keeps saturating beyond a | ||
* 'patience' parameter. This records the rate of collection of new nearest neighbors in the {@code | ||
* delegate} KnnCollector queue, at each HNSW node candidate visit. Once it saturates for a number | ||
* of consecutive node visits (e.g., the patience parameter), this early terminates. | ||
* | ||
* @lucene.experimental | ||
*/ | ||
public class HnswQueueSaturationCollector extends HnswKnnCollector { | ||
|
||
private final KnnCollector delegate; | ||
private final double saturationThreshold; | ||
private final int patience; | ||
private boolean patienceFinished; | ||
private int countSaturated; | ||
private int previousQueueSize; | ||
private int currentQueueSize; | ||
|
||
HnswQueueSaturationCollector(KnnCollector delegate, double saturationThreshold, int patience) { | ||
super(delegate); | ||
this.delegate = delegate; | ||
this.previousQueueSize = 0; | ||
this.currentQueueSize = 0; | ||
this.countSaturated = 0; | ||
this.patienceFinished = false; | ||
this.saturationThreshold = saturationThreshold; | ||
this.patience = patience; | ||
} | ||
|
||
@Override | ||
public boolean earlyTerminated() { | ||
return delegate.earlyTerminated() || patienceFinished; | ||
tteofili marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
@Override | ||
public boolean collect(int docId, float similarity) { | ||
boolean collect = delegate.collect(docId, similarity); | ||
if (collect) { | ||
currentQueueSize++; | ||
} | ||
return collect; | ||
} | ||
|
||
@Override | ||
public float minCompetitiveSimilarity() { | ||
return delegate.minCompetitiveSimilarity(); | ||
} | ||
tteofili marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@Override | ||
public TopDocs topDocs() { | ||
TopDocs topDocs; | ||
if (patienceFinished && delegate.earlyTerminated() == false) { | ||
TopDocs delegateDocs = delegate.topDocs(); | ||
TotalHits totalHits = | ||
new TotalHits(delegateDocs.totalHits.value(), TotalHits.Relation.EQUAL_TO); | ||
topDocs = new TopDocs(totalHits, delegateDocs.scoreDocs); | ||
} else { | ||
topDocs = delegate.topDocs(); | ||
} | ||
return topDocs; | ||
} | ||
|
||
@Override | ||
public void nextCandidate() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tteofili what do you think of making this more general? I think having a "nextCandidate" or "nextBlockOfVectors" is generally useful, and might be applicable to all types of kNN indices. For example:
Do you think we can make this API general? Maybe not, I am not sure. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I really like this idea Ben, I'll see if I can make up something reasonable for that ;) |
||
double queueSaturation = | ||
(double) Math.min(currentQueueSize, previousQueueSize) / currentQueueSize; | ||
mayya-sharipova marked this conversation as resolved.
Show resolved
Hide resolved
|
||
previousQueueSize = currentQueueSize; | ||
if (queueSaturation >= saturationThreshold) { | ||
countSaturated++; | ||
} else { | ||
countSaturated = 0; | ||
} | ||
if (countSaturated > patience) { | ||
patienceFinished = true; | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,237 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.lucene.search; | ||
|
||
import java.io.IOException; | ||
import java.util.Objects; | ||
import org.apache.lucene.index.FieldInfo; | ||
import org.apache.lucene.index.LeafReaderContext; | ||
import org.apache.lucene.index.QueryTimeout; | ||
import org.apache.lucene.search.knn.KnnCollectorManager; | ||
import org.apache.lucene.search.knn.KnnSearchStrategy; | ||
import org.apache.lucene.util.Bits; | ||
|
||
/** | ||
* This is a version of knn vector query that exits early when HNSW queue saturates over a {@code | ||
* #saturationThreshold} for more than {@code #patience} times. | ||
* | ||
* <p>See <a | ||
* href="https://cs.uwaterloo.ca/~jimmylin/publications/Teofili_Lin_ECIR2025.pdf">"Patience in | ||
* Proximity: A Simple Early Termination Strategy for HNSW Graph Traversal in Approximate k-Nearest | ||
* Neighbor Search"</a> (Teofili and Lin). In ECIR '25: Proceedings of the 47th European Conference | ||
* on Information Retrieval. | ||
* | ||
* @lucene.experimental | ||
*/ | ||
public class PatienceKnnVectorQuery extends AbstractKnnVectorQuery { | ||
|
||
private static final double DEFAULT_SATURATION_THRESHOLD = 0.995d; | ||
|
||
private final int patience; | ||
private final double saturationThreshold; | ||
|
||
final AbstractKnnVectorQuery delegate; | ||
|
||
/** | ||
* Construct a new PatienceKnnVectorQuery instance for a float vector field | ||
* | ||
* @param knnQuery the knn query to be seeded | ||
* @param saturationThreshold the early exit saturation threshold | ||
* @param patience the patience parameter | ||
* @return a new PatienceKnnVectorQuery instance | ||
* @lucene.experimental | ||
*/ | ||
public static PatienceKnnVectorQuery fromFloatQuery( | ||
KnnFloatVectorQuery knnQuery, double saturationThreshold, int patience) { | ||
return new PatienceKnnVectorQuery(knnQuery, saturationThreshold, patience); | ||
} | ||
|
||
/** | ||
* Construct a new PatienceKnnVectorQuery instance for a float vector field | ||
* | ||
* @param knnQuery the knn query to be seeded | ||
* @return a new PatienceKnnVectorQuery instance | ||
* @lucene.experimental | ||
*/ | ||
public static PatienceKnnVectorQuery fromFloatQuery(KnnFloatVectorQuery knnQuery) { | ||
return new PatienceKnnVectorQuery( | ||
knnQuery, DEFAULT_SATURATION_THRESHOLD, defaultPatience(knnQuery)); | ||
} | ||
|
||
/** | ||
* Construct a new PatienceKnnVectorQuery instance for a byte vector field | ||
* | ||
* @param knnQuery the knn query to be seeded | ||
* @param saturationThreshold the early exit saturation threshold | ||
* @param patience the patience parameter | ||
* @return a new PatienceKnnVectorQuery instance | ||
* @lucene.experimental | ||
*/ | ||
public static PatienceKnnVectorQuery fromByteQuery( | ||
KnnByteVectorQuery knnQuery, double saturationThreshold, int patience) { | ||
return new PatienceKnnVectorQuery(knnQuery, saturationThreshold, patience); | ||
} | ||
|
||
/** | ||
* Construct a new PatienceKnnVectorQuery instance for a byte vector field | ||
* | ||
* @param knnQuery the knn query to be seeded | ||
* @return a new PatienceKnnVectorQuery instance | ||
* @lucene.experimental | ||
*/ | ||
public static PatienceKnnVectorQuery fromByteQuery(KnnByteVectorQuery knnQuery) { | ||
return new PatienceKnnVectorQuery( | ||
knnQuery, DEFAULT_SATURATION_THRESHOLD, defaultPatience(knnQuery)); | ||
} | ||
|
||
/** | ||
* Construct a new PatienceKnnVectorQuery instance for seeded vector field | ||
* | ||
* @param knnQuery the knn query to be seeded | ||
* @param saturationThreshold the early exit saturation threshold | ||
* @param patience the patience parameter | ||
* @return a new PatienceKnnVectorQuery instance | ||
* @lucene.experimental | ||
*/ | ||
public static PatienceKnnVectorQuery fromSeededQuery( | ||
SeededKnnVectorQuery knnQuery, double saturationThreshold, int patience) { | ||
return new PatienceKnnVectorQuery(knnQuery, saturationThreshold, patience); | ||
} | ||
|
||
/** | ||
* Construct a new PatienceKnnVectorQuery instance for seeded vector field | ||
* | ||
* @param knnQuery the knn query to be seeded | ||
* @return a new PatienceKnnVectorQuery instance | ||
* @lucene.experimental | ||
*/ | ||
public static PatienceKnnVectorQuery fromSeededQuery(SeededKnnVectorQuery knnQuery) { | ||
return new PatienceKnnVectorQuery( | ||
knnQuery, DEFAULT_SATURATION_THRESHOLD, defaultPatience(knnQuery)); | ||
} | ||
|
||
PatienceKnnVectorQuery( | ||
AbstractKnnVectorQuery knnQuery, double saturationThreshold, int patience) { | ||
super(knnQuery.field, knnQuery.k, knnQuery.filter, knnQuery.searchStrategy); | ||
this.delegate = knnQuery; | ||
this.saturationThreshold = saturationThreshold; | ||
this.patience = patience; | ||
} | ||
|
||
private static int defaultPatience(AbstractKnnVectorQuery delegate) { | ||
return Math.max(7, (int) (delegate.k * 0.3)); | ||
} | ||
|
||
@Override | ||
public String toString(String field) { | ||
return "PatienceKnnVectorQuery{" | ||
+ "saturationThreshold=" | ||
+ saturationThreshold | ||
+ ", patience=" | ||
+ patience | ||
+ ", delegate=" | ||
+ delegate | ||
+ '}'; | ||
} | ||
|
||
@Override | ||
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) { | ||
return delegate.getKnnCollectorManager(k, searcher); | ||
} | ||
|
||
@Override | ||
protected TopDocs approximateSearch( | ||
LeafReaderContext context, | ||
Bits acceptDocs, | ||
int visitedLimit, | ||
KnnCollectorManager knnCollectorManager) | ||
throws IOException { | ||
return delegate.approximateSearch( | ||
context, acceptDocs, visitedLimit, new PatienceCollectorManager(knnCollectorManager)); | ||
} | ||
|
||
@Override | ||
protected TopDocs exactSearch( | ||
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) | ||
throws IOException { | ||
return delegate.exactSearch(context, acceptIterator, queryTimeout); | ||
} | ||
|
||
@Override | ||
protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { | ||
return delegate.mergeLeafResults(perLeafResults); | ||
} | ||
|
||
@Override | ||
public void visit(QueryVisitor visitor) { | ||
delegate.visit(visitor); | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
if (!super.equals(o)) return false; | ||
PatienceKnnVectorQuery that = (PatienceKnnVectorQuery) o; | ||
return saturationThreshold == that.saturationThreshold | ||
&& patience == that.patience | ||
&& Objects.equals(delegate, that.delegate); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(super.hashCode(), saturationThreshold, patience, delegate); | ||
} | ||
|
||
@Override | ||
public String getField() { | ||
return delegate.getField(); | ||
} | ||
|
||
@Override | ||
public int getK() { | ||
return delegate.getK(); | ||
} | ||
|
||
@Override | ||
public Query getFilter() { | ||
return delegate.getFilter(); | ||
} | ||
|
||
@Override | ||
VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException { | ||
return delegate.createVectorScorer(context, fi); | ||
} | ||
|
||
class PatienceCollectorManager implements KnnCollectorManager { | ||
final KnnCollectorManager knnCollectorManager; | ||
|
||
PatienceCollectorManager(KnnCollectorManager knnCollectorManager) { | ||
this.knnCollectorManager = knnCollectorManager; | ||
} | ||
|
||
@Override | ||
public KnnCollector newCollector( | ||
int visitLimit, KnnSearchStrategy searchStrategy, LeafReaderContext ctx) | ||
throws IOException { | ||
return new HnswQueueSaturationCollector( | ||
knnCollectorManager.newCollector(visitLimit, searchStrategy, ctx), | ||
saturationThreshold, | ||
patience); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, it is a little frustrating as we already have an "HNSWStrategy" and now we have an "HNSWCollector".
Could we utilize an HNSWStrategy? Or make
nextCandidate
a more general API?My thought on the strategy would be that the graph searcher to indicate through the strategy object when the next group of vectors will be searched and the strategy would have a reference to the collector to which it can forward the request.
Of course, this still requires a new
HnswQueueSaturationCollector
, but it won't require these new base classes.