Skip to content

Commit 17cbedc

Browse files
authored
FieldInfosFormat translation should be independent of VectorSimilartyFunction enum (#13119)
This commit updates the FieldInfosFormat translation of vector similarity functions to be independent of the VectorSimilartyFunction enum. The VectorSimilartyFunction enum lives outside of the codec format, and the format should not inadvertently depend upon the declaration order or values in VectorSimilartyFunction. The format should be in charge of the translation of similarity function to format ordinal (and visa versa). In reality, and for now, the translation remains the same as the declaration order, but this may not be the case in the future.
1 parent 6732d2b commit 17cbedc

File tree

5 files changed

+113
-11
lines changed

5 files changed

+113
-11
lines changed

lucene/core/src/java/org/apache/lucene/codecs/lucene94/Lucene94FieldInfosFormat.java

+35-4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.io.IOException;
2020
import java.util.Collections;
21+
import java.util.List;
2122
import java.util.Map;
2223
import org.apache.lucene.codecs.CodecUtil;
2324
import org.apache.lucene.codecs.DocValuesFormat;
@@ -111,6 +112,8 @@
111112
* <li>0: EUCLIDEAN distance. ({@link VectorSimilarityFunction#EUCLIDEAN})
112113
* <li>1: DOT_PRODUCT similarity. ({@link VectorSimilarityFunction#DOT_PRODUCT})
113114
* <li>2: COSINE similarity. ({@link VectorSimilarityFunction#COSINE})
115+
* <li>3: MAXIMUM_INNER_PRODUCT similarity. ({@link
116+
* VectorSimilarityFunction#MAXIMUM_INNER_PRODUCT})
114117
* </ul>
115118
* </ul>
116119
*
@@ -284,10 +287,38 @@ private static VectorEncoding getVectorEncoding(IndexInput input, byte b) throws
284287
}
285288

286289
private static VectorSimilarityFunction getDistFunc(IndexInput input, byte b) throws IOException {
287-
if (b < 0 || b >= VectorSimilarityFunction.values().length) {
288-
throw new CorruptIndexException("invalid distance function: " + b, input);
290+
try {
291+
return distOrdToFunc(b);
292+
} catch (IllegalArgumentException e) {
293+
throw new CorruptIndexException("invalid distance function: " + b, input, e);
289294
}
290-
return VectorSimilarityFunction.values()[b];
295+
}
296+
297+
// List of vector similarity functions. This list is defined here, in order
298+
// to avoid an undesirable dependency on the declaration and order of values
299+
// in VectorSimilarityFunction. The list values and order have been chosen to
300+
// match that of VectorSimilarityFunction in, at least, Lucene 9.10. Values
301+
static final List<VectorSimilarityFunction> SIMILARITY_FUNCTIONS =
302+
List.of(
303+
VectorSimilarityFunction.EUCLIDEAN,
304+
VectorSimilarityFunction.DOT_PRODUCT,
305+
VectorSimilarityFunction.COSINE,
306+
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT);
307+
308+
static VectorSimilarityFunction distOrdToFunc(byte i) {
309+
if (i < 0 || i >= SIMILARITY_FUNCTIONS.size()) {
310+
throw new IllegalArgumentException("invalid distance function: " + i);
311+
}
312+
return SIMILARITY_FUNCTIONS.get(i);
313+
}
314+
315+
static byte distFuncToOrd(VectorSimilarityFunction func) {
316+
for (int i = 0; i < SIMILARITY_FUNCTIONS.size(); i++) {
317+
if (SIMILARITY_FUNCTIONS.get(i).equals(func)) {
318+
return (byte) i;
319+
}
320+
}
321+
throw new IllegalArgumentException("invalid distance function: " + func);
291322
}
292323

293324
static {
@@ -378,7 +409,7 @@ public void write(
378409
}
379410
output.writeVInt(fi.getVectorDimension());
380411
output.writeByte((byte) fi.getVectorEncoding().ordinal());
381-
output.writeByte((byte) fi.getVectorSimilarityFunction().ordinal());
412+
output.writeByte(distFuncToOrd(fi.getVectorSimilarityFunction()));
382413
}
383414
CodecUtil.writeFooter(output);
384415
}

lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java

+16-6
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.io.IOException;
2323
import java.util.Arrays;
2424
import java.util.HashMap;
25+
import java.util.List;
2526
import java.util.Map;
2627
import org.apache.lucene.codecs.CodecUtil;
2728
import org.apache.lucene.codecs.FlatVectorsReader;
@@ -171,15 +172,24 @@ private void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) {
171172
}
172173
}
173174

175+
// List of vector similarity functions. This list is defined here, in order
176+
// to avoid an undesirable dependency on the declaration and order of values
177+
// in VectorSimilarityFunction. The list values and order must be identical
178+
// to that of {@link o.a.l.c.l.Lucene94FieldInfosFormat#SIMILARITY_FUNCTIONS}.
179+
public static final List<VectorSimilarityFunction> SIMILARITY_FUNCTIONS =
180+
List.of(
181+
VectorSimilarityFunction.EUCLIDEAN,
182+
VectorSimilarityFunction.DOT_PRODUCT,
183+
VectorSimilarityFunction.COSINE,
184+
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT);
185+
174186
public static VectorSimilarityFunction readSimilarityFunction(DataInput input)
175187
throws IOException {
176-
int similarityFunctionId = input.readInt();
177-
if (similarityFunctionId < 0
178-
|| similarityFunctionId >= VectorSimilarityFunction.values().length) {
179-
throw new CorruptIndexException(
180-
"Invalid similarity function id: " + similarityFunctionId, input);
188+
int i = input.readInt();
189+
if (i < 0 || i >= SIMILARITY_FUNCTIONS.size()) {
190+
throw new IllegalArgumentException("invalid distance function: " + i);
181191
}
182-
return VectorSimilarityFunction.values()[similarityFunctionId];
192+
return SIMILARITY_FUNCTIONS.get(i);
183193
}
184194

185195
public static VectorEncoding readVectorEncoding(DataInput input) throws IOException {

lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java

+12-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.lucene.codecs.lucene99;
1919

2020
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
21+
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
2122

2223
import java.io.IOException;
2324
import java.util.ArrayList;
@@ -33,6 +34,7 @@
3334
import org.apache.lucene.index.MergeState;
3435
import org.apache.lucene.index.SegmentWriteState;
3536
import org.apache.lucene.index.Sorter;
37+
import org.apache.lucene.index.VectorSimilarityFunction;
3638
import org.apache.lucene.search.DocIdSetIterator;
3739
import org.apache.lucene.search.TaskExecutor;
3840
import org.apache.lucene.store.IndexOutput;
@@ -436,7 +438,7 @@ private void writeMeta(
436438
throws IOException {
437439
meta.writeInt(field.number);
438440
meta.writeInt(field.getVectorEncoding().ordinal());
439-
meta.writeInt(field.getVectorSimilarityFunction().ordinal());
441+
meta.writeInt(distFuncToOrd(field.getVectorSimilarityFunction()));
440442
meta.writeVLong(vectorIndexOffset);
441443
meta.writeVLong(vectorIndexLength);
442444
meta.writeVInt(field.getVectorDimension());
@@ -500,6 +502,15 @@ public void close() throws IOException {
500502
IOUtils.close(meta, vectorIndex, flatVectorWriter);
501503
}
502504

505+
static int distFuncToOrd(VectorSimilarityFunction func) {
506+
for (int i = 0; i < SIMILARITY_FUNCTIONS.size(); i++) {
507+
if (SIMILARITY_FUNCTIONS.get(i).equals(func)) {
508+
return (byte) i;
509+
}
510+
}
511+
throw new IllegalArgumentException("invalid distance function: " + func);
512+
}
513+
503514
private static class FieldWriter<T> extends KnnFieldVectorsWriter<T> {
504515

505516
private static final long SHALLOW_SIZE =
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.lucene.codecs.lucene94;
18+
19+
import java.util.Arrays;
20+
import org.apache.lucene.codecs.Codec;
21+
import org.apache.lucene.index.VectorSimilarityFunction;
22+
import org.apache.lucene.tests.index.BaseFieldInfoFormatTestCase;
23+
import org.apache.lucene.tests.util.TestUtil;
24+
25+
public class TestLucene94FieldInfosFormat extends BaseFieldInfoFormatTestCase {
26+
@Override
27+
protected Codec getCodec() {
28+
return TestUtil.getDefaultCodec();
29+
}
30+
31+
// Ensures that all expected vector similarity functions are translatable
32+
// in the format.
33+
public void testVectorSimilarityFuncs() {
34+
// This does not necessarily have to be all similarity functions, but
35+
// differences should be considered carefully.
36+
var expectedValues = Arrays.stream(VectorSimilarityFunction.values()).toList();
37+
38+
assertEquals(Lucene94FieldInfosFormat.SIMILARITY_FUNCTIONS, expectedValues);
39+
}
40+
}

lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java

+10
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
2020

2121
import java.util.ArrayList;
22+
import java.util.Arrays;
2223
import java.util.List;
2324
import org.apache.lucene.codecs.Codec;
2425
import org.apache.lucene.codecs.FilterCodec;
@@ -186,4 +187,13 @@ public void testLimits() {
186187
new Lucene99HnswScalarQuantizedVectorsFormat(
187188
20, 100, 1, null, new SameThreadExecutorService()));
188189
}
190+
191+
// Ensures that all expected vector similarity functions are translatable
192+
// in the format.
193+
public void testVectorSimilarityFuncs() {
194+
// This does not necessarily have to be all similarity functions, but
195+
// differences should be considered carefully.
196+
var expectedValues = Arrays.stream(VectorSimilarityFunction.values()).toList();
197+
assertEquals(Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS, expectedValues);
198+
}
189199
}

0 commit comments

Comments
 (0)