Skip to content

Commit 349bf8f

Browse files
authored
Merge pull request #472 from PyAutoLabs/feature/latent-class-phase2
refactor(latent): LatentGalaxy class + declare Analysis.Latent (Phase 2)
2 parents 1cc6eb2 + b5101b9 commit 349bf8f

4 files changed

Lines changed: 56 additions & 48 deletions

File tree

autogalaxy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@
7474
from .gui.scribbler import Scribbler
7575
from .imaging.fit_imaging import FitImaging
7676
from .imaging.model.analysis import AnalysisImaging
77+
from autofit import Latent
78+
from .imaging.model.latent import LatentGalaxy
7779
from .imaging.simulator import SimulatorImaging
7880
from .interferometer.simulator import SimulatorInterferometer
7981
from .interferometer.fit_interferometer import FitInterferometer

autogalaxy/imaging/model/analysis.py

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from autogalaxy.analysis.adapt_images.adapt_images import AdaptImages
2323
from autogalaxy.analysis.analysis.dataset import AnalysisDataset
2424
from autogalaxy.cosmology.model import LensingCosmology
25+
from autogalaxy.imaging.model.latent import LatentGalaxy
2526
from autogalaxy.imaging.model.result import ResultImaging
2627
from autogalaxy.imaging.model.visualizer import VisualizerImaging
2728
from autogalaxy.imaging.fit_imaging import FitImaging
@@ -33,6 +34,7 @@
3334
class AnalysisImaging(AnalysisDataset):
3435
Result = ResultImaging
3536
Visualizer = VisualizerImaging
37+
Latent = LatentGalaxy
3638

3739
def __init__(
3840
self,
@@ -88,11 +90,6 @@ def __init__(
8890
def imaging(self):
8991
return self.dataset
9092

91-
@property
92-
def LATENT_KEYS(self):
93-
from autogalaxy.imaging.model.latent import latent_keys_enabled
94-
return latent_keys_enabled()
95-
9693
def log_likelihood_function(self, instance: af.ModelInstance) -> float:
9794
"""
9895
Given an instance of the model, where the model parameters are set via a non-linear search, fit the model
@@ -176,33 +173,6 @@ def fit_from(self, instance: af.ModelInstance) -> FitImaging:
176173
xp=self._xp,
177174
)
178175

179-
def compute_latent_variables(self, parameters, model):
180-
"""
181-
Compute the catalogue of latent variables enabled in
182-
``config/latent.yaml`` for the given parameter vector.
183-
184-
Returns a tuple positionally aligned with :attr:`LATENT_KEYS` —
185-
PyAutoFit zips it with the keys at
186-
``autofit/non_linear/analysis/analysis.py:285`` and stacks per
187-
sample for the JIT batch path at lines 223-234.
188-
189-
Raises ``NotImplementedError`` when no latents are enabled so
190-
PyAutoFit's outer ``except NotImplementedError`` short-circuits
191-
the latent pipeline cleanly (no empty ``latent.csv`` written).
192-
"""
193-
from autogalaxy.imaging.model.latent import LATENT_FUNCTIONS
194-
195-
keys = self.LATENT_KEYS
196-
if not keys:
197-
raise NotImplementedError
198-
199-
xp = self._xp
200-
instance = model.instance_from_vector(vector=parameters)
201-
fit = self.fit_from(instance=instance)
202-
magzero = self.kwargs.get("magzero", None)
203-
context = {"fit": fit, "magzero": magzero, "xp": xp}
204-
return tuple(LATENT_FUNCTIONS[k](**context) for k in keys)
205-
206176
@staticmethod
207177
def _register_fit_imaging_pytrees() -> None:
208178
"""Register every type reachable from a ``FitImaging`` return value

autogalaxy/imaging/model/latent.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import numpy as np
1818

19+
import autofit as af
1920
from autoconf import conf
2021

2122
logger = logging.getLogger(__name__)
@@ -160,3 +161,38 @@ def latent_keys_enabled(yaml_config: Optional[Dict[str, bool]] = None) -> List[s
160161
continue
161162
enabled.append(key)
162163
return enabled
164+
165+
166+
class LatentGalaxy(af.Latent):
167+
"""
168+
Latent-variable catalogue for galaxy imaging fits, declared on
169+
``AnalysisImaging`` as ``Latent = LatentGalaxy`` (mirrors ``Visualizer`` /
170+
``Result``).
171+
172+
:meth:`keys` returns the config-enabled latent names; :meth:`variables`
173+
dispatches the :data:`LATENT_FUNCTIONS` registry on a per-sample fit.
174+
Subclass to add project-specific latents, composing the library values via
175+
``LatentGalaxy.variables(analysis, parameters, model)``.
176+
"""
177+
178+
# Preserves the previous ``AnalysisDataset.LATENT_BATCH_MODE = "jit"``: the
179+
# per-sample jit path is used because the lensing Einstein-radius latent
180+
# (shared dataset base) routes through ``ZeroSolver``, which is
181+
# vmap-incompatible.
182+
BATCH_MODE = "jit"
183+
184+
@staticmethod
185+
def keys(analysis) -> List[str]:
186+
return latent_keys_enabled()
187+
188+
@staticmethod
189+
def variables(analysis, parameters, model):
190+
keys = latent_keys_enabled()
191+
if not keys:
192+
raise NotImplementedError
193+
xp = analysis._xp
194+
instance = model.instance_from_vector(vector=parameters)
195+
fit = analysis.fit_from(instance=instance)
196+
magzero = analysis.kwargs.get("magzero", None)
197+
context = {"fit": fit, "magzero": magzero, "xp": xp}
198+
return tuple(LATENT_FUNCTIONS[k](**context) for k in keys)

test_autogalaxy/imaging/model/test_latent.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from autogalaxy.imaging.model import latent as _latent_module
1111
from autogalaxy.imaging.model.latent import (
1212
LATENT_FUNCTIONS,
13+
LatentGalaxy,
1314
ab_mag_via_flux_from,
1415
flux_mujy_via_ab_mag_from,
1516
latent_keys_enabled,
@@ -139,7 +140,7 @@ def test_latent_keys_enabled_drops_unknown_with_warning(caplog):
139140
assert any("never_registered_latent" in rec.message for rec in caplog.records)
140141

141142

142-
def test_analysis_imaging_compute_latent_variables_aligns_with_latent_keys(
143+
def test_latent_galaxy_variables_aligns_with_keys(
143144
masked_imaging_7x7,
144145
):
145146
galaxy = ag.Galaxy(redshift=0.5, light=ag.lp.Sersic(intensity=0.1))
@@ -150,35 +151,34 @@ def test_analysis_imaging_compute_latent_variables_aligns_with_latent_keys(
150151
)
151152

152153
parameters = np.array(model.physical_values_from_prior_medians)
153-
values = analysis.compute_latent_variables(parameters=parameters, model=model)
154+
values = LatentGalaxy.variables(analysis, parameters=parameters, model=model)
155+
keys = LatentGalaxy.keys(analysis)
154156

155157
assert isinstance(values, tuple)
156-
assert len(values) == len(analysis.LATENT_KEYS)
158+
assert len(values) == len(keys)
157159
# test_autogalaxy/config/latent.yaml enables both keys, raw flux first.
158-
assert analysis.LATENT_KEYS == ["total_galaxy_0_flux", "total_galaxy_0_flux_mujy"]
160+
assert keys == ["total_galaxy_0_flux", "total_galaxy_0_flux_mujy"]
159161
assert all(np.isfinite(v) for v in values)
160162

161163

162-
def test_analysis_imaging_compute_latent_variables_raises_when_empty(monkeypatch):
164+
def test_latent_galaxy_variables_raises_when_empty(monkeypatch):
163165
# When no latents are enabled, autofit's `except NotImplementedError`
164-
# at autofit/non_linear/analysis/analysis.py:304 short-circuits the
165-
# latent pipeline. We match that contract by raising explicitly.
166-
monkeypatch.setattr(
167-
ag.AnalysisImaging,
168-
"LATENT_KEYS",
169-
property(lambda self: []),
170-
)
166+
# short-circuits the latent pipeline. LatentGalaxy.variables matches that
167+
# contract by raising explicitly.
168+
monkeypatch.setattr(_latent_module, "latent_keys_enabled", lambda *a, **k: [])
171169
analysis = ag.AnalysisImaging(dataset=MagicMock(), use_jax=False)
172170

173171
with pytest.raises(NotImplementedError):
174-
analysis.compute_latent_variables(parameters=np.array([]), model=MagicMock())
172+
LatentGalaxy.variables(analysis, parameters=np.array([]), model=MagicMock())
175173

176174

177-
def test_analysis_imaging_latent_keys_property_reads_config():
175+
def test_analysis_imaging_declares_latent_galaxy_and_keys_read_config():
178176
# The autouse fixture in test_autogalaxy/conftest.py pushes the test
179177
# config dir whose latent.yaml enables both keys.
180178
dataset = MagicMock()
181179
analysis = ag.AnalysisImaging(dataset=dataset, use_jax=False)
182180

183-
assert analysis.LATENT_KEYS == ["total_galaxy_0_flux", "total_galaxy_0_flux_mujy"]
184-
assert set(analysis.LATENT_KEYS).issubset(LATENT_FUNCTIONS.keys())
181+
assert ag.AnalysisImaging.Latent is LatentGalaxy
182+
keys = ag.AnalysisImaging.Latent.keys(analysis)
183+
assert keys == ["total_galaxy_0_flux", "total_galaxy_0_flux_mujy"]
184+
assert set(keys).issubset(LATENT_FUNCTIONS.keys())

0 commit comments

Comments
 (0)