Skip to content

Commit 2ff2fc0

Browse files
Jammy2211claude
authored andcommitted
fix: vmap-blocker bugs in convert.py and FitEllipse cached_property
Two bugs uncovered when validating fitness._vmap(parameters) on the new JAX path: 1. autogalaxy/convert.py line ~65: (a == 0) & (b == 0) on JAX scalar tracers triggers TracerArrayConversionError via Python's __and__. Switch to xp.logical_and so the bool combination stays in JAX-land. Fixed all four occurrences: axis_ratio_and_angle_from, two spots in shear_magnitude_and_angle_from, and multipole_k_m_and_phi_m_from. 2. autogalaxy/ellipse/fit_ellipse.py: _points_from_major_axis was decorated @cached_property. Caching a tracer in self.__dict__ breaks under vmap (different batch elements share one stale cache). Switch to @Property and recompute on access. The numpy path takes a ~2x hit on this property but the absolute cost is negligible. Numpy-path numerics unchanged. 870/870 unit tests pass. Issue PyAutoGalaxy#411. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 1f4a81a commit 2ff2fc0

2 files changed

Lines changed: 6 additions & 6 deletions

File tree

autogalaxy/convert.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def axis_ratio_and_angle_from(
6262
"""
6363
angle = 0.5 * xp.arctan2(
6464
ell_comps[0],
65-
xp.where((ell_comps[0] == 0) & (ell_comps[1] == 0), 1.0, ell_comps[1]),
65+
xp.where(xp.logical_and(ell_comps[0] == 0, ell_comps[1] == 0), 1.0, ell_comps[1]),
6666
)
6767
angle *= 180.0 / xp.pi
6868

@@ -202,15 +202,15 @@ def shear_magnitude_and_angle_from(
202202
"""
203203
angle = (
204204
0.5
205-
* xp.arctan2(gamma_2, xp.where((gamma_1 == 0) & (gamma_2 == 0), 1.0, gamma_1))
205+
* xp.arctan2(gamma_2, xp.where(xp.logical_and(gamma_1 == 0, gamma_2 == 0), 1.0, gamma_1))
206206
* 180.0
207207
/ xp.pi
208208
)
209209
magnitude = xp.sqrt(gamma_1**2 + gamma_2**2)
210210

211211
angle = xp.where(angle < 0, angle + 180.0, angle)
212212
angle = xp.where(
213-
(xp.abs(angle - 90.0) > 45.0) & (angle > 90.0), angle - 180.0, angle
213+
xp.logical_and(xp.abs(angle - 90.0) > 45.0, angle > 90.0), angle - 180.0, angle
214214
)
215215

216216
return magnitude, angle
@@ -309,7 +309,7 @@ def multipole_k_m_and_phi_m_from(
309309
xp.arctan2(
310310
multipole_comps[0],
311311
xp.where(
312-
(multipole_comps[0] == 0) & (multipole_comps[1] == 0),
312+
xp.logical_and(multipole_comps[0] == 0, multipole_comps[1] == 0),
313313
1.0,
314314
multipole_comps[1],
315315
),

autogalaxy/ellipse/fit_ellipse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,10 @@ def points_from_major_axis_from(self) -> "np.ndarray | jax.Array":
123123
keep = mask_values == 0
124124
return xp.where(keep[:, None], points, xp.nan)
125125

126-
@cached_property
126+
@property
127127
def _points_from_major_axis(self) -> np.ndarray:
128128
"""
129-
Returns cached (y,x) coordinates on the ellipse that are used to interpolate the data and noise-map values.
129+
Returns (y,x) coordinates on the ellipse that are used to interpolate the data and noise-map values.
130130
131131
Returns
132132
-------

0 commit comments

Comments
 (0)