Skip to content

Commit 8d6089e

Browse files
Jammy2211Jammy2211
authored andcommitted
test__voronoi_areas_via_delaunay_from
1 parent bb5bc27 commit 8d6089e

2 files changed

Lines changed: 199 additions & 38 deletions

File tree

autoarray/structures/mesh/triangulation_2d.py

Lines changed: 157 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -103,54 +103,178 @@ def find_simplex_from(query_points, points, simplices):
103103
return simplex_idx
104104

105105

106-
def vertex_areas_from_delaunay(points, simplices, xp=np):
106+
def circumcenters_from(points, simplices, xp=np):
107107
"""
108-
Compute per-vertex areas using:
109-
vertex_area[v] = sum(area(triangles incident to v)) / 3
108+
Compute triangle circumcenters for a Delaunay triangulation, using either NumPy or JAX.
110109
111110
Parameters
112111
----------
113-
points : (N_pts, 2) array
114-
simplices : (N_tris, 3) array of triangle vertex indices
112+
points : (N, 2)
113+
simplices : (M, 3)
115114
xp : np or jnp
116-
Backend to use.
117115
118116
Returns
119117
-------
120-
vertex_area : (N_pts,) array
121-
Estimated area associated with each vertex.
118+
circumcenters : (M, 2)
122119
"""
120+
pts = xp.asarray(points)
121+
tris = xp.asarray(simplices, dtype=int)
123122

124-
# Triangle vertices
125-
p0 = points[simplices[:, 0]] # (N_tris, 2)
126-
p1 = points[simplices[:, 1]]
127-
p2 = points[simplices[:, 2]]
123+
tri_pts = pts[tris] # (M, 3, 2)
128124

129-
# Compute triangle areas (vectorized)
130-
tri_area = 0.5 * xp.abs(
131-
(p1[:, 0] - p0[:, 0]) * (p2[:, 1] - p0[:, 1])
132-
- (p1[:, 1] - p0[:, 1]) * (p2[:, 0] - p0[:, 0])
133-
) # (N_tris,)
125+
x0 = tri_pts[:, 0, 0]
126+
y0 = tri_pts[:, 0, 1]
127+
x1 = tri_pts[:, 1, 0]
128+
y1 = tri_pts[:, 1, 1]
129+
x2 = tri_pts[:, 2, 0]
130+
y2 = tri_pts[:, 2, 1]
134131

135-
# Area contribution to each vertex: (N_tris, 3)
136-
contrib = (tri_area / 3.0)[:, None] * xp.ones((1, 3))
132+
a = 2.0 * (x1 - x0)
133+
b = 2.0 * (y1 - y0)
134+
c = 2.0 * (x2 - x0)
135+
d = 2.0 * (y2 - y0)
137136

138-
# Flatten for scatter:
139-
# Each triangle contributes 3 entries, one per vertex.
140-
scatter_idx = simplices.reshape(-1) # (3*N_tris,)
141-
scatter_vals = contrib.reshape(-1) # (3*N_tris,)
137+
v1 = (x1**2 + y1**2) - (x0**2 + y0**2)
138+
v2 = (x2**2 + y2**2) - (x0**2 + y0**2)
142139

143-
# Allocate output
144-
n_pts = points.shape[0]
145-
vertex_area = xp.zeros(n_pts)
140+
det = a * d - b * c
141+
detx = v1 * d - v2 * b
142+
dety = a * v2 - c * v1
146143

147-
# Scatter-add: NumPy and JAX both support this API!
148-
if xp.__name__.startswith("jax"):
149-
vertex_area = vertex_area.at[scatter_idx].add(scatter_vals)
144+
x = detx / det
145+
y = dety / det
146+
147+
centers = xp.stack([x, y], axis=1) # (M, 2)
148+
return centers
149+
150+
151+
def voronoi_areas_via_delaunay_from(points, simplices, xp=np):
152+
"""
153+
Compute 'Voronoi-ish' cell areas for each vertex in a 2D Delaunay triangulation.
154+
155+
For each vertex v:
156+
- find all incident triangles
157+
- get their circumcenters
158+
- sort these circumcenters by angle around v
159+
- compute polygon area by the shoelace formula
160+
161+
Parameters
162+
----------
163+
points : (N, 2)
164+
Delaunay vertices.
165+
simplices : (M, 3)
166+
Delaunay triangles (indices into `points`).
167+
xp : np or jnp
168+
Backend.
169+
170+
Returns
171+
-------
172+
areas : (N,)
173+
Voronoi-ish area associated with each vertex.
174+
"""
175+
176+
pts = xp.asarray(points)
177+
tris = xp.asarray(simplices, dtype=int)
178+
N = pts.shape[0]
179+
M = tris.shape[0]
180+
181+
# 1) Circumcenters for all triangles
182+
centers = circumcenters_from(pts, tris, xp=xp) # (M, 2)
183+
184+
# 2) Build a flattened vertex-triangle incidence
185+
# Each triangle contributes its circumcenter to 3 vertices.
186+
vert_ids = tris.reshape(-1) # (3M,)
187+
centers_rep = xp.repeat(centers, 3, axis=0) # (3M, 2)
188+
189+
# 3) Sort by vertex id so all entries for a given vertex are contiguous
190+
order = xp.argsort(vert_ids)
191+
vert_sorted = vert_ids[order] # (3M,)
192+
centers_sorted = centers_rep[order] # (3M, 2)
193+
194+
# 4) Compute how many triangles are incident to each vertex
195+
if xp is np:
196+
counts = np.bincount(vert_sorted, minlength=N) # (N,)
197+
else:
198+
counts = jnp.bincount(vert_sorted, length=N) # (N,)
199+
200+
max_deg = int(counts.max()) if xp is np else int(counts.max().item())
201+
202+
# 5) Compute start index for each vertex's block in vert_sorted
203+
# start[v] = cumulative sum of counts up to v
204+
if xp is np:
205+
start = np.concatenate([np.array([0], dtype=int), np.cumsum(counts[:-1])])
150206
else:
151-
np.add.at(vertex_area, scatter_idx, scatter_vals)
207+
start = jnp.concatenate([jnp.array([0], dtype=int), jnp.cumsum(counts[:-1])])
152208

153-
return vertex_area
209+
# Global indices 0..3M-1
210+
if xp is np:
211+
arange_all = np.arange(3 * M, dtype=int)
212+
else:
213+
arange_all = jnp.arange(3 * M, dtype=int)
214+
215+
# Position within each vertex block: pos = i - start[vertex]
216+
start_per_entry = start[vert_sorted] # (3M,)
217+
pos = arange_all - start_per_entry # (3M,)
218+
219+
# 6) Scatter into a padded (N, max_deg, 2) array of circumcenters
220+
if xp is np:
221+
circum_padded = np.zeros((N, max_deg, 2), dtype=pts.dtype)
222+
circum_padded[vert_sorted, pos, :] = centers_sorted
223+
else:
224+
circum_padded = jnp.zeros((N, max_deg, 2), dtype=pts.dtype)
225+
circum_padded = circum_padded.at[vert_sorted, pos, :].set(centers_sorted)
226+
227+
# 7) For each vertex, sort its circumcenters by angle around the vertex
228+
# Compute angles: (N, max_deg)
229+
dx = circum_padded[..., 0] - pts[:, None, 0]
230+
dy = circum_padded[..., 1] - pts[:, None, 1]
231+
angles = xp.arctan2(dy, dx)
232+
233+
# Mark which slots are valid (j < count[v])
234+
if xp is np:
235+
j_idx = np.arange(max_deg)[None, :] # (1, max_deg)
236+
else:
237+
j_idx = jnp.arange(max_deg)[None, :]
238+
valid_mask = j_idx < counts[:, None] # (N, max_deg)
239+
240+
# For invalid entries, set angle to a big constant so they go to the end
241+
big_angle = xp.array(1e9, dtype=angles.dtype)
242+
angles_masked = xp.where(valid_mask, angles, big_angle)
243+
244+
# Sort indices by angle for each vertex
245+
order_angles = xp.argsort(angles_masked, axis=1) # (N, max_deg)
246+
247+
# Reorder circumcenters accordingly
248+
if xp is np:
249+
centers_sorted2 = np.take_along_axis(
250+
circum_padded,
251+
order_angles[..., None].repeat(2, axis=2),
252+
axis=1,
253+
) # (N, max_deg, 2)
254+
else:
255+
centers_sorted2 = jnp.take_along_axis(
256+
circum_padded,
257+
jnp.repeat(order_angles[..., None], 2, axis=2),
258+
axis=1,
259+
)
260+
261+
# 8) Compute polygon area with shoelace formula per vertex
262+
x = centers_sorted2[..., 0] # (N, max_deg)
263+
y = centers_sorted2[..., 1] # (N, max_deg)
264+
265+
# roll by -1 so j+1 wraps around
266+
x_next = xp.roll(x, shift=-1, axis=1)
267+
y_next = xp.roll(y, shift=-1, axis=1)
268+
269+
# A contribution is valid if both current and next vertices are valid
270+
valid_pair = valid_mask & xp.roll(valid_mask, shift=-1, axis=1)
271+
272+
cross = x * y_next - x_next * y
273+
cross = xp.where(valid_pair, cross, 0.0)
274+
275+
area = 0.5 * xp.abs(xp.sum(cross, axis=1)) # (N,)
276+
277+
return area
154278

155279

156280
def split_points_from(points, area_weights, xp=np):
@@ -209,11 +333,10 @@ def split_points_from(points, area_weights, xp=np):
209333

210334
class DelaunayInterface:
211335

212-
def __init__(self, ppoints, simplices, areas, vertex_neighbor_vertices):
336+
def __init__(self, ppoints, simplices, vertex_neighbor_vertices):
213337

214338
self.points = ppoints
215339
self.simplices = simplices
216-
self.areas = areas
217340
self.vertex_neighbor_vertices = vertex_neighbor_vertices
218341

219342
def find_simplex(self, query_points):
@@ -309,11 +432,8 @@ def delaunay(self) -> "scipy.spatial.Delaunay":
309432
simplices = delaunay.simplices.astype(np.int32)
310433
vertex_neighbor_vertices = delaunay.vertex_neighbor_vertices
311434

312-
areas = vertex_areas_from_delaunay(
313-
points=points, simplices=simplices, xp=self._xp
314-
)
315435

316-
return DelaunayInterface(points, simplices, areas, vertex_neighbor_vertices)
436+
return DelaunayInterface(points, simplices, vertex_neighbor_vertices)
317437

318438
@property
319439
def voronoi(self) -> "scipy.spatial.Voronoi":

test_autoarray/structures/mesh/test_delaunay.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import autoarray as aa
55

6-
from autoarray.structures.mesh.triangulation_2d import vertex_areas_from_delaunay
6+
from autoarray.structures.mesh.triangulation_2d import voronoi_areas_via_delaunay_from
77

88
def test__edge_pixel_list():
99
grid = np.array(
@@ -60,3 +60,44 @@ def test__interpolated_array_from():
6060
np.array([[1.0, 1.907216], [1.0, 1.0], [1.0, 1.0]]), 1.0e-4
6161
)
6262

63+
64+
def test__voronoi_areas_via_delaunay_from():
65+
66+
import scipy.spatial
67+
68+
mesh_grid = np.array([[0.0, 0.0], [1.1, 0.6], [2.1, 0.1], [0.4, 1.1], [1.1, 7.1], [2.1, 1.1]])
69+
70+
delaunay = scipy.spatial.Delaunay(mesh_grid)
71+
72+
voronoi_areas = voronoi_areas_via_delaunay_from(
73+
mesh_grid, delaunay.simplices,
74+
)
75+
76+
voronoi = scipy.spatial.Voronoi(
77+
mesh_grid,
78+
qhull_options="Qbb Qc Qx Qm",
79+
)
80+
81+
voronoi_vertices = voronoi.vertices
82+
voronoi_regions = voronoi.regions
83+
voronoi_point_region = voronoi.point_region
84+
85+
pixels = mesh_grid.shape[0]
86+
87+
region_areas = np.zeros(pixels)
88+
89+
for i in range(pixels):
90+
region_vertices_indexes = voronoi_regions[voronoi_point_region[i]]
91+
if -1 in region_vertices_indexes:
92+
region_areas[i] = -1
93+
else:
94+
region_areas[i] = aa.util.grid_2d.compute_polygon_area(
95+
voronoi_vertices[region_vertices_indexes]
96+
)
97+
98+
assert voronoi_areas[1] == pytest.approx(region_areas[1], 1.0e-4)
99+
assert voronoi_areas[3] == pytest.approx(region_areas[3], 1.0e-4)
100+
101+
# Old Voronoi cell code put -1 in edge pixels, new code puts large area
102+
103+
assert voronoi_areas[4] == pytest.approx(32.83847776, 1.0e-4)

0 commit comments

Comments
 (0)