diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java
index 123c18e00c08..d428b34eecbb 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java
@@ -17,6 +17,7 @@
package org.apache.lucene.codecs.hnsw;
+import org.apache.lucene.codecs.lucene104.AsymmetricScalarQuantizeFlatVectorsScorer;
import org.apache.lucene.internal.vectorization.VectorizationProvider;
/**
@@ -41,4 +42,14 @@ public static FlatVectorsScorer getLucene99FlatVectorsScorer() {
public static FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() {
return IMPL.getLucene99ScalarQuantizedVectorsScorer();
}
+
+ /**
+ * Returns a FlatVectorsScorer that supports the Lucene104 scalar quantized format. Scorers
+ * retrieved through this method may be optimized on certain platforms. Otherwise, a
+ * DefaultFlatVectorScorer is returned.
+ */
+ public static AsymmetricScalarQuantizeFlatVectorsScorer
+ getLucene104ScalarQuantizedFlatVectorsScorer() {
+ return IMPL.getLucene104ScalarQuantizedVectorsScorer();
+ }
}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/AsymmetricScalarQuantizeFlatVectorsScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/AsymmetricScalarQuantizeFlatVectorsScorer.java
new file mode 100644
index 000000000000..a18f8e411f26
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/AsymmetricScalarQuantizeFlatVectorsScorer.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// XXX DO NOT MERGE FIX NAME
+
+package org.apache.lucene.codecs.lucene104;
+
+import java.io.IOException;
+import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
+
+/**
+ * Extension of {@link FlatVectorsScorer} that allows using two different vector codings for the
+ * "scoring" or "query" vectors and the "target" or "doc" vectors.
+ */
+public interface AsymmetricScalarQuantizeFlatVectorsScorer extends FlatVectorsScorer {
+ /**
+ * Returns a {@link RandomVectorScorerSupplier} that can be used to score asymmetric vector
+ * representations, typically a higher fidelity "scoring" vector against a lower fidelity "target"
+ * vector. This is used during indexing to improve the quality of the index data structure during
+ * build/merge; only the targetVectors are saved.
+ *
+ *
This may only be used when ScalarEncoding.isAsymmetric().
+ *
+ * @param similarityFunction the similarity function to use
+ * @param scoringVectors higher fidelity scoring vectors to use as queries.
+ * @param targetVectors lower fidelity vectors to use as documents.
+ * @return a {@link RandomVectorScorerSupplier} that can be used to score vectors
+ * @throws IOException if an I/O error occurs
+ */
+ RandomVectorScorerSupplier getRandomVectorScorerSupplier(
+ VectorSimilarityFunction similarityFunction,
+ QuantizedByteVectorValues scoringVectors,
+ QuantizedByteVectorValues targetVectors)
+ throws IOException;
+}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java
index db2021a2a0e5..1e1baa0593d6 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java
@@ -22,6 +22,7 @@
import java.io.IOException;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
+import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.ArrayUtil;
@@ -32,7 +33,8 @@
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
/** Vector scorer over OptimizedScalarQuantized vectors */
-public class Lucene104ScalarQuantizedVectorScorer implements FlatVectorsScorer {
+public class Lucene104ScalarQuantizedVectorScorer
+ implements AsymmetricScalarQuantizeFlatVectorsScorer {
private final FlatVectorsScorer nonQuantizedDelegate;
public Lucene104ScalarQuantizedVectorScorer(FlatVectorsScorer nonQuantizedDelegate) {
@@ -106,7 +108,8 @@ public RandomVectorScorer getRandomVectorScorer(
return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
}
- RandomVectorScorerSupplier getRandomVectorScorerSupplier(
+ @Override
+ public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction,
QuantizedByteVectorValues scoringVectors,
QuantizedByteVectorValues targetVectors) {
@@ -239,7 +242,7 @@ private static float quantizedScore(
int targetOrd,
VectorSimilarityFunction similarityFunction)
throws IOException {
- var scalarEncoding = targetVectors.getScalarEncoding();
+ ScalarEncoding scalarEncoding = targetVectors.getScalarEncoding();
byte[] quantizedDoc = targetVectors.vectorValue(targetOrd);
float qcDist =
switch (scalarEncoding) {
@@ -249,8 +252,32 @@ private static float quantizedScore(
case SINGLE_BIT_QUERY_NIBBLE ->
VectorUtil.int4BitDotProduct(quantizedQuery, quantizedDoc);
};
- OptimizedScalarQuantizer.QuantizationResult indexCorrections =
- targetVectors.getCorrectiveTerms(targetOrd);
+ return quantizedScore(
+ similarityFunction,
+ targetVectors,
+ qcDist,
+ queryCorrections,
+ targetVectors.getCorrectiveTerms(targetOrd));
+ }
+
+ /**
+ * Transforms the dotProduct of a query and index vector into a score.
+ *
+ * @param similarityFunction similarity function used to compute the score
+ * @param targetVectors target vector set; used for metadata
+ * @param dotProduct dot product of query and index vectors.
+ * @param queryCorrections corrective terms for the query vector
+ * @param indexCorrections corrective terms for the index vector
+ * @return a score value greater than or equal to 0
+ */
+ public static float quantizedScore(
+ VectorSimilarityFunction similarityFunction,
+ QuantizedByteVectorValues targetVectors,
+ float dotProduct,
+ OptimizedScalarQuantizer.QuantizationResult queryCorrections,
+ OptimizedScalarQuantizer.QuantizationResult indexCorrections)
+ throws IOException {
+ ScalarEncoding scalarEncoding = targetVectors.getScalarEncoding();
float queryScale = SCALE_LUT[scalarEncoding.getQueryBits() - 1];
float scale = SCALE_LUT[scalarEncoding.getBits() - 1];
float x1 = indexCorrections.quantizedComponentSum();
@@ -261,7 +288,7 @@ private static float quantizedScore(
float ly = (queryCorrections.upperInterval() - ay) * queryScale;
float y1 = queryCorrections.quantizedComponentSum();
float score =
- ax * ay * targetVectors.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * qcDist;
+ ax * ay * targetVectors.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * dotProduct;
// For euclidean, we need to invert the score and apply the additional correction, which is
// assumed to be the squared l2norm of the centroid centered vectors.
if (similarityFunction == EUCLIDEAN) {
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java
index 8015751fa422..7326927f7db4 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java
@@ -105,8 +105,8 @@ public class Lucene104ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
private static final FlatVectorsFormat rawVectorFormat =
new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
- private static final Lucene104ScalarQuantizedVectorScorer scorer =
- new Lucene104ScalarQuantizedVectorScorer(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
+ private static final AsymmetricScalarQuantizeFlatVectorsScorer scorer =
+ FlatVectorScorerUtil.getLucene104ScalarQuantizedFlatVectorsScorer();
private final ScalarEncoding encoding;
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java
index f98c0f630c80..78e4fc22db34 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java
@@ -27,6 +27,7 @@
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
+import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding;
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
import org.apache.lucene.index.ByteVectorValues;
@@ -66,12 +67,10 @@ class Lucene104ScalarQuantizedVectorsReader extends FlatVectorsReader
private final Map fields = new HashMap<>();
private final IndexInput quantizedVectorData;
private final FlatVectorsReader rawVectorsReader;
- private final Lucene104ScalarQuantizedVectorScorer vectorScorer;
+ private final FlatVectorsScorer vectorScorer;
Lucene104ScalarQuantizedVectorsReader(
- SegmentReadState state,
- FlatVectorsReader rawVectorsReader,
- Lucene104ScalarQuantizedVectorScorer vectorsScorer)
+ SegmentReadState state, FlatVectorsReader rawVectorsReader, FlatVectorsScorer vectorsScorer)
throws IOException {
super(vectorsScorer);
this.vectorScorer = vectorsScorer;
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java
index 69d23bc95df3..10a28a7f0022 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java
@@ -69,7 +69,7 @@ public class Lucene104ScalarQuantizedVectorsWriter extends FlatVectorsWriter {
private final IndexOutput meta, vectorData;
private final ScalarEncoding encoding;
private final FlatVectorsWriter rawVectorDelegate;
- private final Lucene104ScalarQuantizedVectorScorer vectorsScorer;
+ private final AsymmetricScalarQuantizeFlatVectorsScorer vectorsScorer;
private boolean finished;
/**
@@ -81,7 +81,7 @@ protected Lucene104ScalarQuantizedVectorsWriter(
SegmentWriteState state,
ScalarEncoding encoding,
FlatVectorsWriter rawVectorDelegate,
- Lucene104ScalarQuantizedVectorScorer vectorsScorer)
+ AsymmetricScalarQuantizeFlatVectorsScorer vectorsScorer)
throws IOException {
super(vectorsScorer);
this.encoding = encoding;
@@ -831,6 +831,11 @@ public int dimension() {
return values.dimension();
}
+ @Override
+ public IndexInput getSlice() {
+ return null;
+ }
+
@Override
public OptimizedScalarQuantizer getQuantizer() {
throw new UnsupportedOperationException();
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java
index 440e96d40418..be3647d60f32 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java
@@ -169,7 +169,12 @@ public float[] getCentroid() {
@Override
public int getVectorByteLength() {
- return dimension;
+ return this.encoding.getDocPackedLength(dimension);
+ }
+
+ @Override
+ public IndexInput getSlice() {
+ return slice;
}
static void packNibbles(byte[] unpacked, byte[] packed) {
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java
index 02b9c0748996..55901ed7d0f7 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java
@@ -18,12 +18,13 @@
import java.io.IOException;
import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding;
+import org.apache.lucene.codecs.lucene95.HasIndexSlice;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
/** Scalar quantized byte vector values */
-abstract class QuantizedByteVectorValues extends ByteVectorValues {
+public abstract class QuantizedByteVectorValues extends ByteVectorValues implements HasIndexSlice {
/**
* Retrieve the corrective terms for the given vector ordinal. For the dot-product family of
diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java
index 21977fa3dc77..01dc74b3e5b7 100644
--- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java
+++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java
@@ -19,6 +19,8 @@
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
+import org.apache.lucene.codecs.lucene104.AsymmetricScalarQuantizeFlatVectorsScorer;
+import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorScorer;
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer;
import org.apache.lucene.store.IndexInput;
@@ -46,6 +48,11 @@ public FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() {
return new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
}
+ @Override
+ public AsymmetricScalarQuantizeFlatVectorsScorer getLucene104ScalarQuantizedVectorsScorer() {
+ return new Lucene104ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
+ }
+
@Override
public PostingDecodingUtil newPostingDecodingUtil(IndexInput input) {
return new PostingDecodingUtil(input);
diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java
index cf9c56c59774..e1aec3c012db 100644
--- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java
+++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java
@@ -29,6 +29,7 @@
import java.util.logging.Logger;
import java.util.stream.Stream;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
+import org.apache.lucene.codecs.lucene104.AsymmetricScalarQuantizeFlatVectorsScorer;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Constants;
import org.apache.lucene.util.VectorUtil;
@@ -112,6 +113,10 @@ public static VectorizationProvider getInstance() {
/** Returns a FlatVectorsScorer that supports the Lucene99 format. */
public abstract FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer();
+ /** Returns a FlatVectorsScorer that supports the Lucene104 quantized format. */
+ public abstract AsymmetricScalarQuantizeFlatVectorsScorer
+ getLucene104ScalarQuantizedVectorsScorer();
+
/** Create a new {@link PostingDecodingUtil} for the given {@link IndexInput}. */
public abstract PostingDecodingUtil newPostingDecodingUtil(IndexInput input) throws IOException;
@@ -136,7 +141,8 @@ static VectorizationProvider lookup(boolean testMode) {
"Java runtime is using JVMCI Compiler; Java vector incubator API can't be enabled.");
return new DefaultVectorizationProvider();
}
- // is the incubator module present and readable (JVM providers may to exclude them or it is
+ // is the incubator module present and readable (JVM providers may to exclude
+ // them or it is
// build with jlink)
final var vectorMod = lookupVectorModule();
if (vectorMod.isEmpty()) {
@@ -158,7 +164,8 @@ static VectorizationProvider lookup(boolean testMode) {
}
}
try {
- // we use method handles with lookup, so we do not need to deal with setAccessible as we
+ // we use method handles with lookup, so we do not need to deal with
+ // setAccessible as we
// have private access through the lookup:
final var lookup = MethodHandles.lookup();
final var cls =
diff --git a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene104MemorySegmentScalarQuantizedVectorScorer.java b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene104MemorySegmentScalarQuantizedVectorScorer.java
new file mode 100644
index 000000000000..296259210903
--- /dev/null
+++ b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene104MemorySegmentScalarQuantizedVectorScorer.java
@@ -0,0 +1,274 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.internal.vectorization;
+
+import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
+
+import java.io.IOException;
+import java.lang.foreign.MemorySegment;
+import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
+import org.apache.lucene.codecs.lucene104.AsymmetricScalarQuantizeFlatVectorsScorer;
+import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorScorer;
+import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat;
+import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding;
+import org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues;
+import org.apache.lucene.index.KnnVectorValues;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.store.MemorySegmentAccessInput;
+import org.apache.lucene.util.ArrayUtil;
+import org.apache.lucene.util.VectorUtil;
+import org.apache.lucene.util.hnsw.RandomVectorScorer;
+import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
+import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
+import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
+
+class Lucene104MemorySegmentScalarQuantizedVectorScorer
+ implements AsymmetricScalarQuantizeFlatVectorsScorer {
+ static final Lucene104MemorySegmentScalarQuantizedVectorScorer INSTANCE =
+ new Lucene104MemorySegmentScalarQuantizedVectorScorer();
+
+ private static final Lucene104ScalarQuantizedVectorScorer DELEGATE =
+ new Lucene104ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
+
+ private static final int CORRECTIVE_TERMS_SIZE = Float.BYTES * 3 + Integer.BYTES;
+
+ private Lucene104MemorySegmentScalarQuantizedVectorScorer() {}
+
+ @Override
+ public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
+ VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues)
+ throws IOException {
+ if (vectorValues instanceof QuantizedByteVectorValues quantized
+ && quantized.getSlice() instanceof MemorySegmentAccessInput input
+ && quantized.getScalarEncoding() != ScalarEncoding.SINGLE_BIT_QUERY_NIBBLE) {
+ return new RandomVectorScorerSupplierImpl(similarityFunction, quantized, input);
+ }
+ return DELEGATE.getRandomVectorScorerSupplier(similarityFunction, vectorValues);
+ }
+
+ @Override
+ public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
+ VectorSimilarityFunction similarityFunction,
+ QuantizedByteVectorValues scoringVectors,
+ QuantizedByteVectorValues targetVectors)
+ throws IOException {
+ // We do not yet support acceleration for any asymmetric formats.
+ return DELEGATE.getRandomVectorScorerSupplier(
+ similarityFunction, scoringVectors, targetVectors);
+ }
+
+ @Override
+ public RandomVectorScorer getRandomVectorScorer(
+ VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target)
+ throws IOException {
+ if (vectorValues instanceof QuantizedByteVectorValues quantized
+ && quantized.getSlice() instanceof MemorySegmentAccessInput input
+ && quantized.getScalarEncoding() != ScalarEncoding.SINGLE_BIT_QUERY_NIBBLE) {
+ return new RandomVectorScorerImpl(similarityFunction, quantized, input, target);
+ }
+ return DELEGATE.getRandomVectorScorer(similarityFunction, vectorValues, target);
+ }
+
+ @Override
+ public RandomVectorScorer getRandomVectorScorer(
+ VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target)
+ throws IOException {
+ return DELEGATE.getRandomVectorScorer(similarityFunction, vectorValues, target);
+ }
+
+ @Override
+ public String toString() {
+ return "Lucene104MemorySegmentScalarQuantizedVectorScorer()";
+ }
+
+ private abstract static class RandomVectorScorerBase
+ extends RandomVectorScorer.AbstractRandomVectorScorer {
+
+ private final QuantizedByteVectorValues values;
+ private final MemorySegmentAccessInput input;
+ private final int vectorByteSize;
+ private final int nodeSize;
+ private final VectorSimilarityFunction similarity;
+ private byte[] scratch = null;
+
+ RandomVectorScorerBase(
+ VectorSimilarityFunction similarityFunction,
+ QuantizedByteVectorValues values,
+ MemorySegmentAccessInput input)
+ throws IOException {
+ super(values);
+
+ this.values = values;
+ this.input = input;
+ this.vectorByteSize = values.getVectorByteLength();
+ this.nodeSize = this.vectorByteSize + CORRECTIVE_TERMS_SIZE;
+ this.similarity = similarityFunction;
+ checkInvariants();
+ }
+
+ final void checkInvariants() {
+ if (input.length() < (long) nodeSize * maxOrd()) {
+ throw new IllegalArgumentException("input length is less than expected vector data");
+ }
+ }
+
+ final void checkOrdinal(int ord) {
+ if (ord < 0 || ord >= maxOrd()) {
+ throw new IllegalArgumentException("illegal ordinal: " + ord);
+ }
+ }
+
+ @SuppressWarnings("restricted")
+ MemorySegment getVector(int ord) throws IOException {
+ checkOrdinal(ord);
+ long byteOffset = (long) ord * nodeSize;
+ MemorySegment vector = input.segmentSliceOrNull(byteOffset, vectorByteSize);
+ if (vector != null) {
+ return vector;
+ }
+
+ if (scratch == null) {
+ scratch = new byte[vectorByteSize];
+ }
+ input.readBytes(byteOffset, scratch, 0, vectorByteSize);
+ return MemorySegment.ofArray(scratch);
+ }
+
+ OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) throws IOException {
+ checkOrdinal(ord);
+ long byteOffset = (long) ord * nodeSize + vectorByteSize;
+ return new OptimizedScalarQuantizer.QuantizationResult(
+ Float.intBitsToFloat(input.readInt(byteOffset)),
+ Float.intBitsToFloat(input.readInt(byteOffset + Integer.BYTES)),
+ Float.intBitsToFloat(input.readInt(byteOffset + Integer.BYTES * 2)),
+ input.readInt(byteOffset + Integer.BYTES * 3));
+ }
+
+ VectorSimilarityFunction getSimilarity() {
+ return similarity;
+ }
+
+ QuantizedByteVectorValues getValues() {
+ return values;
+ }
+ }
+
+ private static class RandomVectorScorerImpl extends RandomVectorScorerBase {
+ private final byte[] query;
+ private final OptimizedScalarQuantizer.QuantizationResult queryCorrectiveTerms;
+ private final byte[] scratch;
+
+ RandomVectorScorerImpl(
+ VectorSimilarityFunction similarityFunction,
+ QuantizedByteVectorValues values,
+ MemorySegmentAccessInput input,
+ float[] target)
+ throws IOException {
+ super(similarityFunction, values, input);
+ Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding scalarEncoding =
+ values.getScalarEncoding();
+ OptimizedScalarQuantizer quantizer = values.getQuantizer();
+ scratch = new byte[values.getVectorByteLength()];
+ query = new byte[scalarEncoding.getDiscreteDimensions(target.length)];
+ // We make a copy as the quantization process mutates the input
+ float[] copy = ArrayUtil.copyOfSubArray(target, 0, target.length);
+ if (similarityFunction == COSINE) {
+ VectorUtil.l2normalize(copy);
+ }
+ target = copy;
+ queryCorrectiveTerms =
+ quantizer.scalarQuantize(target, query, scalarEncoding.getBits(), values.getCentroid());
+ }
+
+ @Override
+ public float score(int node) throws IOException {
+ MemorySegment docVector = getVector(node);
+ float dotProduct =
+ switch (getValues().getScalarEncoding()) {
+ case UNSIGNED_BYTE -> PanamaVectorUtilSupport.uint8DotProduct(query, docVector);
+ case SEVEN_BIT -> PanamaVectorUtilSupport.dotProduct(query, docVector);
+ case PACKED_NIBBLE ->
+ PanamaVectorUtilSupport.int4DotProductSinglePacked(query, docVector);
+ case SINGLE_BIT_QUERY_NIBBLE ->
+ throw new IllegalStateException(
+ "this should be handled by the default implementation");
+ };
+ // Call getCorrectiveTerms() after computing dot product since corrective terms
+ // bytes appear after the vector bytes, so this sequence of calls is more cache
+ // friendly.
+ return Lucene104ScalarQuantizedVectorScorer.quantizedScore(
+ getSimilarity(), getValues(), dotProduct, queryCorrectiveTerms, getCorrectiveTerms(node));
+ }
+ }
+
+ record RandomVectorScorerSupplierImpl(
+ VectorSimilarityFunction similarityFunction,
+ QuantizedByteVectorValues values,
+ MemorySegmentAccessInput input)
+ implements RandomVectorScorerSupplier {
+
+ @Override
+ public UpdateableRandomVectorScorer scorer() throws IOException {
+ return new UpdateableRandomVectorScorerImpl(similarityFunction, values, input);
+ }
+
+ @Override
+ public RandomVectorScorerSupplier copy() {
+ return new RandomVectorScorerSupplierImpl(similarityFunction, values, input);
+ }
+ }
+
+ private static class UpdateableRandomVectorScorerImpl extends RandomVectorScorerBase
+ implements UpdateableRandomVectorScorer {
+ private MemorySegment query;
+ private OptimizedScalarQuantizer.QuantizationResult queryCorrectiveTerms;
+
+ UpdateableRandomVectorScorerImpl(
+ VectorSimilarityFunction similarityFunction,
+ QuantizedByteVectorValues values,
+ MemorySegmentAccessInput input)
+ throws IOException {
+ super(similarityFunction, values, input);
+ }
+
+ @Override
+ public void setScoringOrdinal(int ord) throws IOException {
+ checkOrdinal(ord);
+ query = getVector(ord);
+ queryCorrectiveTerms = getCorrectiveTerms(ord);
+ }
+
+ @Override
+ public float score(int node) throws IOException {
+ MemorySegment doc = getVector(node);
+ float dotProduct =
+ switch (getValues().getScalarEncoding()) {
+ case UNSIGNED_BYTE -> PanamaVectorUtilSupport.uint8DotProduct(query, doc);
+ case SEVEN_BIT -> PanamaVectorUtilSupport.dotProduct(query, doc);
+ case PACKED_NIBBLE -> PanamaVectorUtilSupport.int4DotProductBothPacked(query, doc);
+ case SINGLE_BIT_QUERY_NIBBLE ->
+ throw new IllegalStateException(
+ "this should be handled by the default implementation");
+ };
+ // Call getCorrectiveTerms() after computing dot product since corrective terms
+ // bytes appear after the vector bytes, so this sequence of calls is more cache
+ // friendly.
+ return Lucene104ScalarQuantizedVectorScorer.quantizedScore(
+ getSimilarity(), getValues(), dotProduct, queryCorrectiveTerms, getCorrectiveTerms(node));
+ }
+ }
+}
diff --git a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java
index d2e104f92f70..425b0e6630be 100644
--- a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java
+++ b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java
@@ -154,7 +154,8 @@ private float dotProductBody(float[] a, float[] b, int limit) {
FloatVector vh = FloatVector.fromArray(FLOAT_SPECIES, b, i + 3 * FLOAT_SPECIES.length());
acc4 = fma(vg, vh, acc4);
}
- // vector tail: less scalar computations for unaligned sizes, esp with big vector sizes
+ // vector tail: less scalar computations for unaligned sizes, esp with big
+ // vector sizes
for (; i < limit; i += FLOAT_SPECIES.length()) {
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
diff --git a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java
index cf3ab94f417c..0ec71b6b7f2f 100644
--- a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java
+++ b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java
@@ -22,6 +22,7 @@
import java.util.logging.Logger;
import jdk.incubator.vector.FloatVector;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
+import org.apache.lucene.codecs.lucene104.AsymmetricScalarQuantizeFlatVectorsScorer;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.apache.lucene.util.Constants;
@@ -83,6 +84,11 @@ public FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() {
return Lucene99MemorySegmentScalarQuantizedVectorScorer.INSTANCE;
}
+ @Override
+ public AsymmetricScalarQuantizeFlatVectorsScorer getLucene104ScalarQuantizedVectorsScorer() {
+ return Lucene104MemorySegmentScalarQuantizedVectorScorer.INSTANCE;
+ }
+
@Override
public PostingDecodingUtil newPostingDecodingUtil(IndexInput input) throws IOException {
if (input instanceof MemorySegmentAccessInput msai) {
diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java
index 10fc424c44ee..0e8084391ab5 100644
--- a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java
+++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java
@@ -88,16 +88,20 @@ public KnnVectorsFormat knnVectorsFormat() {
"Lucene104HnswScalarQuantizedVectorsFormat(name=Lucene104HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20,"
+ " flatVectorFormat=Lucene104ScalarQuantizedVectorsFormat(name=Lucene104ScalarQuantizedVectorsFormat,"
+ " encoding=UNSIGNED_BYTE,"
- + " flatVectorScorer=Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate=%s()),"
+ + " flatVectorScorer=%s,"
+ " rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s())))";
var defaultScorer =
- format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer", "DefaultFlatVectorScorer");
+ format(
+ Locale.ROOT,
+ expectedPattern,
+ "Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate=DefaultFlatVectorScorer())",
+ "DefaultFlatVectorScorer");
var memSegScorer =
format(
Locale.ROOT,
expectedPattern,
- "Lucene99MemorySegmentFlatVectorsScorer",
+ "Lucene104MemorySegmentScalarQuantizedVectorScorer()",
"Lucene99MemorySegmentFlatVectorsScorer");
assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer)));
}
diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java
index 150294df12ca..d9b6b4e19b3a 100644
--- a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java
+++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java
@@ -110,15 +110,19 @@ public KnnVectorsFormat knnVectorsFormat() {
"Lucene104ScalarQuantizedVectorsFormat("
+ "name=Lucene104ScalarQuantizedVectorsFormat, "
+ "encoding=UNSIGNED_BYTE, "
- + "flatVectorScorer=Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate=%s()), "
+ + "flatVectorScorer=%s, "
+ "rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s()))";
var defaultScorer =
- format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer", "DefaultFlatVectorScorer");
+ format(
+ Locale.ROOT,
+ expectedPattern,
+ "Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate=DefaultFlatVectorScorer())",
+ "DefaultFlatVectorScorer");
var memSegScorer =
format(
Locale.ROOT,
expectedPattern,
- "Lucene99MemorySegmentFlatVectorsScorer",
+ "Lucene104MemorySegmentScalarQuantizedVectorScorer()",
"Lucene99MemorySegmentFlatVectorsScorer");
assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer)));
}