From 4cd4e905078a49746f8e37fa792b785ac4be6915 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Fri, 24 Jan 2025 16:54:31 -0500 Subject: [PATCH] fixing seed refactor --- .../lucene/search/SeededKnnVectorQuery.java | 38 +++-- .../search/BaseKnnVectorQueryTestCase.java | 2 +- .../search/TestSeededKnnByteVectorQuery.java | 140 +++++++++++++++++- .../search/TestSeededKnnFloatVectorQuery.java | 65 +++++++- 4 files changed, 222 insertions(+), 23 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/SeededKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/SeededKnnVectorQuery.java index 7470287fd113..81d3860fe396 100644 --- a/lucene/core/src/java/org/apache/lucene/search/SeededKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/SeededKnnVectorQuery.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.Arrays; import java.util.Objects; +import org.apache.lucene.codecs.lucene90.IndexedDISI; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; @@ -92,6 +93,12 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { if (seedWeight != null) { return super.rewrite(indexSearcher); } + SeededKnnVectorQuery rewritten = + new SeededKnnVectorQuery(delegate, seed, createSeedWeight(indexSearcher)); + return rewritten.rewrite(indexSearcher); + } + + Weight createSeedWeight(IndexSearcher indexSearcher) throws IOException { BooleanQuery.Builder booleanSeedQueryBuilder = new BooleanQuery.Builder() .add(seed, BooleanClause.Occur.MUST) @@ -100,9 +107,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER); } Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build()); - Weight seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f); - SeededKnnVectorQuery rewritten = new SeededKnnVectorQuery(delegate, seed, seedWeight); - return rewritten.rewrite(indexSearcher); + return indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f); } @Override @@ -261,7 +266,7 @@ public int nextDoc() { } class SeededCollectorManager implements KnnCollectorManager { - private final KnnCollectorManager knnCollectorManager; + final KnnCollectorManager knnCollectorManager; SeededCollectorManager(KnnCollectorManager knnCollectorManager) { this.knnCollectorManager = knnCollectorManager; @@ -289,25 +294,30 @@ public KnnCollector newCollector(int visitLimit, LeafReaderContext ctx) throws I } leafCollector.finish(); } - KnnCollector delegateCollector = knnCollectorManager.newCollector(k, ctx); + KnnCollector delegateCollector = knnCollectorManager.newCollector(visitLimit, ctx); TopDocs seedTopDocs = seedCollector.topDocs(); VectorScorer scorer = delegate.createVectorScorer(ctx, leafReader.getFieldInfos().fieldInfo(field)); - DocIdSetIterator vectorIterator = scorer.iterator(); - if (seedTopDocs.totalHits.value() == 0 - || vectorIterator instanceof KnnVectorValues.DocIndexIterator == false) { + if (seedTopDocs.totalHits.value() == 0 || scorer == null) { return delegateCollector; } - DocIdSetIterator seedDocs = - new MappedDISI( - (KnnVectorValues.DocIndexIterator) vectorIterator, new TopDocsDISI(seedTopDocs)); - return new SeededKnnCollector(delegateCollector, seedDocs, seedTopDocs.scoreDocs.length); + DocIdSetIterator vectorIterator = scorer.iterator(); + // Handle sparse + if (vectorIterator instanceof IndexedDISI indexedDISI) { + vectorIterator = IndexedDISI.asDocIndexIterator(indexedDISI); + } + // Most underlying iterators are indexed, so we can map the seed docs to the vector docs + if (vectorIterator instanceof KnnVectorValues.DocIndexIterator indexIterator) { + DocIdSetIterator seedDocs = new MappedDISI(indexIterator, new TopDocsDISI(seedTopDocs)); + return new SeededKnnCollector(delegateCollector, seedDocs, seedTopDocs.scoreDocs.length); + } + return delegateCollector; } } static class SeededKnnCollector extends KnnCollector.Decorator implements EntryPointProvider { - private final DocIdSetIterator entryPoints; - private final int numberOfEntryPoints; + final DocIdSetIterator entryPoints; + final int numberOfEntryPoints; SeededKnnCollector( KnnCollector collector, DocIdSetIterator entryPoints, int numberOfEntryPoints) { diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java index 8a0d3b65aea9..8dd969a7ca3a 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java @@ -1123,7 +1123,7 @@ public void testSameFieldDifferentFormats() throws IOException { } } - private static class CountingQueryTimeout implements QueryTimeout { + static class CountingQueryTimeout implements QueryTimeout { private int remaining; public CountingQueryTimeout(int count) { diff --git a/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java index 8218ab8f34f6..e407301b7f33 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java @@ -19,6 +19,7 @@ import static org.apache.lucene.search.TestKnnByteVectorQuery.floatToBytes; import java.io.IOException; +import java.util.concurrent.atomic.AtomicInteger; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.IntPoint; @@ -27,10 +28,13 @@ import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.Bits; import org.apache.lucene.util.TestVectorUtil; public class TestSeededKnnByteVectorQuery extends BaseKnnVectorQueryTestCase { @@ -73,6 +77,46 @@ Field getKnnVectorField(String name, float[] vector) { return new KnnByteVectorField(name, floatToBytes(vector), VectorSimilarityFunction.EUCLIDEAN); } + public void testSeedWithTimeout() throws IOException { + int numDocs = atLeast(50); + int dimension = atLeast(5); + int numIters = atLeast(5); + try (Directory d = newDirectoryForTest()) { + IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec()); + RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc); + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + doc.add(getKnnVectorField("field", randomVector(dimension))); + doc.add(new NumericDocValuesField("tag", i)); + doc.add(new IntPoint("tag", i)); + w.addDocument(doc); + } + w.close(); + + try (IndexReader reader = DirectoryReader.open(d)) { + IndexSearcher searcher = newSearcher(reader); + searcher.setTimeout(() -> true); + int k = random().nextInt(80) + 1; + for (int i = 0; i < numIters; i++) { + // All documents as seeds + Query seed = + random().nextBoolean() + ? IntPoint.newRangeQuery("tag", 1, 6) + : new MatchAllDocsQuery(); + Query filter = random().nextBoolean() ? null : new MatchAllDocsQuery(); + KnnByteVectorQuery byteVectorQuery = + new KnnByteVectorQuery("field", floatToBytes(randomVector(dimension)), k, filter); + Query knnQuery = SeededKnnVectorQuery.fromByteQuery(byteVectorQuery, seed); + assertEquals(0, searcher.count(knnQuery)); + // No seed documents -- falls back on full approx search + seed = new MatchNoDocsQuery(); + knnQuery = SeededKnnVectorQuery.fromByteQuery(byteVectorQuery, seed); + assertEquals(0, searcher.count(knnQuery)); + } + } + } + } + /** Tests with random vectors and a random seed. Uses RandomIndexWriter. */ public void testRandomWithSeed() throws IOException { int numDocs = 1000; @@ -103,6 +147,12 @@ public void testRandomWithSeed() throws IOException { try (IndexReader reader = DirectoryReader.open(d)) { IndexSearcher searcher = newSearcher(reader); for (int i = 0; i < numIters; i++) { + // verify timeout collector wrapping is used + if (random().nextBoolean()) { + searcher.setTimeout(() -> false); + } else { + searcher.setTimeout(null); + } int k = random().nextInt(80) + 1; int n = random().nextInt(100) + 1; // we may get fewer results than requested if there are deletions, but this test doesn't @@ -110,12 +160,15 @@ public void testRandomWithSeed() throws IOException { assert reader.hasDeletions() == false; // All documents as seeds + AtomicInteger seedCalls = new AtomicInteger(); Query seed1 = new MatchAllDocsQuery(); Query filter = random().nextBoolean() ? null : new MatchAllDocsQuery(); KnnByteVectorQuery byteVectorQuery = new KnnByteVectorQuery("field", floatToBytes(randomVector(dimension)), k, filter); - SeededKnnVectorQuery query = SeededKnnVectorQuery.fromByteQuery(byteVectorQuery, seed1); + AssertingSeededKnnVectorQuery query = + new AssertingSeededKnnVectorQuery(byteVectorQuery, seed1, null, seedCalls); TopDocs results = searcher.search(query, n); + assertEquals(seedCalls.get(), 1); int expected = Math.min(Math.min(n, k), numDocsWithVector); assertEquals(expected, results.scoreDocs.length); @@ -131,8 +184,9 @@ public void testRandomWithSeed() throws IOException { Query seed2 = IntPoint.newRangeQuery("tag", 1, 6); byteVectorQuery = new KnnByteVectorQuery("field", floatToBytes(randomVector(dimension)), k, null); - query = SeededKnnVectorQuery.fromByteQuery(byteVectorQuery, seed2); + query = new AssertingSeededKnnVectorQuery(byteVectorQuery, seed2, null, seedCalls); results = searcher.search(query, n); + assertEquals(seedCalls.get(), 2); expected = Math.min(Math.min(n, k), reader.numDocs()); assertEquals(expected, results.scoreDocs.length); assertTrue(results.totalHits.value() >= results.scoreDocs.length); @@ -145,7 +199,7 @@ public void testRandomWithSeed() throws IOException { // No seed documents -- falls back on full approx search Query seed3 = new MatchNoDocsQuery(); - query = SeededKnnVectorQuery.fromByteQuery(byteVectorQuery, seed3); + query = new AssertingSeededKnnVectorQuery(byteVectorQuery, seed3, null, null); results = searcher.search(query, n); expected = Math.min(Math.min(n, k), reader.numDocs()); assertEquals(expected, results.scoreDocs.length); @@ -160,4 +214,84 @@ public void testRandomWithSeed() throws IOException { } } } + + static class AssertingSeededKnnVectorQuery extends SeededKnnVectorQuery { + private final AtomicInteger seedCalls; + + public AssertingSeededKnnVectorQuery( + AbstractKnnVectorQuery query, Query seed, Weight seedWeight, AtomicInteger seedCalls) { + super(query, seed, seedWeight); + this.seedCalls = seedCalls; + } + + @Override + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + if (seedWeight != null) { + return super.rewrite(indexSearcher); + } + AssertingSeededKnnVectorQuery rewritten = + new AssertingSeededKnnVectorQuery( + delegate, seed, createSeedWeight(indexSearcher), seedCalls); + return rewritten.rewrite(indexSearcher); + } + + @Override + protected TopDocs approximateSearch( + LeafReaderContext context, + Bits acceptDocs, + int visitedLimit, + KnnCollectorManager knnCollectorManager) + throws IOException { + return delegate.approximateSearch( + context, + acceptDocs, + visitedLimit, + new AssertingSeededCollectorManager(new SeededCollectorManager(knnCollectorManager))); + } + + class AssertingSeededCollectorManager extends SeededCollectorManager { + + public AssertingSeededCollectorManager(SeededCollectorManager delegate) { + super(delegate); + } + + @Override + public KnnCollector newCollector(int numVisisted, LeafReaderContext context) + throws IOException { + KnnCollector knnCollector = knnCollectorManager.newCollector(numVisisted, context); + if (knnCollector instanceof SeededKnnCollector seededKnnCollector) { + if (seedCalls == null) { + fail("Expected non-seeded collector"); + } + return new AssertingKnnCollector(seededKnnCollector); + } + if (seedCalls != null) { + fail("Expected seeded collector"); + } + return knnCollector; + } + } + + class AssertingKnnCollector extends SeededKnnCollector { + private final SeededKnnCollector seeded; + + public AssertingKnnCollector(SeededKnnCollector collector) { + super(collector, collector.entryPoints, collector.numberOfEntryPoints()); + this.seeded = collector; + } + + @Override + public DocIdSetIterator entryPoints() { + DocIdSetIterator iterator = seeded.entryPoints(); + assert iterator.cost() > 0; + seedCalls.incrementAndGet(); + return iterator; + } + + @Override + public int numberOfEntryPoints() { + return seeded.numberOfEntryPoints(); + } + } + } } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java index 3969d94ee914..6d4d2851f6ec 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java @@ -17,6 +17,7 @@ package org.apache.lucene.search; import java.io.IOException; +import java.util.concurrent.atomic.AtomicInteger; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.IntPoint; @@ -36,7 +37,7 @@ public class TestSeededKnnFloatVectorQuery extends BaseKnnVectorQueryTestCase { @Override AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) { - KnnFloatVectorQuery knnQuery = new KnnFloatVectorQuery(field, query, k); + KnnFloatVectorQuery knnQuery = new KnnFloatVectorQuery(field, query, k, queryFilter); return SeededKnnVectorQuery.fromFloatQuery(knnQuery, MATCH_NONE); } @@ -63,6 +64,46 @@ Field getKnnVectorField(String name, float[] vector) { return new KnnFloatVectorField(name, vector); } + public void testSeedWithTimeout() throws IOException { + int numDocs = atLeast(50); + int dimension = atLeast(5); + int numIters = atLeast(5); + try (Directory d = newDirectoryForTest()) { + IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec()); + RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc); + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + doc.add(getKnnVectorField("field", randomVector(dimension))); + doc.add(new NumericDocValuesField("tag", i)); + doc.add(new IntPoint("tag", i)); + w.addDocument(doc); + } + w.close(); + + try (IndexReader reader = DirectoryReader.open(d)) { + IndexSearcher searcher = newSearcher(reader); + searcher.setTimeout(() -> true); + int k = random().nextInt(80) + 1; + for (int i = 0; i < numIters; i++) { + // All documents as seeds + Query seed = + random().nextBoolean() + ? IntPoint.newRangeQuery("tag", 1, 6) + : new MatchAllDocsQuery(); + Query filter = random().nextBoolean() ? null : new MatchAllDocsQuery(); + KnnFloatVectorQuery knnFloatVectorQuery = + new KnnFloatVectorQuery("field", randomVector(dimension), k, filter); + Query knnQuery = SeededKnnVectorQuery.fromFloatQuery(knnFloatVectorQuery, seed); + assertEquals(0, searcher.count(knnQuery)); + // No seed documents -- falls back on full approx search + seed = new MatchNoDocsQuery(); + knnQuery = SeededKnnVectorQuery.fromFloatQuery(knnFloatVectorQuery, seed); + assertEquals(0, searcher.count(knnQuery)); + } + } + } + } + /** Tests with random vectors and a random seed. Uses RandomIndexWriter. */ public void testRandomWithSeed() throws IOException { int numDocs = 1000; @@ -93,6 +134,12 @@ public void testRandomWithSeed() throws IOException { try (IndexReader reader = DirectoryReader.open(d)) { IndexSearcher searcher = newSearcher(reader); for (int i = 0; i < numIters; i++) { + // verify timeout collector wrapping is used + if (random().nextBoolean()) { + searcher.setTimeout(() -> false); + } else { + searcher.setTimeout(null); + } int k = random().nextInt(80) + 1; int n = random().nextInt(100) + 1; // we may get fewer results than requested if there are deletions, but this test doesn't @@ -100,13 +147,16 @@ public void testRandomWithSeed() throws IOException { assert reader.hasDeletions() == false; // All documents as seeds + AtomicInteger seedCalls = new AtomicInteger(); Query seed1 = new MatchAllDocsQuery(); Query filter = random().nextBoolean() ? null : new MatchAllDocsQuery(); KnnFloatVectorQuery knnFloatVectorQuery = new KnnFloatVectorQuery("field", randomVector(dimension), k, filter); - AbstractKnnVectorQuery query = - SeededKnnVectorQuery.fromFloatQuery(knnFloatVectorQuery, seed1); + TestSeededKnnByteVectorQuery.AssertingSeededKnnVectorQuery query = + new TestSeededKnnByteVectorQuery.AssertingSeededKnnVectorQuery( + knnFloatVectorQuery, seed1, null, seedCalls); TopDocs results = searcher.search(query, n); + assertEquals(seedCalls.get(), 1); int expected = Math.min(Math.min(n, k), numDocsWithVector); assertEquals(expected, results.scoreDocs.length); @@ -120,8 +170,11 @@ public void testRandomWithSeed() throws IOException { // Restrictive seed query -- 6 documents Query seed2 = IntPoint.newRangeQuery("tag", 1, 6); - query = SeededKnnVectorQuery.fromFloatQuery(knnFloatVectorQuery, seed2); + query = + new TestSeededKnnByteVectorQuery.AssertingSeededKnnVectorQuery( + knnFloatVectorQuery, seed2, null, seedCalls); results = searcher.search(query, n); + assertEquals(seedCalls.get(), 2); expected = Math.min(Math.min(n, k), reader.numDocs()); assertEquals(expected, results.scoreDocs.length); assertTrue(results.totalHits.value() >= results.scoreDocs.length); @@ -134,7 +187,9 @@ public void testRandomWithSeed() throws IOException { // No seed documents -- falls back on full approx search Query seed3 = new MatchNoDocsQuery(); - query = SeededKnnVectorQuery.fromFloatQuery(knnFloatVectorQuery, seed3); + query = + new TestSeededKnnByteVectorQuery.AssertingSeededKnnVectorQuery( + knnFloatVectorQuery, seed3, null, null); results = searcher.search(query, n); expected = Math.min(Math.min(n, k), reader.numDocs()); assertEquals(expected, results.scoreDocs.length);