Skip to content

Commit 7a1748f

Browse files
Jammy2211Jammy2211
authored andcommitted
black
1 parent 17b38ce commit 7a1748f

21 files changed

Lines changed: 76 additions & 71 deletions

File tree

autolens/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from autoconf import jax_wrapper
12
from autoconf.dictable import from_dict, from_json, output_to_json, to_dict
23
from autoarray import preprocess
34
from autoarray.dataset.imaging.w_tilde import WTildeImaging

autolens/analysis/analysis/dataset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import os
21
import logging
2+
import numpy as np
3+
import os
34
from typing import List, Optional
45

56
from autoconf import conf
@@ -33,6 +34,7 @@ def __init__(
3334
preloads: aa.Preloads = None,
3435
raise_inversion_positions_likelihood_exception: bool = True,
3536
title_prefix: str = None,
37+
use_jax: bool = True,
3638
**kwargs,
3739
):
3840
"""
@@ -75,6 +77,7 @@ def __init__(
7577
settings_inversion=settings_inversion,
7678
preloads=preloads,
7779
title_prefix=title_prefix,
80+
use_jax=use_jax,
7881
**kwargs,
7982
)
8083

autolens/analysis/positions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import jax
21
import numpy as np
32
from typing import Optional
43
from os import path
@@ -189,6 +188,8 @@ def log_likelihood_penalty_from(
189188

190189
penalty = self.log_likelihood_penalty_factor * (max_separation - self.threshold)
191190

191+
import jax
192+
192193
return jax.lax.cond(
193194
max_separation > self.threshold,
194195
lambda: penalty,

autolens/config/general.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
jax:
2-
use_jax: true # If True, uses JAX internally, whereas False uses normal Numpy.
31
output:
42
fit_dill: false
53
test:

autolens/fixtures.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,15 +145,14 @@ def make_analysis_imaging_7x7():
145145
analysis = al.AnalysisImaging(
146146
dataset=make_masked_imaging_7x7(),
147147
settings_inversion=aa.SettingsInversion(use_w_tilde=False),
148+
use_jax=False,
148149
)
149150
analysis._adapt_images = make_adapt_images_7x7()
150151
return analysis
151152

152153

153154
def make_analysis_interferometer_7():
154-
analysis = al.AnalysisInterferometer(
155-
dataset=make_interferometer_7(),
156-
)
155+
analysis = al.AnalysisInterferometer(dataset=make_interferometer_7(), use_jax=False)
157156
analysis._adapt_images = make_adapt_images_7x7()
158157
return analysis
159158

@@ -162,4 +161,5 @@ def make_analysis_point_x2():
162161
return al.AnalysisPoint(
163162
point_dict=make_point_dict(),
164163
solver=al.m.MockPointSolver(model_positions=make_positions_x2()),
164+
use_jax=False,
165165
)

autolens/imaging/model/analysis.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def modify_before_fit(self, paths: af.DirectoryPaths, model: af.Collection):
4646

4747
return self
4848

49-
def log_likelihood_function(self, instance: af.ModelInstance, xp=np) -> float:
49+
def log_likelihood_function(self, instance: af.ModelInstance) -> float:
5050
"""
5151
Given an instance of the model, where the model parameters are set via a non-linear search, fit the model
5252
instance to the imaging dataset.
@@ -87,15 +87,14 @@ def log_likelihood_function(self, instance: af.ModelInstance, xp=np) -> float:
8787

8888
log_likelihood_penalty = self.log_likelihood_penalty_from(
8989
instance=instance,
90-
xp=xp
90+
xp=self._xp
9191
)
9292

93-
return self.fit_from(instance=instance, xp=xp).figure_of_merit - log_likelihood_penalty
93+
return self.fit_from(instance=instance).figure_of_merit - log_likelihood_penalty
9494

9595
def fit_from(
9696
self,
9797
instance: af.ModelInstance,
98-
xp=np
9998
) -> FitImaging:
10099
"""
101100
Given a model instance create a `FitImaging` object.
@@ -133,7 +132,7 @@ def fit_from(
133132
adapt_images=adapt_images,
134133
settings_inversion=self.settings_inversion,
135134
preloads=self.preloads,
136-
xp=xp
135+
xp=self._xp
137136
)
138137

139138
def save_attributes(self, paths: af.DirectoryPaths):

autolens/interferometer/model/analysis.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,17 @@
11
import logging
22
import numpy as np
3-
from typing import Dict, Optional, Tuple, Union
3+
from typing import Optional
44

55
import autofit as af
66
import autoarray as aa
77
import autogalaxy as ag
88

9-
from autoarray.exc import PixelizationException
10-
119
from autolens.analysis.analysis.dataset import AnalysisDataset
1210
from autolens.analysis.positions import PositionsLH
1311
from autolens.interferometer.model.result import ResultInterferometer
1412
from autolens.interferometer.model.visualizer import VisualizerInterferometer
1513
from autolens.interferometer.fit_interferometer import FitInterferometer
1614

17-
from autolens import exc
18-
1915
logger = logging.getLogger(__name__)
2016

2117
logger.setLevel(level="INFO")
@@ -35,6 +31,7 @@ def __init__(
3531
preloads: aa.Preloads = None,
3632
raise_inversion_positions_likelihood_exception: bool = True,
3733
title_prefix: str = None,
34+
use_jax: bool = True,
3835
):
3936
"""
4037
Analysis classes are used by PyAutoFit to fit a model to a dataset via a non-linear search.
@@ -86,6 +83,7 @@ def __init__(
8683
preloads=preloads,
8784
raise_inversion_positions_likelihood_exception=raise_inversion_positions_likelihood_exception,
8885
title_prefix=title_prefix,
86+
use_jax=use_jax,
8987
)
9088

9189
@property
@@ -114,7 +112,7 @@ def modify_before_fit(self, paths: af.DirectoryPaths, model: af.Collection):
114112

115113
return self
116114

117-
def log_likelihood_function(self, instance, xp=np):
115+
def log_likelihood_function(self, instance):
118116
"""
119117
Given an instance of the model, where the model parameters are set via a non-linear search, fit the model
120118
instance to the interferometer dataset.
@@ -154,15 +152,12 @@ def log_likelihood_function(self, instance, xp=np):
154152
"""
155153

156154
log_likelihood_penalty = self.log_likelihood_penalty_from(
157-
instance=instance, xp=xp
155+
instance=instance, xp=self._xp
158156
)
159157

160-
return (
161-
self.fit_from(instance=instance, xp=xp).figure_of_merit
162-
- log_likelihood_penalty
163-
)
158+
return self.fit_from(instance=instance).figure_of_merit - log_likelihood_penalty
164159

165-
def fit_from(self, instance: af.ModelInstance, xp=np) -> FitInterferometer:
160+
def fit_from(self, instance: af.ModelInstance) -> FitInterferometer:
166161
"""
167162
Given a model instance create a `FitInterferometer` object.
168163
@@ -198,7 +193,7 @@ def fit_from(self, instance: af.ModelInstance, xp=np) -> FitInterferometer:
198193
adapt_images=adapt_images,
199194
settings_inversion=self.settings_inversion,
200195
preloads=self.preloads,
201-
xp=xp,
196+
xp=self._xp,
202197
)
203198

204199
def save_attributes(self, paths: af.DirectoryPaths):

autolens/lens/tracer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,7 @@ def deflections_between_planes_from(
709709
The 2D (y, x) coordinates where values of the deflections are evaluated.
710710
"""
711711

712-
traced_grids_list = self.traced_grid_2d_list_from(grid=grid)
712+
traced_grids_list = self.traced_grid_2d_list_from(grid=grid, xp=xp)
713713

714714
return traced_grids_list[plane_i] - traced_grids_list[plane_j]
715715

autolens/point/fit/positions/source/separations.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ def model_data(self) -> aa.Grid2DIrregular:
9393
at its specific redshift are used.
9494
"""
9595
if len(self.tracer.planes) <= 2:
96-
deflections = self.tracer.deflections_yx_2d_from(grid=self.data)
96+
deflections = self.tracer.deflections_yx_2d_from(
97+
grid=self.data, xp=self._xp
98+
)
9799
else:
98100
deflections = self.tracer.deflections_between_planes_from(
99101
grid=self.data, xp=self._xp, plane_i=0, plane_j=self.plane_index

autolens/point/model/analysis.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Optional
1+
import numpy as np
22

33
import autofit as af
44
import autogalaxy as ag
@@ -13,15 +13,6 @@
1313
from autolens.point.model.visualizer import VisualizerPoint
1414
from autolens.point.solver import PointSolver
1515

16-
from autolens import exc
17-
18-
try:
19-
import numba
20-
21-
NumbaException = numba.errors.TypingError
22-
except ModuleNotFoundError:
23-
NumbaException = ValueError
24-
2516

2617
class AnalysisPoint(AgAnalysis, AnalysisLens):
2718
Visualizer = VisualizerPoint
@@ -35,6 +26,7 @@ def __init__(
3526
image=None,
3627
cosmology: ag.cosmo.LensingCosmology = None,
3728
title_prefix: str = None,
29+
use_jax: bool = True,
3830
):
3931
"""
4032
Fits a lens model to a point source dataset (e.g. positions, fluxes, time delays) via a non-linear search.
@@ -69,7 +61,7 @@ def __init__(
6961
A string that is added before the title of all figures output by visualization, for example to
7062
put the name of the dataset and galaxy in the title.
7163
"""
72-
super().__init__(cosmology=cosmology)
64+
super().__init__(cosmology=cosmology, use_jax=use_jax)
7365

7466
AnalysisLens.__init__(self=self, cosmology=cosmology)
7567

@@ -127,9 +119,6 @@ def fit_from(
127119
self,
128120
instance,
129121
) -> FitPointDataset:
130-
tracer = self.tracer_via_instance_from(
131-
instance=instance,
132-
)
133122
"""
134123
Given a model instance create a `FitPointDataset` object.
135124
@@ -146,11 +135,16 @@ def fit_from(
146135
-------
147136
The fit of the lens model to the point source dataset.
148137
"""
138+
tracer = self.tracer_via_instance_from(
139+
instance=instance,
140+
)
141+
149142
return FitPointDataset(
150143
dataset=self.dataset,
151144
tracer=tracer,
152145
solver=self.solver,
153146
fit_positions_cls=self.fit_positions_cls,
147+
xp=self._xp,
154148
)
155149

156150
def save_attributes(self, paths: af.DirectoryPaths):

0 commit comments

Comments
 (0)