diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java index 123c18e00c08..d428b34eecbb 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorScorerUtil.java @@ -17,6 +17,7 @@ package org.apache.lucene.codecs.hnsw; +import org.apache.lucene.codecs.lucene104.AsymmetricScalarQuantizeFlatVectorsScorer; import org.apache.lucene.internal.vectorization.VectorizationProvider; /** @@ -41,4 +42,14 @@ public static FlatVectorsScorer getLucene99FlatVectorsScorer() { public static FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() { return IMPL.getLucene99ScalarQuantizedVectorsScorer(); } + + /** + * Returns a FlatVectorsScorer that supports the Lucene104 scalar quantized format. Scorers + * retrieved through this method may be optimized on certain platforms. Otherwise, a + * DefaultFlatVectorScorer is returned. + */ + public static AsymmetricScalarQuantizeFlatVectorsScorer + getLucene104ScalarQuantizedFlatVectorsScorer() { + return IMPL.getLucene104ScalarQuantizedVectorsScorer(); + } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/AsymmetricScalarQuantizeFlatVectorsScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/AsymmetricScalarQuantizeFlatVectorsScorer.java new file mode 100644 index 000000000000..a18f8e411f26 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/AsymmetricScalarQuantizeFlatVectorsScorer.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// XXX DO NOT MERGE FIX NAME + +package org.apache.lucene.codecs.lucene104; + +import java.io.IOException; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; + +/** + * Extension of {@link FlatVectorsScorer} that allows using two different vector codings for the + * "scoring" or "query" vectors and the "target" or "doc" vectors. + */ +public interface AsymmetricScalarQuantizeFlatVectorsScorer extends FlatVectorsScorer { + /** + * Returns a {@link RandomVectorScorerSupplier} that can be used to score asymmetric vector + * representations, typically a higher fidelity "scoring" vector against a lower fidelity "target" + * vector. This is used during indexing to improve the quality of the index data structure during + * build/merge; only the targetVectors are saved. + * + *

This may only be used when ScalarEncoding.isAsymmetric(). + * + * @param similarityFunction the similarity function to use + * @param scoringVectors higher fidelity scoring vectors to use as queries. + * @param targetVectors lower fidelity vectors to use as documents. + * @return a {@link RandomVectorScorerSupplier} that can be used to score vectors + * @throws IOException if an I/O error occurs + */ + RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, + QuantizedByteVectorValues scoringVectors, + QuantizedByteVectorValues targetVectors) + throws IOException; +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java index db2021a2a0e5..1e1baa0593d6 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorScorer.java @@ -22,6 +22,7 @@ import java.io.IOException; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding; import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.ArrayUtil; @@ -32,7 +33,8 @@ import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; /** Vector scorer over OptimizedScalarQuantized vectors */ -public class Lucene104ScalarQuantizedVectorScorer implements FlatVectorsScorer { +public class Lucene104ScalarQuantizedVectorScorer + implements AsymmetricScalarQuantizeFlatVectorsScorer { private final FlatVectorsScorer nonQuantizedDelegate; public Lucene104ScalarQuantizedVectorScorer(FlatVectorsScorer nonQuantizedDelegate) { @@ -106,7 +108,8 @@ public RandomVectorScorer getRandomVectorScorer( return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); } - RandomVectorScorerSupplier getRandomVectorScorerSupplier( + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( VectorSimilarityFunction similarityFunction, QuantizedByteVectorValues scoringVectors, QuantizedByteVectorValues targetVectors) { @@ -239,7 +242,7 @@ private static float quantizedScore( int targetOrd, VectorSimilarityFunction similarityFunction) throws IOException { - var scalarEncoding = targetVectors.getScalarEncoding(); + ScalarEncoding scalarEncoding = targetVectors.getScalarEncoding(); byte[] quantizedDoc = targetVectors.vectorValue(targetOrd); float qcDist = switch (scalarEncoding) { @@ -249,8 +252,32 @@ private static float quantizedScore( case SINGLE_BIT_QUERY_NIBBLE -> VectorUtil.int4BitDotProduct(quantizedQuery, quantizedDoc); }; - OptimizedScalarQuantizer.QuantizationResult indexCorrections = - targetVectors.getCorrectiveTerms(targetOrd); + return quantizedScore( + similarityFunction, + targetVectors, + qcDist, + queryCorrections, + targetVectors.getCorrectiveTerms(targetOrd)); + } + + /** + * Transforms the dotProduct of a query and index vector into a score. + * + * @param similarityFunction similarity function used to compute the score + * @param targetVectors target vector set; used for metadata + * @param dotProduct dot product of query and index vectors. + * @param queryCorrections corrective terms for the query vector + * @param indexCorrections corrective terms for the index vector + * @return a score value greater than or equal to 0 + */ + public static float quantizedScore( + VectorSimilarityFunction similarityFunction, + QuantizedByteVectorValues targetVectors, + float dotProduct, + OptimizedScalarQuantizer.QuantizationResult queryCorrections, + OptimizedScalarQuantizer.QuantizationResult indexCorrections) + throws IOException { + ScalarEncoding scalarEncoding = targetVectors.getScalarEncoding(); float queryScale = SCALE_LUT[scalarEncoding.getQueryBits() - 1]; float scale = SCALE_LUT[scalarEncoding.getBits() - 1]; float x1 = indexCorrections.quantizedComponentSum(); @@ -261,7 +288,7 @@ private static float quantizedScore( float ly = (queryCorrections.upperInterval() - ay) * queryScale; float y1 = queryCorrections.quantizedComponentSum(); float score = - ax * ay * targetVectors.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * qcDist; + ax * ay * targetVectors.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * dotProduct; // For euclidean, we need to invert the score and apply the additional correction, which is // assumed to be the squared l2norm of the centroid centered vectors. if (similarityFunction == EUCLIDEAN) { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java index 8015751fa422..7326927f7db4 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsFormat.java @@ -105,8 +105,8 @@ public class Lucene104ScalarQuantizedVectorsFormat extends FlatVectorsFormat { private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); - private static final Lucene104ScalarQuantizedVectorScorer scorer = - new Lucene104ScalarQuantizedVectorScorer(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); + private static final AsymmetricScalarQuantizeFlatVectorsScorer scorer = + FlatVectorScorerUtil.getLucene104ScalarQuantizedFlatVectorsScorer(); private final ScalarEncoding encoding; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java index f98c0f630c80..78e4fc22db34 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsReader.java @@ -27,6 +27,7 @@ import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; import org.apache.lucene.index.ByteVectorValues; @@ -66,12 +67,10 @@ class Lucene104ScalarQuantizedVectorsReader extends FlatVectorsReader private final Map fields = new HashMap<>(); private final IndexInput quantizedVectorData; private final FlatVectorsReader rawVectorsReader; - private final Lucene104ScalarQuantizedVectorScorer vectorScorer; + private final FlatVectorsScorer vectorScorer; Lucene104ScalarQuantizedVectorsReader( - SegmentReadState state, - FlatVectorsReader rawVectorsReader, - Lucene104ScalarQuantizedVectorScorer vectorsScorer) + SegmentReadState state, FlatVectorsReader rawVectorsReader, FlatVectorsScorer vectorsScorer) throws IOException { super(vectorsScorer); this.vectorScorer = vectorsScorer; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java index 69d23bc95df3..10a28a7f0022 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/Lucene104ScalarQuantizedVectorsWriter.java @@ -69,7 +69,7 @@ public class Lucene104ScalarQuantizedVectorsWriter extends FlatVectorsWriter { private final IndexOutput meta, vectorData; private final ScalarEncoding encoding; private final FlatVectorsWriter rawVectorDelegate; - private final Lucene104ScalarQuantizedVectorScorer vectorsScorer; + private final AsymmetricScalarQuantizeFlatVectorsScorer vectorsScorer; private boolean finished; /** @@ -81,7 +81,7 @@ protected Lucene104ScalarQuantizedVectorsWriter( SegmentWriteState state, ScalarEncoding encoding, FlatVectorsWriter rawVectorDelegate, - Lucene104ScalarQuantizedVectorScorer vectorsScorer) + AsymmetricScalarQuantizeFlatVectorsScorer vectorsScorer) throws IOException { super(vectorsScorer); this.encoding = encoding; @@ -831,6 +831,11 @@ public int dimension() { return values.dimension(); } + @Override + public IndexInput getSlice() { + return null; + } + @Override public OptimizedScalarQuantizer getQuantizer() { throw new UnsupportedOperationException(); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java index 440e96d40418..be3647d60f32 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/OffHeapScalarQuantizedVectorValues.java @@ -169,7 +169,12 @@ public float[] getCentroid() { @Override public int getVectorByteLength() { - return dimension; + return this.encoding.getDocPackedLength(dimension); + } + + @Override + public IndexInput getSlice() { + return slice; } static void packNibbles(byte[] unpacked, byte[] packed) { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java index 02b9c0748996..55901ed7d0f7 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene104/QuantizedByteVectorValues.java @@ -18,12 +18,13 @@ import java.io.IOException; import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding; +import org.apache.lucene.codecs.lucene95.HasIndexSlice; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.search.VectorScorer; import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; /** Scalar quantized byte vector values */ -abstract class QuantizedByteVectorValues extends ByteVectorValues { +public abstract class QuantizedByteVectorValues extends ByteVectorValues implements HasIndexSlice { /** * Retrieve the corrective terms for the given vector ordinal. For the dot-product family of diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java index 21977fa3dc77..01dc74b3e5b7 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorizationProvider.java @@ -19,6 +19,8 @@ import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene104.AsymmetricScalarQuantizeFlatVectorsScorer; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorScorer; import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer; import org.apache.lucene.store.IndexInput; @@ -46,6 +48,11 @@ public FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() { return new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE); } + @Override + public AsymmetricScalarQuantizeFlatVectorsScorer getLucene104ScalarQuantizedVectorsScorer() { + return new Lucene104ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE); + } + @Override public PostingDecodingUtil newPostingDecodingUtil(IndexInput input) { return new PostingDecodingUtil(input); diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java index cf9c56c59774..e1aec3c012db 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java @@ -29,6 +29,7 @@ import java.util.logging.Logger; import java.util.stream.Stream; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene104.AsymmetricScalarQuantizeFlatVectorsScorer; import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.Constants; import org.apache.lucene.util.VectorUtil; @@ -112,6 +113,10 @@ public static VectorizationProvider getInstance() { /** Returns a FlatVectorsScorer that supports the Lucene99 format. */ public abstract FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer(); + /** Returns a FlatVectorsScorer that supports the Lucene104 quantized format. */ + public abstract AsymmetricScalarQuantizeFlatVectorsScorer + getLucene104ScalarQuantizedVectorsScorer(); + /** Create a new {@link PostingDecodingUtil} for the given {@link IndexInput}. */ public abstract PostingDecodingUtil newPostingDecodingUtil(IndexInput input) throws IOException; @@ -136,7 +141,8 @@ static VectorizationProvider lookup(boolean testMode) { "Java runtime is using JVMCI Compiler; Java vector incubator API can't be enabled."); return new DefaultVectorizationProvider(); } - // is the incubator module present and readable (JVM providers may to exclude them or it is + // is the incubator module present and readable (JVM providers may to exclude + // them or it is // build with jlink) final var vectorMod = lookupVectorModule(); if (vectorMod.isEmpty()) { @@ -158,7 +164,8 @@ static VectorizationProvider lookup(boolean testMode) { } } try { - // we use method handles with lookup, so we do not need to deal with setAccessible as we + // we use method handles with lookup, so we do not need to deal with + // setAccessible as we // have private access through the lookup: final var lookup = MethodHandles.lookup(); final var cls = diff --git a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene104MemorySegmentScalarQuantizedVectorScorer.java b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene104MemorySegmentScalarQuantizedVectorScorer.java new file mode 100644 index 000000000000..296259210903 --- /dev/null +++ b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/Lucene104MemorySegmentScalarQuantizedVectorScorer.java @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.internal.vectorization; + +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; +import org.apache.lucene.codecs.lucene104.AsymmetricScalarQuantizeFlatVectorsScorer; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorScorer; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat; +import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding; +import org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; + +class Lucene104MemorySegmentScalarQuantizedVectorScorer + implements AsymmetricScalarQuantizeFlatVectorsScorer { + static final Lucene104MemorySegmentScalarQuantizedVectorScorer INSTANCE = + new Lucene104MemorySegmentScalarQuantizedVectorScorer(); + + private static final Lucene104ScalarQuantizedVectorScorer DELEGATE = + new Lucene104ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE); + + private static final int CORRECTIVE_TERMS_SIZE = Float.BYTES * 3 + Integer.BYTES; + + private Lucene104MemorySegmentScalarQuantizedVectorScorer() {} + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) + throws IOException { + if (vectorValues instanceof QuantizedByteVectorValues quantized + && quantized.getSlice() instanceof MemorySegmentAccessInput input + && quantized.getScalarEncoding() != ScalarEncoding.SINGLE_BIT_QUERY_NIBBLE) { + return new RandomVectorScorerSupplierImpl(similarityFunction, quantized, input); + } + return DELEGATE.getRandomVectorScorerSupplier(similarityFunction, vectorValues); + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, + QuantizedByteVectorValues scoringVectors, + QuantizedByteVectorValues targetVectors) + throws IOException { + // We do not yet support acceleration for any asymmetric formats. + return DELEGATE.getRandomVectorScorerSupplier( + similarityFunction, scoringVectors, targetVectors); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) + throws IOException { + if (vectorValues instanceof QuantizedByteVectorValues quantized + && quantized.getSlice() instanceof MemorySegmentAccessInput input + && quantized.getScalarEncoding() != ScalarEncoding.SINGLE_BIT_QUERY_NIBBLE) { + return new RandomVectorScorerImpl(similarityFunction, quantized, input, target); + } + return DELEGATE.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) + throws IOException { + return DELEGATE.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public String toString() { + return "Lucene104MemorySegmentScalarQuantizedVectorScorer()"; + } + + private abstract static class RandomVectorScorerBase + extends RandomVectorScorer.AbstractRandomVectorScorer { + + private final QuantizedByteVectorValues values; + private final MemorySegmentAccessInput input; + private final int vectorByteSize; + private final int nodeSize; + private final VectorSimilarityFunction similarity; + private byte[] scratch = null; + + RandomVectorScorerBase( + VectorSimilarityFunction similarityFunction, + QuantizedByteVectorValues values, + MemorySegmentAccessInput input) + throws IOException { + super(values); + + this.values = values; + this.input = input; + this.vectorByteSize = values.getVectorByteLength(); + this.nodeSize = this.vectorByteSize + CORRECTIVE_TERMS_SIZE; + this.similarity = similarityFunction; + checkInvariants(); + } + + final void checkInvariants() { + if (input.length() < (long) nodeSize * maxOrd()) { + throw new IllegalArgumentException("input length is less than expected vector data"); + } + } + + final void checkOrdinal(int ord) { + if (ord < 0 || ord >= maxOrd()) { + throw new IllegalArgumentException("illegal ordinal: " + ord); + } + } + + @SuppressWarnings("restricted") + MemorySegment getVector(int ord) throws IOException { + checkOrdinal(ord); + long byteOffset = (long) ord * nodeSize; + MemorySegment vector = input.segmentSliceOrNull(byteOffset, vectorByteSize); + if (vector != null) { + return vector; + } + + if (scratch == null) { + scratch = new byte[vectorByteSize]; + } + input.readBytes(byteOffset, scratch, 0, vectorByteSize); + return MemorySegment.ofArray(scratch); + } + + OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) throws IOException { + checkOrdinal(ord); + long byteOffset = (long) ord * nodeSize + vectorByteSize; + return new OptimizedScalarQuantizer.QuantizationResult( + Float.intBitsToFloat(input.readInt(byteOffset)), + Float.intBitsToFloat(input.readInt(byteOffset + Integer.BYTES)), + Float.intBitsToFloat(input.readInt(byteOffset + Integer.BYTES * 2)), + input.readInt(byteOffset + Integer.BYTES * 3)); + } + + VectorSimilarityFunction getSimilarity() { + return similarity; + } + + QuantizedByteVectorValues getValues() { + return values; + } + } + + private static class RandomVectorScorerImpl extends RandomVectorScorerBase { + private final byte[] query; + private final OptimizedScalarQuantizer.QuantizationResult queryCorrectiveTerms; + private final byte[] scratch; + + RandomVectorScorerImpl( + VectorSimilarityFunction similarityFunction, + QuantizedByteVectorValues values, + MemorySegmentAccessInput input, + float[] target) + throws IOException { + super(similarityFunction, values, input); + Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding scalarEncoding = + values.getScalarEncoding(); + OptimizedScalarQuantizer quantizer = values.getQuantizer(); + scratch = new byte[values.getVectorByteLength()]; + query = new byte[scalarEncoding.getDiscreteDimensions(target.length)]; + // We make a copy as the quantization process mutates the input + float[] copy = ArrayUtil.copyOfSubArray(target, 0, target.length); + if (similarityFunction == COSINE) { + VectorUtil.l2normalize(copy); + } + target = copy; + queryCorrectiveTerms = + quantizer.scalarQuantize(target, query, scalarEncoding.getBits(), values.getCentroid()); + } + + @Override + public float score(int node) throws IOException { + MemorySegment docVector = getVector(node); + float dotProduct = + switch (getValues().getScalarEncoding()) { + case UNSIGNED_BYTE -> PanamaVectorUtilSupport.uint8DotProduct(query, docVector); + case SEVEN_BIT -> PanamaVectorUtilSupport.dotProduct(query, docVector); + case PACKED_NIBBLE -> + PanamaVectorUtilSupport.int4DotProductSinglePacked(query, docVector); + case SINGLE_BIT_QUERY_NIBBLE -> + throw new IllegalStateException( + "this should be handled by the default implementation"); + }; + // Call getCorrectiveTerms() after computing dot product since corrective terms + // bytes appear after the vector bytes, so this sequence of calls is more cache + // friendly. + return Lucene104ScalarQuantizedVectorScorer.quantizedScore( + getSimilarity(), getValues(), dotProduct, queryCorrectiveTerms, getCorrectiveTerms(node)); + } + } + + record RandomVectorScorerSupplierImpl( + VectorSimilarityFunction similarityFunction, + QuantizedByteVectorValues values, + MemorySegmentAccessInput input) + implements RandomVectorScorerSupplier { + + @Override + public UpdateableRandomVectorScorer scorer() throws IOException { + return new UpdateableRandomVectorScorerImpl(similarityFunction, values, input); + } + + @Override + public RandomVectorScorerSupplier copy() { + return new RandomVectorScorerSupplierImpl(similarityFunction, values, input); + } + } + + private static class UpdateableRandomVectorScorerImpl extends RandomVectorScorerBase + implements UpdateableRandomVectorScorer { + private MemorySegment query; + private OptimizedScalarQuantizer.QuantizationResult queryCorrectiveTerms; + + UpdateableRandomVectorScorerImpl( + VectorSimilarityFunction similarityFunction, + QuantizedByteVectorValues values, + MemorySegmentAccessInput input) + throws IOException { + super(similarityFunction, values, input); + } + + @Override + public void setScoringOrdinal(int ord) throws IOException { + checkOrdinal(ord); + query = getVector(ord); + queryCorrectiveTerms = getCorrectiveTerms(ord); + } + + @Override + public float score(int node) throws IOException { + MemorySegment doc = getVector(node); + float dotProduct = + switch (getValues().getScalarEncoding()) { + case UNSIGNED_BYTE -> PanamaVectorUtilSupport.uint8DotProduct(query, doc); + case SEVEN_BIT -> PanamaVectorUtilSupport.dotProduct(query, doc); + case PACKED_NIBBLE -> PanamaVectorUtilSupport.int4DotProductBothPacked(query, doc); + case SINGLE_BIT_QUERY_NIBBLE -> + throw new IllegalStateException( + "this should be handled by the default implementation"); + }; + // Call getCorrectiveTerms() after computing dot product since corrective terms + // bytes appear after the vector bytes, so this sequence of calls is more cache + // friendly. + return Lucene104ScalarQuantizedVectorScorer.quantizedScore( + getSimilarity(), getValues(), dotProduct, queryCorrectiveTerms, getCorrectiveTerms(node)); + } + } +} diff --git a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java index d2e104f92f70..425b0e6630be 100644 --- a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java +++ b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java @@ -154,7 +154,8 @@ private float dotProductBody(float[] a, float[] b, int limit) { FloatVector vh = FloatVector.fromArray(FLOAT_SPECIES, b, i + 3 * FLOAT_SPECIES.length()); acc4 = fma(vg, vh, acc4); } - // vector tail: less scalar computations for unaligned sizes, esp with big vector sizes + // vector tail: less scalar computations for unaligned sizes, esp with big + // vector sizes for (; i < limit; i += FLOAT_SPECIES.length()) { FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i); FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i); diff --git a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java index cf3ab94f417c..0ec71b6b7f2f 100644 --- a/lucene/core/src/java25/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java +++ b/lucene/core/src/java25/org/apache/lucene/internal/vectorization/PanamaVectorizationProvider.java @@ -22,6 +22,7 @@ import java.util.logging.Logger; import jdk.incubator.vector.FloatVector; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene104.AsymmetricScalarQuantizeFlatVectorsScorer; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.MemorySegmentAccessInput; import org.apache.lucene.util.Constants; @@ -83,6 +84,11 @@ public FlatVectorsScorer getLucene99ScalarQuantizedVectorsScorer() { return Lucene99MemorySegmentScalarQuantizedVectorScorer.INSTANCE; } + @Override + public AsymmetricScalarQuantizeFlatVectorsScorer getLucene104ScalarQuantizedVectorsScorer() { + return Lucene104MemorySegmentScalarQuantizedVectorScorer.INSTANCE; + } + @Override public PostingDecodingUtil newPostingDecodingUtil(IndexInput input) throws IOException { if (input instanceof MemorySegmentAccessInput msai) { diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java index 10fc424c44ee..0e8084391ab5 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104HnswScalarQuantizedVectorsFormat.java @@ -88,16 +88,20 @@ public KnnVectorsFormat knnVectorsFormat() { "Lucene104HnswScalarQuantizedVectorsFormat(name=Lucene104HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20," + " flatVectorFormat=Lucene104ScalarQuantizedVectorsFormat(name=Lucene104ScalarQuantizedVectorsFormat," + " encoding=UNSIGNED_BYTE," - + " flatVectorScorer=Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate=%s())," + + " flatVectorScorer=%s," + " rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s())))"; var defaultScorer = - format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer", "DefaultFlatVectorScorer"); + format( + Locale.ROOT, + expectedPattern, + "Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate=DefaultFlatVectorScorer())", + "DefaultFlatVectorScorer"); var memSegScorer = format( Locale.ROOT, expectedPattern, - "Lucene99MemorySegmentFlatVectorsScorer", + "Lucene104MemorySegmentScalarQuantizedVectorScorer()", "Lucene99MemorySegmentFlatVectorsScorer"); assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); } diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java index 150294df12ca..d9b6b4e19b3a 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene104/TestLucene104ScalarQuantizedVectorsFormat.java @@ -110,15 +110,19 @@ public KnnVectorsFormat knnVectorsFormat() { "Lucene104ScalarQuantizedVectorsFormat(" + "name=Lucene104ScalarQuantizedVectorsFormat, " + "encoding=UNSIGNED_BYTE, " - + "flatVectorScorer=Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate=%s()), " + + "flatVectorScorer=%s, " + "rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s()))"; var defaultScorer = - format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer", "DefaultFlatVectorScorer"); + format( + Locale.ROOT, + expectedPattern, + "Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate=DefaultFlatVectorScorer())", + "DefaultFlatVectorScorer"); var memSegScorer = format( Locale.ROOT, expectedPattern, - "Lucene99MemorySegmentFlatVectorsScorer", + "Lucene104MemorySegmentScalarQuantizedVectorScorer()", "Lucene99MemorySegmentFlatVectorsScorer"); assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); }