Skip to content

Commit

Permalink
fixing seed refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
benwtrent committed Jan 24, 2025
1 parent 7944ba8 commit 4cd4e90
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.LeafReader;
Expand Down Expand Up @@ -92,6 +93,12 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
if (seedWeight != null) {
return super.rewrite(indexSearcher);
}
SeededKnnVectorQuery rewritten =
new SeededKnnVectorQuery(delegate, seed, createSeedWeight(indexSearcher));
return rewritten.rewrite(indexSearcher);
}

Weight createSeedWeight(IndexSearcher indexSearcher) throws IOException {
BooleanQuery.Builder booleanSeedQueryBuilder =
new BooleanQuery.Builder()
.add(seed, BooleanClause.Occur.MUST)
Expand All @@ -100,9 +107,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER);
}
Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build());
Weight seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f);
SeededKnnVectorQuery rewritten = new SeededKnnVectorQuery(delegate, seed, seedWeight);
return rewritten.rewrite(indexSearcher);
return indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f);
}

@Override
Expand Down Expand Up @@ -261,7 +266,7 @@ public int nextDoc() {
}

class SeededCollectorManager implements KnnCollectorManager {
private final KnnCollectorManager knnCollectorManager;
final KnnCollectorManager knnCollectorManager;

SeededCollectorManager(KnnCollectorManager knnCollectorManager) {
this.knnCollectorManager = knnCollectorManager;
Expand Down Expand Up @@ -289,25 +294,30 @@ public KnnCollector newCollector(int visitLimit, LeafReaderContext ctx) throws I
}
leafCollector.finish();
}
KnnCollector delegateCollector = knnCollectorManager.newCollector(k, ctx);
KnnCollector delegateCollector = knnCollectorManager.newCollector(visitLimit, ctx);
TopDocs seedTopDocs = seedCollector.topDocs();
VectorScorer scorer =
delegate.createVectorScorer(ctx, leafReader.getFieldInfos().fieldInfo(field));
DocIdSetIterator vectorIterator = scorer.iterator();
if (seedTopDocs.totalHits.value() == 0
|| vectorIterator instanceof KnnVectorValues.DocIndexIterator == false) {
if (seedTopDocs.totalHits.value() == 0 || scorer == null) {
return delegateCollector;
}
DocIdSetIterator seedDocs =
new MappedDISI(
(KnnVectorValues.DocIndexIterator) vectorIterator, new TopDocsDISI(seedTopDocs));
return new SeededKnnCollector(delegateCollector, seedDocs, seedTopDocs.scoreDocs.length);
DocIdSetIterator vectorIterator = scorer.iterator();
// Handle sparse
if (vectorIterator instanceof IndexedDISI indexedDISI) {
vectorIterator = IndexedDISI.asDocIndexIterator(indexedDISI);
}
// Most underlying iterators are indexed, so we can map the seed docs to the vector docs
if (vectorIterator instanceof KnnVectorValues.DocIndexIterator indexIterator) {
DocIdSetIterator seedDocs = new MappedDISI(indexIterator, new TopDocsDISI(seedTopDocs));
return new SeededKnnCollector(delegateCollector, seedDocs, seedTopDocs.scoreDocs.length);
}
return delegateCollector;
}
}

static class SeededKnnCollector extends KnnCollector.Decorator implements EntryPointProvider {
private final DocIdSetIterator entryPoints;
private final int numberOfEntryPoints;
final DocIdSetIterator entryPoints;
final int numberOfEntryPoints;

SeededKnnCollector(
KnnCollector collector, DocIdSetIterator entryPoints, int numberOfEntryPoints) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1123,7 +1123,7 @@ public void testSameFieldDifferentFormats() throws IOException {
}
}

private static class CountingQueryTimeout implements QueryTimeout {
static class CountingQueryTimeout implements QueryTimeout {
private int remaining;

public CountingQueryTimeout(int count) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static org.apache.lucene.search.TestKnnByteVectorQuery.floatToBytes;

import java.io.IOException;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.IntPoint;
Expand All @@ -27,10 +28,13 @@
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.TestVectorUtil;

public class TestSeededKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
Expand Down Expand Up @@ -73,6 +77,46 @@ Field getKnnVectorField(String name, float[] vector) {
return new KnnByteVectorField(name, floatToBytes(vector), VectorSimilarityFunction.EUCLIDEAN);
}

public void testSeedWithTimeout() throws IOException {
int numDocs = atLeast(50);
int dimension = atLeast(5);
int numIters = atLeast(5);
try (Directory d = newDirectoryForTest()) {
IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec());
RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc);
for (int i = 0; i < numDocs; i++) {
Document doc = new Document();
doc.add(getKnnVectorField("field", randomVector(dimension)));
doc.add(new NumericDocValuesField("tag", i));
doc.add(new IntPoint("tag", i));
w.addDocument(doc);
}
w.close();

try (IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = newSearcher(reader);
searcher.setTimeout(() -> true);
int k = random().nextInt(80) + 1;
for (int i = 0; i < numIters; i++) {
// All documents as seeds
Query seed =
random().nextBoolean()
? IntPoint.newRangeQuery("tag", 1, 6)
: new MatchAllDocsQuery();
Query filter = random().nextBoolean() ? null : new MatchAllDocsQuery();
KnnByteVectorQuery byteVectorQuery =
new KnnByteVectorQuery("field", floatToBytes(randomVector(dimension)), k, filter);
Query knnQuery = SeededKnnVectorQuery.fromByteQuery(byteVectorQuery, seed);
assertEquals(0, searcher.count(knnQuery));
// No seed documents -- falls back on full approx search
seed = new MatchNoDocsQuery();
knnQuery = SeededKnnVectorQuery.fromByteQuery(byteVectorQuery, seed);
assertEquals(0, searcher.count(knnQuery));
}
}
}
}

/** Tests with random vectors and a random seed. Uses RandomIndexWriter. */
public void testRandomWithSeed() throws IOException {
int numDocs = 1000;
Expand Down Expand Up @@ -103,19 +147,28 @@ public void testRandomWithSeed() throws IOException {
try (IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = newSearcher(reader);
for (int i = 0; i < numIters; i++) {
// verify timeout collector wrapping is used
if (random().nextBoolean()) {
searcher.setTimeout(() -> false);
} else {
searcher.setTimeout(null);
}
int k = random().nextInt(80) + 1;
int n = random().nextInt(100) + 1;
// we may get fewer results than requested if there are deletions, but this test doesn't
// check that
assert reader.hasDeletions() == false;

// All documents as seeds
AtomicInteger seedCalls = new AtomicInteger();
Query seed1 = new MatchAllDocsQuery();
Query filter = random().nextBoolean() ? null : new MatchAllDocsQuery();
KnnByteVectorQuery byteVectorQuery =
new KnnByteVectorQuery("field", floatToBytes(randomVector(dimension)), k, filter);
SeededKnnVectorQuery query = SeededKnnVectorQuery.fromByteQuery(byteVectorQuery, seed1);
AssertingSeededKnnVectorQuery query =
new AssertingSeededKnnVectorQuery(byteVectorQuery, seed1, null, seedCalls);
TopDocs results = searcher.search(query, n);
assertEquals(seedCalls.get(), 1);
int expected = Math.min(Math.min(n, k), numDocsWithVector);

assertEquals(expected, results.scoreDocs.length);
Expand All @@ -131,8 +184,9 @@ public void testRandomWithSeed() throws IOException {
Query seed2 = IntPoint.newRangeQuery("tag", 1, 6);
byteVectorQuery =
new KnnByteVectorQuery("field", floatToBytes(randomVector(dimension)), k, null);
query = SeededKnnVectorQuery.fromByteQuery(byteVectorQuery, seed2);
query = new AssertingSeededKnnVectorQuery(byteVectorQuery, seed2, null, seedCalls);
results = searcher.search(query, n);
assertEquals(seedCalls.get(), 2);
expected = Math.min(Math.min(n, k), reader.numDocs());
assertEquals(expected, results.scoreDocs.length);
assertTrue(results.totalHits.value() >= results.scoreDocs.length);
Expand All @@ -145,7 +199,7 @@ public void testRandomWithSeed() throws IOException {

// No seed documents -- falls back on full approx search
Query seed3 = new MatchNoDocsQuery();
query = SeededKnnVectorQuery.fromByteQuery(byteVectorQuery, seed3);
query = new AssertingSeededKnnVectorQuery(byteVectorQuery, seed3, null, null);
results = searcher.search(query, n);
expected = Math.min(Math.min(n, k), reader.numDocs());
assertEquals(expected, results.scoreDocs.length);
Expand All @@ -160,4 +214,84 @@ public void testRandomWithSeed() throws IOException {
}
}
}

static class AssertingSeededKnnVectorQuery extends SeededKnnVectorQuery {
private final AtomicInteger seedCalls;

public AssertingSeededKnnVectorQuery(
AbstractKnnVectorQuery query, Query seed, Weight seedWeight, AtomicInteger seedCalls) {
super(query, seed, seedWeight);
this.seedCalls = seedCalls;
}

@Override
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
if (seedWeight != null) {
return super.rewrite(indexSearcher);
}
AssertingSeededKnnVectorQuery rewritten =
new AssertingSeededKnnVectorQuery(
delegate, seed, createSeedWeight(indexSearcher), seedCalls);
return rewritten.rewrite(indexSearcher);
}

@Override
protected TopDocs approximateSearch(
LeafReaderContext context,
Bits acceptDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager)
throws IOException {
return delegate.approximateSearch(
context,
acceptDocs,
visitedLimit,
new AssertingSeededCollectorManager(new SeededCollectorManager(knnCollectorManager)));
}

class AssertingSeededCollectorManager extends SeededCollectorManager {

public AssertingSeededCollectorManager(SeededCollectorManager delegate) {
super(delegate);
}

@Override
public KnnCollector newCollector(int numVisisted, LeafReaderContext context)
throws IOException {
KnnCollector knnCollector = knnCollectorManager.newCollector(numVisisted, context);
if (knnCollector instanceof SeededKnnCollector seededKnnCollector) {
if (seedCalls == null) {
fail("Expected non-seeded collector");
}
return new AssertingKnnCollector(seededKnnCollector);
}
if (seedCalls != null) {
fail("Expected seeded collector");
}
return knnCollector;
}
}

class AssertingKnnCollector extends SeededKnnCollector {
private final SeededKnnCollector seeded;

public AssertingKnnCollector(SeededKnnCollector collector) {
super(collector, collector.entryPoints, collector.numberOfEntryPoints());
this.seeded = collector;
}

@Override
public DocIdSetIterator entryPoints() {
DocIdSetIterator iterator = seeded.entryPoints();
assert iterator.cost() > 0;
seedCalls.incrementAndGet();
return iterator;
}

@Override
public int numberOfEntryPoints() {
return seeded.numberOfEntryPoints();
}
}
}
}
Loading

0 comments on commit 4cd4e90

Please sign in to comment.