@@ -148,6 +148,8 @@ def circumcenters_from(points, simplices, xp=np):
148148 return centers
149149
150150
151+ MAX_DEG_JAX = 128
152+
151153def voronoi_areas_via_delaunay_from (points , simplices , xp = np ):
152154 """
153155 Compute 'Voronoi-ish' cell areas for each vertex in a 2D Delaunay triangulation.
@@ -172,6 +174,8 @@ def voronoi_areas_via_delaunay_from(points, simplices, xp=np):
172174 areas : (N,)
173175 Voronoi-ish area associated with each vertex.
174176 """
177+ import jax
178+ import jax .numpy as jnp
175179
176180 pts = xp .asarray (points )
177181 tris = xp .asarray (simplices , dtype = int )
@@ -193,35 +197,28 @@ def voronoi_areas_via_delaunay_from(points, simplices, xp=np):
193197
194198 # 4) Compute how many triangles are incident to each vertex
195199 if xp is np :
196- counts = np .bincount (vert_sorted , minlength = N ) # (N,)
200+ counts = xp .bincount (vert_sorted , minlength = N ) # (N,)
201+ max_deg = int (counts .max ())
197202 else :
198- counts = jnp .bincount (vert_sorted , length = N ) # (N,)
199-
200- max_deg = int (counts .max ()) if xp is np else int (counts .max ().item ())
203+ counts = xp .bincount (vert_sorted , length = N ) # (N,)
204+ max_deg = MAX_DEG_JAX # static upper bound
201205
202206 # 5) Compute start index for each vertex's block in vert_sorted
203207 # start[v] = cumulative sum of counts up to v
204- if xp is np :
205- start = np .concatenate ([np .array ([0 ], dtype = int ), np .cumsum (counts [:- 1 ])])
206- else :
207- start = jnp .concatenate ([jnp .array ([0 ], dtype = int ), jnp .cumsum (counts [:- 1 ])])
208+ start = xp .concatenate ([xp .array ([0 ], dtype = int ), xp .cumsum (counts [:- 1 ])])
208209
209210 # Global indices 0..3M-1
210- if xp is np :
211- arange_all = np .arange (3 * M , dtype = int )
212- else :
213- arange_all = jnp .arange (3 * M , dtype = int )
211+ arange_all = xp .arange (3 * M , dtype = int )
214212
215213 # Position within each vertex block: pos = i - start[vertex]
216214 start_per_entry = start [vert_sorted ] # (3M,)
217215 pos = arange_all - start_per_entry # (3M,)
218216
219217 # 6) Scatter into a padded (N, max_deg, 2) array of circumcenters
218+ circum_padded = xp .zeros ((N , max_deg , 2 ), dtype = pts .dtype )
220219 if xp is np :
221- circum_padded = np .zeros ((N , max_deg , 2 ), dtype = pts .dtype )
222220 circum_padded [vert_sorted , pos , :] = centers_sorted
223221 else :
224- circum_padded = jnp .zeros ((N , max_deg , 2 ), dtype = pts .dtype )
225222 circum_padded = circum_padded .at [vert_sorted , pos , :].set (centers_sorted )
226223
227224 # 7) For each vertex, sort its circumcenters by angle around the vertex
@@ -231,10 +228,7 @@ def voronoi_areas_via_delaunay_from(points, simplices, xp=np):
231228 angles = xp .arctan2 (dy , dx )
232229
233230 # Mark which slots are valid (j < count[v])
234- if xp is np :
235- j_idx = np .arange (max_deg )[None , :] # (1, max_deg)
236- else :
237- j_idx = jnp .arange (max_deg )[None , :]
231+ j_idx = xp .arange (max_deg )[None , :]
238232 valid_mask = j_idx < counts [:, None ] # (N, max_deg)
239233
240234 # For invalid entries, set angle to a big constant so they go to the end
@@ -245,18 +239,11 @@ def voronoi_areas_via_delaunay_from(points, simplices, xp=np):
245239 order_angles = xp .argsort (angles_masked , axis = 1 ) # (N, max_deg)
246240
247241 # Reorder circumcenters accordingly
248- if xp is np :
249- centers_sorted2 = np .take_along_axis (
250- circum_padded ,
251- order_angles [..., None ].repeat (2 , axis = 2 ),
252- axis = 1 ,
253- ) # (N, max_deg, 2)
254- else :
255- centers_sorted2 = jnp .take_along_axis (
256- circum_padded ,
257- jnp .repeat (order_angles [..., None ], 2 , axis = 2 ),
258- axis = 1 ,
259- )
242+ centers_sorted2 = xp .take_along_axis (
243+ circum_padded ,
244+ order_angles [..., None ].repeat (2 , axis = 2 ),
245+ axis = 1 ,
246+ ) # (N, max_deg, 2)
260247
261248 # 8) Compute polygon area with shoelace formula per vertex
262249 x = centers_sorted2 [..., 0 ] # (N, max_deg)
@@ -432,7 +419,6 @@ def delaunay(self) -> "scipy.spatial.Delaunay":
432419 simplices = delaunay .simplices .astype (np .int32 )
433420 vertex_neighbor_vertices = delaunay .vertex_neighbor_vertices
434421
435-
436422 return DelaunayInterface (points , simplices , vertex_neighbor_vertices )
437423
438424 @property
@@ -487,7 +473,7 @@ def split_cross(self) -> np.ndarray:
487473 gradient regularization to an `Inversion` using a Delaunay triangulation or Voronoi mesh.
488474 """
489475
490- half_region_area_sqrt_lengths = 0.5 * np .sqrt (
476+ half_region_area_sqrt_lengths = 0.5 * self . _xp .sqrt (
491477 self .voronoi_pixel_areas_for_split
492478 )
493479
@@ -507,8 +493,9 @@ def voronoi_pixel_areas(self) -> np.ndarray:
507493 calculations.
508494 """
509495 return voronoi_areas_via_delaunay_from (
510- self .delaunay .points ,
511- self .delaunay .simplices ,
496+ points = self .delaunay .points ,
497+ simplices = self .delaunay .simplices ,
498+ xp = self ._xp ,
512499 )
513500
514501 @property
@@ -524,13 +511,20 @@ def voronoi_pixel_areas_for_split(self) -> np.ndarray:
524511 with large regularization coefficients, which is preferred at the edge of the mesh where the reconstruction
525512 goes to zero.
526513 """
527- areas = self .voronoi_pixel_areas
514+ areas = self ._xp . asarray ( self . voronoi_pixel_areas )
528515
529- max_area = np .percentile (areas , 90.0 )
516+ # 90th percentile
517+ max_area = self ._xp .percentile (areas , 90.0 )
530518
531- areas [areas == - 1 ] = max_area
532- areas [areas > max_area ] = max_area
519+ if self ._xp is np :
520+ # NumPy allows in-place mutation
521+ areas [areas == - 1 ] = max_area
522+ areas [areas > max_area ] = max_area
523+ return areas
533524
525+ # JAX arrays are immutable → use .at[]
526+ areas = self ._xp .where (areas == - 1 , max_area , areas )
527+ areas = self ._xp .where (areas > max_area , max_area , areas )
534528 return areas
535529
536530 @property
0 commit comments