Skip to content

Commit d7cf9d4

Browse files
Jammy2211claude
authored andcommitted
feat: cross-Analysis shared per-evaluation state in FactorGraphModel
Add an opt-in mechanism for the per-factor Analysis objects of a FactorGraphModel to compute a model-dependent object once per likelihood evaluation and reuse it across every factor, instead of each factor recomputing identical work. - Analysis.shared_state_from(instance) -> None: opt-in hook (default None), the per-evaluation cross-factor sibling of modify_before_fit. - Analysis.log_likelihood_function gains an optional shared= kwarg. - FactorGraphModel computes the shared object once from the lead factor and forwards it to each factor only when non-None, so existing graphs are byte-for-byte unchanged. - AnalysisFactor forwards shared=; PriorFactor and HierarchicalFactor accept and ignore it for a uniform calling convention. - af.ex.Analysis demonstrates it on the 1D Gaussian toy via an opt-in share_model_data flag. - New unit tests in test_autofit/graphical/test_shared_state.py. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 682d673 commit d7cf9d4

7 files changed

Lines changed: 268 additions & 9 deletions

File tree

autofit/example/analysis.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,13 @@ class Analysis(af.Analysis):
3636

3737
LATENT_KEYS = ["gaussian.fwhm"]
3838

39-
def __init__(self, data: np.ndarray, noise_map: np.ndarray, use_jax=False):
39+
def __init__(
40+
self,
41+
data: np.ndarray,
42+
noise_map: np.ndarray,
43+
use_jax=False,
44+
share_model_data=False,
45+
):
4046
"""
4147
In this example the `Analysis` object only contains the data and noise-map. It can be easily extended,
4248
for more complex data-sets and model fitting problems.
@@ -48,26 +54,66 @@ def __init__(self, data: np.ndarray, noise_map: np.ndarray, use_jax=False):
4854
noise_map
4955
A 1D numpy array containing the noise values of the data, used for computing the goodness of fit
5056
metric.
57+
share_model_data
58+
If `True`, opt this `Analysis` into the `FactorGraphModel` cross-factor shared-state mechanism
59+
(see `shared_state_from`). This is only valid when the *entire* model is shared across every
60+
factor, so the model data is identical for all of them and can be computed once instead of being
61+
rebuilt by each factor. It is `False` by default, so the standard per-analysis behaviour is
62+
unchanged.
5163
"""
5264
super().__init__(use_jax=use_jax)
5365

5466
self.data = data
5567
self.noise_map = noise_map
68+
self.share_model_data = share_model_data
69+
70+
def shared_state_from(self, instance: af.ModelInstance):
71+
"""
72+
Compute the model data once so that it can be shared across the factors of a `FactorGraphModel`.
73+
74+
This is the worked example of `Analysis.shared_state_from` (see that method). When every factor of
75+
the graph shares the *entire* model — for example several datasets fit by the same 1D profile via
76+
shared priors — the model data is identical for every factor, so it is wasteful to rebuild it once
77+
per factor. Returning it here means the `FactorGraphModel` computes it a single time on the lead
78+
factor and reuses it for all the others.
79+
80+
In this toy the model data is cheap, but it stands in for an expensive shared computation: it is the
81+
1D analog of the lensing case, where the shared work (ray-tracing, the source-plane mapper, the
82+
mapping matrix and the curvature matrix) dominates the per-factor cost.
83+
84+
Sharing is opt-in (`share_model_data`) because it is only correct when the model really is fully
85+
shared. If only some parameters are shared the model data differs between factors and this returns
86+
`None`, so each factor computes its own as usual.
87+
"""
88+
if not self.share_model_data:
89+
return None
90+
91+
return self.model_data_1d_from(instance=instance)
5692

57-
def log_likelihood_function(self, instance: af.ModelInstance, xp=np) -> float:
93+
def log_likelihood_function(
94+
self, instance: af.ModelInstance, shared=None, xp=np
95+
) -> float:
5896
"""
5997
Determine the log likelihood of a fit of multiple profiles to the dataset.
6098
6199
Parameters
62100
----------
63101
instance : af.Collection
64102
The model instances of the profiles.
103+
shared
104+
The model data shared across the factors of a `FactorGraphModel`, computed once by
105+
`shared_state_from` (see that method). When provided it is used directly instead of being
106+
recomputed here; when `None` (the default, e.g. a standalone fit) the model data is computed
107+
as normal.
65108
66109
Returns
67110
-------
68111
The log likelihood value indicating how well this model fit the dataset.
69112
"""
70-
model_data_1d = self.model_data_1d_from(instance=instance)
113+
if shared is None:
114+
model_data_1d = self.model_data_1d_from(instance=instance)
115+
else:
116+
model_data_1d = shared
71117

72118
residual_map = self.data - model_data_1d
73119
chi_squared_map = (residual_map / self.noise_map) ** 2.0

autofit/graphical/declarative/collection.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,17 @@ def log_likelihood_function(self, instance: ModelInstance) -> float:
9191
Compute the combined likelihood of each factor from a collection of instances
9292
with the same ordering as the factors.
9393
94+
Before the per-factor loop, the lead factor is asked to compute a `shared`
95+
object via `Analysis.shared_state_from` (see that method). If it returns a
96+
non-`None` value — the opt-in case — that object is forwarded as the `shared`
97+
keyword argument to every factor's `log_likelihood_function`, so that work
98+
which is identical for all factors at this point in parameter space is computed
99+
once and reused rather than recomputed for each factor.
100+
101+
When no factor provides a shared object (the default) the loop calls each
102+
factor's `log_likelihood_function` exactly as it would without this mechanism,
103+
so existing graphs are unchanged.
104+
94105
Parameters
95106
----------
96107
instance
@@ -100,12 +111,37 @@ def log_likelihood_function(self, instance: ModelInstance) -> float:
100111
-------
101112
The combined likelihood of all factors
102113
"""
114+
shared = self._shared_state_from(instance)
115+
103116
log_likelihood = 0
104117
for model_factor, instance_ in zip(self.model_factors, instance):
105-
log_likelihood += model_factor.log_likelihood_function(instance_)
118+
if shared is None:
119+
log_likelihood += model_factor.log_likelihood_function(instance_)
120+
else:
121+
log_likelihood += model_factor.log_likelihood_function(
122+
instance_, shared=shared
123+
)
106124

107125
return log_likelihood
108126

127+
def _shared_state_from(self, instance: ModelInstance):
128+
"""
129+
Compute the per-evaluation object shared across factors, by asking each factor's
130+
`Analysis` in turn (via `Analysis.shared_state_from`) until one returns a
131+
non-`None` value — the "lead" factor. Returns `None` if no factor opts in, in
132+
which case no state is shared this evaluation.
133+
"""
134+
for model_factor, instance_ in zip(self.model_factors, instance):
135+
analysis = getattr(model_factor, "analysis", None)
136+
shared_state_from = getattr(analysis, "shared_state_from", None)
137+
if shared_state_from is None:
138+
continue
139+
shared = shared_state_from(instance_)
140+
if shared is not None:
141+
return shared
142+
143+
return None
144+
109145
@property
110146
def model_factors(self):
111147
model_factors = list()

autofit/graphical/declarative/factor/analysis.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,10 @@ def save_results(self, paths: AbstractPaths, result):
250250
"""
251251
self.analysis.save_results(paths=paths, result=result)
252252

253-
def log_likelihood_function(self, instance: ModelInstance) -> float:
254-
return self.analysis.log_likelihood_function(instance)
253+
def log_likelihood_function(self, instance: ModelInstance, shared=None) -> float:
254+
if shared is None:
255+
return self.analysis.log_likelihood_function(instance)
256+
return self.analysis.log_likelihood_function(instance, shared=shared)
255257

256258

257259
class EPAnalysisFactor(AnalysisFactor):

autofit/graphical/declarative/factor/hierarchical.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def message_dict(self) -> Dict[Prior, NormalMessage]:
190190
def variable(self):
191191
return self.drawn_prior
192192

193-
def log_likelihood_function(self, instance):
193+
def log_likelihood_function(self, instance, shared=None):
194194
return instance.distribution_model.message(instance.drawn_prior, xp=self._xp)
195195

196196
@property

autofit/graphical/declarative/factor/prior.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,17 @@ def analysis(self) -> "PriorFactor":
4747
"""
4848
return self
4949

50-
def log_likelihood_function(self, instance) -> float:
50+
def log_likelihood_function(self, instance, shared=None) -> float:
5151
"""
5252
Compute the likelihood.
5353
5454
The instance is a collection with a single argument expressing a
5555
possible value for this prior. The likelihood is computed by simply
5656
evaluating the prior's PDF for the given value.
57+
58+
The `shared` argument (the cross-factor shared state of a
59+
`FactorGraphModel`) is accepted for a uniform calling convention but is
60+
not used by a prior factor.
5761
"""
5862
return self.prior.factor(instance[0])
5963

autofit/non_linear/analysis/analysis.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,9 +305,49 @@ def with_model(self, model):
305305

306306
return ModelAnalysis(analysis=self, model=model)
307307

308-
def log_likelihood_function(self, instance):
308+
def log_likelihood_function(self, instance, shared=None):
309309
raise NotImplementedError()
310310

311+
def shared_state_from(self, instance):
312+
"""
313+
Optionally compute a per-evaluation object that is shared across the factors
314+
of a `FactorGraphModel`.
315+
316+
This is the per-evaluation, cross-factor sibling of `modify_before_fit`. Where
317+
`modify_before_fit` runs once before sampling to precompute analysis state that
318+
does not depend on the model, `shared_state_from` runs once per likelihood
319+
evaluation (the model parameters change every sample) and computes state that
320+
is identical for every factor at the current point in parameter space.
321+
322+
When a `FactorGraphModel` evaluates its likelihood it calls this method on its
323+
lead factor's `Analysis`. If the returned value is not `None` it is forwarded as
324+
the `shared` keyword argument to every factor's `log_likelihood_function`, so
325+
that work which is identical for all factors (because they share model
326+
parameters) is computed once and reused rather than recomputed `N` times.
327+
328+
The default implementation returns `None`, meaning no state is shared and every
329+
factor's `log_likelihood_function` runs exactly as it does without this
330+
mechanism. An `Analysis` opts in by overriding this method.
331+
332+
The returned object must be a valid JAX pytree of traced arrays when the fit is
333+
JIT-compiled: it is recomputed inside the jitted region each evaluation (it
334+
depends on the traced model parameters) and must not be memoised on the instance.
335+
336+
Correctness is the responsibility of the overriding `Analysis`: only return a
337+
shared object when the parameters it depends on really are shared across every
338+
factor. If they are not, return `None` and let each factor compute its own state.
339+
340+
Parameters
341+
----------
342+
instance
343+
The model instance of the factor whose `Analysis` is acting as the lead.
344+
345+
Returns
346+
-------
347+
An object shared across all factors for this evaluation, or `None` for no sharing.
348+
"""
349+
return None
350+
311351
def save_attributes(self, paths: AbstractPaths):
312352
pass
313353

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import itertools
2+
3+
import numpy as np
4+
import pytest
5+
6+
import autofit as af
7+
import autofit.graphical as g
8+
9+
10+
@pytest.fixture(autouse=True)
11+
def reset_ids():
12+
af.Prior._ids = itertools.count()
13+
14+
15+
class CountingAnalysis(af.ex.Analysis):
16+
"""
17+
An example `Analysis` that counts how many times the (notionally expensive) model
18+
data computation runs, so a test can prove the `FactorGraphModel` shared-state
19+
mechanism computes it once per evaluation rather than once per factor.
20+
"""
21+
22+
def __init__(self, data, noise_map, share_model_data=True):
23+
super().__init__(
24+
data=data, noise_map=noise_map, share_model_data=share_model_data
25+
)
26+
self.model_data_calls = 0
27+
28+
def model_data_1d_from(self, instance):
29+
self.model_data_calls += 1
30+
return super().model_data_1d_from(instance=instance)
31+
32+
33+
def _shared_gaussian_graph(analyses):
34+
"""
35+
Build a `FactorGraphModel` whose factors share the *entire* Gaussian model via
36+
shared prior objects, so the model data is identical for every factor.
37+
"""
38+
centre = af.UniformPrior(lower_limit=0.0, upper_limit=10.0)
39+
normalization = af.UniformPrior(lower_limit=0.0, upper_limit=10.0)
40+
sigma = af.UniformPrior(lower_limit=0.0, upper_limit=10.0)
41+
42+
factors = []
43+
for analysis in analyses:
44+
gaussian = af.Model(af.ex.Gaussian)
45+
gaussian.centre = centre
46+
gaussian.normalization = normalization
47+
gaussian.sigma = sigma
48+
factors.append(af.AnalysisFactor(gaussian, analysis))
49+
50+
return g.FactorGraphModel(*factors)
51+
52+
53+
def _datasets(n=3, size=10):
54+
"""`n` distinct 1D datasets sharing a common noise map of ones."""
55+
return [
56+
(np.arange(size, dtype=float) + float(i), np.ones(size))
57+
for i in range(n)
58+
]
59+
60+
61+
def _instance(collection):
62+
prior_count = collection.global_prior_model.prior_count
63+
return collection.global_prior_model.instance_from_unit_vector(
64+
[0.5] * prior_count
65+
)
66+
67+
68+
def _reference_log_likelihood(collection, instance):
69+
"""Sum each factor's likelihood with no sharing (each computes its own model data)."""
70+
return sum(
71+
factor.analysis.log_likelihood_function(instance_)
72+
for factor, instance_ in zip(collection.model_factors, instance)
73+
)
74+
75+
76+
def test_shared_state_computed_once_per_evaluation():
77+
analyses = [
78+
CountingAnalysis(data, noise_map) for data, noise_map in _datasets(n=3)
79+
]
80+
collection = _shared_gaussian_graph(analyses)
81+
instance = _instance(collection)
82+
83+
collection.log_likelihood_function(instance)
84+
85+
total_calls = sum(analysis.model_data_calls for analysis in analyses)
86+
assert total_calls == 1
87+
88+
89+
def test_shared_likelihood_equals_unshared_sum():
90+
analyses = [
91+
CountingAnalysis(data, noise_map) for data, noise_map in _datasets(n=3)
92+
]
93+
collection = _shared_gaussian_graph(analyses)
94+
instance = _instance(collection)
95+
96+
shared_log_likelihood = collection.log_likelihood_function(instance)
97+
reference_log_likelihood = _reference_log_likelihood(collection, instance)
98+
99+
assert shared_log_likelihood == pytest.approx(reference_log_likelihood)
100+
101+
102+
def test_no_provider_graph_is_unchanged():
103+
"""
104+
With `share_model_data=False` no factor opts in, so no state is shared: each factor
105+
computes its own model data (N calls) and the summed likelihood is unchanged.
106+
"""
107+
analyses = [
108+
CountingAnalysis(data, noise_map, share_model_data=False)
109+
for data, noise_map in _datasets(n=3)
110+
]
111+
collection = _shared_gaussian_graph(analyses)
112+
instance = _instance(collection)
113+
114+
log_likelihood = collection.log_likelihood_function(instance)
115+
reference_log_likelihood = _reference_log_likelihood(collection, instance)
116+
117+
total_calls = sum(analysis.model_data_calls for analysis in analyses)
118+
# one call per factor from the graph evaluation, plus one per factor from the
119+
# reference sum — the graph did not share, so it computed all three itself.
120+
assert total_calls == 2 * len(analyses)
121+
assert log_likelihood == pytest.approx(reference_log_likelihood)
122+
123+
124+
def test_shared_state_from_default_returns_none():
125+
analysis = af.ex.Analysis(
126+
data=np.arange(10, dtype=float), noise_map=np.ones(10)
127+
)
128+
model = af.Model(af.ex.Gaussian)
129+
instance = model.instance_from_unit_vector([0.5] * model.prior_count)
130+
131+
assert analysis.shared_state_from(instance) is None

0 commit comments

Comments
 (0)