Skip to content

Commit a15e097

Browse files
authored
Merge pull request #342 from Jammy2211/feature/jax_unit_tests
Feature/jax unit tests
2 parents 310c42d + 8eb945d commit a15e097

30 files changed

Lines changed: 247 additions & 373 deletions

autolens/analysis/positions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,10 +224,12 @@ def log_likelihood_penalty_base_from(
224224
residual_map=residual_map, noise_map=dataset.noise_map
225225
)
226226

227-
chi_squared = aa.util.fit.chi_squared_from(chi_squared_map=chi_squared_map)
227+
chi_squared = aa.util.fit.chi_squared_from(
228+
chi_squared_map=chi_squared_map.array
229+
)
228230

229231
noise_normalization = aa.util.fit.noise_normalization_from(
230-
noise_map=dataset.noise_map
232+
noise_map=dataset.noise_map.array
231233
)
232234

233235
else:

autolens/imaging/fit_imaging.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def tracer_to_inversion(self) -> TracerToInversion:
104104
noise_map=self.noise_map,
105105
grids=self.grids,
106106
psf=self.dataset.psf,
107+
convolver=self.dataset.convolver,
107108
w_tilde=self.w_tilde,
108109
)
109110

autolens/lens/to_inversion.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def lp_linear_func_list_galaxy_dict(
181181
noise_map=self.dataset.noise_map,
182182
grids=grids,
183183
psf=self.psf,
184+
convolver=self.dataset.convolver,
184185
transformer=self.transformer,
185186
w_tilde=self.dataset.w_tilde,
186187
)

autolens/lens/tracer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from abc import ABC
22
import numpy as np
3-
from functools import wraps
43
from scipy.interpolate import griddata
54
from typing import Dict, List, Optional, Type, Union
65

@@ -549,9 +548,9 @@ def image_2d_via_input_plane_image_from(
549548
)[plane_index]
550549

551550
image = griddata(
552-
points=plane_grid,
553-
values=plane_image,
554-
xi=traced_grid.over_sampled,
551+
points=plane_grid.array,
552+
values=plane_image.array,
553+
xi=traced_grid.over_sampled.array,
555554
fill_value=0.0,
556555
method="linear",
557556
)

autolens/point/fit/positions/source/separations.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from autoarray.numpy_wrapper import numpy as npw
1+
import jax.numpy as jnp
22
import numpy as np
33
from typing import Optional
44

@@ -126,8 +126,8 @@ def noise_normalization(self) -> float:
126126
"""
127127
Returns the normalization of the noise-map, which is the sum of the noise-map values squared.
128128
"""
129-
return npw.sum(
130-
npw.log(
129+
return jnp.sum(
130+
jnp.log(
131131
2
132132
* np.pi
133133
* (self.magnifications_at_positions**-2.0 * self.noise_map**2.0)

autolens/point/solver/point_solver.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
import logging
22
from typing import Tuple, Optional
33

4-
from autoarray.numpy_wrapper import np
5-
64
import autoarray as aa
7-
from autoarray.numpy_wrapper import use_jax
85
from autoarray.structures.triangles.shape import Point
96

107
from autofit.jax_wrapper import jit, register_pytree_node_class
@@ -56,23 +53,23 @@ def solve(
5653
filtered_means = self._filter_low_magnification(
5754
tracer=tracer, points=kept_triangles.means
5855
)
59-
if use_jax:
60-
return aa.Grid2DIrregular([pair for pair in filtered_means])
61-
62-
filtered_means = [
63-
pair for pair in filtered_means if not np.any(np.isnan(pair)).all()
64-
]
6556

66-
difference = len(kept_triangles.means) - len(filtered_means)
67-
if difference > 0:
68-
logger.debug(
69-
f"Filtered one multiple-image with magnification below threshold."
70-
)
71-
elif difference > 1:
72-
logger.warning(
73-
f"Filtered {difference} multiple-images with magnification below threshold."
74-
)
57+
return aa.Grid2DIrregular([pair for pair in filtered_means])
7558

76-
return aa.Grid2DIrregular(
77-
[pair for pair in filtered_means if not np.isnan(pair).all()]
78-
)
59+
# filtered_means = [
60+
# pair for pair in filtered_means if not np.any(np.isnan(pair)).all()
61+
# ]
62+
#
63+
# difference = len(kept_triangles.means) - len(filtered_means)
64+
# if difference > 0:
65+
# logger.debug(
66+
# f"Filtered one multiple-image with magnification below threshold."
67+
# )
68+
# elif difference > 1:
69+
# logger.warning(
70+
# f"Filtered {difference} multiple-images with magnification below threshold."
71+
# )
72+
#
73+
# return aa.Grid2DIrregular(
74+
# [pair for pair in filtered_means if not np.isnan(pair).all()]
75+
# )

autolens/point/solver/shape_solver.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import jax.numpy as jnp
2+
from jax import jit
13
import logging
24
import math
35

@@ -6,23 +8,11 @@
68
import autoarray as aa
79

810
from autoarray.structures.triangles.shape import Shape
9-
from autofit.jax_wrapper import jit, use_jax, numpy as np, register_pytree_node_class
10-
11-
try:
12-
if use_jax:
13-
from autoarray.structures.triangles.coordinate_array.jax_coordinate_array import (
14-
CoordinateArrayTriangles,
15-
)
16-
else:
17-
from autoarray.structures.triangles.coordinate_array.coordinate_array import (
18-
CoordinateArrayTriangles,
19-
)
20-
21-
except ImportError:
22-
from autoarray.structures.triangles.coordinate_array.coordinate_array import (
23-
CoordinateArrayTriangles,
24-
)
11+
from autofit.jax_wrapper import register_pytree_node_class
2512

13+
from autoarray.structures.triangles.coordinate_array.jax_coordinate_array import (
14+
CoordinateArrayTriangles,
15+
)
2616
from autoarray.structures.triangles.abstract import AbstractTriangles
2717

2818
from autogalaxy import OperateDeflections
@@ -278,13 +268,13 @@ def _filter_low_magnification(
278268
-------
279269
The points with an absolute magnification above the threshold.
280270
"""
281-
points = np.array(points)
271+
points = jnp.array(points)
282272
magnifications = tracer.magnification_2d_via_hessian_from(
283273
grid=aa.Grid2DIrregular(points),
284274
buffer=self.scale,
285275
)
286-
mask = np.abs(magnifications.array) > self.magnification_threshold
287-
return np.where(mask[:, None], points, np.nan)
276+
mask = jnp.abs(magnifications.array) > self.magnification_threshold
277+
return jnp.where(mask[:, None], points, jnp.nan)
288278

289279
def _source_triangles(
290280
self,

docs/installation/conda.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ For interferometer analysis there are two optional dependencies that must be ins
105105
.. code-block:: bash
106106
107107
pip install pynufft
108-
pip install pylops==2.3.1
109108
110109
**PyAutoLens** will run without these libraries and it is recommended that you only install them if you intend to
111110
do interferometer analysis.

docs/installation/overview.rst

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,4 @@ Dependencies
6666

6767
And the following optional dependencies:
6868

69-
**pynufft**: https://github.com/jyhmiinlin/pynufft
70-
71-
**PyLops**: https://github.com/PyLops/pylops
69+
**pynufft**: https://github.com/jyhmiinlin/pynufft

docs/installation/pip.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ For interferometer analysis there are two optional dependencies that must be ins
8686
.. code-block:: bash
8787
8888
pip install pynufft
89-
pip install pylops==2.3.1
9089
9190
**PyAutoLens** will run without these libraries and it is recommended that you only install them if you intend to
9291
do interferometer analysis.

0 commit comments

Comments
 (0)