@@ -86,88 +86,94 @@ class OperateDeflections:
8686 The majority of methods are those which from the 2D deflection angle map compute lensing quantities like a 2D
8787 shear field, magnification map or the Einstein Radius.
8888
89- The methods in `CalcLens` are passed to the mass object to provide a concise API.
90-
9189 Parameters
9290 ----------
9391 deflections_yx_2d_from
94- The function which returns the mass object's 2D deflection angles.
92+ A callable with signature ``(grid, xp=np, **kwargs)`` that returns the 2D deflection angles on the given
93+ grid. Typically a bound method of a ``MassProfile``, ``Galaxy``, or ``Galaxies`` instance.
9594 """
9695
97- def deflections_yx_2d_from (self , grid : aa . type . Grid2DLike , ** kwargs ):
98- raise NotImplementedError
96+ def __init__ (self , deflections_yx_2d_from ):
97+ self . deflections_yx_2d_from = deflections_yx_2d_from
9998
100- def __eq__ (self , other ):
101- return self .__dict__ == other .__dict__ and self .__class__ is other .__class__
99+ @classmethod
100+ def from_mass_obj (cls , mass_obj ):
101+ """Construct from any object that has a ``deflections_yx_2d_from`` method."""
102+ return cls (deflections_yx_2d_from = mass_obj .deflections_yx_2d_from )
102103
103- def time_delay_geometry_term_from (self , grid , xp = np ) -> aa .Array2D :
104+ @classmethod
105+ def from_tracer (cls , tracer , use_multi_plane : bool = True , plane_i : int = 0 , plane_j : int = - 1 ):
104106 """
105- Returns the geometric time delay term of the Fermat potential for a given grid of image-plane positions.
106-
107- This term is given by:
108-
109- .. math::
110- \[\t au_{\t ext{geom}}(\b oldsymbol{\t heta}) = \f rac{1}{2} |\b oldsymbol{\t heta} - \b oldsymbol{\b eta}|^2\]
107+ Construct from a PyAutoLens ``Tracer`` object.
111108
112- where:
113- - \( \b oldsymbol{\t heta} \) is the image-plane coordinate,
114- - \( \b oldsymbol{\b eta} = \b oldsymbol{\t heta} - \b oldsymbol{\a lpha}(\b oldsymbol{\t heta}) \) is the source-plane coordinate,
115- - \( \b oldsymbol{\a lpha} \) is the deflection angle at each image-plane coordinate.
116-
117- Parameters
118- ----------
119- grid
120- The 2D grid of (y,x) arc-second coordinates the deflection angles and time delay geometric term are computed
121- on.
122-
123- Returns
124- -------
125- The geometric time delay term at each grid position.
126- """
127- deflections = self .deflections_yx_2d_from (grid = grid , xp = xp )
109+ The ``Tracer`` type is deliberately left unannotated: ``autogalaxy`` does not
110+ depend on ``autolens``, so no import of ``Tracer`` is performed here. Callers
111+ (which live inside ``autolens``) are responsible for passing the correct object.
128112
129- src_y = grid [:, 0 ] - deflections [:, 0 ]
130- src_x = grid [:, 1 ] - deflections [:, 1 ]
131-
132- delay = 0.5 * ((grid [:, 0 ] - src_y ) ** 2 + (grid [:, 1 ] - src_x ) ** 2 )
133-
134- if isinstance (grid , aa .Grid2DIrregular ):
135- return aa .ArrayIrregular (values = delay )
136- return aa .Array2D (values = delay , mask = grid .mask )
113+ Parameters
114+ ----------
115+ tracer
116+ A PyAutoLens ``Tracer`` instance. Must expose ``deflections_yx_2d_from``
117+ and, when ``use_multi_plane=True``, ``deflections_between_planes_from``.
118+ use_multi_plane
119+ If ``True`` the stored callable is
120+ ``tracer.deflections_between_planes_from`` with ``plane_i`` and ``plane_j``
121+ bound via ``functools.partial``, matching the multi-plane ray-tracing path.
122+ If ``False`` the stored callable is ``tracer.deflections_yx_2d_from``,
123+ i.e. the standard two-plane path.
124+ plane_i
125+ Index of the first plane used by ``deflections_between_planes_from``.
126+ Ignored when ``use_multi_plane=False``. Defaults to ``0`` (image plane).
127+ plane_j
128+ Index of the second plane used by ``deflections_between_planes_from``.
129+ Ignored when ``use_multi_plane=False``. Defaults to ``-1`` (source plane).
130+ """
131+ if use_multi_plane :
132+ from functools import partial
133+
134+ return cls (
135+ deflections_yx_2d_from = partial (
136+ tracer .deflections_between_planes_from ,
137+ plane_i = plane_i ,
138+ plane_j = plane_j ,
139+ )
140+ )
141+ return cls (deflections_yx_2d_from = tracer .deflections_yx_2d_from )
137142
138- def fermat_potential_from (self , grid , xp = np ) -> aa .Array2D :
143+ def time_delay_geometry_term_from (self , grid , xp = np ) -> aa .Array2D :
139144 """
140- Returns the Fermat potential for a given grid of image-plane positions.
145+ Returns the geometric time delay term of the Fermat potential for a given grid of image-plane positions.
141146
142- This is the sum of the geometric time delay term and the gravitational (Shapiro) delay term (i.e. the lensing
143- potential), and is given by:
147+ This term is given by:
144148
145149 .. math::
146- \[\phi (\b oldsymbol{\t heta}) = \f rac{1}{2} |\b oldsymbol{\t heta} - \b oldsymbol{\b eta}|^2 - \psi( \b oldsymbol{ \t heta}) \]
150+ \[\t au_{ \t ext{geom}} (\b oldsymbol{\t heta}) = \f rac{1}{2} |\b oldsymbol{\t heta} - \b oldsymbol{\b eta}|^2\]
147151
148152 where:
149153 - \( \b oldsymbol{\t heta} \) is the image-plane coordinate,
150154 - \( \b oldsymbol{\b eta} = \b oldsymbol{\t heta} - \b oldsymbol{\a lpha}(\b oldsymbol{\t heta}) \) is the source-plane coordinate,
151- - \( \psi(\b oldsymbol{\t heta}) \) is the lensing potential,
152- - \( \phi(\b oldsymbol{\t heta}) \) is the Fermat potential.
155+ - \( \b oldsymbol{\a lpha} \) is the deflection angle at each image-plane coordinate.
153156
154157 Parameters
155158 ----------
156159 grid
157- The 2D grid of (y,x) arc-second coordinates the Fermat potential is computed on.
160+ The 2D grid of (y,x) arc-second coordinates the deflection angles and time delay geometric term are computed
161+ on.
158162
159163 Returns
160164 -------
161- The Fermat potential at each grid position.
165+ The geometric time delay term at each grid position.
162166 """
163- time_delay_geometry_term = self .time_delay_geometry_term_from (grid = grid , xp = xp )
164- potential = self .potential_2d_from (grid = grid , xp = xp )
167+ deflections = self .deflections_yx_2d_from (grid = grid , xp = xp )
168+
169+ src_y = grid [:, 0 ] - deflections [:, 0 ]
170+ src_x = grid [:, 1 ] - deflections [:, 1 ]
165171
166- fermat_potential = time_delay_geometry_term - potential
172+ delay = 0.5 * (( grid [:, 0 ] - src_y ) ** 2 + ( grid [:, 1 ] - src_x ) ** 2 )
167173
168174 if isinstance (grid , aa .Grid2DIrregular ):
169- return aa .ArrayIrregular (values = fermat_potential )
170- return aa .Array2D (values = fermat_potential , mask = grid .mask )
175+ return aa .ArrayIrregular (values = delay )
176+ return aa .Array2D (values = delay , mask = grid .mask )
171177
172178 def tangential_eigen_value_from (self , grid , xp = np ) -> aa .Array2D :
173179 """
@@ -188,9 +194,6 @@ def tangential_eigen_value_from(self, grid, xp=np) -> aa.Array2D:
188194 convergence = self .convergence_2d_via_hessian_from (grid = grid , xp = xp )
189195 shear_yx = self .shear_yx_2d_via_hessian_from (grid = grid , xp = xp )
190196
191- if xp is not np :
192- shear_magnitudes = xp .sqrt (shear_yx [:, 0 ] ** 2 + shear_yx [:, 1 ] ** 2 )
193- return xp .array (1 - convergence - shear_magnitudes )
194197 return aa .Array2D (values = 1 - convergence - shear_yx .magnitudes , mask = grid .mask )
195198
196199 def radial_eigen_value_from (self , grid , xp = np ) -> aa .Array2D :
@@ -211,9 +214,6 @@ def radial_eigen_value_from(self, grid, xp=np) -> aa.Array2D:
211214 convergence = self .convergence_2d_via_hessian_from (grid = grid , xp = xp )
212215 shear = self .shear_yx_2d_via_hessian_from (grid = grid , xp = xp )
213216
214- if xp is not np :
215- shear_magnitudes = xp .sqrt (shear [:, 0 ] ** 2 + shear [:, 1 ] ** 2 )
216- return xp .array (1 - convergence + shear_magnitudes )
217217 return aa .Array2D (values = 1 - convergence + shear .magnitudes , mask = grid .mask )
218218
219219 def magnification_2d_from (self , grid , xp = np ) -> aa .Array2D :
@@ -235,8 +235,6 @@ def magnification_2d_from(self, grid, xp=np) -> aa.Array2D:
235235
236236 det_A = (1 - hessian_xx ) * (1 - hessian_yy ) - hessian_xy * hessian_yx
237237
238- if xp is not np :
239- return xp .array (1 / det_A )
240238 return aa .Array2D (values = 1 / det_A , mask = grid .mask )
241239
242240 def deflections_yx_scalar (self , y , x , pixel_scales ):
@@ -295,10 +293,9 @@ def hessian_from(self, grid, xp=np) -> Tuple:
295293 The array module (``numpy`` or ``jax.numpy``). Controls which computational path is
296294 used and the type of the returned arrays.
297295 """
298- if xp is not np :
299- return self ._hessian_via_jax (grid = grid , xp = xp )
300-
301- return self ._hessian_via_finite_difference (grid = grid )
296+ if xp is np :
297+ return self ._hessian_via_finite_difference (grid = grid )
298+ return self ._hessian_via_jax (grid = grid , xp = xp )
302299
303300 def _hessian_via_jax (self , grid , xp ) -> Tuple :
304301 import jax
@@ -417,8 +414,6 @@ def convergence_2d_via_hessian_from(
417414
418415 convergence = 0.5 * (hessian_yy + hessian_xx )
419416
420- if xp is not np :
421- return xp .array (convergence )
422417 return aa .ArrayIrregular (values = convergence )
423418
424419 def shear_yx_2d_via_hessian_from (
@@ -460,13 +455,7 @@ def shear_yx_2d_via_hessian_from(
460455 gamma_1 = 0.5 * (hessian_xx - hessian_yy )
461456 gamma_2 = hessian_xy
462457
463- if xp is not np :
464- return xp .stack ([gamma_2 , gamma_1 ], axis = - 1 )
465-
466- shear_yx_2d = np .zeros (shape = (grid .shape [0 ], 2 ))
467-
468- shear_yx_2d [:, 0 ] = gamma_2
469- shear_yx_2d [:, 1 ] = gamma_1
458+ shear_yx_2d = xp .stack ([gamma_2 , gamma_1 ], axis = - 1 )
470459
471460 return ShearYX2DIrregular (values = shear_yx_2d , grid = grid )
472461
@@ -498,8 +487,6 @@ def magnification_2d_via_hessian_from(
498487
499488 det_A = (1 - hessian_xx ) * (1 - hessian_yy ) - hessian_xy * hessian_yx
500489
501- if xp is not np :
502- return xp .array (1.0 / det_A )
503490 return aa .ArrayIrregular (values = 1.0 / det_A )
504491
505492 def contour_list_from (self , grid , contour_array ):
0 commit comments