Skip to content

Commit aa418a5

Browse files
authored
Merge pull request #316 from PyAutoLabs/feature/fix-interferometer-sparse-curvature
fix(interferometer): correct sparse curvature for Pmax > 1 (Delaunay)
2 parents dabe993 + 7a00566 commit aa418a5

4 files changed

Lines changed: 156 additions & 189 deletions

File tree

autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py

Lines changed: 95 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,7 @@ class InterferometerSparseOperator:
554554
batch_size: int
555555
w_dtype: "jax.numpy.dtype"
556556
Khat: "jax.Array" # (2y, 2x), complex
557+
col_offsets: "jax.Array" # (batch_size,) int32
557558
"""
558559
Cached FFT operator state for fast interferometer curvature-matrix assembly.
559560
@@ -672,168 +673,120 @@ def from_nufft_precision_operator(
672673
batch_size=int(batch_size),
673674
w_dtype=nufft_precision_operator.dtype,
674675
Khat=Khat,
676+
col_offsets=jnp.arange(int(batch_size), dtype=jnp.int32),
675677
)
676678

677-
def curvature_matrix_via_sparse_operator_from(
678-
self,
679-
pix_indexes_for_sub_slim_index: np.ndarray,
680-
pix_weights_for_sub_slim_index: np.ndarray,
681-
pix_pixels: int,
682-
fft_index_for_masked_pixel: np.ndarray,
683-
):
679+
def apply_operator(self, Fbatch_flat):
684680
"""
685-
Assemble the curvature matrix C = Aᵀ W A using sparse triplets and the FFT W~ operator.
686-
687-
This method computes the mapper (pixelization) curvature matrix without
688-
forming a dense mapping matrix. Instead, it uses fixed-length mapping
689-
arrays (pixel indexes + weights per masked pixel) which define a sparse
690-
mapping operator A in COO-like form.
691-
692-
Algorithm outline
693-
-----------------
694-
Let S be the number of source pixels and M be the number of rectangular
695-
real-space pixels.
696-
697-
1) Build a fixed-length COO stream from the mapping arrays:
698-
rows_rect[k] : rectangular pixel index (0..M-1)
699-
cols[k] : source pixel index (0..S-1)
700-
vals[k] : mapping weight
701-
Invalid mappings (cols < 0 or cols >= S) are masked out.
702-
703-
2) Process source-pixel columns in blocks of width `batch_size`:
704-
- Scatter the block’s source columns into a dense (M, batch_size) array F.
705-
- Apply the W~ operator by FFT:
706-
G = apply_W(F)
707-
- Project back with Aᵀ via segmented reductions:
708-
C[:, start:start+B] = Aᵀ G
709-
710-
3) Symmetrize the result:
711-
C <- 0.5 * (C + Cᵀ)
681+
Apply the interferometer W~ operator to a batch of vectors.
682+
683+
Given an input matrix of shape (M, B) on the rectangular real-space
684+
grid (M = y_shape * x_shape), this method computes
685+
686+
G = W~ Fbatch_flat
687+
688+
via FFT-based convolution with the cached `Khat` kernel:
689+
690+
apply_W(F) = Re( IFFT( FFT(F_pad) * Khat ) )[:y, :x]
691+
692+
where `F_pad` is the (2y, 2x) zero-padded version of `F`.
712693
713694
Parameters
714695
----------
715-
pix_indexes_for_sub_slim_index
716-
Integer array of shape (M_masked, Pmax).
717-
For each masked (slim) image pixel, stores the source-pixel indices
718-
involved in the interpolation / mapping stencil. Invalid entries
719-
should be set to -1.
720-
pix_weights_for_sub_slim_index
721-
Floating array of shape (M_masked, Pmax).
722-
Weights corresponding to `pix_indexes_for_sub_slim_index`.
723-
These should already include any oversampling normalisation (e.g.
724-
sub-pixel fractions) required by the mapper.
725-
pix_pixels
726-
Number of source pixels, S.
727-
fft_index_for_masked_pixel
728-
Integer array of shape (M_masked,).
729-
Maps each masked (slim) image pixel index to its corresponding
730-
rectangular-grid flat index (0..M-1). This embeds the masked pixel
731-
ordering into the FFT-friendly rectangular grid.
696+
Fbatch_flat
697+
Array of shape (M, B) representing B vectors on the rectangular grid.
732698
733699
Returns
734700
-------
735-
jax.Array
736-
Curvature matrix of shape (S, S), symmetric.
701+
ndarray
702+
Array of shape (M, B) equal to W~ applied to the batch.
703+
"""
704+
import jax.numpy as jnp
705+
706+
y_shape, x_shape = self.y_shape, self.x_shape
707+
M = y_shape * x_shape
708+
Khat = self.Khat
737709

738-
Notes
739-
-----
740-
- The inner computation is written in JAX and is intended to be jitted.
741-
For best performance, keep `batch_size` fixed (static) across calls.
742-
- Choosing `batch_size` as a divisor of S avoids a smaller tail block,
743-
but correctness does not require that if the implementation masks the tail.
744-
- This method uses FFTs on padded (2y, 2x) arrays; memory use scales with
745-
batch_size and grid size.
710+
B = Fbatch_flat.shape[1]
711+
F_img = Fbatch_flat.T.reshape((B, y_shape, x_shape))
712+
F_pad = jnp.pad(F_img, ((0, 0), (0, y_shape), (0, x_shape)))
713+
Fhat = jnp.fft.fft2(F_pad)
714+
Ghat = Fhat * Khat[None, :, :]
715+
G_pad = jnp.fft.ifft2(Ghat)
716+
G = jnp.real(G_pad[:, :y_shape, :x_shape])
717+
return G.reshape((B, M)).T
718+
719+
def curvature_matrix_diag_from(self, rows, cols, vals, *, S: int):
746720
"""
721+
Compute the diagonal (mapper-mapper) curvature matrix block F = Aᵀ W~ A.
722+
723+
This method mirrors `ImagingSparseOperator.curvature_matrix_diag_from`
724+
and is the structural counterpart for the interferometer W~ operator.
725+
726+
Given a sparse mapping operator A in COO triplet form (rows, cols, vals)
727+
with `S` source pixels, it computes
728+
729+
F = Aᵀ W~ A
747730
731+
in column blocks of width `batch_size`:
732+
733+
1) Assemble Fbatch = A[:, start:start+B] on the rectangular grid via scatter-add.
734+
2) Apply W~ to the block via FFT: Gbatch = W~(Fbatch).
735+
3) Project back with Aᵀ via segment_sum over `cols`.
736+
737+
Parameters
738+
----------
739+
rows, cols, vals
740+
COO triplets encoding the sparse mapping operator A.
741+
- `rows`: rectangular-grid pixel indices (flat) in [0, M), shape (nnz,)
742+
- `cols`: source pixel indices in [0, S), shape (nnz,)
743+
- `vals`: mapping weights (interpolation + any sub-fraction normalisation),
744+
shape (nnz,)
745+
These should already be produced by `mapper.sparse_triplets_curvature`.
746+
S
747+
Number of source pixels / parameters for this mapper.
748+
749+
Returns
750+
-------
751+
ndarray
752+
Curvature matrix of shape (S, S), symmetric.
753+
"""
748754
import jax.numpy as jnp
755+
from jax import lax
749756
from jax.ops import segment_sum
750757

751-
# -------------------------
752-
# Pull static quantities from state
753-
# -------------------------
754-
y_shape = self.y_shape
755-
x_shape = self.x_shape
758+
rows = jnp.asarray(rows, dtype=jnp.int32)
759+
cols = jnp.asarray(cols, dtype=jnp.int32)
760+
vals = jnp.asarray(vals, dtype=jnp.float64)
761+
756762
M = self.M
757-
batch_size = self.batch_size
758-
Khat = self.Khat
759-
w_dtype = self.w_dtype
760-
761-
# -------------------------
762-
# Basic shape checks (NumPy side, safe)
763-
# -------------------------
764-
M_masked, Pmax = pix_indexes_for_sub_slim_index.shape
765-
S = int(pix_pixels)
766-
767-
# -------------------------
768-
# JAX core (unchanged COO logic)
769-
# -------------------------
770-
def _curvature_rect_jax(
771-
pix_idx: jnp.ndarray, # (M_masked, Pmax)
772-
pix_wts: jnp.ndarray, # (M_masked, Pmax)
773-
rect_map: jnp.ndarray, # (M_masked,)
774-
) -> jnp.ndarray:
775-
rect_map = jnp.asarray(rect_map)
776-
777-
nnz_full = M_masked * Pmax
778-
779-
# Flatten mapping arrays into a fixed-length COO stream
780-
rows_mask = jnp.repeat(
781-
jnp.arange(M_masked, dtype=jnp.int32), Pmax
782-
) # (nnz_full,)
783-
cols = pix_idx.reshape((nnz_full,)).astype(jnp.int32)
784-
vals = pix_wts.reshape((nnz_full,)).astype(w_dtype)
785-
786-
# Validity mask
787-
valid = (cols >= 0) & (cols < S)
788-
789-
# Embed masked rows into rectangular rows
790-
rows_rect = rect_map[rows_mask].astype(jnp.int32)
791-
792-
# Make cols / vals safe
793-
cols_safe = jnp.where(valid, cols, 0)
794-
vals_safe = jnp.where(valid, vals, 0.0)
795-
796-
def apply_operator_fft_batch(Fbatch_flat: jnp.ndarray) -> jnp.ndarray:
797-
B = Fbatch_flat.shape[1]
798-
F_img = Fbatch_flat.T.reshape((B, y_shape, x_shape))
799-
F_pad = jnp.pad(
800-
F_img, ((0, 0), (0, y_shape), (0, x_shape))
801-
) # (B,2y,2x)
802-
Fhat = jnp.fft.fft2(F_pad)
803-
Ghat = Fhat * Khat[None, :, :]
804-
G_pad = jnp.fft.ifft2(Ghat)
805-
G = jnp.real(G_pad[:, :y_shape, :x_shape])
806-
return G.reshape((B, M)).T # (M,B)
807-
808-
def compute_block(start_col: int) -> jnp.ndarray:
809-
in_block = (cols_safe >= start_col) & (
810-
cols_safe < start_col + batch_size
811-
)
812-
in_use = valid & in_block
763+
B = self.batch_size
813764

814-
bc = jnp.where(in_use, cols_safe - start_col, 0).astype(jnp.int32)
815-
v = jnp.where(in_use, vals_safe, 0.0)
765+
n_blocks = (S + B - 1) // B
766+
S_pad = n_blocks * B
816767

817-
Fbatch = jnp.zeros((M, batch_size), dtype=w_dtype)
818-
Fbatch = Fbatch.at[rows_rect, bc].add(v)
768+
C0 = jnp.zeros((S, S_pad), dtype=jnp.float64)
819769

820-
Gbatch = apply_operator_fft_batch(Fbatch)
821-
G_at_rows = Gbatch[rows_rect, :]
770+
def body(block_i, C):
771+
start = block_i * B
822772

823-
contrib = vals_safe[:, None] * G_at_rows
824-
return segment_sum(contrib, cols_safe, num_segments=S)
773+
in_block = (cols >= start) & (cols < (start + B))
774+
bc = jnp.where(in_block, cols - start, 0).astype(jnp.int32)
775+
v = jnp.where(in_block, vals, 0.0)
825776

826-
# Assemble curvature
827-
C = jnp.zeros((S, S), dtype=w_dtype)
828-
for start in range(0, S, batch_size):
829-
Cblock = compute_block(start)
830-
width = min(batch_size, S - start)
831-
C = C.at[:, start : start + width].set(Cblock[:, :width])
777+
F = jnp.zeros((M, B), dtype=jnp.float64)
778+
F = F.at[rows, bc].add(v)
832779

833-
return 0.5 * (C + C.T)
780+
G = self.apply_operator(F) # (M, B)
834781

835-
return _curvature_rect_jax(
836-
pix_indexes_for_sub_slim_index,
837-
pix_weights_for_sub_slim_index,
838-
fft_index_for_masked_pixel,
839-
)
782+
contrib = vals[:, None] * G[rows, :]
783+
Cblock = segment_sum(contrib, cols, num_segments=S) # (S, B)
784+
785+
width = jnp.minimum(B, jnp.maximum(0, S - start))
786+
Cblock = Cblock * (self.col_offsets < width)[None, :]
787+
788+
return lax.dynamic_update_slice(C, Cblock, (0, start))
789+
790+
C_pad = lax.fori_loop(0, n_blocks, body, C0)
791+
C = C_pad[:, :S]
792+
return 0.5 * (C + C.T)

autoarray/inversion/inversion/interferometer/sparse.py

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,11 @@
77
AbstractInversionInterferometer,
88
)
99
from autoarray.inversion.linear_obj.linear_obj import LinearObj
10-
from autoarray.inversion.mesh.mesh.delaunay import Delaunay
10+
from autoarray.inversion.mappers import mapper_util
1111
from autoarray.settings import Settings
1212
from autoarray.inversion.mappers.abstract import Mapper
1313
from autoarray.structures.visibilities import Visibilities
1414

15-
from autoarray.inversion.inversion.interferometer import inversion_interferometer_util
16-
1715

1816
class InversionInterferometerSparse(AbstractInversionInterferometer):
1917
def __init__(
@@ -99,35 +97,27 @@ def curvature_matrix_diag(self) -> np.ndarray:
9997
10098
This function computes the diagonal terms of F using the sparse linear algebra formalism.
10199
"""
102-
103100
mapper = self.cls_list_from(cls=Mapper)[0]
104101

105-
# The interferometer sparse-operator curvature path
106-
# (``InterferometerSparseOperator.curvature_matrix_via_sparse_operator_from``)
107-
# has only been validated against ``Rectangular*`` meshes (single source
108-
# pixel per image pixel, weight=1). When given a ``Delaunay`` mapper
109-
# (three source pixels per image pixel via barycentric interpolation,
110-
# weights summing to 1) the returned curvature matrix disagrees with the
111-
# mapping path by ~34% Frobenius and the regularized matrix loses
112-
# positive-definiteness, raising a numpy ``LinAlgError`` at the Cholesky
113-
# call site in ``Inversion.log_det_curvature_reg_matrix_term``. Guard
114-
# rather than silently mis-computing.
115-
if isinstance(mapper.mesh, Delaunay):
116-
raise NotImplementedError(
117-
"Interferometer.apply_sparse_operator() is not implemented for "
118-
"Delaunay-mesh pixelizations: the sparse curvature math has only "
119-
"been validated against Rectangular meshes (Pmax=1, weight=1) "
120-
"and is structurally wrong for barycentric-interpolated mappers "
121-
"(Pmax=3). For Delaunay interferometer fits, use the plain DFT "
122-
"or NUFFT path (i.e. omit the apply_sparse_operator step). "
123-
"Tracking issue: https://github.com/PyAutoLabs/PyAutoArray/issues/314"
124-
)
102+
# The interferometer W~ operator lives on the unmasked-extent rectangular
103+
# grid (shape_native_masked_pixels), not the full native grid used by
104+
# the imaging path. Build sparse triplets with extent-flat row indices
105+
# so they match the operator's (M = extent_y * extent_x, B) scatter buffer.
106+
rows, cols, vals = mapper_util.sparse_triplets_from(
107+
pix_indexes_for_sub=mapper.pix_indexes_for_sub_slim_index,
108+
pix_weights_for_sub=mapper.pix_weights_for_sub_slim_index,
109+
slim_index_for_sub=mapper.slim_index_for_sub_slim_index,
110+
fft_index_for_masked_pixel=self.mask.extent_index_for_masked_pixel,
111+
sub_fraction_slim=mapper.over_sampler.sub_fraction.array,
112+
return_rows_slim=False,
113+
xp=self._xp,
114+
)
125115

126-
return self.dataset.sparse_operator.curvature_matrix_via_sparse_operator_from(
127-
pix_indexes_for_sub_slim_index=mapper.pix_indexes_for_sub_slim_index,
128-
pix_weights_for_sub_slim_index=mapper.pix_weights_for_sub_slim_index,
129-
pix_pixels=self.linear_obj_list[0].params,
130-
fft_index_for_masked_pixel=self.mask.fft_index_for_masked_pixel,
116+
return self.dataset.sparse_operator.curvature_matrix_diag_from(
117+
rows=rows,
118+
cols=cols,
119+
vals=vals,
120+
S=mapper.params,
131121
)
132122

133123
@property

autoarray/mask/mask_2d.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,37 @@ def fft_index_for_masked_pixel(self) -> np.ndarray:
712712
# Convert (y, x) coordinates to flat row-major indices
713713
return (ys * width + xs).astype(np.int32)
714714

715+
@cached_property
716+
def extent_index_for_masked_pixel(self) -> np.ndarray:
717+
"""
718+
Return a mapping from masked-pixel (slim) indices to flat indices on
719+
the *unmasked-extent* rectangular FFT grid.
720+
721+
The unmasked extent is the bounding box of unmasked pixels
722+
(``shape_native_masked_pixels``). This index is the interferometer
723+
counterpart of `fft_index_for_masked_pixel`, which uses the full native
724+
grid: the interferometer W~ kernel is computed on the (extent_y,
725+
extent_x) grid because it is translation-invariant and only the offsets
726+
between pairs of unmasked pixels matter — the surrounding masked region
727+
contributes nothing.
728+
729+
Returns
730+
-------
731+
np.ndarray
732+
A 1D array of shape (N_unmasked,) of int32 values in
733+
``[0, extent_y * extent_x)``, suitable as row indices into the
734+
(extent_y * extent_x, batch) scatter buffer used by
735+
``InterferometerSparseOperator.curvature_matrix_diag_from``.
736+
"""
737+
ys, xs = np.where(~self)
738+
if ys.size == 0:
739+
return np.zeros((0,), dtype=np.int32)
740+
741+
y0, x0 = int(np.min(ys)), int(np.min(xs))
742+
extent_y, extent_x = self.shape_native_masked_pixels
743+
width = int(extent_x)
744+
return ((ys - y0) * width + (xs - x0)).astype(np.int32)
745+
715746
def trimmed_array_from(self, padded_array, image_shape) -> Array2D:
716747
"""
717748
Map a padded 1D array of values to its original 2D array, trimming all edge values.

0 commit comments

Comments
 (0)