Skip to content

Commit

Permalink
Add updateable random scorer interface for vector index building (#14181
Browse files Browse the repository at this point in the history
)

As stated by @ChrisHegarty and @msokolov the amount of garbage we create during vector index creation is pretty astounding.

This adjusts the interface to allow an "Updateable" random vector interface (@msokolov 's idea if I remember correctly) and refactors the usage to keep it threadsafe.
  • Loading branch information
benwtrent authored Feb 6, 2025
1 parent fe42efc commit 48884e7
Show file tree
Hide file tree
Showing 20 changed files with 368 additions and 154 deletions.
6 changes: 5 additions & 1 deletion lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ Optimizations
* GITHUB#14133: Dense blocks of postings are now encoded as bit sets.
(Adrien Grand)

# GITHUB#14169: Optimize ContextQuery with big number of contexts. (Mayya Sharipova)
* GITHUB#14169: Optimize ContextQuery with big number of contexts. (Mayya Sharipova)

* GITHUB#14181: Add updateable random scorer interface for knn vector index building. This allows
for fewer objects to be created during indexing and simplifies internally used iterfaces.
(Ben Trent)

Bug Fixes
---------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.openjdk.jmh.annotations.*;

@BenchmarkMode(Mode.Throughput)
Expand All @@ -57,7 +58,7 @@ public class VectorScorerBenchmark {
IndexInput in;
KnnVectorValues vectorValues;
byte[] vec1, vec2;
RandomVectorScorer scorer;
UpdateableRandomVectorScorer scorer;

@Setup(Level.Iteration)
public void init() throws IOException {
Expand All @@ -76,7 +77,8 @@ public void init() throws IOException {
scorer =
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
.getRandomVectorScorerSupplier(DOT_PRODUCT, vectorValues)
.scorer(0);
.scorer();
scorer.setScoringOrdinal(0);
}

@TearDown
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;

/** A bit vector scorer for scoring byte vectors. */
public class FlatBitVectorsScorer implements FlatVectorsScorer {
@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues)
throws IOException {
assert vectorValues instanceof ByteVectorValues;
if (vectorValues instanceof ByteVectorValues byteVectorValues) {
return new BitRandomVectorScorerSupplier(byteVectorValues);
}
Expand All @@ -51,14 +51,13 @@ public RandomVectorScorer getRandomVectorScorer(
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target)
throws IOException {
assert vectorValues instanceof ByteVectorValues;
if (vectorValues instanceof ByteVectorValues byteVectorValues) {
return new BitRandomVectorScorer(byteVectorValues, target);
}
throw new IllegalArgumentException("vectorValues must be an instance of ByteVectorValues");
}

static class BitRandomVectorScorer implements RandomVectorScorer {
static class BitRandomVectorScorer implements UpdateableRandomVectorScorer {
private final ByteVectorValues vectorValues;
private final int bitDimensions;
private final byte[] query;
Expand All @@ -80,6 +79,11 @@ public int maxOrd() {
return vectorValues.size();
}

@Override
public void setScoringOrdinal(int node) throws IOException {
System.arraycopy(vectorValues.vectorValue(node), 0, query, 0, query.length);
}

@Override
public int ordToDoc(int ord) {
return vectorValues.ordToDoc(ord);
Expand All @@ -93,24 +97,22 @@ public Bits getAcceptOrds(Bits acceptDocs) {

static class BitRandomVectorScorerSupplier implements RandomVectorScorerSupplier {
protected final ByteVectorValues vectorValues;
protected final ByteVectorValues vectorValues1;
protected final ByteVectorValues vectorValues2;
protected final ByteVectorValues targetVectors;

public BitRandomVectorScorerSupplier(ByteVectorValues vectorValues) throws IOException {
this.vectorValues = vectorValues;
this.vectorValues1 = vectorValues.copy();
this.vectorValues2 = vectorValues.copy();
this.targetVectors = vectorValues.copy();
}

@Override
public RandomVectorScorer scorer(int ord) throws IOException {
byte[] query = vectorValues1.vectorValue(ord);
return new BitRandomVectorScorer(vectorValues2, query);
public UpdateableRandomVectorScorer scorer() throws IOException {
byte[] query = new byte[vectorValues.dimension()];
return new BitRandomVectorScorer(vectorValues, query);
}

@Override
public RandomVectorScorerSupplier copy() throws IOException {
return new BitRandomVectorScorerSupplier(vectorValues.copy());
return new BitRandomVectorScorerSupplier(vectorValues);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;

/**
* Default implementation of {@link FlatVectorsScorer}.
Expand Down Expand Up @@ -89,24 +90,29 @@ public String toString() {
/** RandomVectorScorerSupplier for bytes vector */
private static final class ByteScoringSupplier implements RandomVectorScorerSupplier {
private final ByteVectorValues vectors;
private final ByteVectorValues vectors1;
private final ByteVectorValues vectors2;
private final ByteVectorValues targetVectors;
private final VectorSimilarityFunction similarityFunction;

private ByteScoringSupplier(
ByteVectorValues vectors, VectorSimilarityFunction similarityFunction) throws IOException {
this.vectors = vectors;
vectors1 = vectors.copy();
vectors2 = vectors.copy();
targetVectors = vectors.copy();
this.similarityFunction = similarityFunction;
}

@Override
public RandomVectorScorer scorer(int ord) {
return new RandomVectorScorer.AbstractRandomVectorScorer(vectors) {
public UpdateableRandomVectorScorer scorer() throws IOException {
byte[] vector = new byte[vectors.dimension()];
return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(vectors) {

@Override
public void setScoringOrdinal(int node) throws IOException {
System.arraycopy(targetVectors.vectorValue(node), 0, vector, 0, vector.length);
}

@Override
public float score(int node) throws IOException {
return similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(node));
return similarityFunction.compare(vector, targetVectors.vectorValue(node));
}
};
}
Expand All @@ -125,24 +131,28 @@ public String toString() {
/** RandomVectorScorerSupplier for Float vector */
private static final class FloatScoringSupplier implements RandomVectorScorerSupplier {
private final FloatVectorValues vectors;
private final FloatVectorValues vectors1;
private final FloatVectorValues vectors2;
private final FloatVectorValues targetVectors;
private final VectorSimilarityFunction similarityFunction;

private FloatScoringSupplier(
FloatVectorValues vectors, VectorSimilarityFunction similarityFunction) throws IOException {
this.vectors = vectors;
vectors1 = vectors.copy();
vectors2 = vectors.copy();
targetVectors = vectors.copy();
this.similarityFunction = similarityFunction;
}

@Override
public RandomVectorScorer scorer(int ord) {
return new RandomVectorScorer.AbstractRandomVectorScorer(vectors) {
public UpdateableRandomVectorScorer scorer() throws IOException {
float[] vector = new float[vectors.dimension()];
return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(vectors) {
@Override
public float score(int node) throws IOException {
return similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(node));
return similarityFunction.compare(vector, targetVectors.vectorValue(node));
}

@Override
public void setScoringOrdinal(int node) throws IOException {
System.arraycopy(targetVectors.vectorValue(node), 0, vector, 0, vector.length);
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity;
import org.apache.lucene.util.quantization.ScalarQuantizer;
Expand Down Expand Up @@ -147,11 +148,18 @@ private ScalarQuantizedRandomVectorScorerSupplier(
}

@Override
public RandomVectorScorer scorer(int ord) throws IOException {
public UpdateableRandomVectorScorer scorer() throws IOException {
final QuantizedByteVectorValues vectorsCopy = values.copy();
final byte[] queryVector = values.vectorValue(ord);
final float queryOffset = values.getScoreCorrectionConstant(ord);
return new RandomVectorScorer.AbstractRandomVectorScorer(vectorsCopy) {
byte[] queryVector = new byte[values.dimension()];
return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(vectorsCopy) {
float queryOffset = 0;

@Override
public void setScoringOrdinal(int node) throws IOException {
System.arraycopy(vectorsCopy.vectorValue(node), 0, queryVector, 0, queryVector.length);
queryOffset = vectorsCopy.getScoreCorrectionConstant(node);
}

@Override
public float score(int node) throws IOException {
byte[] nodeVector = vectorsCopy.vectorValue(node);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;

/**
* Writes vector values to index segments.
Expand Down Expand Up @@ -507,8 +507,8 @@ static final class FlatCloseableRandomVectorScorerSupplier
}

@Override
public RandomVectorScorer scorer(int ord) throws IOException {
return supplier.scorer(ord);
public UpdateableRandomVectorScorer scorer() throws IOException {
return supplier.scorer();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.apache.lucene.util.packed.DirectMonotonicWriter;

/**
Expand Down Expand Up @@ -561,6 +562,7 @@ private static class FieldWriter<T> extends KnnFieldVectorsWriter<T> {
private int lastDocID = -1;
private int node = 0;
private final FlatFieldVectorsWriter<T> flatFieldVectorsWriter;
private UpdateableRandomVectorScorer scorer;

@SuppressWarnings("unchecked")
static FieldWriter<?> create(
Expand Down Expand Up @@ -616,6 +618,7 @@ static FieldWriter<?> create(
(List<float[]>) flatFieldVectorsWriter.getVectors(),
fieldInfo.getVectorDimension()));
};
this.scorer = scorerSupplier.scorer();
hnswGraphBuilder =
HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(infoStream);
Expand All @@ -631,7 +634,8 @@ public void addValue(int docID, T vectorValue) throws IOException {
+ "\" appears more than once in this document (only one value is allowed per field)");
}
flatFieldVectorsWriter.addValue(docID, vectorValue);
hnswGraphBuilder.addGraphNode(node);
scorer.setScoringOrdinal(node);
hnswGraphBuilder.addGraphNode(node, scorer);
node++;
lastDocID = docID;
}
Expand Down
Loading

0 comments on commit 48884e7

Please sign in to comment.