File tree Expand file tree Collapse file tree
autoarray/structures/arrays Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -484,22 +484,29 @@ def convolved_array_from(self, array: Array2D) -> Array2D:
484484 return Array2D (values = convolved_array_1d , mask = array_2d .mask )
485485
486486 def jax_convolve (self , image , blurring_image , method = "auto" ):
487- slim_to_2D_index_image = jnp .nonzero (
487+
488+ slim_to_native = jnp .nonzero (
488489 jnp .logical_not (self .mask .array ), size = image .shape [0 ]
489490 )
490- slim_to_2D_index_blurring = jnp .nonzero (
491+ slim_to_native_blurring = jnp .nonzero (
491492 jnp .logical_not (self .blurring_mask ), size = blurring_image .shape [0 ]
492493 )
494+
493495 expanded_image_native = jnp .zeros (self .mask .shape )
494- expanded_image_native = expanded_image_native .at [slim_to_2D_index_image ].set (
496+
497+ expanded_image_native = expanded_image_native .at [slim_to_native ].set (
495498 image .array
496499 )
497- expanded_image_native = expanded_image_native .at [slim_to_2D_index_blurring ].set (
500+ expanded_image_native = expanded_image_native .at [slim_to_native_blurring ].set (
498501 blurring_image .array
499502 )
503+
500504 kernel = np .array (self .kernel .native .array )
505+
501506 convolve_native = jax .scipy .signal .convolve (
502507 expanded_image_native , kernel , mode = "same" , method = method
503508 )
504- convolve_slim = convolve_native [slim_to_2D_index_image ]
509+
510+ convolve_slim = convolve_native [slim_to_native ]
511+
505512 return convolve_slim
You can’t perform that action at this time.
0 commit comments