Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use read advice consistently in the knn vector formats #14076

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsWriter;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.store.ReadAdvice;

class Lucene99RWHnswScalarQuantizationVectorsFormat
extends Lucene99HnswScalarQuantizedVectorsFormat {
Expand All @@ -54,15 +55,16 @@ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException

static class Lucene99RWScalarQuantizedFormat extends Lucene99ScalarQuantizedVectorsFormat {
private static final FlatVectorsFormat rawVectorFormat =
new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer());
new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer(), ReadAdvice.RANDOM);

@Override
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99ScalarQuantizedVectorsWriter(
state,
null,
rawVectorFormat.fieldsWriter(state),
new ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer()));
new ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer()),
ReadAdvice.RANDOM);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.store.ReadAdvice;
import org.apache.lucene.util.hnsw.HnswGraph;

/**
Expand Down Expand Up @@ -128,7 +129,8 @@ public HnswBitVectorsFormat(
} else {
this.mergeExec = null;
}
this.flatVectorsFormat = new Lucene99FlatVectorsFormat(new FlatBitVectorsScorer());
this.flatVectorsFormat =
new Lucene99FlatVectorsFormat(new FlatBitVectorsScorer(), ReadAdvice.RANDOM);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,4 @@ public abstract void search(
public KnnVectorsReader getMergeInstance() {
return this;
}

/**
* Optional: reset or close merge resources used in the reader
*
* <p>The default implementation is empty
*/
public void finishMerge() throws IOException {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,18 +109,9 @@ public final void merge(MergeState mergeState) throws IOException {
}
}
}
finishMerge(mergeState);
finish();
}

private void finishMerge(MergeState mergeState) throws IOException {
for (KnnVectorsReader reader : mergeState.knnVectorsReaders) {
if (reader != null) {
reader.finishMerge();
}
}
}

/** Tracks state of one sub-reader that we are merging */
private static class FloatVectorValuesSub extends DocIDMerger.Sub {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,4 @@ public abstract RandomVectorScorer getRandomVectorScorer(String field, float[] t
*/
public abstract RandomVectorScorer getRandomVectorScorer(String field, byte[] target)
throws IOException;

/**
* Returns an instance optimized for merging. This instance may only be consumed in the thread
* that called {@link #getMergeInstance()}.
*
* <p>The default implementation returns {@code this}
*/
@Override
public FlatVectorsReader getMergeInstance() {
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.ReadAdvice;

/**
* Lucene 9.9 flat vector format, which encodes numeric vector values
Expand Down Expand Up @@ -78,21 +79,23 @@ public final class Lucene99FlatVectorsFormat extends FlatVectorsFormat {

static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16;
private final FlatVectorsScorer vectorsScorer;
private final ReadAdvice readAdvice;

/** Constructs a format */
public Lucene99FlatVectorsFormat(FlatVectorsScorer vectorsScorer) {
public Lucene99FlatVectorsFormat(FlatVectorsScorer vectorsScorer, ReadAdvice readAdvice) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++ allowing to pass the read advice here is good, since the higher-level usage of this format really should dictate the intended usage.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea of passing the read advice. If the top level format can dictate the read advice then that makes code more better. +1 on this idea.

super(NAME);
this.vectorsScorer = vectorsScorer;
this.readAdvice = readAdvice;
}

@Override
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99FlatVectorsWriter(state, vectorsScorer);
return new Lucene99FlatVectorsWriter(state, vectorsScorer, readAdvice);
}

@Override
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene99FlatVectorsReader(state, vectorsScorer);
return new Lucene99FlatVectorsReader(state, vectorsScorer, readAdvice);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding;

import java.io.IOException;
import java.io.UncheckedIOException;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
Expand Down Expand Up @@ -60,8 +59,8 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
private final IndexInput vectorData;
private final FieldInfos fieldInfos;

public Lucene99FlatVectorsReader(SegmentReadState state, FlatVectorsScorer scorer)
throws IOException {
public Lucene99FlatVectorsReader(
SegmentReadState state, FlatVectorsScorer scorer, ReadAdvice readAdvice) throws IOException {
super(scorer);
int versionMeta = readMetadata(state);
this.fieldInfos = state.fieldInfos;
Expand All @@ -73,9 +72,7 @@ public Lucene99FlatVectorsReader(SegmentReadState state, FlatVectorsScorer score
versionMeta,
Lucene99FlatVectorsFormat.VECTOR_DATA_EXTENSION,
Lucene99FlatVectorsFormat.VECTOR_DATA_CODEC_NAME,
// Flat formats are used to randomly access vectors from their node ID that is stored
// in the HNSW graph.
state.context.withReadAdvice(ReadAdvice.RANDOM));
state.context.withReadAdvice(readAdvice));
success = true;
} finally {
if (success == false) {
Expand Down Expand Up @@ -171,17 +168,6 @@ public void checkIntegrity() throws IOException {
CodecUtil.checksumEntireFile(vectorData);
}

@Override
public FlatVectorsReader getMergeInstance() {
try {
// Update the read advice since vectors are guaranteed to be accessed sequentially for merge
this.vectorData.updateReadAdvice(ReadAdvice.SEQUENTIAL);
return this;
} catch (IOException exception) {
throw new UncheckedIOException(exception);
}
}

private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) {
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry;
Expand Down Expand Up @@ -262,13 +248,6 @@ public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) thr
target);
}

@Override
public void finishMerge() throws IOException {
// This makes sure that the access pattern hint is reverted back since HNSW implementation
// needs it
this.vectorData.updateReadAdvice(ReadAdvice.RANDOM);
}

@Override
public void close() throws IOException {
IOUtils.close(vectorData);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,18 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
RamUsageEstimator.shallowSizeOfInstance(Lucene99FlatVectorsWriter.class);

private final SegmentWriteState segmentWriteState;
private final ReadAdvice readAdvice;
private final IndexOutput meta, vectorData;

private final List<FieldWriter<?>> fields = new ArrayList<>();

private boolean finished;

public Lucene99FlatVectorsWriter(SegmentWriteState state, FlatVectorsScorer scorer)
throws IOException {
public Lucene99FlatVectorsWriter(
SegmentWriteState state, FlatVectorsScorer scorer, ReadAdvice readAdvice) throws IOException {
super(scorer);
segmentWriteState = state;
this.readAdvice = readAdvice;
String metaFileName =
IndexFileNames.segmentFileName(
state.segmentInfo.name, state.segmentSuffix, Lucene99FlatVectorsFormat.META_EXTENSION);
Expand Down Expand Up @@ -282,7 +285,7 @@ public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
// to perform random reads.
vectorDataInput =
segmentWriteState.directory.openInput(
tempVectorData.getName(), IOContext.DEFAULT.withReadAdvice(ReadAdvice.RANDOM));
tempVectorData.getName(), IOContext.DEFAULT.withReadAdvice(readAdvice));
jimczi marked this conversation as resolved.
Show resolved Hide resolved
// copy the temporary file vectors to the actual data file
vectorData.copyBytes(vectorDataInput, vectorDataInput.length() - CodecUtil.footerLength());
CodecUtil.retrieveChecksum(vectorDataInput);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.store.ReadAdvice;
import org.apache.lucene.util.hnsw.HnswGraph;

/**
Expand Down Expand Up @@ -134,8 +135,15 @@ public Lucene99HnswScalarQuantizedVectorsFormat(
} else {
this.mergeExec = null;
}

/**
* Defines the format used for storing, reading, and merging vectors on disk. Flat formats
* enable random access to vectors based on their node ID, as recorded in the HNSW graph. To
* ensure consistent access, the {@link ReadAdvice#RANDOM} read advice is used.
*/
this.flatVectorsFormat =
new Lucene99ScalarQuantizedVectorsFormat(confidenceInterval, bits, compress);
new Lucene99ScalarQuantizedVectorsFormat(
confidenceInterval, bits, compress, ReadAdvice.RANDOM);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.store.ReadAdvice;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;

Expand Down Expand Up @@ -130,9 +131,14 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
*/
private final int beamWidth;

/** The format for storing, reading, and merging vectors on disk. */
/**
* Defines the format used for storing, reading, and merging vectors on disk. Flat formats enable
* random access to vectors based on their node ID, as recorded in the HNSW graph. To ensure
* consistent access, the {@link ReadAdvice#RANDOM} read advice is used.
*/
private static final FlatVectorsFormat flatVectorsFormat =
new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
new Lucene99FlatVectorsFormat(
FlatVectorScorerUtil.getLucene99FlatVectorsScorer(), ReadAdvice.RANDOM);

private final int numMergeWorkers;
private final TaskExecutor mergeExec;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,11 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader

private final FlatVectorsReader flatVectorsReader;
private final FieldInfos fieldInfos;
private final IntObjectHashMap<FieldEntry> fields;
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorIndex;

public Lucene99HnswVectorsReader(SegmentReadState state, FlatVectorsReader flatVectorsReader)
throws IOException {
this.fields = new IntObjectHashMap<>();
this.flatVectorsReader = flatVectorsReader;
boolean success = false;
this.fieldInfos = state.fieldInfos;
Expand Down Expand Up @@ -114,24 +113,6 @@ public Lucene99HnswVectorsReader(SegmentReadState state, FlatVectorsReader flatV
}
}

private Lucene99HnswVectorsReader(
Lucene99HnswVectorsReader reader, FlatVectorsReader flatVectorsReader) {
this.flatVectorsReader = flatVectorsReader;
this.fieldInfos = reader.fieldInfos;
this.fields = reader.fields;
this.vectorIndex = reader.vectorIndex;
}

@Override
public KnnVectorsReader getMergeInstance() {
return new Lucene99HnswVectorsReader(this, this.flatVectorsReader.getMergeInstance());
}

@Override
public void finishMerge() throws IOException {
flatVectorsReader.finishMerge();
}

private static IndexInput openDataInput(
SegmentReadState state,
int versionMeta,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.store.ReadAdvice;

/**
* Format supporting vector quantization, storage, and retrieval
Expand All @@ -50,8 +51,14 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
static final String META_EXTENSION = "vemq";
static final String VECTOR_DATA_EXTENSION = "veq";

/**
* Defines the format used for storing, reading, and merging raw vectors on disk. For this format,
* the {@link ReadAdvice#SEQUENTIAL} read advice is employed, as nearest neighbors are retrieved
* exclusively using a brute-force approach.
*/
private static final FlatVectorsFormat rawVectorFormat =
new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
new Lucene99FlatVectorsFormat(
FlatVectorScorerUtil.getLucene99FlatVectorsScorer(), ReadAdvice.SEQUENTIAL);

/** The minimum confidence interval */
private static final float MINIMUM_CONFIDENCE_INTERVAL = 0.9f;
Expand All @@ -71,10 +78,15 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
final byte bits;
final boolean compress;
final Lucene99ScalarQuantizedVectorScorer flatVectorScorer;
final ReadAdvice readAdvice;

/** Constructs a format using default graph construction parameters */
public Lucene99ScalarQuantizedVectorsFormat() {
this(null, 7, false);
/**
* For this format, the {@link ReadAdvice#SEQUENTIAL} read advice is employed, as nearest
* neighbors are retrieved exclusively using a brute-force approach.
*/
this(null, 7, false, ReadAdvice.SEQUENTIAL);
}

/**
Expand All @@ -91,7 +103,7 @@ public Lucene99ScalarQuantizedVectorsFormat() {
* during searching, at some decode speed penalty.
*/
public Lucene99ScalarQuantizedVectorsFormat(
Float confidenceInterval, int bits, boolean compress) {
Float confidenceInterval, int bits, boolean compress, ReadAdvice readAdvice) {
super(NAME);
if (confidenceInterval != null
&& confidenceInterval != DYNAMIC_CONFIDENCE_INTERVAL
Expand Down Expand Up @@ -119,6 +131,7 @@ public Lucene99ScalarQuantizedVectorsFormat(
this.compress = compress;
this.flatVectorScorer =
new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
this.readAdvice = readAdvice;
}

public static float calculateDefaultConfidenceInterval(int vectorDimension) {
Expand Down Expand Up @@ -151,12 +164,13 @@ public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOExceptio
bits,
compress,
rawVectorFormat.fieldsWriter(state),
flatVectorScorer);
flatVectorScorer,
readAdvice);
}

@Override
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene99ScalarQuantizedVectorsReader(
state, rawVectorFormat.fieldsReader(state), flatVectorScorer);
state, rawVectorFormat.fieldsReader(state), flatVectorScorer, readAdvice);
}
}
Loading
Loading