Skip to content

Commit 5de1760

Browse files
authored
Merge pull request #167 from Jammy2211/feature/unit_tests
Feature/unit tests
2 parents 2485318 + 72af86b commit 5de1760

99 files changed

Lines changed: 1494 additions & 4745 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

autoarray/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from . import fixtures
55
from . import mock as m
66
from .numba_util import profile_func
7-
from .preloads import Preloads
87
from .dataset import preprocess
98
from .dataset.abstract.dataset import AbstractDataset
109
from .dataset.abstract.w_tilde import AbstractWTilde
@@ -55,8 +54,6 @@
5554
from .mask.derive.grid_2d import DeriveGrid2D
5655
from .mask.mask_1d import Mask1D
5756
from .mask.mask_2d import Mask2D
58-
from .operators.convolver import Convolver
59-
from .operators.convolver import Convolver
6057
from .operators.transformer import TransformerDFT
6158
from .operators.transformer import TransformerNUFFT
6259
from .operators.over_sampling.decorator import over_sample

autoarray/abstract_ndarray.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44

55
from abc import ABC
66
from abc import abstractmethod
7-
import numpy as np
7+
import jax.numpy as jnp
88

99
from autoconf.fitsable import output_to_fits
1010

11-
from autoarray.numpy_wrapper import numpy as npw, register_pytree_node, Array
11+
from autoarray.numpy_wrapper import register_pytree_node, Array
1212

1313
from typing import TYPE_CHECKING
1414

@@ -83,7 +83,7 @@ def __init__(self, array):
8383

8484
def invert(self):
8585
new = self.copy()
86-
new._array = np.invert(new._array)
86+
new._array = jnp.invert(new._array)
8787
return new
8888

8989
@classmethod
@@ -105,7 +105,7 @@ def instance_flatten(cls, instance):
105105
@staticmethod
106106
def flip_hdu_for_ds9(values):
107107
if conf.instance["general"]["fits"]["flip_for_ds9"]:
108-
return np.flipud(values)
108+
return jnp.flipud(values)
109109
return values
110110

111111
@classmethod
@@ -114,11 +114,11 @@ def instance_unflatten(cls, aux_data, children):
114114
Unflatten a tuple of attributes (i.e. a pytree) into an instance of an autoarray class
115115
"""
116116
instance = cls.__new__(cls)
117-
for key, value in zip(aux_data, children[1:]):
117+
for key, value in zip(aux_data, children):
118118
setattr(instance, key, value)
119119
return instance
120120

121-
def with_new_array(self, array: np.ndarray) -> "AbstractNDArray":
121+
def with_new_array(self, array: jnp.ndarray) -> "AbstractNDArray":
122122
"""
123123
Copy this object but give it a new array.
124124
@@ -165,7 +165,7 @@ def __iter__(self):
165165

166166
@to_new_array
167167
def sqrt(self):
168-
return np.sqrt(self._array)
168+
return jnp.sqrt(self._array)
169169

170170
@property
171171
def array(self):
@@ -331,13 +331,13 @@ def __getitem__(self, item):
331331
result = self._array[item]
332332
if isinstance(item, slice):
333333
result = self.with_new_array(result)
334-
if isinstance(result, np.ndarray):
334+
if isinstance(result, jnp.ndarray):
335335
result = self.with_new_array(result)
336336
return result
337337

338338
def __setitem__(self, key, value):
339-
if isinstance(key, (np.ndarray, AbstractNDArray, Array)):
340-
self._array = npw.where(key, value, self._array)
339+
if isinstance(key, (jnp.ndarray, AbstractNDArray, Array)):
340+
self._array = jnp.where(key, value, self._array)
341341
else:
342342
self._array[key] = value
343343

autoarray/config/general.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,3 @@ pixelization:
1616
voronoi_nn_max_interpolation_neighbors: 300
1717
structures:
1818
native_binned_only: false # If True, data structures are only stored in their native and binned format. This is used to reduce memory usage in autocti.
19-
test:
20-
preloads_check_threshold: 1.0 # If the figure of merit of a fit with and without preloads is greater than this threshold, the check preload test fails and an exception raised for a model-fit.
21-

autoarray/config/grids.yaml

Lines changed: 0 additions & 3 deletions
This file was deleted.

autoarray/dataset/imaging/dataset.py

Lines changed: 21 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from autoarray.dataset.grids import GridsDataset
1010
from autoarray.dataset.imaging.w_tilde import WTildeImaging
1111
from autoarray.structures.arrays.uniform_2d import Array2D
12-
from autoarray.operators.convolver import Convolver
1312
from autoarray.structures.arrays.kernel_2d import Kernel2D
1413
from autoarray.mask.mask_2d import Mask2D
1514
from autoarray import type as ty
@@ -30,7 +29,7 @@ def __init__(
3029
noise_covariance_matrix: Optional[np.ndarray] = None,
3130
over_sample_size_lp: Union[int, Array2D] = 4,
3231
over_sample_size_pixelization: Union[int, Array2D] = 4,
33-
pad_for_convolver: bool = False,
32+
pad_for_psf: bool = False,
3433
use_normalized_psf: Optional[bool] = True,
3534
check_noise_map: bool = True,
3635
):
@@ -77,7 +76,7 @@ def __init__(
7776
over_sample_size_pixelization
7877
How over sampling is performed for the grid which is associated with a pixelization, which is therefore
7978
passed into the calculations performed in the `inversion` module.
80-
pad_for_convolver
79+
pad_for_psf
8180
The PSF convolution may extend beyond the edges of the image mask, which can lead to edge effects in the
8281
convolved image. If `True`, the image and noise-map are padded to ensure the PSF convolution does not
8382
extend beyond the edge of the image.
@@ -90,9 +89,9 @@ def __init__(
9089

9190
self.unmasked = None
9291

93-
self.pad_for_convolver = pad_for_convolver
92+
self.pad_for_psf = pad_for_psf
9493

95-
if pad_for_convolver and psf is not None:
94+
if pad_for_psf and psf is not None:
9695
try:
9796
data.mask.derive_mask.blurring_from(
9897
kernel_shape_native=psf.shape_native
@@ -162,11 +161,15 @@ def __init__(
162161

163162
if psf is not None and use_normalized_psf:
164163
psf = Kernel2D.no_mask(
165-
values=psf.native, pixel_scales=psf.pixel_scales, normalize=True
164+
values=psf.native._array, pixel_scales=psf.pixel_scales, normalize=True
166165
)
167166

168167
self.psf = psf
169168

169+
if psf is not None:
170+
if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0:
171+
raise exc.KernelException("Kernel2D Kernel2D must be odd")
172+
170173
@cached_property
171174
def grids(self):
172175
return GridsDataset(
@@ -176,25 +179,6 @@ def grids(self):
176179
psf=self.psf,
177180
)
178181

179-
@cached_property
180-
def convolver(self):
181-
"""
182-
Returns a `Convolver` from a mask and 2D PSF kernel.
183-
184-
The `Convolver` stores in memory the array indexing between the mask and PSF, enabling efficient 2D PSF
185-
convolution of images and matrices used for linear algebra calculations (see `operators.convolver`).
186-
187-
This uses lazy allocation such that the calculation is only performed when the convolver is used, ensuring
188-
efficient set up of the `Imaging` class.
189-
190-
Returns
191-
-------
192-
Convolver
193-
The convolver given the masked imaging data's mask and PSF.
194-
"""
195-
196-
return Convolver(mask=self.mask, kernel=self.psf)
197-
198182
@cached_property
199183
def w_tilde(self):
200184
"""
@@ -220,9 +204,9 @@ def w_tilde(self):
220204
indexes,
221205
lengths,
222206
) = inversion_imaging_util.w_tilde_curvature_preload_imaging_from(
223-
noise_map_native=np.array(self.noise_map.native),
224-
kernel_native=np.array(self.psf.native),
225-
native_index_for_slim_index=self.mask.derive_indexes.native_for_slim,
207+
noise_map_native=np.array(self.noise_map.native.array).astype("float64"),
208+
kernel_native=np.array(self.psf.native.array).astype("float64"),
209+
native_index_for_slim_index=np.array(self.mask.derive_indexes.native_for_slim).astype("int"),
226210
)
227211

228212
return WTildeImaging(
@@ -370,7 +354,7 @@ def apply_mask(self, mask: Mask2D) -> "Imaging":
370354
noise_covariance_matrix=noise_covariance_matrix,
371355
over_sample_size_lp=over_sample_size_lp,
372356
over_sample_size_pixelization=over_sample_size_pixelization,
373-
pad_for_convolver=True,
357+
pad_for_psf=True,
374358
)
375359

376360
dataset.unmasked = unmasked_dataset
@@ -425,20 +409,20 @@ def apply_noise_scaling(
425409
"""
426410

427411
if signal_to_noise_value is None:
428-
noise_map = self.noise_map.native
429-
noise_map[mask == False] = noise_value
412+
noise_map = np.array(self.noise_map.native.array)
413+
noise_map[mask.array == False] = noise_value
430414
else:
431415
noise_map = np.where(
432416
mask == False,
433-
np.median(self.data.native[mask.derive_mask.edge == False])
417+
np.median(self.data.native.array[mask.derive_mask.edge == False])
434418
/ signal_to_noise_value,
435-
self.noise_map.native,
419+
self.noise_map.native.array,
436420
)
437421

438422
if should_zero_data:
439-
data = np.where(np.invert(mask), 0.0, self.data.native)
423+
data = np.where(np.invert(mask.array), 0.0, self.data.native.array)
440424
else:
441-
data = self.data.native
425+
data = self.data.native.array
442426

443427
data_unmasked = Array2D.no_mask(
444428
values=data,
@@ -463,7 +447,7 @@ def apply_noise_scaling(
463447
noise_covariance_matrix=self.noise_covariance_matrix,
464448
over_sample_size_lp=self.over_sample_size_lp,
465449
over_sample_size_pixelization=self.over_sample_size_pixelization,
466-
pad_for_convolver=False,
450+
pad_for_psf=False,
467451
check_noise_map=False,
468452
)
469453

@@ -511,7 +495,7 @@ def apply_over_sampling(
511495
over_sample_size_lp=over_sample_size_lp or self.over_sample_size_lp,
512496
over_sample_size_pixelization=over_sample_size_pixelization
513497
or self.over_sample_size_pixelization,
514-
pad_for_convolver=False,
498+
pad_for_psf=False,
515499
check_noise_map=False,
516500
)
517501

autoarray/dataset/interferometer/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,5 +276,5 @@ def output_to_fits(
276276
)
277277

278278
@property
279-
def convolver(self):
279+
def psf(self):
280280
return None

autoarray/dataset/preprocess.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -263,15 +263,15 @@ def edges_from(image, no_edges):
263263
edges = []
264264

265265
for edge_no in range(no_edges):
266-
top_edge = image.native[edge_no, edge_no : image.shape_native[1] - edge_no]
267-
bottom_edge = image.native[
266+
top_edge = image.native.array[edge_no, edge_no : image.shape_native[1] - edge_no]
267+
bottom_edge = image.native.array[
268268
image.shape_native[0] - 1 - edge_no,
269269
edge_no : image.shape_native[1] - edge_no,
270270
]
271-
left_edge = image.native[
271+
left_edge = image.native.array[
272272
edge_no + 1 : image.shape_native[0] - 1 - edge_no, edge_no
273273
]
274-
right_edge = image.native[
274+
right_edge = image.native.array[
275275
edge_no + 1 : image.shape_native[0] - 1 - edge_no,
276276
image.shape_native[1] - 1 - edge_no,
277277
]
@@ -328,7 +328,7 @@ def background_noise_map_via_edges_from(image, no_edges):
328328
def psf_with_odd_dimensions_from(psf):
329329
"""
330330
If the PSF kernel has one or two even-sized dimensions, return a PSF object where the kernel has odd-sized
331-
dimensions (odd-sized dimensions are required by a *Convolver*).
331+
dimensions (odd-sized dimensions are required for 2D convolution).
332332
333333
Kernels are rescaled using the scikit-image routine rescale, which performs rescaling via an interpolation
334334
routine. This may lead to loss of accuracy in the PSF kernel and it is advised that users, where possible,
@@ -517,8 +517,8 @@ def noise_map_with_signal_to_noise_limit_from(
517517
noise_map_limit = np.where(
518518
(signal_to_noise_map.native > signal_to_noise_limit)
519519
& (noise_limit_mask == False),
520-
np.abs(data.native) / signal_to_noise_limit,
521-
noise_map.native,
520+
np.abs(data.native.array) / signal_to_noise_limit,
521+
noise_map.native.array,
522522
)
523523

524524
mask = Mask2D.all_false(

autoarray/exc.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -106,17 +106,6 @@ class PlottingException(Exception):
106106
pass
107107

108108

109-
class PreloadsException(Exception):
110-
"""
111-
Raises exceptions associated with the `preloads.py` module and `Preloads` class.
112-
113-
For example if the preloaded quantities lead to a change in figure of merit of a fit compared to a fit without
114-
preloading.
115-
"""
116-
117-
pass
118-
119-
120109
class ProfilingException(Exception):
121110
"""
122111
Raises exceptions associated with in-built profiling tools (e.g. the `profile_func` decorator).

0 commit comments

Comments
 (0)