@@ -14,21 +14,11 @@ private static Vector3[] KMeansCluster(Span<Vector3> points, int k, out int[] co
14
14
int [ ] clusterIds = new int [ points . Length ] ;
15
15
16
16
// Track the centroids of each cluster and its member count
17
- // TODO: stackalloc is great here, but pooling should be thresholded
18
- // just in case
19
17
Span < Vector3 > centroids = stackalloc Vector3 [ k ] ;
20
18
counts = new int [ k ] ;
21
-
19
+
22
20
// Split the points into arbitrary clusters
23
- // NOTE: Can this be rearranged to converge faster?
24
- #if NET6_0_OR_GREATER
25
- var offset = Random . Shared . Next ( k ) ; // Mathematically true random sampling
26
- #else
27
- var rand = new Random ( ) ;
28
- var offset = rand . Next ( k ) ;
29
- #endif
30
- for ( int i = 0 ; i < clusterIds . Length ; i ++ )
31
- clusterIds [ i ] = ( i + offset ) % k ;
21
+ Split ( k , clusterIds ) ;
32
22
33
23
bool converged = false ;
34
24
while ( ! converged )
@@ -37,83 +27,16 @@ private static Vector3[] KMeansCluster(Span<Vector3> points, int k, out int[] co
37
27
// to false when adjust the clusters
38
28
converged = true ;
39
29
40
- // KMeans Loop Step 1:
41
- // Calculate/Recalculate the centroids of each cluster
42
-
43
- // Clear centroids and counts before recalculation
44
- for ( int i = 0 ; i < centroids . Length ; i ++ )
45
- {
46
- centroids [ i ] = Vector3 . Zero ;
47
- counts [ i ] = 0 ;
48
- }
49
-
50
- // Accumulate step in centroid calculation
51
- for ( int i = 0 ; i < clusterIds . Length ; i ++ )
52
- {
53
- int id = clusterIds [ i ] ;
54
- centroids [ id ] += points [ i ] ;
55
- counts [ id ] ++ ;
56
- }
30
+ // Calculate/Recalculate centroids
31
+ CalculateCentroidsAndPrune ( ref centroids , ref counts , points , clusterIds ) ;
57
32
58
- // Prune empty clusters
59
- // All empty clusters are swapped to the end of the span
60
- // then a slice is taken with only the remaining populated clusters
61
- int pivot = counts . Length ;
62
- for ( int i = 0 ; i < pivot ; )
63
- {
64
- // Increment and continue if populated
65
- if ( counts [ i ] != 0 )
66
- {
67
- i ++ ;
68
- continue ;
69
- }
70
-
71
- // The item is not populated. Swap to end and move pivot
72
- // NOTE: This is a oneway swap. We're discarding the 0 anyways.
73
- pivot -- ;
74
- counts [ i ] = counts [ pivot ] ;
75
- }
76
-
77
- #if ! WINDOWS_UWP
78
- counts = counts [ ..pivot ] ;
79
- centroids = centroids [ ..pivot ] ;
80
- #elif WINDOWS_UWP
81
- Array . Resize ( ref counts , pivot ) ;
82
- centroids = centroids . Slice ( 0 , pivot ) ;
83
- #endif
84
-
85
- // Division step in centroid calculation
86
- for ( int i = 0 ; i < centroids . Length ; i ++ )
87
- centroids [ i ] /= counts [ i ] ;
88
-
89
- // KMeans Loop Step 2:
90
33
// Move each point's clusterId to the nearest cluster centroid
91
34
for ( int i = 0 ; i < points . Length ; i ++ )
92
35
{
93
- Vector3 point = points [ i ] ;
94
- var oldId = clusterIds [ i ] ;
95
-
96
- // Track the nearest centroid's distance and the index of that centroid
97
- float nearestDistance = float . PositiveInfinity ;
98
- int nearestIndex = - 1 ;
99
-
100
- for ( int j = 0 ; j < centroids . Length ; j ++ )
101
- {
102
- // Compare the point to the jth centroid
103
- float distance = Vector3 . DistanceSquared ( point , centroids [ j ] ) ;
104
-
105
- // Skip the cluster if further than the nearest seen cluster
106
- if ( nearestDistance < distance )
107
- continue ;
108
-
109
- // This is the nearest cluster
110
- // Update the distance and index
111
- nearestDistance = distance ;
112
- nearestIndex = j ;
113
- }
36
+ var nearestIndex = FindNearestClusterIndex ( points [ i ] , centroids ) ;
114
37
115
38
// The nearest cluster hasn't changed. Do nothing
116
- if ( oldId == nearestIndex )
39
+ if ( clusterIds [ i ] == nearestIndex )
117
40
continue ;
118
41
119
42
// Update the cluster id and note that we have not converged
@@ -125,6 +48,105 @@ private static Vector3[] KMeansCluster(Span<Vector3> points, int k, out int[] co
125
48
return centroids . ToArray ( ) ;
126
49
}
127
50
51
+ /// <summary>
52
+ /// Assigns arbitrary clusterIds for each point
53
+ /// </summary>
54
+ private static void Split ( int k , int [ ] clusterIds )
55
+ {
56
+ // Mathematically true random sampling
57
+ #if NET6_0_OR_GREATER
58
+ var offset = Random . Shared . Next ( k ) ;
59
+ #else
60
+ var rand = new Random ( ) ;
61
+ var offset = rand . Next ( k ) ;
62
+ #endif
63
+
64
+ // Assign each clusters id
65
+ for ( int i = 0 ; i < clusterIds . Length ; i ++ )
66
+ clusterIds [ i ] = ( i + offset ) % k ;
67
+ }
68
+
69
+ /// <summary>
70
+ /// Calculates the centroid of each cluster, and prunes empty clusters.
71
+ /// </summary>
72
+ private static void CalculateCentroidsAndPrune ( ref Span < Vector3 > centroids , ref int [ ] counts , Span < Vector3 > points , int [ ] clusterIds )
73
+ {
74
+ // Clear centroids and counts before recalculation
75
+ for ( int i = 0 ; i < centroids . Length ; i ++ )
76
+ {
77
+ centroids [ i ] = Vector3 . Zero ;
78
+ counts [ i ] = 0 ;
79
+ }
80
+
81
+ // Accumulate step in centroid calculation
82
+ for ( int i = 0 ; i < clusterIds . Length ; i ++ )
83
+ {
84
+ int id = clusterIds [ i ] ;
85
+ centroids [ id ] += points [ i ] ;
86
+ counts [ id ] ++ ;
87
+ }
88
+
89
+ // Prune empty clusters
90
+ // All empty clusters are swapped to the end of the span
91
+ // then a slice is taken with only the remaining populated clusters
92
+ int pivot = counts . Length ;
93
+ for ( int i = 0 ; i < pivot ; )
94
+ {
95
+ // Increment and continue if populated
96
+ if ( counts [ i ] != 0 )
97
+ {
98
+ i ++ ;
99
+ continue ;
100
+ }
101
+
102
+ // The item is not populated. Swap to end and move pivot
103
+ // NOTE: This is a one-way "swap". We're discarding the 0s anyways.
104
+ pivot -- ;
105
+ centroids [ i ] = centroids [ pivot ] ;
106
+ counts [ i ] = counts [ pivot ] ;
107
+ }
108
+
109
+ // Perform slice
110
+ #if ! WINDOWS_UWP
111
+ counts = counts [ ..pivot ] ;
112
+ centroids = centroids [ ..pivot ] ;
113
+ #elif WINDOWS_UWP
114
+ Array . Resize ( ref counts , pivot ) ;
115
+ centroids = centroids . Slice ( 0 , pivot ) ;
116
+ #endif
117
+
118
+ // Division step in centroid calculation
119
+ for ( int i = 0 ; i < centroids . Length ; i ++ )
120
+ centroids [ i ] /= counts [ i ] ;
121
+ }
122
+
123
+ /// <summary>
124
+ /// Finds the index of the centroid nearest the point
125
+ /// </summary>
126
+ private static int FindNearestClusterIndex ( Vector3 point , Span < Vector3 > centroids )
127
+ {
128
+ // Track the nearest centroid's distance and the index of that centroid
129
+ float nearestDistance = float . PositiveInfinity ;
130
+ int nearestIndex = - 1 ;
131
+
132
+ for ( int j = 0 ; j < centroids . Length ; j ++ )
133
+ {
134
+ // Compare the point to the jth centroid
135
+ float distance = Vector3 . DistanceSquared ( point , centroids [ j ] ) ;
136
+
137
+ // Skip the cluster if further than the nearest seen cluster
138
+ if ( nearestDistance < distance )
139
+ continue ;
140
+
141
+ // This is the nearest cluster
142
+ // Update the distance and index
143
+ nearestDistance = distance ;
144
+ nearestIndex = j ;
145
+ }
146
+
147
+ return nearestIndex ;
148
+ }
149
+
128
150
private static float FindColorfulness ( Vector3 color )
129
151
{
130
152
var rg = color . X - color . Y ;
0 commit comments