Skip to content

Commit b12b2b9

Browse files
Jammy2211Jammy2211
authored andcommitted
fix mapped valued tests by converting to JAX
1 parent 89d3065 commit b12b2b9

3 files changed

Lines changed: 8 additions & 6 deletions

File tree

autoarray/inversion/inversion/mapper_valued.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def values_masked(self):
5757
values = self.values
5858

5959
if self.mesh_pixel_mask is not None:
60-
values[self.mesh_pixel_mask] = 0.0
60+
values = values.at[self.mesh_pixel_mask].set(0.0)
6161

6262
return values
6363

@@ -187,7 +187,7 @@ def mapped_reconstructed_image_from(
187187
mapping_matrix = self.mapper.mapping_matrix
188188

189189
if self.mesh_pixel_mask is not None:
190-
mapping_matrix[:, self.mesh_pixel_mask] = 0.0
190+
mapping_matrix = mapping_matrix.at[:, self.mesh_pixel_mask].set(0.0)
191191

192192
return Array2D(
193193
values=inversion_util.mapped_reconstructed_data_via_mapping_matrix_from(

autoarray/inversion/pixelization/mappers/abstract.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ def pixel_signals_from(self, signal_scale: float) -> np.ndarray:
282282
A factor which controls how rapidly the smoothness of regularization varies from high signal regions to
283283
low signal regions.
284284
"""
285+
285286
return mapper_util.adaptive_pixel_signals_from(
286287
pixels=self.pixels,
287288
signal_scale=signal_scale,

test_autoarray/inversion/inversion/test_mapper_valued.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import jax.numpy as jnp
12
import numpy as np
23
import pytest
34

@@ -145,7 +146,7 @@ def test__magnification_via_mesh_from__with_pixel_mask():
145146
pixel_scales=(0.5, 0.5),
146147
)
147148

148-
magnification = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
149+
magnification = jnp.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
149150

150151
source_plane_mesh_grid = aa.Mesh2DVoronoi(
151152
values=np.array(
@@ -168,15 +169,15 @@ def test__magnification_via_mesh_from__with_pixel_mask():
168169
parameters=3,
169170
source_plane_mesh_grid=source_plane_mesh_grid,
170171
mask=mask,
171-
mapping_matrix=np.ones((12, 10)),
172+
mapping_matrix=jnp.ones((12, 10)),
172173
)
173174

174175
mesh_pixel_mask = np.array(
175176
[True, True, True, True, True, True, True, True, False, False]
176177
)
177178

178179
mapper_valued = aa.MapperValued(
179-
values=np.array(magnification), mapper=mapper, mesh_pixel_mask=mesh_pixel_mask
180+
values=magnification, mapper=mapper, mesh_pixel_mask=mesh_pixel_mask
180181
)
181182

182183
magnification = mapper_valued.magnification_via_mesh_from()
@@ -199,7 +200,7 @@ def test__magnification_via_interpolation_from():
199200
parameters=4,
200201
mask=mask,
201202
interpolated_array=magnification,
202-
mapping_matrix=np.ones((4, 4)),
203+
mapping_matrix=jnp.ones((4, 4)),
203204
)
204205

205206
mapper_valued = aa.MapperValued(values=np.array(magnification), mapper=mapper)

0 commit comments

Comments
 (0)