Skip to content

Commit 4051190

Browse files
authored
Merge pull request #361 from Jammy2211/feature/jax_remove_profiling
Feature/jax remove profiling
2 parents 5373b60 + 9543df9 commit 4051190

21 files changed

Lines changed: 211 additions & 539 deletions

autolens/analysis/analysis/lens.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,10 @@ def __init__(
4646
def tracer_via_instance_from(
4747
self,
4848
instance: af.ModelInstance,
49-
run_time_dict: Optional[Dict] = None,
5049
) -> Tracer:
5150
"""
5251
Create a `Tracer` from the galaxies contained in a model instance.
5352
54-
If PyAutoFit's profiling tools are used with the analysis class, this function may receive a `run_time_dict`
55-
which times how long each set of the model-fit takes to perform.
56-
5753
Parameters
5854
----------
5955
instance
@@ -90,13 +86,11 @@ def tracer_via_instance_from(
9086
if getattr(instance, "extra_galaxies", None) is not None:
9187
return Tracer(
9288
galaxies=instance.galaxies + instance.extra_galaxies,
93-
run_time_dict=run_time_dict,
9489
)
9590

9691
return Tracer(
9792
galaxies=instance.galaxies,
9893
cosmology=cosmology,
99-
run_time_dict=run_time_dict,
10094
)
10195

10296
def log_likelihood_penalty_from(

autolens/config/general.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
jax:
2+
use_jax: true # If True, uses JAX internally, whereas False uses normal Numpy.
13
output:
24
fit_dill: false
35
test:

autolens/imaging/fit_imaging.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def __init__(
2323
dataset_model : Optional[aa.DatasetModel] = None,
2424
adapt_images: Optional[ag.AdaptImages] = None,
2525
settings_inversion: aa.SettingsInversion = aa.SettingsInversion(),
26-
run_time_dict: Optional[Dict] = None,
2726
):
2827
"""
2928
Fits an imaging dataset using a `Tracer` object.
@@ -62,12 +61,9 @@ def __init__(
6261
reconstructed galaxy's morphology.
6362
settings_inversion
6463
Settings controlling how an inversion is fitted for example which linear algebra formalism is used.
65-
run_time_dict
66-
A dictionary which if passed to the fit records how long function calls which have the `profile_func`
67-
decorator take to run.
6864
"""
6965

70-
super().__init__(dataset=dataset, dataset_model=dataset_model, run_time_dict=run_time_dict)
66+
super().__init__(dataset=dataset, dataset_model=dataset_model)
7167
AbstractFitInversion.__init__(
7268
self=self, model_obj=tracer, settings_inversion=settings_inversion
7369
)
@@ -120,7 +116,6 @@ def tracer_to_inversion(self) -> TracerToInversion:
120116
tracer=self.tracer,
121117
adapt_images=self.adapt_images,
122118
settings_inversion=self.settings_inversion,
123-
run_time_dict=self.run_time_dict
124119
)
125120

126121
@cached_property

autolens/imaging/model/analysis.py

Lines changed: 6 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -92,34 +92,15 @@ def log_likelihood_function(self, instance: af.ModelInstance) -> float:
9292
The log likelihood indicating how well this model instance fitted the imaging data.
9393
"""
9494

95-
try:
96-
log_likelihood_penalty = self.log_likelihood_penalty_from(
97-
instance=instance
98-
)
99-
except Exception as e:
100-
raise e
101-
102-
try:
103-
return self.fit_from(instance=instance).figure_of_merit + log_likelihood_penalty
104-
except (
105-
PixelizationException,
106-
exc.PixelizationException,
107-
exc.InversionException,
108-
exc.GridException,
109-
exc.MeshException,
110-
ValueError,
111-
TypeError,
112-
np.linalg.LinAlgError,
113-
OverflowError,
114-
) as e:
115-
print(e)
116-
fggdfg
117-
raise exc.FitException from e
95+
log_likelihood_penalty = self.log_likelihood_penalty_from(
96+
instance=instance
97+
)
98+
99+
return self.fit_from(instance=instance).figure_of_merit + log_likelihood_penalty
118100

119101
def fit_from(
120102
self,
121103
instance: af.ModelInstance,
122-
run_time_dict: Optional[Dict] = None,
123104
) -> FitImaging:
124105
"""
125106
Given a model instance create a `FitImaging` object.
@@ -135,8 +116,6 @@ def fit_from(
135116
check_positions
136117
Whether the multiple image positions of the lensed source should be checked, i.e. whether they trace
137118
within the position threshold of one another in the source plane.
138-
run_time_dict
139-
A dictionary which times functions called to fit the model to data, for profiling.
140119
141120
Returns
142121
-------
@@ -145,7 +124,7 @@ def fit_from(
145124
"""
146125

147126
tracer = self.tracer_via_instance_from(
148-
instance=instance, run_time_dict=run_time_dict
127+
instance=instance,
149128
)
150129

151130
dataset_model = self.dataset_model_via_instance_from(instance=instance)
@@ -158,7 +137,6 @@ def fit_from(
158137
dataset_model=dataset_model,
159138
adapt_images=adapt_images,
160139
settings_inversion=self.settings_inversion,
161-
run_time_dict=run_time_dict,
162140
)
163141

164142
def save_attributes(self, paths: af.DirectoryPaths):
@@ -201,42 +179,3 @@ def save_attributes(self, paths: af.DirectoryPaths):
201179
)
202180

203181
analysis.save_attributes(paths=paths)
204-
205-
def profile_log_likelihood_function(
206-
self, instance: af.ModelInstance, paths: Optional[af.DirectoryPaths] = None
207-
) -> Tuple[Dict, Dict]:
208-
"""
209-
This function is optionally called throughout a model-fit to profile the log likelihood function.
210-
211-
All function calls inside the `log_likelihood_function` that are decorated with the `profile_func` are timed
212-
with their times stored in a dictionary called the `run_time_dict`.
213-
214-
An `info_dict` is also created which stores information on aspects of the model and dataset that dictate
215-
run times, so the profiled times can be interpreted with this context.
216-
217-
The results of this profiling are then output to hard-disk in the `profiling` folder of the model-fit results,
218-
which they can be inspected to ensure run-times are as expected.
219-
220-
Parameters
221-
----------
222-
instance
223-
An instance of the model that is being fitted to the data by this analysis (whose parameters have been set
224-
via a non-linear search).
225-
paths
226-
The paths object which manages all paths, e.g. where the non-linear search outputs are stored,
227-
visualization and the pickled objects used by the aggregator output by this function.
228-
229-
Returns
230-
-------
231-
Two dictionaries, the profiling dictionary and info dictionary, which contain the profiling times of the
232-
`log_likelihood_function` and information on the model and dataset used to perform the profiling.
233-
"""
234-
run_time_dict, info_dict = super().profile_log_likelihood_function(
235-
instance=instance,
236-
)
237-
238-
info_dict["psf_shape_2d"] = self.dataset.psf.shape_native
239-
240-
self.output_profiling_info(paths=paths, run_time_dict=run_time_dict, info_dict=info_dict)
241-
242-
return run_time_dict, info_dict

autolens/interferometer/fit_interferometer.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ def __init__(
2020
dataset_model: Optional[aa.DatasetModel] = None,
2121
adapt_images: Optional[ag.AdaptImages] = None,
2222
settings_inversion: aa.SettingsInversion = aa.SettingsInversion(),
23-
run_time_dict: Optional[Dict] = None,
2423
):
2524
"""
2625
Fits an interferometer dataset using a `Tracer` object.
@@ -60,9 +59,6 @@ def __init__(
6059
reconstructed galaxy's morphology.
6160
settings_inversion
6261
Settings controlling how an inversion is fitted for example which linear algebra formalism is used.
63-
run_time_dict
64-
A dictionary which if passed to the fit records how long function calls which have the `profile_func`
65-
decorator take to run.
6662
"""
6763

6864
try:
@@ -76,10 +72,9 @@ def __init__(
7672

7773
self.settings_inversion = settings_inversion
7874

79-
self.run_time_dict = run_time_dict
80-
8175
super().__init__(
82-
dataset=dataset, dataset_model=dataset_model, run_time_dict=run_time_dict
76+
dataset=dataset,
77+
dataset_model=dataset_model,
8378
)
8479
AbstractFitInversion.__init__(
8580
self=self, model_obj=tracer, settings_inversion=settings_inversion

autolens/interferometer/model/analysis.py

Lines changed: 4 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -151,35 +151,13 @@ def log_likelihood_function(self, instance):
151151
The log likelihood indicating how well this model instance fitted the interferometer data.
152152
"""
153153

154-
try:
155-
log_likelihood_penalty = self.log_likelihood_penalty_from(instance=instance)
156-
except Exception as e:
157-
raise e
158-
159-
try:
160-
return (
161-
self.fit_from(instance=instance).figure_of_merit
162-
+ log_likelihood_penalty
163-
)
164-
except (
165-
PixelizationException,
166-
exc.PixelizationException,
167-
exc.InversionException,
168-
exc.GridException,
169-
exc.MeshException,
170-
ValueError,
171-
TypeError,
172-
np.linalg.LinAlgError,
173-
OverflowError,
174-
) as e:
175-
print(e)
176-
fggdfg
177-
raise exc.FitException from e
154+
log_likelihood_penalty = self.log_likelihood_penalty_from(instance=instance)
155+
156+
return self.fit_from(instance=instance).figure_of_merit + log_likelihood_penalty
178157

179158
def fit_from(
180159
self,
181160
instance: af.ModelInstance,
182-
run_time_dict: Optional[Dict] = None,
183161
) -> FitInterferometer:
184162
"""
185163
Given a model instance create a `FitInterferometer` object.
@@ -205,7 +183,7 @@ def fit_from(
205183
"""
206184

207185
tracer = self.tracer_via_instance_from(
208-
instance=instance, run_time_dict=run_time_dict
186+
instance=instance,
209187
)
210188

211189
adapt_images = self.adapt_images_via_instance_from(instance=instance)
@@ -215,7 +193,6 @@ def fit_from(
215193
tracer=tracer,
216194
adapt_images=adapt_images,
217195
settings_inversion=self.settings_inversion,
218-
run_time_dict=run_time_dict,
219196
)
220197

221198
def save_attributes(self, paths: af.DirectoryPaths):
@@ -257,45 +234,3 @@ def save_attributes(self, paths: af.DirectoryPaths):
257234
)
258235

259236
analysis.save_attributes(paths=paths)
260-
261-
def profile_log_likelihood_function(
262-
self, instance: af.ModelInstance, paths: Optional[af.DirectoryPaths] = None
263-
) -> Tuple[Dict, Dict]:
264-
"""
265-
This function is optionally called throughout a model-fit to profile the log likelihood function.
266-
267-
All function calls inside the `log_likelihood_function` that are decorated with the `profile_func` are timed
268-
with their times stored in a dictionary called the `run_time_dict`.
269-
270-
An `info_dict` is also created which stores information on aspects of the model and dataset that dictate
271-
run times, so the profiled times can be interpreted with this context.
272-
273-
The results of this profiling are then output to hard-disk in the `profiling` folder of the model-fit results,
274-
which they can be inspected to ensure run-times are as expected.
275-
276-
Parameters
277-
----------
278-
instance
279-
An instance of the model that is being fitted to the data by this analysis (whose parameters have been set
280-
via a non-linear search).
281-
paths
282-
The paths object which manages all paths, e.g. where the non-linear search outputs are stored,
283-
visualization and the pickled objects used by the aggregator output by this function.
284-
285-
Returns
286-
-------
287-
Two dictionaries, the profiling dictionary and info dictionary, which contain the profiling times of the
288-
`log_likelihood_function` and information on the model and dataset used to perform the profiling.
289-
"""
290-
run_time_dict, info_dict = super().profile_log_likelihood_function(
291-
instance=instance,
292-
)
293-
294-
info_dict["number_of_visibilities"] = self.dataset.data.shape[0]
295-
info_dict["transformer_cls"] = self.dataset.transformer.__class__.__name__
296-
297-
self.output_profiling_info(
298-
paths=paths, run_time_dict=run_time_dict, info_dict=info_dict
299-
)
300-
301-
return run_time_dict, info_dict

autolens/lens/mock/mock_to_inversion.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,10 @@ def __init__(
66
self,
77
tracer,
88
image_plane_mesh_grid_pg_list=None,
9-
run_time_dict: Optional[Dict] = None,
109
):
1110
self.tracer = tracer
1211

1312
self.image_plane_mesh_grid_pg_list = image_plane_mesh_grid_pg_list
1413

15-
self.run_time_dict = run_time_dict
16-
1714
def image_plane_mesh_grid_pg_list(self):
1815
return self.image_plane_mesh_grid_pg_list

0 commit comments

Comments
 (0)