@@ -554,6 +554,7 @@ class InterferometerSparseOperator:
554554 batch_size : int
555555 w_dtype : "jax.numpy.dtype"
556556 Khat : "jax.Array" # (2y, 2x), complex
557+ col_offsets : "jax.Array" # (batch_size,) int32
557558 """
558559 Cached FFT operator state for fast interferometer curvature-matrix assembly.
559560
@@ -672,168 +673,120 @@ def from_nufft_precision_operator(
672673 batch_size = int (batch_size ),
673674 w_dtype = nufft_precision_operator .dtype ,
674675 Khat = Khat ,
676+ col_offsets = jnp .arange (int (batch_size ), dtype = jnp .int32 ),
675677 )
676678
677- def curvature_matrix_via_sparse_operator_from (
678- self ,
679- pix_indexes_for_sub_slim_index : np .ndarray ,
680- pix_weights_for_sub_slim_index : np .ndarray ,
681- pix_pixels : int ,
682- fft_index_for_masked_pixel : np .ndarray ,
683- ):
679+ def apply_operator (self , Fbatch_flat ):
684680 """
685- Assemble the curvature matrix C = Aᵀ W A using sparse triplets and the FFT W~ operator.
686-
687- This method computes the mapper (pixelization) curvature matrix without
688- forming a dense mapping matrix. Instead, it uses fixed-length mapping
689- arrays (pixel indexes + weights per masked pixel) which define a sparse
690- mapping operator A in COO-like form.
691-
692- Algorithm outline
693- -----------------
694- Let S be the number of source pixels and M be the number of rectangular
695- real-space pixels.
696-
697- 1) Build a fixed-length COO stream from the mapping arrays:
698- rows_rect[k] : rectangular pixel index (0..M-1)
699- cols[k] : source pixel index (0..S-1)
700- vals[k] : mapping weight
701- Invalid mappings (cols < 0 or cols >= S) are masked out.
702-
703- 2) Process source-pixel columns in blocks of width `batch_size`:
704- - Scatter the block’s source columns into a dense (M, batch_size) array F.
705- - Apply the W~ operator by FFT:
706- G = apply_W(F)
707- - Project back with Aᵀ via segmented reductions:
708- C[:, start:start+B] = Aᵀ G
709-
710- 3) Symmetrize the result:
711- C <- 0.5 * (C + Cᵀ)
681+ Apply the interferometer W~ operator to a batch of vectors.
682+
683+ Given an input matrix of shape (M, B) on the rectangular real-space
684+ grid (M = y_shape * x_shape), this method computes
685+
686+ G = W~ Fbatch_flat
687+
688+ via FFT-based convolution with the cached `Khat` kernel:
689+
690+ apply_W(F) = Re( IFFT( FFT(F_pad) * Khat ) )[:y, :x]
691+
692+ where `F_pad` is the (2y, 2x) zero-padded version of `F`.
712693
713694 Parameters
714695 ----------
715- pix_indexes_for_sub_slim_index
716- Integer array of shape (M_masked, Pmax).
717- For each masked (slim) image pixel, stores the source-pixel indices
718- involved in the interpolation / mapping stencil. Invalid entries
719- should be set to -1.
720- pix_weights_for_sub_slim_index
721- Floating array of shape (M_masked, Pmax).
722- Weights corresponding to `pix_indexes_for_sub_slim_index`.
723- These should already include any oversampling normalisation (e.g.
724- sub-pixel fractions) required by the mapper.
725- pix_pixels
726- Number of source pixels, S.
727- fft_index_for_masked_pixel
728- Integer array of shape (M_masked,).
729- Maps each masked (slim) image pixel index to its corresponding
730- rectangular-grid flat index (0..M-1). This embeds the masked pixel
731- ordering into the FFT-friendly rectangular grid.
696+ Fbatch_flat
697+ Array of shape (M, B) representing B vectors on the rectangular grid.
732698
733699 Returns
734700 -------
735- jax.Array
736- Curvature matrix of shape (S, S), symmetric.
701+ ndarray
702+ Array of shape (M, B) equal to W~ applied to the batch.
703+ """
704+ import jax .numpy as jnp
705+
706+ y_shape , x_shape = self .y_shape , self .x_shape
707+ M = y_shape * x_shape
708+ Khat = self .Khat
737709
738- Notes
739- -----
740- - The inner computation is written in JAX and is intended to be jitted.
741- For best performance, keep `batch_size` fixed (static) across calls.
742- - Choosing `batch_size` as a divisor of S avoids a smaller tail block,
743- but correctness does not require that if the implementation masks the tail.
744- - This method uses FFTs on padded (2y, 2x) arrays; memory use scales with
745- batch_size and grid size.
710+ B = Fbatch_flat .shape [1 ]
711+ F_img = Fbatch_flat .T .reshape ((B , y_shape , x_shape ))
712+ F_pad = jnp .pad (F_img , ((0 , 0 ), (0 , y_shape ), (0 , x_shape )))
713+ Fhat = jnp .fft .fft2 (F_pad )
714+ Ghat = Fhat * Khat [None , :, :]
715+ G_pad = jnp .fft .ifft2 (Ghat )
716+ G = jnp .real (G_pad [:, :y_shape , :x_shape ])
717+ return G .reshape ((B , M )).T
718+
719+ def curvature_matrix_diag_from (self , rows , cols , vals , * , S : int ):
746720 """
721+ Compute the diagonal (mapper-mapper) curvature matrix block F = Aᵀ W~ A.
722+
723+ This method mirrors `ImagingSparseOperator.curvature_matrix_diag_from`
724+ and is the structural counterpart for the interferometer W~ operator.
725+
726+ Given a sparse mapping operator A in COO triplet form (rows, cols, vals)
727+ with `S` source pixels, it computes
728+
729+ F = Aᵀ W~ A
747730
731+ in column blocks of width `batch_size`:
732+
733+ 1) Assemble Fbatch = A[:, start:start+B] on the rectangular grid via scatter-add.
734+ 2) Apply W~ to the block via FFT: Gbatch = W~(Fbatch).
735+ 3) Project back with Aᵀ via segment_sum over `cols`.
736+
737+ Parameters
738+ ----------
739+ rows, cols, vals
740+ COO triplets encoding the sparse mapping operator A.
741+ - `rows`: rectangular-grid pixel indices (flat) in [0, M), shape (nnz,)
742+ - `cols`: source pixel indices in [0, S), shape (nnz,)
743+ - `vals`: mapping weights (interpolation + any sub-fraction normalisation),
744+ shape (nnz,)
745+ These should already be produced by `mapper.sparse_triplets_curvature`.
746+ S
747+ Number of source pixels / parameters for this mapper.
748+
749+ Returns
750+ -------
751+ ndarray
752+ Curvature matrix of shape (S, S), symmetric.
753+ """
748754 import jax .numpy as jnp
755+ from jax import lax
749756 from jax .ops import segment_sum
750757
751- # -------------------------
752- # Pull static quantities from state
753- # -------------------------
754- y_shape = self .y_shape
755- x_shape = self .x_shape
758+ rows = jnp .asarray (rows , dtype = jnp .int32 )
759+ cols = jnp .asarray (cols , dtype = jnp .int32 )
760+ vals = jnp .asarray (vals , dtype = jnp .float64 )
761+
756762 M = self .M
757- batch_size = self .batch_size
758- Khat = self .Khat
759- w_dtype = self .w_dtype
760-
761- # -------------------------
762- # Basic shape checks (NumPy side, safe)
763- # -------------------------
764- M_masked , Pmax = pix_indexes_for_sub_slim_index .shape
765- S = int (pix_pixels )
766-
767- # -------------------------
768- # JAX core (unchanged COO logic)
769- # -------------------------
770- def _curvature_rect_jax (
771- pix_idx : jnp .ndarray , # (M_masked, Pmax)
772- pix_wts : jnp .ndarray , # (M_masked, Pmax)
773- rect_map : jnp .ndarray , # (M_masked,)
774- ) -> jnp .ndarray :
775- rect_map = jnp .asarray (rect_map )
776-
777- nnz_full = M_masked * Pmax
778-
779- # Flatten mapping arrays into a fixed-length COO stream
780- rows_mask = jnp .repeat (
781- jnp .arange (M_masked , dtype = jnp .int32 ), Pmax
782- ) # (nnz_full,)
783- cols = pix_idx .reshape ((nnz_full ,)).astype (jnp .int32 )
784- vals = pix_wts .reshape ((nnz_full ,)).astype (w_dtype )
785-
786- # Validity mask
787- valid = (cols >= 0 ) & (cols < S )
788-
789- # Embed masked rows into rectangular rows
790- rows_rect = rect_map [rows_mask ].astype (jnp .int32 )
791-
792- # Make cols / vals safe
793- cols_safe = jnp .where (valid , cols , 0 )
794- vals_safe = jnp .where (valid , vals , 0.0 )
795-
796- def apply_operator_fft_batch (Fbatch_flat : jnp .ndarray ) -> jnp .ndarray :
797- B = Fbatch_flat .shape [1 ]
798- F_img = Fbatch_flat .T .reshape ((B , y_shape , x_shape ))
799- F_pad = jnp .pad (
800- F_img , ((0 , 0 ), (0 , y_shape ), (0 , x_shape ))
801- ) # (B,2y,2x)
802- Fhat = jnp .fft .fft2 (F_pad )
803- Ghat = Fhat * Khat [None , :, :]
804- G_pad = jnp .fft .ifft2 (Ghat )
805- G = jnp .real (G_pad [:, :y_shape , :x_shape ])
806- return G .reshape ((B , M )).T # (M,B)
807-
808- def compute_block (start_col : int ) -> jnp .ndarray :
809- in_block = (cols_safe >= start_col ) & (
810- cols_safe < start_col + batch_size
811- )
812- in_use = valid & in_block
763+ B = self .batch_size
813764
814- bc = jnp . where ( in_use , cols_safe - start_col , 0 ). astype ( jnp . int32 )
815- v = jnp . where ( in_use , vals_safe , 0.0 )
765+ n_blocks = ( S + B - 1 ) // B
766+ S_pad = n_blocks * B
816767
817- Fbatch = jnp .zeros ((M , batch_size ), dtype = w_dtype )
818- Fbatch = Fbatch .at [rows_rect , bc ].add (v )
768+ C0 = jnp .zeros ((S , S_pad ), dtype = jnp .float64 )
819769
820- Gbatch = apply_operator_fft_batch ( Fbatch )
821- G_at_rows = Gbatch [ rows_rect , :]
770+ def body ( block_i , C ):
771+ start = block_i * B
822772
823- contrib = vals_safe [:, None ] * G_at_rows
824- return segment_sum (contrib , cols_safe , num_segments = S )
773+ in_block = (cols >= start ) & (cols < (start + B ))
774+ bc = jnp .where (in_block , cols - start , 0 ).astype (jnp .int32 )
775+ v = jnp .where (in_block , vals , 0.0 )
825776
826- # Assemble curvature
827- C = jnp .zeros ((S , S ), dtype = w_dtype )
828- for start in range (0 , S , batch_size ):
829- Cblock = compute_block (start )
830- width = min (batch_size , S - start )
831- C = C .at [:, start : start + width ].set (Cblock [:, :width ])
777+ F = jnp .zeros ((M , B ), dtype = jnp .float64 )
778+ F = F .at [rows , bc ].add (v )
832779
833- return 0.5 * ( C + C . T )
780+ G = self . apply_operator ( F ) # (M, B )
834781
835- return _curvature_rect_jax (
836- pix_indexes_for_sub_slim_index ,
837- pix_weights_for_sub_slim_index ,
838- fft_index_for_masked_pixel ,
839- )
782+ contrib = vals [:, None ] * G [rows , :]
783+ Cblock = segment_sum (contrib , cols , num_segments = S ) # (S, B)
784+
785+ width = jnp .minimum (B , jnp .maximum (0 , S - start ))
786+ Cblock = Cblock * (self .col_offsets < width )[None , :]
787+
788+ return lax .dynamic_update_slice (C , Cblock , (0 , start ))
789+
790+ C_pad = lax .fori_loop (0 , n_blocks , body , C0 )
791+ C = C_pad [:, :S ]
792+ return 0.5 * (C + C .T )
0 commit comments