Skip to content

Don't accept clustersPerNeighborhood lower than 2 #130526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions muted-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -545,9 +545,6 @@ tests:
- class: org.elasticsearch.compute.aggregation.TopIntAggregatorFunctionTests
method: testManyInitialManyPartialFinalRunnerThrowing
issue: https://github.com/elastic/elasticsearch/issues/130145
- class: org.elasticsearch.index.codec.vectors.cluster.KMeansLocalTests
method: testKMeansNeighbors
issue: https://github.com/elastic/elasticsearch/issues/130258
- class: org.elasticsearch.xpack.test.rest.XPackRestIT
method: test {p0=esql/10_basic/basic with documents_found}
issue: https://github.com/elastic/elasticsearch/issues/130256
Expand Down Expand Up @@ -575,9 +572,6 @@ tests:
- class: org.elasticsearch.test.rest.yaml.RcsCcsCommonYamlTestSuiteIT
method: test {p0=msearch/20_typed_keys/Multisearch test with typed_keys parameter for sampler and significant terms}
issue: https://github.com/elastic/elasticsearch/issues/130472
- class: org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeansTests
method: testHKmeans
issue: https://github.com/elastic/elasticsearch/issues/130497
- class: org.elasticsearch.xpack.esql.action.EsqlActionBreakerIT
method: testProjectWhere
issue: https://github.com/elastic/elasticsearch/issues/130504
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IO
KMeansIntermediate kMeansIntermediate = clusterAndSplit(vectors, targetSize);
if (kMeansIntermediate.centroids().length > 1 && kMeansIntermediate.centroids().length < vectors.size()) {
int localSampleSize = Math.min(kMeansIntermediate.centroids().length * samplesPerCluster / 2, vectors.size());
KMeansLocal kMeansLocal = new KMeansLocal(localSampleSize, maxIterations, clustersPerNeighborhood, DEFAULT_SOAR_LAMBDA);
kMeansLocal.cluster(vectors, kMeansIntermediate, true);
KMeansLocal kMeansLocal = new KMeansLocal(localSampleSize, maxIterations);
kMeansLocal.cluster(vectors, kMeansIntermediate, clustersPerNeighborhood, DEFAULT_SOAR_LAMBDA);
}

return kMeansIntermediate;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,10 @@ class KMeansLocal {

final int sampleSize;
final int maxIterations;
final int clustersPerNeighborhood;
final float soarLambda;

KMeansLocal(int sampleSize, int maxIterations, int clustersPerNeighborhood, float soarLambda) {
KMeansLocal(int sampleSize, int maxIterations) {
this.sampleSize = sampleSize;
this.maxIterations = maxIterations;
this.clustersPerNeighborhood = clustersPerNeighborhood;
this.soarLambda = soarLambda;
}

KMeansLocal(int sampleSize, int maxIterations) {
this(sampleSize, maxIterations, -1, -1f);
}

/**
Expand Down Expand Up @@ -198,8 +190,13 @@ private void computeNeighborhoods(float[][] centers, List<NeighborHood> neighbor
}
}

private int[] assignSpilled(FloatVectorValues vectors, List<NeighborHood> neighborhoods, float[][] centroids, int[] assignments)
throws IOException {
private int[] assignSpilled(
FloatVectorValues vectors,
List<NeighborHood> neighborhoods,
float[][] centroids,
int[] assignments,
float soarLambda
) throws IOException {
// SOAR uses an adjusted distance for assigning spilled documents which is
// given by:
//
Expand Down Expand Up @@ -264,6 +261,10 @@ private int[] assignSpilled(FloatVectorValues vectors, List<NeighborHood> neighb
return spilledAssignments;
}

record NeighborHood(int[] neighbors, float maxIntraDistance) {
static final NeighborHood EMPTY = new NeighborHood(new int[0], Float.POSITIVE_INFINITY);
}

/**
* cluster using a lloyd k-means algorithm that is not neighbor aware
*
Expand All @@ -274,11 +275,7 @@ private int[] assignSpilled(FloatVectorValues vectors, List<NeighborHood> neighb
* @throws IOException is thrown if vectors is inaccessible
*/
void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) throws IOException {
cluster(vectors, kMeansIntermediate, false);
}

record NeighborHood(int[] neighbors, float maxIntraDistance) {
static final NeighborHood EMPTY = new NeighborHood(new int[0], Float.POSITIVE_INFINITY);
doCluster(vectors, kMeansIntermediate, -1, -1);
}

/**
Expand All @@ -290,12 +287,23 @@ record NeighborHood(int[] neighbors, float maxIntraDistance) {
* the prior assignments of the given vectors; care should be taken in
* passing in a valid output object with a centroids array that is the size of centroids expected
* and assignments that are the same size as the vectors. The SOAR assignments are overwritten by this operation.
* @param neighborAware whether nearby neighboring centroids and their vectors should be used to update the centroid positions,
* implies SOAR assignments
* @throws IOException is thrown if vectors is inaccessible
* @param clustersPerNeighborhood number of nearby neighboring centroids to be used to update the centroid positions.
* @param soarLambda lambda used for SOAR assignments
*
* @throws IOException is thrown if vectors is inaccessible or if the clustersPerNeighborhood is less than 2
*/
void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, boolean neighborAware) throws IOException {
void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, int clustersPerNeighborhood, float soarLambda)
throws IOException {
if (clustersPerNeighborhood < 2) {
throw new IllegalArgumentException("clustersPerNeighborhood must be at least 2, got [" + clustersPerNeighborhood + "]");
}
doCluster(vectors, kMeansIntermediate, clustersPerNeighborhood, soarLambda);
}

private void doCluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, int clustersPerNeighborhood, float soarLambda)
throws IOException {
float[][] centroids = kMeansIntermediate.centroids();
boolean neighborAware = clustersPerNeighborhood != -1 && centroids.length > 1;

List<NeighborHood> neighborhoods = null;
// if there are very few centroids, don't bother with neighborhoods or neighbor aware clustering
Expand All @@ -308,11 +316,11 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, b
computeNeighborhoods(centroids, neighborhoods, clustersPerNeighborhood);
}
cluster(vectors, kMeansIntermediate, neighborhoods);
if (neighborAware && clustersPerNeighborhood > 0) {
if (neighborAware) {
int[] assignments = kMeansIntermediate.assignments();
assert assignments != null;
assert assignments.length == vectors.size();
kMeansIntermediate.setSoarAssignments(assignSpilled(vectors, neighborhoods, centroids, assignments));
kMeansIntermediate.setSoarAssignments(assignSpilled(vectors, neighborhoods, centroids, assignments, soarLambda));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public void testHKmeans() throws IOException {
int dims = random().nextInt(2, 20);
int sampleSize = random().nextInt(100, nVectors + 1);
int maxIterations = random().nextInt(0, 100);
int clustersPerNeighborhood = random().nextInt(0, 512);
int clustersPerNeighborhood = random().nextInt(2, 512);
float soarLambda = random().nextFloat(0.5f, 1.5f);
FloatVectorValues vectors = generateData(nVectors, dims, nClusters);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,32 @@
import java.util.ArrayList;
import java.util.List;

import static org.hamcrest.Matchers.containsString;

public class KMeansLocalTests extends ESTestCase {

public void testIllegalClustersPerNeighborhood() {
KMeansLocal kMeansLocal = new KMeansLocal(randomInt(), randomInt());
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(new float[0][], new int[0], i -> i);
IllegalArgumentException ex = expectThrows(
IllegalArgumentException.class,
() -> kMeansLocal.cluster(
FloatVectorValues.fromFloats(List.of(), randomInt(1024)),
kMeansIntermediate,
randomIntBetween(Integer.MIN_VALUE, 1),
randomFloat()
)
);
assertThat(ex.getMessage(), containsString("clustersPerNeighborhood must be at least 2"));
}

public void testKMeansNeighbors() throws IOException {
int nClusters = random().nextInt(1, 10);
int nVectors = random().nextInt(nClusters * 100, nClusters * 200);
int dims = random().nextInt(2, 20);
int sampleSize = random().nextInt(100, nVectors + 1);
int maxIterations = random().nextInt(0, 100);
int clustersPerNeighborhood = random().nextInt(0, 512);
int clustersPerNeighborhood = random().nextInt(2, 512);
float soarLambda = random().nextFloat(0.5f, 1.5f);
FloatVectorValues vectors = generateData(nVectors, dims, nClusters);

Expand All @@ -49,8 +66,8 @@ public void testKMeansNeighbors() throws IOException {
}

KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, i -> assignmentOrdinals[i]);
KMeansLocal kMeansLocal = new KMeansLocal(sampleSize, maxIterations, clustersPerNeighborhood, soarLambda);
kMeansLocal.cluster(vectors, kMeansIntermediate, true);
KMeansLocal kMeansLocal = new KMeansLocal(sampleSize, maxIterations);
kMeansLocal.cluster(vectors, kMeansIntermediate, clustersPerNeighborhood, soarLambda);

assertEquals(nClusters, centroids.length);
assertNotNull(kMeansIntermediate.soarAssignments());
Expand Down Expand Up @@ -90,8 +107,8 @@ public void testKMeansNeighborsAllZero() throws IOException {
}

KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, i -> assignmentOrdinals[i]);
KMeansLocal kMeansLocal = new KMeansLocal(sampleSize, maxIterations, clustersPerNeighborhood, soarLambda);
kMeansLocal.cluster(fvv, kMeansIntermediate, true);
KMeansLocal kMeansLocal = new KMeansLocal(sampleSize, maxIterations);
kMeansLocal.cluster(fvv, kMeansIntermediate, clustersPerNeighborhood, soarLambda);

assertEquals(nClusters, centroids.length);
assertNotNull(kMeansIntermediate.soarAssignments());
Expand Down