Skip to content

Commit 3b7b2ff

Browse files
authored
Merge pull request #534 from PyAutoLabs/feature/latent-module-autolens
feat: first-class lensing latent variable API in PyAutoLens
2 parents 6a80e9e + 368f292 commit 3b7b2ff

5 files changed

Lines changed: 521 additions & 0 deletions

File tree

autolens/analysis/latent.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
"""
2+
Latent variables for PyAutoLens analyses.
3+
4+
All latents take a generic ``fit`` argument and access ``fit.tracer``,
5+
``fit.galaxy_image_dict`` and ``fit.dataset.grids.lp`` — APIs that exist
6+
identically on both ``FitImaging`` (``autolens/imaging/fit_imaging.py:176``)
7+
and ``FitInterferometer`` (``autolens/interferometer/fit_interferometer.py:176``).
8+
The registry is dataset-agnostic; a future ``AnalysisInterferometer``
9+
wiring can reuse it without code duplication.
10+
11+
User-level enable/disable: each key in ``autolens/config/latent.yaml`` maps
12+
to a bool. All five default ``false`` because ``compute_latent_samples``
13+
runs on every fit (``latent_after_fit: true`` in autofit's default
14+
``output.yaml``) and the latents that require ``magzero`` would otherwise
15+
crash existing fits where ``magzero`` is not passed.
16+
"""
17+
import logging
18+
from typing import Callable, Dict, List, Optional
19+
20+
import numpy as np
21+
22+
from autoconf import conf
23+
from autogalaxy.imaging.model.latent import (
24+
ab_mag_via_flux_from,
25+
flux_mujy_via_ab_mag_from,
26+
)
27+
28+
logger = logging.getLogger(__name__)
29+
30+
31+
def _require_magzero(magzero, name):
32+
if magzero is None:
33+
raise ValueError(
34+
f"magzero must be passed to the Analysis via kwargs to compute "
35+
f"the '{name}' latent. Disable it in config/latent.yaml or "
36+
f"pass magzero=<value>."
37+
)
38+
39+
40+
def total_lens_flux_mujy(fit, magzero, xp=np):
41+
"""
42+
Total integrated flux of the lens galaxy (``fit.tracer.galaxies[0]``),
43+
magzero-converted to microjanskies.
44+
45+
Returns NaN when galaxy 0 has no light profile (raises ``KeyError`` /
46+
``AttributeError`` inside ``fit.galaxy_image_dict``).
47+
"""
48+
_require_magzero(magzero, "total_lens_flux_mujy")
49+
try:
50+
image = fit.galaxy_image_dict[fit.tracer.galaxies[0]]
51+
except (AttributeError, KeyError, IndexError):
52+
return xp.nan
53+
total_flux = xp.sum(image.array)
54+
return flux_mujy_via_ab_mag_from(
55+
ab_mag=ab_mag_via_flux_from(flux=total_flux, magzero=magzero, xp=xp),
56+
xp=xp,
57+
)
58+
59+
60+
def total_lensed_source_flux_mujy(fit, magzero, xp=np):
61+
"""
62+
Image-plane integrated flux of the source galaxy after lensing
63+
(``fit.galaxy_image_dict[fit.tracer.galaxies[-1]]``).
64+
"""
65+
_require_magzero(magzero, "total_lensed_source_flux_mujy")
66+
try:
67+
image = fit.galaxy_image_dict[fit.tracer.galaxies[-1]]
68+
except (AttributeError, KeyError, IndexError):
69+
return xp.nan
70+
total_flux = xp.sum(image.array)
71+
return flux_mujy_via_ab_mag_from(
72+
ab_mag=ab_mag_via_flux_from(flux=total_flux, magzero=magzero, xp=xp),
73+
xp=xp,
74+
)
75+
76+
77+
def total_source_flux_mujy(fit, magzero, xp=np):
78+
"""
79+
Source-plane intrinsic flux of the source galaxy, via
80+
``fit.tracer.galaxies[-1].image_2d_from(grid=fit.dataset.grids.lp)``.
81+
"""
82+
_require_magzero(magzero, "total_source_flux_mujy")
83+
try:
84+
source_image = fit.tracer.galaxies[-1].image_2d_from(
85+
grid=fit.dataset.grids.lp, xp=xp
86+
)
87+
except (AttributeError, IndexError):
88+
return xp.nan
89+
total_flux = xp.sum(source_image.array)
90+
return flux_mujy_via_ab_mag_from(
91+
ab_mag=ab_mag_via_flux_from(flux=total_flux, magzero=magzero, xp=xp),
92+
xp=xp,
93+
)
94+
95+
96+
def magnification(fit, magzero, xp=np):
97+
"""
98+
Ratio of image-plane to source-plane source flux — the integrated
99+
magnification implied by the lens model and source light profile.
100+
101+
``magzero`` is accepted but unused (the µJy conversions cancel in the
102+
ratio). It's still required in the signature so the dispatcher can
103+
pass a uniform context dict to every latent function.
104+
"""
105+
lensed = total_lensed_source_flux_mujy(fit=fit, magzero=magzero, xp=xp)
106+
intrinsic = total_source_flux_mujy(fit=fit, magzero=magzero, xp=xp)
107+
return lensed / intrinsic
108+
109+
110+
def effective_einstein_radius(fit, magzero, xp=np):
111+
"""
112+
Effective Einstein radius via the tangential critical curve.
113+
114+
JAX path: ``LensCalc.einstein_radius_jit_from(init_guess=fan)``, where
115+
``fan`` is a fixed 4-seed fan at ±1 arcsec from the lens centre — the
116+
JIT-compatible variant required because ``ZeroSolver`` (line 1520 of
117+
``autogalaxy/operate/lens_calc.py``) uses ``lax.cond`` /
118+
``lax.while_loop`` early termination that is incompatible with
119+
``jax.vmap`` but fine under ``jax.jit``.
120+
121+
NumPy path: ``LensCalc.einstein_radius_from(grid=fit.dataset.grids.lp)``.
122+
"""
123+
from autogalaxy.operate.lens_calc import LensCalc
124+
125+
try:
126+
lens_calc = LensCalc.from_mass_obj(fit.tracer)
127+
if xp is not np:
128+
import jax.numpy as jnp
129+
init_guess = jnp.array(
130+
[[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0]]
131+
)
132+
return lens_calc.einstein_radius_jit_from(init_guess=init_guess)
133+
return lens_calc.einstein_radius_from(grid=fit.dataset.grids.lp)
134+
except (ValueError, AttributeError):
135+
return xp.nan
136+
137+
138+
LATENT_FUNCTIONS: Dict[str, Callable] = {
139+
"total_lens_flux_mujy": total_lens_flux_mujy,
140+
"total_lensed_source_flux_mujy": total_lensed_source_flux_mujy,
141+
"total_source_flux_mujy": total_source_flux_mujy,
142+
"magnification": magnification,
143+
"effective_einstein_radius": effective_einstein_radius,
144+
}
145+
146+
147+
def latent_keys_enabled(yaml_config: Optional[Dict[str, bool]] = None) -> List[str]:
148+
"""
149+
Return the ordered list of enabled latent keys.
150+
151+
Reads ``conf.instance["latent"]`` (a flat ``key: bool`` dict from
152+
``autolens/config/latent.yaml``) unless ``yaml_config`` is passed
153+
explicitly — tests pass a literal dict to avoid pushing a temporary
154+
config directory.
155+
156+
Unknown keys (present in the yaml but not in :data:`LATENT_FUNCTIONS`)
157+
are dropped with a logger warning rather than raising — yaml carries
158+
forward-compat entries for latents that ship in later releases.
159+
"""
160+
if yaml_config is None:
161+
yaml_config = dict(conf.instance["latent"])
162+
163+
enabled: List[str] = []
164+
for key, on in yaml_config.items():
165+
if not on:
166+
continue
167+
if key not in LATENT_FUNCTIONS:
168+
logger.warning(
169+
"latent.yaml lists '%s' but no such latent is registered; "
170+
"dropping. Known latents: %s",
171+
key,
172+
sorted(LATENT_FUNCTIONS),
173+
)
174+
continue
175+
enabled.append(key)
176+
return enabled

autolens/config/latent.yaml

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Toggles for the catalogue of latent variables computed by `AnalysisImaging`
2+
# (and, when wired in a follow-up, `AnalysisInterferometer`).
3+
#
4+
# Each entry maps a registered latent name (see
5+
# `autolens/analysis/latent.py::LATENT_FUNCTIONS`) to a bool. Setting `false`
6+
# excludes that latent from `LATENT_KEYS` so it is neither computed nor
7+
# written to `latent/samples.csv` / `latent/latent_summary.json`.
8+
#
9+
# Workspaces should mirror this file in their own `config/latent.yaml` to
10+
# override defaults locally (workspace values shadow library values).
11+
#
12+
# All keys ship `false` because:
13+
# - `compute_latent_samples` runs on every fit (`latent_after_fit: true`
14+
# in autofit's default output.yaml), so on-by-default would crash any
15+
# existing fit that doesn't pass `magzero` to the Analysis.
16+
# - autoconf lowercases yaml keys at read time, so the registry/yaml
17+
# names must be snake_case-lowercase (this leaks into the `latent.csv`
18+
# column header — e.g. `total_lens_flux_mujy`, not `_muJy`).
19+
20+
# `total_lens_flux_mujy` — total integrated flux of the lens galaxy
21+
# (`fit.tracer.galaxies[0]`) in microjanskies. Requires `magzero` via
22+
# Analysis kwargs. Returns NaN when the lens has no light profile.
23+
total_lens_flux_mujy: false
24+
25+
# `total_lensed_source_flux_mujy` — image-plane integrated flux of the
26+
# source galaxy after lensing (`fit.galaxy_image_dict[tracer.galaxies[-1]]`)
27+
# in microjanskies. Requires `magzero`.
28+
total_lensed_source_flux_mujy: false
29+
30+
# `total_source_flux_mujy` — source-plane intrinsic flux of the source
31+
# galaxy (computed via the source's light profile on `fit.dataset.grids.lp`)
32+
# in microjanskies. Requires `magzero`.
33+
total_source_flux_mujy: false
34+
35+
# `magnification` — ratio of image-plane lensed source flux to source-plane
36+
# intrinsic source flux. Dimensionless; `magzero` is accepted but unused.
37+
magnification: false
38+
39+
# `effective_einstein_radius` — effective Einstein radius in arcseconds,
40+
# from the tangential critical curve via
41+
# `LensCalc.einstein_radius_jit_from` (JAX) or `einstein_radius_from`
42+
# (numpy). Does NOT require `magzero`.
43+
effective_einstein_radius: false

autolens/imaging/model/analysis.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,38 @@ class AnalysisImaging(AnalysisDataset):
3232
Result = ResultImaging
3333
Visualizer = VisualizerImaging
3434

35+
@property
36+
def LATENT_KEYS(self):
37+
from autolens.analysis.latent import latent_keys_enabled
38+
return latent_keys_enabled()
39+
40+
def compute_latent_variables(self, parameters, model):
41+
"""
42+
Compute the catalogue of lensing latent variables enabled in
43+
``config/latent.yaml`` for the given parameter vector.
44+
45+
Returns a tuple positionally aligned with :attr:`LATENT_KEYS` —
46+
PyAutoFit zips it with the keys at
47+
``autofit/non_linear/analysis/analysis.py:285`` and stacks per
48+
sample for the JIT batch path at lines 223-234.
49+
50+
Raises ``NotImplementedError`` when no latents are enabled so
51+
PyAutoFit's outer ``except NotImplementedError`` short-circuits
52+
the latent pipeline cleanly (no empty ``latent.csv`` written).
53+
"""
54+
from autolens.analysis.latent import LATENT_FUNCTIONS
55+
56+
keys = self.LATENT_KEYS
57+
if not keys:
58+
raise NotImplementedError
59+
60+
xp = self._xp
61+
instance = model.instance_from_vector(vector=parameters)
62+
fit = self.fit_from(instance=instance)
63+
magzero = self.kwargs.get("magzero", None)
64+
context = {"fit": fit, "magzero": magzero, "xp": xp}
65+
return tuple(LATENT_FUNCTIONS[k](**context) for k in keys)
66+
3567
def log_likelihood_function(self, instance: af.ModelInstance) -> float:
3668
"""
3769
Given an instance of the model, where the model parameters are set via a non-linear search, fit the model

0 commit comments

Comments
 (0)