From 3b30c0712aa76c44afc9f5e587c5991ec6c64bd7 Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Tue, 24 Dec 2024 15:06:29 +0100 Subject: [PATCH 01/23] Add a HNSW early termination based on nn queue saturation --- .../lucene99/Lucene99HnswVectorsReader.java | 5 +- .../lucene/search/HnswKnnCollector.java | 28 +++++ .../search/HnswQueueSaturationCollector.java | 117 ++++++++++++++++++ .../lucene/util/hnsw/HnswGraphSearcher.java | 5 + .../HnswQueueSaturationCollectorTest.java | 43 +++++++ 5 files changed, 196 insertions(+), 2 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java create mode 100644 lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java create mode 100644 lucene/core/src/test/org/apache/lucene/search/HnswQueueSaturationCollectorTest.java diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java index ed6388b53cb7..f981ecd7f472 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java @@ -37,6 +37,7 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.internal.hppc.IntObjectHashMap; import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.HnswQueueSaturationCollector; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; import org.apache.lucene.store.IOContext; @@ -314,8 +315,8 @@ private void search( return; } final RandomVectorScorer scorer = scorerSupplier.get(); - final KnnCollector collector = - new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc); + final KnnCollector collector = new HnswQueueSaturationCollector( + new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc)); final Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs); if (knnCollector.k() < scorer.maxOrd()) { HnswGraphSearcher.search(scorer, collector, getGraph(fieldEntry), acceptedOrds); diff --git a/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java new file mode 100644 index 000000000000..bd264a721d94 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +/** + * {@link KnnCollector} that exposes methods to hook into specific parts of the HNSW algorithm. + */ +public interface HnswKnnCollector extends KnnCollector { + + /** + * Indicates exploration of the next HNSW candidate graph node. + */ + void nextCandidate(); +} diff --git a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java new file mode 100644 index 000000000000..ef12938c767a --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.search; + +/** + * A {@link HnswKnnCollector} that early exits when nearest neighbor queue keeps saturating beyond a 'patience' + * parameter, function of {@link #k()}. + */ +public class HnswQueueSaturationCollector implements HnswKnnCollector { + + private static final double DEFAULT_SATURATION_THRESHOLD = 0.995d; + + private final KnnCollector delegate; + private double saturationThreshold; + private int patience; + private boolean globalPatienceFinished; + private int countSaturated; + private int previousQueueSize; + private int currentQueueSize; + + public HnswQueueSaturationCollector(KnnCollector delegate, double saturationThreshold, int patience) { + this.delegate = delegate; + this.previousQueueSize = 0; + this.currentQueueSize = 0; + this.countSaturated = 0; + this.globalPatienceFinished = false; + this.saturationThreshold = saturationThreshold; + this.patience = patience; + } + + public HnswQueueSaturationCollector(KnnCollector delegate) { + this.delegate = delegate; + this.previousQueueSize = 0; + this.currentQueueSize = 0; + this.countSaturated = 0; + this.globalPatienceFinished = false; + this.saturationThreshold = DEFAULT_SATURATION_THRESHOLD; + this.patience = defaultPatience(); + } + + private int defaultPatience() { + return Math.max(7, (int) (k() * 0.3)); + } + + @Override + public boolean earlyTerminated() { + return delegate.earlyTerminated() || globalPatienceFinished; + } + + @Override + public void incVisitedCount(int count) { + delegate.incVisitedCount(count); + } + + @Override + public long visitedCount() { + return delegate.visitedCount(); + } + + @Override + public long visitLimit() { + return delegate.visitLimit(); + } + + @Override + public int k() { + return delegate.k(); + } + + @Override + public boolean collect(int docId, float similarity) { + boolean collect = delegate.collect(docId, similarity); + if (collect) { + currentQueueSize++; + } + return collect; + } + + @Override + public float minCompetitiveSimilarity() { + return delegate.minCompetitiveSimilarity(); + } + + @Override + public TopDocs topDocs() { + return delegate.topDocs(); + } + + @Override + public void nextCandidate() { + double queueSaturation = (double) Math.min(currentQueueSize, previousQueueSize) / currentQueueSize; + previousQueueSize = currentQueueSize; + if (queueSaturation >= saturationThreshold) { + countSaturated++; + } else { + countSaturated = 0; + } + if (countSaturated > patience) { + globalPatienceFinished = true; + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 46d6c93d52c3..6734e208e87f 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -20,6 +20,8 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; + +import org.apache.lucene.search.HnswKnnCollector; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.TopKnnCollector; import org.apache.lucene.util.BitSet; @@ -245,6 +247,9 @@ void searchLevel( } } } + if (results instanceof HnswKnnCollector hnswKnnCollector) { + hnswKnnCollector.nextCandidate(); + } } } diff --git a/lucene/core/src/test/org/apache/lucene/search/HnswQueueSaturationCollectorTest.java b/lucene/core/src/test/org/apache/lucene/search/HnswQueueSaturationCollectorTest.java new file mode 100644 index 000000000000..d33cb9c0300a --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/HnswQueueSaturationCollectorTest.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import org.apache.lucene.tests.util.LuceneTestCase; +import org.junit.Test; + +import java.util.Random; + +/** + * Tests for {@link HnswQueueSaturationCollector} + */ +public class HnswQueueSaturationCollectorTest extends LuceneTestCase { + + @Test + public void testDelegate() { + Random random = random(); + KnnCollector delegate = new TopKnnCollector(random.nextInt(100), random.nextInt(1000)); + HnswQueueSaturationCollector queueSaturationCollector = new HnswQueueSaturationCollector(delegate); + for (int i = 0; i < random.nextInt(100); i++) { + queueSaturationCollector.collect(random.nextInt(1000), random.nextFloat(1.0f)); + } + assertEquals(delegate.k(), queueSaturationCollector.k()); + assertEquals(delegate.visitedCount(), queueSaturationCollector.visitedCount()); + assertEquals(delegate.visitLimit(), queueSaturationCollector.visitLimit()); + assertEquals(delegate.minCompetitiveSimilarity(), queueSaturationCollector.minCompetitiveSimilarity(), 1e-3); + } + +} \ No newline at end of file From 0b24e7937f768fd40f3981fcb91b87701216e32b Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Thu, 2 Jan 2025 10:27:06 +0100 Subject: [PATCH 02/23] enable optimized collector with 1k+ docs --- .../codecs/lucene99/Lucene99HnswVectorsReader.java | 10 ++++++++-- .../lucene/search/HnswQueueSaturationCollector.java | 12 ++++++------ 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java index f981ecd7f472..0e0406bd3de8 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java @@ -315,10 +315,16 @@ private void search( return; } final RandomVectorScorer scorer = scorerSupplier.get(); - final KnnCollector collector = new HnswQueueSaturationCollector( - new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc)); final Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs); if (knnCollector.k() < scorer.maxOrd()) { + final KnnCollector collector; + OrdinalTranslatedKnnCollector ordinalTranslatedKnnCollector = + new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc); + if (scorer.maxOrd() > 1000) { + collector = new HnswQueueSaturationCollector(ordinalTranslatedKnnCollector); + } else { + collector = ordinalTranslatedKnnCollector; + } HnswGraphSearcher.search(scorer, collector, getGraph(fieldEntry), acceptedOrds); } else { // if k is larger than the number of vectors, we can just iterate over all vectors diff --git a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java index ef12938c767a..43088db136b1 100644 --- a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java @@ -19,7 +19,7 @@ /** * A {@link HnswKnnCollector} that early exits when nearest neighbor queue keeps saturating beyond a 'patience' - * parameter, function of {@link #k()}. + * parameter. */ public class HnswQueueSaturationCollector implements HnswKnnCollector { @@ -28,7 +28,7 @@ public class HnswQueueSaturationCollector implements HnswKnnCollector { private final KnnCollector delegate; private double saturationThreshold; private int patience; - private boolean globalPatienceFinished; + private boolean patienceFinished; private int countSaturated; private int previousQueueSize; private int currentQueueSize; @@ -38,7 +38,7 @@ public HnswQueueSaturationCollector(KnnCollector delegate, double saturationThre this.previousQueueSize = 0; this.currentQueueSize = 0; this.countSaturated = 0; - this.globalPatienceFinished = false; + this.patienceFinished = false; this.saturationThreshold = saturationThreshold; this.patience = patience; } @@ -48,7 +48,7 @@ public HnswQueueSaturationCollector(KnnCollector delegate) { this.previousQueueSize = 0; this.currentQueueSize = 0; this.countSaturated = 0; - this.globalPatienceFinished = false; + this.patienceFinished = false; this.saturationThreshold = DEFAULT_SATURATION_THRESHOLD; this.patience = defaultPatience(); } @@ -59,7 +59,7 @@ private int defaultPatience() { @Override public boolean earlyTerminated() { - return delegate.earlyTerminated() || globalPatienceFinished; + return delegate.earlyTerminated() || patienceFinished; } @Override @@ -111,7 +111,7 @@ public void nextCandidate() { countSaturated = 0; } if (countSaturated > patience) { - globalPatienceFinished = true; + patienceFinished = true; } } } From 93fb470d40ad1ef11049364787993dfb2b430e76 Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Thu, 2 Jan 2025 15:24:09 +0100 Subject: [PATCH 03/23] tidy --- .../lucene99/Lucene99HnswVectorsReader.java | 4 +- .../lucene/search/HnswKnnCollector.java | 10 +- .../search/HnswQueueSaturationCollector.java | 182 +++++++++--------- .../lucene/util/hnsw/HnswGraphSearcher.java | 1 - .../HnswQueueSaturationCollectorTest.java | 38 ++-- 5 files changed, 116 insertions(+), 119 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java index 0e0406bd3de8..6f18170f2e8f 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java @@ -36,8 +36,8 @@ import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.internal.hppc.IntObjectHashMap; -import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.HnswQueueSaturationCollector; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; import org.apache.lucene.store.IOContext; @@ -319,7 +319,7 @@ private void search( if (knnCollector.k() < scorer.maxOrd()) { final KnnCollector collector; OrdinalTranslatedKnnCollector ordinalTranslatedKnnCollector = - new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc); + new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc); if (scorer.maxOrd() > 1000) { collector = new HnswQueueSaturationCollector(ordinalTranslatedKnnCollector); } else { diff --git a/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java index bd264a721d94..e145ea99dd63 100644 --- a/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java @@ -16,13 +16,9 @@ */ package org.apache.lucene.search; -/** - * {@link KnnCollector} that exposes methods to hook into specific parts of the HNSW algorithm. - */ +/** {@link KnnCollector} that exposes methods to hook into specific parts of the HNSW algorithm. */ public interface HnswKnnCollector extends KnnCollector { - /** - * Indicates exploration of the next HNSW candidate graph node. - */ - void nextCandidate(); + /** Indicates exploration of the next HNSW candidate graph node. */ + void nextCandidate(); } diff --git a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java index 43088db136b1..63b32930f3c0 100644 --- a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java @@ -18,100 +18,102 @@ package org.apache.lucene.search; /** - * A {@link HnswKnnCollector} that early exits when nearest neighbor queue keeps saturating beyond a 'patience' - * parameter. + * A {@link HnswKnnCollector} that early exits when nearest neighbor queue keeps saturating beyond a + * 'patience' parameter. */ public class HnswQueueSaturationCollector implements HnswKnnCollector { - private static final double DEFAULT_SATURATION_THRESHOLD = 0.995d; - - private final KnnCollector delegate; - private double saturationThreshold; - private int patience; - private boolean patienceFinished; - private int countSaturated; - private int previousQueueSize; - private int currentQueueSize; - - public HnswQueueSaturationCollector(KnnCollector delegate, double saturationThreshold, int patience) { - this.delegate = delegate; - this.previousQueueSize = 0; - this.currentQueueSize = 0; - this.countSaturated = 0; - this.patienceFinished = false; - this.saturationThreshold = saturationThreshold; - this.patience = patience; + private static final double DEFAULT_SATURATION_THRESHOLD = 0.995d; + + private final KnnCollector delegate; + private double saturationThreshold; + private int patience; + private boolean patienceFinished; + private int countSaturated; + private int previousQueueSize; + private int currentQueueSize; + + public HnswQueueSaturationCollector( + KnnCollector delegate, double saturationThreshold, int patience) { + this.delegate = delegate; + this.previousQueueSize = 0; + this.currentQueueSize = 0; + this.countSaturated = 0; + this.patienceFinished = false; + this.saturationThreshold = saturationThreshold; + this.patience = patience; + } + + public HnswQueueSaturationCollector(KnnCollector delegate) { + this.delegate = delegate; + this.previousQueueSize = 0; + this.currentQueueSize = 0; + this.countSaturated = 0; + this.patienceFinished = false; + this.saturationThreshold = DEFAULT_SATURATION_THRESHOLD; + this.patience = defaultPatience(); + } + + private int defaultPatience() { + return Math.max(7, (int) (k() * 0.3)); + } + + @Override + public boolean earlyTerminated() { + return delegate.earlyTerminated() || patienceFinished; + } + + @Override + public void incVisitedCount(int count) { + delegate.incVisitedCount(count); + } + + @Override + public long visitedCount() { + return delegate.visitedCount(); + } + + @Override + public long visitLimit() { + return delegate.visitLimit(); + } + + @Override + public int k() { + return delegate.k(); + } + + @Override + public boolean collect(int docId, float similarity) { + boolean collect = delegate.collect(docId, similarity); + if (collect) { + currentQueueSize++; } - - public HnswQueueSaturationCollector(KnnCollector delegate) { - this.delegate = delegate; - this.previousQueueSize = 0; - this.currentQueueSize = 0; - this.countSaturated = 0; - this.patienceFinished = false; - this.saturationThreshold = DEFAULT_SATURATION_THRESHOLD; - this.patience = defaultPatience(); - } - - private int defaultPatience() { - return Math.max(7, (int) (k() * 0.3)); - } - - @Override - public boolean earlyTerminated() { - return delegate.earlyTerminated() || patienceFinished; - } - - @Override - public void incVisitedCount(int count) { - delegate.incVisitedCount(count); - } - - @Override - public long visitedCount() { - return delegate.visitedCount(); - } - - @Override - public long visitLimit() { - return delegate.visitLimit(); - } - - @Override - public int k() { - return delegate.k(); - } - - @Override - public boolean collect(int docId, float similarity) { - boolean collect = delegate.collect(docId, similarity); - if (collect) { - currentQueueSize++; - } - return collect; + return collect; + } + + @Override + public float minCompetitiveSimilarity() { + return delegate.minCompetitiveSimilarity(); + } + + @Override + public TopDocs topDocs() { + return delegate.topDocs(); + } + + @Override + public void nextCandidate() { + double queueSaturation = + (double) Math.min(currentQueueSize, previousQueueSize) / currentQueueSize; + previousQueueSize = currentQueueSize; + if (queueSaturation >= saturationThreshold) { + countSaturated++; + } else { + countSaturated = 0; } - - @Override - public float minCompetitiveSimilarity() { - return delegate.minCompetitiveSimilarity(); - } - - @Override - public TopDocs topDocs() { - return delegate.topDocs(); - } - - @Override - public void nextCandidate() { - double queueSaturation = (double) Math.min(currentQueueSize, previousQueueSize) / currentQueueSize; - previousQueueSize = currentQueueSize; - if (queueSaturation >= saturationThreshold) { - countSaturated++; - } else { - countSaturated = 0; - } - if (countSaturated > patience) { - patienceFinished = true; - } + if (countSaturated > patience) { + patienceFinished = true; } + } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 6734e208e87f..136ca37ae07d 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -20,7 +20,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; - import org.apache.lucene.search.HnswKnnCollector; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.TopKnnCollector; diff --git a/lucene/core/src/test/org/apache/lucene/search/HnswQueueSaturationCollectorTest.java b/lucene/core/src/test/org/apache/lucene/search/HnswQueueSaturationCollectorTest.java index d33cb9c0300a..68124ed533c6 100644 --- a/lucene/core/src/test/org/apache/lucene/search/HnswQueueSaturationCollectorTest.java +++ b/lucene/core/src/test/org/apache/lucene/search/HnswQueueSaturationCollectorTest.java @@ -16,28 +16,28 @@ */ package org.apache.lucene.search; +import java.util.Random; import org.apache.lucene.tests.util.LuceneTestCase; import org.junit.Test; -import java.util.Random; - -/** - * Tests for {@link HnswQueueSaturationCollector} - */ +/** Tests for {@link HnswQueueSaturationCollector} */ public class HnswQueueSaturationCollectorTest extends LuceneTestCase { - @Test - public void testDelegate() { - Random random = random(); - KnnCollector delegate = new TopKnnCollector(random.nextInt(100), random.nextInt(1000)); - HnswQueueSaturationCollector queueSaturationCollector = new HnswQueueSaturationCollector(delegate); - for (int i = 0; i < random.nextInt(100); i++) { - queueSaturationCollector.collect(random.nextInt(1000), random.nextFloat(1.0f)); - } - assertEquals(delegate.k(), queueSaturationCollector.k()); - assertEquals(delegate.visitedCount(), queueSaturationCollector.visitedCount()); - assertEquals(delegate.visitLimit(), queueSaturationCollector.visitLimit()); - assertEquals(delegate.minCompetitiveSimilarity(), queueSaturationCollector.minCompetitiveSimilarity(), 1e-3); + @Test + public void testDelegate() { + Random random = random(); + KnnCollector delegate = new TopKnnCollector(random.nextInt(100), random.nextInt(1000)); + HnswQueueSaturationCollector queueSaturationCollector = + new HnswQueueSaturationCollector(delegate); + for (int i = 0; i < random.nextInt(100); i++) { + queueSaturationCollector.collect(random.nextInt(1000), random.nextFloat(1.0f)); } - -} \ No newline at end of file + assertEquals(delegate.k(), queueSaturationCollector.k()); + assertEquals(delegate.visitedCount(), queueSaturationCollector.visitedCount()); + assertEquals(delegate.visitLimit(), queueSaturationCollector.visitLimit()); + assertEquals( + delegate.minCompetitiveSimilarity(), + queueSaturationCollector.minCompetitiveSimilarity(), + 1e-3); + } +} From b7eb24fa43d855bd7816a73fcee2413925a6cce0 Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Wed, 15 Jan 2025 12:26:33 +0100 Subject: [PATCH 04/23] don't trigger exact search when early terminating --- .../lucene/search/HnswQueueSaturationCollector.java | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java index 63b32930f3c0..1cee33b297e1 100644 --- a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java @@ -99,7 +99,16 @@ public float minCompetitiveSimilarity() { @Override public TopDocs topDocs() { - return delegate.topDocs(); + TopDocs topDocs; + if (patienceFinished && delegate.earlyTerminated() == false) { + TopDocs delegateDocs = delegate.topDocs(); + TotalHits totalHits = + new TotalHits(delegateDocs.totalHits.value(), TotalHits.Relation.EQUAL_TO); + topDocs = new TopDocs(totalHits, delegateDocs.scoreDocs); + } else { + topDocs = delegate.topDocs(); + } + return topDocs; } @Override From d143bbb1f817aee62efebf8c83e5d7306057c84c Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Wed, 15 Jan 2025 14:39:45 +0100 Subject: [PATCH 05/23] improved javadoc --- .../apache/lucene/search/HnswQueueSaturationCollector.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java index 1cee33b297e1..ad7509d2d780 100644 --- a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java @@ -19,7 +19,9 @@ /** * A {@link HnswKnnCollector} that early exits when nearest neighbor queue keeps saturating beyond a - * 'patience' parameter. + * 'patience' parameter. This records the rate of collection of new nearest neighbors in the + * {@code delegate} {@link KnnCollector) queue, at each HNSW node candidate visit. Once it saturates for a number of + * consecutive node visits (e.g., the patience parameter), this early terminates. */ public class HnswQueueSaturationCollector implements HnswKnnCollector { From 51df9ee1d14f79d9bfbda6335341bd474a6a9b8a Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Wed, 15 Jan 2025 14:44:10 +0100 Subject: [PATCH 06/23] improved javadoc --- .../apache/lucene/search/HnswQueueSaturationCollector.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java index ad7509d2d780..e7a2b59b7d5d 100644 --- a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java @@ -20,8 +20,8 @@ /** * A {@link HnswKnnCollector} that early exits when nearest neighbor queue keeps saturating beyond a * 'patience' parameter. This records the rate of collection of new nearest neighbors in the - * {@code delegate} {@link KnnCollector) queue, at each HNSW node candidate visit. Once it saturates for a number of - * consecutive node visits (e.g., the patience parameter), this early terminates. + * {@code delegate} {@link org.apache.lucene.search.KnnCollector) queue, at each HNSW node candidate visit. + * Once it saturates for a number of consecutive node visits (e.g., the patience parameter), this early terminates. */ public class HnswQueueSaturationCollector implements HnswKnnCollector { From e55f967989207603952644df99eea702b32a3695 Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Wed, 15 Jan 2025 14:46:50 +0100 Subject: [PATCH 07/23] improved javadoc --- .../org/apache/lucene/search/HnswQueueSaturationCollector.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java index e7a2b59b7d5d..dd205a0e8261 100644 --- a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java @@ -20,7 +20,7 @@ /** * A {@link HnswKnnCollector} that early exits when nearest neighbor queue keeps saturating beyond a * 'patience' parameter. This records the rate of collection of new nearest neighbors in the - * {@code delegate} {@link org.apache.lucene.search.KnnCollector) queue, at each HNSW node candidate visit. + * {@code delegate} KnnCollector queue, at each HNSW node candidate visit. * Once it saturates for a number of consecutive node visits (e.g., the patience parameter), this early terminates. */ public class HnswQueueSaturationCollector implements HnswKnnCollector { From e3f8db33cc94c7e5f71538b58fef1e08dd2e4eea Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Wed, 15 Jan 2025 15:37:31 +0100 Subject: [PATCH 08/23] minor fixes, more tests --- .../search/HnswQueueSaturationCollector.java | 6 ++-- .../HnswQueueSaturationCollectorTest.java | 30 +++++++++++++++++-- 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java index dd205a0e8261..1b7ae3c01271 100644 --- a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java @@ -19,9 +19,9 @@ /** * A {@link HnswKnnCollector} that early exits when nearest neighbor queue keeps saturating beyond a - * 'patience' parameter. This records the rate of collection of new nearest neighbors in the - * {@code delegate} KnnCollector queue, at each HNSW node candidate visit. - * Once it saturates for a number of consecutive node visits (e.g., the patience parameter), this early terminates. + * 'patience' parameter. This records the rate of collection of new nearest neighbors in the {@code + * delegate} KnnCollector queue, at each HNSW node candidate visit. Once it saturates for a number + * of consecutive node visits (e.g., the patience parameter), this early terminates. */ public class HnswQueueSaturationCollector implements HnswKnnCollector { diff --git a/lucene/core/src/test/org/apache/lucene/search/HnswQueueSaturationCollectorTest.java b/lucene/core/src/test/org/apache/lucene/search/HnswQueueSaturationCollectorTest.java index 68124ed533c6..6766c666600f 100644 --- a/lucene/core/src/test/org/apache/lucene/search/HnswQueueSaturationCollectorTest.java +++ b/lucene/core/src/test/org/apache/lucene/search/HnswQueueSaturationCollectorTest.java @@ -26,11 +26,13 @@ public class HnswQueueSaturationCollectorTest extends LuceneTestCase { @Test public void testDelegate() { Random random = random(); - KnnCollector delegate = new TopKnnCollector(random.nextInt(100), random.nextInt(1000)); + int numDocs = 100; + int k = random.nextInt(10); + KnnCollector delegate = new TopKnnCollector(k, numDocs); HnswQueueSaturationCollector queueSaturationCollector = new HnswQueueSaturationCollector(delegate); - for (int i = 0; i < random.nextInt(100); i++) { - queueSaturationCollector.collect(random.nextInt(1000), random.nextFloat(1.0f)); + for (int i = 0; i < random.nextInt(numDocs); i++) { + queueSaturationCollector.collect(random.nextInt(numDocs), random.nextFloat(1.0f)); } assertEquals(delegate.k(), queueSaturationCollector.k()); assertEquals(delegate.visitedCount(), queueSaturationCollector.visitedCount()); @@ -40,4 +42,26 @@ public void testDelegate() { queueSaturationCollector.minCompetitiveSimilarity(), 1e-3); } + + @Test + public void testEarlyExit() { + Random random = random(); + int numDocs = 10000; + int k = random.nextInt(100); + KnnCollector delegate = new TopKnnCollector(k, numDocs); + HnswQueueSaturationCollector queueSaturationCollector = + new HnswQueueSaturationCollector(delegate); + for (int i = 0; i < random.nextInt(numDocs); i++) { + queueSaturationCollector.collect(random.nextInt(numDocs), random.nextFloat(1.0f)); + boolean earlyTerminatedSaturation = queueSaturationCollector.earlyTerminated(); + boolean earlyTerminatedDelegate = delegate.earlyTerminated(); + assertTrue(earlyTerminatedSaturation || !earlyTerminatedDelegate); + if (earlyTerminatedDelegate) { + assertTrue(earlyTerminatedSaturation); + } + if (!earlyTerminatedSaturation) { + assertFalse(earlyTerminatedSaturation); + } + } + } } From a71e93602e4c67b9317b5c416b72f8d93c679b6e Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Wed, 15 Jan 2025 15:57:48 +0100 Subject: [PATCH 09/23] tidy --- .../src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 3268b34ebbd6..fb93398d1596 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -20,8 +20,8 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; -import org.apache.lucene.search.HnswKnnCollector; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.HnswKnnCollector; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.TopKnnCollector; import org.apache.lucene.search.knn.EntryPointProvider; From 09b0229712009dc3102f646612bf4bddec4c6188 Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Thu, 16 Jan 2025 12:10:50 +0100 Subject: [PATCH 10/23] dropped useless assertions --- .../lucene/search/HnswQueueSaturationCollectorTest.java | 6 ------ 1 file changed, 6 deletions(-) diff --git a/lucene/core/src/test/org/apache/lucene/search/HnswQueueSaturationCollectorTest.java b/lucene/core/src/test/org/apache/lucene/search/HnswQueueSaturationCollectorTest.java index 6766c666600f..f2881106c890 100644 --- a/lucene/core/src/test/org/apache/lucene/search/HnswQueueSaturationCollectorTest.java +++ b/lucene/core/src/test/org/apache/lucene/search/HnswQueueSaturationCollectorTest.java @@ -56,12 +56,6 @@ public void testEarlyExit() { boolean earlyTerminatedSaturation = queueSaturationCollector.earlyTerminated(); boolean earlyTerminatedDelegate = delegate.earlyTerminated(); assertTrue(earlyTerminatedSaturation || !earlyTerminatedDelegate); - if (earlyTerminatedDelegate) { - assertTrue(earlyTerminatedSaturation); - } - if (!earlyTerminatedSaturation) { - assertFalse(earlyTerminatedSaturation); - } } } } From 74132f15752f130368ca937b992ea3b1f93416d2 Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Thu, 16 Jan 2025 12:12:41 +0100 Subject: [PATCH 11/23] changes added --- lucene/CHANGES.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 5084c25f3560..4a108e5cc0ac 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -22,6 +22,7 @@ Optimizations --------------------- * GITHUB#14011: Reduce allocation rate in HNSW concurrent merge. (Viliam Durina) * GITHUB#14022: Optimize DFS marking of connected components in HNSW by reducing stack depth, improving performance and reducing allocations. (Viswanath Kuchibhotla) +* GITHUB#14094: Early terminate when HNSW nearest neighbor queue saturates (Tommaso Teofili) Bug Fixes --------------------- From 370f513ea0e4feb7eed67c59393803085a1ee382 Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Thu, 16 Jan 2025 12:16:20 +0100 Subject: [PATCH 12/23] changes to 10.2 --- lucene/CHANGES.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 158ac129fa62..aa11b8b6fac8 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -22,7 +22,6 @@ Optimizations --------------------- * GITHUB#14011: Reduce allocation rate in HNSW concurrent merge. (Viliam Durina) * GITHUB#14022: Optimize DFS marking of connected components in HNSW by reducing stack depth, improving performance and reducing allocations. (Viswanath Kuchibhotla) -* GITHUB#14094: Early terminate when HNSW nearest neighbor queue saturates (Tommaso Teofili) Bug Fixes --------------------- @@ -77,6 +76,8 @@ Optimizations * GITHUB#14133: Dense blocks of postings are now encoded as bit sets. (Adrien Grand) +* GITHUB#14094: Early terminate when HNSW nearest neighbor queue saturates (Tommaso Teofili) + Bug Fixes --------------------- From fed77c9e3b943fa1789306350dd70965b0bfef62 Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Thu, 16 Jan 2025 12:33:01 +0100 Subject: [PATCH 13/23] more tests --- .../HnswQueueSaturationCollectorTest.java | 39 ++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/lucene/core/src/test/org/apache/lucene/search/HnswQueueSaturationCollectorTest.java b/lucene/core/src/test/org/apache/lucene/search/HnswQueueSaturationCollectorTest.java index f2881106c890..11b756c04dba 100644 --- a/lucene/core/src/test/org/apache/lucene/search/HnswQueueSaturationCollectorTest.java +++ b/lucene/core/src/test/org/apache/lucene/search/HnswQueueSaturationCollectorTest.java @@ -44,7 +44,22 @@ public void testDelegate() { } @Test - public void testEarlyExit() { + public void testEarlyExpectedExit() { + int numDocs = 1000; + int k = 10; + KnnCollector delegate = new TopKnnCollector(k, numDocs); + HnswQueueSaturationCollector queueSaturationCollector = + new HnswQueueSaturationCollector(delegate, 0.9, 10); + for (int i = 0; i < numDocs; i++) { + queueSaturationCollector.collect(i, 1.0f - i * 1e-3f); + if (queueSaturationCollector.earlyTerminated()) { + assertEquals(20, i); + } + } + } + + @Test + public void testDelegateVsSaturateEarlyExit() { Random random = random(); int numDocs = 10000; int k = random.nextInt(100); @@ -58,4 +73,26 @@ public void testEarlyExit() { assertTrue(earlyTerminatedSaturation || !earlyTerminatedDelegate); } } + + @Test + public void testEarlyExitRelation() { + Random random = random(); + int numDocs = 10000; + int k = random.nextInt(100); + KnnCollector delegate = new TopKnnCollector(k, numDocs); + HnswQueueSaturationCollector queueSaturationCollector = + new HnswQueueSaturationCollector(delegate); + for (int i = 0; i < random.nextInt(numDocs); i++) { + queueSaturationCollector.collect(random.nextInt(numDocs), random.nextFloat(1.0f)); + if (delegate.earlyTerminated()) { + TopDocs topDocs = queueSaturationCollector.topDocs(); + assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, topDocs.totalHits.relation()); + } + if (queueSaturationCollector.earlyTerminated()) { + TopDocs topDocs = queueSaturationCollector.topDocs(); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocs.totalHits.relation()); + break; + } + } + } } From 88d22df28eab9990bb23a857fb28f0389d312194 Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Thu, 16 Jan 2025 12:33:33 +0100 Subject: [PATCH 14/23] more tests --- .../apache/lucene/search/HnswQueueSaturationCollectorTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/core/src/test/org/apache/lucene/search/HnswQueueSaturationCollectorTest.java b/lucene/core/src/test/org/apache/lucene/search/HnswQueueSaturationCollectorTest.java index 11b756c04dba..7d59e52ab0d6 100644 --- a/lucene/core/src/test/org/apache/lucene/search/HnswQueueSaturationCollectorTest.java +++ b/lucene/core/src/test/org/apache/lucene/search/HnswQueueSaturationCollectorTest.java @@ -79,7 +79,7 @@ public void testEarlyExitRelation() { Random random = random(); int numDocs = 10000; int k = random.nextInt(100); - KnnCollector delegate = new TopKnnCollector(k, numDocs); + KnnCollector delegate = new TopKnnCollector(k, random.nextInt(numDocs)); HnswQueueSaturationCollector queueSaturationCollector = new HnswQueueSaturationCollector(delegate); for (int i = 0; i < random.nextInt(numDocs); i++) { From e86ebdc96c22c49fa0b18a4e0b29334ac3e421ff Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Fri, 17 Jan 2025 12:07:41 +0100 Subject: [PATCH 15/23] minor fixes --- .../search/HnswQueueSaturationCollector.java | 6 +++--- ...a => TestHnswQueueSaturationCollector.java} | 18 ++++++++++++++---- 2 files changed, 17 insertions(+), 7 deletions(-) rename lucene/core/src/test/org/apache/lucene/search/{HnswQueueSaturationCollectorTest.java => TestHnswQueueSaturationCollector.java} (89%) diff --git a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java index 1b7ae3c01271..b2af5cbd6853 100644 --- a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java @@ -28,14 +28,14 @@ public class HnswQueueSaturationCollector implements HnswKnnCollector { private static final double DEFAULT_SATURATION_THRESHOLD = 0.995d; private final KnnCollector delegate; - private double saturationThreshold; - private int patience; + private final double saturationThreshold; + private final int patience; private boolean patienceFinished; private int countSaturated; private int previousQueueSize; private int currentQueueSize; - public HnswQueueSaturationCollector( + HnswQueueSaturationCollector( KnnCollector delegate, double saturationThreshold, int patience) { this.delegate = delegate; this.previousQueueSize = 0; diff --git a/lucene/core/src/test/org/apache/lucene/search/HnswQueueSaturationCollectorTest.java b/lucene/core/src/test/org/apache/lucene/search/TestHnswQueueSaturationCollector.java similarity index 89% rename from lucene/core/src/test/org/apache/lucene/search/HnswQueueSaturationCollectorTest.java rename to lucene/core/src/test/org/apache/lucene/search/TestHnswQueueSaturationCollector.java index 7d59e52ab0d6..1ef298450160 100644 --- a/lucene/core/src/test/org/apache/lucene/search/HnswQueueSaturationCollectorTest.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestHnswQueueSaturationCollector.java @@ -21,13 +21,13 @@ import org.junit.Test; /** Tests for {@link HnswQueueSaturationCollector} */ -public class HnswQueueSaturationCollectorTest extends LuceneTestCase { +public class TestHnswQueueSaturationCollector extends LuceneTestCase { @Test public void testDelegate() { Random random = random(); int numDocs = 100; - int k = random.nextInt(10); + int k = random.nextInt(1, 10); KnnCollector delegate = new TopKnnCollector(k, numDocs); HnswQueueSaturationCollector queueSaturationCollector = new HnswQueueSaturationCollector(delegate); @@ -52,8 +52,12 @@ public void testEarlyExpectedExit() { new HnswQueueSaturationCollector(delegate, 0.9, 10); for (int i = 0; i < numDocs; i++) { queueSaturationCollector.collect(i, 1.0f - i * 1e-3f); + if (i % 10 == 0) { + queueSaturationCollector.nextCandidate(); + } if (queueSaturationCollector.earlyTerminated()) { - assertEquals(20, i); + assertEquals(120, i); + break; } } } @@ -62,12 +66,15 @@ public void testEarlyExpectedExit() { public void testDelegateVsSaturateEarlyExit() { Random random = random(); int numDocs = 10000; - int k = random.nextInt(100); + int k = random.nextInt(1, 100); KnnCollector delegate = new TopKnnCollector(k, numDocs); HnswQueueSaturationCollector queueSaturationCollector = new HnswQueueSaturationCollector(delegate); for (int i = 0; i < random.nextInt(numDocs); i++) { queueSaturationCollector.collect(random.nextInt(numDocs), random.nextFloat(1.0f)); + if (i % 10 == 0) { + queueSaturationCollector.nextCandidate(); + } boolean earlyTerminatedSaturation = queueSaturationCollector.earlyTerminated(); boolean earlyTerminatedDelegate = delegate.earlyTerminated(); assertTrue(earlyTerminatedSaturation || !earlyTerminatedDelegate); @@ -84,6 +91,9 @@ public void testEarlyExitRelation() { new HnswQueueSaturationCollector(delegate); for (int i = 0; i < random.nextInt(numDocs); i++) { queueSaturationCollector.collect(random.nextInt(numDocs), random.nextFloat(1.0f)); + if (i % 10 == 0) { + queueSaturationCollector.nextCandidate(); + } if (delegate.earlyTerminated()) { TopDocs topDocs = queueSaturationCollector.topDocs(); assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, topDocs.totalHits.relation()); From 20a481f0f1c764343dd8c648297e1f0d23a610bc Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Mon, 3 Feb 2025 14:08:23 +0100 Subject: [PATCH 16/23] tidy --- .../org/apache/lucene/search/HnswQueueSaturationCollector.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java index b2af5cbd6853..00396b39f1db 100644 --- a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java @@ -35,8 +35,7 @@ public class HnswQueueSaturationCollector implements HnswKnnCollector { private int previousQueueSize; private int currentQueueSize; - HnswQueueSaturationCollector( - KnnCollector delegate, double saturationThreshold, int patience) { + HnswQueueSaturationCollector(KnnCollector delegate, double saturationThreshold, int patience) { this.delegate = delegate; this.previousQueueSize = 0; this.currentQueueSize = 0; From c6dbf7ec69eea1614f8efdbf2a77f57b0f1d8405 Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Thu, 13 Feb 2025 13:21:58 +0100 Subject: [PATCH 17/23] make hnsw collector a decorator --- .../lucene/search/HnswKnnCollector.java | 8 ++++-- .../search/HnswQueueSaturationCollector.java | 26 +++---------------- 2 files changed, 10 insertions(+), 24 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java index e145ea99dd63..24dfbcf88ef4 100644 --- a/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java @@ -17,8 +17,12 @@ package org.apache.lucene.search; /** {@link KnnCollector} that exposes methods to hook into specific parts of the HNSW algorithm. */ -public interface HnswKnnCollector extends KnnCollector { +public abstract class HnswKnnCollector extends KnnCollector.Decorator { + + public HnswKnnCollector(KnnCollector collector) { + super(collector); + } /** Indicates exploration of the next HNSW candidate graph node. */ - void nextCandidate(); + public abstract void nextCandidate(); } diff --git a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java index 00396b39f1db..95a3c100c175 100644 --- a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java @@ -23,9 +23,9 @@ * delegate} KnnCollector queue, at each HNSW node candidate visit. Once it saturates for a number * of consecutive node visits (e.g., the patience parameter), this early terminates. */ -public class HnswQueueSaturationCollector implements HnswKnnCollector { +public class HnswQueueSaturationCollector extends HnswKnnCollector { - private static final double DEFAULT_SATURATION_THRESHOLD = 0.995d; + private static final double DEFAULT_SATURATION_THRESHOLD = 0.95d; private final KnnCollector delegate; private final double saturationThreshold; @@ -36,6 +36,7 @@ public class HnswQueueSaturationCollector implements HnswKnnCollector { private int currentQueueSize; HnswQueueSaturationCollector(KnnCollector delegate, double saturationThreshold, int patience) { + super(delegate); this.delegate = delegate; this.previousQueueSize = 0; this.currentQueueSize = 0; @@ -46,6 +47,7 @@ public class HnswQueueSaturationCollector implements HnswKnnCollector { } public HnswQueueSaturationCollector(KnnCollector delegate) { + super(delegate); this.delegate = delegate; this.previousQueueSize = 0; this.currentQueueSize = 0; @@ -64,26 +66,6 @@ public boolean earlyTerminated() { return delegate.earlyTerminated() || patienceFinished; } - @Override - public void incVisitedCount(int count) { - delegate.incVisitedCount(count); - } - - @Override - public long visitedCount() { - return delegate.visitedCount(); - } - - @Override - public long visitLimit() { - return delegate.visitLimit(); - } - - @Override - public int k() { - return delegate.k(); - } - @Override public boolean collect(int docId, float similarity) { boolean collect = delegate.collect(docId, similarity); From 460efd9f8796810ffc0ff7dee3e3e9565e6ea0c7 Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Tue, 25 Feb 2025 12:19:32 +0100 Subject: [PATCH 18/23] moved the early termination logic into PatienceKnnVectorQuery --- .../lucene99/Lucene99HnswVectorsReader.java | 9 +- .../lucene/search/HnswKnnCollector.java | 9 +- .../search/HnswQueueSaturationCollector.java | 19 +- .../lucene/search/PatienceKnnVectorQuery.java | 237 ++++++++++++++++++ .../hnsw/OrdinalTranslatedKnnCollector.java | 10 +- .../TestHnswQueueSaturationCollector.java | 6 +- .../search/TestPatienceByteVectorQuery.java | 107 ++++++++ .../search/TestPatienceFloatVectorQuery.java | 98 ++++++++ 8 files changed, 464 insertions(+), 31 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/search/PatienceKnnVectorQuery.java create mode 100644 lucene/core/src/test/org/apache/lucene/search/TestPatienceByteVectorQuery.java create mode 100644 lucene/core/src/test/org/apache/lucene/search/TestPatienceFloatVectorQuery.java diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java index 44f502399063..1ea6b51bfd47 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java @@ -36,7 +36,6 @@ import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.internal.hppc.IntObjectHashMap; -import org.apache.lucene.search.HnswQueueSaturationCollector; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; @@ -316,14 +315,8 @@ private void search( } final RandomVectorScorer scorer = scorerSupplier.get(); final Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs); - final KnnCollector collector; - OrdinalTranslatedKnnCollector ordinalTranslatedKnnCollector = + final KnnCollector collector = new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc); - if (scorer.maxOrd() > 1000) { - collector = new HnswQueueSaturationCollector(ordinalTranslatedKnnCollector); - } else { - collector = ordinalTranslatedKnnCollector; - } HnswGraph graph = getGraph(fieldEntry); boolean doHnsw = knnCollector.k() < scorer.maxOrd(); // Take into account if quantized? E.g. some scorer cost? diff --git a/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java index 24dfbcf88ef4..9902e340ad81 100644 --- a/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java @@ -16,7 +16,11 @@ */ package org.apache.lucene.search; -/** {@link KnnCollector} that exposes methods to hook into specific parts of the HNSW algorithm. */ +/** + * {@link KnnCollector} that exposes methods to hook into specific parts of the HNSW algorithm. + * + * @lucene.experimental + */ public abstract class HnswKnnCollector extends KnnCollector.Decorator { public HnswKnnCollector(KnnCollector collector) { @@ -24,5 +28,6 @@ public HnswKnnCollector(KnnCollector collector) { } /** Indicates exploration of the next HNSW candidate graph node. */ - public abstract void nextCandidate(); + public void nextCandidate() {} + ; } diff --git a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java index 95a3c100c175..95ad5973ca43 100644 --- a/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java @@ -22,11 +22,11 @@ * 'patience' parameter. This records the rate of collection of new nearest neighbors in the {@code * delegate} KnnCollector queue, at each HNSW node candidate visit. Once it saturates for a number * of consecutive node visits (e.g., the patience parameter), this early terminates. + * + * @lucene.experimental */ public class HnswQueueSaturationCollector extends HnswKnnCollector { - private static final double DEFAULT_SATURATION_THRESHOLD = 0.95d; - private final KnnCollector delegate; private final double saturationThreshold; private final int patience; @@ -46,21 +46,6 @@ public class HnswQueueSaturationCollector extends HnswKnnCollector { this.patience = patience; } - public HnswQueueSaturationCollector(KnnCollector delegate) { - super(delegate); - this.delegate = delegate; - this.previousQueueSize = 0; - this.currentQueueSize = 0; - this.countSaturated = 0; - this.patienceFinished = false; - this.saturationThreshold = DEFAULT_SATURATION_THRESHOLD; - this.patience = defaultPatience(); - } - - private int defaultPatience() { - return Math.max(7, (int) (k() * 0.3)); - } - @Override public boolean earlyTerminated() { return delegate.earlyTerminated() || patienceFinished; diff --git a/lucene/core/src/java/org/apache/lucene/search/PatienceKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/PatienceKnnVectorQuery.java new file mode 100644 index 000000000000..b253fbeb2a30 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/PatienceKnnVectorQuery.java @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.Objects; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; +import org.apache.lucene.search.knn.KnnCollectorManager; +import org.apache.lucene.search.knn.KnnSearchStrategy; +import org.apache.lucene.util.Bits; + +/** + * This is a version of knn vector query that exits early when HNSW queue saturates over a {@code + * #saturationThreshold} for more than {@code #patience} times. + * + *

See "Patience in + * Proximity: A Simple Early Termination Strategy for HNSW Graph Traversal in Approximate k-Nearest + * Neighbor Search" (Teofili and Lin). In ECIR '25: Proceedings of the 47th European Conference + * on Information Retrieval. + * + * @lucene.experimental + */ +public class PatienceKnnVectorQuery extends AbstractKnnVectorQuery { + + private static final double DEFAULT_SATURATION_THRESHOLD = 0.995d; + + private final int patience; + private final double saturationThreshold; + + final AbstractKnnVectorQuery delegate; + + /** + * Construct a new PatienceKnnVectorQuery instance for a float vector field + * + * @param knnQuery the knn query to be seeded + * @param saturationThreshold the early exit saturation threshold + * @param patience the patience parameter + * @return a new PatienceKnnVectorQuery instance + * @lucene.experimental + */ + public static PatienceKnnVectorQuery fromFloatQuery( + KnnFloatVectorQuery knnQuery, double saturationThreshold, int patience) { + return new PatienceKnnVectorQuery(knnQuery, saturationThreshold, patience); + } + + /** + * Construct a new PatienceKnnVectorQuery instance for a float vector field + * + * @param knnQuery the knn query to be seeded + * @return a new PatienceKnnVectorQuery instance + * @lucene.experimental + */ + public static PatienceKnnVectorQuery fromFloatQuery(KnnFloatVectorQuery knnQuery) { + return new PatienceKnnVectorQuery( + knnQuery, DEFAULT_SATURATION_THRESHOLD, defaultPatience(knnQuery)); + } + + /** + * Construct a new PatienceKnnVectorQuery instance for a byte vector field + * + * @param knnQuery the knn query to be seeded + * @param saturationThreshold the early exit saturation threshold + * @param patience the patience parameter + * @return a new PatienceKnnVectorQuery instance + * @lucene.experimental + */ + public static PatienceKnnVectorQuery fromByteQuery( + KnnByteVectorQuery knnQuery, double saturationThreshold, int patience) { + return new PatienceKnnVectorQuery(knnQuery, saturationThreshold, patience); + } + + /** + * Construct a new PatienceKnnVectorQuery instance for a byte vector field + * + * @param knnQuery the knn query to be seeded + * @return a new PatienceKnnVectorQuery instance + * @lucene.experimental + */ + public static PatienceKnnVectorQuery fromByteQuery(KnnByteVectorQuery knnQuery) { + return new PatienceKnnVectorQuery( + knnQuery, DEFAULT_SATURATION_THRESHOLD, defaultPatience(knnQuery)); + } + + /** + * Construct a new PatienceKnnVectorQuery instance for seeded vector field + * + * @param knnQuery the knn query to be seeded + * @param saturationThreshold the early exit saturation threshold + * @param patience the patience parameter + * @return a new PatienceKnnVectorQuery instance + * @lucene.experimental + */ + public static PatienceKnnVectorQuery fromSeededQuery( + SeededKnnVectorQuery knnQuery, double saturationThreshold, int patience) { + return new PatienceKnnVectorQuery(knnQuery, saturationThreshold, patience); + } + + /** + * Construct a new PatienceKnnVectorQuery instance for seeded vector field + * + * @param knnQuery the knn query to be seeded + * @return a new PatienceKnnVectorQuery instance + * @lucene.experimental + */ + public static PatienceKnnVectorQuery fromSeededQuery(SeededKnnVectorQuery knnQuery) { + return new PatienceKnnVectorQuery( + knnQuery, DEFAULT_SATURATION_THRESHOLD, defaultPatience(knnQuery)); + } + + PatienceKnnVectorQuery( + AbstractKnnVectorQuery knnQuery, double saturationThreshold, int patience) { + super(knnQuery.field, knnQuery.k, knnQuery.filter, knnQuery.searchStrategy); + this.delegate = knnQuery; + this.saturationThreshold = saturationThreshold; + this.patience = patience; + } + + private static int defaultPatience(AbstractKnnVectorQuery delegate) { + return Math.max(7, (int) (delegate.k * 0.3)); + } + + @Override + public String toString(String field) { + return "PatienceKnnVectorQuery{" + + "saturationThreshold=" + + saturationThreshold + + ", patience=" + + patience + + ", delegate=" + + delegate + + '}'; + } + + @Override + protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) { + return delegate.getKnnCollectorManager(k, searcher); + } + + @Override + protected TopDocs approximateSearch( + LeafReaderContext context, + Bits acceptDocs, + int visitedLimit, + KnnCollectorManager knnCollectorManager) + throws IOException { + return delegate.approximateSearch( + context, acceptDocs, visitedLimit, new PatienceCollectorManager(knnCollectorManager)); + } + + @Override + protected TopDocs exactSearch( + LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) + throws IOException { + return delegate.exactSearch(context, acceptIterator, queryTimeout); + } + + @Override + protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { + return delegate.mergeLeafResults(perLeafResults); + } + + @Override + public void visit(QueryVisitor visitor) { + delegate.visit(visitor); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + if (!super.equals(o)) return false; + PatienceKnnVectorQuery that = (PatienceKnnVectorQuery) o; + return Objects.equals(saturationThreshold, that.saturationThreshold) + && Objects.equals(patience, that.patience) + && Objects.equals(delegate, that.delegate); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), saturationThreshold, patience, delegate); + } + + @Override + public String getField() { + return delegate.getField(); + } + + @Override + public int getK() { + return delegate.getK(); + } + + @Override + public Query getFilter() { + return delegate.getFilter(); + } + + @Override + VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException { + return delegate.createVectorScorer(context, fi); + } + + class PatienceCollectorManager implements KnnCollectorManager { + final KnnCollectorManager knnCollectorManager; + + PatienceCollectorManager(KnnCollectorManager knnCollectorManager) { + this.knnCollectorManager = knnCollectorManager; + } + + @Override + public KnnCollector newCollector( + int visitLimit, KnnSearchStrategy searchStrategy, LeafReaderContext ctx) + throws IOException { + return new HnswQueueSaturationCollector( + knnCollectorManager.newCollector(visitLimit, searchStrategy, ctx), + saturationThreshold, + patience); + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java index 5225fe700ab9..1e92abf54aab 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java @@ -17,6 +17,7 @@ package org.apache.lucene.util.hnsw; +import org.apache.lucene.search.HnswKnnCollector; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; @@ -24,7 +25,7 @@ /** * Wraps a provided KnnCollector object, translating the provided vectorId ordinal to a documentId */ -public final class OrdinalTranslatedKnnCollector extends KnnCollector.Decorator { +public final class OrdinalTranslatedKnnCollector extends HnswKnnCollector { private final IntToIntFunction vectorOrdinalToDocId; @@ -50,4 +51,11 @@ public TopDocs topDocs() { : TotalHits.Relation.EQUAL_TO), td.scoreDocs); } + + @Override + public void nextCandidate() { + if (this.collector instanceof HnswKnnCollector) { + ((HnswKnnCollector) this.collector).nextCandidate(); + } + } } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestHnswQueueSaturationCollector.java b/lucene/core/src/test/org/apache/lucene/search/TestHnswQueueSaturationCollector.java index 1ef298450160..9f6a98f3812c 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestHnswQueueSaturationCollector.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestHnswQueueSaturationCollector.java @@ -30,7 +30,7 @@ public void testDelegate() { int k = random.nextInt(1, 10); KnnCollector delegate = new TopKnnCollector(k, numDocs); HnswQueueSaturationCollector queueSaturationCollector = - new HnswQueueSaturationCollector(delegate); + new HnswQueueSaturationCollector(delegate, 1, Integer.MAX_VALUE); for (int i = 0; i < random.nextInt(numDocs); i++) { queueSaturationCollector.collect(random.nextInt(numDocs), random.nextFloat(1.0f)); } @@ -69,7 +69,7 @@ public void testDelegateVsSaturateEarlyExit() { int k = random.nextInt(1, 100); KnnCollector delegate = new TopKnnCollector(k, numDocs); HnswQueueSaturationCollector queueSaturationCollector = - new HnswQueueSaturationCollector(delegate); + new HnswQueueSaturationCollector(delegate, 0.5, 1); for (int i = 0; i < random.nextInt(numDocs); i++) { queueSaturationCollector.collect(random.nextInt(numDocs), random.nextFloat(1.0f)); if (i % 10 == 0) { @@ -88,7 +88,7 @@ public void testEarlyExitRelation() { int k = random.nextInt(100); KnnCollector delegate = new TopKnnCollector(k, random.nextInt(numDocs)); HnswQueueSaturationCollector queueSaturationCollector = - new HnswQueueSaturationCollector(delegate); + new HnswQueueSaturationCollector(delegate, 0.5, 1); for (int i = 0; i < random.nextInt(numDocs); i++) { queueSaturationCollector.collect(random.nextInt(numDocs), random.nextFloat(1.0f)); if (i % 10 == 0) { diff --git a/lucene/core/src/test/org/apache/lucene/search/TestPatienceByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestPatienceByteVectorQuery.java new file mode 100644 index 000000000000..f80e00c57eec --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestPatienceByteVectorQuery.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import static org.apache.lucene.search.TestKnnByteVectorQuery.floatToBytes; + +import java.io.IOException; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.knn.KnnSearchStrategy; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.TestVectorUtil; + +public class TestPatienceByteVectorQuery extends BaseKnnVectorQueryTestCase { + + @Override + PatienceKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) { + return PatienceKnnVectorQuery.fromByteQuery( + new KnnByteVectorQuery(field, floatToBytes(query), k, queryFilter)); + } + + @Override + AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) { + return PatienceKnnVectorQuery.fromByteQuery( + new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query)); + } + + @Override + float[] randomVector(int dim) { + byte[] b = TestVectorUtil.randomVectorBytes(dim); + float[] v = new float[b.length]; + int vi = 0; + for (int i = 0; i < v.length; i++) { + v[vi++] = b[i]; + } + return v; + } + + @Override + Field getKnnVectorField( + String name, float[] vector, VectorSimilarityFunction similarityFunction) { + return new KnnByteVectorField(name, floatToBytes(vector), similarityFunction); + } + + @Override + Field getKnnVectorField(String name, float[] vector) { + return new KnnByteVectorField(name, floatToBytes(vector), VectorSimilarityFunction.EUCLIDEAN); + } + + public void testToString() throws IOException { + try (Directory indexStore = + getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); + IndexReader reader = DirectoryReader.open(indexStore)) { + AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10); + assertEquals( + "PatienceKnnVectorQuery{saturationThreshold=0.995, patience=7, delegate=KnnByteVectorQuery:field[0,...][10]}", + query.toString("ignored")); + + assertDocScoreQueryToString(query.rewrite(newSearcher(reader))); + + // test with filter + Query filter = new TermQuery(new Term("id", "text")); + query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10, filter); + assertEquals( + "PatienceKnnVectorQuery{saturationThreshold=0.995, patience=7, delegate=KnnByteVectorQuery:field[0,...][10][id:text]}", + query.toString("ignored")); + } + } + + static class ThrowingKnnVectorQuery extends KnnByteVectorQuery { + + public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter) { + super(field, target, k, filter, new KnnSearchStrategy.Hnsw(0)); + } + + @Override + protected TopDocs exactSearch( + LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) { + throw new UnsupportedOperationException("exact search is not supported"); + } + + @Override + public String toString(String field) { + return null; + } + } +} diff --git a/lucene/core/src/test/org/apache/lucene/search/TestPatienceFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestPatienceFloatVectorQuery.java new file mode 100644 index 000000000000..53ba460f09ad --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestPatienceFloatVectorQuery.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.knn.KnnSearchStrategy; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.TestVectorUtil; + +public class TestPatienceFloatVectorQuery extends BaseKnnVectorQueryTestCase { + + @Override + PatienceKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) { + return PatienceKnnVectorQuery.fromFloatQuery( + new KnnFloatVectorQuery(field, query, k, queryFilter)); + } + + @Override + AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) { + return PatienceKnnVectorQuery.fromFloatQuery(new ThrowingKnnVectorQuery(field, vec, k, query)); + } + + @Override + float[] randomVector(int dim) { + return TestVectorUtil.randomVector(dim); + } + + @Override + Field getKnnVectorField( + String name, float[] vector, VectorSimilarityFunction similarityFunction) { + return new KnnFloatVectorField(name, vector, similarityFunction); + } + + @Override + Field getKnnVectorField(String name, float[] vector) { + return new KnnFloatVectorField(name, vector); + } + + public void testToString() throws IOException { + try (Directory indexStore = + getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0}); + IndexReader reader = DirectoryReader.open(indexStore)) { + AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10); + assertEquals( + "PatienceKnnVectorQuery{saturationThreshold=0.995, patience=7, delegate=KnnFloatVectorQuery:field[0.0,...][10]}", + query.toString("ignored")); + + assertDocScoreQueryToString(query.rewrite(newSearcher(reader))); + + // test with filter + Query filter = new TermQuery(new Term("id", "text")); + query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10, filter); + assertEquals( + "PatienceKnnVectorQuery{saturationThreshold=0.995, patience=7, delegate=KnnFloatVectorQuery:field[0.0,...][10][id:text]}", + query.toString("ignored")); + } + } + + static class ThrowingKnnVectorQuery extends KnnFloatVectorQuery { + + public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter) { + super(field, target, k, filter, new KnnSearchStrategy.Hnsw(0)); + } + + @Override + protected TopDocs exactSearch( + LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) { + throw new UnsupportedOperationException("exact search is not supported"); + } + + @Override + public String toString(String field) { + return null; + } + } +} From 3d2e46b9deeddf32f366a4893e66e47fbafb7f77 Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Tue, 25 Feb 2025 12:23:20 +0100 Subject: [PATCH 19/23] minor fix --- .../java/org/apache/lucene/search/PatienceKnnVectorQuery.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/PatienceKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/PatienceKnnVectorQuery.java index b253fbeb2a30..ac2d401cbbea 100644 --- a/lucene/core/src/java/org/apache/lucene/search/PatienceKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/PatienceKnnVectorQuery.java @@ -187,8 +187,8 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; if (!super.equals(o)) return false; PatienceKnnVectorQuery that = (PatienceKnnVectorQuery) o; - return Objects.equals(saturationThreshold, that.saturationThreshold) - && Objects.equals(patience, that.patience) + return saturationThreshold == that.saturationThreshold + && patience == that.patience && Objects.equals(delegate, that.delegate); } From eef4f97d18534a7b57dba6958da2621eb33638c7 Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Tue, 25 Feb 2025 12:30:16 +0100 Subject: [PATCH 20/23] updated CHANGES to reflect new query, minor fix --- lucene/CHANGES.txt | 4 ++-- .../src/java/org/apache/lucene/search/HnswKnnCollector.java | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 8cff8391fb1f..15fc4a7b7184 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -68,6 +68,8 @@ New Features histogram of the distribution of the values of a field, for documents matching a given query. (Adrien Grand) +* GITHUB#14094: New KNN query that early terminates when HNSW nearest neighbor queue saturates (Tommaso Teofili) + Improvements --------------------- @@ -135,8 +137,6 @@ Optimizations * GITHUB#14272: Use DocIdSetIterator#range for continuous-id BKD leaves. (Guo Feng) -* GITHUB#14094: Early terminate when HNSW nearest neighbor queue saturates (Tommaso Teofili) - Bug Fixes --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java index 9902e340ad81..7bdd2e22020a 100644 --- a/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java @@ -29,5 +29,5 @@ public HnswKnnCollector(KnnCollector collector) { /** Indicates exploration of the next HNSW candidate graph node. */ public void nextCandidate() {} - ; + } From acf58662682ab952048f527c9ede58b2aa2cf573 Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Tue, 25 Feb 2025 12:32:27 +0100 Subject: [PATCH 21/23] reverted unneeded change --- .../lucene/codecs/lucene99/Lucene99HnswVectorsReader.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java index 1ea6b51bfd47..101ebd1eef4d 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java @@ -314,9 +314,9 @@ private void search( return; } final RandomVectorScorer scorer = scorerSupplier.get(); - final Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs); final KnnCollector collector = new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc); + final Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs); HnswGraph graph = getGraph(fieldEntry); boolean doHnsw = knnCollector.k() < scorer.maxOrd(); // Take into account if quantized? E.g. some scorer cost? From 620e985507db1b5a1d60e6888a94586b4ca6adf9 Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Tue, 25 Feb 2025 12:40:19 +0100 Subject: [PATCH 22/23] tidy --- .../core/src/java/org/apache/lucene/search/HnswKnnCollector.java | 1 - 1 file changed, 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java index 7bdd2e22020a..88ab3f24025d 100644 --- a/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java @@ -29,5 +29,4 @@ public HnswKnnCollector(KnnCollector collector) { /** Indicates exploration of the next HNSW candidate graph node. */ public void nextCandidate() {} - } From f116141e40ebf0e30ede017ed3b43029073ec88e Mon Sep 17 00:00:00 2001 From: Tommaso Teofili Date: Wed, 12 Mar 2025 16:11:46 +0100 Subject: [PATCH 23/23] minor tweaks --- .../src/java/org/apache/lucene/search/HnswKnnCollector.java | 2 +- .../org/apache/lucene/search/TestPatienceByteVectorQuery.java | 2 +- .../org/apache/lucene/search/TestPatienceFloatVectorQuery.java | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java index 88ab3f24025d..3147c30e418b 100644 --- a/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/HnswKnnCollector.java @@ -27,6 +27,6 @@ public HnswKnnCollector(KnnCollector collector) { super(collector); } - /** Indicates exploration of the next HNSW candidate graph node. */ + /** Triggers exploration of the next HNSW candidate graph node. */ public void nextCandidate() {} } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestPatienceByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestPatienceByteVectorQuery.java index f80e00c57eec..47a036259bcb 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestPatienceByteVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestPatienceByteVectorQuery.java @@ -101,7 +101,7 @@ protected TopDocs exactSearch( @Override public String toString(String field) { - return null; + return "ThrowingKnnVectorQuery{" + field + "}"; } } } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestPatienceFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestPatienceFloatVectorQuery.java index 53ba460f09ad..8f66b2be1e6f 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestPatienceFloatVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestPatienceFloatVectorQuery.java @@ -92,7 +92,7 @@ protected TopDocs exactSearch( @Override public String toString(String field) { - return null; + return "ThrowingKnnVectorQuery{" + field + "}"; } } }