Skip to content

Commit 48884e7

Browse files
authored
Add updateable random scorer interface for vector index building (#14181)
As stated by @ChrisHegarty and @msokolov the amount of garbage we create during vector index creation is pretty astounding. This adjusts the interface to allow an "Updateable" random vector interface (@msokolov 's idea if I remember correctly) and refactors the usage to keep it threadsafe.
1 parent fe42efc commit 48884e7

20 files changed

+368
-154
lines changed

lucene/CHANGES.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,11 @@ Optimizations
9090
* GITHUB#14133: Dense blocks of postings are now encoded as bit sets.
9191
(Adrien Grand)
9292

93-
# GITHUB#14169: Optimize ContextQuery with big number of contexts. (Mayya Sharipova)
93+
* GITHUB#14169: Optimize ContextQuery with big number of contexts. (Mayya Sharipova)
94+
95+
* GITHUB#14181: Add updateable random scorer interface for knn vector index building. This allows
96+
for fewer objects to be created during indexing and simplifies internally used iterfaces.
97+
(Ben Trent)
9498

9599
Bug Fixes
96100
---------------------

lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorScorerBenchmark.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.apache.lucene.util.IOUtils;
3636
import org.apache.lucene.util.hnsw.RandomVectorScorer;
3737
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
38+
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
3839
import org.openjdk.jmh.annotations.*;
3940

4041
@BenchmarkMode(Mode.Throughput)
@@ -57,7 +58,7 @@ public class VectorScorerBenchmark {
5758
IndexInput in;
5859
KnnVectorValues vectorValues;
5960
byte[] vec1, vec2;
60-
RandomVectorScorer scorer;
61+
UpdateableRandomVectorScorer scorer;
6162

6263
@Setup(Level.Iteration)
6364
public void init() throws IOException {
@@ -76,7 +77,8 @@ public void init() throws IOException {
7677
scorer =
7778
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
7879
.getRandomVectorScorerSupplier(DOT_PRODUCT, vectorValues)
79-
.scorer(0);
80+
.scorer();
81+
scorer.setScoringOrdinal(0);
8082
}
8183

8284
@TearDown

lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/FlatBitVectorsScorer.java

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@
2626
import org.apache.lucene.util.VectorUtil;
2727
import org.apache.lucene.util.hnsw.RandomVectorScorer;
2828
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
29+
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
2930

3031
/** A bit vector scorer for scoring byte vectors. */
3132
public class FlatBitVectorsScorer implements FlatVectorsScorer {
3233
@Override
3334
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
3435
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues)
3536
throws IOException {
36-
assert vectorValues instanceof ByteVectorValues;
3737
if (vectorValues instanceof ByteVectorValues byteVectorValues) {
3838
return new BitRandomVectorScorerSupplier(byteVectorValues);
3939
}
@@ -51,14 +51,13 @@ public RandomVectorScorer getRandomVectorScorer(
5151
public RandomVectorScorer getRandomVectorScorer(
5252
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target)
5353
throws IOException {
54-
assert vectorValues instanceof ByteVectorValues;
5554
if (vectorValues instanceof ByteVectorValues byteVectorValues) {
5655
return new BitRandomVectorScorer(byteVectorValues, target);
5756
}
5857
throw new IllegalArgumentException("vectorValues must be an instance of ByteVectorValues");
5958
}
6059

61-
static class BitRandomVectorScorer implements RandomVectorScorer {
60+
static class BitRandomVectorScorer implements UpdateableRandomVectorScorer {
6261
private final ByteVectorValues vectorValues;
6362
private final int bitDimensions;
6463
private final byte[] query;
@@ -80,6 +79,11 @@ public int maxOrd() {
8079
return vectorValues.size();
8180
}
8281

82+
@Override
83+
public void setScoringOrdinal(int node) throws IOException {
84+
System.arraycopy(vectorValues.vectorValue(node), 0, query, 0, query.length);
85+
}
86+
8387
@Override
8488
public int ordToDoc(int ord) {
8589
return vectorValues.ordToDoc(ord);
@@ -93,24 +97,22 @@ public Bits getAcceptOrds(Bits acceptDocs) {
9397

9498
static class BitRandomVectorScorerSupplier implements RandomVectorScorerSupplier {
9599
protected final ByteVectorValues vectorValues;
96-
protected final ByteVectorValues vectorValues1;
97-
protected final ByteVectorValues vectorValues2;
100+
protected final ByteVectorValues targetVectors;
98101

99102
public BitRandomVectorScorerSupplier(ByteVectorValues vectorValues) throws IOException {
100103
this.vectorValues = vectorValues;
101-
this.vectorValues1 = vectorValues.copy();
102-
this.vectorValues2 = vectorValues.copy();
104+
this.targetVectors = vectorValues.copy();
103105
}
104106

105107
@Override
106-
public RandomVectorScorer scorer(int ord) throws IOException {
107-
byte[] query = vectorValues1.vectorValue(ord);
108-
return new BitRandomVectorScorer(vectorValues2, query);
108+
public UpdateableRandomVectorScorer scorer() throws IOException {
109+
byte[] query = new byte[vectorValues.dimension()];
110+
return new BitRandomVectorScorer(vectorValues, query);
109111
}
110112

111113
@Override
112114
public RandomVectorScorerSupplier copy() throws IOException {
113-
return new BitRandomVectorScorerSupplier(vectorValues.copy());
115+
return new BitRandomVectorScorerSupplier(vectorValues);
114116
}
115117
}
116118

lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.apache.lucene.index.VectorSimilarityFunction;
2525
import org.apache.lucene.util.hnsw.RandomVectorScorer;
2626
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
27+
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
2728

2829
/**
2930
* Default implementation of {@link FlatVectorsScorer}.
@@ -89,24 +90,29 @@ public String toString() {
8990
/** RandomVectorScorerSupplier for bytes vector */
9091
private static final class ByteScoringSupplier implements RandomVectorScorerSupplier {
9192
private final ByteVectorValues vectors;
92-
private final ByteVectorValues vectors1;
93-
private final ByteVectorValues vectors2;
93+
private final ByteVectorValues targetVectors;
9494
private final VectorSimilarityFunction similarityFunction;
9595

9696
private ByteScoringSupplier(
9797
ByteVectorValues vectors, VectorSimilarityFunction similarityFunction) throws IOException {
9898
this.vectors = vectors;
99-
vectors1 = vectors.copy();
100-
vectors2 = vectors.copy();
99+
targetVectors = vectors.copy();
101100
this.similarityFunction = similarityFunction;
102101
}
103102

104103
@Override
105-
public RandomVectorScorer scorer(int ord) {
106-
return new RandomVectorScorer.AbstractRandomVectorScorer(vectors) {
104+
public UpdateableRandomVectorScorer scorer() throws IOException {
105+
byte[] vector = new byte[vectors.dimension()];
106+
return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(vectors) {
107+
108+
@Override
109+
public void setScoringOrdinal(int node) throws IOException {
110+
System.arraycopy(targetVectors.vectorValue(node), 0, vector, 0, vector.length);
111+
}
112+
107113
@Override
108114
public float score(int node) throws IOException {
109-
return similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(node));
115+
return similarityFunction.compare(vector, targetVectors.vectorValue(node));
110116
}
111117
};
112118
}
@@ -125,24 +131,28 @@ public String toString() {
125131
/** RandomVectorScorerSupplier for Float vector */
126132
private static final class FloatScoringSupplier implements RandomVectorScorerSupplier {
127133
private final FloatVectorValues vectors;
128-
private final FloatVectorValues vectors1;
129-
private final FloatVectorValues vectors2;
134+
private final FloatVectorValues targetVectors;
130135
private final VectorSimilarityFunction similarityFunction;
131136

132137
private FloatScoringSupplier(
133138
FloatVectorValues vectors, VectorSimilarityFunction similarityFunction) throws IOException {
134139
this.vectors = vectors;
135-
vectors1 = vectors.copy();
136-
vectors2 = vectors.copy();
140+
targetVectors = vectors.copy();
137141
this.similarityFunction = similarityFunction;
138142
}
139143

140144
@Override
141-
public RandomVectorScorer scorer(int ord) {
142-
return new RandomVectorScorer.AbstractRandomVectorScorer(vectors) {
145+
public UpdateableRandomVectorScorer scorer() throws IOException {
146+
float[] vector = new float[vectors.dimension()];
147+
return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(vectors) {
143148
@Override
144149
public float score(int node) throws IOException {
145-
return similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(node));
150+
return similarityFunction.compare(vector, targetVectors.vectorValue(node));
151+
}
152+
153+
@Override
154+
public void setScoringOrdinal(int node) throws IOException {
155+
System.arraycopy(targetVectors.vectorValue(node), 0, vector, 0, vector.length);
146156
}
147157
};
148158
}

lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.apache.lucene.util.VectorUtil;
2525
import org.apache.lucene.util.hnsw.RandomVectorScorer;
2626
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
27+
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
2728
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
2829
import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity;
2930
import org.apache.lucene.util.quantization.ScalarQuantizer;
@@ -147,11 +148,18 @@ private ScalarQuantizedRandomVectorScorerSupplier(
147148
}
148149

149150
@Override
150-
public RandomVectorScorer scorer(int ord) throws IOException {
151+
public UpdateableRandomVectorScorer scorer() throws IOException {
151152
final QuantizedByteVectorValues vectorsCopy = values.copy();
152-
final byte[] queryVector = values.vectorValue(ord);
153-
final float queryOffset = values.getScoreCorrectionConstant(ord);
154-
return new RandomVectorScorer.AbstractRandomVectorScorer(vectorsCopy) {
153+
byte[] queryVector = new byte[values.dimension()];
154+
return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(vectorsCopy) {
155+
float queryOffset = 0;
156+
157+
@Override
158+
public void setScoringOrdinal(int node) throws IOException {
159+
System.arraycopy(vectorsCopy.vectorValue(node), 0, queryVector, 0, queryVector.length);
160+
queryOffset = vectorsCopy.getScoreCorrectionConstant(node);
161+
}
162+
155163
@Override
156164
public float score(int node) throws IOException {
157165
byte[] nodeVector = vectorsCopy.vectorValue(node);

lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@
5252
import org.apache.lucene.util.IOUtils;
5353
import org.apache.lucene.util.RamUsageEstimator;
5454
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
55-
import org.apache.lucene.util.hnsw.RandomVectorScorer;
5655
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
56+
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
5757

5858
/**
5959
* Writes vector values to index segments.
@@ -507,8 +507,8 @@ static final class FlatCloseableRandomVectorScorerSupplier
507507
}
508508

509509
@Override
510-
public RandomVectorScorer scorer(int ord) throws IOException {
511-
return supplier.scorer(ord);
510+
public UpdateableRandomVectorScorer scorer() throws IOException {
511+
return supplier.scorer();
512512
}
513513

514514
@Override

lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
import org.apache.lucene.util.hnsw.NeighborArray;
5858
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
5959
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
60+
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
6061
import org.apache.lucene.util.packed.DirectMonotonicWriter;
6162

6263
/**
@@ -561,6 +562,7 @@ private static class FieldWriter<T> extends KnnFieldVectorsWriter<T> {
561562
private int lastDocID = -1;
562563
private int node = 0;
563564
private final FlatFieldVectorsWriter<T> flatFieldVectorsWriter;
565+
private UpdateableRandomVectorScorer scorer;
564566

565567
@SuppressWarnings("unchecked")
566568
static FieldWriter<?> create(
@@ -616,6 +618,7 @@ static FieldWriter<?> create(
616618
(List<float[]>) flatFieldVectorsWriter.getVectors(),
617619
fieldInfo.getVectorDimension()));
618620
};
621+
this.scorer = scorerSupplier.scorer();
619622
hnswGraphBuilder =
620623
HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
621624
hnswGraphBuilder.setInfoStream(infoStream);
@@ -631,7 +634,8 @@ public void addValue(int docID, T vectorValue) throws IOException {
631634
+ "\" appears more than once in this document (only one value is allowed per field)");
632635
}
633636
flatFieldVectorsWriter.addValue(docID, vectorValue);
634-
hnswGraphBuilder.addGraphNode(node);
637+
scorer.setScoringOrdinal(node);
638+
hnswGraphBuilder.addGraphNode(node, scorer);
635639
node++;
636640
lastDocID = docID;
637641
}

0 commit comments

Comments
 (0)