diff --git a/test_autoarray/conftest.py b/test_autoarray/conftest.py index 8f7684dd2..7a8a32de6 100644 --- a/test_autoarray/conftest.py +++ b/test_autoarray/conftest.py @@ -1,4 +1,3 @@ -import importlib.util import os from os import path import pytest @@ -7,13 +6,6 @@ from autoarray import fixtures from autoconf import conf -# Skip JAX-only tests when jax isn't installed. find_spec checks availability -# WITHOUT importing the module, so this conftest stays numpy-only per the -# "library unit tests stay numpy-only" rule. -collect_ignore_glob = [] -if importlib.util.find_spec("jax") is None: - collect_ignore_glob = ["test_jax_*.py", "**/test_jax_*.py"] - class PlotPatch: def __init__(self): diff --git a/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py b/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py index f48980987..183823814 100644 --- a/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py +++ b/test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py @@ -1,324 +1,183 @@ -import autoarray as aa -import numpy as np -import pytest - - -def test__psf_weighted_noise_imaging_from(): - noise_map = np.array( - [ - [0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 2.0, 0.0], - [0.0, 2.0, 4.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - ] - ) - - kernel = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 2.0], [0.0, 1.0, 2.0]]) - - native_index_for_slim_index = np.array([[1, 1], [1, 2], [2, 1], [2, 2]]) - - psf_weighted_noise = aa.util.inversion_imaging_numba.psf_precision_operator_from( - noise_map_native=noise_map, - kernel_native=kernel, - native_index_for_slim_index=native_index_for_slim_index, - ) - - assert psf_weighted_noise == pytest.approx( - np.array( - [ - [2.5, 1.625, 0.5, 0.375], - [1.625, 1.3125, 0.125, 0.0625], - [0.5, 0.125, 0.5, 0.375], - [0.375, 0.0625, 0.375, 0.3125], - ] - ), - 1.0e-4, - ) - - -def test__psf_weighted_data_from(): - - mask = aa.Mask2D( - mask=[ - [True, True, True, True], - [True, False, False, True], - [True, False, False, True], - [True, True, True, True], - ], - pixel_scales=(1.0, 1.0), - ) - - data = aa.Array2D( - values=[ - [0.0, 0.0, 0.0, 0.0], - [0.0, 2.0, 1.0, 0.0], - [0.0, 1.0, 2.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - ], - mask=mask, - ) - - noise_map = aa.Array2D( - values=[ - [0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 0.0], - [0.0, 1.0, 2.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - ], - mask=mask, - ) - - kernel = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [1.0, 2.0, 0.0]]) - - native_index_for_slim_index = np.array([[1, 1], [1, 2], [2, 1], [2, 2]]) - - weight_map = data / (noise_map**2) - weight_map = aa.Array2D(values=weight_map, mask=mask) - - psf_weighted_data = aa.util.inversion_imaging.psf_weighted_data_from( - weight_map_native=weight_map.native.array, - kernel_native=kernel, - native_index_for_slim_index=native_index_for_slim_index, - ) - - assert (psf_weighted_data == np.array([5.0, 5.0, 1.5, 1.5])).all() - - -def test__psf_precision_operator_sparse_from(): - noise_map = np.array( - [ - [0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 2.0, 0.0], - [0.0, 2.0, 4.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - ] - ) - - kernel = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 2.0], [0.0, 1.0, 2.0]]) - - native_index_for_slim_index = np.array([[1, 1], [1, 2], [2, 1], [2, 2]]) - - ( - psf_weighted_noise_preload, - psf_weighted_noise_indexes, - psf_weighted_noise_lengths, - ) = aa.util.inversion_imaging_numba.psf_precision_operator_sparse_from( - noise_map_native=noise_map, - kernel_native=kernel, - native_index_for_slim_index=native_index_for_slim_index, - ) - - assert psf_weighted_noise_preload == pytest.approx( - np.array( - [1.25, 1.625, 0.5, 0.375, 0.65625, 0.125, 0.0625, 0.25, 0.375, 0.15625] - ), - 1.0e-4, - ) - assert psf_weighted_noise_indexes == pytest.approx( - np.array([0, 1, 2, 3, 1, 2, 3, 2, 3, 3]), 1.0e-4 - ) - - assert psf_weighted_noise_lengths == pytest.approx(np.array([4, 3, 2, 1]), 1.0e-4) - - -def test__data_vector_via_blurred_mapping_matrix_from(): - blurred_mapping_matrix = np.array( - [ - [1.0, 1.0, 0.0], - [1.0, 0.0, 0.0], - [0.0, 1.0, 0.0], - [0.0, 1.0, 1.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - ] - ) - - image = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) - noise_map = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) - - data_vector = aa.util.inversion_imaging.data_vector_via_blurred_mapping_matrix_from( - blurred_mapping_matrix=blurred_mapping_matrix, image=image, noise_map=noise_map - ) - - assert (data_vector == np.array([2.0, 3.0, 1.0])).all() - - blurred_mapping_matrix = np.array( - [ - [1.0, 1.0, 0.0], - [1.0, 0.0, 0.0], - [0.0, 1.0, 0.0], - [0.0, 1.0, 1.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - ] - ) - - image = np.array([3.0, 1.0, 1.0, 10.0, 1.0, 1.0]) - noise_map = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) - - data_vector = aa.util.inversion_imaging.data_vector_via_blurred_mapping_matrix_from( - blurred_mapping_matrix=blurred_mapping_matrix, image=image, noise_map=noise_map - ) - - assert (data_vector == np.array([4.0, 14.0, 10.0])).all() - - blurred_mapping_matrix = np.array( - [ - [1.0, 1.0, 0.0], - [1.0, 0.0, 0.0], - [0.0, 1.0, 0.0], - [0.0, 1.0, 1.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - ] - ) - - image = np.array([4.0, 1.0, 1.0, 16.0, 1.0, 1.0]) - noise_map = np.array([2.0, 1.0, 1.0, 4.0, 1.0, 1.0]) - - data_vector = aa.util.inversion_imaging.data_vector_via_blurred_mapping_matrix_from( - blurred_mapping_matrix=blurred_mapping_matrix, image=image, noise_map=noise_map - ) - - assert (data_vector == np.array([2.0, 3.0, 1.0])).all() - - -def test__data_vector_via_weighted_data_two_methods_agree(): - mask = aa.Mask2D.circular(shape_native=(51, 51), pixel_scales=0.1, radius=2.0) - - image = np.random.uniform(size=mask.shape_native) - image = aa.Array2D(values=image, mask=mask) - - noise_map = np.random.uniform(size=mask.shape_native) - noise_map = aa.Array2D(values=noise_map, mask=mask) - - convolver = aa.Convolver.from_gaussian( - shape_native=(7, 7), pixel_scales=mask.pixel_scales, sigma=1.0, normalize=True - ) - - psf = convolver - - mesh = aa.mesh.RectangularUniform(shape=(20, 20)) - - # TODO : Use pytest.parameterize - - for sub_size in range(1, 3): - - grid = aa.Grid2D.from_mask(mask=mask, over_sample_size=sub_size) - - interpolator = mesh.interpolator_from( - source_plane_data_grid=grid, source_plane_mesh_grid=None - ) - - mapper = aa.Mapper(interpolator=interpolator) - - mapping_matrix = mapper.mapping_matrix - - blurred_mapping_matrix = psf.convolved_mapping_matrix_from( - mapping_matrix=mapping_matrix, mask=mask - ) - - data_vector = ( - aa.util.inversion_imaging.data_vector_via_blurred_mapping_matrix_from( - blurred_mapping_matrix=blurred_mapping_matrix, - image=image, - noise_map=noise_map, - ) - ) - - rows, cols, vals = aa.util.mapper.sparse_triplets_from( - pix_indexes_for_sub=mapper.pix_indexes_for_sub_slim_index, - pix_weights_for_sub=mapper.pix_weights_for_sub_slim_index, - slim_index_for_sub=mapper.slim_index_for_sub_slim_index, - fft_index_for_masked_pixel=mask.fft_index_for_masked_pixel, - sub_fraction_slim=mapper.over_sampler.sub_fraction.array, - ) - - weight_map = image.array / (noise_map.array**2) - weight_map = aa.Array2D(values=weight_map, mask=noise_map.mask) - - psf_weighted_data = aa.util.inversion_imaging.psf_weighted_data_from( - weight_map_native=weight_map.native.array, - kernel_native=convolver.kernel.native.array, - native_index_for_slim_index=mask.derive_indexes.native_for_slim.astype( - "int" - ), - ) - - data_vector_via_psf_weighted_noise = ( - aa.util.inversion_imaging.data_vector_via_psf_weighted_data_from( - psf_weighted_data=psf_weighted_data, - rows=rows, - cols=cols, - vals=vals, - S=mesh.pixels, - ) - ) - - assert data_vector_via_psf_weighted_noise == pytest.approx(data_vector, 1.0e-4) - - -def test__curvature_matrix_via_psf_weighted_noise_two_methods_agree(): - - mask = aa.Mask2D.circular(shape_native=(21, 21), pixel_scales=0.1, radius=0.8) - - noise_map = np.random.uniform(size=mask.shape_native) - noise_map = aa.Array2D(values=noise_map, mask=mask) - - kernel = aa.Convolver.from_gaussian( - shape_native=(5, 5), pixel_scales=mask.pixel_scales, sigma=1.0, normalize=True - ) - - psf = kernel - - sparse_operator = aa.ImagingSparseOperator.from_noise_map_and_psf( - data=noise_map, - noise_map=noise_map, - psf=psf.kernel.native, - ) - - mesh = aa.mesh.RectangularAdaptDensity(shape=(8, 8)) - - interpolator = mesh.interpolator_from( - source_plane_data_grid=mask.derive_grid.unmasked, - source_plane_mesh_grid=None, - ) - - mapper = aa.Mapper(interpolator=interpolator) - - mapping_matrix = mapper.mapping_matrix - - rows, cols, vals = aa.util.mapper.sparse_triplets_from( - pix_indexes_for_sub=mapper.pix_indexes_for_sub_slim_index, - pix_weights_for_sub=mapper.pix_weights_for_sub_slim_index, - slim_index_for_sub=mapper.slim_index_for_sub_slim_index, - fft_index_for_masked_pixel=mask.fft_index_for_masked_pixel, - sub_fraction_slim=mapper.over_sampler.sub_fraction.array, - return_rows_slim=False, - ) - - curvature_matrix_via_sparse_operator = sparse_operator.curvature_matrix_diag_from( - rows, - cols, - vals, - S=mesh.shape[0] * mesh.shape[1], - ) - - curvature_matrix_via_sparse_operator = ( - aa.util.inversion_imaging.curvature_matrix_mirrored_from( - curvature_matrix=curvature_matrix_via_sparse_operator, - ) - ) - - blurred_mapping_matrix = psf.convolved_mapping_matrix_from( - mapping_matrix=mapping_matrix, mask=mask - ) - - curvature_matrix = aa.util.inversion.curvature_matrix_via_mapping_matrix_from( - mapping_matrix=blurred_mapping_matrix, - noise_map=noise_map, - ) - - assert curvature_matrix_via_sparse_operator == pytest.approx( - curvature_matrix, abs=1.0e-4 - ) +import autoarray as aa +import numpy as np +import pytest + + +def test__psf_weighted_noise_imaging_from(): + noise_map = np.array( + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 2.0, 0.0], + [0.0, 2.0, 4.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ] + ) + + kernel = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 2.0], [0.0, 1.0, 2.0]]) + + native_index_for_slim_index = np.array([[1, 1], [1, 2], [2, 1], [2, 2]]) + + psf_weighted_noise = aa.util.inversion_imaging_numba.psf_precision_operator_from( + noise_map_native=noise_map, + kernel_native=kernel, + native_index_for_slim_index=native_index_for_slim_index, + ) + + assert psf_weighted_noise == pytest.approx( + np.array( + [ + [2.5, 1.625, 0.5, 0.375], + [1.625, 1.3125, 0.125, 0.0625], + [0.5, 0.125, 0.5, 0.375], + [0.375, 0.0625, 0.375, 0.3125], + ] + ), + 1.0e-4, + ) + + +def test__psf_weighted_data_from(): + + mask = aa.Mask2D( + mask=[ + [True, True, True, True], + [True, False, False, True], + [True, False, False, True], + [True, True, True, True], + ], + pixel_scales=(1.0, 1.0), + ) + + data = aa.Array2D( + values=[ + [0.0, 0.0, 0.0, 0.0], + [0.0, 2.0, 1.0, 0.0], + [0.0, 1.0, 2.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + mask=mask, + ) + + noise_map = aa.Array2D( + values=[ + [0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 0.0], + [0.0, 1.0, 2.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + mask=mask, + ) + + kernel = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [1.0, 2.0, 0.0]]) + + native_index_for_slim_index = np.array([[1, 1], [1, 2], [2, 1], [2, 2]]) + + weight_map = data / (noise_map**2) + weight_map = aa.Array2D(values=weight_map, mask=mask) + + psf_weighted_data = aa.util.inversion_imaging.psf_weighted_data_from( + weight_map_native=weight_map.native.array, + kernel_native=kernel, + native_index_for_slim_index=native_index_for_slim_index, + ) + + assert (psf_weighted_data == np.array([5.0, 5.0, 1.5, 1.5])).all() + + +def test__psf_precision_operator_sparse_from(): + noise_map = np.array( + [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 2.0, 0.0], + [0.0, 2.0, 4.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ] + ) + + kernel = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 2.0], [0.0, 1.0, 2.0]]) + + native_index_for_slim_index = np.array([[1, 1], [1, 2], [2, 1], [2, 2]]) + + ( + psf_weighted_noise_preload, + psf_weighted_noise_indexes, + psf_weighted_noise_lengths, + ) = aa.util.inversion_imaging_numba.psf_precision_operator_sparse_from( + noise_map_native=noise_map, + kernel_native=kernel, + native_index_for_slim_index=native_index_for_slim_index, + ) + + assert psf_weighted_noise_preload == pytest.approx( + np.array( + [1.25, 1.625, 0.5, 0.375, 0.65625, 0.125, 0.0625, 0.25, 0.375, 0.15625] + ), + 1.0e-4, + ) + assert psf_weighted_noise_indexes == pytest.approx( + np.array([0, 1, 2, 3, 1, 2, 3, 2, 3, 3]), 1.0e-4 + ) + + assert psf_weighted_noise_lengths == pytest.approx(np.array([4, 3, 2, 1]), 1.0e-4) + + +def test__data_vector_via_blurred_mapping_matrix_from(): + blurred_mapping_matrix = np.array( + [ + [1.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 1.0, 1.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ] + ) + + image = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) + noise_map = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) + + data_vector = aa.util.inversion_imaging.data_vector_via_blurred_mapping_matrix_from( + blurred_mapping_matrix=blurred_mapping_matrix, image=image, noise_map=noise_map + ) + + assert (data_vector == np.array([2.0, 3.0, 1.0])).all() + + blurred_mapping_matrix = np.array( + [ + [1.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 1.0, 1.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ] + ) + + image = np.array([3.0, 1.0, 1.0, 10.0, 1.0, 1.0]) + noise_map = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) + + data_vector = aa.util.inversion_imaging.data_vector_via_blurred_mapping_matrix_from( + blurred_mapping_matrix=blurred_mapping_matrix, image=image, noise_map=noise_map + ) + + assert (data_vector == np.array([4.0, 14.0, 10.0])).all() + + blurred_mapping_matrix = np.array( + [ + [1.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 1.0, 1.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ] + ) + + image = np.array([4.0, 1.0, 1.0, 16.0, 1.0, 1.0]) + noise_map = np.array([2.0, 1.0, 1.0, 4.0, 1.0, 1.0]) + + data_vector = aa.util.inversion_imaging.data_vector_via_blurred_mapping_matrix_from( + blurred_mapping_matrix=blurred_mapping_matrix, image=image, noise_map=noise_map + ) + + assert (data_vector == np.array([2.0, 3.0, 1.0])).all() diff --git a/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py b/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py index a23ac182e..9a527b09b 100644 --- a/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py +++ b/test_autoarray/inversion/inversion/interferometer/test_inversion_interferometer_util.py @@ -1,141 +1,68 @@ -import autoarray as aa -import numpy as np -import pytest - - -def test__data_vector_via_transformed_mapping_matrix_from(): - mapping_matrix = np.array( - [ - [1.0, 1.0, 0.0], - [1.0, 0.0, 0.0], - [0.0, 1.0, 0.0], - [0.0, 1.0, 1.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - ] - ) - - data_real = np.array([4.0, 1.0, 1.0, 16.0, 1.0, 1.0]) - noise_map_real = np.array([2.0, 1.0, 1.0, 4.0, 1.0, 1.0]) - - data_vector_real_via_blurred = ( - aa.util.inversion_imaging.data_vector_via_blurred_mapping_matrix_from( - blurred_mapping_matrix=mapping_matrix, - image=data_real, - noise_map=noise_map_real, - ) - ) - - data_imag = np.array([4.0, 1.0, 1.0, 16.0, 1.0, 1.0]) - noise_map_imag = np.array([2.0, 1.0, 1.0, 4.0, 1.0, 1.0]) - - data_vector_imag_via_blurred = ( - aa.util.inversion_imaging.data_vector_via_blurred_mapping_matrix_from( - blurred_mapping_matrix=mapping_matrix, - image=data_imag, - noise_map=noise_map_imag, - ) - ) - - data_vector_complex_via_blurred = ( - data_vector_real_via_blurred + data_vector_imag_via_blurred - ) - - transformed_mapping_matrix = np.array( - [ - [1.0 + 1.0j, 1.0 + 1.0j, 0.0 + 0.0j], - [1.0 + 1.0j, 0.0 + 0.0j, 0.0 + 0.0j], - [0.0 + 0.0j, 1.0 + 1.0j, 0.0 + 0.0j], - [0.0 + 0.0j, 1.0 + 1.0j, 1.0 + 1.0j], - [0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j], - [0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j], - ] - ) - - data = np.array( - [4.0 + 4.0j, 1.0 + 1.0j, 1.0 + 1.0j, 16.0 + 16.0j, 1.0 + 1.0j, 1.0 + 1.0j] - ) - noise_map = np.array( - [2.0 + 2.0j, 1.0 + 1.0j, 1.0 + 1.0j, 4.0 + 4.0j, 1.0 + 1.0j, 1.0 + 1.0j] - ) - - data_vector_via_transformed = aa.util.inversion_interferometer.data_vector_via_transformed_mapping_matrix_from( - transformed_mapping_matrix=transformed_mapping_matrix, - visibilities=data, - noise_map=noise_map, - ) - - assert (data_vector_complex_via_blurred == data_vector_via_transformed).all() - - -def test__curvature_matrix_via_psf_precision_operator_from(): - noise_map = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) - uv_wavelengths = np.array( - [[0.0001, 2.0, 3000.0, 50000.0, 200000.0], [3000.0, 2.0, 0.0001, 10.0, 5000.0]] - ) - - grid = aa.Grid2D.uniform(shape_native=(3, 3), pixel_scales=0.0005) - - mapping_matrix = np.array( - [ - [1.0, 0.0, 0.0], - [0.0, 0.0, 1.0], - [0.0, 1.0, 0.0], - [0.0, 1.0, 0.0], - [0.0, 0.0, 1.0], - [0.0, 0.0, 1.0], - [1.0, 0.0, 0.0], - [0.0, 0.0, 1.0], - [1.0, 0.0, 0.0], - ] - ) - - nufft_precision_operator = ( - aa.util.inversion_interferometer.nufft_precision_operator_from( - noise_map_real=noise_map, - uv_wavelengths=uv_wavelengths, - shape_masked_pixels_2d=(3, 3), - grid_radians_2d=np.array(grid.native), - ) - ) - - native_index_for_slim_index = np.array( - [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [2, 0], [2, 1], [2, 2]] - ) - - psf_weighted_noise = ( - aa.util.inversion_interferometer.nufft_weighted_noise_via_sparse_operator_from( - translation_invariant_kernel=nufft_precision_operator, - native_index_for_slim_index=native_index_for_slim_index, - ) - ) - - curvature_matrix_via_nufft_weighted_noise = ( - aa.util.inversion.curvature_matrix_diag_via_psf_weighted_noise_from( - psf_weighted_noise=psf_weighted_noise, mapping_matrix=mapping_matrix - ) - ) - - pix_indexes_for_sub_slim_index = np.array( - [[0], [2], [1], [1], [2], [2], [0], [2], [0]] - ) - - pix_weights_for_sub_slim_index = np.ones(shape=(9, 1)) - - sparse_operator = aa.InterferometerSparseOperator.from_nufft_precision_operator( - nufft_precision_operator=nufft_precision_operator, - dirty_image=None, - ) - - curvature_matrix_via_preload = ( - sparse_operator.curvature_matrix_via_sparse_operator_from( - pix_indexes_for_sub_slim_index=pix_indexes_for_sub_slim_index, - pix_weights_for_sub_slim_index=pix_weights_for_sub_slim_index, - fft_index_for_masked_pixel=grid.mask.fft_index_for_masked_pixel, - pix_pixels=3, - ) - ) - - assert curvature_matrix_via_nufft_weighted_noise == pytest.approx( - curvature_matrix_via_preload, 1.0e-4 - ) +import autoarray as aa +import numpy as np +import pytest + + +def test__data_vector_via_transformed_mapping_matrix_from(): + mapping_matrix = np.array( + [ + [1.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 1.0, 1.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ] + ) + + data_real = np.array([4.0, 1.0, 1.0, 16.0, 1.0, 1.0]) + noise_map_real = np.array([2.0, 1.0, 1.0, 4.0, 1.0, 1.0]) + + data_vector_real_via_blurred = ( + aa.util.inversion_imaging.data_vector_via_blurred_mapping_matrix_from( + blurred_mapping_matrix=mapping_matrix, + image=data_real, + noise_map=noise_map_real, + ) + ) + + data_imag = np.array([4.0, 1.0, 1.0, 16.0, 1.0, 1.0]) + noise_map_imag = np.array([2.0, 1.0, 1.0, 4.0, 1.0, 1.0]) + + data_vector_imag_via_blurred = ( + aa.util.inversion_imaging.data_vector_via_blurred_mapping_matrix_from( + blurred_mapping_matrix=mapping_matrix, + image=data_imag, + noise_map=noise_map_imag, + ) + ) + + data_vector_complex_via_blurred = ( + data_vector_real_via_blurred + data_vector_imag_via_blurred + ) + + transformed_mapping_matrix = np.array( + [ + [1.0 + 1.0j, 1.0 + 1.0j, 0.0 + 0.0j], + [1.0 + 1.0j, 0.0 + 0.0j, 0.0 + 0.0j], + [0.0 + 0.0j, 1.0 + 1.0j, 0.0 + 0.0j], + [0.0 + 0.0j, 1.0 + 1.0j, 1.0 + 1.0j], + [0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j], + [0.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j], + ] + ) + + data = np.array( + [4.0 + 4.0j, 1.0 + 1.0j, 1.0 + 1.0j, 16.0 + 16.0j, 1.0 + 1.0j, 1.0 + 1.0j] + ) + noise_map = np.array( + [2.0 + 2.0j, 1.0 + 1.0j, 1.0 + 1.0j, 4.0 + 4.0j, 1.0 + 1.0j, 1.0 + 1.0j] + ) + + data_vector_via_transformed = aa.util.inversion_interferometer.data_vector_via_transformed_mapping_matrix_from( + transformed_mapping_matrix=transformed_mapping_matrix, + visibilities=data, + noise_map=noise_map, + ) + + assert (data_vector_complex_via_blurred == data_vector_via_transformed).all() diff --git a/test_autoarray/inversion/inversion/test_factory.py b/test_autoarray/inversion/inversion/test_factory.py index be751d07e..3674d1c89 100644 --- a/test_autoarray/inversion/inversion/test_factory.py +++ b/test_autoarray/inversion/inversion/test_factory.py @@ -125,108 +125,6 @@ def test__inversion_imaging__via_mapper( ) -def test__inversion_imaging__via_mapper_knn( - masked_imaging_7x7_no_blur, - knn_mapper_9_3x3, - regularization_adaptive_brightness_split, -): - - inversion = aa.Inversion( - dataset=masked_imaging_7x7_no_blur, - linear_obj_list=[knn_mapper_9_3x3], - ) - - assert knn_mapper_9_3x3.pix_indexes_for_sub_slim_index[0, :] == pytest.approx( - [1, 0, 4, 6, 2, 5, 3, 7, 8], 1.0e-4 - ) - assert knn_mapper_9_3x3.pix_indexes_for_sub_slim_index[1, :] == pytest.approx( - [1, 0, 2, 4, 6, 3, 5, 7, 8], 1.0e-4 - ) - assert knn_mapper_9_3x3.pix_indexes_for_sub_slim_index[2, :] == pytest.approx( - [1, 0, 4, 6, 2, 5, 3, 7, 8], 1.0e-4 - ) - - assert knn_mapper_9_3x3.pix_weights_for_sub_slim_index[0, :] == pytest.approx( - [ - 0.24139248, - 0.20182463, - 0.13465525, - 0.12882639, - 0.12169429, - 0.08682546, - 0.07062276, - 0.00982079, - 0.00433794, - ], - 1.0e-4, - ) - assert knn_mapper_9_3x3.pix_weights_for_sub_slim_index[1, :] == pytest.approx( - [ - 0.23255487, - 0.22727716, - 0.14466056, - 0.11643257, - 0.09868897, - 0.08878719, - 0.07744259, - 0.01010399, - 0.0040521, - ], - 1.0e-4, - ) - assert knn_mapper_9_3x3.pix_weights_for_sub_slim_index[2, :] == pytest.approx( - [ - 0.2334672, - 0.1785593, - 0.153417, - 0.15099354, - 0.11075057, - 0.09986048, - 0.06060822, - 0.00869774, - 0.00364596, - ], - 1.0e-4, - ) - - assert isinstance(inversion, aa.InversionImagingMapping) - - assert inversion.regularization_matrix[0:3, 0] == pytest.approx( - [4.00000001, -1.0, -1.0], 1.0e-4 - ) - assert inversion.regularization_matrix[0:3, 1] == pytest.approx( - [-1.0, 3.00000001, 0.0], 1.0e-4 - ) - assert inversion.regularization_matrix[0:3, 2] == pytest.approx( - [-1.0, 0.0, 4.00000001], 1.0e-4 - ) - - assert inversion.log_det_curvature_reg_matrix_term == pytest.approx( - 10.417803331712355, 1.0e-4 - ) - assert inversion.mapped_reconstructed_operated_data == pytest.approx( - np.ones(9), 1.0e-4 - ) - - mapper = copy.copy(knn_mapper_9_3x3) - mapper.regularization = regularization_adaptive_brightness_split - - inversion = aa.Inversion( - dataset=masked_imaging_7x7_no_blur, - linear_obj_list=[mapper], - ) - - assert inversion.regularization_matrix[0:3, 0] == pytest.approx( - [22.47519068, -16.373819, 8.39424766], 1.0e-4 - ) - assert inversion.regularization_matrix[0:3, 1] == pytest.approx( - [-16.373819, 112.1402519, -13.56808248], 1.0e-4 - ) - assert inversion.regularization_matrix[0:3, 2] == pytest.approx( - [8.39424766, -13.56808248, 26.10743213], 1.0e-4 - ) - - def test__inversion_imaging__via_regularizations( masked_imaging_7x7_no_blur, delaunay_mapper_9_3x3, diff --git a/test_autoarray/inversion/inversion/test_inversion_util.py b/test_autoarray/inversion/inversion/test_inversion_util.py index 851701912..87964d82e 100644 --- a/test_autoarray/inversion/inversion/test_inversion_util.py +++ b/test_autoarray/inversion/inversion/test_inversion_util.py @@ -1,301 +1,230 @@ -import autoarray as aa -import numpy as np -import pytest - - -def test__curvature_matrix_diag_via_psf_weighted_noise_from(): - psf_weighted_noise = np.array( - [ - [1.0, 2.0, 3.0, 4.0], - [2.0, 1.0, 2.0, 3.0], - [3.0, 2.0, 1.0, 2.0], - [4.0, 3.0, 2.0, 1.0], - ] - ) - - mapping_matrix = np.array( - [[1.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]] - ) - - curvature_matrix = ( - aa.util.inversion.curvature_matrix_diag_via_psf_weighted_noise_from( - psf_weighted_noise=psf_weighted_noise, mapping_matrix=mapping_matrix - ) - ) - - assert ( - curvature_matrix - == np.array([[6.0, 8.0, 0.0], [8.0, 8.0, 0.0], [0.0, 0.0, 0.0]]) - ).all() - - -def test__curvature_matrix_via_mapping_matrix_from(): - blurred_mapping_matrix = np.array( - [ - [1.0, 1.0, 0.0], - [1.0, 0.0, 0.0], - [0.0, 1.0, 0.0], - [0.0, 1.0, 1.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - ] - ) - - noise_map = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) - - curvature_matrix = aa.util.inversion.curvature_matrix_via_mapping_matrix_from( - mapping_matrix=blurred_mapping_matrix, noise_map=noise_map - ) - - assert ( - curvature_matrix - == np.array([[2.0, 1.0, 0.0], [1.0, 3.0, 1.0], [0.0, 1.0, 1.0]]) - ).all() - - blurred_mapping_matrix = np.array( - [ - [1.0, 1.0, 0.0], - [1.0, 0.0, 0.0], - [0.0, 1.0, 0.0], - [0.0, 1.0, 1.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - ] - ) - - noise_map = np.array([2.0, 1.0, 1.0, 1.0, 1.0, 1.0]) - - curvature_matrix = aa.util.inversion.curvature_matrix_via_mapping_matrix_from( - mapping_matrix=blurred_mapping_matrix, noise_map=noise_map - ) - - assert ( - curvature_matrix - == np.array([[1.25, 0.25, 0.0], [0.25, 2.25, 1.0], [0.0, 1.0, 1.0]]) - ).all() - - -def test__reconstruction_positive_negative_from(): - data_vector = np.array([1.0, 1.0, 2.0]) - - curvature_reg_matrix = np.array([[2.0, 1.0, 0.0], [1.0, 3.0, 1.0], [0.0, 1.0, 1.0]]) - - reconstruction = aa.util.inversion.reconstruction_positive_negative_from( - data_vector=data_vector, - curvature_reg_matrix=curvature_reg_matrix, - ) - - assert reconstruction == pytest.approx(np.array([1.0, -1.0, 3.0]), 1.0e-4) - - -def test__mapped_reconstructed_data_via_mapping_matrix_from(): - mapping_matrix = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) - - reconstruction = np.array([1.0, 1.0, 2.0]) - - mapped_reconstructed_operated_data = ( - aa.util.inversion.mapped_reconstructed_data_via_mapping_matrix_from( - mapping_matrix=mapping_matrix, reconstruction=reconstruction - ) - ) - - assert (mapped_reconstructed_operated_data == np.array([1.0, 1.0, 2.0])).all() - - mapping_matrix = np.array([[0.25, 0.50, 0.25], [0.0, 1.0, 0.0], [0.0, 0.25, 0.75]]) - - reconstruction = np.array([1.0, 1.0, 2.0]) - - mapped_reconstructed_operated_data = ( - aa.util.inversion.mapped_reconstructed_data_via_mapping_matrix_from( - mapping_matrix=mapping_matrix, reconstruction=reconstruction - ) - ) - - assert (mapped_reconstructed_operated_data == np.array([1.25, 1.0, 1.75])).all() - - -def test__mapped_reconstructed_data_via_image_to_pix_unique_from(): - pix_indexes_for_sub_slim_index = np.array([[0], [1], [2]]) - pix_indexes_for_sub_slim_index_sizes = np.array([1, 1, 1]).astype("int") - pix_weights_for_sub_slim_index = np.array([[1.0], [1.0], [1.0]]) - - ( - data_to_pix_unique, - data_weights, - pix_lengths, - ) = aa.util.mapper_numba.data_slim_to_pixelization_unique_from( - data_pixels=3, - pix_indexes_for_sub_slim_index=pix_indexes_for_sub_slim_index, - pix_sizes_for_sub_slim_index=pix_indexes_for_sub_slim_index_sizes, - pix_weights_for_sub_slim_index=pix_weights_for_sub_slim_index, - pix_pixels=3, - sub_size=np.array([1, 1, 1]), - ) - - reconstruction = np.array([1.0, 1.0, 2.0]) - - mapped_reconstructed_operated_data = aa.util.inversion_imaging_numba.mapped_reconstructed_data_via_image_to_pix_unique_from( - data_to_pix_unique=data_to_pix_unique.astype("int"), - data_weights=data_weights, - pix_lengths=pix_lengths.astype("int"), - reconstruction=reconstruction, - ) - - assert (mapped_reconstructed_operated_data == np.array([1.0, 1.0, 2.0])).all() - - pix_indexes_for_sub_slim_index = np.array( - [[0], [1], [1], [2], [1], [1], [1], [1], [1], [2], [2], [2]] - ) - pix_indexes_for_sub_slim_index_sizes = np.ones(shape=(12,)).astype("int") - pix_weights_for_sub_slim_index = np.ones(shape=(12, 1)) - - ( - data_to_pix_unique, - data_weights, - pix_lengths, - ) = aa.util.mapper_numba.data_slim_to_pixelization_unique_from( - data_pixels=3, - pix_indexes_for_sub_slim_index=pix_indexes_for_sub_slim_index, - pix_sizes_for_sub_slim_index=pix_indexes_for_sub_slim_index_sizes, - pix_weights_for_sub_slim_index=pix_weights_for_sub_slim_index, - pix_pixels=3, - sub_size=np.array([2, 2, 2]), - ) - - reconstruction = np.array([1.0, 1.0, 2.0]) - - mapped_reconstructed_operated_data = aa.util.inversion_imaging_numba.mapped_reconstructed_data_via_image_to_pix_unique_from( - data_to_pix_unique=data_to_pix_unique.astype("int"), - data_weights=data_weights, - pix_lengths=pix_lengths.astype("int"), - reconstruction=reconstruction, - ) - - assert (mapped_reconstructed_operated_data == np.array([1.25, 1.0, 1.75])).all() - - -def test__preconditioner_matrix_via_mapping_matrix_from(): - mapping_matrix = np.array( - [ - [1.0, 0.0, 0.0], - [1.0, 0.0, 0.0], - [0.0, 1.0, 0.0], - [0.0, 1.0, 0.0], - [0.0, 0.0, 1.0], - [0.0, 0.0, 1.0], - ] - ) - - preconditioner_matrix = ( - aa.util.inversion.preconditioner_matrix_via_mapping_matrix_from( - mapping_matrix=mapping_matrix, - preconditioner_noise_normalization=1.0, - regularization_matrix=np.zeros((3, 3)), - ) - ) - - assert ( - preconditioner_matrix - == np.array([[2.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 2.0]]) - ).all() - - preconditioner_matrix = ( - aa.util.inversion.preconditioner_matrix_via_mapping_matrix_from( - mapping_matrix=mapping_matrix, - preconditioner_noise_normalization=2.0, - regularization_matrix=np.zeros((3, 3)), - ) - ) - - assert ( - preconditioner_matrix - == np.array([[4.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]]) - ).all() - - regularization_matrix = np.array( - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] - ) - - preconditioner_matrix = ( - aa.util.inversion.preconditioner_matrix_via_mapping_matrix_from( - mapping_matrix=mapping_matrix, - preconditioner_noise_normalization=2.0, - regularization_matrix=regularization_matrix, - ) - ) - - assert ( - preconditioner_matrix - == np.array([[5.0, 2.0, 3.0], [4.0, 9.0, 6.0], [7.0, 8.0, 13.0]]) - ).all() - - -def test__reconstruction_positive_only_from__jax_ill_conditioned_grad_is_finite(): - """ - On ill-conditioned curvature matrices the jaxnnls backward pass used to - return NaN gradients, because the relaxed-KKT solver diverged. Jacobi - preconditioning inside `reconstruction_positive_only_from` re-parameterises - the NNLS problem so the solve converges and `jax.value_and_grad` produces - finite gradients. Skip the test if jax / jaxnnls are not available. - """ - jax = pytest.importorskip("jax") - import jax.numpy as jnp - pytest.importorskip("jaxnnls") - - # A small deliberately ill-conditioned symmetric positive-definite Q, - # cond(Q) ~ 1e7, which is enough to break the raw jaxnnls backward pass. - rng = np.random.default_rng(0) - n = 10 - U, _ = np.linalg.qr(rng.standard_normal((n, n))) - eigs = np.logspace(-4, 3, n) - Q_np = (U * eigs) @ U.T - Q_np = 0.5 * (Q_np + Q_np.T) - q_np = rng.standard_normal(n) - - Q = jnp.array(Q_np) - q = jnp.array(q_np) - - def loss(q_in): - x = aa.util.inversion.reconstruction_positive_only_from( - data_vector=q_in, curvature_reg_matrix=Q, xp=jnp, +import autoarray as aa +import numpy as np +import pytest + + +def test__curvature_matrix_diag_via_psf_weighted_noise_from(): + psf_weighted_noise = np.array( + [ + [1.0, 2.0, 3.0, 4.0], + [2.0, 1.0, 2.0, 3.0], + [3.0, 2.0, 1.0, 2.0], + [4.0, 3.0, 2.0, 1.0], + ] + ) + + mapping_matrix = np.array( + [[1.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]] + ) + + curvature_matrix = ( + aa.util.inversion.curvature_matrix_diag_via_psf_weighted_noise_from( + psf_weighted_noise=psf_weighted_noise, mapping_matrix=mapping_matrix ) - return jnp.sum(x) + ) + + assert ( + curvature_matrix + == np.array([[6.0, 8.0, 0.0], [8.0, 8.0, 0.0], [0.0, 0.0, 0.0]]) + ).all() + + +def test__curvature_matrix_via_mapping_matrix_from(): + blurred_mapping_matrix = np.array( + [ + [1.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 1.0, 1.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ] + ) + + noise_map = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) + + curvature_matrix = aa.util.inversion.curvature_matrix_via_mapping_matrix_from( + mapping_matrix=blurred_mapping_matrix, noise_map=noise_map + ) + + assert ( + curvature_matrix + == np.array([[2.0, 1.0, 0.0], [1.0, 3.0, 1.0], [0.0, 1.0, 1.0]]) + ).all() + + blurred_mapping_matrix = np.array( + [ + [1.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 1.0, 1.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ] + ) - value, grad = jax.value_and_grad(loss)(q) + noise_map = np.array([2.0, 1.0, 1.0, 1.0, 1.0, 1.0]) - assert np.isfinite(float(value)) - grad_np = np.array(grad) - assert np.all(np.isfinite(grad_np)), ( - f"gradient has {np.sum(~np.isfinite(grad_np))} non-finite entries" + curvature_matrix = aa.util.inversion.curvature_matrix_via_mapping_matrix_from( + mapping_matrix=blurred_mapping_matrix, noise_map=noise_map ) + assert ( + curvature_matrix + == np.array([[1.25, 0.25, 0.0], [0.25, 2.25, 1.0], [0.0, 1.0, 1.0]]) + ).all() + + +def test__reconstruction_positive_negative_from(): + data_vector = np.array([1.0, 1.0, 2.0]) + + curvature_reg_matrix = np.array([[2.0, 1.0, 0.0], [1.0, 3.0, 1.0], [0.0, 1.0, 1.0]]) + + reconstruction = aa.util.inversion.reconstruction_positive_negative_from( + data_vector=data_vector, + curvature_reg_matrix=curvature_reg_matrix, + ) + + assert reconstruction == pytest.approx(np.array([1.0, -1.0, 3.0]), 1.0e-4) + + +def test__mapped_reconstructed_data_via_mapping_matrix_from(): + mapping_matrix = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) + + reconstruction = np.array([1.0, 1.0, 2.0]) + + mapped_reconstructed_operated_data = ( + aa.util.inversion.mapped_reconstructed_data_via_mapping_matrix_from( + mapping_matrix=mapping_matrix, reconstruction=reconstruction + ) + ) + + assert (mapped_reconstructed_operated_data == np.array([1.0, 1.0, 2.0])).all() + + mapping_matrix = np.array([[0.25, 0.50, 0.25], [0.0, 1.0, 0.0], [0.0, 0.25, 0.75]]) + + reconstruction = np.array([1.0, 1.0, 2.0]) + + mapped_reconstructed_operated_data = ( + aa.util.inversion.mapped_reconstructed_data_via_mapping_matrix_from( + mapping_matrix=mapping_matrix, reconstruction=reconstruction + ) + ) + + assert (mapped_reconstructed_operated_data == np.array([1.25, 1.0, 1.75])).all() + + +def test__mapped_reconstructed_data_via_image_to_pix_unique_from(): + pix_indexes_for_sub_slim_index = np.array([[0], [1], [2]]) + pix_indexes_for_sub_slim_index_sizes = np.array([1, 1, 1]).astype("int") + pix_weights_for_sub_slim_index = np.array([[1.0], [1.0], [1.0]]) + + ( + data_to_pix_unique, + data_weights, + pix_lengths, + ) = aa.util.mapper_numba.data_slim_to_pixelization_unique_from( + data_pixels=3, + pix_indexes_for_sub_slim_index=pix_indexes_for_sub_slim_index, + pix_sizes_for_sub_slim_index=pix_indexes_for_sub_slim_index_sizes, + pix_weights_for_sub_slim_index=pix_weights_for_sub_slim_index, + pix_pixels=3, + sub_size=np.array([1, 1, 1]), + ) + + reconstruction = np.array([1.0, 1.0, 2.0]) + + mapped_reconstructed_operated_data = aa.util.inversion_imaging_numba.mapped_reconstructed_data_via_image_to_pix_unique_from( + data_to_pix_unique=data_to_pix_unique.astype("int"), + data_weights=data_weights, + pix_lengths=pix_lengths.astype("int"), + reconstruction=reconstruction, + ) + + assert (mapped_reconstructed_operated_data == np.array([1.0, 1.0, 2.0])).all() + + pix_indexes_for_sub_slim_index = np.array( + [[0], [1], [1], [2], [1], [1], [1], [1], [1], [2], [2], [2]] + ) + pix_indexes_for_sub_slim_index_sizes = np.ones(shape=(12,)).astype("int") + pix_weights_for_sub_slim_index = np.ones(shape=(12, 1)) + + ( + data_to_pix_unique, + data_weights, + pix_lengths, + ) = aa.util.mapper_numba.data_slim_to_pixelization_unique_from( + data_pixels=3, + pix_indexes_for_sub_slim_index=pix_indexes_for_sub_slim_index, + pix_sizes_for_sub_slim_index=pix_indexes_for_sub_slim_index_sizes, + pix_weights_for_sub_slim_index=pix_weights_for_sub_slim_index, + pix_pixels=3, + sub_size=np.array([2, 2, 2]), + ) + + reconstruction = np.array([1.0, 1.0, 2.0]) + + mapped_reconstructed_operated_data = aa.util.inversion_imaging_numba.mapped_reconstructed_data_via_image_to_pix_unique_from( + data_to_pix_unique=data_to_pix_unique.astype("int"), + data_weights=data_weights, + pix_lengths=pix_lengths.astype("int"), + reconstruction=reconstruction, + ) + + assert (mapped_reconstructed_operated_data == np.array([1.25, 1.0, 1.75])).all() + + +def test__preconditioner_matrix_via_mapping_matrix_from(): + mapping_matrix = np.array( + [ + [1.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [0.0, 0.0, 1.0], + ] + ) + + preconditioner_matrix = ( + aa.util.inversion.preconditioner_matrix_via_mapping_matrix_from( + mapping_matrix=mapping_matrix, + preconditioner_noise_normalization=1.0, + regularization_matrix=np.zeros((3, 3)), + ) + ) + + assert ( + preconditioner_matrix + == np.array([[2.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 2.0]]) + ).all() + + preconditioner_matrix = ( + aa.util.inversion.preconditioner_matrix_via_mapping_matrix_from( + mapping_matrix=mapping_matrix, + preconditioner_noise_normalization=2.0, + regularization_matrix=np.zeros((3, 3)), + ) + ) + + assert ( + preconditioner_matrix + == np.array([[4.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]]) + ).all() + + regularization_matrix = np.array( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] + ) -def test__reconstruction_positive_only_from__jax_matches_unpreconditioned_primal(): - """ - Jacobi preconditioning is a change of coordinates; the forward primal - solution must match the raw jaxnnls solve to within solver tolerance for - a moderately-conditioned problem where the raw solver also converges. - """ - jax = pytest.importorskip("jax") - import jax.numpy as jnp - jaxnnls = pytest.importorskip("jaxnnls") - - rng = np.random.default_rng(1) - n = 8 - U, _ = np.linalg.qr(rng.standard_normal((n, n))) - eigs = np.linspace(0.5, 5.0, n) # well-conditioned - Q_np = (U * eigs) @ U.T - Q_np = 0.5 * (Q_np + Q_np.T) - q_np = rng.standard_normal(n) - - Q = jnp.array(Q_np) - q = jnp.array(q_np) - - x_raw = np.array(jaxnnls.solve_nnls_primal(Q, q)) - x_pc = np.array( - aa.util.inversion.reconstruction_positive_only_from( - data_vector=q, curvature_reg_matrix=Q, xp=jnp, + preconditioner_matrix = ( + aa.util.inversion.preconditioner_matrix_via_mapping_matrix_from( + mapping_matrix=mapping_matrix, + preconditioner_noise_normalization=2.0, + regularization_matrix=regularization_matrix, ) ) - np.testing.assert_allclose(x_pc, x_raw, rtol=1e-6, atol=1e-8) + assert ( + preconditioner_matrix + == np.array([[5.0, 2.0, 3.0], [4.0, 9.0, 6.0], [7.0, 8.0, 13.0]]) + ).all() diff --git a/test_autoarray/structures/grids/test_irregular_2d.py b/test_autoarray/structures/grids/test_irregular_2d.py index ca42e02bc..810ab1842 100644 --- a/test_autoarray/structures/grids/test_irregular_2d.py +++ b/test_autoarray/structures/grids/test_irregular_2d.py @@ -2,7 +2,6 @@ from os import path import shutil import numpy as np -import pytest import autoarray as aa @@ -76,23 +75,15 @@ def test__grid_2d_via_deflection_grid_from(): def test__grid_2d_via_deflection_grid_from__propagates_xp(): - # numpy-backed receiver -> numpy-backed result + # numpy-backed receiver -> numpy-backed result. + # The JAX half of this assertion lives in + # autolens_workspace_test/scripts/jax_assertions/grid_irregular.py. grid_np = aa.Grid2DIrregular(values=[(1.0, 1.0), (2.0, 2.0)]) result_np = grid_np.grid_2d_via_deflection_grid_from( deflection_grid=np.array([[1.0, 0.0], [1.0, 1.0]]) ) assert result_np._xp is np - # jax-backed receiver -> jax-backed result (so downstream .square calls use jnp) - jnp = pytest.importorskip("jax.numpy") - grid_jax = aa.Grid2DIrregular( - values=jnp.array([[1.0, 1.0], [2.0, 2.0]]), xp=jnp - ) - result_jax = grid_jax.grid_2d_via_deflection_grid_from( - deflection_grid=jnp.array([[1.0, 0.0], [1.0, 1.0]]) - ) - assert result_jax._xp is jnp - def test__furthest_distances_to_other_coordinates(): grid = aa.Grid2DIrregular(values=[(0.0, 0.0), (0.0, 1.0)]) diff --git a/test_autoarray/test_jax_changes.py b/test_autoarray/test_changes.py similarity index 100% rename from test_autoarray/test_jax_changes.py rename to test_autoarray/test_changes.py diff --git a/test_autoarray/test_jax_pytree.py b/test_autoarray/test_jax_pytree.py deleted file mode 100644 index da7b8dd66..000000000 --- a/test_autoarray/test_jax_pytree.py +++ /dev/null @@ -1,64 +0,0 @@ -"""Tests for gated JAX pytree registration of ``AbstractNDArray`` subclasses. - -Follows the three-step pattern from ``autolens_workspace_test/scripts/hessian_jax.py``: -1. NumPy path — confirm autoarray type with ``np.ndarray`` backing, no pytree registration. -2. JAX path outside JIT — same autoarray type with ``jax.Array`` backing; pytree registered. -3. JAX path through ``jax.jit`` — round-trip the instance and assert the output carries - a ``jax.Array`` leaf. -""" - -import numpy as np -import numpy.testing as npt -import pytest - -jax = pytest.importorskip("jax") -jnp = pytest.importorskip("jax.numpy") - -from autoarray.abstract_ndarray import AbstractNDArray, _pytree_registered_classes - - -class _LeafArray(AbstractNDArray): - """Minimal concrete ``AbstractNDArray`` with no nested autoarray children. - - Isolates the pytree-registration machinery from the larger autoarray - hierarchy: a real ``Array2D`` also carries a ``Mask2D`` and other nested - ``AbstractNDArray`` children whose own registration is covered by - follow-up steps in the ``fit-imaging-pytree`` task. - """ - - @property - def native(self): - return self - - -def test_numpy_path_does_not_register_pytree(): - _pytree_registered_classes.discard(_LeafArray) - - arr = _LeafArray(np.array([1.0, 2.0, 3.0])) - - assert isinstance(arr._array, np.ndarray) - assert _LeafArray not in _pytree_registered_classes - - -def test_jax_path_registers_pytree_once(): - _pytree_registered_classes.discard(_LeafArray) - - arr_jax = _LeafArray(jnp.array([1.0, 2.0, 3.0]), xp=jnp) - - assert isinstance(arr_jax._array, jnp.ndarray) - assert _LeafArray in _pytree_registered_classes - - # Second construction on the JAX path is a no-op; class stays registered. - _LeafArray(jnp.array([4.0, 5.0]), xp=jnp) - assert _LeafArray in _pytree_registered_classes - - -def test_jax_jit_round_trip_returns_wrapper_with_jax_array(): - arr_jax = _LeafArray(jnp.array([1.0, 2.0, 3.0]), xp=jnp) - assert _LeafArray in _pytree_registered_classes - - result = jax.jit(lambda a: a)(arr_jax) - - assert isinstance(result, _LeafArray) - assert isinstance(result._array, jnp.ndarray) - npt.assert_allclose(np.asarray(result._array), np.asarray(arr_jax._array))