Skip to content

Commit 545c864

Browse files
Jammy2211Jammy2211
authored andcommitted
fix casting in abstract fit
1 parent 229857b commit 545c864

3 files changed

Lines changed: 28 additions & 3 deletions

File tree

autogalaxy/abstract_fit.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22
from typing import TYPE_CHECKING, Dict, Optional
33

4+
import numpy as np
5+
46
from autofit import ModelInstance
57

68
if TYPE_CHECKING:
@@ -17,6 +19,7 @@ def __init__(
1719
self,
1820
model_obj,
1921
settings_inversion: aa.SettingsInversion,
22+
xp=np
2023
):
2124
"""
2225
An abstract fit object which fits to datasets (e.g. imaging, interferometer) inherit from.
@@ -35,6 +38,15 @@ def __init__(
3538
"""
3639
self.model_obj = model_obj
3740
self.settings_inversion = settings_inversion
41+
self.use_jax = xp is not np
42+
43+
@property
44+
def _xp(self):
45+
if self.use_jax:
46+
import jax.numpy as jnp
47+
48+
return jnp
49+
return np
3850

3951
@property
4052
def total_mappers(self) -> int:
@@ -95,6 +107,14 @@ def linear_light_profile_intensity_dict(
95107
96108
This function returns a dictionary which maps every linear light profile instance to its solved for
97109
`intensity` value in the inversion, so that the intensity value of every light profile can be accessed.
110+
111+
Type casting is complicated by JAX. When this function is used in a JAX.jit (e.g. computed latent varialbes)
112+
it requires the reconstruction values to be JAX arrays, but when it is used outside of JAX certain taks
113+
requires the reconstruction values to be floats.
114+
115+
An example of the latter is using a tracer inferred in one search to pass the solved for intensity values of
116+
linear light profiles to a subsequent search, for example setting up the intensities of the mass components
117+
of a light dark model.
98118
"""
99119

100120
if self.inversion is None:
@@ -110,9 +130,12 @@ def linear_light_profile_intensity_dict(
110130
reconstruction = self.inversion.reconstruction_dict[linear_obj_func]
111131

112132
for i, light_profile in enumerate(linear_obj_func.light_profile_list):
113-
linear_light_profile_intensity_dict[light_profile] = float(
114-
reconstruction[i]
115-
)
133+
if self.use_jax:
134+
linear_light_profile_intensity_dict[light_profile] = reconstruction[i]
135+
else:
136+
linear_light_profile_intensity_dict[light_profile] = float(
137+
reconstruction[i]
138+
)
116139

117140
return linear_light_profile_intensity_dict
118141

autogalaxy/imaging/fit_imaging.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(
7878
self=self,
7979
model_obj=self.galaxies,
8080
settings_inversion=settings_inversion,
81+
xp=xp
8182
)
8283

8384
self.adapt_images = adapt_images

autogalaxy/interferometer/fit_interferometer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def __init__(
7272
self=self,
7373
model_obj=self.galaxies,
7474
settings_inversion=settings_inversion,
75+
xp=xp
7576
)
7677

7778
self.adapt_images = adapt_images

0 commit comments

Comments
 (0)