diff --git a/autofit/__init__.py b/autofit/__init__.py index b76f21199..6f42e0789 100644 --- a/autofit/__init__.py +++ b/autofit/__init__.py @@ -15,6 +15,7 @@ from .database.aggregator.aggregator import GridSearchAggregator from .graphical.expectation_propagation.history import EPHistory from .graphical.declarative.factor.analysis import AnalysisFactor +from .graphical.declarative.factor.analysis import EPAnalysisFactor from .graphical.declarative.collection import FactorGraphModel from .graphical.declarative.factor.hierarchical import HierarchicalFactor from .graphical.laplace import LaplaceOptimiser diff --git a/autofit/graphical/__init__.py b/autofit/graphical/__init__.py index 99eb3c7d7..a3e6ae235 100644 --- a/autofit/graphical/__init__.py +++ b/autofit/graphical/__init__.py @@ -1,7 +1,7 @@ from . import utils from .declarative.abstract import PriorFactor from .declarative.collection import FactorGraphModel -from .declarative.factor.analysis import AnalysisFactor +from .declarative.factor.analysis import AnalysisFactor, EPAnalysisFactor from .declarative.factor.hierarchical import _HierarchicalFactor, HierarchicalFactor from .expectation_propagation.ep_mean_field import EPMeanField from .expectation_propagation.optimiser import EPOptimiser diff --git a/autofit/graphical/declarative/factor/analysis.py b/autofit/graphical/declarative/factor/analysis.py index 5ea5ab50a..ed9cdf16d 100644 --- a/autofit/graphical/declarative/factor/analysis.py +++ b/autofit/graphical/declarative/factor/analysis.py @@ -187,3 +187,52 @@ def save_results(self, paths: AbstractPaths, result): def log_likelihood_function(self, instance: ModelInstance) -> float: return self.analysis.log_likelihood_function(instance) + + +class EPAnalysisFactor(AnalysisFactor): + """ + An ``AnalysisFactor`` that exposes the EP cavity distribution to its + ``Analysis`` on each likelihood evaluation. + + On every iteration of the EP optimiser, the cavity distribution + ``q⁻ᵃ`` — the product of the posterior approximations from all + *other* factors over the variables this factor shares with them — + is computed in + :class:`autofit.graphical.mean_field.FactorApproximation`. For most + factors that distribution is consumed implicitly: it becomes the + prior the search samples from. + + Some hierarchical / population-level analyses want to read those + cavity messages directly. A canonical example is a "global" + Analysis whose log-likelihood compares model predictions to the + per-dataset Gaussian posterior summaries produced by upstream + local fits, e.g.:: + + log L = -0.5 * sum_i || (pred_i - cavity_mean_i) / cavity_sigma_i ||^2 + + To support that, ``EPAnalysisFactor`` attaches the current cavity + distribution to its ``Analysis`` immediately before optimisation, + via the attribute ``_cavity_mean_field``. The user's + ``log_likelihood_function`` can then read each shared variable's + cavity message (``.mean`` and ``.scale`` on the + ``AbstractMessage`` value) out of the dict. + + The hook is invoked from + :func:`autofit.graphical.expectation_propagation.optimiser.factor_step` + via duck-typing (``hasattr(factor, "set_cavity_dist")``), so the + behaviour of plain ``AnalysisFactor`` is unaffected. + """ + + def set_cavity_dist(self, cavity_dist): + """ + Store the cavity distribution on the wrapped ``Analysis``. + + Called by :func:`factor_step` once per EP iteration, before this + factor's local search runs. The Analysis can read the messages + inside ``log_likelihood_function`` by inspecting + ``self._cavity_mean_field`` — a ``MeanField`` mapping each + shared :class:`Variable` (i.e. ``Prior``) to an + ``AbstractMessage`` whose ``.mean`` and ``.scale`` give the + cavity Gaussian summary. + """ + self.analysis._cavity_mean_field = cavity_dist diff --git a/autofit/graphical/expectation_propagation/optimiser.py b/autofit/graphical/expectation_propagation/optimiser.py index b3c090f5e..2a000bec3 100644 --- a/autofit/graphical/expectation_propagation/optimiser.py +++ b/autofit/graphical/expectation_propagation/optimiser.py @@ -108,6 +108,13 @@ def factor_step(factor_approx, optimiser): factor = factor_approx.factor factor_logger = logging.getLogger(factor.name) factor_logger.debug("Optimising...") + # Cavity-message opt-in: factors that implement ``set_cavity_dist`` + # (e.g. ``EPAnalysisFactor``) receive the current cavity distribution + # before optimisation so their Analysis can read per-variable cavity + # messages inside ``log_likelihood_function``. Default factors lack + # the method, so this is a no-op for them. + if hasattr(factor, "set_cavity_dist"): + factor.set_cavity_dist(factor_approx.cavity_dist) try: with LogWarnings( logger=factor_logger.debug, action="always" diff --git a/test_autofit/graphical/test_ep_analysis_factor.py b/test_autofit/graphical/test_ep_analysis_factor.py new file mode 100644 index 000000000..f05c0856c --- /dev/null +++ b/test_autofit/graphical/test_ep_analysis_factor.py @@ -0,0 +1,99 @@ +""" +Tests for EPAnalysisFactor cavity-message injection. + +EPAnalysisFactor is a thin AnalysisFactor subclass that exposes the EP +cavity distribution to its Analysis on each fit, via the hook in +``autofit.graphical.expectation_propagation.optimiser.factor_step``. +The Analysis can then read per-variable cavity messages inside +``log_likelihood_function`` — the canonical use case is a "global" +Analysis that compares predictions to per-dataset posterior summaries +produced by upstream local fits. +""" +from unittest.mock import MagicMock + +import autofit as af +from autofit.graphical.expectation_propagation.optimiser import factor_step +from autofit.graphical.utils import Status, StatusFlag + + +class _RecordingAnalysis(af.Analysis): + """Records every log_likelihood_function call's cavity state.""" + + def __init__(self): + super().__init__() + self.observed_cavities = [] + + def log_likelihood_function(self, instance): + self.observed_cavities.append(getattr(self, "_cavity_mean_field", None)) + return 0.0 + + +def test_set_cavity_dist_attaches_to_analysis(): + """``set_cavity_dist`` should populate ``analysis._cavity_mean_field``.""" + model = af.Model(af.ex.Gaussian) + analysis = _RecordingAnalysis() + factor = af.EPAnalysisFactor(prior_model=model, analysis=analysis) + + sentinel = object() + factor.set_cavity_dist(sentinel) + + assert analysis._cavity_mean_field is sentinel + + +def test_plain_analysis_factor_has_no_set_cavity_dist(): + """Plain AnalysisFactor must remain untouched by the hook.""" + model = af.Model(af.ex.Gaussian) + analysis = _RecordingAnalysis() + factor = af.AnalysisFactor(prior_model=model, analysis=analysis) + + assert not hasattr(factor, "set_cavity_dist") + + +def test_factor_step_invokes_set_cavity_dist(): + """ + ``factor_step`` should call ``set_cavity_dist`` before optimisation + so the Analysis sees the cavity during every likelihood evaluation. + """ + model = af.Model(af.ex.Gaussian) + analysis = _RecordingAnalysis() + factor = af.EPAnalysisFactor(prior_model=model, analysis=analysis) + + cavity_sentinel = object() + + factor_approx = MagicMock() + factor_approx.factor = factor + factor_approx.cavity_dist = cavity_sentinel + factor_approx.model_dist = MagicMock() + + optimiser = MagicMock() + optimiser.optimise.return_value = ( + MagicMock(), + Status(success=True, messages=(), flag=StatusFlag.SUCCESS), + ) + + factor_step(factor_approx, optimiser) + + assert analysis._cavity_mean_field is cavity_sentinel + optimiser.optimise.assert_called_once_with(factor_approx) + + +def test_factor_step_no_op_for_plain_analysis_factor(): + """No exception should be raised for plain ``AnalysisFactor``.""" + model = af.Model(af.ex.Gaussian) + analysis = _RecordingAnalysis() + factor = af.AnalysisFactor(prior_model=model, analysis=analysis) + + factor_approx = MagicMock() + factor_approx.factor = factor + factor_approx.cavity_dist = object() + factor_approx.model_dist = MagicMock() + + optimiser = MagicMock() + optimiser.optimise.return_value = ( + MagicMock(), + Status(success=True, messages=(), flag=StatusFlag.SUCCESS), + ) + + factor_step(factor_approx, optimiser) + + assert not hasattr(analysis, "_cavity_mean_field")