Skip to content

Commit c19dd09

Browse files
Jammy2211Jammy2211
authored andcommitted
scipy debug checks
1 parent 8a9a797 commit c19dd09

1 file changed

Lines changed: 39 additions & 55 deletions

File tree

autoarray/structures/mesh/delaunay_2d.py

Lines changed: 39 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -13,71 +13,57 @@
1313
from autoarray import exc
1414

1515

16-
def scipy_delaunay(points_np, query_points_np, use_voronoi_areas, areas_factor):
17-
"""Compute Delaunay simplices (simplices_padded) and Voronoi areas in one call."""
18-
19-
max_simplices = 2 * points_np.shape[0]
20-
21-
# --- Delaunay mesh using source plane data grid ---
22-
tri = Delaunay(points_np)
23-
24-
points = tri.points.astype(points_np.dtype)
25-
simplices = tri.simplices.astype(np.int32)
26-
27-
# Pad simplices to max_simplices
28-
simplices_padded = -np.ones((max_simplices, 3), dtype=np.int32)
29-
simplices_padded[: simplices.shape[0]] = simplices
30-
31-
# ---------- find_simplex for source plane data grid ----------
32-
simplex_idx = tri.find_simplex(query_points_np).astype(np.int32) # (Q,)
33-
34-
mappings = pix_indexes_for_sub_slim_index_delaunay_from(
35-
source_plane_data_grid=query_points_np,
36-
simplex_index_for_sub_slim_index=simplex_idx,
37-
pix_indexes_for_simplex_index=simplices,
38-
delaunay_points=points_np,
16+
def _debug_host_array(name: str, x):
17+
"""
18+
Print shape/dtype/strides and nan/inf counts for a host-side array-like.
19+
Safe to call inside pure_callback.
20+
"""
21+
arr = np.asarray(x) # IMPORTANT: raw view BEFORE any dtype/contiguous conversion
22+
finite = np.isfinite(arr)
23+
n_nan = np.isnan(arr).sum()
24+
n_inf = np.isinf(arr).sum()
25+
n_bad = (~finite).sum()
26+
27+
print(
28+
f"[pure_callback] {name}: shape={arr.shape} dtype={arr.dtype} strides={arr.strides} "
29+
f"n_bad={n_bad} n_nan={n_nan} n_inf={n_inf}"
3930
)
4031

41-
# ---------- Voronoi or Barycentric Areas used to weight split points ----------
32+
# Print a tiny sample for sanity (won't crash on empty)
33+
flat = arr.reshape(-1)
34+
head = flat[:10] if flat.size >= 10 else flat
35+
print(f"[pure_callback] {name}: head={head}")
4236

43-
if use_voronoi_areas:
37+
return arr
4438

45-
areas = voronoi_areas_numpy(
46-
points,
47-
)
4839

49-
max_area = np.percentile(areas, 90.0)
5040

51-
areas[areas == -1] = max_area
52-
areas[areas > max_area] = max_area
41+
def scipy_delaunay(points_np, query_points_np, use_voronoi_areas, areas_factor):
42+
"""Compute Delaunay simplices (simplices_padded) and Voronoi areas in one call."""
5343

54-
else:
44+
# --- Debug: what did the callback actually receive? ---
45+
points_raw = _debug_host_array("points_raw", points_np)
46+
qpts_raw = _debug_host_array("qpts_raw", query_points_np)
5547

56-
areas = barycentric_dual_area_from(
57-
points,
58-
simplices,
59-
xp=np,
48+
# If anything is non-finite, save inputs to replay and crash immediately.
49+
if (not np.isfinite(points_raw).all()) or (not np.isfinite(qpts_raw).all()):
50+
np.savez("callback_bad_inputs.npz", points=points_raw, qpts=qpts_raw)
51+
raise FloatingPointError(
52+
"Non-finite values at pure_callback entry; saved callback_bad_inputs.npz"
6053
)
6154

62-
split_point_areas = areas_factor * np.sqrt(areas)
55+
# (Optional but helpful) enforce expected rank early:
56+
if points_raw.ndim != 2 or points_raw.shape[1] != 2:
57+
raise ValueError(f"points_raw unexpected shape {points_raw.shape}")
58+
if qpts_raw.ndim != 2 or qpts_raw.shape[1] != 2:
59+
raise ValueError(f"qpts_raw unexpected shape {qpts_raw.shape}")
6360

64-
# ---------- Compute split cross points for Split regularization ----------
65-
split_points = split_points_from(
66-
points=points_np,
67-
area_weights=split_point_areas,
68-
)
69-
70-
# ---------- find_simplex for split cross points ----------
71-
split_points_idx = tri.find_simplex(split_points)
61+
# Continue using the raw arrays
62+
points_np = points_raw
63+
query_points_np = qpts_raw
7264

73-
splitted_mappings = pix_indexes_for_sub_slim_index_delaunay_from(
74-
source_plane_data_grid=split_points,
75-
simplex_index_for_sub_slim_index=split_points_idx,
76-
pix_indexes_for_simplex_index=simplices,
77-
delaunay_points=points_np,
78-
)
65+
max_simplices = 2 * points_np.shape[0]
7966

80-
return points, simplices_padded, mappings, split_points, splitted_mappings
8167

8268

8369
def jax_delaunay(points, query_points, use_voronoi_areas, areas_factor=0.5):
@@ -95,9 +81,7 @@ def jax_delaunay(points, query_points, use_voronoi_areas, areas_factor=0.5):
9581
splitted_mappings_shape = jax.ShapeDtypeStruct((N * 4, 3), jnp.int32)
9682

9783
return jax.pure_callback(
98-
lambda points, qpts: scipy_delaunay(
99-
np.asarray(points), np.asarray(qpts), use_voronoi_areas, areas_factor
100-
),
84+
lambda points, qpts: scipy_delaunay(points, qpts, use_voronoi_areas, areas_factor),
10185
(
10286
points_shape,
10387
simplices_padded_shape,

0 commit comments

Comments
 (0)