Skip to content

Commit

Permalink
Refactor the vector scorer interface to allow reuse during HNSW graph…
Browse files Browse the repository at this point in the history
… building
  • Loading branch information
benwtrent committed Jan 29, 2025
1 parent de4f07b commit 59dab18
Show file tree
Hide file tree
Showing 16 changed files with 313 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;

/** A bit vector scorer for scoring byte vectors. */
public class FlatBitVectorsScorer implements FlatVectorsScorer {
Expand Down Expand Up @@ -58,7 +59,7 @@ public RandomVectorScorer getRandomVectorScorer(
throw new IllegalArgumentException("vectorValues must be an instance of ByteVectorValues");
}

static class BitRandomVectorScorer implements RandomVectorScorer {
static class BitRandomVectorScorer implements UpdateableRandomVectorScorer {
private final ByteVectorValues vectorValues;
private final int bitDimensions;
private final byte[] query;
Expand All @@ -80,6 +81,11 @@ public int maxOrd() {
return vectorValues.size();
}

@Override
public void setScoringOrdinal(int node) throws IOException {
System.arraycopy(vectorValues.vectorValue(node), 0, query, 0, query.length);
}

@Override
public int ordToDoc(int ord) {
return vectorValues.ordToDoc(ord);
Expand All @@ -103,7 +109,7 @@ public BitRandomVectorScorerSupplier(ByteVectorValues vectorValues) throws IOExc
}

@Override
public RandomVectorScorer scorer(int ord) throws IOException {
public UpdateableRandomVectorScorer scorer(int ord) throws IOException {
byte[] query = vectorValues1.vectorValue(ord);
return new BitRandomVectorScorer(vectorValues2, query);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;

/**
* Default implementation of {@link FlatVectorsScorer}.
Expand Down Expand Up @@ -90,23 +91,29 @@ public String toString() {
private static final class ByteScoringSupplier implements RandomVectorScorerSupplier {
private final ByteVectorValues vectors;
private final ByteVectorValues vectors1;
private final ByteVectorValues vectors2;
private final VectorSimilarityFunction similarityFunction;

private ByteScoringSupplier(
ByteVectorValues vectors, VectorSimilarityFunction similarityFunction) throws IOException {
this.vectors = vectors;
vectors1 = vectors.copy();
vectors2 = vectors.copy();
this.similarityFunction = similarityFunction;
}

@Override
public RandomVectorScorer scorer(int ord) {
return new RandomVectorScorer.AbstractRandomVectorScorer(vectors) {
public UpdateableRandomVectorScorer scorer(int ord) throws IOException {
byte[] vector = new byte[vectors.dimension()];
System.arraycopy(vectors1.vectorValue(ord), 0, vector, 0, vector.length);
return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(vectors) {

@Override
public void setScoringOrdinal(int node) throws IOException {
System.arraycopy(vectors1.vectorValue(node), 0, vector, 0, vector.length);
}

@Override
public float score(int node) throws IOException {
return similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(node));
return similarityFunction.compare(vector, vectors1.vectorValue(ord));
}
};
}
Expand Down Expand Up @@ -138,12 +145,19 @@ private FloatScoringSupplier(
}

@Override
public RandomVectorScorer scorer(int ord) {
return new RandomVectorScorer.AbstractRandomVectorScorer(vectors) {
public UpdateableRandomVectorScorer scorer(int ord) throws IOException {
float[] vector = new float[vectors.dimension()];
System.arraycopy(vectors1.vectorValue(ord), 0, vector, 0, vector.length);
return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(vectors) {
@Override
public float score(int node) throws IOException {
return similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(node));
}

@Override
public void setScoringOrdinal(int node) throws IOException {
System.arraycopy(vectors1.vectorValue(node), 0, vector, 0, vector.length);
}
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity;
import org.apache.lucene.util.quantization.ScalarQuantizer;
Expand Down Expand Up @@ -147,11 +148,19 @@ private ScalarQuantizedRandomVectorScorerSupplier(
}

@Override
public RandomVectorScorer scorer(int ord) throws IOException {
public UpdateableRandomVectorScorer scorer(int ord) throws IOException {
final QuantizedByteVectorValues vectorsCopy = values.copy();
final byte[] queryVector = values.vectorValue(ord);
final float queryOffset = values.getScoreCorrectionConstant(ord);
return new RandomVectorScorer.AbstractRandomVectorScorer(vectorsCopy) {
byte[] queryVector = new byte[values.dimension()];
System.arraycopy(values.vectorValue(ord), 0, queryVector, 0, queryVector.length);
return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(vectorsCopy) {
float queryOffset = values.getScoreCorrectionConstant(ord);

@Override
public void setScoringOrdinal(int node) throws IOException {
System.arraycopy(vectorsCopy.vectorValue(node), 0, queryVector, 0, queryVector.length);
queryOffset = vectorsCopy.getScoreCorrectionConstant(node);
}

@Override
public float score(int node) throws IOException {
byte[] nodeVector = vectorsCopy.vectorValue(node);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;

/**
* Writes vector values to index segments.
Expand Down Expand Up @@ -507,7 +507,7 @@ static final class FlatCloseableRandomVectorScorerSupplier
}

@Override
public RandomVectorScorer scorer(int ord) throws IOException {
public UpdateableRandomVectorScorer scorer(int ord) throws IOException {
return supplier.scorer(ord);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.apache.lucene.util.packed.DirectMonotonicWriter;

/**
Expand Down Expand Up @@ -561,6 +562,8 @@ private static class FieldWriter<T> extends KnnFieldVectorsWriter<T> {
private int lastDocID = -1;
private int node = 0;
private final FlatFieldVectorsWriter<T> flatFieldVectorsWriter;
private final RandomVectorScorerSupplier scorerSupplier;
private UpdateableRandomVectorScorer scorer;

@SuppressWarnings("unchecked")
static FieldWriter<?> create(
Expand Down Expand Up @@ -601,7 +604,7 @@ static FieldWriter<?> create(
InfoStream infoStream)
throws IOException {
this.fieldInfo = fieldInfo;
RandomVectorScorerSupplier scorerSupplier =
scorerSupplier =
switch (fieldInfo.getVectorEncoding()) {
case BYTE ->
scorer.getRandomVectorScorerSupplier(
Expand Down Expand Up @@ -631,7 +634,12 @@ public void addValue(int docID, T vectorValue) throws IOException {
+ "\" appears more than once in this document (only one value is allowed per field)");
}
flatFieldVectorsWriter.addValue(docID, vectorValue);
hnswGraphBuilder.addGraphNode(node);
if (scorer == null) {
scorer = scorerSupplier.scorer(node);
} else {
scorer.setScoringOrdinal(node);
}
hnswGraphBuilder.addGraphNode(node, scorer);
node++;
lastDocID = docID;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizer;

Expand Down Expand Up @@ -87,7 +88,7 @@ public String toString() {
return "ScalarQuantizedVectorScorer(" + "nonQuantizedDelegate=" + nonQuantizedDelegate + ')';
}

static RandomVectorScorer fromVectorSimilarity(
static UpdateableRandomVectorScorer fromVectorSimilarity(
byte[] targetBytes,
float offsetCorrection,
VectorSimilarityFunction sim,
Expand Down Expand Up @@ -120,12 +121,13 @@ static void checkDimensions(int queryLen, int fieldLen) {
}
}

private static RandomVectorScorer.AbstractRandomVectorScorer dotProductFactory(
byte[] targetBytes,
float offsetCorrection,
float constMultiplier,
QuantizedByteVectorValues values,
FloatToFloatFunction scoreAdjustmentFunction) {
private static UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer
dotProductFactory(
byte[] targetBytes,
float offsetCorrection,
float constMultiplier,
QuantizedByteVectorValues values,
FloatToFloatFunction scoreAdjustmentFunction) {
if (values.getScalarQuantizer().getBits() <= 4) {
if (values.getVectorByteLength() != values.dimension() && values.getSlice() != null) {
return new CompressedInt4DotProduct(
Expand All @@ -138,7 +140,8 @@ private static RandomVectorScorer.AbstractRandomVectorScorer dotProductFactory(
values, constMultiplier, targetBytes, offsetCorrection, scoreAdjustmentFunction);
}

private static class Euclidean extends RandomVectorScorer.AbstractRandomVectorScorer {
private static class Euclidean
extends UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer {
private final float constMultiplier;
private final byte[] targetBytes;
private final QuantizedByteVectorValues values;
Expand All @@ -157,14 +160,20 @@ public float score(int node) throws IOException {
float adjustedDistance = squareDistance * constMultiplier;
return 1 / (1f + adjustedDistance);
}

@Override
public void setScoringOrdinal(int node) throws IOException {
System.arraycopy(values.vectorValue(node), 0, targetBytes, 0, targetBytes.length);
}
}

/** Calculates dot product on quantized vectors, applying the appropriate corrections */
private static class DotProduct extends RandomVectorScorer.AbstractRandomVectorScorer {
private static class DotProduct
extends UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer {
private final float constMultiplier;
private final QuantizedByteVectorValues values;
private final byte[] targetBytes;
private final float offsetCorrection;
private float offsetCorrection;
private final FloatToFloatFunction scoreAdjustmentFunction;

public DotProduct(
Expand All @@ -191,15 +200,24 @@ public float score(int vectorOrdinal) throws IOException {
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
return scoreAdjustmentFunction.apply(adjustedDistance);
}

@Override
public void setScoringOrdinal(int node) throws IOException {
System.arraycopy(values.vectorValue(node), 0, targetBytes, 0, targetBytes.length);
offsetCorrection = values.getScoreCorrectionConstant(node);
}
}

// TODO consider splitting this into two classes. right now the "query" vector is always
// decompressed
// it could stay compressed if we had a compressed version of the target vector
private static class CompressedInt4DotProduct
extends RandomVectorScorer.AbstractRandomVectorScorer {
extends UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer {
private final float constMultiplier;
private final QuantizedByteVectorValues values;
private final byte[] compressedVector;
private final byte[] targetBytes;
private final float offsetCorrection;
private float offsetCorrection;
private final FloatToFloatFunction scoreAdjustmentFunction;

private CompressedInt4DotProduct(
Expand Down Expand Up @@ -230,13 +248,20 @@ public float score(int vectorOrdinal) throws IOException {
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
return scoreAdjustmentFunction.apply(adjustedDistance);
}

@Override
public void setScoringOrdinal(int node) throws IOException {
System.arraycopy(values.vectorValue(node), 0, targetBytes, 0, targetBytes.length);
offsetCorrection = values.getScoreCorrectionConstant(node);
}
}

private static class Int4DotProduct extends RandomVectorScorer.AbstractRandomVectorScorer {
private static class Int4DotProduct
extends UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer {
private final float constMultiplier;
private final QuantizedByteVectorValues values;
private final byte[] targetBytes;
private final float offsetCorrection;
private float offsetCorrection;
private final FloatToFloatFunction scoreAdjustmentFunction;

public Int4DotProduct(
Expand All @@ -263,6 +288,12 @@ public float score(int vectorOrdinal) throws IOException {
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
return scoreAdjustmentFunction.apply(adjustedDistance);
}

@Override
public void setScoringOrdinal(int node) throws IOException {
System.arraycopy(values.vectorValue(node), 0, targetBytes, 0, targetBytes.length);
offsetCorrection = values.getScoreCorrectionConstant(node);
}
}

@FunctionalInterface
Expand All @@ -276,27 +307,26 @@ private static final class ScalarQuantizedRandomVectorScorerSupplier
private final VectorSimilarityFunction vectorSimilarityFunction;
private final QuantizedByteVectorValues values;
private final QuantizedByteVectorValues values1;
private final QuantizedByteVectorValues values2;

public ScalarQuantizedRandomVectorScorerSupplier(
QuantizedByteVectorValues values, VectorSimilarityFunction vectorSimilarityFunction)
throws IOException {
this.values = values;
this.values1 = values.copy();
this.values2 = values.copy();
this.vectorSimilarityFunction = vectorSimilarityFunction;
}

@Override
public RandomVectorScorer scorer(int ord) throws IOException {
byte[] vectorValue = values1.vectorValue(ord);
public UpdateableRandomVectorScorer scorer(int ord) throws IOException {
byte[] vectorValue = new byte[values.dimension()];
System.arraycopy(values1.vectorValue(ord), 0, vectorValue, 0, vectorValue.length);
float offsetCorrection = values1.getScoreCorrectionConstant(ord);
return fromVectorSimilarity(
vectorValue,
offsetCorrection,
vectorSimilarityFunction,
values.getScalarQuantizer().getConstantMultiplier(),
values2);
values1);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.QuantizedVectorsReader;
import org.apache.lucene.util.quantization.ScalarQuantizer;
Expand Down Expand Up @@ -1128,7 +1128,7 @@ static final class ScalarQuantizedCloseableRandomVectorScorerSupplier
}

@Override
public RandomVectorScorer scorer(int ord) throws IOException {
public UpdateableRandomVectorScorer scorer(int ord) throws IOException {
return supplier.scorer(ord);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import java.io.Closeable;

/**
* A supplier that creates {@link RandomVectorScorer} from an ordinal. Caller should be sure to
* close after use
* A supplier that creates {@link UpdateableRandomVectorScorer} from an ordinal. Caller should be
* sure to close after use
*
* <p>NOTE: the {@link #copy()} returned {@link RandomVectorScorerSupplier} is not necessarily
* closeable
Expand Down
Loading

0 comments on commit 59dab18

Please sign in to comment.