Skip to content

Adjust IVF fixup phase to sometimes bypass some of the neighborhood calculations #130490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ private static boolean stepLloyd(
float[][] centroids,
float[][] nextCentroids,
int[] assignments,
List<int[]> neighborhoods
List<NeighborHood> neighborhoods
) throws IOException {
boolean changed = false;
int dim = vectors.dimension();
Expand Down Expand Up @@ -124,11 +124,20 @@ private static boolean stepLloyd(
return changed;
}

private static int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) {
private static int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centroidIdx, NeighborHood neighborhood) {
int bestCentroidOffset = centroidIdx;
assert centroidIdx >= 0 && centroidIdx < centroids.length;
float minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]);
for (int offset : centroidOffsets) {
for (int i = 0; i < neighborhood.neighbors.length; i++) {
int offset = neighborhood.neighbors[i];
// float score = neighborhood.scores[i];
assert offset >= 0 && offset < centroids.length : "Invalid neighbor offset: " + offset;
if (minDsq < neighborhood.maxIntraDistance) {
// if the distance found is smaller than the maximum intra-cluster distance
// we don't consider it for further re-assignment
return bestCentroidOffset;
}
// compute the distance to the centroid
float dsq = VectorUtil.squareDistance(vector, centroids[offset]);
if (dsq < minDsq) {
minDsq = dsq;
Expand All @@ -151,7 +160,7 @@ private static int getBestCentroid(float[][] centroids, float[] vector) {
return bestCentroidOffset;
}

private void computeNeighborhoods(float[][] centers, List<int[]> neighborhoods, int clustersPerNeighborhood) {
private void computeNeighborhoods(float[][] centers, List<NeighborHood> neighborhoods, int clustersPerNeighborhood) {
int k = neighborhoods.size();

if (k == 0 || clustersPerNeighborhood <= 0) {
Expand All @@ -172,14 +181,24 @@ private void computeNeighborhoods(float[][] centers, List<int[]> neighborhoods,

for (int i = 0; i < k; i++) {
NeighborQueue queue = neighborQueues.get(i);
int neighborCount = queue.size();
int[] neighbors = new int[neighborCount];
queue.consumeNodes(neighbors);
neighborhoods.set(i, neighbors);
if (queue.size() == 0) {
// no neighbors, skip
neighborhoods.set(i, NeighborHood.EMPTY);
continue;
}
// consume the queue into the neighbors array and get the maximum intra-cluster distance
int[] neighbors = new int[queue.size()];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks much nicer now

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and faster!

float maxIntraDistance = queue.topScore();
int iter = 0;
while (queue.size() > 0) {
neighbors[neighbors.length - ++iter] = queue.pop();
}
NeighborHood neighborHood = new NeighborHood(neighbors, maxIntraDistance);
neighborhoods.set(i, neighborHood);
}
}

private int[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods, float[][] centroids, int[] assignments)
private int[] assignSpilled(FloatVectorValues vectors, List<NeighborHood> neighborhoods, float[][] centroids, int[] assignments)
throws IOException {
// SOAR uses an adjusted distance for assigning spilled documents which is
// given by:
Expand All @@ -200,7 +219,6 @@ private int[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods
float[] currentCentroid = centroids[currAssignment];

// TODO: cache these?
// float vectorCentroidDist = assignmentDistances[i];
float vectorCentroidDist = VectorUtil.squareDistance(vector, currentCentroid);

if (vectorCentroidDist > SOAR_MIN_DISTANCE) {
Expand All @@ -212,24 +230,33 @@ private int[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods

int bestAssignment = -1;
float minSoar = Float.MAX_VALUE;
assert neighborhoods.get(currAssignment) != null;
for (int neighbor : neighborhoods.get(currAssignment)) {
if (neighbor == currAssignment) {
continue;
int centroidCount = centroids.length;
IntToIntFunction centroidOrds = c -> c;
if (neighborhoods != null) {
assert neighborhoods.get(currAssignment) != null;
NeighborHood neighborhood = neighborhoods.get(currAssignment);
centroidCount = neighborhood.neighbors.length;
centroidOrds = c -> neighborhood.neighbors[c];
}
for (int j = 0; j < centroidCount; j++) {
int centroidOrd = centroidOrds.apply(j);
if (centroidOrd == currAssignment) {
continue; // skip the current assignment
}
float[] neighborCentroid = centroids[neighbor];
final float soar;
float[] centroid = centroids[centroidOrd];
float soar;
if (vectorCentroidDist > SOAR_MIN_DISTANCE) {
soar = ESVectorUtil.soarDistance(vector, neighborCentroid, diffs, soarLambda, vectorCentroidDist);
soar = ESVectorUtil.soarDistance(vector, centroid, diffs, soarLambda, vectorCentroidDist);
} else {
// if the vector is very close to the centroid, we look for the second-nearest centroid
soar = VectorUtil.squareDistance(vector, neighborCentroid);
soar = VectorUtil.squareDistance(vector, centroid);
}
if (soar < minSoar) {
bestAssignment = neighbor;
minSoar = soar;
bestAssignment = centroidOrd;
}
}

assert bestAssignment != -1 : "Failed to assign soar vector to centroid";
spilledAssignments[i] = bestAssignment;
}
Expand All @@ -250,6 +277,10 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) t
cluster(vectors, kMeansIntermediate, false);
}

record NeighborHood(int[] neighbors, float maxIntraDistance) {
static final NeighborHood EMPTY = new NeighborHood(new int[0], Float.POSITIVE_INFINITY);
}

/**
* cluster using a lloyd kmeans algorithm that also considers prior clustered neighborhoods when adjusting centroids
* this also is used to generate the neighborhood aware additional (SOAR) assignments
Expand All @@ -266,8 +297,9 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) t
void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, boolean neighborAware) throws IOException {
float[][] centroids = kMeansIntermediate.centroids();

List<int[]> neighborhoods = null;
if (neighborAware) {
List<NeighborHood> neighborhoods = null;
// if there are very few centroids, don't bother with neighborhoods or neighbor aware clustering
if (neighborAware && centroids.length > clustersPerNeighborhood) {
int k = centroids.length;
neighborhoods = new ArrayList<>(k);
for (int i = 0; i < k; ++i) {
Expand All @@ -284,7 +316,8 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, b
}
}

private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, List<int[]> neighborhoods) throws IOException {
private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, List<NeighborHood> neighborhoods)
throws IOException {
float[][] centroids = kMeansIntermediate.centroids();
int k = centroids.length;
int n = vectors.size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,33 +121,6 @@ public int pop() {
return decodeNodeId(heap.pop());
}

public void consumeNodes(int[] dest) {
if (dest.length < size()) {
throw new IllegalArgumentException("Destination array is too small. Expected at least " + size() + " elements.");
}
for (int i = 0; i < size(); i++) {
dest[i] = decodeNodeId(heap.get(i + 1));
}
}

public int consumeNodesAndScoresMin(int[] dest, float[] scores) {
if (dest.length < size() || scores.length < size()) {
throw new IllegalArgumentException("Destination array is too small. Expected at least " + size() + " elements.");
}
float bestScore = Float.POSITIVE_INFINITY;
int bestIdx = 0;
for (int i = 0; i < size(); i++) {
long heapValue = heap.get(i + 1);
scores[i] = decodeScore(heapValue);
dest[i] = decodeNodeId(heapValue);
if (scores[i] < bestScore) {
bestScore = scores[i];
bestIdx = i;
}
}
return bestIdx;
}

public void clear() {
heap.clear();
}
Expand Down