11from __future__ import annotations
22from typing import TYPE_CHECKING , Dict , Optional
33
4+ import numpy as np
5+
46from autofit import ModelInstance
57
68if TYPE_CHECKING :
@@ -17,6 +19,7 @@ def __init__(
1719 self ,
1820 model_obj ,
1921 settings_inversion : aa .SettingsInversion ,
22+ xp = np
2023 ):
2124 """
2225 An abstract fit object which fits to datasets (e.g. imaging, interferometer) inherit from.
@@ -35,6 +38,15 @@ def __init__(
3538 """
3639 self .model_obj = model_obj
3740 self .settings_inversion = settings_inversion
41+ self .use_jax = xp is not np
42+
43+ @property
44+ def _xp (self ):
45+ if self .use_jax :
46+ import jax .numpy as jnp
47+
48+ return jnp
49+ return np
3850
3951 @property
4052 def total_mappers (self ) -> int :
@@ -95,6 +107,14 @@ def linear_light_profile_intensity_dict(
95107
96108 This function returns a dictionary which maps every linear light profile instance to its solved for
97109 `intensity` value in the inversion, so that the intensity value of every light profile can be accessed.
110+
111+ Type casting is complicated by JAX. When this function is used in a JAX.jit (e.g. computed latent varialbes)
112+ it requires the reconstruction values to be JAX arrays, but when it is used outside of JAX certain taks
113+ requires the reconstruction values to be floats.
114+
115+ An example of the latter is using a tracer inferred in one search to pass the solved for intensity values of
116+ linear light profiles to a subsequent search, for example setting up the intensities of the mass components
117+ of a light dark model.
98118 """
99119
100120 if self .inversion is None :
@@ -110,9 +130,12 @@ def linear_light_profile_intensity_dict(
110130 reconstruction = self .inversion .reconstruction_dict [linear_obj_func ]
111131
112132 for i , light_profile in enumerate (linear_obj_func .light_profile_list ):
113- linear_light_profile_intensity_dict [light_profile ] = float (
114- reconstruction [i ]
115- )
133+ if self .use_jax :
134+ linear_light_profile_intensity_dict [light_profile ] = reconstruction [i ]
135+ else :
136+ linear_light_profile_intensity_dict [light_profile ] = float (
137+ reconstruction [i ]
138+ )
116139
117140 return linear_light_profile_intensity_dict
118141
0 commit comments