Skip to content

Commit 1897037

Browse files
authored
Merge pull request #287 from PyAutoLabs/feature/grid-irregular-xp-propagation
Propagate xp through Grid2DIrregular.grid_2d_via_deflection_grid_from
2 parents becbabd + a6b4402 commit 1897037

2 files changed

Lines changed: 20 additions & 1 deletion

File tree

autoarray/structures/grids/irregular_2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def grid_2d_via_deflection_grid_from(
182182
deflection_grid
183183
The grid of (y,x) coordinates which is subtracted from this grid.
184184
"""
185-
return Grid2DIrregular(values=self - deflection_grid)
185+
return Grid2DIrregular(values=self - deflection_grid, xp=self._xp)
186186

187187
def squared_distances_to_coordinate_from(
188188
self, coordinate: Tuple[float, float] = (0.0, 0.0)

test_autoarray/structures/grids/test_irregular_2d.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,25 @@ def test__grid_2d_via_deflection_grid_from():
7575
assert grid.in_list == [(0.0, 1.0), (1.0, 1.0)]
7676

7777

78+
def test__grid_2d_via_deflection_grid_from__propagates_xp():
79+
# numpy-backed receiver -> numpy-backed result
80+
grid_np = aa.Grid2DIrregular(values=[(1.0, 1.0), (2.0, 2.0)])
81+
result_np = grid_np.grid_2d_via_deflection_grid_from(
82+
deflection_grid=np.array([[1.0, 0.0], [1.0, 1.0]])
83+
)
84+
assert result_np._xp is np
85+
86+
# jax-backed receiver -> jax-backed result (so downstream .square calls use jnp)
87+
jnp = pytest.importorskip("jax.numpy")
88+
grid_jax = aa.Grid2DIrregular(
89+
values=jnp.array([[1.0, 1.0], [2.0, 2.0]]), xp=jnp
90+
)
91+
result_jax = grid_jax.grid_2d_via_deflection_grid_from(
92+
deflection_grid=jnp.array([[1.0, 0.0], [1.0, 1.0]])
93+
)
94+
assert result_jax._xp is jnp
95+
96+
7897
def test__furthest_distances_to_other_coordinates():
7998
grid = aa.Grid2DIrregular(values=[(0.0, 0.0), (0.0, 1.0)])
8099

0 commit comments

Comments
 (0)