@@ -103,54 +103,178 @@ def find_simplex_from(query_points, points, simplices):
103103 return simplex_idx
104104
105105
106- def vertex_areas_from_delaunay (points , simplices , xp = np ):
106+ def circumcenters_from (points , simplices , xp = np ):
107107 """
108- Compute per-vertex areas using:
109- vertex_area[v] = sum(area(triangles incident to v)) / 3
108+ Compute triangle circumcenters for a Delaunay triangulation, using either NumPy or JAX.
110109
111110 Parameters
112111 ----------
113- points : (N_pts , 2) array
114- simplices : (N_tris , 3) array of triangle vertex indices
112+ points : (N , 2)
113+ simplices : (M , 3)
115114 xp : np or jnp
116- Backend to use.
117115
118116 Returns
119117 -------
120- vertex_area : (N_pts,) array
121- Estimated area associated with each vertex.
118+ circumcenters : (M, 2)
122119 """
120+ pts = xp .asarray (points )
121+ tris = xp .asarray (simplices , dtype = int )
123122
124- # Triangle vertices
125- p0 = points [simplices [:, 0 ]] # (N_tris, 2)
126- p1 = points [simplices [:, 1 ]]
127- p2 = points [simplices [:, 2 ]]
123+ tri_pts = pts [tris ] # (M, 3, 2)
128124
129- # Compute triangle areas (vectorized)
130- tri_area = 0.5 * xp .abs (
131- (p1 [:, 0 ] - p0 [:, 0 ]) * (p2 [:, 1 ] - p0 [:, 1 ])
132- - (p1 [:, 1 ] - p0 [:, 1 ]) * (p2 [:, 0 ] - p0 [:, 0 ])
133- ) # (N_tris,)
125+ x0 = tri_pts [:, 0 , 0 ]
126+ y0 = tri_pts [:, 0 , 1 ]
127+ x1 = tri_pts [:, 1 , 0 ]
128+ y1 = tri_pts [:, 1 , 1 ]
129+ x2 = tri_pts [:, 2 , 0 ]
130+ y2 = tri_pts [:, 2 , 1 ]
134131
135- # Area contribution to each vertex: (N_tris, 3)
136- contrib = (tri_area / 3.0 )[:, None ] * xp .ones ((1 , 3 ))
132+ a = 2.0 * (x1 - x0 )
133+ b = 2.0 * (y1 - y0 )
134+ c = 2.0 * (x2 - x0 )
135+ d = 2.0 * (y2 - y0 )
137136
138- # Flatten for scatter:
139- # Each triangle contributes 3 entries, one per vertex.
140- scatter_idx = simplices .reshape (- 1 ) # (3*N_tris,)
141- scatter_vals = contrib .reshape (- 1 ) # (3*N_tris,)
137+ v1 = (x1 ** 2 + y1 ** 2 ) - (x0 ** 2 + y0 ** 2 )
138+ v2 = (x2 ** 2 + y2 ** 2 ) - (x0 ** 2 + y0 ** 2 )
142139
143- # Allocate output
144- n_pts = points . shape [ 0 ]
145- vertex_area = xp . zeros ( n_pts )
140+ det = a * d - b * c
141+ detx = v1 * d - v2 * b
142+ dety = a * v2 - c * v1
146143
147- # Scatter-add: NumPy and JAX both support this API!
148- if xp .__name__ .startswith ("jax" ):
149- vertex_area = vertex_area .at [scatter_idx ].add (scatter_vals )
144+ x = detx / det
145+ y = dety / det
146+
147+ centers = xp .stack ([x , y ], axis = 1 ) # (M, 2)
148+ return centers
149+
150+
151+ def voronoi_areas_via_delaunay_from (points , simplices , xp = np ):
152+ """
153+ Compute 'Voronoi-ish' cell areas for each vertex in a 2D Delaunay triangulation.
154+
155+ For each vertex v:
156+ - find all incident triangles
157+ - get their circumcenters
158+ - sort these circumcenters by angle around v
159+ - compute polygon area by the shoelace formula
160+
161+ Parameters
162+ ----------
163+ points : (N, 2)
164+ Delaunay vertices.
165+ simplices : (M, 3)
166+ Delaunay triangles (indices into `points`).
167+ xp : np or jnp
168+ Backend.
169+
170+ Returns
171+ -------
172+ areas : (N,)
173+ Voronoi-ish area associated with each vertex.
174+ """
175+
176+ pts = xp .asarray (points )
177+ tris = xp .asarray (simplices , dtype = int )
178+ N = pts .shape [0 ]
179+ M = tris .shape [0 ]
180+
181+ # 1) Circumcenters for all triangles
182+ centers = circumcenters_from (pts , tris , xp = xp ) # (M, 2)
183+
184+ # 2) Build a flattened vertex-triangle incidence
185+ # Each triangle contributes its circumcenter to 3 vertices.
186+ vert_ids = tris .reshape (- 1 ) # (3M,)
187+ centers_rep = xp .repeat (centers , 3 , axis = 0 ) # (3M, 2)
188+
189+ # 3) Sort by vertex id so all entries for a given vertex are contiguous
190+ order = xp .argsort (vert_ids )
191+ vert_sorted = vert_ids [order ] # (3M,)
192+ centers_sorted = centers_rep [order ] # (3M, 2)
193+
194+ # 4) Compute how many triangles are incident to each vertex
195+ if xp is np :
196+ counts = np .bincount (vert_sorted , minlength = N ) # (N,)
197+ else :
198+ counts = jnp .bincount (vert_sorted , length = N ) # (N,)
199+
200+ max_deg = int (counts .max ()) if xp is np else int (counts .max ().item ())
201+
202+ # 5) Compute start index for each vertex's block in vert_sorted
203+ # start[v] = cumulative sum of counts up to v
204+ if xp is np :
205+ start = np .concatenate ([np .array ([0 ], dtype = int ), np .cumsum (counts [:- 1 ])])
150206 else :
151- np . add . at ( vertex_area , scatter_idx , scatter_vals )
207+ start = jnp . concatenate ([ jnp . array ([ 0 ], dtype = int ), jnp . cumsum ( counts [: - 1 ])] )
152208
153- return vertex_area
209+ # Global indices 0..3M-1
210+ if xp is np :
211+ arange_all = np .arange (3 * M , dtype = int )
212+ else :
213+ arange_all = jnp .arange (3 * M , dtype = int )
214+
215+ # Position within each vertex block: pos = i - start[vertex]
216+ start_per_entry = start [vert_sorted ] # (3M,)
217+ pos = arange_all - start_per_entry # (3M,)
218+
219+ # 6) Scatter into a padded (N, max_deg, 2) array of circumcenters
220+ if xp is np :
221+ circum_padded = np .zeros ((N , max_deg , 2 ), dtype = pts .dtype )
222+ circum_padded [vert_sorted , pos , :] = centers_sorted
223+ else :
224+ circum_padded = jnp .zeros ((N , max_deg , 2 ), dtype = pts .dtype )
225+ circum_padded = circum_padded .at [vert_sorted , pos , :].set (centers_sorted )
226+
227+ # 7) For each vertex, sort its circumcenters by angle around the vertex
228+ # Compute angles: (N, max_deg)
229+ dx = circum_padded [..., 0 ] - pts [:, None , 0 ]
230+ dy = circum_padded [..., 1 ] - pts [:, None , 1 ]
231+ angles = xp .arctan2 (dy , dx )
232+
233+ # Mark which slots are valid (j < count[v])
234+ if xp is np :
235+ j_idx = np .arange (max_deg )[None , :] # (1, max_deg)
236+ else :
237+ j_idx = jnp .arange (max_deg )[None , :]
238+ valid_mask = j_idx < counts [:, None ] # (N, max_deg)
239+
240+ # For invalid entries, set angle to a big constant so they go to the end
241+ big_angle = xp .array (1e9 , dtype = angles .dtype )
242+ angles_masked = xp .where (valid_mask , angles , big_angle )
243+
244+ # Sort indices by angle for each vertex
245+ order_angles = xp .argsort (angles_masked , axis = 1 ) # (N, max_deg)
246+
247+ # Reorder circumcenters accordingly
248+ if xp is np :
249+ centers_sorted2 = np .take_along_axis (
250+ circum_padded ,
251+ order_angles [..., None ].repeat (2 , axis = 2 ),
252+ axis = 1 ,
253+ ) # (N, max_deg, 2)
254+ else :
255+ centers_sorted2 = jnp .take_along_axis (
256+ circum_padded ,
257+ jnp .repeat (order_angles [..., None ], 2 , axis = 2 ),
258+ axis = 1 ,
259+ )
260+
261+ # 8) Compute polygon area with shoelace formula per vertex
262+ x = centers_sorted2 [..., 0 ] # (N, max_deg)
263+ y = centers_sorted2 [..., 1 ] # (N, max_deg)
264+
265+ # roll by -1 so j+1 wraps around
266+ x_next = xp .roll (x , shift = - 1 , axis = 1 )
267+ y_next = xp .roll (y , shift = - 1 , axis = 1 )
268+
269+ # A contribution is valid if both current and next vertices are valid
270+ valid_pair = valid_mask & xp .roll (valid_mask , shift = - 1 , axis = 1 )
271+
272+ cross = x * y_next - x_next * y
273+ cross = xp .where (valid_pair , cross , 0.0 )
274+
275+ area = 0.5 * xp .abs (xp .sum (cross , axis = 1 )) # (N,)
276+
277+ return area
154278
155279
156280def split_points_from (points , area_weights , xp = np ):
@@ -209,11 +333,10 @@ def split_points_from(points, area_weights, xp=np):
209333
210334class DelaunayInterface :
211335
212- def __init__ (self , ppoints , simplices , areas , vertex_neighbor_vertices ):
336+ def __init__ (self , ppoints , simplices , vertex_neighbor_vertices ):
213337
214338 self .points = ppoints
215339 self .simplices = simplices
216- self .areas = areas
217340 self .vertex_neighbor_vertices = vertex_neighbor_vertices
218341
219342 def find_simplex (self , query_points ):
@@ -309,11 +432,8 @@ def delaunay(self) -> "scipy.spatial.Delaunay":
309432 simplices = delaunay .simplices .astype (np .int32 )
310433 vertex_neighbor_vertices = delaunay .vertex_neighbor_vertices
311434
312- areas = vertex_areas_from_delaunay (
313- points = points , simplices = simplices , xp = self ._xp
314- )
315435
316- return DelaunayInterface (points , simplices , areas , vertex_neighbor_vertices )
436+ return DelaunayInterface (points , simplices , vertex_neighbor_vertices )
317437
318438 @property
319439 def voronoi (self ) -> "scipy.spatial.Voronoi" :
0 commit comments