Skip to content

Commit 9da6598

Browse files
Jammy2211Jammy2211
authored andcommitted
fix slow JAX compile in geometry util
1 parent ea21935 commit 9da6598

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

autoarray/geometry/geometry_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def transform_grid_2d_to_reference_frame(
381381

382382
shifted_grid_2d = grid_2d - jnp.array(centre)
383383

384-
radius = jnp.sqrt(jnp.sum(shifted_grid_2d**2.0, axis=1))
384+
radius = jnp.sqrt(jnp.sum(jnp.square(shifted_grid_2d), axis=1))
385385
theta_coordinate_to_profile = jnp.arctan2(
386386
shifted_grid_2d[:, 0], shifted_grid_2d[:, 1]
387387
) - jnp.radians(angle)

0 commit comments

Comments
 (0)