1313from autoarray .inversion .pixelization .mappers .abstract import AbstractMapper
1414from autoarray .inversion .regularization .abstract import AbstractRegularization
1515from autoarray .inversion .inversion .settings import SettingsInversion
16+ from autoarray .preloads import Preloads
1617from autoarray .structures .arrays .uniform_2d import Array2D
1718from autoarray .structures .visibilities import Visibilities
1819
@@ -27,6 +28,7 @@ def __init__(
2728 dataset : Union [Imaging , Interferometer , DatasetInterface ],
2829 linear_obj_list : List [LinearObj ],
2930 settings : SettingsInversion = SettingsInversion (),
31+ preloads : Preloads = None ,
3032 ):
3133 """
3234 An `Inversion` reconstructs an input dataset using a list of linear objects (e.g. a list of analytic functions
@@ -83,6 +85,8 @@ def __init__(
8385
8486 self .settings = settings
8587
88+ self .preloads = preloads or Preloads ()
89+
8690 @property
8791 def data (self ):
8892 return self .dataset .data
@@ -267,6 +271,22 @@ def no_regularization_index_list(self) -> List[int]:
267271
268272 return no_regularization_index_list
269273
274+ @property
275+ def mapper_index_list (self ) -> List [int ]:
276+
277+ if self .preloads .mapper_index_list is not None :
278+ return self .preloads .mapper_index_list
279+
280+ mapper_index_list = []
281+
282+ param_range_list = self .param_range_list_from (cls = AbstractMapper )
283+
284+ for param_range in param_range_list :
285+
286+ mapper_index_list += range (param_range [0 ], param_range [1 ])
287+
288+ return mapper_index_list
289+
270290 @property
271291 def mask (self ) -> Array2D :
272292 return self .data .mask
@@ -358,14 +378,10 @@ def regularization_matrix_reduced(self) -> Optional[np.ndarray]:
358378 return self .regularization_matrix
359379
360380 # ids of values which are on edge so zero-d and not solved for.
361- ids_to_not_solve_for = jnp .array (self .no_regularization_index_list , dtype = int )
362-
363- # Create a boolean mask: True = keep, False = ignore
364- mask = jnp .ones (self .data_vector .shape [0 ], dtype = bool ).at [ids_to_not_solve_for ].set (False )
381+ ids_to_keep = jnp .array (self .mapper_index_list , dtype = int )
365382
366383 # Zero rows and columns in the matrix we want to ignore
367- mask_matrix = mask [:, None ] * mask [None , :]
368- return self .regularization_matrix * mask_matrix
384+ return self .regularization_matrix [ids_to_keep ][:, ids_to_keep ]
369385
370386 @cached_property
371387 def curvature_reg_matrix (self ) -> np .ndarray :
@@ -383,25 +399,28 @@ def curvature_reg_matrix(self) -> np.ndarray:
383399 return jnp .add (self .curvature_matrix , self .regularization_matrix )
384400
385401 @cached_property
386- def curvature_reg_matrix_reduced (self ) -> np .ndarray :
402+ def curvature_reg_matrix_reduced (self ) -> Optional [ np .ndarray ] :
387403 """
388- The linear system of equations solves for F + regularization_coefficient*H, which is computed below.
404+ The regularization matrix H is used to impose smoothness on our inversion's reconstruction. This enters the
405+ linear algebra system we solve for using D and F above and is given by
406+ equation (12) in https://arxiv.org/pdf/astro-ph/0302587.pdf.
389407
390- This is the curvature reg matrix for only the mappers, which is necessary for computing the log det
391- term without the linear light profiles included.
408+ A complete description of regularization is given in the `regularization.py` and `regularization_util.py`
409+ modules.
410+
411+ For multiple mappers, the regularization matrix is computed as the block diagonal of each individual mapper.
412+ The scipy function `block_diag` has an overhead associated with it and if there is only one mapper and
413+ regularization it is bypassed.
392414 """
415+
393416 if self .all_linear_obj_have_regularization :
394417 return self .curvature_reg_matrix
395418
396419 # ids of values which are on edge so zero-d and not solved for.
397- ids_to_not_solve_for = jnp .array (self .no_regularization_index_list , dtype = int )
398-
399- # Create a boolean mask: True = keep, False = ignore
400- mask = jnp .ones (self .data_vector .shape [0 ], dtype = bool ).at [ids_to_not_solve_for ].set (False )
420+ ids_to_keep = jnp .array (self .mapper_index_list , dtype = int )
401421
402422 # Zero rows and columns in the matrix we want to ignore
403- mask_matrix = mask [:, None ] * mask [None , :]
404- return self .curvature_reg_matrix * mask_matrix
423+ return self .regularization_matrix [ids_to_keep ][:, ids_to_keep ]
405424
406425 @property
407426 def mapper_zero_pixel_list (self ) -> np .ndarray :
@@ -454,10 +473,14 @@ def reconstruction(self) -> np.ndarray:
454473 ):
455474
456475 # ids of values which are on edge so zero-d and not solved for.
457- ids_to_not_solve_for = jnp .array (self .mapper_edge_pixel_list , dtype = int )
476+ ids_to_remove = jnp .array (self .mapper_edge_pixel_list , dtype = int )
458477
459478 # Create a boolean mask: True = keep, False = ignore
460- mask = jnp .ones (self .data_vector .shape [0 ], dtype = bool ).at [ids_to_not_solve_for ].set (False )
479+ mask = (
480+ jnp .ones (self .data_vector .shape [0 ], dtype = bool )
481+ .at [ids_to_remove ]
482+ .set (False )
483+ )
461484
462485 # Zero out entries we don't want to solve for
463486 data_vector_masked = self .data_vector * mask
@@ -502,13 +525,10 @@ def reconstruction_reduced(self) -> np.ndarray:
502525 return self .reconstruction
503526
504527 # ids of values which are on edge so zero-d and not solved for.
505- ids_to_not_solve_for = jnp .array (self .no_regularization_index_list , dtype = int )
506-
507- # Create a boolean mask: True = keep, False = ignore
508- mask = jnp .ones (self .reconstruction .shape [0 ], dtype = bool ).at [ids_to_not_solve_for ].set (False )
528+ ids_to_keep = jnp .array (self .mapper_index_list , dtype = int )
509529
510- # Zero out entries we don't want to solve for
511- return self .reconstruction * mask
530+ # Zero rows and columns in the matrix we want to ignore
531+ return self .reconstruction [ ids_to_keep ]
512532
513533 @property
514534 def reconstruction_dict (self ) -> Dict [LinearObj , np .ndarray ]:
@@ -651,12 +671,6 @@ def regularization_term(self) -> float:
651671 if not self .has (cls = AbstractRegularization ):
652672 return 0.0
653673
654- print (self .reconstruction_reduced )
655- print (self .regularization_matrix_reduced )
656-
657- print (self .reconstruction_reduced .shape )
658- print (self .regularization_matrix_reduced .shape )
659-
660674 return np .matmul (
661675 self .reconstruction_reduced .T ,
662676 np .matmul (self .regularization_matrix_reduced , self .reconstruction_reduced ),
@@ -674,7 +688,7 @@ def log_det_curvature_reg_matrix_term(self) -> float:
674688
675689 try :
676690 return 2.0 * np .sum (
677- np .log (np .diag (np .linalg .cholesky (self .curvature_reg_matrix_reduced )))
691+ jnp .log (jnp .diag (jnp .linalg .cholesky (self .curvature_reg_matrix_reduced )))
678692 )
679693 except np .linalg .LinAlgError as e :
680694 raise exc .InversionException () from e
0 commit comments