Skip to content

Commit 5373b60

Browse files
authored
Merge pull request #360 from Jammy2211/feature/jax_interferometer_linear
Feature/jax interferometer linear
2 parents f241b35 + acf48dd commit 5373b60

18 files changed

Lines changed: 124 additions & 272 deletions

File tree

autolens/analysis/analysis/lens.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,10 @@ def log_likelihood_penalty_from(
127127

128128
try:
129129
for positions_likelihood in self.positions_likelihood_list:
130-
log_likelihood_penalty = positions_likelihood.log_likelihood_penalty_from(
131-
instance=instance, analysis=self
130+
log_likelihood_penalty = (
131+
positions_likelihood.log_likelihood_penalty_from(
132+
instance=instance, analysis=self
133+
)
132134
)
133135

134136
log_likelihood_penalty += log_likelihood_penalty
@@ -137,4 +139,4 @@ def log_likelihood_penalty_from(
137139
except (ValueError, np.linalg.LinAlgError) as e:
138140
raise exc.FitException from e
139141

140-
return log_likelihood_penalty
142+
return log_likelihood_penalty

autolens/analysis/positions.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,9 @@ def output_positions_info(
136136
)
137137
f.write("")
138138

139-
def log_likelihood_penalty_from(self, instance: af.ModelInstance, analysis: AnalysisDataset) -> jnp.array:
139+
def log_likelihood_penalty_from(
140+
self, instance: af.ModelInstance, analysis: AnalysisDataset
141+
) -> jnp.array:
140142
"""
141143
Returns a log-likelihood penalty used to constrain lens models where multiple image-plane
142144
positions do not trace to within a threshold distance of one another in the source-plane.
@@ -174,7 +176,7 @@ def log_likelihood_penalty_from(self, instance: af.ModelInstance, analysis: Anal
174176
tracer = analysis.tracer_via_instance_from(instance=instance)
175177

176178
if not tracer.has(cls=ag.mp.MassProfile) or len(tracer.planes) == 1:
177-
return jnp.array(0.0),
179+
return (jnp.array(0.0),)
178180

179181
positions_fit = SourceMaxSeparation(
180182
data=self.positions,
@@ -183,11 +185,11 @@ def log_likelihood_penalty_from(self, instance: af.ModelInstance, analysis: Anal
183185
plane_redshift=self.plane_redshift,
184186
)
185187

186-
max_separation = jnp.max(positions_fit.furthest_separations_of_plane_positions.array)
188+
max_separation = jnp.max(
189+
positions_fit.furthest_separations_of_plane_positions.array
190+
)
187191

188-
penalty = self.log_likelihood_penalty_factor * (
189-
max_separation - self.threshold
190-
)
192+
penalty = self.log_likelihood_penalty_factor * (max_separation - self.threshold)
191193

192194
return jax.lax.cond(
193195
max_separation > self.threshold,

autolens/analysis/result.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -310,9 +310,7 @@ def positions_likelihood_from(
310310

311311
mask = np.isfinite(positions.array).all(axis=1)
312312

313-
positions = aa.Grid2DIrregular(
314-
positions[mask]
315-
)
313+
positions = aa.Grid2DIrregular(positions[mask])
316314

317315
threshold = self.positions_threshold_from(
318316
factor=factor,

autolens/interferometer/model/analysis.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,15 @@ def log_likelihood_function(self, instance):
152152
"""
153153

154154
try:
155-
log_likelihood_penalty = self.log_likelihood_penalty_from(
156-
instance=instance
157-
)
155+
log_likelihood_penalty = self.log_likelihood_penalty_from(instance=instance)
158156
except Exception as e:
159157
raise e
160158

161159
try:
162-
return self.fit_from(instance=instance).figure_of_merit + log_likelihood_penalty
160+
return (
161+
self.fit_from(instance=instance).figure_of_merit
162+
+ log_likelihood_penalty
163+
)
163164
except (
164165
PixelizationException,
165166
exc.PixelizationException,
@@ -171,6 +172,8 @@ def log_likelihood_function(self, instance):
171172
np.linalg.LinAlgError,
172173
OverflowError,
173174
) as e:
175+
print(e)
176+
fggdfg
174177
raise exc.FitException from e
175178

176179
def fit_from(

autolens/mock.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from autofit.jax_wrapper import register_pytree_node_class, numpy as np
1+
import jax.numpy as jnp
2+
3+
from autofit.jax_wrapper import register_pytree_node_class
24
from autofit.mock import * # noqa
35
from autoarray.mock import * # noqa
46
from autogalaxy.mock import * # noqa
@@ -17,15 +19,15 @@ def __init__(self):
1719
super().__init__([])
1820

1921
def deflections_yx_2d_from(self, grid):
20-
return np.zeros_like(grid.array)
22+
return jnp.zeros_like(grid.array)
2123

2224
def deflections_between_planes_from(self, grid, plane_i=0, plane_j=-1):
23-
return np.zeros_like(grid.array)
25+
return jnp.zeros_like(grid.array)
2426

2527
def magnification_2d_via_hessian_from(
2628
self, grid, buffer: float = 0.01, deflections_func=None
2729
) -> aa.ArrayIrregular:
28-
return aa.ArrayIrregular(values=np.ones(grid.shape[0]))
30+
return aa.ArrayIrregular(values=jnp.ones(grid.shape[0]))
2931

3032
def tree_flatten(self):
3133
"""

autolens/point/fit/fluxes.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,12 @@ def model_data(self):
102102
are used.
103103
"""
104104
return aa.ArrayIrregular(
105-
values=jnp.array([
106-
magnification * self.profile.flux
107-
for magnification in self.magnifications_at_positions
108-
])
105+
values=jnp.array(
106+
[
107+
magnification * self.profile.flux
108+
for magnification in self.magnifications_at_positions
109+
]
110+
)
109111
)
110112

111113
@property

autolens/point/fit/positions/image/pair_all.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -107,16 +107,18 @@ def all_permutations_log_likelihoods(self) -> np.ndarray:
107107
[
108108
jnp.log(
109109
jnp.sum(
110-
jnp.array([
111-
jnp.exp(
112-
self.log_p(
113-
data_position,
114-
model_position,
115-
sigma,
110+
jnp.array(
111+
[
112+
jnp.exp(
113+
self.log_p(
114+
data_position,
115+
model_position,
116+
sigma,
117+
)
116118
)
117-
)
118-
for model_position in model_data
119-
])
119+
for model_position in model_data
120+
]
121+
)
120122
)
121123
)
122124
for data_position, sigma in zip(self.data, self.noise_map)

autolens/point/max_separation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ def __init__(
4242
except TypeError:
4343
plane_index = -1
4444

45-
self.plane_positions = aa.Grid2DIrregular(values=tracer.traced_grid_2d_list_from(grid=data)[plane_index])
45+
self.plane_positions = aa.Grid2DIrregular(
46+
values=tracer.traced_grid_2d_list_from(grid=data)[plane_index]
47+
)
4648

4749
@property
4850
def furthest_separations_of_plane_positions(self) -> aa.ArrayIrregular:

autolens/point/solver/point_solver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,4 @@ def solve(
6262
sentinel = jnp.full_like(solution[0], fill_value=jnp.inf)
6363
solution = jnp.where(is_nan[:, None], sentinel, solution)
6464

65-
return aa.Grid2DIrregular(solution)
65+
return aa.Grid2DIrregular(solution)

autolens/point/solver/shape_solver.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
import logging
33
import math
44

5-
from typing import Tuple, List, Iterator, Type, Optional
5+
from typing import Tuple, List, Iterator, Optional
66

77
import autoarray as aa
88

99
from autoarray.structures.triangles.shape import Shape
1010
from autofit.jax_wrapper import register_pytree_node_class
1111

12-
from autoarray.structures.triangles.coordinate_array.jax_coordinate_array import (
12+
from autoarray.structures.triangles.coordinate_array import (
1313
CoordinateArrayTriangles,
1414
)
1515
from autoarray.structures.triangles.abstract import AbstractTriangles
@@ -59,7 +59,6 @@ def for_grid(
5959
grid: aa.Grid2D,
6060
pixel_scale_precision: float,
6161
magnification_threshold=0.1,
62-
array_triangles_cls: Type[AbstractTriangles] = CoordinateArrayTriangles,
6362
neighbor_degree: int = 1,
6463
):
6564
"""
@@ -75,9 +74,6 @@ def for_grid(
7574
The precision to which the triangles should be subdivided.
7675
magnification_threshold
7776
The threshold for the magnification under which multiple images are filtered.
78-
array_triangles_cls
79-
The class to use for the triangles. JAX is used implicitly if USE_JAX=1 and
80-
jax is installed.
8177
max_containing_size
8278
Only applies to JAX. This is the maximum number of multiple images expected.
8379
We need to know this in advance to allocate memory for the JAX array.
@@ -106,7 +102,6 @@ def for_grid(
106102
scale=scale,
107103
pixel_scale_precision=pixel_scale_precision,
108104
magnification_threshold=magnification_threshold,
109-
array_triangles_cls=array_triangles_cls,
110105
neighbor_degree=neighbor_degree,
111106
)
112107

@@ -120,7 +115,6 @@ def for_limits_and_scale(
120115
scale=0.1,
121116
pixel_scale_precision: float = 0.001,
122117
magnification_threshold=0.1,
123-
array_triangles_cls: Type[AbstractTriangles] = CoordinateArrayTriangles,
124118
neighbor_degree: int = 1,
125119
):
126120
"""
@@ -141,17 +135,14 @@ def for_limits_and_scale(
141135
The precision to which the triangles should be subdivided.
142136
magnification_threshold
143137
The threshold for the magnification under which multiple images are filtered.
144-
array_triangles_cls
145-
The class to use for the triangles. JAX is used implicitly if USE_JAX=1 and
146-
jax is installed.
147138
neighbor_degree
148139
The number of times recursively add neighbors for the triangles that contain
149140
150141
Returns
151142
-------
152143
The solver.
153144
"""
154-
initial_triangles = array_triangles_cls.for_limits_and_scale(
145+
initial_triangles = CoordinateArrayTriangles.for_limits_and_scale(
155146
y_min=y_min,
156147
y_max=y_max,
157148
x_min=x_min,
@@ -310,6 +301,7 @@ def steps(
310301
An iterator over the steps of the triangle solver algorithm.
311302
"""
312303
initial_triangles = self.initial_triangles
304+
313305
for number in range(self.n_steps):
314306
plane_triangles = self._plane_triangles(
315307
tracer=tracer,

0 commit comments

Comments
 (0)