Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions autofit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .database.aggregator.aggregator import GridSearchAggregator
from .graphical.expectation_propagation.history import EPHistory
from .graphical.declarative.factor.analysis import AnalysisFactor
from .graphical.declarative.factor.analysis import EPAnalysisFactor
from .graphical.declarative.collection import FactorGraphModel
from .graphical.declarative.factor.hierarchical import HierarchicalFactor
from .graphical.laplace import LaplaceOptimiser
Expand Down
2 changes: 1 addition & 1 deletion autofit/graphical/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from . import utils
from .declarative.abstract import PriorFactor
from .declarative.collection import FactorGraphModel
from .declarative.factor.analysis import AnalysisFactor
from .declarative.factor.analysis import AnalysisFactor, EPAnalysisFactor
from .declarative.factor.hierarchical import _HierarchicalFactor, HierarchicalFactor
from .expectation_propagation.ep_mean_field import EPMeanField
from .expectation_propagation.optimiser import EPOptimiser
Expand Down
49 changes: 49 additions & 0 deletions autofit/graphical/declarative/factor/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,52 @@ def save_results(self, paths: AbstractPaths, result):

def log_likelihood_function(self, instance: ModelInstance) -> float:
return self.analysis.log_likelihood_function(instance)


class EPAnalysisFactor(AnalysisFactor):
"""
An ``AnalysisFactor`` that exposes the EP cavity distribution to its
``Analysis`` on each likelihood evaluation.

On every iteration of the EP optimiser, the cavity distribution
``q⁻ᵃ`` — the product of the posterior approximations from all
*other* factors over the variables this factor shares with them —
is computed in
:class:`autofit.graphical.mean_field.FactorApproximation`. For most
factors that distribution is consumed implicitly: it becomes the
prior the search samples from.

Some hierarchical / population-level analyses want to read those
cavity messages directly. A canonical example is a "global"
Analysis whose log-likelihood compares model predictions to the
per-dataset Gaussian posterior summaries produced by upstream
local fits, e.g.::

log L = -0.5 * sum_i || (pred_i - cavity_mean_i) / cavity_sigma_i ||^2

To support that, ``EPAnalysisFactor`` attaches the current cavity
distribution to its ``Analysis`` immediately before optimisation,
via the attribute ``_cavity_mean_field``. The user's
``log_likelihood_function`` can then read each shared variable's
cavity message (``.mean`` and ``.scale`` on the
``AbstractMessage`` value) out of the dict.

The hook is invoked from
:func:`autofit.graphical.expectation_propagation.optimiser.factor_step`
via duck-typing (``hasattr(factor, "set_cavity_dist")``), so the
behaviour of plain ``AnalysisFactor`` is unaffected.
"""

def set_cavity_dist(self, cavity_dist):
"""
Store the cavity distribution on the wrapped ``Analysis``.

Called by :func:`factor_step` once per EP iteration, before this
factor's local search runs. The Analysis can read the messages
inside ``log_likelihood_function`` by inspecting
``self._cavity_mean_field`` — a ``MeanField`` mapping each
shared :class:`Variable` (i.e. ``Prior``) to an
``AbstractMessage`` whose ``.mean`` and ``.scale`` give the
cavity Gaussian summary.
"""
self.analysis._cavity_mean_field = cavity_dist
7 changes: 7 additions & 0 deletions autofit/graphical/expectation_propagation/optimiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ def factor_step(factor_approx, optimiser):
factor = factor_approx.factor
factor_logger = logging.getLogger(factor.name)
factor_logger.debug("Optimising...")
# Cavity-message opt-in: factors that implement ``set_cavity_dist``
# (e.g. ``EPAnalysisFactor``) receive the current cavity distribution
# before optimisation so their Analysis can read per-variable cavity
# messages inside ``log_likelihood_function``. Default factors lack
# the method, so this is a no-op for them.
if hasattr(factor, "set_cavity_dist"):
factor.set_cavity_dist(factor_approx.cavity_dist)
try:
with LogWarnings(
logger=factor_logger.debug, action="always"
Expand Down
99 changes: 99 additions & 0 deletions test_autofit/graphical/test_ep_analysis_factor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""
Tests for EPAnalysisFactor cavity-message injection.

EPAnalysisFactor is a thin AnalysisFactor subclass that exposes the EP
cavity distribution to its Analysis on each fit, via the hook in
``autofit.graphical.expectation_propagation.optimiser.factor_step``.
The Analysis can then read per-variable cavity messages inside
``log_likelihood_function`` — the canonical use case is a "global"
Analysis that compares predictions to per-dataset posterior summaries
produced by upstream local fits.
"""
from unittest.mock import MagicMock

import autofit as af
from autofit.graphical.expectation_propagation.optimiser import factor_step
from autofit.graphical.utils import Status, StatusFlag


class _RecordingAnalysis(af.Analysis):
"""Records every log_likelihood_function call's cavity state."""

def __init__(self):
super().__init__()
self.observed_cavities = []

def log_likelihood_function(self, instance):
self.observed_cavities.append(getattr(self, "_cavity_mean_field", None))
return 0.0


def test_set_cavity_dist_attaches_to_analysis():
"""``set_cavity_dist`` should populate ``analysis._cavity_mean_field``."""
model = af.Model(af.ex.Gaussian)
analysis = _RecordingAnalysis()
factor = af.EPAnalysisFactor(prior_model=model, analysis=analysis)

sentinel = object()
factor.set_cavity_dist(sentinel)

assert analysis._cavity_mean_field is sentinel


def test_plain_analysis_factor_has_no_set_cavity_dist():
"""Plain AnalysisFactor must remain untouched by the hook."""
model = af.Model(af.ex.Gaussian)
analysis = _RecordingAnalysis()
factor = af.AnalysisFactor(prior_model=model, analysis=analysis)

assert not hasattr(factor, "set_cavity_dist")


def test_factor_step_invokes_set_cavity_dist():
"""
``factor_step`` should call ``set_cavity_dist`` before optimisation
so the Analysis sees the cavity during every likelihood evaluation.
"""
model = af.Model(af.ex.Gaussian)
analysis = _RecordingAnalysis()
factor = af.EPAnalysisFactor(prior_model=model, analysis=analysis)

cavity_sentinel = object()

factor_approx = MagicMock()
factor_approx.factor = factor
factor_approx.cavity_dist = cavity_sentinel
factor_approx.model_dist = MagicMock()

optimiser = MagicMock()
optimiser.optimise.return_value = (
MagicMock(),
Status(success=True, messages=(), flag=StatusFlag.SUCCESS),
)

factor_step(factor_approx, optimiser)

assert analysis._cavity_mean_field is cavity_sentinel
optimiser.optimise.assert_called_once_with(factor_approx)


def test_factor_step_no_op_for_plain_analysis_factor():
"""No exception should be raised for plain ``AnalysisFactor``."""
model = af.Model(af.ex.Gaussian)
analysis = _RecordingAnalysis()
factor = af.AnalysisFactor(prior_model=model, analysis=analysis)

factor_approx = MagicMock()
factor_approx.factor = factor
factor_approx.cavity_dist = object()
factor_approx.model_dist = MagicMock()

optimiser = MagicMock()
optimiser.optimise.return_value = (
MagicMock(),
Status(success=True, messages=(), flag=StatusFlag.SUCCESS),
)

factor_step(factor_approx, optimiser)

assert not hasattr(analysis, "_cavity_mean_field")
Loading