Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Vector API to decode BKD docIds #14203

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ Optimizations

* GITHUB#14176: Reduce when visiting bpv24-encoded doc ids in BKD leaves. (Guo Feng)

# GITHUB#14203: Use Vector API to decode BKD docIds. (GuoFeng)

Bug Fixes
---------------------

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.benchmark.jmh;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.bkd.BKDConfig;
import org.apache.lucene.util.bkd.BKDWriter;
import org.apache.lucene.util.bkd.DocIdsWriter;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;

@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Benchmark)
@Warmup(iterations = 3, time = 1)
@Measurement(iterations = 5, time = 1)
@Fork(value = 1)
public class BKDCodecBenchmark {

private static final int SIZE = BKDConfig.DEFAULT_MAX_POINTS_IN_LEAF_NODE;

@Param({"16", "24"})
public int bpv;

private Directory dir;
private DocIdsWriter legacy;
private IndexInput legacyIn;
private DocIdsWriter vector;
private IndexInput vectorIn;
private int[] docs;

@Setup(Level.Trial)
public void setupTrial() throws IOException {
Path path = Files.createTempDirectory("bkd");
dir = MMapDirectory.open(path);
docs = new int[SIZE];
legacy = new DocIdsWriter(SIZE, BKDWriter.VERSION_META_FILE);
legacyIn = writeDocIds("legacy", docs, legacy);
vector = new DocIdsWriter(SIZE, BKDWriter.VERSION_VECTORIZED_DOCID);
vectorIn = writeDocIds("current", docs, vector);
}

private IndexInput writeDocIds(String file, int[] docs, DocIdsWriter writer) throws IOException {
try (IndexOutput out = dir.createOutput(file, IOContext.DEFAULT)) {
Random r = new Random(0);
// avoid cluster encoding
docs[0] = 1;
docs[1] = (1 << bpv) - 1;
for (int i = 2; i < SIZE; ++i) {
docs[i] = r.nextInt(1 << bpv);
}
writer.writeDocIds(docs, 0, SIZE, out);
}
return dir.openInput(file, IOContext.DEFAULT);
}

@Setup(Level.Invocation)
public void setupInvocation() throws IOException {
legacyIn.seek(0);
vectorIn.seek(0);
}

@TearDown(Level.Trial)
public void tearDownTrial() throws IOException {
IOUtils.close(legacyIn, vectorIn, dir);
}

private int count(int iter) {
return iter % 20 == 0 ? SIZE - 1 : SIZE;
}

@Benchmark
public void scalar(Blackhole bh) throws IOException {
for (int i = 0; i <= 100; i++) {
int count = count(i);
legacy.readInts(legacyIn, count, docs);
bh.consume(docs);
setupInvocation();
}
}

@Benchmark
public void vector(Blackhole bh) throws IOException {
for (int i = 0; i <= 100; i++) {
int count = count(i);
vector.readInts(vectorIn, count, docs);
bh.consume(docs);
setupInvocation();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.benchmark.jmh;

import java.io.IOException;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;

@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Benchmark)
@Warmup(iterations = 3, time = 1)
@Measurement(iterations = 5, time = 1)
@Fork(value = 1)
/** no-commit : remove before merge */
public class ShiftMaskBenchmark {

private int[] counts;
private int[] source;
private int[] dest;

@Setup(Level.Trial)
public void setupTrial() throws IOException {
Random r = new Random(0);
source = new int[1024];
dest = new int[1024];
for (int i = 0; i < 512; i++) {
source[i] = r.nextInt(1 << 24);
}
counts = new int[] {255, 256, 511, 512};
}

@Benchmark
public void varOffset(Blackhole bh) throws IOException {
for (int count : counts) {
shiftMask(source, dest, count & 0x1, count, 8, 0xFF);
}
}

@Benchmark
public void fixOffset(Blackhole bh) throws IOException {
for (int count : counts) {
shiftMask(source, dest, 1, count, 8, 0xFF);
}
}

private static void shiftMask(int[] src, int[] dst, int offset, int count, int shift, int mask) {
for (int i = 0; i < count; i++) {
dst[i] = (src[i + offset] >> shift) & mask;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
package org.apache.lucene.codecs.lucene90;

import java.io.IOException;
import java.util.Map;
import org.apache.lucene.codecs.PointsFormat;
import org.apache.lucene.codecs.PointsReader;
import org.apache.lucene.codecs.PointsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.util.bkd.BKDWriter;

/**
* Lucene 9.0 point format, which encodes dimensional values in a block KD-tree structure for fast
Expand Down Expand Up @@ -59,18 +61,39 @@ public final class Lucene90PointsFormat extends PointsFormat {
public static final String META_EXTENSION = "kdm";

static final int VERSION_START = 0;
static final int VERSION_CURRENT = VERSION_START;
static final int VERSION_BKD_VECTORIZED_BPV24 = 1;
static final int VERSION_CURRENT = VERSION_BKD_VECTORIZED_BPV24;

private static final Map<Integer, Integer> VERSION_TO_BKD_VERSION =
Map.of(
VERSION_START, BKDWriter.VERSION_META_FILE,
VERSION_BKD_VECTORIZED_BPV24, BKDWriter.VERSION_VECTORIZED_DOCID);

private final int version;

/** Sole constructor */
public Lucene90PointsFormat() {}
public Lucene90PointsFormat() {
this(VERSION_CURRENT);
}

public Lucene90PointsFormat(int version) {
if (VERSION_TO_BKD_VERSION.containsKey(version) == false) {
throw new IllegalArgumentException("Invalid version: " + version);
}
this.version = version;
}

@Override
public PointsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene90PointsWriter(state);
return new Lucene90PointsWriter(state, version);
}

@Override
public PointsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene90PointsReader(state);
}

static int bkdVersion(int version) {
return VERSION_TO_BKD_VERSION.get(version);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,18 @@ public class Lucene90PointsWriter extends PointsWriter {
final SegmentWriteState writeState;
final int maxPointsInLeafNode;
final double maxMBSortInHeap;
final int version;
private boolean finished;

/** Full constructor */
public Lucene90PointsWriter(
SegmentWriteState writeState, int maxPointsInLeafNode, double maxMBSortInHeap)
SegmentWriteState writeState, int maxPointsInLeafNode, double maxMBSortInHeap, int version)
throws IOException {
assert writeState.fieldInfos.hasPointValues();
this.writeState = writeState;
this.maxPointsInLeafNode = maxPointsInLeafNode;
this.maxMBSortInHeap = maxMBSortInHeap;
this.version = version;
String dataFileName =
IndexFileNames.segmentFileName(
writeState.segmentInfo.name,
Expand Down Expand Up @@ -105,15 +107,22 @@ public Lucene90PointsWriter(
}
}

public Lucene90PointsWriter(
SegmentWriteState writeState, int maxPointsInLeafNode, double maxMBSortInHeap)
throws IOException {
this(writeState, maxPointsInLeafNode, maxMBSortInHeap, Lucene90PointsFormat.VERSION_CURRENT);
}

/**
* Uses the defaults values for {@code maxPointsInLeafNode} (512) and {@code maxMBSortInHeap}
* (16.0)
*/
public Lucene90PointsWriter(SegmentWriteState writeState) throws IOException {
public Lucene90PointsWriter(SegmentWriteState writeState, int version) throws IOException {
this(
writeState,
BKDConfig.DEFAULT_MAX_POINTS_IN_LEAF_NODE,
BKDWriter.DEFAULT_MAX_MB_SORT_IN_HEAP);
BKDWriter.DEFAULT_MAX_MB_SORT_IN_HEAP,
version);
}

@Override
Expand All @@ -135,7 +144,8 @@ public void writeField(FieldInfo fieldInfo, PointsReader reader) throws IOExcept
writeState.segmentInfo.name,
config,
maxMBSortInHeap,
values.size())) {
values.size(),
Lucene90PointsFormat.bkdVersion(version))) {

if (values instanceof MutablePointTree) {
IORunnable finalizer =
Expand Down Expand Up @@ -233,7 +243,8 @@ public void merge(MergeState mergeState) throws IOException {
writeState.segmentInfo.name,
config,
maxMBSortInHeap,
totMaxSize)) {
totMaxSize,
Lucene90PointsFormat.bkdVersion(version))) {
List<PointValues> pointValues = new ArrayList<>();
List<MergeState.DocMap> docMaps = new ArrayList<>();
for (int i = 0; i < mergeState.pointsReaders.length; i++) {
Expand Down
10 changes: 5 additions & 5 deletions lucene/core/src/java/org/apache/lucene/util/bkd/BKDReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
* @lucene.experimental
*/
public class BKDReader extends PointValues {

final BKDConfig config;
final int numLeaves;
final IndexInput in;
Expand Down Expand Up @@ -261,7 +260,7 @@ private BKDPointTree(
1,
minPackedValue,
maxPackedValue,
new BKDReaderDocIDSetIterator(config.maxPointsInLeafNode()),
new BKDReaderDocIDSetIterator(config.maxPointsInLeafNode(), version),
new byte[config.packedBytesLength()],
new byte[config.packedIndexBytesLength()],
new byte[config.packedIndexBytesLength()],
Expand Down Expand Up @@ -590,7 +589,8 @@ public void addAll(PointValues.IntersectVisitor visitor, boolean grown) throws I
// How many points are stored in this leaf cell:
int count = leafNodes.readVInt();
// No need to call grow(), it has been called up-front
docIdsWriter.readInts(leafNodes, count, visitor);
// Borrow scratchIterator.docIds as decoding buffer
docIdsWriter.readInts(leafNodes, count, visitor, scratchIterator.docIDs);
} else {
pushLeft();
addAll(visitor, grown);
Expand Down Expand Up @@ -1028,9 +1028,9 @@ private static class BKDReaderDocIDSetIterator extends DocIdSetIterator {
final int[] docIDs;
private final DocIdsWriter docIdsWriter;

public BKDReaderDocIDSetIterator(int maxPointsInLeafNode) {
public BKDReaderDocIDSetIterator(int maxPointsInLeafNode, int version) {
this.docIDs = new int[maxPointsInLeafNode];
this.docIdsWriter = new DocIdsWriter(maxPointsInLeafNode);
this.docIdsWriter = new DocIdsWriter(maxPointsInLeafNode, version);
}

@Override
Expand Down
Loading