diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index f9091382913e..b78f088e5375 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -103,6 +103,8 @@ New Features * GITHUB#13470: Added `TopDocs#rrf` to combine multiple TopDocs instances using reciprocal rank fusion. (Haren Lin, Adrien Grand) +* GITHUB#14094: New KNN query that early terminates when HNSW nearest neighbor queue saturates (Tommaso Teofili) + Improvements --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java new file mode 100644 index 000000000000..3147c30e418b --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +/** + * {@link KnnCollector} that exposes methods to hook into specific parts of the HNSW algorithm. + * + * @lucene.experimental + */ +public abstract class HnswKnnCollector extends KnnCollector.Decorator { + + public HnswKnnCollector(KnnCollector collector) { + super(collector); + } + + /** Triggers exploration of the next HNSW candidate graph node. */ + public void nextCandidate() {} +} diff --git a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java new file mode 100644 index 000000000000..95ad5973ca43 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search; + +/** + * A {@link HnswKnnCollector} that early exits when nearest neighbor queue keeps saturating beyond a + * 'patience' parameter. This records the rate of collection of new nearest neighbors in the {@code + * delegate} KnnCollector queue, at each HNSW node candidate visit. Once it saturates for a number + * of consecutive node visits (e.g., the patience parameter), this early terminates. + * + * @lucene.experimental + */ +public class HnswQueueSaturationCollector extends HnswKnnCollector { + + private final KnnCollector delegate; + private final double saturationThreshold; + private final int patience; + private boolean patienceFinished; + private int countSaturated; + private int previousQueueSize; + private int currentQueueSize; + + HnswQueueSaturationCollector(KnnCollector delegate, double saturationThreshold, int patience) { + super(delegate); + this.delegate = delegate; + this.previousQueueSize = 0; + this.currentQueueSize = 0; + this.countSaturated = 0; + this.patienceFinished = false; + this.saturationThreshold = saturationThreshold; + this.patience = patience; + } + + @Override + public boolean earlyTerminated() { + return delegate.earlyTerminated() || patienceFinished; + } + + @Override + public boolean collect(int docId, float similarity) { + boolean collect = delegate.collect(docId, similarity); + if (collect) { + currentQueueSize++; + } + return collect; + } + + @Override + public float minCompetitiveSimilarity() { + return delegate.minCompetitiveSimilarity(); + } + + @Override + public TopDocs topDocs() { + TopDocs topDocs; + if (patienceFinished && delegate.earlyTerminated() == false) { + TopDocs delegateDocs = delegate.topDocs(); + TotalHits totalHits = + new TotalHits(delegateDocs.totalHits.value(), TotalHits.Relation.EQUAL_TO); + topDocs = new TopDocs(totalHits, delegateDocs.scoreDocs); + } else { + topDocs = delegate.topDocs(); + } + return topDocs; + } + + @Override + public void nextCandidate() { + double queueSaturation = + (double) Math.min(currentQueueSize, previousQueueSize) / currentQueueSize; + previousQueueSize = currentQueueSize; + if (queueSaturation >= saturationThreshold) { + countSaturated++; + } else { + countSaturated = 0; + } + if (countSaturated > patience) { + patienceFinished = true; + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/PatienceKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/PatienceKnnVectorQuery.java new file mode 100644 index 000000000000..ac2d401cbbea --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/PatienceKnnVectorQuery.java @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.Objects; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; +import org.apache.lucene.search.knn.KnnCollectorManager; +import org.apache.lucene.search.knn.KnnSearchStrategy; +import org.apache.lucene.util.Bits; + +/** + * This is a version of knn vector query that exits early when HNSW queue saturates over a {@code + * #saturationThreshold} for more than {@code #patience} times. + * + *

See "Patience in + * Proximity: A Simple Early Termination Strategy for HNSW Graph Traversal in Approximate k-Nearest + * Neighbor Search" (Teofili and Lin). In ECIR '25: Proceedings of the 47th European Conference + * on Information Retrieval. + * + * @lucene.experimental + */ +public class PatienceKnnVectorQuery extends AbstractKnnVectorQuery { + + private static final double DEFAULT_SATURATION_THRESHOLD = 0.995d; + + private final int patience; + private final double saturationThreshold; + + final AbstractKnnVectorQuery delegate; + + /** + * Construct a new PatienceKnnVectorQuery instance for a float vector field + * + * @param knnQuery the knn query to be seeded + * @param saturationThreshold the early exit saturation threshold + * @param patience the patience parameter + * @return a new PatienceKnnVectorQuery instance + * @lucene.experimental + */ + public static PatienceKnnVectorQuery fromFloatQuery( + KnnFloatVectorQuery knnQuery, double saturationThreshold, int patience) { + return new PatienceKnnVectorQuery(knnQuery, saturationThreshold, patience); + } + + /** + * Construct a new PatienceKnnVectorQuery instance for a float vector field + * + * @param knnQuery the knn query to be seeded + * @return a new PatienceKnnVectorQuery instance + * @lucene.experimental + */ + public static PatienceKnnVectorQuery fromFloatQuery(KnnFloatVectorQuery knnQuery) { + return new PatienceKnnVectorQuery( + knnQuery, DEFAULT_SATURATION_THRESHOLD, defaultPatience(knnQuery)); + } + + /** + * Construct a new PatienceKnnVectorQuery instance for a byte vector field + * + * @param knnQuery the knn query to be seeded + * @param saturationThreshold the early exit saturation threshold + * @param patience the patience parameter + * @return a new PatienceKnnVectorQuery instance + * @lucene.experimental + */ + public static PatienceKnnVectorQuery fromByteQuery( + KnnByteVectorQuery knnQuery, double saturationThreshold, int patience) { + return new PatienceKnnVectorQuery(knnQuery, saturationThreshold, patience); + } + + /** + * Construct a new PatienceKnnVectorQuery instance for a byte vector field + * + * @param knnQuery the knn query to be seeded + * @return a new PatienceKnnVectorQuery instance + * @lucene.experimental + */ + public static PatienceKnnVectorQuery fromByteQuery(KnnByteVectorQuery knnQuery) { + return new PatienceKnnVectorQuery( + knnQuery, DEFAULT_SATURATION_THRESHOLD, defaultPatience(knnQuery)); + } + + /** + * Construct a new PatienceKnnVectorQuery instance for seeded vector field + * + * @param knnQuery the knn query to be seeded + * @param saturationThreshold the early exit saturation threshold + * @param patience the patience parameter + * @return a new PatienceKnnVectorQuery instance + * @lucene.experimental + */ + public static PatienceKnnVectorQuery fromSeededQuery( + SeededKnnVectorQuery knnQuery, double saturationThreshold, int patience) { + return new PatienceKnnVectorQuery(knnQuery, saturationThreshold, patience); + } + + /** + * Construct a new PatienceKnnVectorQuery instance for seeded vector field + * + * @param knnQuery the knn query to be seeded + * @return a new PatienceKnnVectorQuery instance + * @lucene.experimental + */ + public static PatienceKnnVectorQuery fromSeededQuery(SeededKnnVectorQuery knnQuery) { + return new PatienceKnnVectorQuery( + knnQuery, DEFAULT_SATURATION_THRESHOLD, defaultPatience(knnQuery)); + } + + PatienceKnnVectorQuery( + AbstractKnnVectorQuery knnQuery, double saturationThreshold, int patience) { + super(knnQuery.field, knnQuery.k, knnQuery.filter, knnQuery.searchStrategy); + this.delegate = knnQuery; + this.saturationThreshold = saturationThreshold; + this.patience = patience; + } + + private static int defaultPatience(AbstractKnnVectorQuery delegate) { + return Math.max(7, (int) (delegate.k * 0.3)); + } + + @Override + public String toString(String field) { + return "PatienceKnnVectorQuery{" + + "saturationThreshold=" + + saturationThreshold + + ", patience=" + + patience + + ", delegate=" + + delegate + + '}'; + } + + @Override + protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) { + return delegate.getKnnCollectorManager(k, searcher); + } + + @Override + protected TopDocs approximateSearch( + LeafReaderContext context, + Bits acceptDocs, + int visitedLimit, + KnnCollectorManager knnCollectorManager) + throws IOException { + return delegate.approximateSearch( + context, acceptDocs, visitedLimit, new PatienceCollectorManager(knnCollectorManager)); + } + + @Override + protected TopDocs exactSearch( + LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) + throws IOException { + return delegate.exactSearch(context, acceptIterator, queryTimeout); + } + + @Override + protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { + return delegate.mergeLeafResults(perLeafResults); + } + + @Override + public void visit(QueryVisitor visitor) { + delegate.visit(visitor); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + if (!super.equals(o)) return false; + PatienceKnnVectorQuery that = (PatienceKnnVectorQuery) o; + return saturationThreshold == that.saturationThreshold + && patience == that.patience + && Objects.equals(delegate, that.delegate); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), saturationThreshold, patience, delegate); + } + + @Override + public String getField() { + return delegate.getField(); + } + + @Override + public int getK() { + return delegate.getK(); + } + + @Override + public Query getFilter() { + return delegate.getFilter(); + } + + @Override + VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException { + return delegate.createVectorScorer(context, fi); + } + + class PatienceCollectorManager implements KnnCollectorManager { + final KnnCollectorManager knnCollectorManager; + + PatienceCollectorManager(KnnCollectorManager knnCollectorManager) { + this.knnCollectorManager = knnCollectorManager; + } + + @Override + public KnnCollector newCollector( + int visitLimit, KnnSearchStrategy searchStrategy, LeafReaderContext ctx) + throws IOException { + return new HnswQueueSaturationCollector( + knnCollectorManager.newCollector(visitLimit, searchStrategy, ctx), + saturationThreshold, + patience); + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 4c7075dbd9f5..2105fa893952 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -20,6 +20,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; +import org.apache.lucene.search.HnswKnnCollector; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.TopKnnCollector; import org.apache.lucene.search.knn.KnnSearchStrategy; @@ -296,6 +297,9 @@ void searchLevel( } } } + if (results instanceof HnswKnnCollector hnswKnnCollector) { + hnswKnnCollector.nextCandidate(); + } } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java index 5225fe700ab9..1e92abf54aab 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java @@ -17,6 +17,7 @@ package org.apache.lucene.util.hnsw; +import org.apache.lucene.search.HnswKnnCollector; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; @@ -24,7 +25,7 @@ /** * Wraps a provided KnnCollector object, translating the provided vectorId ordinal to a documentId */ -public final class OrdinalTranslatedKnnCollector extends KnnCollector.Decorator { +public final class OrdinalTranslatedKnnCollector extends HnswKnnCollector { private final IntToIntFunction vectorOrdinalToDocId; @@ -50,4 +51,11 @@ public TopDocs topDocs() { : TotalHits.Relation.EQUAL_TO), td.scoreDocs); } + + @Override + public void nextCandidate() { + if (this.collector instanceof HnswKnnCollector) { + ((HnswKnnCollector) this.collector).nextCandidate(); + } + } } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestHnswQueueSaturationCollector.java b/lucene/core/src/test/org/apache/lucene/search/TestHnswQueueSaturationCollector.java new file mode 100644 index 000000000000..9f6a98f3812c --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestHnswQueueSaturationCollector.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.util.Random; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.junit.Test; + +/** Tests for {@link HnswQueueSaturationCollector} */ +public class TestHnswQueueSaturationCollector extends LuceneTestCase { + + @Test + public void testDelegate() { + Random random = random(); + int numDocs = 100; + int k = random.nextInt(1, 10); + KnnCollector delegate = new TopKnnCollector(k, numDocs); + HnswQueueSaturationCollector queueSaturationCollector = + new HnswQueueSaturationCollector(delegate, 1, Integer.MAX_VALUE); + for (int i = 0; i < random.nextInt(numDocs); i++) { + queueSaturationCollector.collect(random.nextInt(numDocs), random.nextFloat(1.0f)); + } + assertEquals(delegate.k(), queueSaturationCollector.k()); + assertEquals(delegate.visitedCount(), queueSaturationCollector.visitedCount()); + assertEquals(delegate.visitLimit(), queueSaturationCollector.visitLimit()); + assertEquals( + delegate.minCompetitiveSimilarity(), + queueSaturationCollector.minCompetitiveSimilarity(), + 1e-3); + } + + @Test + public void testEarlyExpectedExit() { + int numDocs = 1000; + int k = 10; + KnnCollector delegate = new TopKnnCollector(k, numDocs); + HnswQueueSaturationCollector queueSaturationCollector = + new HnswQueueSaturationCollector(delegate, 0.9, 10); + for (int i = 0; i < numDocs; i++) { + queueSaturationCollector.collect(i, 1.0f - i * 1e-3f); + if (i % 10 == 0) { + queueSaturationCollector.nextCandidate(); + } + if (queueSaturationCollector.earlyTerminated()) { + assertEquals(120, i); + break; + } + } + } + + @Test + public void testDelegateVsSaturateEarlyExit() { + Random random = random(); + int numDocs = 10000; + int k = random.nextInt(1, 100); + KnnCollector delegate = new TopKnnCollector(k, numDocs); + HnswQueueSaturationCollector queueSaturationCollector = + new HnswQueueSaturationCollector(delegate, 0.5, 1); + for (int i = 0; i < random.nextInt(numDocs); i++) { + queueSaturationCollector.collect(random.nextInt(numDocs), random.nextFloat(1.0f)); + if (i % 10 == 0) { + queueSaturationCollector.nextCandidate(); + } + boolean earlyTerminatedSaturation = queueSaturationCollector.earlyTerminated(); + boolean earlyTerminatedDelegate = delegate.earlyTerminated(); + assertTrue(earlyTerminatedSaturation || !earlyTerminatedDelegate); + } + } + + @Test + public void testEarlyExitRelation() { + Random random = random(); + int numDocs = 10000; + int k = random.nextInt(100); + KnnCollector delegate = new TopKnnCollector(k, random.nextInt(numDocs)); + HnswQueueSaturationCollector queueSaturationCollector = + new HnswQueueSaturationCollector(delegate, 0.5, 1); + for (int i = 0; i < random.nextInt(numDocs); i++) { + queueSaturationCollector.collect(random.nextInt(numDocs), random.nextFloat(1.0f)); + if (i % 10 == 0) { + queueSaturationCollector.nextCandidate(); + } + if (delegate.earlyTerminated()) { + TopDocs topDocs = queueSaturationCollector.topDocs(); + assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, topDocs.totalHits.relation()); + } + if (queueSaturationCollector.earlyTerminated()) { + TopDocs topDocs = queueSaturationCollector.topDocs(); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocs.totalHits.relation()); + break; + } + } + } +} diff --git a/lucene/core/src/test/org/apache/lucene/search/TestPatienceByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestPatienceByteVectorQuery.java new file mode 100644 index 000000000000..47a036259bcb --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestPatienceByteVectorQuery.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import static org.apache.lucene.search.TestKnnByteVectorQuery.floatToBytes; + +import java.io.IOException; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.knn.KnnSearchStrategy; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.TestVectorUtil; + +public class TestPatienceByteVectorQuery extends BaseKnnVectorQueryTestCase { + + @Override + PatienceKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) { + return PatienceKnnVectorQuery.fromByteQuery( + new KnnByteVectorQuery(field, floatToBytes(query), k, queryFilter)); + } + + @Override + AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) { + return PatienceKnnVectorQuery.fromByteQuery( + new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query)); + } + + @Override + float[] randomVector(int dim) { + byte[] b = TestVectorUtil.randomVectorBytes(dim); + float[] v = new float[b.length]; + int vi = 0; + for (int i = 0; i < v.length; i++) { + v[vi++] = b[i]; + } + return v; + } + + @Override + Field getKnnVectorField( + String name, float[] vector, VectorSimilarityFunction similarityFunction) { + return new KnnByteVectorField(name, floatToBytes(vector), similarityFunction); + } + + @Override + Field getKnnVectorField(String name, float[] vector) { + return new KnnByteVectorField(name, floatToBytes(vector), VectorSimilarityFunction.EUCLIDEAN); + } + + public void testToString() throws IOException { + try (Directory indexStore = + getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); + IndexReader reader = DirectoryReader.open(indexStore)) { + AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10); + assertEquals( + "PatienceKnnVectorQuery{saturationThreshold=0.995, patience=7, delegate=KnnByteVectorQuery:field[0,...][10]}", + query.toString("ignored")); + + assertDocScoreQueryToString(query.rewrite(newSearcher(reader))); + + // test with filter + Query filter = new TermQuery(new Term("id", "text")); + query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10, filter); + assertEquals( + "PatienceKnnVectorQuery{saturationThreshold=0.995, patience=7, delegate=KnnByteVectorQuery:field[0,...][10][id:text]}", + query.toString("ignored")); + } + } + + static class ThrowingKnnVectorQuery extends KnnByteVectorQuery { + + public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter) { + super(field, target, k, filter, new KnnSearchStrategy.Hnsw(0)); + } + + @Override + protected TopDocs exactSearch( + LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) { + throw new UnsupportedOperationException("exact search is not supported"); + } + + @Override + public String toString(String field) { + return "ThrowingKnnVectorQuery{" + field + "}"; + } + } +} diff --git a/lucene/core/src/test/org/apache/lucene/search/TestPatienceFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestPatienceFloatVectorQuery.java new file mode 100644 index 000000000000..8f66b2be1e6f --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestPatienceFloatVectorQuery.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.knn.KnnSearchStrategy; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.TestVectorUtil; + +public class TestPatienceFloatVectorQuery extends BaseKnnVectorQueryTestCase { + + @Override + PatienceKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) { + return PatienceKnnVectorQuery.fromFloatQuery( + new KnnFloatVectorQuery(field, query, k, queryFilter)); + } + + @Override + AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) { + return PatienceKnnVectorQuery.fromFloatQuery(new ThrowingKnnVectorQuery(field, vec, k, query)); + } + + @Override + float[] randomVector(int dim) { + return TestVectorUtil.randomVector(dim); + } + + @Override + Field getKnnVectorField( + String name, float[] vector, VectorSimilarityFunction similarityFunction) { + return new KnnFloatVectorField(name, vector, similarityFunction); + } + + @Override + Field getKnnVectorField(String name, float[] vector) { + return new KnnFloatVectorField(name, vector); + } + + public void testToString() throws IOException { + try (Directory indexStore = + getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); + IndexReader reader = DirectoryReader.open(indexStore)) { + AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10); + assertEquals( + "PatienceKnnVectorQuery{saturationThreshold=0.995, patience=7, delegate=KnnFloatVectorQuery:field[0.0,...][10]}", + query.toString("ignored")); + + assertDocScoreQueryToString(query.rewrite(newSearcher(reader))); + + // test with filter + Query filter = new TermQuery(new Term("id", "text")); + query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10, filter); + assertEquals( + "PatienceKnnVectorQuery{saturationThreshold=0.995, patience=7, delegate=KnnFloatVectorQuery:field[0.0,...][10][id:text]}", + query.toString("ignored")); + } + } + + static class ThrowingKnnVectorQuery extends KnnFloatVectorQuery { + + public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter) { + super(field, target, k, filter, new KnnSearchStrategy.Hnsw(0)); + } + + @Override + protected TopDocs exactSearch( + LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) { + throw new UnsupportedOperationException("exact search is not supported"); + } + + @Override + public String toString(String field) { + return "ThrowingKnnVectorQuery{" + field + "}"; + } + } +}