Skip to content

Commit 58b314c

Browse files
Jammy2211claude
authored andcommitted
Fix InterpolatorKNNBarycentric writability + split-mappings padding
Bugfixes surfaced while integrating InterpolatorKNNBarycentric into the autolens FitImaging eager numpy path: - `_mappings_sizes_weights` was returning np.asarray() of jax outputs, which gives a read-only view. ConstantSplit's `reg_split_np_from` uses in-place assignment and raised "assignment destination is read-only". On the numpy path, materialize with np.array() to get a writable buffer. - `_mappings_sizes_weights_split` returned a k=3 mappings array. The ConstantSplit code writes `splitted_mappings[i][j+1]` for the central pixel insertion, requiring an extra reserved column. Match InterpolatorDelaunay's hstack-append pattern. - Add `import numpy as np` at module scope (was implicitly available through `jax.numpy as jnp` aliasing for Wendland-only paths). - New regression-guard test confirms numpy-path mappings/weights are writable. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
1 parent f59ae1d commit 58b314c

2 files changed

Lines changed: 64 additions & 2 deletions

File tree

autoarray/inversion/mesh/interpolator/knn.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import numpy as np
2+
13
from autoconf import cached_property
24

35
from autoarray.inversion.mesh.interpolator.delaunay import InterpolatorDelaunay
@@ -357,8 +359,16 @@ def _mappings_sizes_weights(self):
357359
xp=self._xp,
358360
)
359361

360-
mappings = self._xp.asarray(mappings)
361-
weights = self._xp.asarray(weights)
362+
# On the numpy path, materialize with `np.array(...)` so the regularization
363+
# code (which uses in-place assignment, e.g. `reg_split_np_from`) gets a
364+
# writable buffer rather than a read-only view of a jax.Array. On the jax
365+
# path, asarray is the right cast (no copy in a JIT trace).
366+
if self._xp is np:
367+
mappings = np.array(mappings)
368+
weights = np.array(weights)
369+
else:
370+
mappings = self._xp.asarray(mappings)
371+
weights = self._xp.asarray(weights)
362372

363373
sizes = self._xp.full(
364374
(mappings.shape[0],),
@@ -404,9 +414,20 @@ def _mappings_sizes_weights_split(self):
404414
mappings = interpolator.mappings
405415
weights = interpolator.weights
406416

417+
# `reg_split_np_from` writes `splitted_mappings[i][j+1] = pixel_index`
418+
# for the "flag-zero" insertion of the central pixel, so the buffer
419+
# must have an extra column reserved past the k=3 mappings — matching
420+
# `InterpolatorDelaunay._mappings_sizes_weights_split`'s hstack-append.
421+
# `sizes` reports 3 (the actual mappings); `reg_split_np_from` grows it
422+
# to 4 in-place when it inserts.
407423
sizes = self._xp.full(
408424
(mappings.shape[0],),
409425
mappings.shape[1],
410426
)
411427

428+
pad_int = self._xp.full((mappings.shape[0], 1), -1, dtype=mappings.dtype)
429+
pad_float = self._xp.zeros((weights.shape[0], 1), dtype=weights.dtype)
430+
mappings = self._xp.hstack((mappings, pad_int))
431+
weights = self._xp.hstack((weights, pad_float))
432+
412433
return mappings, sizes, weights

test_autoarray/inversion/pixelization/interpolator/test_knn_barycentric.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,44 @@ def test__mesh_class__inherits_knn_knobs():
127127
assert mesh.radius_scale == 1.5
128128
assert mesh.areas_factor == 0.5
129129
assert mesh.split_neighbor_division == 2
130+
131+
132+
def test__interpolator__numpy_path_returns_writable_mappings_and_weights():
133+
"""
134+
Regression guard: on the numpy path, `mappings` and `weights` must be
135+
writable numpy arrays — not read-only views of jax.Arrays. The
136+
`ConstantSplit` / `AdaptSplit` regularization code uses in-place
137+
assignment (e.g. `splitted_mappings[i][j+1] = ...` in `reg_split_np_from`),
138+
which raises `ValueError: assignment destination is read-only` if the
139+
array is a numpy view onto a jax.Array.
140+
"""
141+
import autoarray as aa
142+
from autoarray.inversion.mesh.interpolator.knn import (
143+
InterpolatorKNNBarycentric,
144+
)
145+
146+
mesh_grid = aa.Grid2D.no_mask(
147+
values=[[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5], [-0.5, -0.5]],
148+
shape_native=(3, 2),
149+
pixel_scales=1.0,
150+
over_sample_size=1,
151+
)
152+
153+
class _Raw:
154+
def __init__(self, arr):
155+
self.array = arr
156+
157+
queries = np.array([[0.3, 0.3], [0.0, 0.0], [1.0, 1.0], [0.5, 0.5]])
158+
159+
interp = InterpolatorKNNBarycentric(
160+
mesh=aa.mesh.KNNBarycentric(pixels=6),
161+
mesh_grid=mesh_grid,
162+
data_grid=_Raw(queries),
163+
xp=np,
164+
)
165+
mappings, _, weights = interp._mappings_sizes_weights
166+
167+
assert isinstance(mappings, np.ndarray)
168+
assert isinstance(weights, np.ndarray)
169+
assert mappings.flags.writeable, "mappings must be writable (regularization mutates it)"
170+
assert weights.flags.writeable, "weights must be writable (regularization mutates it)"

0 commit comments

Comments
 (0)