From 07d3be59af8b57190608e5bdadaa5da2dba70014 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Tue, 2 Apr 2024 13:38:40 -0400 Subject: [PATCH] Expand scalar quantization with adding half-byte (int4) quantization (#13197) This PR is a culmination of some various streams of work: - Confidence interval optimizations, unlocked even smaller quantization bytes. - The ability to quantize down smaller than just int8 or int7 - Adding an optimized int4 (halfbyte) vector API comparison for dot-product. The idea of further scalar quantization gives users the choice between: - Further quantizing to gain space through compressing the bits into single byte values - Or allowing quantization to give guarantees around maximal values that afford faster vector operations. I didn't add more panama vector APIs as I think trying to micro-optimize int4 for anything other than dot-product was a fools errand. Additionally, I only focused on ARM. I experimented with trying to get better performance on other architectures, but didn't get very far, so I fall back to dotProduct. --- lucene/CHANGES.txt | 3 + .../TestBasicBackwardsCompatibility.java | 2 +- .../TestGenerateBwcIndices.java | 10 + .../TestInt8HnswBackwardsCompatibility.java | 150 +++++++ .../backward_index/int8_hnsw.9.10.1.zip | Bin 0 -> 4790 bytes .../benchmark/jmh/VectorUtilBenchmark.java | 21 + ...ene99HnswScalarQuantizedVectorsFormat.java | 14 +- .../Lucene99ScalarQuantizedVectorsFormat.java | 33 +- .../Lucene99ScalarQuantizedVectorsReader.java | 53 ++- .../Lucene99ScalarQuantizedVectorsWriter.java | 206 ++++++--- .../OffHeapQuantizedByteVectorValues.java | 75 +++- .../DefaultVectorUtilSupport.java | 5 + .../vectorization/VectorUtilSupport.java | 3 + .../org/apache/lucene/util/VectorUtil.java | 7 + .../ScalarQuantizedRandomVectorScorer.java | 2 +- ...arQuantizedRandomVectorScorerSupplier.java | 2 +- .../ScalarQuantizedVectorSimilarity.java | 26 +- .../util/quantization/ScalarQuantizer.java | 421 +++++++++++++++--- .../PanamaVectorUtilSupport.java | 47 ++ ...estLucene99HnswQuantizedVectorsFormat.java | 70 ++- .../TestScalarQuantizedVectorSimilarity.java | 27 +- .../quantization/TestScalarQuantizer.java | 32 +- .../index/BaseKnnVectorsFormatTestCase.java | 76 ++-- 23 files changed, 1074 insertions(+), 211 deletions(-) create mode 100644 lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestInt8HnswBackwardsCompatibility.java create mode 100644 lucene/backward-codecs/src/test/org/apache/lucene/backward_index/int8_hnsw.9.10.1.zip diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index ea81069b3b6c..16e289c71bff 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -218,6 +218,9 @@ New Features This may improve paging logic especially when large segments are merged under memory pressure. (Uwe Schindler, Chris Hegarty, Robert Muir, Adrien Grand) +* GITHUB#13197: Expand support for new scalar bit levels for HNSW vectors. This includes 4-bit vectors and an option + to compress them to gain a 50% reduction in memory usage. (Ben Trent) + Improvements --------------------- diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestBasicBackwardsCompatibility.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestBasicBackwardsCompatibility.java index b03a7e202eef..8d35a1128be9 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestBasicBackwardsCompatibility.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestBasicBackwardsCompatibility.java @@ -511,7 +511,7 @@ private static void doTestHits(ScoreDoc[] hits, int expectedCount, IndexReader r } } - private static ScoreDoc[] assertKNNSearch( + static ScoreDoc[] assertKNNSearch( IndexSearcher searcher, float[] queryVector, int k, diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestGenerateBwcIndices.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestGenerateBwcIndices.java index c7b1ea3fb4a9..936a4c28cf24 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestGenerateBwcIndices.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestGenerateBwcIndices.java @@ -82,6 +82,16 @@ public void testCreateSortedIndex() throws IOException { sortedTest.createBWCIndex(); } + public void testCreateInt8HNSWIndices() throws IOException { + TestInt8HnswBackwardsCompatibility int8HnswBackwardsCompatibility = + new TestInt8HnswBackwardsCompatibility( + Version.LATEST, + createPattern( + TestInt8HnswBackwardsCompatibility.INDEX_NAME, + TestInt8HnswBackwardsCompatibility.SUFFIX)); + int8HnswBackwardsCompatibility.createBWCIndex(); + } + private boolean isInitialMajorVersionRelease() { return Version.LATEST.equals(Version.fromBits(Version.LATEST.major, 0, 0)); } diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestInt8HnswBackwardsCompatibility.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestInt8HnswBackwardsCompatibility.java new file mode 100644 index 000000000000..c2b229ad2be5 --- /dev/null +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestInt8HnswBackwardsCompatibility.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.backward_index; + +import static org.apache.lucene.backward_index.TestBasicBackwardsCompatibility.assertKNNSearch; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import java.io.IOException; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99Codec; +import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.document.StringField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.NoMergePolicy; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.Version; + +public class TestInt8HnswBackwardsCompatibility extends BackwardsCompatibilityTestBase { + + static final String INDEX_NAME = "int8_hnsw"; + static final String SUFFIX = ""; + private static final Version FIRST_INT8_HNSW_VERSION = Version.LUCENE_9_10_1; + private static final String KNN_VECTOR_FIELD = "knn_field"; + private static final int DOC_COUNT = 30; + private static final FieldType KNN_VECTOR_FIELD_TYPE = + KnnFloatVectorField.createFieldType(3, VectorSimilarityFunction.COSINE); + private static final float[] KNN_VECTOR = {0.2f, -0.1f, 0.1f}; + + public TestInt8HnswBackwardsCompatibility(Version version, String pattern) { + super(version, pattern); + } + + /** Provides all sorted versions to the test-framework */ + @ParametersFactory(argumentFormatting = "Lucene-Version:%1$s; Pattern: %2$s") + public static Iterable testVersionsFactory() throws IllegalAccessException { + return allVersion(INDEX_NAME, SUFFIX); + } + + protected Codec getCodec() { + return new Lucene99Codec() { + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + return new Lucene99HnswScalarQuantizedVectorsFormat( + Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, + Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH); + } + }; + } + + @Override + protected boolean supportsVersion(Version version) { + return version.onOrAfter(FIRST_INT8_HNSW_VERSION); + } + + @Override + void verifyUsesDefaultCodec(Directory dir, String name) throws IOException { + // We don't use the default codec + } + + public void testInt8HnswIndexAndSearch() throws Exception { + IndexWriterConfig indexWriterConfig = + newIndexWriterConfig(new MockAnalyzer(random())) + .setOpenMode(IndexWriterConfig.OpenMode.APPEND) + .setCodec(getCodec()) + .setMergePolicy(newLogMergePolicy()); + try (IndexWriter writer = new IndexWriter(directory, indexWriterConfig)) { + // add 10 docs + for (int i = 0; i < 10; i++) { + writer.addDocument(knnDocument(i + DOC_COUNT)); + if (random().nextBoolean()) { + writer.flush(); + } + } + if (random().nextBoolean()) { + writer.forceMerge(1); + } + writer.commit(); + try (IndexReader reader = DirectoryReader.open(directory)) { + IndexSearcher searcher = new IndexSearcher(reader); + assertKNNSearch(searcher, KNN_VECTOR, 1000, DOC_COUNT + 10, "0"); + assertKNNSearch(searcher, KNN_VECTOR, 10, 10, "0"); + } + } + // This will confirm the docs are really sorted + TestUtil.checkIndex(directory); + } + + @Override + protected void createIndex(Directory dir) throws IOException { + IndexWriterConfig conf = + new IndexWriterConfig(new MockAnalyzer(random())) + .setMaxBufferedDocs(10) + .setCodec(TestUtil.getDefaultCodec()) + .setMergePolicy(NoMergePolicy.INSTANCE); + try (IndexWriter writer = new IndexWriter(dir, conf)) { + for (int i = 0; i < DOC_COUNT; i++) { + writer.addDocument(knnDocument(i)); + } + writer.forceMerge(1); + } + try (DirectoryReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = new IndexSearcher(reader); + assertKNNSearch(searcher, KNN_VECTOR, 1000, DOC_COUNT, "0"); + assertKNNSearch(searcher, KNN_VECTOR, 10, 10, "0"); + } + } + + private static Document knnDocument(int id) { + Document doc = new Document(); + float[] vector = {KNN_VECTOR[0], KNN_VECTOR[1], KNN_VECTOR[2] + 0.1f * id}; + doc.add(new KnnFloatVectorField(KNN_VECTOR_FIELD, vector, KNN_VECTOR_FIELD_TYPE)); + doc.add(new StringField("id", Integer.toString(id), Field.Store.YES)); + return doc; + } + + public void testReadOldIndices() throws Exception { + try (DirectoryReader reader = DirectoryReader.open(directory)) { + IndexSearcher searcher = new IndexSearcher(reader); + assertKNNSearch(searcher, KNN_VECTOR, 1000, DOC_COUNT, "0"); + assertKNNSearch(searcher, KNN_VECTOR, 10, 10, "0"); + } + } +} diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/int8_hnsw.9.10.1.zip b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/int8_hnsw.9.10.1.zip new file mode 100644 index 0000000000000000000000000000000000000000..168a22d63a70527daf9e2edbe743bd38b51e01f9 GIT binary patch literal 4790 zcmc(jcT`i^8pRVz2)zX%fHb8Uhz6}H)RE|1(iwm$DKbN*3s z-l~YcK2u(#p@|0l1Bo9S$*Yw^Gq<-3oa(j(|u z9Rn)0;DSVQJ$J4xH@q@9!3r{tcWC4-TOKG3&VRvXGf)52q5XO7v z`l8r{)<|n%%^OX1nV(~ycUDDv0N+YU^g(A65W|V_joG}FWu76Exs4+O4L{d8jjjk< zxBo~U+f`aK_PA?y-!zzW*tEl@FYc#7PZa&JJ~jJ~NuTq_ZzHU_@L(65wRgv~=NNS9 z&Sj@Fy-mh4>NWqmsG(Pc4)INoW%7wus3;Lg3?f+L1{M0hjn9@-;W;PI5PQMmhj^z{|-rklYh$s&A@ zgmLxgXG2N#wxE?9KD*(i>8G-k5?03weu`iHh-}$tXIo!52 z_Woxj{ABiWG$DI=xz4Ju#>>BU>DjGY(B3CijJLdU^DwbI;=QweOD^WO>M)Cb&2jyVph!AGDi*~R+K@-C(pAL4PH zo3DaqED#o;BwB5*#M|;?+!{p$Mr4?5tdA2ZGQ9FU-Bk-GYcZ2h2feT)TsXqX-{@W> zHVmN=hHzr+r{k;)K)5^TUi3cJV+B8=>4ibsA3DS8I`w*erN-hm%~~15aAdR39VB>0{4DAk?JrLCJG&X_-jJQ z{ugvS_D|@D&0l5m4f-C6F{Ggcrn@l`cR@I5eLlQMmeAciF(A@^P$KUw0GeFd%JQa1 z{JCsBLmvMj%(}g|9*9}%8gCLv!n5)` zBKjYB=-5r0+&_@pqIfl;BPB&u5-r(_uO>2Z21pXdd-nkO>DkNdji!wzMW53Sl-D?9Bf4x_qi6)=_Q?I??l(-Ox-o$*>1Qf}f9aypN7%LIH! zs~DuDw0i3l&smW-dlS?LV?x7|E(#8kxJ%3p&J~EWYG=7KT3g71=pwGi({KR5G=e}S zCZ!u!AcRMsoE$4J$_j;&m6N@d?hP9$B-I{PWq9Dn7Q>+f1WP%@cdLUX;<_Pd4?Gc! zM3m^@`CUA25QquVvHk;W}ja)=A(n#C#x=c2~TPxJ^^V*_%{tQi}g3d66PvG;hFfTiq<=3~N^|OnB!qMRHAehIs#U&iUs#ev~ZRQ@S-z;qcLo z()=&G%a){;7ZpaW*SY#$!oXT$NsMR=?kueOTB*uF(gJ4UD|zO{#K{;gJM-NW6)|oW znRKWp-h+-9m>{TtCjM4cpq?1PAAQ)$`h~Ciw7XGFrzK%$xHa)f=k#PS-1e)?%Z;Eq zxvTDVrknR|$8R!kZH}b5&XX!s&hZD$kHepQ7OfRSx9E~SoYYcE~z36oN$Yy}on74OpUC5w}onwlSKAAYP|;x&CG zoLP=;jotjBeM%PEY$eKBnwlxn@pxh*Hmq>9We*=?1|EIuwCw_-V^RAhNhUxH&`T$41t4b<_(d{e5a)78DkLdkO#Km^W>k!p?qKgby$1a&>0^9sj>^Gn zjt}@T%7KH4vN~KUwLh#E@WRFHRK6$<_8v#RAh)G2W!9a0AM!qPxyS$mSK$QBg-IVNb`Xx{6K`^_)H+&k`>kky@(WXWRUH6qMu zx~K3&^wa7|VMV@X&&H{_L-InRP%@8a;6v7sHvT>geBrVmSL$=N#6js#u6Amb*|s|h zeQa@xZ{{)(Nd;YPs!cgt?R}Svaz<}tGw*1WZo%5m3O(Wz2n4fFb7nm%99senUH&D9 zi$`p}Axy2RV=WtUUE=WWieMLg`TQOG3OuWpEdLEh{I2l>OJe~(%;CX(H<5OYVHfO0osGPWY;U4i+!c(DMi?s5Q#<}~{2sBcdr>)ss`;k5W4AqN>(H zpD{H!==?g!g`Zeft^+lh=IKvu;%ec@aKawlUiEE6W1}m;?s9Dv6}8Rpmrt*oGunO}LOsJ8;~!tFNl#Hx9Tu*%j-G1sb3sIkpgnT-l;v2tj$DZHho ztTv@mdR8;vi_bp4srj|Qi+~u@hOt4e)1qirMo+ylGd)%GWb0aLrCQdq$>rWVf=^xO zDXeru1!U1Do{7(qzmAlDRC@m8Iim?Q!?w#)lKussWwd$v>$G)KvoDzHHE_Roa~P+g zfsTC{K5n?%5a%=vPd|G5!TlpGYH@R6U+MSq5zHo4r{(XlQtqFz67+v(Wj*vySov7s zEsrY+*fn$C3~b)>PgvPDR4W`rbpf{?^{!_kM4ZgF11J{VippPfvBv3}6OmPy{H-rvSje0f#?WNB{r; literal 0 HcmV?d00001 diff --git a/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java index 0ba817d410c8..b2c2653baf37 100644 --- a/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java +++ b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java @@ -36,6 +36,8 @@ public class VectorUtilBenchmark { private byte[] bytesA; private byte[] bytesB; + private byte[] halfBytesA; + private byte[] halfBytesB; private float[] floatsA; private float[] floatsB; @@ -51,6 +53,14 @@ public void init() { bytesB = new byte[size]; random.nextBytes(bytesA); random.nextBytes(bytesB); + // random half byte arrays for binary methods + // this means that all values must be between 0 and 15 + halfBytesA = new byte[size]; + halfBytesB = new byte[size]; + for (int i = 0; i < size; ++i) { + halfBytesA[i] = (byte) random.nextInt(16); + halfBytesB[i] = (byte) random.nextInt(16); + } // random float arrays for float methods floatsA = new float[size]; @@ -94,6 +104,17 @@ public int binarySquareVector() { return VectorUtil.squareDistance(bytesA, bytesB); } + @Benchmark + public int binaryHalfByteScalar() { + return VectorUtil.int4DotProduct(halfBytesA, halfBytesB); + } + + @Benchmark + @Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) + public int binaryHalfByteVector() { + return VectorUtil.int4DotProduct(halfBytesA, halfBytesB); + } + @Benchmark public float floatCosineScalar() { return VectorUtil.cosine(floatsA, floatsB); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java index 7c2b66b49f5b..5ebcd2b51792 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java @@ -65,7 +65,7 @@ public final class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFo /** Constructs a format using default graph construction parameters */ public Lucene99HnswScalarQuantizedVectorsFormat() { - this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null, null); + this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, 7, true, null, null); } /** @@ -75,7 +75,7 @@ public Lucene99HnswScalarQuantizedVectorsFormat() { * @param beamWidth the size of the queue maintained during graph construction. */ public Lucene99HnswScalarQuantizedVectorsFormat(int maxConn, int beamWidth) { - this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null, null); + this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, 7, true, null, null); } /** @@ -85,6 +85,11 @@ public Lucene99HnswScalarQuantizedVectorsFormat(int maxConn, int beamWidth) { * @param beamWidth the size of the queue maintained during graph construction. * @param numMergeWorkers number of workers (threads) that will be used when doing merge. If * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec + * @param bits the number of bits to use for scalar quantization (must be between 1 and 8, + * inclusive) + * @param compress whether to compress the vectors, if true, the vectors that are quantized with + * lte 4 bits will be compressed into a single byte. If false, the vectors will be stored as + * is. This provides a trade-off of memory usage and speed. * @param confidenceInterval the confidenceInterval for scalar quantizing the vectors, when `null` * it is calculated based on the vector field dimensions. * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are @@ -94,6 +99,8 @@ public Lucene99HnswScalarQuantizedVectorsFormat( int maxConn, int beamWidth, int numMergeWorkers, + int bits, + boolean compress, Float confidenceInterval, ExecutorService mergeExec) { super("Lucene99HnswScalarQuantizedVectorsFormat"); @@ -127,7 +134,8 @@ public Lucene99HnswScalarQuantizedVectorsFormat( } else { this.mergeExec = null; } - this.flatVectorsFormat = new Lucene99ScalarQuantizedVectorsFormat(confidenceInterval); + this.flatVectorsFormat = + new Lucene99ScalarQuantizedVectorsFormat(confidenceInterval, bits, compress); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java index c051d88fe701..52e89c673acf 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java @@ -30,12 +30,17 @@ * @lucene.experimental */ public final class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat { + + // The bits that are allowed for scalar quantization + // We only allow unsigned byte (8), signed byte (7), and half-byte (4) + private static final int ALLOWED_BITS = (1 << 8) | (1 << 7) | (1 << 4); public static final String QUANTIZED_VECTOR_COMPONENT = "QVEC"; static final String NAME = "Lucene99ScalarQuantizedVectorsFormat"; static final int VERSION_START = 0; - static final int VERSION_CURRENT = VERSION_START; + static final int VERSION_ADD_BITS = 1; + static final int VERSION_CURRENT = VERSION_ADD_BITS; static final String META_CODEC_NAME = "Lucene99ScalarQuantizedVectorsFormatMeta"; static final String VECTOR_DATA_CODEC_NAME = "Lucene99ScalarQuantizedVectorsFormatData"; static final String META_EXTENSION = "vemq"; @@ -55,18 +60,27 @@ public final class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsForma */ final Float confidenceInterval; + final byte bits; + final boolean compress; + /** Constructs a format using default graph construction parameters */ public Lucene99ScalarQuantizedVectorsFormat() { - this(null); + this(null, 7, true); } /** * Constructs a format using the given graph construction parameters. * * @param confidenceInterval the confidenceInterval for scalar quantizing the vectors, when `null` - * it is calculated based on the vector field dimensions. + * it is calculated dynamically. + * @param bits the number of bits to use for scalar quantization (must be between 1 and 8, + * inclusive) + * @param compress whether to compress the vectors, if true, the vectors that are quantized with + * lte 4 bits will be compressed into a single byte. If false, the vectors will be stored as + * is. This provides a trade-off of memory usage and speed. */ - public Lucene99ScalarQuantizedVectorsFormat(Float confidenceInterval) { + public Lucene99ScalarQuantizedVectorsFormat( + Float confidenceInterval, int bits, boolean compress) { if (confidenceInterval != null && (confidenceInterval < MINIMUM_CONFIDENCE_INTERVAL || confidenceInterval > MAXIMUM_CONFIDENCE_INTERVAL)) { @@ -78,7 +92,12 @@ public Lucene99ScalarQuantizedVectorsFormat(Float confidenceInterval) { + "; confidenceInterval=" + confidenceInterval); } + if (bits < 1 || bits > 8 || (ALLOWED_BITS & (1 << bits)) == 0) { + throw new IllegalArgumentException("bits must be one of: 4, 7, 8; bits=" + bits); + } + this.bits = (byte) bits; this.confidenceInterval = confidenceInterval; + this.compress = compress; } public static float calculateDefaultConfidenceInterval(int vectorDimension) { @@ -92,6 +111,10 @@ public String toString() { + NAME + ", confidenceInterval=" + confidenceInterval + + ", bits=" + + bits + + ", compress=" + + compress + ", rawVectorFormat=" + rawVectorFormat + ")"; @@ -100,7 +123,7 @@ public String toString() { @Override public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { return new Lucene99ScalarQuantizedVectorsWriter( - state, confidenceInterval, rawVectorFormat.fieldsWriter(state)); + state, confidenceInterval, bits, compress, rawVectorFormat.fieldsWriter(state)); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java index 96bef956a3f6..3c8b8f0c490a 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java @@ -82,7 +82,7 @@ public Lucene99ScalarQuantizedVectorsReader( Lucene99ScalarQuantizedVectorsFormat.VERSION_CURRENT, state.segmentInfo.getId(), state.segmentSuffix); - readFields(meta, state.fieldInfos); + readFields(meta, versionMeta, state.fieldInfos); } catch (Throwable exception) { priorE = exception; } finally { @@ -102,13 +102,14 @@ public Lucene99ScalarQuantizedVectorsReader( } } - private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException { + private void readFields(ChecksumIndexInput meta, int versionMeta, FieldInfos infos) + throws IOException { for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { FieldInfo info = infos.fieldInfo(fieldNumber); if (info == null) { throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); } - FieldEntry fieldEntry = readField(meta, info); + FieldEntry fieldEntry = readField(meta, versionMeta, info); validateFieldEntry(info, fieldEntry); fields.put(info.name, fieldEntry); } @@ -126,8 +127,13 @@ static void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) { + fieldEntry.dimension); } - // int8 quantized and calculated stored offset. - long quantizedVectorBytes = dimension + Float.BYTES; + final long quantizedVectorBytes; + if (fieldEntry.bits <= 4 && fieldEntry.compress) { + quantizedVectorBytes = ((dimension + 1) >> 1) + Float.BYTES; + } else { + // int8 quantized and calculated stored offset. + quantizedVectorBytes = dimension + Float.BYTES; + } long numQuantizedVectorBytes = Math.multiplyExact(quantizedVectorBytes, fieldEntry.size); if (numQuantizedVectorBytes != fieldEntry.vectorDataLength) { throw new IllegalStateException( @@ -209,6 +215,8 @@ public RandomVectorScorer getRandomVectorScorer(String field, float[] target) th fieldEntry.ordToDoc, fieldEntry.dimension, fieldEntry.size, + fieldEntry.bits, + fieldEntry.compress, fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength, quantizedVectorData); @@ -236,7 +244,8 @@ public long ramBytesUsed() { return size; } - private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException { + private FieldEntry readField(IndexInput input, int versionMeta, FieldInfo info) + throws IOException { VectorEncoding vectorEncoding = readVectorEncoding(input); VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); if (similarityFunction != info.getVectorSimilarityFunction()) { @@ -248,7 +257,7 @@ private FieldEntry readField(IndexInput input, FieldInfo info) throws IOExceptio + " != " + info.getVectorSimilarityFunction()); } - return new FieldEntry(input, vectorEncoding, info.getVectorSimilarityFunction()); + return new FieldEntry(input, versionMeta, vectorEncoding, info.getVectorSimilarityFunction()); } @Override @@ -261,6 +270,8 @@ public QuantizedByteVectorValues getQuantizedVectorValues(String fieldName) thro fieldEntry.ordToDoc, fieldEntry.dimension, fieldEntry.size, + fieldEntry.bits, + fieldEntry.compress, fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength, quantizedVectorData); @@ -285,10 +296,13 @@ private static class FieldEntry implements Accountable { final long vectorDataLength; final ScalarQuantizer scalarQuantizer; final int size; + final byte bits; + final boolean compress; final OrdToDocDISIReaderConfiguration ordToDoc; FieldEntry( IndexInput input, + int versionMeta, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction) throws IOException { @@ -299,12 +313,29 @@ private static class FieldEntry implements Accountable { dimension = input.readVInt(); size = input.readInt(); if (size > 0) { - float confidenceInterval = Float.intBitsToFloat(input.readInt()); - float minQuantile = Float.intBitsToFloat(input.readInt()); - float maxQuantile = Float.intBitsToFloat(input.readInt()); - scalarQuantizer = new ScalarQuantizer(minQuantile, maxQuantile, confidenceInterval); + if (versionMeta < Lucene99ScalarQuantizedVectorsFormat.VERSION_ADD_BITS) { + int floatBits = input.readInt(); // confidenceInterval, unused + if (floatBits == -1) { + throw new CorruptIndexException( + "Missing confidence interval for scalar quantizer", input); + } + this.bits = (byte) 7; + this.compress = false; + float minQuantile = Float.intBitsToFloat(input.readInt()); + float maxQuantile = Float.intBitsToFloat(input.readInt()); + scalarQuantizer = new ScalarQuantizer(minQuantile, maxQuantile, (byte) 7); + } else { + input.readInt(); // confidenceInterval, unused + this.bits = input.readByte(); + this.compress = input.readByte() == 1; + float minQuantile = Float.intBitsToFloat(input.readInt()); + float maxQuantile = Float.intBitsToFloat(input.readInt()); + scalarQuantizer = new ScalarQuantizer(minQuantile, maxQuantile, bits); + } } else { scalarQuantizer = null; + this.bits = (byte) 7; + this.compress = false; } ordToDoc = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java index b36291bdbb10..acab61adbf29 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java @@ -96,12 +96,20 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite private final IndexOutput meta, quantizedVectorData; private final Float confidenceInterval; private final FlatVectorsWriter rawVectorDelegate; + private final byte bits; + private final boolean compress; private boolean finished; public Lucene99ScalarQuantizedVectorsWriter( - SegmentWriteState state, Float confidenceInterval, FlatVectorsWriter rawVectorDelegate) + SegmentWriteState state, + Float confidenceInterval, + byte bits, + boolean compress, + FlatVectorsWriter rawVectorDelegate) throws IOException { this.confidenceInterval = confidenceInterval; + this.bits = bits; + this.compress = compress; segmentWriteState = state; String metaFileName = IndexFileNames.segmentFileName( @@ -145,12 +153,21 @@ public Lucene99ScalarQuantizedVectorsWriter( public FlatFieldVectorsWriter addField( FieldInfo fieldInfo, KnnFieldVectorsWriter indexWriter) throws IOException { if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { - float confidenceInterval = - this.confidenceInterval == null - ? calculateDefaultConfidenceInterval(fieldInfo.getVectorDimension()) - : this.confidenceInterval; + if (bits <= 4 && fieldInfo.getVectorDimension() % 2 != 0) { + throw new IllegalArgumentException( + "bits=" + + bits + + " is not supported for odd vector dimensions; vector dimension=" + + fieldInfo.getVectorDimension()); + } FieldWriter quantizedWriter = - new FieldWriter(confidenceInterval, fieldInfo, segmentWriteState.infoStream, indexWriter); + new FieldWriter( + confidenceInterval, + bits, + compress, + fieldInfo, + segmentWriteState.infoStream, + indexWriter); fields.add(quantizedWriter); indexWriter = quantizedWriter; } @@ -164,24 +181,23 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE // the vectors directly to the new segment. // No need to use temporary file as we don't have to re-open for reading if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { - ScalarQuantizer mergedQuantizationState = mergeQuantiles(fieldInfo, mergeState); + ScalarQuantizer mergedQuantizationState = + mergeAndRecalculateQuantiles(mergeState, fieldInfo, confidenceInterval, bits); MergedQuantizedVectorValues byteVectorValues = MergedQuantizedVectorValues.mergeQuantizedByteVectorValues( fieldInfo, mergeState, mergedQuantizationState); long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES); DocsWithFieldSet docsWithField = - writeQuantizedVectorData(quantizedVectorData, byteVectorValues); + writeQuantizedVectorData(quantizedVectorData, byteVectorValues, bits, compress); long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset; - float confidenceInterval = - this.confidenceInterval == null - ? calculateDefaultConfidenceInterval(fieldInfo.getVectorDimension()) - : this.confidenceInterval; writeMeta( fieldInfo, segmentWriteState.segmentInfo.maxDoc(), vectorDataOffset, vectorDataLength, confidenceInterval, + bits, + compress, mergedQuantizationState.getLowerQuantile(), mergedQuantizationState.getUpperQuantile(), docsWithField); @@ -195,7 +211,8 @@ public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( // Simply merge the underlying delegate, which just copies the raw vector data to a new // segment file rawVectorDelegate.mergeOneField(fieldInfo, mergeState); - ScalarQuantizer mergedQuantizationState = mergeQuantiles(fieldInfo, mergeState); + ScalarQuantizer mergedQuantizationState = + mergeAndRecalculateQuantiles(mergeState, fieldInfo, confidenceInterval, bits); return mergeOneFieldToIndex( segmentWriteState, fieldInfo, mergeState, mergedQuantizationState); } @@ -255,6 +272,8 @@ private void writeField(FieldWriter fieldData, int maxDoc) throws IOException { vectorDataOffset, vectorDataLength, confidenceInterval, + bits, + compress, fieldData.minQuantile, fieldData.maxQuantile, fieldData.docsWithField); @@ -266,6 +285,8 @@ private void writeMeta( long vectorDataOffset, long vectorDataLength, Float confidenceInterval, + byte bits, + boolean compress, Float lowerQuantile, Float upperQuantile, DocsWithFieldSet docsWithField) @@ -280,11 +301,9 @@ private void writeMeta( meta.writeInt(count); if (count > 0) { assert Float.isFinite(lowerQuantile) && Float.isFinite(upperQuantile); - meta.writeInt( - Float.floatToIntBits( - confidenceInterval != null - ? confidenceInterval - : calculateDefaultConfidenceInterval(field.getVectorDimension()))); + meta.writeInt(confidenceInterval == null ? -1 : Float.floatToIntBits(confidenceInterval)); + meta.writeByte(bits); + meta.writeByte(compress ? (byte) 1 : (byte) 0); meta.writeInt(Float.floatToIntBits(lowerQuantile)); meta.writeInt(Float.floatToIntBits(upperQuantile)); } @@ -296,6 +315,11 @@ private void writeMeta( private void writeQuantizedVectors(FieldWriter fieldData) throws IOException { ScalarQuantizer scalarQuantizer = fieldData.createQuantizer(); byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()]; + byte[] compressedVector = + fieldData.compress + ? OffHeapQuantizedByteVectorValues.compressedArray( + fieldData.fieldInfo.getVectorDimension(), bits) + : null; final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); float[] copy = fieldData.normalize ? new float[fieldData.fieldInfo.getVectorDimension()] : null; for (float[] v : fieldData.floatVectors) { @@ -307,7 +331,12 @@ private void writeQuantizedVectors(FieldWriter fieldData) throws IOException { float offsetCorrection = scalarQuantizer.quantize(v, vector, fieldData.fieldInfo.getVectorSimilarityFunction()); - quantizedVectorData.writeBytes(vector, vector.length); + if (compressedVector != null) { + OffHeapQuantizedByteVectorValues.compressBytes(vector, compressedVector); + quantizedVectorData.writeBytes(compressedVector, compressedVector.length); + } else { + quantizedVectorData.writeBytes(vector, vector.length); + } offsetBuffer.putFloat(offsetCorrection); quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length); offsetBuffer.rewind(); @@ -348,6 +377,8 @@ private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocMap vectorDataOffset, quantizedVectorLength, confidenceInterval, + bits, + compress, fieldData.minQuantile, fieldData.maxQuantile, newDocsWithField); @@ -356,6 +387,11 @@ private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocMap private void writeSortedQuantizedVectors(FieldWriter fieldData, int[] ordMap) throws IOException { ScalarQuantizer scalarQuantizer = fieldData.createQuantizer(); byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()]; + byte[] compressedVector = + fieldData.compress + ? OffHeapQuantizedByteVectorValues.compressedArray( + fieldData.fieldInfo.getVectorDimension(), bits) + : null; final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); float[] copy = fieldData.normalize ? new float[fieldData.fieldInfo.getVectorDimension()] : null; for (int ordinal : ordMap) { @@ -367,29 +403,35 @@ private void writeSortedQuantizedVectors(FieldWriter fieldData, int[] ordMap) th } float offsetCorrection = scalarQuantizer.quantize(v, vector, fieldData.fieldInfo.getVectorSimilarityFunction()); - quantizedVectorData.writeBytes(vector, vector.length); + if (compressedVector != null) { + OffHeapQuantizedByteVectorValues.compressBytes(vector, compressedVector); + quantizedVectorData.writeBytes(compressedVector, compressedVector.length); + } else { + quantizedVectorData.writeBytes(vector, vector.length); + } offsetBuffer.putFloat(offsetCorrection); quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length); offsetBuffer.rewind(); } } - private ScalarQuantizer mergeQuantiles(FieldInfo fieldInfo, MergeState mergeState) - throws IOException { - assert fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32; - float confidenceInterval = - this.confidenceInterval == null - ? calculateDefaultConfidenceInterval(fieldInfo.getVectorDimension()) - : this.confidenceInterval; - return mergeAndRecalculateQuantiles(mergeState, fieldInfo, confidenceInterval); - } - private ScalarQuantizedCloseableRandomVectorScorerSupplier mergeOneFieldToIndex( SegmentWriteState segmentWriteState, FieldInfo fieldInfo, MergeState mergeState, ScalarQuantizer mergedQuantizationState) throws IOException { + if (segmentWriteState.infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message( + QUANTIZED_VECTOR_COMPONENT, + "quantized field=" + + " confidenceInterval=" + + confidenceInterval + + " minQuantile=" + + mergedQuantizationState.getLowerQuantile() + + " maxQuantile=" + + mergedQuantizationState.getUpperQuantile()); + } long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES); IndexOutput tempQuantizedVectorData = segmentWriteState.directory.createTempOutput( @@ -401,7 +443,7 @@ private ScalarQuantizedCloseableRandomVectorScorerSupplier mergeOneFieldToIndex( MergedQuantizedVectorValues.mergeQuantizedByteVectorValues( fieldInfo, mergeState, mergedQuantizationState); DocsWithFieldSet docsWithField = - writeQuantizedVectorData(tempQuantizedVectorData, byteVectorValues); + writeQuantizedVectorData(tempQuantizedVectorData, byteVectorValues, bits, compress); CodecUtil.writeFooter(tempQuantizedVectorData); IOUtils.close(tempQuantizedVectorData); quantizationDataInput = @@ -421,6 +463,8 @@ private ScalarQuantizedCloseableRandomVectorScorerSupplier mergeOneFieldToIndex( vectorDataOffset, vectorDataLength, confidenceInterval, + bits, + compress, mergedQuantizationState.getLowerQuantile(), mergedQuantizationState.getUpperQuantile(), docsWithField); @@ -438,6 +482,8 @@ private ScalarQuantizedCloseableRandomVectorScorerSupplier mergeOneFieldToIndex( new OffHeapQuantizedByteVectorValues.DenseOffHeapVectorValues( fieldInfo.getVectorDimension(), docsWithField.cardinality(), + bits, + compress, quantizationDataInput))); } finally { if (success == false) { @@ -449,9 +495,7 @@ private ScalarQuantizedCloseableRandomVectorScorerSupplier mergeOneFieldToIndex( } static ScalarQuantizer mergeQuantiles( - List quantizationStates, - List segmentSizes, - float confidenceInterval) { + List quantizationStates, List segmentSizes, byte bits) { assert quantizationStates.size() == segmentSizes.size(); if (quantizationStates.isEmpty()) { return null; @@ -466,10 +510,13 @@ static ScalarQuantizer mergeQuantiles( lowerQuantile += quantizationStates.get(i).getLowerQuantile() * segmentSizes.get(i); upperQuantile += quantizationStates.get(i).getUpperQuantile() * segmentSizes.get(i); totalCount += segmentSizes.get(i); + if (quantizationStates.get(i).getBits() != bits) { + return null; + } } lowerQuantile /= totalCount; upperQuantile /= totalCount; - return new ScalarQuantizer(lowerQuantile, upperQuantile, confidenceInterval); + return new ScalarQuantizer(lowerQuantile, upperQuantile, bits); } /** @@ -531,11 +578,14 @@ private static ScalarQuantizer getQuantizedState( * @param mergeState The merge state * @param fieldInfo The field info * @param confidenceInterval The confidence interval + * @param bits The number of bits * @return The merged quantiles * @throws IOException If there is a low-level I/O error */ public static ScalarQuantizer mergeAndRecalculateQuantiles( - MergeState mergeState, FieldInfo fieldInfo, float confidenceInterval) throws IOException { + MergeState mergeState, FieldInfo fieldInfo, Float confidenceInterval, byte bits) + throws IOException { + assert fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32); List quantizationStates = new ArrayList<>(mergeState.liveDocs.length); List segmentSizes = new ArrayList<>(mergeState.liveDocs.length); for (int i = 0; i < mergeState.liveDocs.length; i++) { @@ -550,14 +600,17 @@ public static ScalarQuantizer mergeAndRecalculateQuantiles( segmentSizes.add(fvv.size()); } } - ScalarQuantizer mergedQuantiles = - mergeQuantiles(quantizationStates, segmentSizes, confidenceInterval); + ScalarQuantizer mergedQuantiles = mergeQuantiles(quantizationStates, segmentSizes, bits); // Segments no providing quantization state indicates that their quantiles were never // calculated. // To be safe, we should always recalculate given a sample set over all the float vectors in the // merged // segment view - if (mergedQuantiles == null || shouldRecomputeQuantiles(mergedQuantiles, quantizationStates)) { + if (mergedQuantiles == null + // For smaller `bits` values, we should always recalculate the quantiles + // TODO: this is very conservative, could we reuse information for even int4 quantization? + || bits <= 4 + || shouldRecomputeQuantiles(mergedQuantiles, quantizationStates)) { int numVectors = 0; FloatVectorValues vectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); @@ -568,10 +621,17 @@ public static ScalarQuantizer mergeAndRecalculateQuantiles( numVectors++; } mergedQuantiles = - ScalarQuantizer.fromVectors( - KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState), - confidenceInterval, - numVectors); + confidenceInterval == null + ? ScalarQuantizer.fromVectorsAutoInterval( + KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState), + fieldInfo.getVectorSimilarityFunction(), + numVectors, + bits) + : ScalarQuantizer.fromVectors( + KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState), + confidenceInterval, + numVectors, + bits); } return mergedQuantiles; } @@ -600,8 +660,17 @@ static boolean shouldRequantize(ScalarQuantizer existingQuantiles, ScalarQuantiz * Writes the vector values to the output and returns a set of documents that contains vectors. */ public static DocsWithFieldSet writeQuantizedVectorData( - IndexOutput output, QuantizedByteVectorValues quantizedByteVectorValues) throws IOException { + IndexOutput output, + QuantizedByteVectorValues quantizedByteVectorValues, + byte bits, + boolean compress) + throws IOException { DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + final byte[] compressedVector = + compress + ? OffHeapQuantizedByteVectorValues.compressedArray( + quantizedByteVectorValues.dimension(), bits) + : null; for (int docV = quantizedByteVectorValues.nextDoc(); docV != NO_MORE_DOCS; docV = quantizedByteVectorValues.nextDoc()) { @@ -609,7 +678,12 @@ public static DocsWithFieldSet writeQuantizedVectorData( byte[] binaryValue = quantizedByteVectorValues.vectorValue(); assert binaryValue.length == quantizedByteVectorValues.dimension() : "dim=" + quantizedByteVectorValues.dimension() + " len=" + binaryValue.length; - output.writeBytes(binaryValue, binaryValue.length); + if (compressedVector != null) { + OffHeapQuantizedByteVectorValues.compressBytes(binaryValue, compressedVector); + output.writeBytes(compressedVector, compressedVector.length); + } else { + output.writeBytes(binaryValue, binaryValue.length); + } output.writeInt(Float.floatToIntBits(quantizedByteVectorValues.getScoreCorrectionConstant())); docsWithField.add(docV); } @@ -625,7 +699,9 @@ static class FieldWriter extends FlatFieldVectorsWriter { private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class); private final List floatVectors; private final FieldInfo fieldInfo; - private final float confidenceInterval; + private final Float confidenceInterval; + private final byte bits; + private final boolean compress; private final InfoStream infoStream; private final boolean normalize; private float minQuantile = Float.POSITIVE_INFINITY; @@ -635,17 +711,21 @@ static class FieldWriter extends FlatFieldVectorsWriter { @SuppressWarnings("unchecked") FieldWriter( - float confidenceInterval, + Float confidenceInterval, + byte bits, + boolean compress, FieldInfo fieldInfo, InfoStream infoStream, KnnFieldVectorsWriter indexWriter) { super((KnnFieldVectorsWriter) indexWriter); this.confidenceInterval = confidenceInterval; + this.bits = bits; this.fieldInfo = fieldInfo; this.normalize = fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE; this.floatVectors = new ArrayList<>(); this.infoStream = infoStream; this.docsWithField = new DocsWithFieldSet(); + this.compress = compress; } void finish() throws IOException { @@ -656,13 +736,16 @@ void finish() throws IOException { finished = true; return; } + FloatVectorValues floatVectorValues = new FloatVectorWrapper(floatVectors, normalize); ScalarQuantizer quantizer = - ScalarQuantizer.fromVectors( - new FloatVectorWrapper( - floatVectors, - fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE), - confidenceInterval, - floatVectors.size()); + confidenceInterval == null + ? ScalarQuantizer.fromVectorsAutoInterval( + floatVectorValues, + fieldInfo.getVectorSimilarityFunction(), + floatVectors.size(), + bits) + : ScalarQuantizer.fromVectors( + floatVectorValues, confidenceInterval, floatVectors.size(), bits); minQuantile = quantizer.getLowerQuantile(); maxQuantile = quantizer.getUpperQuantile(); if (infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) { @@ -671,6 +754,8 @@ void finish() throws IOException { "quantized field=" + " confidenceInterval=" + confidenceInterval + + " bits=" + + bits + " minQuantile=" + minQuantile + " maxQuantile=" @@ -681,7 +766,7 @@ void finish() throws IOException { ScalarQuantizer createQuantizer() { assert finished; - return new ScalarQuantizer(minQuantile, maxQuantile, confidenceInterval); + return new ScalarQuantizer(minQuantile, maxQuantile, bits); } @Override @@ -765,7 +850,7 @@ public int advance(int target) throws IOException { } } - private static class QuantizedByteVectorValueSub extends DocIDMerger.Sub { + static class QuantizedByteVectorValueSub extends DocIDMerger.Sub { private final QuantizedByteVectorValues values; QuantizedByteVectorValueSub(MergeState.DocMap docMap, QuantizedByteVectorValues values) { @@ -799,6 +884,10 @@ public static MergedQuantizedVectorValues mergeQuantizedByteVectorValues( // Or we have never been quantized. if (reader == null || reader.getQuantizationState(fieldInfo.name) == null + // For smaller `bits` values, we should always recalculate the quantiles + // TODO: this is very conservative, could we reuse information for even int4 + // quantization? + || scalarQuantizer.getBits() <= 4 || shouldRequantize(reader.getQuantizationState(fieldInfo.name), scalarQuantizer)) { sub = new QuantizedByteVectorValueSub( @@ -884,7 +973,7 @@ public float getScoreCorrectionConstant() throws IOException { } } - private static class QuantizedFloatVectorValues extends QuantizedByteVectorValues { + static class QuantizedFloatVectorValues extends QuantizedByteVectorValues { private final FloatVectorValues values; private final ScalarQuantizer quantizer; private final byte[] quantizedVector; @@ -999,14 +1088,13 @@ public int totalVectorCount() { } } - private static final class OffsetCorrectedQuantizedByteVectorValues - extends QuantizedByteVectorValues { + static final class OffsetCorrectedQuantizedByteVectorValues extends QuantizedByteVectorValues { private final QuantizedByteVectorValues in; private final VectorSimilarityFunction vectorSimilarityFunction; private final ScalarQuantizer scalarQuantizer, oldScalarQuantizer; - private OffsetCorrectedQuantizedByteVectorValues( + OffsetCorrectedQuantizedByteVectorValues( QuantizedByteVectorValues in, VectorSimilarityFunction vectorSimilarityFunction, ScalarQuantizer scalarQuantizer, diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java index 5d7e493e98f9..872dc4fb23ff 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java @@ -36,6 +36,10 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect protected final int dimension; protected final int size; + protected final int numBytes; + protected final byte bits; + protected final boolean compress; + protected final IndexInput slice; protected final byte[] binaryValue; protected final ByteBuffer byteBuffer; @@ -43,11 +47,52 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect protected int lastOrd = -1; protected final float[] scoreCorrectionConstant = new float[1]; - OffHeapQuantizedByteVectorValues(int dimension, int size, IndexInput slice) { + static void decompressBytes(byte[] compressed, int numBytes) { + if (numBytes == compressed.length) { + return; + } + if (numBytes << 1 != compressed.length) { + throw new IllegalArgumentException( + "numBytes: " + numBytes + " does not match compressed length: " + compressed.length); + } + for (int i = 0; i < numBytes; ++i) { + compressed[numBytes + i] = (byte) (compressed[i] & 0x0F); + compressed[i] = (byte) ((compressed[i] & 0xFF) >> 4); + } + } + + static byte[] compressedArray(int dimension, byte bits) { + if (bits <= 4) { + return new byte[(dimension + 1) >> 1]; + } else { + return null; + } + } + + static void compressBytes(byte[] raw, byte[] compressed) { + if (compressed.length != ((raw.length + 1) >> 1)) { + throw new IllegalArgumentException( + "compressed length: " + compressed.length + " does not match raw length: " + raw.length); + } + for (int i = 0; i < compressed.length; ++i) { + int v = (raw[i] << 4) | raw[compressed.length + i]; + compressed[i] = (byte) v; + } + } + + OffHeapQuantizedByteVectorValues( + int dimension, int size, byte bits, boolean compress, IndexInput slice) { this.dimension = dimension; this.size = size; this.slice = slice; - this.byteSize = dimension + Float.BYTES; + this.bits = bits; + this.compress = compress; + if (bits <= 4 && compress) { + this.numBytes = (dimension + 1) >> 1; + } else { + this.numBytes = dimension; + } + this.byteSize = this.numBytes + Float.BYTES; byteBuffer = ByteBuffer.allocate(dimension); binaryValue = byteBuffer.array(); } @@ -68,8 +113,9 @@ public byte[] vectorValue(int targetOrd) throws IOException { return binaryValue; } slice.seek((long) targetOrd * byteSize); - slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), dimension); + slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), numBytes); slice.readFloats(scoreCorrectionConstant, 0, 1); + decompressBytes(binaryValue, numBytes); lastOrd = targetOrd; return binaryValue; } @@ -83,6 +129,8 @@ public static OffHeapQuantizedByteVectorValues load( OrdToDocDISIReaderConfiguration configuration, int dimension, int size, + byte bits, + boolean compress, long quantizedVectorDataOffset, long quantizedVectorDataLength, IndexInput vectorData) @@ -94,9 +142,10 @@ public static OffHeapQuantizedByteVectorValues load( vectorData.slice( "quantized-vector-data", quantizedVectorDataOffset, quantizedVectorDataLength); if (configuration.isDense()) { - return new DenseOffHeapVectorValues(dimension, size, bytesSlice); + return new DenseOffHeapVectorValues(dimension, size, bits, compress, bytesSlice); } else { - return new SparseOffHeapVectorValues(configuration, dimension, size, vectorData, bytesSlice); + return new SparseOffHeapVectorValues( + configuration, dimension, size, bits, compress, vectorData, bytesSlice); } } @@ -108,8 +157,9 @@ public static class DenseOffHeapVectorValues extends OffHeapQuantizedByteVectorV private int doc = -1; - public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice) { - super(dimension, size, slice); + public DenseOffHeapVectorValues( + int dimension, int size, byte bits, boolean compress, IndexInput slice) { + super(dimension, size, bits, compress, slice); } @Override @@ -138,7 +188,7 @@ public int advance(int target) throws IOException { @Override public DenseOffHeapVectorValues copy() throws IOException { - return new DenseOffHeapVectorValues(dimension, size, slice.clone()); + return new DenseOffHeapVectorValues(dimension, size, bits, compress, slice.clone()); } @Override @@ -158,10 +208,12 @@ public SparseOffHeapVectorValues( OrdToDocDISIReaderConfiguration configuration, int dimension, int size, + byte bits, + boolean compress, IndexInput dataIn, IndexInput slice) throws IOException { - super(dimension, size, slice); + super(dimension, size, bits, compress, slice); this.configuration = configuration; this.dataIn = dataIn; this.ordToDoc = configuration.getDirectMonotonicReader(dataIn); @@ -191,7 +243,8 @@ public int advance(int target) throws IOException { @Override public SparseOffHeapVectorValues copy() throws IOException { - return new SparseOffHeapVectorValues(configuration, dimension, size, dataIn, slice.clone()); + return new SparseOffHeapVectorValues( + configuration, dimension, size, bits, compress, dataIn, slice.clone()); } @Override @@ -221,7 +274,7 @@ public int length() { private static class EmptyOffHeapVectorValues extends OffHeapQuantizedByteVectorValues { public EmptyOffHeapVectorValues(int dimension) { - super(dimension, 0, null); + super(dimension, 0, (byte) 7, false, null); } private int doc = -1; diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java index 750e0ee136ae..e56d6b97f314 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java @@ -151,6 +151,11 @@ public int dotProduct(byte[] a, byte[] b) { return total; } + @Override + public int int4DotProduct(byte[] a, byte[] b) { + return dotProduct(a, b); + } + @Override public float cosine(byte[] a, byte[] b) { // Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14. diff --git a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java index 44943473fa4a..246cbdf95bc7 100644 --- a/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java +++ b/lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorUtilSupport.java @@ -36,6 +36,9 @@ public interface VectorUtilSupport { /** Returns the dot product computed over signed bytes. */ int dotProduct(byte[] a, byte[] b); + /** Returns the dot product over the computed bytes, assuming the values are int4 encoded. */ + int int4DotProduct(byte[] a, byte[] b); + /** Returns the cosine similarity between the two byte vectors. */ float cosine(byte[] a, byte[] b); diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index 4a792c182441..97f8c0473835 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -175,6 +175,13 @@ public static int dotProduct(byte[] a, byte[] b) { return IMPL.dotProduct(a, b); } + public static int int4DotProduct(byte[] a, byte[] b) { + if (a.length != b.length) { + throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); + } + return IMPL.int4DotProduct(a, b); + } + /** * Dot product score computed over signed bytes, scaled to be in [0, 1]. * diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorer.java b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorer.java index 2f825ccd972a..0c07e1e971c6 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorer.java @@ -76,7 +76,7 @@ public ScalarQuantizedRandomVectorScorer( this.queryOffset = correction; this.similarity = ScalarQuantizedVectorSimilarity.fromVectorSimilarity( - similarityFunction, scalarQuantizer.getConstantMultiplier()); + similarityFunction, scalarQuantizer.getConstantMultiplier(), scalarQuantizer.getBits()); this.values = values; } diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorerSupplier.java b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorerSupplier.java index 21b0667e47f0..b3b1d4adee37 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorerSupplier.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorerSupplier.java @@ -37,7 +37,7 @@ public ScalarQuantizedRandomVectorScorerSupplier( RandomAccessQuantizedByteVectorValues values) { this.similarity = ScalarQuantizedVectorSimilarity.fromVectorSimilarity( - similarityFunction, scalarQuantizer.getConstantMultiplier()); + similarityFunction, scalarQuantizer.getConstantMultiplier(), scalarQuantizer.getBits()); this.values = values; } diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedVectorSimilarity.java b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedVectorSimilarity.java index 6e93dba381b5..77ad4103e49b 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedVectorSimilarity.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedVectorSimilarity.java @@ -33,14 +33,17 @@ public interface ScalarQuantizedVectorSimilarity { * * @param sim similarity function * @param constMultiplier constant multiplier used for quantization + * @param bits number of bits used for quantization * @return a {@link ScalarQuantizedVectorSimilarity} that applies the appropriate corrections */ static ScalarQuantizedVectorSimilarity fromVectorSimilarity( - VectorSimilarityFunction sim, float constMultiplier) { + VectorSimilarityFunction sim, float constMultiplier, byte bits) { return switch (sim) { case EUCLIDEAN -> new Euclidean(constMultiplier); - case COSINE, DOT_PRODUCT -> new DotProduct(constMultiplier); - case MAXIMUM_INNER_PRODUCT -> new MaximumInnerProduct(constMultiplier); + case COSINE, DOT_PRODUCT -> new DotProduct( + constMultiplier, bits <= 4 ? VectorUtil::int4DotProduct : VectorUtil::dotProduct); + case MAXIMUM_INNER_PRODUCT -> new MaximumInnerProduct( + constMultiplier, bits <= 4 ? VectorUtil::int4DotProduct : VectorUtil::dotProduct); }; } @@ -66,15 +69,17 @@ public float score( /** Calculates dot product on quantized vectors, applying the appropriate corrections */ class DotProduct implements ScalarQuantizedVectorSimilarity { private final float constMultiplier; + private final ByteVectorComparator comparator; - public DotProduct(float constMultiplier) { + public DotProduct(float constMultiplier, ByteVectorComparator comparator) { this.constMultiplier = constMultiplier; + this.comparator = comparator; } @Override public float score( byte[] queryVector, float queryOffset, byte[] storedVector, float vectorOffset) { - int dotProduct = VectorUtil.dotProduct(storedVector, queryVector); + int dotProduct = comparator.compare(storedVector, queryVector); float adjustedDistance = dotProduct * constMultiplier + queryOffset + vectorOffset; return (1 + adjustedDistance) / 2; } @@ -83,17 +88,24 @@ public float score( /** Calculates max inner product on quantized vectors, applying the appropriate corrections */ class MaximumInnerProduct implements ScalarQuantizedVectorSimilarity { private final float constMultiplier; + private final ByteVectorComparator comparator; - public MaximumInnerProduct(float constMultiplier) { + public MaximumInnerProduct(float constMultiplier, ByteVectorComparator comparator) { this.constMultiplier = constMultiplier; + this.comparator = comparator; } @Override public float score( byte[] queryVector, float queryOffset, byte[] storedVector, float vectorOffset) { - int dotProduct = VectorUtil.dotProduct(storedVector, queryVector); + int dotProduct = comparator.compare(storedVector, queryVector); float adjustedDistance = dotProduct * constMultiplier + queryOffset + vectorOffset; return scaleMaxInnerProductScore(adjustedDistance); } } + + /** Compares two byte vectors */ + interface ByteVectorComparator { + int compare(byte[] v1, byte[] v2); + } } diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java index 89b2f1ed3ae7..95dd32a1dc7f 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java @@ -19,11 +19,15 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.Random; import java.util.stream.IntStream; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.HitQueue; +import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.util.IntroSelector; import org.apache.lucene.util.Selector; @@ -74,20 +78,23 @@ public class ScalarQuantizer { private final float alpha; private final float scale; - private final float minQuantile, maxQuantile, confidenceInterval; + private final byte bits; + private final float minQuantile, maxQuantile; /** * @param minQuantile the lower quantile of the distribution * @param maxQuantile the upper quantile of the distribution - * @param confidenceInterval The configured confidence interval used to calculate the quantiles. + * @param bits the number of bits to use for quantization */ - public ScalarQuantizer(float minQuantile, float maxQuantile, float confidenceInterval) { + public ScalarQuantizer(float minQuantile, float maxQuantile, byte bits) { assert maxQuantile >= minQuantile; + assert bits > 0 && bits <= 8; this.minQuantile = minQuantile; this.maxQuantile = maxQuantile; - this.scale = 127f / (maxQuantile - minQuantile); - this.alpha = (maxQuantile - minQuantile) / 127f; - this.confidenceInterval = confidenceInterval; + this.bits = bits; + final float divisor = (float) ((1 << bits) - 1); + this.scale = divisor / (maxQuantile - minQuantile); + this.alpha = (maxQuantile - minQuantile) / divisor; } /** @@ -100,31 +107,38 @@ public ScalarQuantizer(float minQuantile, float maxQuantile, float confidenceInt */ public float quantize(float[] src, byte[] dest, VectorSimilarityFunction similarityFunction) { assert src.length == dest.length; - float correctiveOffset = 0f; + float correction = 0; for (int i = 0; i < src.length; i++) { - float v = src[i]; - // Make sure the value is within the quantile range, cutting off the tails - // see first parenthesis in equation: byte = (float - minQuantile) * 127/(maxQuantile - - // minQuantile) - float dx = Math.max(minQuantile, Math.min(maxQuantile, src[i])) - minQuantile; - // Scale the value to the range [0, 127], this is our quantized value - // scale = 127/(maxQuantile - minQuantile) - float dxs = scale * dx; - // We multiply by `alpha` here to get the quantized value back into the original range - // to aid in calculating the corrective offset - float dxq = Math.round(dxs) * alpha; - // Calculate the corrective offset that needs to be applied to the score - // in addition to the `byte * minQuantile * alpha` term in the equation - // we add the `(dx - dxq) * dxq` term to account for the fact that the quantized value - // will be rounded to the nearest whole number and lose some accuracy - // Additionally, we account for the global correction of `minQuantile^2` in the equation - correctiveOffset += minQuantile * (v - minQuantile / 2.0F) + (dx - dxq) * dxq; - dest[i] = (byte) Math.round(dxs); + correction += quantizeFloat(src[i], dest, i); } if (similarityFunction.equals(VectorSimilarityFunction.EUCLIDEAN)) { return 0; } - return correctiveOffset; + return correction; + } + + private float quantizeFloat(float v, byte[] dest, int destIndex) { + assert dest == null || destIndex < dest.length; + // Make sure the value is within the quantile range, cutting off the tails + // see first parenthesis in equation: byte = (float - minQuantile) * 127/(maxQuantile - + // minQuantile) + float dx = v - minQuantile; + float dxc = Math.max(minQuantile, Math.min(maxQuantile, v)) - minQuantile; + // Scale the value to the range [0, 127], this is our quantized value + // scale = 127/(maxQuantile - minQuantile) + float dxs = scale * dxc; + // We multiply by `alpha` here to get the quantized value back into the original range + // to aid in calculating the corrective offset + float dxq = Math.round(dxs) * alpha; + if (dest != null) { + dest[destIndex] = (byte) Math.round(dxs); + } + // Calculate the corrective offset that needs to be applied to the score + // in addition to the `byte * minQuantile * alpha` term in the equation + // we add the `(dx - dxq) * dxq` term to account for the fact that the quantized value + // will be rounded to the nearest whole number and lose some accuracy + // Additionally, we account for the global correction of `minQuantile^2` in the equation + return minQuantile * (v - minQuantile / 2.0F) + (dx - dxq) * dxq; } /** @@ -146,10 +160,7 @@ public float recalculateCorrectiveOffset( for (int i = 0; i < quantizedVector.length; i++) { // dequantize the old value in order to recalculate the corrective offset float v = (oldQuantizer.alpha * quantizedVector[i]) + oldQuantizer.minQuantile; - float dx = Math.max(minQuantile, Math.min(maxQuantile, v)) - minQuantile; - float dxs = scale * dx; - float dxq = Math.round(dxs) * alpha; - correctiveOffset += minQuantile * (v - minQuantile / 2.0F) + (dx - dxq) * dxq; + correctiveOffset += quantizeFloat(v, null, 0); } return correctiveOffset; } @@ -160,7 +171,7 @@ public float recalculateCorrectiveOffset( * @param src the source vector * @param dest the destination vector */ - public void deQuantize(byte[] src, float[] dest) { + void deQuantize(byte[] src, float[] dest) { assert src.length == dest.length; for (int i = 0; i < src.length; i++) { dest[i] = (alpha * src[i]) + minQuantile; @@ -175,14 +186,14 @@ public float getUpperQuantile() { return maxQuantile; } - public float getConfidenceInterval() { - return confidenceInterval; - } - public float getConstantMultiplier() { return alpha * alpha; } + public byte getBits() { + return bits; + } + @Override public String toString() { return "ScalarQuantizer{" @@ -190,14 +201,14 @@ public String toString() { + minQuantile + ", maxQuantile=" + maxQuantile - + ", confidenceInterval=" - + confidenceInterval + + ", bits=" + + bits + '}'; } private static final Random random = new Random(42); - static int[] reservoirSampleIndices(int numFloatVecs, int sampleSize) { + private static int[] reservoirSampleIndices(int numFloatVecs, int sampleSize) { int[] vectorsToTake = IntStream.range(0, sampleSize).toArray(); for (int i = sampleSize; i < numFloatVecs; i++) { int j = random.nextInt(i + 1); @@ -220,26 +231,35 @@ static int[] reservoirSampleIndices(int numFloatVecs, int sampleSize) { * @param confidenceInterval the confidence interval used to calculate the quantiles * @param totalVectorCount the total number of live float vectors in the index. This is vital for * accounting for deleted documents when calculating the quantiles. + * @param bits the number of bits to use for quantization * @return A new {@link ScalarQuantizer} instance * @throws IOException if there is an error reading the float vector values */ public static ScalarQuantizer fromVectors( - FloatVectorValues floatVectorValues, float confidenceInterval, int totalVectorCount) + FloatVectorValues floatVectorValues, + float confidenceInterval, + int totalVectorCount, + byte bits) throws IOException { return fromVectors( - floatVectorValues, confidenceInterval, totalVectorCount, SCALAR_QUANTIZATION_SAMPLE_SIZE); + floatVectorValues, + confidenceInterval, + totalVectorCount, + bits, + SCALAR_QUANTIZATION_SAMPLE_SIZE); } static ScalarQuantizer fromVectors( FloatVectorValues floatVectorValues, float confidenceInterval, int totalVectorCount, + byte bits, int quantizationSampleSize) throws IOException { assert 0.9f <= confidenceInterval && confidenceInterval <= 1f; assert quantizationSampleSize > SCRATCH_SIZE; if (totalVectorCount == 0) { - return new ScalarQuantizer(0f, 0f, confidenceInterval); + return new ScalarQuantizer(0f, 0f, bits); } if (confidenceInterval == 1f) { float min = Float.POSITIVE_INFINITY; @@ -250,13 +270,14 @@ static ScalarQuantizer fromVectors( max = Math.max(max, v); } } - return new ScalarQuantizer(min, max, confidenceInterval); + return new ScalarQuantizer(min, max, bits); } final float[] quantileGatheringScratch = new float[floatVectorValues.dimension() * Math.min(SCRATCH_SIZE, totalVectorCount)]; int count = 0; - double upperSum = 0; - double lowerSum = 0; + double[] upperSum = new double[1]; + double[] lowerSum = new double[1]; + float[] confidenceIntervals = new float[] {confidenceInterval}; if (totalVectorCount <= quantizationSampleSize) { int scratchSize = Math.min(SCRATCH_SIZE, totalVectorCount); int i = 0; @@ -266,10 +287,7 @@ static ScalarQuantizer fromVectors( vectorValue, 0, quantileGatheringScratch, i * vectorValue.length, vectorValue.length); i++; if (i == scratchSize) { - float[] upperAndLower = - getUpperAndLowerQuantile(quantileGatheringScratch, confidenceInterval); - upperSum += upperAndLower[1]; - lowerSum += upperAndLower[0]; + extractQuantiles(confidenceIntervals, quantileGatheringScratch, upperSum, lowerSum); i = 0; count++; } @@ -277,10 +295,8 @@ static ScalarQuantizer fromVectors( // Note, we purposefully don't use the rest of the scratch state if we have fewer than // `SCRATCH_SIZE` vectors, mainly because if we are sampling so few vectors then we don't // want to be adversely affected by the extreme confidence intervals over small sample sizes - return new ScalarQuantizer( - (float) lowerSum / count, (float) upperSum / count, confidenceInterval); + return new ScalarQuantizer((float) lowerSum[0] / count, (float) upperSum[0] / count, bits); } - // Reservoir sample the vector ordinals we want to read int[] vectorsToTake = reservoirSampleIndices(totalVectorCount, quantizationSampleSize); int index = 0; int idx = 0; @@ -296,16 +312,213 @@ static ScalarQuantizer fromVectors( vectorValue, 0, quantileGatheringScratch, idx * vectorValue.length, vectorValue.length); idx++; if (idx == SCRATCH_SIZE) { - float[] upperAndLower = - getUpperAndLowerQuantile(quantileGatheringScratch, confidenceInterval); - upperSum += upperAndLower[1]; - lowerSum += upperAndLower[0]; + extractQuantiles(confidenceIntervals, quantileGatheringScratch, upperSum, lowerSum); count++; idx = 0; } } - return new ScalarQuantizer( - (float) lowerSum / count, (float) upperSum / count, confidenceInterval); + return new ScalarQuantizer((float) lowerSum[0] / count, (float) upperSum[0] / count, bits); + } + + public static ScalarQuantizer fromVectorsAutoInterval( + FloatVectorValues floatVectorValues, + VectorSimilarityFunction function, + int totalVectorCount, + byte bits) + throws IOException { + if (totalVectorCount == 0) { + return new ScalarQuantizer(0f, 0f, bits); + } + + int sampleSize = Math.min(totalVectorCount, 1000); + final float[] quantileGatheringScratch = + new float[floatVectorValues.dimension() * Math.min(SCRATCH_SIZE, totalVectorCount)]; + int count = 0; + double[] upperSum = new double[2]; + double[] lowerSum = new double[2]; + final List sampledDocs = new ArrayList<>(sampleSize); + float[] confidenceIntervals = + new float[] { + 1 + - Math.min(32, floatVectorValues.dimension() / 10f) + / (floatVectorValues.dimension() + 1), + 1 - 1f / (floatVectorValues.dimension() + 1) + }; + if (totalVectorCount <= sampleSize) { + int scratchSize = Math.min(SCRATCH_SIZE, totalVectorCount); + int i = 0; + while (floatVectorValues.nextDoc() != NO_MORE_DOCS) { + gatherSample(floatVectorValues, quantileGatheringScratch, sampledDocs, i); + i++; + if (i == scratchSize) { + extractQuantiles(confidenceIntervals, quantileGatheringScratch, upperSum, lowerSum); + i = 0; + count++; + } + } + } else { + // Reservoir sample the vector ordinals we want to read + int[] vectorsToTake = reservoirSampleIndices(totalVectorCount, 1000); + // TODO make this faster by .advance()ing & dual iterator + int index = 0; + int idx = 0; + for (int i : vectorsToTake) { + while (index <= i) { + // We cannot use `advance(docId)` as MergedVectorValues does not support it + floatVectorValues.nextDoc(); + index++; + } + assert floatVectorValues.docID() != NO_MORE_DOCS; + gatherSample(floatVectorValues, quantileGatheringScratch, sampledDocs, idx); + idx++; + if (idx == SCRATCH_SIZE) { + extractQuantiles(confidenceIntervals, quantileGatheringScratch, upperSum, lowerSum); + count++; + idx = 0; + } + } + } + + // Here we gather the upper and lower bounds for the quantile grid search + float al = (float) lowerSum[1] / count; + float bu = (float) upperSum[1] / count; + final float au = (float) lowerSum[0] / count; + final float bl = (float) upperSum[0] / count; + final float[] lowerCandidates = new float[16]; + final float[] upperCandidates = new float[16]; + int idx = 0; + for (float i = 0f; i < 32f; i += 2f) { + lowerCandidates[idx] = al + i * (au - al) / 32f; + upperCandidates[idx] = bl + i * (bu - bl) / 32f; + idx++; + } + // Now we need to find the best candidate pair by correlating the true quantized nearest + // neighbor scores + // with the float vector scores + List nearestNeighbors = findNearestNeighbors(sampledDocs, function); + float[] bestPair = + candidateGridSearch( + nearestNeighbors, sampledDocs, lowerCandidates, upperCandidates, function, bits); + return new ScalarQuantizer(bestPair[0], bestPair[1], bits); + } + + private static void extractQuantiles( + float[] confidenceIntervals, + float[] quantileGatheringScratch, + double[] upperSum, + double[] lowerSum) { + assert confidenceIntervals.length == upperSum.length + && confidenceIntervals.length == lowerSum.length; + for (int i = 0; i < confidenceIntervals.length; i++) { + float[] upperAndLower = + getUpperAndLowerQuantile(quantileGatheringScratch, confidenceIntervals[i]); + upperSum[i] += upperAndLower[1]; + lowerSum[i] += upperAndLower[0]; + } + } + + private static void gatherSample( + FloatVectorValues floatVectorValues, + float[] quantileGatheringScratch, + List sampledDocs, + int i) + throws IOException { + float[] vectorValue = floatVectorValues.vectorValue(); + float[] copy = new float[vectorValue.length]; + System.arraycopy(vectorValue, 0, copy, 0, vectorValue.length); + sampledDocs.add(copy); + System.arraycopy( + vectorValue, 0, quantileGatheringScratch, i * vectorValue.length, vectorValue.length); + } + + private static float[] candidateGridSearch( + List nearestNeighbors, + List vectors, + float[] lowerCandidates, + float[] upperCandidates, + VectorSimilarityFunction function, + byte bits) { + double maxCorr = Double.NEGATIVE_INFINITY; + float bestLower = 0f; + float bestUpper = 0f; + ScoreErrorCorrelator scoreErrorCorrelator = + new ScoreErrorCorrelator(function, nearestNeighbors, vectors, bits); + // first do a coarse grained search to find the initial best candidate pair + int bestQuandrantLower = 0; + int bestQuandrantUpper = 0; + for (int i = 0; i < lowerCandidates.length; i += 4) { + float lower = lowerCandidates[i]; + for (int j = 0; j < upperCandidates.length; j += 4) { + float upper = upperCandidates[j]; + if (upper <= lower) { + continue; + } + double mean = scoreErrorCorrelator.scoreErrorCorrelation(lower, upper); + if (mean > maxCorr) { + maxCorr = mean; + bestLower = lower; + bestUpper = upper; + bestQuandrantLower = i; + bestQuandrantUpper = j; + } + } + } + // Now search within the best quadrant + for (int i = bestQuandrantLower + 1; i < bestQuandrantLower + 4; i++) { + for (int j = bestQuandrantUpper + 1; j < bestQuandrantUpper + 4; j++) { + float lower = lowerCandidates[i]; + float upper = upperCandidates[j]; + if (upper <= lower) { + continue; + } + double mean = scoreErrorCorrelator.scoreErrorCorrelation(lower, upper); + if (mean > maxCorr) { + maxCorr = mean; + bestLower = lower; + bestUpper = upper; + } + } + } + return new float[] {bestLower, bestUpper}; + } + + /** + * @param vectors The vectors to find the nearest neighbors for each other + * @param similarityFunction The similarity function to use + * @return The top 10 nearest neighbors for each vector from the vectors list + */ + private static List findNearestNeighbors( + List vectors, VectorSimilarityFunction similarityFunction) { + List queues = new ArrayList<>(vectors.size()); + queues.add(new HitQueue(10, false)); + for (int i = 0; i < vectors.size(); i++) { + float[] vector = vectors.get(i); + for (int j = i + 1; j < vectors.size(); j++) { + float[] otherVector = vectors.get(j); + float score = similarityFunction.compare(vector, otherVector); + // initialize the rest of the queues + if (queues.size() <= j) { + queues.add(new HitQueue(10, false)); + } + queues.get(i).insertWithOverflow(new ScoreDoc(j, score)); + queues.get(j).insertWithOverflow(new ScoreDoc(i, score)); + } + } + // Extract the top 10 from each queue + List result = new ArrayList<>(vectors.size()); + OnlineMeanAndVar meanAndVar = new OnlineMeanAndVar(); + for (int i = 0; i < vectors.size(); i++) { + HitQueue queue = queues.get(i); + ScoreDoc[] scoreDocs = new ScoreDoc[queue.size()]; + for (int j = queue.size() - 1; j >= 0; j--) { + scoreDocs[j] = queue.pop(); + assert scoreDocs[j] != null; + meanAndVar.add(scoreDocs[j].score); + } + result.add(new ScoreDocsAndScoreVariance(scoreDocs, meanAndVar.var())); + meanAndVar.reset(); + } + return result; } /** @@ -319,9 +532,8 @@ static ScalarQuantizer fromVectors( * @return lower and upper quantile values */ static float[] getUpperAndLowerQuantile(float[] arr, float confidenceInterval) { - assert 0.9f <= confidenceInterval && confidenceInterval <= 1f; int selectorIndex = (int) (arr.length * (1f - confidenceInterval) / 2f + 0.5f); - if (selectorIndex > 0) { + if (selectorIndex > 0 && arr.length > 2) { Selector selector = new FloatSelector(arr); selector.select(0, arr.length, arr.length - selectorIndex); selector.select(0, arr.length - selectorIndex, selectorIndex); @@ -361,4 +573,95 @@ protected void swap(int i, int j) { arr[j] = tmp; } } + + private static class ScoreDocsAndScoreVariance { + private final ScoreDoc[] scoreDocs; + private final float scoreVariance; + + public ScoreDocsAndScoreVariance(ScoreDoc[] scoreDocs, float scoreVariance) { + this.scoreDocs = scoreDocs; + this.scoreVariance = scoreVariance; + } + + public ScoreDoc[] getScoreDocs() { + return scoreDocs; + } + } + + private static class OnlineMeanAndVar { + private double mean = 0.0; + private double var = 0.0; + private int n = 0; + + void reset() { + mean = 0.0; + var = 0.0; + n = 0; + } + + void add(double x) { + n++; + double delta = x - mean; + mean += delta / n; + var += delta * (x - mean); + } + + float var() { + return (float) (var / (n - 1)); + } + } + + /** + * This class is used to correlate the scores of the nearest neighbors with the errors in the + * scores. This is used to find the best quantile pair for the scalar quantizer. + */ + private static class ScoreErrorCorrelator { + private final OnlineMeanAndVar corr = new OnlineMeanAndVar(); + private final OnlineMeanAndVar errors = new OnlineMeanAndVar(); + private final VectorSimilarityFunction function; + private final List nearestNeighbors; + private final List vectors; + private final byte[] query; + private final byte[] vector; + private final byte bits; + + public ScoreErrorCorrelator( + VectorSimilarityFunction function, + List nearestNeighbors, + List vectors, + byte bits) { + this.function = function; + this.nearestNeighbors = nearestNeighbors; + this.vectors = vectors; + this.query = new byte[vectors.get(0).length]; + this.vector = new byte[vectors.get(0).length]; + this.bits = bits; + } + + double scoreErrorCorrelation(float lowerQuantile, float upperQuantile) { + corr.reset(); + ScalarQuantizer quantizer = new ScalarQuantizer(lowerQuantile, upperQuantile, bits); + ScalarQuantizedVectorSimilarity scalarQuantizedVectorSimilarity = + ScalarQuantizedVectorSimilarity.fromVectorSimilarity( + function, quantizer.getConstantMultiplier(), quantizer.bits); + for (int i = 0; i < nearestNeighbors.size(); i++) { + float queryCorrection = quantizer.quantize(vectors.get(i), query, function); + ScoreDocsAndScoreVariance scoreDocsAndScoreVariance = nearestNeighbors.get(i); + ScoreDoc[] scoreDocs = scoreDocsAndScoreVariance.getScoreDocs(); + float scoreVariance = scoreDocsAndScoreVariance.scoreVariance; + // calculate the score for the vector against its nearest neighbors but with quantized + // scores now + errors.reset(); + for (ScoreDoc scoreDoc : scoreDocs) { + float vectorCorrection = quantizer.quantize(vectors.get(scoreDoc.doc), vector, function); + float qScore = + scalarQuantizedVectorSimilarity.score( + query, queryCorrection, vector, vectorCorrection); + errors.add(qScore - scoreDoc.score); + } + corr.add(1 - errors.var() / scoreVariance); + } + return corr.mean; + } + } } diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java index ccd838cb8dd0..96dafcf2c1af 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java @@ -389,6 +389,53 @@ private int dotProductBody128(byte[] a, byte[] b, int limit) { return acc.reduceLanes(ADD); } + @Override + public int int4DotProduct(byte[] a, byte[] b) { + int i = 0; + int res = 0; + if (VECTOR_BITSIZE >= 512 || VECTOR_BITSIZE == 256) { + return dotProduct(a, b); + } else if (a.length >= 32 && HAS_FAST_INTEGER_VECTORS) { + i += ByteVector.SPECIES_128.loopBound(a.length); + res += int4DotProductBody128(a, b, i); + } + // scalar tail + for (; i < a.length; i++) { + res += b[i] * a[i]; + } + return res; + } + + private int int4DotProductBody128(byte[] a, byte[] b, int limit) { + int sum = 0; + // iterate in chunks of 1024 items to ensure we don't overflow the short accumulator + for (int i = 0; i < limit; i += 1024) { + ShortVector acc0 = ShortVector.zero(ShortVector.SPECIES_128); + ShortVector acc1 = ShortVector.zero(ShortVector.SPECIES_128); + int innerLimit = Math.min(limit - i, 1024); + for (int j = 0; j < innerLimit; j += ByteVector.SPECIES_128.length()) { + ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i + j); + ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i + j); + ByteVector prod8 = va8.mul(vb8); + ShortVector prod16 = + prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); + acc0 = acc0.add(prod16.and((short) 0xFF)); + + va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i + j + 8); + vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i + j + 8); + prod8 = va8.mul(vb8); + prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); + acc1 = acc1.add(prod16.and((short) 0xFF)); + } + IntVector intAcc0 = acc0.convertShape(S2I, IntVector.SPECIES_128, 0).reinterpretAsInts(); + IntVector intAcc1 = acc0.convertShape(S2I, IntVector.SPECIES_128, 1).reinterpretAsInts(); + IntVector intAcc2 = acc1.convertShape(S2I, IntVector.SPECIES_128, 0).reinterpretAsInts(); + IntVector intAcc3 = acc1.convertShape(S2I, IntVector.SPECIES_128, 1).reinterpretAsInts(); + sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD); + } + return sum; + } + @Override public float cosine(byte[] a, byte[] b) { int i = 0; diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java index 66ae7bc68d16..0dc7261acb6c 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java @@ -42,17 +42,37 @@ import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.ScalarQuantizer; +import org.junit.Before; public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase { + KnnVectorsFormat format; + Float confidenceInterval; + int bits; + + @Before + @Override + public void setUp() throws Exception { + bits = random().nextBoolean() ? 4 : 7; + confidenceInterval = random().nextBoolean() ? 0.99f : null; + format = + new Lucene99HnswScalarQuantizedVectorsFormat( + Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, + Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, + 1, + bits, + random().nextBoolean(), + confidenceInterval, + null); + super.setUp(); + } + @Override protected Codec getCodec() { return new Lucene99Codec() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { - return new Lucene99HnswScalarQuantizedVectorsFormat( - Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, - Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH); + return format; } }; } @@ -63,17 +83,25 @@ public void testQuantizedVectorsWriteAndRead() throws Exception { VectorSimilarityFunction similarityFunction = randomSimilarity(); boolean normalize = similarityFunction == VectorSimilarityFunction.COSINE; int dim = random().nextInt(64) + 1; + if (dim % 2 == 1) { + dim++; + } List vectors = new ArrayList<>(numVectors); for (int i = 0; i < numVectors; i++) { vectors.add(randomVector(dim)); } - float confidenceInterval = - Lucene99ScalarQuantizedVectorsFormat.calculateDefaultConfidenceInterval(dim); ScalarQuantizer scalarQuantizer = - ScalarQuantizer.fromVectors( - new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize), - confidenceInterval, - numVectors); + confidenceInterval == null + ? ScalarQuantizer.fromVectorsAutoInterval( + new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize), + similarityFunction, + numVectors, + (byte) bits) + : ScalarQuantizer.fromVectors( + new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize), + confidenceInterval, + numVectors, + (byte) bits); float[] expectedCorrections = new float[numVectors]; byte[][] expectedVectors = new byte[numVectors][]; for (int i = 0; i < numVectors; i++) { @@ -149,11 +177,12 @@ public void testToString() { new FilterCodec("foo", Codec.getDefault()) { @Override public KnnVectorsFormat knnVectorsFormat() { - return new Lucene99HnswScalarQuantizedVectorsFormat(10, 20, 1, 0.9f, null); + return new Lucene99HnswScalarQuantizedVectorsFormat( + 10, 20, 1, (byte) 4, false, 0.9f, null); } }; String expectedString = - "Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, confidenceInterval=0.9, rawVectorFormat=Lucene99FlatVectorsFormat()))"; + "Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, confidenceInterval=0.9, bits=4, compress=false, rawVectorFormat=Lucene99FlatVectorsFormat()))"; assertEquals(expectedString, customCodec.knnVectorsFormat().toString()); } @@ -174,15 +203,28 @@ public void testLimits() { () -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 3201)); expectThrows( IllegalArgumentException.class, - () -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 100, 0, 1.1f, null)); + () -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 100, 0, 7, false, 1.1f, null)); + expectThrows( + IllegalArgumentException.class, + () -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 100, 0, -1, false, null, null)); + expectThrows( + IllegalArgumentException.class, + () -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 100, 0, 5, false, null, null)); + + expectThrows( + IllegalArgumentException.class, + () -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 100, 0, 9, false, null, null)); + expectThrows( + IllegalArgumentException.class, + () -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 100, 0, 7, false, 0.8f, null)); expectThrows( IllegalArgumentException.class, - () -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 100, 0, 0.8f, null)); + () -> new Lucene99HnswScalarQuantizedVectorsFormat(20, 100, 100, 7, false, null, null)); expectThrows( IllegalArgumentException.class, () -> new Lucene99HnswScalarQuantizedVectorsFormat( - 20, 100, 1, null, new SameThreadExecutorService())); + 20, 100, 1, 7, false, null, new SameThreadExecutorService())); } // Ensures that all expected vector similarity functions are translatable diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizedVectorSimilarity.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizedVectorSimilarity.java index 14d93f5ad114..468046090a53 100644 --- a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizedVectorSimilarity.java +++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizedVectorSimilarity.java @@ -39,14 +39,17 @@ public void testToEuclidean() throws IOException { float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f); FloatVectorValues floatVectorValues = fromFloats(floats); ScalarQuantizer scalarQuantizer = - ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, floats.length); + ScalarQuantizer.fromVectors( + floatVectorValues, confidenceInterval, floats.length, (byte) 7); byte[][] quantized = new byte[floats.length][]; float[] offsets = quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.EUCLIDEAN); float[] query = ArrayUtil.copyOfSubArray(floats[0], 0, dims); ScalarQuantizedVectorSimilarity quantizedSimilarity = ScalarQuantizedVectorSimilarity.fromVectorSimilarity( - VectorSimilarityFunction.EUCLIDEAN, scalarQuantizer.getConstantMultiplier()); + VectorSimilarityFunction.EUCLIDEAN, + scalarQuantizer.getConstantMultiplier(), + scalarQuantizer.getBits()); assertQuantizedScores( floats, quantized, @@ -69,7 +72,8 @@ public void testToCosine() throws IOException { float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f); FloatVectorValues floatVectorValues = fromFloatsNormalized(floats, null); ScalarQuantizer scalarQuantizer = - ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, floats.length); + ScalarQuantizer.fromVectors( + floatVectorValues, confidenceInterval, floats.length, (byte) 7); byte[][] quantized = new byte[floats.length][]; float[] offsets = quantizeVectorsNormalized( @@ -78,7 +82,9 @@ public void testToCosine() throws IOException { VectorUtil.l2normalize(query); ScalarQuantizedVectorSimilarity quantizedSimilarity = ScalarQuantizedVectorSimilarity.fromVectorSimilarity( - VectorSimilarityFunction.COSINE, scalarQuantizer.getConstantMultiplier()); + VectorSimilarityFunction.COSINE, + scalarQuantizer.getConstantMultiplier(), + scalarQuantizer.getBits()); assertQuantizedScores( floats, quantized, @@ -103,7 +109,8 @@ public void testToDotProduct() throws IOException { float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f); FloatVectorValues floatVectorValues = fromFloats(floats); ScalarQuantizer scalarQuantizer = - ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, floats.length); + ScalarQuantizer.fromVectors( + floatVectorValues, confidenceInterval, floats.length, (byte) 7); byte[][] quantized = new byte[floats.length][]; float[] offsets = quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.DOT_PRODUCT); @@ -111,7 +118,9 @@ public void testToDotProduct() throws IOException { VectorUtil.l2normalize(query); ScalarQuantizedVectorSimilarity quantizedSimilarity = ScalarQuantizedVectorSimilarity.fromVectorSimilarity( - VectorSimilarityFunction.DOT_PRODUCT, scalarQuantizer.getConstantMultiplier()); + VectorSimilarityFunction.DOT_PRODUCT, + scalarQuantizer.getConstantMultiplier(), + scalarQuantizer.getBits()); assertQuantizedScores( floats, quantized, @@ -133,7 +142,8 @@ public void testToMaxInnerProduct() throws IOException { float error = Math.max((100 - confidenceInterval) * 0.5f, 0.5f); FloatVectorValues floatVectorValues = fromFloats(floats); ScalarQuantizer scalarQuantizer = - ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, floats.length); + ScalarQuantizer.fromVectors( + floatVectorValues, confidenceInterval, floats.length, (byte) 7); byte[][] quantized = new byte[floats.length][]; float[] offsets = quantizeVectors( @@ -142,7 +152,8 @@ public void testToMaxInnerProduct() throws IOException { ScalarQuantizedVectorSimilarity quantizedSimilarity = ScalarQuantizedVectorSimilarity.fromVectorSimilarity( VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT, - scalarQuantizer.getConstantMultiplier()); + scalarQuantizer.getConstantMultiplier(), + scalarQuantizer.getBits()); assertQuantizedScores( floats, quantized, diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java index 926b985f5388..97d313eb4388 100644 --- a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java +++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java @@ -34,7 +34,8 @@ public void testQuantizeAndDeQuantize() throws IOException { float[][] floats = randomFloats(numVecs, dims); FloatVectorValues floatVectorValues = fromFloats(floats); - ScalarQuantizer scalarQuantizer = ScalarQuantizer.fromVectors(floatVectorValues, 1, numVecs); + ScalarQuantizer scalarQuantizer = + ScalarQuantizer.fromVectors(floatVectorValues, 1, numVecs, (byte) 7); float[] dequantized = new float[dims]; byte[] quantized = new byte[dims]; byte[] requantized = new byte[dims]; @@ -87,6 +88,7 @@ public void testScalarWithSampling() throws IOException { floatVectorValues, 0.99f, floatVectorValues.numLiveVectors, + (byte) 7, Math.max(floatVectorValues.numLiveVectors - 1, SCRATCH_SIZE + 1)); } { @@ -96,6 +98,7 @@ public void testScalarWithSampling() throws IOException { floatVectorValues, 0.99f, floatVectorValues.numLiveVectors, + (byte) 7, Math.max(floatVectorValues.numLiveVectors - 1, SCRATCH_SIZE + 1)); } { @@ -105,6 +108,7 @@ public void testScalarWithSampling() throws IOException { floatVectorValues, 0.99f, floatVectorValues.numLiveVectors, + (byte) 7, Math.max(floatVectorValues.numLiveVectors - 1, SCRATCH_SIZE + 1)); } { @@ -114,10 +118,36 @@ public void testScalarWithSampling() throws IOException { floatVectorValues, 0.99f, floatVectorValues.numLiveVectors, + (byte) 7, Math.max(random().nextInt(floatVectorValues.floats.length - 1) + 1, SCRATCH_SIZE + 1)); } } + public void testFromVectorsAutoInterval() throws IOException { + int dims = 128; + int numVecs = 100; + VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; + + float[][] floats = randomFloats(numVecs, dims); + FloatVectorValues floatVectorValues = fromFloats(floats); + ScalarQuantizer scalarQuantizer = + ScalarQuantizer.fromVectorsAutoInterval( + floatVectorValues, similarityFunction, numVecs, (byte) 4); + assertNotNull(scalarQuantizer); + float[] dequantized = new float[dims]; + byte[] quantized = new byte[dims]; + byte[] requantized = new byte[dims]; + for (int i = 0; i < numVecs; i++) { + scalarQuantizer.quantize(floats[i], quantized, similarityFunction); + scalarQuantizer.deQuantize(quantized, dequantized); + scalarQuantizer.quantize(dequantized, requantized, similarityFunction); + for (int j = 0; j < dims; j++) { + assertEquals(dequantized[j], floats[i][j], 0.2); + assertEquals(quantized[j], requantized[j]); + } + } + } + static void shuffleArray(float[] ar) { for (int i = ar.length - 1; i > 0; i--) { int index = random().nextInt(i + 1); diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java index 29a1fe0fa086..b6e3da77224f 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java @@ -127,12 +127,12 @@ public void testIllegalDimChangeTwoDocs() throws Exception { w.addDocument(doc); Document doc2 = new Document(); - doc2.add(new KnnFloatVectorField("f", new float[3], VectorSimilarityFunction.DOT_PRODUCT)); + doc2.add(new KnnFloatVectorField("f", new float[6], VectorSimilarityFunction.DOT_PRODUCT)); IllegalArgumentException expected = expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc2)); String errMsg = "Inconsistency of field data structures across documents for field [f] of doc [1]." - + " vector dimension: expected '4', but it has '3'."; + + " vector dimension: expected '4', but it has '6'."; assertEquals(errMsg, expected.getMessage()); } @@ -145,12 +145,12 @@ public void testIllegalDimChangeTwoDocs() throws Exception { w.commit(); Document doc2 = new Document(); - doc2.add(new KnnFloatVectorField("f", new float[3], VectorSimilarityFunction.DOT_PRODUCT)); + doc2.add(new KnnFloatVectorField("f", new float[6], VectorSimilarityFunction.DOT_PRODUCT)); IllegalArgumentException expected = expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc2)); String errMsg = "cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT " - + "to inconsistent vector dimension=3, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT"; + + "to inconsistent vector dimension=6, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT"; assertEquals(errMsg, expected.getMessage()); } } @@ -202,12 +202,12 @@ public void testIllegalDimChangeTwoWriters() throws Exception { try (IndexWriter w2 = new IndexWriter(dir, newIndexWriterConfig())) { Document doc2 = new Document(); - doc2.add(new KnnFloatVectorField("f", new float[1], VectorSimilarityFunction.DOT_PRODUCT)); + doc2.add(new KnnFloatVectorField("f", new float[2], VectorSimilarityFunction.DOT_PRODUCT)); IllegalArgumentException expected = expectThrows(IllegalArgumentException.class, () -> w2.addDocument(doc2)); assertEquals( "cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT " - + "to inconsistent vector dimension=1, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT", + + "to inconsistent vector dimension=2, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT", expected.getMessage()); } } @@ -284,7 +284,7 @@ public void testAddIndexesDirectory1() throws Exception { public void testAddIndexesDirectory01() throws Exception { String fieldName = "field"; - float[] vector = new float[1]; + float[] vector = new float[2]; Document doc = new Document(); doc.add(new KnnFloatVectorField(fieldName, vector, VectorSimilarityFunction.DOT_PRODUCT)); try (Directory dir = newDirectory(); @@ -294,6 +294,7 @@ public void testAddIndexesDirectory01() throws Exception { } try (IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig())) { vector[0] = 1; + vector[1] = 1; w2.addDocument(doc); w2.addIndexes(dir); w2.forceMerge(1); @@ -322,13 +323,13 @@ public void testIllegalDimChangeViaAddIndexesDirectory() throws Exception { } try (IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig())) { Document doc = new Document(); - doc.add(new KnnFloatVectorField("f", new float[5], VectorSimilarityFunction.DOT_PRODUCT)); + doc.add(new KnnFloatVectorField("f", new float[6], VectorSimilarityFunction.DOT_PRODUCT)); w2.addDocument(doc); IllegalArgumentException expected = expectThrows( IllegalArgumentException.class, () -> w2.addIndexes(new Directory[] {dir})); assertEquals( - "cannot change field \"f\" from vector dimension=5, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT " + "cannot change field \"f\" from vector dimension=6, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT " + "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT", expected.getMessage()); } @@ -367,7 +368,7 @@ public void testIllegalDimChangeViaAddIndexesCodecReader() throws Exception { } try (IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig())) { Document doc = new Document(); - doc.add(new KnnFloatVectorField("f", new float[5], VectorSimilarityFunction.DOT_PRODUCT)); + doc.add(new KnnFloatVectorField("f", new float[6], VectorSimilarityFunction.DOT_PRODUCT)); w2.addDocument(doc); try (DirectoryReader r = DirectoryReader.open(dir)) { IllegalArgumentException expected = @@ -375,7 +376,7 @@ public void testIllegalDimChangeViaAddIndexesCodecReader() throws Exception { IllegalArgumentException.class, () -> w2.addIndexes(new CodecReader[] {(CodecReader) getOnlyLeafReader(r)})); assertEquals( - "cannot change field \"f\" from vector dimension=5, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT " + "cannot change field \"f\" from vector dimension=6, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT " + "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT", expected.getMessage()); } @@ -419,13 +420,13 @@ public void testIllegalDimChangeViaAddIndexesSlowCodecReader() throws Exception } try (IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig())) { Document doc = new Document(); - doc.add(new KnnFloatVectorField("f", new float[5], VectorSimilarityFunction.DOT_PRODUCT)); + doc.add(new KnnFloatVectorField("f", new float[6], VectorSimilarityFunction.DOT_PRODUCT)); w2.addDocument(doc); try (DirectoryReader r = DirectoryReader.open(dir)) { IllegalArgumentException expected = expectThrows(IllegalArgumentException.class, () -> TestUtil.addIndexesSlowly(w2, r)); assertEquals( - "cannot change field \"f\" from vector dimension=5, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT " + "cannot change field \"f\" from vector dimension=6, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT " + "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT", expected.getMessage()); } @@ -486,7 +487,7 @@ public void testIllegalDimensionTooLarge() throws Exception { .contains("vector's dimensions must be <= [" + getVectorsMaxDimensions("f") + "]")); Document doc2 = new Document(); - doc2.add(new KnnFloatVectorField("f", new float[1], VectorSimilarityFunction.DOT_PRODUCT)); + doc2.add(new KnnFloatVectorField("f", new float[2], VectorSimilarityFunction.DOT_PRODUCT)); w.addDocument(doc2); Document doc3 = new Document(); @@ -531,7 +532,7 @@ public void testIllegalEmptyVector() throws Exception { assertEquals("cannot index an empty vector", e.getMessage()); Document doc2 = new Document(); - doc2.add(new KnnFloatVectorField("f", new float[1], VectorSimilarityFunction.EUCLIDEAN)); + doc2.add(new KnnFloatVectorField("f", new float[2], VectorSimilarityFunction.EUCLIDEAN)); w.addDocument(doc2); } } @@ -592,7 +593,7 @@ public void testDeleteAllVectorDocs() throws Exception { doc.add(new StringField("id", "0", Field.Store.NO)); doc.add( new KnnFloatVectorField( - "v", new float[] {2, 3, 5}, VectorSimilarityFunction.DOT_PRODUCT)); + "v", new float[] {2, 3, 5, 6}, VectorSimilarityFunction.DOT_PRODUCT)); w.addDocument(doc); w.addDocument(new Document()); w.commit(); @@ -613,7 +614,7 @@ public void testDeleteAllVectorDocs() throws Exception { // assert that knn search doesn't fail on a field with all deleted docs TopDocs results = leafReader.searchNearestVectors( - "v", randomNormalizedVector(3), 1, leafReader.getLiveDocs(), Integer.MAX_VALUE); + "v", randomNormalizedVector(4), 1, leafReader.getLiveDocs(), Integer.MAX_VALUE); assertEquals(0, results.scoreDocs.length); } } @@ -626,14 +627,14 @@ public void testKnnVectorFieldMissingFromOneSegment() throws Exception { doc.add(new StringField("id", "0", Field.Store.NO)); doc.add( new KnnFloatVectorField( - "v0", new float[] {2, 3, 5}, VectorSimilarityFunction.DOT_PRODUCT)); + "v0", new float[] {2, 3, 5, 6}, VectorSimilarityFunction.DOT_PRODUCT)); w.addDocument(doc); w.commit(); doc = new Document(); doc.add( new KnnFloatVectorField( - "v1", new float[] {2, 3, 5}, VectorSimilarityFunction.DOT_PRODUCT)); + "v1", new float[] {2, 3, 5, 6}, VectorSimilarityFunction.DOT_PRODUCT)); w.addDocument(doc); w.forceMerge(1); } @@ -649,6 +650,9 @@ public void testSparseVectors() throws Exception { VectorEncoding[] fieldVectorEncodings = new VectorEncoding[numFields]; for (int i = 0; i < numFields; i++) { fieldDims[i] = random().nextInt(20) + 1; + if (fieldDims[i] % 2 != 0) { + fieldDims[i]++; + } fieldSimilarityFunctions[i] = randomSimilarity(); fieldVectorEncodings[i] = randomVectorEncoding(); } @@ -731,7 +735,7 @@ public void testIndexedValueNotAliased() throws Exception { // We copy indexed values (as for BinaryDocValues) so the input float[] can be reused across // calls to IndexWriter.addDocument. String fieldName = "field"; - float[] v = {0}; + float[] v = {0, 0}; try (Directory dir = newDirectory(); IndexWriter iw = new IndexWriter( @@ -829,25 +833,25 @@ public void testIndexMultipleKnnVectorFields() throws Exception { IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig().setMergePolicy(newLogMergePolicy()))) { Document doc = new Document(); - float[] v = new float[] {1}; + float[] v = new float[] {1, 2}; doc.add(new KnnFloatVectorField("field1", v, VectorSimilarityFunction.EUCLIDEAN)); doc.add( new KnnFloatVectorField( - "field2", new float[] {1, 2, 3}, VectorSimilarityFunction.EUCLIDEAN)); + "field2", new float[] {1, 2, 3, 4}, VectorSimilarityFunction.EUCLIDEAN)); iw.addDocument(doc); v[0] = 2; iw.addDocument(doc); doc = new Document(); doc.add( new KnnFloatVectorField( - "field3", new float[] {1, 2, 3}, VectorSimilarityFunction.DOT_PRODUCT)); + "field3", new float[] {1, 2, 3, 4}, VectorSimilarityFunction.DOT_PRODUCT)); iw.addDocument(doc); iw.forceMerge(1); try (IndexReader reader = DirectoryReader.open(iw)) { LeafReader leaf = reader.leaves().get(0).reader(); FloatVectorValues vectorValues = leaf.getFloatVectorValues("field1"); - assertEquals(1, vectorValues.dimension()); + assertEquals(2, vectorValues.dimension()); assertEquals(2, vectorValues.size()); vectorValues.nextDoc(); assertEquals(1f, vectorValues.vectorValue()[0], 0); @@ -856,7 +860,7 @@ public void testIndexMultipleKnnVectorFields() throws Exception { assertEquals(NO_MORE_DOCS, vectorValues.nextDoc()); FloatVectorValues vectorValues2 = leaf.getFloatVectorValues("field2"); - assertEquals(3, vectorValues2.dimension()); + assertEquals(4, vectorValues2.dimension()); assertEquals(2, vectorValues2.size()); vectorValues2.nextDoc(); assertEquals(2f, vectorValues2.vectorValue()[1], 0); @@ -865,7 +869,7 @@ public void testIndexMultipleKnnVectorFields() throws Exception { assertEquals(NO_MORE_DOCS, vectorValues2.nextDoc()); FloatVectorValues vectorValues3 = leaf.getFloatVectorValues("field3"); - assertEquals(3, vectorValues3.dimension()); + assertEquals(4, vectorValues3.dimension()); assertEquals(1, vectorValues3.size()); vectorValues3.nextDoc(); assertEquals(1f, vectorValues3.vectorValue()[0], 0.1); @@ -889,6 +893,9 @@ public void testRandom() throws Exception { IndexWriter iw = new IndexWriter(dir, iwc)) { int numDoc = atLeast(100); int dimension = atLeast(10); + if (dimension % 2 != 0) { + dimension++; + } float[] scratch = new float[dimension]; int numValues = 0; float[][] values = new float[numDoc][]; @@ -965,6 +972,9 @@ public void testRandomBytes() throws Exception { IndexWriter iw = new IndexWriter(dir, iwc)) { int numDoc = atLeast(100); int dimension = atLeast(10); + if (dimension % 2 != 0) { + dimension++; + } byte[] scratch = new byte[dimension]; int numValues = 0; BytesRef[] values = new BytesRef[numDoc]; @@ -1101,6 +1111,9 @@ public void testRandomWithUpdatesAndGraph() throws Exception { IndexWriter iw = new IndexWriter(dir, iwc)) { int numDoc = atLeast(100); int dimension = atLeast(10); + if (dimension % 2 != 0) { + dimension++; + } float[][] id2value = new float[numDoc][]; for (int i = 0; i < numDoc; i++) { int id = random().nextInt(numDoc); @@ -1252,7 +1265,7 @@ protected float[] randomNormalizedVector(int dim) { return v; } - private byte[] randomVector8(int dim) { + protected byte[] randomVector8(int dim) { assert dim > 0; float[] v = randomNormalizedVector(dim); byte[] b = new byte[dim]; @@ -1268,12 +1281,12 @@ public void testCheckIndexIncludesVectors() throws Exception { Document doc = new Document(); doc.add( new KnnFloatVectorField( - "v1", randomNormalizedVector(3), VectorSimilarityFunction.EUCLIDEAN)); + "v1", randomNormalizedVector(4), VectorSimilarityFunction.EUCLIDEAN)); w.addDocument(doc); doc.add( new KnnFloatVectorField( - "v2", randomNormalizedVector(3), VectorSimilarityFunction.EUCLIDEAN)); + "v2", randomNormalizedVector(4), VectorSimilarityFunction.EUCLIDEAN)); w.addDocument(doc); } @@ -1360,7 +1373,10 @@ public void testAdvance() throws Exception { public void testVectorValuesReportCorrectDocs() throws Exception { final int numDocs = atLeast(1000); - final int dim = random().nextInt(20) + 1; + int dim = random().nextInt(20) + 1; + if (dim % 2 != 0) { + dim++; + } double fieldValuesCheckSum = 0; int fieldDocCount = 0; long fieldSumDocIDs = 0;