Skip to content

Commit d155be6

Browse files
Jammy2211Jammy2211
authored andcommitted
areas improved
1 parent e39d991 commit d155be6

2 files changed

Lines changed: 34 additions & 38 deletions

File tree

autoarray/structures/mesh/triangulation_2d.py

Lines changed: 32 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ def circumcenters_from(points, simplices, xp=np):
148148
return centers
149149

150150

151+
MAX_DEG_JAX = 128
152+
151153
def voronoi_areas_via_delaunay_from(points, simplices, xp=np):
152154
"""
153155
Compute 'Voronoi-ish' cell areas for each vertex in a 2D Delaunay triangulation.
@@ -172,6 +174,8 @@ def voronoi_areas_via_delaunay_from(points, simplices, xp=np):
172174
areas : (N,)
173175
Voronoi-ish area associated with each vertex.
174176
"""
177+
import jax
178+
import jax.numpy as jnp
175179

176180
pts = xp.asarray(points)
177181
tris = xp.asarray(simplices, dtype=int)
@@ -193,35 +197,28 @@ def voronoi_areas_via_delaunay_from(points, simplices, xp=np):
193197

194198
# 4) Compute how many triangles are incident to each vertex
195199
if xp is np:
196-
counts = np.bincount(vert_sorted, minlength=N) # (N,)
200+
counts = xp.bincount(vert_sorted, minlength=N) # (N,)
201+
max_deg = int(counts.max())
197202
else:
198-
counts = jnp.bincount(vert_sorted, length=N) # (N,)
199-
200-
max_deg = int(counts.max()) if xp is np else int(counts.max().item())
203+
counts = xp.bincount(vert_sorted, length=N) # (N,)
204+
max_deg = MAX_DEG_JAX # static upper bound
201205

202206
# 5) Compute start index for each vertex's block in vert_sorted
203207
# start[v] = cumulative sum of counts up to v
204-
if xp is np:
205-
start = np.concatenate([np.array([0], dtype=int), np.cumsum(counts[:-1])])
206-
else:
207-
start = jnp.concatenate([jnp.array([0], dtype=int), jnp.cumsum(counts[:-1])])
208+
start = xp.concatenate([xp.array([0], dtype=int), xp.cumsum(counts[:-1])])
208209

209210
# Global indices 0..3M-1
210-
if xp is np:
211-
arange_all = np.arange(3 * M, dtype=int)
212-
else:
213-
arange_all = jnp.arange(3 * M, dtype=int)
211+
arange_all = xp.arange(3 * M, dtype=int)
214212

215213
# Position within each vertex block: pos = i - start[vertex]
216214
start_per_entry = start[vert_sorted] # (3M,)
217215
pos = arange_all - start_per_entry # (3M,)
218216

219217
# 6) Scatter into a padded (N, max_deg, 2) array of circumcenters
218+
circum_padded = xp.zeros((N, max_deg, 2), dtype=pts.dtype)
220219
if xp is np:
221-
circum_padded = np.zeros((N, max_deg, 2), dtype=pts.dtype)
222220
circum_padded[vert_sorted, pos, :] = centers_sorted
223221
else:
224-
circum_padded = jnp.zeros((N, max_deg, 2), dtype=pts.dtype)
225222
circum_padded = circum_padded.at[vert_sorted, pos, :].set(centers_sorted)
226223

227224
# 7) For each vertex, sort its circumcenters by angle around the vertex
@@ -231,10 +228,7 @@ def voronoi_areas_via_delaunay_from(points, simplices, xp=np):
231228
angles = xp.arctan2(dy, dx)
232229

233230
# Mark which slots are valid (j < count[v])
234-
if xp is np:
235-
j_idx = np.arange(max_deg)[None, :] # (1, max_deg)
236-
else:
237-
j_idx = jnp.arange(max_deg)[None, :]
231+
j_idx = xp.arange(max_deg)[None, :]
238232
valid_mask = j_idx < counts[:, None] # (N, max_deg)
239233

240234
# For invalid entries, set angle to a big constant so they go to the end
@@ -245,18 +239,11 @@ def voronoi_areas_via_delaunay_from(points, simplices, xp=np):
245239
order_angles = xp.argsort(angles_masked, axis=1) # (N, max_deg)
246240

247241
# Reorder circumcenters accordingly
248-
if xp is np:
249-
centers_sorted2 = np.take_along_axis(
250-
circum_padded,
251-
order_angles[..., None].repeat(2, axis=2),
252-
axis=1,
253-
) # (N, max_deg, 2)
254-
else:
255-
centers_sorted2 = jnp.take_along_axis(
256-
circum_padded,
257-
jnp.repeat(order_angles[..., None], 2, axis=2),
258-
axis=1,
259-
)
242+
centers_sorted2 = xp.take_along_axis(
243+
circum_padded,
244+
order_angles[..., None].repeat(2, axis=2),
245+
axis=1,
246+
) # (N, max_deg, 2)
260247

261248
# 8) Compute polygon area with shoelace formula per vertex
262249
x = centers_sorted2[..., 0] # (N, max_deg)
@@ -432,7 +419,6 @@ def delaunay(self) -> "scipy.spatial.Delaunay":
432419
simplices = delaunay.simplices.astype(np.int32)
433420
vertex_neighbor_vertices = delaunay.vertex_neighbor_vertices
434421

435-
436422
return DelaunayInterface(points, simplices, vertex_neighbor_vertices)
437423

438424
@property
@@ -487,7 +473,7 @@ def split_cross(self) -> np.ndarray:
487473
gradient regularization to an `Inversion` using a Delaunay triangulation or Voronoi mesh.
488474
"""
489475

490-
half_region_area_sqrt_lengths = 0.5 * np.sqrt(
476+
half_region_area_sqrt_lengths = 0.5 * self._xp.sqrt(
491477
self.voronoi_pixel_areas_for_split
492478
)
493479

@@ -507,8 +493,9 @@ def voronoi_pixel_areas(self) -> np.ndarray:
507493
calculations.
508494
"""
509495
return voronoi_areas_via_delaunay_from(
510-
self.delaunay.points,
511-
self.delaunay.simplices,
496+
points=self.delaunay.points,
497+
simplices=self.delaunay.simplices,
498+
xp=self._xp,
512499
)
513500

514501
@property
@@ -524,13 +511,20 @@ def voronoi_pixel_areas_for_split(self) -> np.ndarray:
524511
with large regularization coefficients, which is preferred at the edge of the mesh where the reconstruction
525512
goes to zero.
526513
"""
527-
areas = self.voronoi_pixel_areas
514+
areas = self._xp.asarray(self.voronoi_pixel_areas)
528515

529-
max_area = np.percentile(areas, 90.0)
516+
# 90th percentile
517+
max_area = self._xp.percentile(areas, 90.0)
530518

531-
areas[areas == -1] = max_area
532-
areas[areas > max_area] = max_area
519+
if self._xp is np:
520+
# NumPy allows in-place mutation
521+
areas[areas == -1] = max_area
522+
areas[areas > max_area] = max_area
523+
return areas
533524

525+
# JAX arrays are immutable → use .at[]
526+
areas = self._xp.where(areas == -1, max_area, areas)
527+
areas = self._xp.where(areas > max_area, max_area, areas)
534528
return areas
535529

536530
@property

test_autoarray/structures/mesh/test_voronoi.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def test__mesh_areas():
5454

5555
mesh = aa.Mesh2DVoronoi(values=grid)
5656

57+
print(mesh.voronoi_pixel_areas_for_split)
58+
5759
assert mesh.voronoi_pixel_areas_for_split == pytest.approx(
5860
np.array(
5961
[

0 commit comments

Comments
 (0)