1313from autoarray import exc
1414
1515
16- def scipy_delaunay (points_np , query_points_np , use_voronoi_areas , areas_factor ):
17- """Compute Delaunay simplices (simplices_padded) and Voronoi areas in one call."""
18-
19- max_simplices = 2 * points_np .shape [0 ]
20-
21- # --- Delaunay mesh using source plane data grid ---
22- tri = Delaunay (points_np )
23-
24- points = tri .points .astype (points_np .dtype )
25- simplices = tri .simplices .astype (np .int32 )
26-
27- # Pad simplices to max_simplices
28- simplices_padded = - np .ones ((max_simplices , 3 ), dtype = np .int32 )
29- simplices_padded [: simplices .shape [0 ]] = simplices
30-
31- # ---------- find_simplex for source plane data grid ----------
32- simplex_idx = tri .find_simplex (query_points_np ).astype (np .int32 ) # (Q,)
33-
34- mappings = pix_indexes_for_sub_slim_index_delaunay_from (
35- source_plane_data_grid = query_points_np ,
36- simplex_index_for_sub_slim_index = simplex_idx ,
37- pix_indexes_for_simplex_index = simplices ,
38- delaunay_points = points_np ,
16+ def _debug_host_array (name : str , x ):
17+ """
18+ Print shape/dtype/strides and nan/inf counts for a host-side array-like.
19+ Safe to call inside pure_callback.
20+ """
21+ arr = np .asarray (x ) # IMPORTANT: raw view BEFORE any dtype/contiguous conversion
22+ finite = np .isfinite (arr )
23+ n_nan = np .isnan (arr ).sum ()
24+ n_inf = np .isinf (arr ).sum ()
25+ n_bad = (~ finite ).sum ()
26+
27+ print (
28+ f"[pure_callback] { name } : shape={ arr .shape } dtype={ arr .dtype } strides={ arr .strides } "
29+ f"n_bad={ n_bad } n_nan={ n_nan } n_inf={ n_inf } "
3930 )
4031
41- # ---------- Voronoi or Barycentric Areas used to weight split points ----------
32+ # Print a tiny sample for sanity (won't crash on empty)
33+ flat = arr .reshape (- 1 )
34+ head = flat [:10 ] if flat .size >= 10 else flat
35+ print (f"[pure_callback] { name } : head={ head } " )
4236
43- if use_voronoi_areas :
37+ return arr
4438
45- areas = voronoi_areas_numpy (
46- points ,
47- )
4839
49- max_area = np .percentile (areas , 90.0 )
5040
51- areas [ areas == - 1 ] = max_area
52- areas [ areas > max_area ] = max_area
41+ def scipy_delaunay ( points_np , query_points_np , use_voronoi_areas , areas_factor ):
42+ """Compute Delaunay simplices (simplices_padded) and Voronoi areas in one call."""
5343
54- else :
44+ # --- Debug: what did the callback actually receive? ---
45+ points_raw = _debug_host_array ("points_raw" , points_np )
46+ qpts_raw = _debug_host_array ("qpts_raw" , query_points_np )
5547
56- areas = barycentric_dual_area_from (
57- points ,
58- simplices ,
59- xp = np ,
48+ # If anything is non-finite, save inputs to replay and crash immediately.
49+ if (not np .isfinite (points_raw ).all ()) or (not np .isfinite (qpts_raw ).all ()):
50+ np .savez ("callback_bad_inputs.npz" , points = points_raw , qpts = qpts_raw )
51+ raise FloatingPointError (
52+ "Non-finite values at pure_callback entry; saved callback_bad_inputs.npz"
6053 )
6154
62- split_point_areas = areas_factor * np .sqrt (areas )
55+ # (Optional but helpful) enforce expected rank early:
56+ if points_raw .ndim != 2 or points_raw .shape [1 ] != 2 :
57+ raise ValueError (f"points_raw unexpected shape { points_raw .shape } " )
58+ if qpts_raw .ndim != 2 or qpts_raw .shape [1 ] != 2 :
59+ raise ValueError (f"qpts_raw unexpected shape { qpts_raw .shape } " )
6360
64- # ---------- Compute split cross points for Split regularization ----------
65- split_points = split_points_from (
66- points = points_np ,
67- area_weights = split_point_areas ,
68- )
69-
70- # ---------- find_simplex for split cross points ----------
71- split_points_idx = tri .find_simplex (split_points )
61+ # Continue using the raw arrays
62+ points_np = points_raw
63+ query_points_np = qpts_raw
7264
73- splitted_mappings = pix_indexes_for_sub_slim_index_delaunay_from (
74- source_plane_data_grid = split_points ,
75- simplex_index_for_sub_slim_index = split_points_idx ,
76- pix_indexes_for_simplex_index = simplices ,
77- delaunay_points = points_np ,
78- )
65+ max_simplices = 2 * points_np .shape [0 ]
7966
80- return points , simplices_padded , mappings , split_points , splitted_mappings
8167
8268
8369def jax_delaunay (points , query_points , use_voronoi_areas , areas_factor = 0.5 ):
@@ -95,9 +81,7 @@ def jax_delaunay(points, query_points, use_voronoi_areas, areas_factor=0.5):
9581 splitted_mappings_shape = jax .ShapeDtypeStruct ((N * 4 , 3 ), jnp .int32 )
9682
9783 return jax .pure_callback (
98- lambda points , qpts : scipy_delaunay (
99- np .asarray (points ), np .asarray (qpts ), use_voronoi_areas , areas_factor
100- ),
84+ lambda points , qpts : scipy_delaunay (points , qpts , use_voronoi_areas , areas_factor ),
10185 (
10286 points_shape ,
10387 simplices_padded_shape ,
0 commit comments