@@ -75,6 +75,25 @@ def test__grid_2d_via_deflection_grid_from():
7575 assert grid .in_list == [(0.0 , 1.0 ), (1.0 , 1.0 )]
7676
7777
78+ def test__grid_2d_via_deflection_grid_from__propagates_xp ():
79+ # numpy-backed receiver -> numpy-backed result
80+ grid_np = aa .Grid2DIrregular (values = [(1.0 , 1.0 ), (2.0 , 2.0 )])
81+ result_np = grid_np .grid_2d_via_deflection_grid_from (
82+ deflection_grid = np .array ([[1.0 , 0.0 ], [1.0 , 1.0 ]])
83+ )
84+ assert result_np ._xp is np
85+
86+ # jax-backed receiver -> jax-backed result (so downstream .square calls use jnp)
87+ jnp = pytest .importorskip ("jax.numpy" )
88+ grid_jax = aa .Grid2DIrregular (
89+ values = jnp .array ([[1.0 , 1.0 ], [2.0 , 2.0 ]]), xp = jnp
90+ )
91+ result_jax = grid_jax .grid_2d_via_deflection_grid_from (
92+ deflection_grid = jnp .array ([[1.0 , 0.0 ], [1.0 , 1.0 ]])
93+ )
94+ assert result_jax ._xp is jnp
95+
96+
7897def test__furthest_distances_to_other_coordinates ():
7998 grid = aa .Grid2DIrregular (values = [(0.0 , 0.0 ), (0.0 , 1.0 )])
8099
0 commit comments