Skip to content

Commit 4e0876b

Browse files
Jammy2211claude
authored andcommitted
Add EPAnalysisFactor for cavity-message injection
EPAnalysisFactor is a thin AnalysisFactor subclass that exposes the EP cavity distribution to its Analysis on each fit, via a guarded ``set_cavity_dist`` hook called from ``factor_step``. This lets a hierarchical / population-level Analysis read per-variable cavity messages directly inside ``log_likelihood_function`` — useful for "global" likelihoods that compare predictions to the per-dataset posterior summaries produced by upstream local fits, e.g. the EP-message-comparison form documented in the cancer-IC50 use case: log L = -0.5 * sum_i || (pred_i - cavity_mean_i) / cavity_sigma_i ||^2 The hook is duck-typed so plain AnalysisFactor behaviour is unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 8e24a11 commit 4e0876b

5 files changed

Lines changed: 157 additions & 1 deletion

File tree

autofit/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .database.aggregator.aggregator import GridSearchAggregator
1616
from .graphical.expectation_propagation.history import EPHistory
1717
from .graphical.declarative.factor.analysis import AnalysisFactor
18+
from .graphical.declarative.factor.analysis import EPAnalysisFactor
1819
from .graphical.declarative.collection import FactorGraphModel
1920
from .graphical.declarative.factor.hierarchical import HierarchicalFactor
2021
from .graphical.laplace import LaplaceOptimiser

autofit/graphical/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from . import utils
22
from .declarative.abstract import PriorFactor
33
from .declarative.collection import FactorGraphModel
4-
from .declarative.factor.analysis import AnalysisFactor
4+
from .declarative.factor.analysis import AnalysisFactor, EPAnalysisFactor
55
from .declarative.factor.hierarchical import _HierarchicalFactor, HierarchicalFactor
66
from .expectation_propagation.ep_mean_field import EPMeanField
77
from .expectation_propagation.optimiser import EPOptimiser

autofit/graphical/declarative/factor/analysis.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,52 @@ def save_results(self, paths: AbstractPaths, result):
187187

188188
def log_likelihood_function(self, instance: ModelInstance) -> float:
189189
return self.analysis.log_likelihood_function(instance)
190+
191+
192+
class EPAnalysisFactor(AnalysisFactor):
193+
"""
194+
An ``AnalysisFactor`` that exposes the EP cavity distribution to its
195+
``Analysis`` on each likelihood evaluation.
196+
197+
On every iteration of the EP optimiser, the cavity distribution
198+
``q⁻ᵃ`` — the product of the posterior approximations from all
199+
*other* factors over the variables this factor shares with them —
200+
is computed in
201+
:class:`autofit.graphical.mean_field.FactorApproximation`. For most
202+
factors that distribution is consumed implicitly: it becomes the
203+
prior the search samples from.
204+
205+
Some hierarchical / population-level analyses want to read those
206+
cavity messages directly. A canonical example is a "global"
207+
Analysis whose log-likelihood compares model predictions to the
208+
per-dataset Gaussian posterior summaries produced by upstream
209+
local fits, e.g.::
210+
211+
log L = -0.5 * sum_i || (pred_i - cavity_mean_i) / cavity_sigma_i ||^2
212+
213+
To support that, ``EPAnalysisFactor`` attaches the current cavity
214+
distribution to its ``Analysis`` immediately before optimisation,
215+
via the attribute ``_cavity_mean_field``. The user's
216+
``log_likelihood_function`` can then read each shared variable's
217+
cavity message (``.mean`` and ``.scale`` on the
218+
``AbstractMessage`` value) out of the dict.
219+
220+
The hook is invoked from
221+
:func:`autofit.graphical.expectation_propagation.optimiser.factor_step`
222+
via duck-typing (``hasattr(factor, "set_cavity_dist")``), so the
223+
behaviour of plain ``AnalysisFactor`` is unaffected.
224+
"""
225+
226+
def set_cavity_dist(self, cavity_dist):
227+
"""
228+
Store the cavity distribution on the wrapped ``Analysis``.
229+
230+
Called by :func:`factor_step` once per EP iteration, before this
231+
factor's local search runs. The Analysis can read the messages
232+
inside ``log_likelihood_function`` by inspecting
233+
``self._cavity_mean_field`` — a ``MeanField`` mapping each
234+
shared :class:`Variable` (i.e. ``Prior``) to an
235+
``AbstractMessage`` whose ``.mean`` and ``.scale`` give the
236+
cavity Gaussian summary.
237+
"""
238+
self.analysis._cavity_mean_field = cavity_dist

autofit/graphical/expectation_propagation/optimiser.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,13 @@ def factor_step(factor_approx, optimiser):
108108
factor = factor_approx.factor
109109
factor_logger = logging.getLogger(factor.name)
110110
factor_logger.debug("Optimising...")
111+
# Cavity-message opt-in: factors that implement ``set_cavity_dist``
112+
# (e.g. ``EPAnalysisFactor``) receive the current cavity distribution
113+
# before optimisation so their Analysis can read per-variable cavity
114+
# messages inside ``log_likelihood_function``. Default factors lack
115+
# the method, so this is a no-op for them.
116+
if hasattr(factor, "set_cavity_dist"):
117+
factor.set_cavity_dist(factor_approx.cavity_dist)
111118
try:
112119
with LogWarnings(
113120
logger=factor_logger.debug, action="always"
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""
2+
Tests for EPAnalysisFactor cavity-message injection.
3+
4+
EPAnalysisFactor is a thin AnalysisFactor subclass that exposes the EP
5+
cavity distribution to its Analysis on each fit, via the hook in
6+
``autofit.graphical.expectation_propagation.optimiser.factor_step``.
7+
The Analysis can then read per-variable cavity messages inside
8+
``log_likelihood_function`` — the canonical use case is a "global"
9+
Analysis that compares predictions to per-dataset posterior summaries
10+
produced by upstream local fits.
11+
"""
12+
from unittest.mock import MagicMock
13+
14+
import autofit as af
15+
from autofit.graphical.expectation_propagation.optimiser import factor_step
16+
from autofit.graphical.utils import Status, StatusFlag
17+
18+
19+
class _RecordingAnalysis(af.Analysis):
20+
"""Records every log_likelihood_function call's cavity state."""
21+
22+
def __init__(self):
23+
super().__init__()
24+
self.observed_cavities = []
25+
26+
def log_likelihood_function(self, instance):
27+
self.observed_cavities.append(getattr(self, "_cavity_mean_field", None))
28+
return 0.0
29+
30+
31+
def test_set_cavity_dist_attaches_to_analysis():
32+
"""``set_cavity_dist`` should populate ``analysis._cavity_mean_field``."""
33+
model = af.Model(af.ex.Gaussian)
34+
analysis = _RecordingAnalysis()
35+
factor = af.EPAnalysisFactor(prior_model=model, analysis=analysis)
36+
37+
sentinel = object()
38+
factor.set_cavity_dist(sentinel)
39+
40+
assert analysis._cavity_mean_field is sentinel
41+
42+
43+
def test_plain_analysis_factor_has_no_set_cavity_dist():
44+
"""Plain AnalysisFactor must remain untouched by the hook."""
45+
model = af.Model(af.ex.Gaussian)
46+
analysis = _RecordingAnalysis()
47+
factor = af.AnalysisFactor(prior_model=model, analysis=analysis)
48+
49+
assert not hasattr(factor, "set_cavity_dist")
50+
51+
52+
def test_factor_step_invokes_set_cavity_dist():
53+
"""
54+
``factor_step`` should call ``set_cavity_dist`` before optimisation
55+
so the Analysis sees the cavity during every likelihood evaluation.
56+
"""
57+
model = af.Model(af.ex.Gaussian)
58+
analysis = _RecordingAnalysis()
59+
factor = af.EPAnalysisFactor(prior_model=model, analysis=analysis)
60+
61+
cavity_sentinel = object()
62+
63+
factor_approx = MagicMock()
64+
factor_approx.factor = factor
65+
factor_approx.cavity_dist = cavity_sentinel
66+
factor_approx.model_dist = MagicMock()
67+
68+
optimiser = MagicMock()
69+
optimiser.optimise.return_value = (
70+
MagicMock(),
71+
Status(success=True, messages=(), flag=StatusFlag.SUCCESS),
72+
)
73+
74+
factor_step(factor_approx, optimiser)
75+
76+
assert analysis._cavity_mean_field is cavity_sentinel
77+
optimiser.optimise.assert_called_once_with(factor_approx)
78+
79+
80+
def test_factor_step_no_op_for_plain_analysis_factor():
81+
"""No exception should be raised for plain ``AnalysisFactor``."""
82+
model = af.Model(af.ex.Gaussian)
83+
analysis = _RecordingAnalysis()
84+
factor = af.AnalysisFactor(prior_model=model, analysis=analysis)
85+
86+
factor_approx = MagicMock()
87+
factor_approx.factor = factor
88+
factor_approx.cavity_dist = object()
89+
factor_approx.model_dist = MagicMock()
90+
91+
optimiser = MagicMock()
92+
optimiser.optimise.return_value = (
93+
MagicMock(),
94+
Status(success=True, messages=(), flag=StatusFlag.SUCCESS),
95+
)
96+
97+
factor_step(factor_approx, optimiser)
98+
99+
assert not hasattr(analysis, "_cavity_mean_field")

0 commit comments

Comments
 (0)