diff --git a/muted-tests.yml b/muted-tests.yml index e51f39c997a88..73e0407f40651 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -545,9 +545,6 @@ tests: - class: org.elasticsearch.compute.aggregation.TopIntAggregatorFunctionTests method: testManyInitialManyPartialFinalRunnerThrowing issue: https://github.com/elastic/elasticsearch/issues/130145 -- class: org.elasticsearch.index.codec.vectors.cluster.KMeansLocalTests - method: testKMeansNeighbors - issue: https://github.com/elastic/elasticsearch/issues/130258 - class: org.elasticsearch.xpack.test.rest.XPackRestIT method: test {p0=esql/10_basic/basic with documents_found} issue: https://github.com/elastic/elasticsearch/issues/130256 @@ -575,9 +572,6 @@ tests: - class: org.elasticsearch.test.rest.yaml.RcsCcsCommonYamlTestSuiteIT method: test {p0=msearch/20_typed_keys/Multisearch test with typed_keys parameter for sampler and significant terms} issue: https://github.com/elastic/elasticsearch/issues/130472 -- class: org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeansTests - method: testHKmeans - issue: https://github.com/elastic/elasticsearch/issues/130497 - class: org.elasticsearch.xpack.esql.action.EsqlActionBreakerIT method: testProjectWhere issue: https://github.com/elastic/elasticsearch/issues/130504 diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java index 34d802e5f0aaf..b0709417b350f 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java @@ -68,8 +68,8 @@ public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IO KMeansIntermediate kMeansIntermediate = clusterAndSplit(vectors, targetSize); if (kMeansIntermediate.centroids().length > 1 && kMeansIntermediate.centroids().length < vectors.size()) { int localSampleSize = Math.min(kMeansIntermediate.centroids().length * samplesPerCluster / 2, vectors.size()); - KMeansLocal kMeansLocal = new KMeansLocal(localSampleSize, maxIterations, clustersPerNeighborhood, DEFAULT_SOAR_LAMBDA); - kMeansLocal.cluster(vectors, kMeansIntermediate, true); + KMeansLocal kMeansLocal = new KMeansLocal(localSampleSize, maxIterations); + kMeansLocal.cluster(vectors, kMeansIntermediate, clustersPerNeighborhood, DEFAULT_SOAR_LAMBDA); } return kMeansIntermediate; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java index 05372851768e7..5f2ff53e2a4f6 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java @@ -35,18 +35,10 @@ class KMeansLocal { final int sampleSize; final int maxIterations; - final int clustersPerNeighborhood; - final float soarLambda; - KMeansLocal(int sampleSize, int maxIterations, int clustersPerNeighborhood, float soarLambda) { + KMeansLocal(int sampleSize, int maxIterations) { this.sampleSize = sampleSize; this.maxIterations = maxIterations; - this.clustersPerNeighborhood = clustersPerNeighborhood; - this.soarLambda = soarLambda; - } - - KMeansLocal(int sampleSize, int maxIterations) { - this(sampleSize, maxIterations, -1, -1f); } /** @@ -198,8 +190,13 @@ private void computeNeighborhoods(float[][] centers, List neighbor } } - private int[] assignSpilled(FloatVectorValues vectors, List neighborhoods, float[][] centroids, int[] assignments) - throws IOException { + private int[] assignSpilled( + FloatVectorValues vectors, + List neighborhoods, + float[][] centroids, + int[] assignments, + float soarLambda + ) throws IOException { // SOAR uses an adjusted distance for assigning spilled documents which is // given by: // @@ -264,6 +261,10 @@ private int[] assignSpilled(FloatVectorValues vectors, List neighb return spilledAssignments; } + record NeighborHood(int[] neighbors, float maxIntraDistance) { + static final NeighborHood EMPTY = new NeighborHood(new int[0], Float.POSITIVE_INFINITY); + } + /** * cluster using a lloyd k-means algorithm that is not neighbor aware * @@ -274,11 +275,7 @@ private int[] assignSpilled(FloatVectorValues vectors, List neighb * @throws IOException is thrown if vectors is inaccessible */ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) throws IOException { - cluster(vectors, kMeansIntermediate, false); - } - - record NeighborHood(int[] neighbors, float maxIntraDistance) { - static final NeighborHood EMPTY = new NeighborHood(new int[0], Float.POSITIVE_INFINITY); + doCluster(vectors, kMeansIntermediate, -1, -1); } /** @@ -290,12 +287,23 @@ record NeighborHood(int[] neighbors, float maxIntraDistance) { * the prior assignments of the given vectors; care should be taken in * passing in a valid output object with a centroids array that is the size of centroids expected * and assignments that are the same size as the vectors. The SOAR assignments are overwritten by this operation. - * @param neighborAware whether nearby neighboring centroids and their vectors should be used to update the centroid positions, - * implies SOAR assignments - * @throws IOException is thrown if vectors is inaccessible + * @param clustersPerNeighborhood number of nearby neighboring centroids to be used to update the centroid positions. + * @param soarLambda lambda used for SOAR assignments + * + * @throws IOException is thrown if vectors is inaccessible or if the clustersPerNeighborhood is less than 2 */ - void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, boolean neighborAware) throws IOException { + void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, int clustersPerNeighborhood, float soarLambda) + throws IOException { + if (clustersPerNeighborhood < 2) { + throw new IllegalArgumentException("clustersPerNeighborhood must be at least 2, got [" + clustersPerNeighborhood + "]"); + } + doCluster(vectors, kMeansIntermediate, clustersPerNeighborhood, soarLambda); + } + + private void doCluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, int clustersPerNeighborhood, float soarLambda) + throws IOException { float[][] centroids = kMeansIntermediate.centroids(); + boolean neighborAware = clustersPerNeighborhood != -1 && centroids.length > 1; List neighborhoods = null; // if there are very few centroids, don't bother with neighborhoods or neighbor aware clustering @@ -308,11 +316,11 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, b computeNeighborhoods(centroids, neighborhoods, clustersPerNeighborhood); } cluster(vectors, kMeansIntermediate, neighborhoods); - if (neighborAware && clustersPerNeighborhood > 0) { + if (neighborAware) { int[] assignments = kMeansIntermediate.assignments(); assert assignments != null; assert assignments.length == vectors.size(); - kMeansIntermediate.setSoarAssignments(assignSpilled(vectors, neighborhoods, centroids, assignments)); + kMeansIntermediate.setSoarAssignments(assignSpilled(vectors, neighborhoods, centroids, assignments, soarLambda)); } } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeansTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeansTests.java index 4c481ca4a5f36..e8578338f09a3 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeansTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeansTests.java @@ -23,7 +23,7 @@ public void testHKmeans() throws IOException { int dims = random().nextInt(2, 20); int sampleSize = random().nextInt(100, nVectors + 1); int maxIterations = random().nextInt(0, 100); - int clustersPerNeighborhood = random().nextInt(0, 512); + int clustersPerNeighborhood = random().nextInt(2, 512); float soarLambda = random().nextFloat(0.5f, 1.5f); FloatVectorValues vectors = generateData(nVectors, dims, nClusters); diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocalTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocalTests.java index c0a0ca8341129..a2d34d28f3784 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocalTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocalTests.java @@ -17,15 +17,32 @@ import java.util.ArrayList; import java.util.List; +import static org.hamcrest.Matchers.containsString; + public class KMeansLocalTests extends ESTestCase { + public void testIllegalClustersPerNeighborhood() { + KMeansLocal kMeansLocal = new KMeansLocal(randomInt(), randomInt()); + KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(new float[0][], new int[0], i -> i); + IllegalArgumentException ex = expectThrows( + IllegalArgumentException.class, + () -> kMeansLocal.cluster( + FloatVectorValues.fromFloats(List.of(), randomInt(1024)), + kMeansIntermediate, + randomIntBetween(Integer.MIN_VALUE, 1), + randomFloat() + ) + ); + assertThat(ex.getMessage(), containsString("clustersPerNeighborhood must be at least 2")); + } + public void testKMeansNeighbors() throws IOException { int nClusters = random().nextInt(1, 10); int nVectors = random().nextInt(nClusters * 100, nClusters * 200); int dims = random().nextInt(2, 20); int sampleSize = random().nextInt(100, nVectors + 1); int maxIterations = random().nextInt(0, 100); - int clustersPerNeighborhood = random().nextInt(0, 512); + int clustersPerNeighborhood = random().nextInt(2, 512); float soarLambda = random().nextFloat(0.5f, 1.5f); FloatVectorValues vectors = generateData(nVectors, dims, nClusters); @@ -49,8 +66,8 @@ public void testKMeansNeighbors() throws IOException { } KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, i -> assignmentOrdinals[i]); - KMeansLocal kMeansLocal = new KMeansLocal(sampleSize, maxIterations, clustersPerNeighborhood, soarLambda); - kMeansLocal.cluster(vectors, kMeansIntermediate, true); + KMeansLocal kMeansLocal = new KMeansLocal(sampleSize, maxIterations); + kMeansLocal.cluster(vectors, kMeansIntermediate, clustersPerNeighborhood, soarLambda); assertEquals(nClusters, centroids.length); assertNotNull(kMeansIntermediate.soarAssignments()); @@ -90,8 +107,8 @@ public void testKMeansNeighborsAllZero() throws IOException { } KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, i -> assignmentOrdinals[i]); - KMeansLocal kMeansLocal = new KMeansLocal(sampleSize, maxIterations, clustersPerNeighborhood, soarLambda); - kMeansLocal.cluster(fvv, kMeansIntermediate, true); + KMeansLocal kMeansLocal = new KMeansLocal(sampleSize, maxIterations); + kMeansLocal.cluster(fvv, kMeansIntermediate, clustersPerNeighborhood, soarLambda); assertEquals(nClusters, centroids.length); assertNotNull(kMeansIntermediate.soarAssignments());