diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java index 3d037ecf749db..05372851768e7 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java @@ -82,7 +82,7 @@ private static boolean stepLloyd( float[][] centroids, float[][] nextCentroids, int[] assignments, - List neighborhoods + List neighborhoods ) throws IOException { boolean changed = false; int dim = vectors.dimension(); @@ -124,11 +124,20 @@ private static boolean stepLloyd( return changed; } - private static int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) { + private static int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centroidIdx, NeighborHood neighborhood) { int bestCentroidOffset = centroidIdx; assert centroidIdx >= 0 && centroidIdx < centroids.length; float minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]); - for (int offset : centroidOffsets) { + for (int i = 0; i < neighborhood.neighbors.length; i++) { + int offset = neighborhood.neighbors[i]; + // float score = neighborhood.scores[i]; + assert offset >= 0 && offset < centroids.length : "Invalid neighbor offset: " + offset; + if (minDsq < neighborhood.maxIntraDistance) { + // if the distance found is smaller than the maximum intra-cluster distance + // we don't consider it for further re-assignment + return bestCentroidOffset; + } + // compute the distance to the centroid float dsq = VectorUtil.squareDistance(vector, centroids[offset]); if (dsq < minDsq) { minDsq = dsq; @@ -151,7 +160,7 @@ private static int getBestCentroid(float[][] centroids, float[] vector) { return bestCentroidOffset; } - private void computeNeighborhoods(float[][] centers, List neighborhoods, int clustersPerNeighborhood) { + private void computeNeighborhoods(float[][] centers, List neighborhoods, int clustersPerNeighborhood) { int k = neighborhoods.size(); if (k == 0 || clustersPerNeighborhood <= 0) { @@ -172,14 +181,24 @@ private void computeNeighborhoods(float[][] centers, List neighborhoods, for (int i = 0; i < k; i++) { NeighborQueue queue = neighborQueues.get(i); - int neighborCount = queue.size(); - int[] neighbors = new int[neighborCount]; - queue.consumeNodes(neighbors); - neighborhoods.set(i, neighbors); + if (queue.size() == 0) { + // no neighbors, skip + neighborhoods.set(i, NeighborHood.EMPTY); + continue; + } + // consume the queue into the neighbors array and get the maximum intra-cluster distance + int[] neighbors = new int[queue.size()]; + float maxIntraDistance = queue.topScore(); + int iter = 0; + while (queue.size() > 0) { + neighbors[neighbors.length - ++iter] = queue.pop(); + } + NeighborHood neighborHood = new NeighborHood(neighbors, maxIntraDistance); + neighborhoods.set(i, neighborHood); } } - private int[] assignSpilled(FloatVectorValues vectors, List neighborhoods, float[][] centroids, int[] assignments) + private int[] assignSpilled(FloatVectorValues vectors, List neighborhoods, float[][] centroids, int[] assignments) throws IOException { // SOAR uses an adjusted distance for assigning spilled documents which is // given by: @@ -200,7 +219,6 @@ private int[] assignSpilled(FloatVectorValues vectors, List neighborhoods float[] currentCentroid = centroids[currAssignment]; // TODO: cache these? - // float vectorCentroidDist = assignmentDistances[i]; float vectorCentroidDist = VectorUtil.squareDistance(vector, currentCentroid); if (vectorCentroidDist > SOAR_MIN_DISTANCE) { @@ -212,24 +230,33 @@ private int[] assignSpilled(FloatVectorValues vectors, List neighborhoods int bestAssignment = -1; float minSoar = Float.MAX_VALUE; - assert neighborhoods.get(currAssignment) != null; - for (int neighbor : neighborhoods.get(currAssignment)) { - if (neighbor == currAssignment) { - continue; + int centroidCount = centroids.length; + IntToIntFunction centroidOrds = c -> c; + if (neighborhoods != null) { + assert neighborhoods.get(currAssignment) != null; + NeighborHood neighborhood = neighborhoods.get(currAssignment); + centroidCount = neighborhood.neighbors.length; + centroidOrds = c -> neighborhood.neighbors[c]; + } + for (int j = 0; j < centroidCount; j++) { + int centroidOrd = centroidOrds.apply(j); + if (centroidOrd == currAssignment) { + continue; // skip the current assignment } - float[] neighborCentroid = centroids[neighbor]; - final float soar; + float[] centroid = centroids[centroidOrd]; + float soar; if (vectorCentroidDist > SOAR_MIN_DISTANCE) { - soar = ESVectorUtil.soarDistance(vector, neighborCentroid, diffs, soarLambda, vectorCentroidDist); + soar = ESVectorUtil.soarDistance(vector, centroid, diffs, soarLambda, vectorCentroidDist); } else { // if the vector is very close to the centroid, we look for the second-nearest centroid - soar = VectorUtil.squareDistance(vector, neighborCentroid); + soar = VectorUtil.squareDistance(vector, centroid); } if (soar < minSoar) { - bestAssignment = neighbor; minSoar = soar; + bestAssignment = centroidOrd; } } + assert bestAssignment != -1 : "Failed to assign soar vector to centroid"; spilledAssignments[i] = bestAssignment; } @@ -250,6 +277,10 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) t cluster(vectors, kMeansIntermediate, false); } + record NeighborHood(int[] neighbors, float maxIntraDistance) { + static final NeighborHood EMPTY = new NeighborHood(new int[0], Float.POSITIVE_INFINITY); + } + /** * cluster using a lloyd kmeans algorithm that also considers prior clustered neighborhoods when adjusting centroids * this also is used to generate the neighborhood aware additional (SOAR) assignments @@ -266,8 +297,9 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) t void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, boolean neighborAware) throws IOException { float[][] centroids = kMeansIntermediate.centroids(); - List neighborhoods = null; - if (neighborAware) { + List neighborhoods = null; + // if there are very few centroids, don't bother with neighborhoods or neighbor aware clustering + if (neighborAware && centroids.length > clustersPerNeighborhood) { int k = centroids.length; neighborhoods = new ArrayList<>(k); for (int i = 0; i < k; ++i) { @@ -284,7 +316,8 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, b } } - private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, List neighborhoods) throws IOException { + private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, List neighborhoods) + throws IOException { float[][] centroids = kMeansIntermediate.centroids(); int k = centroids.length; int n = vectors.size(); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/NeighborQueue.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/NeighborQueue.java index 48aa3c5004843..89bc942374117 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/NeighborQueue.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/NeighborQueue.java @@ -121,33 +121,6 @@ public int pop() { return decodeNodeId(heap.pop()); } - public void consumeNodes(int[] dest) { - if (dest.length < size()) { - throw new IllegalArgumentException("Destination array is too small. Expected at least " + size() + " elements."); - } - for (int i = 0; i < size(); i++) { - dest[i] = decodeNodeId(heap.get(i + 1)); - } - } - - public int consumeNodesAndScoresMin(int[] dest, float[] scores) { - if (dest.length < size() || scores.length < size()) { - throw new IllegalArgumentException("Destination array is too small. Expected at least " + size() + " elements."); - } - float bestScore = Float.POSITIVE_INFINITY; - int bestIdx = 0; - for (int i = 0; i < size(); i++) { - long heapValue = heap.get(i + 1); - scores[i] = decodeScore(heapValue); - dest[i] = decodeNodeId(heapValue); - if (scores[i] < bestScore) { - bestScore = scores[i]; - bestIdx = i; - } - } - return bestIdx; - } - public void clear() { heap.clear(); }