|
18 | 18 |
|
19 | 19 | import java.io.IOException;
|
20 | 20 | import java.util.Collections;
|
| 21 | +import java.util.List; |
21 | 22 | import java.util.Map;
|
22 | 23 | import org.apache.lucene.codecs.CodecUtil;
|
23 | 24 | import org.apache.lucene.codecs.DocValuesFormat;
|
|
111 | 112 | * <li>0: EUCLIDEAN distance. ({@link VectorSimilarityFunction#EUCLIDEAN})
|
112 | 113 | * <li>1: DOT_PRODUCT similarity. ({@link VectorSimilarityFunction#DOT_PRODUCT})
|
113 | 114 | * <li>2: COSINE similarity. ({@link VectorSimilarityFunction#COSINE})
|
| 115 | + * <li>3: MAXIMUM_INNER_PRODUCT similarity. ({@link |
| 116 | + * VectorSimilarityFunction#MAXIMUM_INNER_PRODUCT}) |
114 | 117 | * </ul>
|
115 | 118 | * </ul>
|
116 | 119 | *
|
@@ -284,10 +287,38 @@ private static VectorEncoding getVectorEncoding(IndexInput input, byte b) throws
|
284 | 287 | }
|
285 | 288 |
|
286 | 289 | private static VectorSimilarityFunction getDistFunc(IndexInput input, byte b) throws IOException {
|
287 |
| - if (b < 0 || b >= VectorSimilarityFunction.values().length) { |
288 |
| - throw new CorruptIndexException("invalid distance function: " + b, input); |
| 290 | + try { |
| 291 | + return distOrdToFunc(b); |
| 292 | + } catch (IllegalArgumentException e) { |
| 293 | + throw new CorruptIndexException("invalid distance function: " + b, input, e); |
289 | 294 | }
|
290 |
| - return VectorSimilarityFunction.values()[b]; |
| 295 | + } |
| 296 | + |
| 297 | + // List of vector similarity functions. This list is defined here, in order |
| 298 | + // to avoid an undesirable dependency on the declaration and order of values |
| 299 | + // in VectorSimilarityFunction. The list values and order have been chosen to |
| 300 | + // match that of VectorSimilarityFunction in, at least, Lucene 9.10. Values |
| 301 | + static final List<VectorSimilarityFunction> SIMILARITY_FUNCTIONS = |
| 302 | + List.of( |
| 303 | + VectorSimilarityFunction.EUCLIDEAN, |
| 304 | + VectorSimilarityFunction.DOT_PRODUCT, |
| 305 | + VectorSimilarityFunction.COSINE, |
| 306 | + VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT); |
| 307 | + |
| 308 | + static VectorSimilarityFunction distOrdToFunc(byte i) { |
| 309 | + if (i < 0 || i >= SIMILARITY_FUNCTIONS.size()) { |
| 310 | + throw new IllegalArgumentException("invalid distance function: " + i); |
| 311 | + } |
| 312 | + return SIMILARITY_FUNCTIONS.get(i); |
| 313 | + } |
| 314 | + |
| 315 | + static byte distFuncToOrd(VectorSimilarityFunction func) { |
| 316 | + for (int i = 0; i < SIMILARITY_FUNCTIONS.size(); i++) { |
| 317 | + if (SIMILARITY_FUNCTIONS.get(i).equals(func)) { |
| 318 | + return (byte) i; |
| 319 | + } |
| 320 | + } |
| 321 | + throw new IllegalArgumentException("invalid distance function: " + func); |
291 | 322 | }
|
292 | 323 |
|
293 | 324 | static {
|
@@ -378,7 +409,7 @@ public void write(
|
378 | 409 | }
|
379 | 410 | output.writeVInt(fi.getVectorDimension());
|
380 | 411 | output.writeByte((byte) fi.getVectorEncoding().ordinal());
|
381 |
| - output.writeByte((byte) fi.getVectorSimilarityFunction().ordinal()); |
| 412 | + output.writeByte(distFuncToOrd(fi.getVectorSimilarityFunction())); |
382 | 413 | }
|
383 | 414 | CodecUtil.writeFooter(output);
|
384 | 415 | }
|
|
0 commit comments