Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 49 additions & 3 deletions autofit/example/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@ class Analysis(af.Analysis):

LATENT_KEYS = ["gaussian.fwhm"]

def __init__(self, data: np.ndarray, noise_map: np.ndarray, use_jax=False):
def __init__(
self,
data: np.ndarray,
noise_map: np.ndarray,
use_jax=False,
share_model_data=False,
):
"""
In this example the `Analysis` object only contains the data and noise-map. It can be easily extended,
for more complex data-sets and model fitting problems.
Expand All @@ -48,26 +54,66 @@ def __init__(self, data: np.ndarray, noise_map: np.ndarray, use_jax=False):
noise_map
A 1D numpy array containing the noise values of the data, used for computing the goodness of fit
metric.
share_model_data
If `True`, opt this `Analysis` into the `FactorGraphModel` cross-factor shared-state mechanism
(see `shared_state_from`). This is only valid when the *entire* model is shared across every
factor, so the model data is identical for all of them and can be computed once instead of being
rebuilt by each factor. It is `False` by default, so the standard per-analysis behaviour is
unchanged.
"""
super().__init__(use_jax=use_jax)

self.data = data
self.noise_map = noise_map
self.share_model_data = share_model_data

def shared_state_from(self, instance: af.ModelInstance):
"""
Compute the model data once so that it can be shared across the factors of a `FactorGraphModel`.

This is the worked example of `Analysis.shared_state_from` (see that method). When every factor of
the graph shares the *entire* model — for example several datasets fit by the same 1D profile via
shared priors — the model data is identical for every factor, so it is wasteful to rebuild it once
per factor. Returning it here means the `FactorGraphModel` computes it a single time on the lead
factor and reuses it for all the others.

In this toy the model data is cheap, but it stands in for an expensive shared computation: it is the
1D analog of the lensing case, where the shared work (ray-tracing, the source-plane mapper, the
mapping matrix and the curvature matrix) dominates the per-factor cost.

Sharing is opt-in (`share_model_data`) because it is only correct when the model really is fully
shared. If only some parameters are shared the model data differs between factors and this returns
`None`, so each factor computes its own as usual.
"""
if not self.share_model_data:
return None

return self.model_data_1d_from(instance=instance)

def log_likelihood_function(self, instance: af.ModelInstance, xp=np) -> float:
def log_likelihood_function(
self, instance: af.ModelInstance, shared=None, xp=np
) -> float:
"""
Determine the log likelihood of a fit of multiple profiles to the dataset.

Parameters
----------
instance : af.Collection
The model instances of the profiles.
shared
The model data shared across the factors of a `FactorGraphModel`, computed once by
`shared_state_from` (see that method). When provided it is used directly instead of being
recomputed here; when `None` (the default, e.g. a standalone fit) the model data is computed
as normal.

Returns
-------
The log likelihood value indicating how well this model fit the dataset.
"""
model_data_1d = self.model_data_1d_from(instance=instance)
if shared is None:
model_data_1d = self.model_data_1d_from(instance=instance)
else:
model_data_1d = shared

residual_map = self.data - model_data_1d
chi_squared_map = (residual_map / self.noise_map) ** 2.0
Expand Down
38 changes: 37 additions & 1 deletion autofit/graphical/declarative/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,17 @@ def log_likelihood_function(self, instance: ModelInstance) -> float:
Compute the combined likelihood of each factor from a collection of instances
with the same ordering as the factors.

Before the per-factor loop, the lead factor is asked to compute a `shared`
object via `Analysis.shared_state_from` (see that method). If it returns a
non-`None` value — the opt-in case — that object is forwarded as the `shared`
keyword argument to every factor's `log_likelihood_function`, so that work
which is identical for all factors at this point in parameter space is computed
once and reused rather than recomputed for each factor.

When no factor provides a shared object (the default) the loop calls each
factor's `log_likelihood_function` exactly as it would without this mechanism,
so existing graphs are unchanged.

Parameters
----------
instance
Expand All @@ -100,12 +111,37 @@ def log_likelihood_function(self, instance: ModelInstance) -> float:
-------
The combined likelihood of all factors
"""
shared = self._shared_state_from(instance)

log_likelihood = 0
for model_factor, instance_ in zip(self.model_factors, instance):
log_likelihood += model_factor.log_likelihood_function(instance_)
if shared is None:
log_likelihood += model_factor.log_likelihood_function(instance_)
else:
log_likelihood += model_factor.log_likelihood_function(
instance_, shared=shared
)

return log_likelihood

def _shared_state_from(self, instance: ModelInstance):
"""
Compute the per-evaluation object shared across factors, by asking each factor's
`Analysis` in turn (via `Analysis.shared_state_from`) until one returns a
non-`None` value — the "lead" factor. Returns `None` if no factor opts in, in
which case no state is shared this evaluation.
"""
for model_factor, instance_ in zip(self.model_factors, instance):
analysis = getattr(model_factor, "analysis", None)
shared_state_from = getattr(analysis, "shared_state_from", None)
if shared_state_from is None:
continue
shared = shared_state_from(instance_)
if shared is not None:
return shared

return None

@property
def model_factors(self):
model_factors = list()
Expand Down
6 changes: 4 additions & 2 deletions autofit/graphical/declarative/factor/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,10 @@ def save_results(self, paths: AbstractPaths, result):
"""
self.analysis.save_results(paths=paths, result=result)

def log_likelihood_function(self, instance: ModelInstance) -> float:
return self.analysis.log_likelihood_function(instance)
def log_likelihood_function(self, instance: ModelInstance, shared=None) -> float:
if shared is None:
return self.analysis.log_likelihood_function(instance)
return self.analysis.log_likelihood_function(instance, shared=shared)


class EPAnalysisFactor(AnalysisFactor):
Expand Down
2 changes: 1 addition & 1 deletion autofit/graphical/declarative/factor/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def message_dict(self) -> Dict[Prior, NormalMessage]:
def variable(self):
return self.drawn_prior

def log_likelihood_function(self, instance):
def log_likelihood_function(self, instance, shared=None):
return instance.distribution_model.message(instance.drawn_prior, xp=self._xp)

@property
Expand Down
6 changes: 5 additions & 1 deletion autofit/graphical/declarative/factor/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,17 @@ def analysis(self) -> "PriorFactor":
"""
return self

def log_likelihood_function(self, instance) -> float:
def log_likelihood_function(self, instance, shared=None) -> float:
"""
Compute the likelihood.

The instance is a collection with a single argument expressing a
possible value for this prior. The likelihood is computed by simply
evaluating the prior's PDF for the given value.

The `shared` argument (the cross-factor shared state of a
`FactorGraphModel`) is accepted for a uniform calling convention but is
not used by a prior factor.
"""
return self.prior.factor(instance[0])

Expand Down
42 changes: 41 additions & 1 deletion autofit/non_linear/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,49 @@ def with_model(self, model):

return ModelAnalysis(analysis=self, model=model)

def log_likelihood_function(self, instance):
def log_likelihood_function(self, instance, shared=None):
raise NotImplementedError()

def shared_state_from(self, instance):
"""
Optionally compute a per-evaluation object that is shared across the factors
of a `FactorGraphModel`.

This is the per-evaluation, cross-factor sibling of `modify_before_fit`. Where
`modify_before_fit` runs once before sampling to precompute analysis state that
does not depend on the model, `shared_state_from` runs once per likelihood
evaluation (the model parameters change every sample) and computes state that
is identical for every factor at the current point in parameter space.

When a `FactorGraphModel` evaluates its likelihood it calls this method on its
lead factor's `Analysis`. If the returned value is not `None` it is forwarded as
the `shared` keyword argument to every factor's `log_likelihood_function`, so
that work which is identical for all factors (because they share model
parameters) is computed once and reused rather than recomputed `N` times.

The default implementation returns `None`, meaning no state is shared and every
factor's `log_likelihood_function` runs exactly as it does without this
mechanism. An `Analysis` opts in by overriding this method.

The returned object must be a valid JAX pytree of traced arrays when the fit is
JIT-compiled: it is recomputed inside the jitted region each evaluation (it
depends on the traced model parameters) and must not be memoised on the instance.

Correctness is the responsibility of the overriding `Analysis`: only return a
shared object when the parameters it depends on really are shared across every
factor. If they are not, return `None` and let each factor compute its own state.

Parameters
----------
instance
The model instance of the factor whose `Analysis` is acting as the lead.

Returns
-------
An object shared across all factors for this evaluation, or `None` for no sharing.
"""
return None

def save_attributes(self, paths: AbstractPaths):
pass

Expand Down
131 changes: 131 additions & 0 deletions test_autofit/graphical/test_shared_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import itertools

import numpy as np
import pytest

import autofit as af
import autofit.graphical as g


@pytest.fixture(autouse=True)
def reset_ids():
af.Prior._ids = itertools.count()


class CountingAnalysis(af.ex.Analysis):
"""
An example `Analysis` that counts how many times the (notionally expensive) model
data computation runs, so a test can prove the `FactorGraphModel` shared-state
mechanism computes it once per evaluation rather than once per factor.
"""

def __init__(self, data, noise_map, share_model_data=True):
super().__init__(
data=data, noise_map=noise_map, share_model_data=share_model_data
)
self.model_data_calls = 0

def model_data_1d_from(self, instance):
self.model_data_calls += 1
return super().model_data_1d_from(instance=instance)


def _shared_gaussian_graph(analyses):
"""
Build a `FactorGraphModel` whose factors share the *entire* Gaussian model via
shared prior objects, so the model data is identical for every factor.
"""
centre = af.UniformPrior(lower_limit=0.0, upper_limit=10.0)
normalization = af.UniformPrior(lower_limit=0.0, upper_limit=10.0)
sigma = af.UniformPrior(lower_limit=0.0, upper_limit=10.0)

factors = []
for analysis in analyses:
gaussian = af.Model(af.ex.Gaussian)
gaussian.centre = centre
gaussian.normalization = normalization
gaussian.sigma = sigma
factors.append(af.AnalysisFactor(gaussian, analysis))

return g.FactorGraphModel(*factors)


def _datasets(n=3, size=10):
"""`n` distinct 1D datasets sharing a common noise map of ones."""
return [
(np.arange(size, dtype=float) + float(i), np.ones(size))
for i in range(n)
]


def _instance(collection):
prior_count = collection.global_prior_model.prior_count
return collection.global_prior_model.instance_from_unit_vector(
[0.5] * prior_count
)


def _reference_log_likelihood(collection, instance):
"""Sum each factor's likelihood with no sharing (each computes its own model data)."""
return sum(
factor.analysis.log_likelihood_function(instance_)
for factor, instance_ in zip(collection.model_factors, instance)
)


def test_shared_state_computed_once_per_evaluation():
analyses = [
CountingAnalysis(data, noise_map) for data, noise_map in _datasets(n=3)
]
collection = _shared_gaussian_graph(analyses)
instance = _instance(collection)

collection.log_likelihood_function(instance)

total_calls = sum(analysis.model_data_calls for analysis in analyses)
assert total_calls == 1


def test_shared_likelihood_equals_unshared_sum():
analyses = [
CountingAnalysis(data, noise_map) for data, noise_map in _datasets(n=3)
]
collection = _shared_gaussian_graph(analyses)
instance = _instance(collection)

shared_log_likelihood = collection.log_likelihood_function(instance)
reference_log_likelihood = _reference_log_likelihood(collection, instance)

assert shared_log_likelihood == pytest.approx(reference_log_likelihood)


def test_no_provider_graph_is_unchanged():
"""
With `share_model_data=False` no factor opts in, so no state is shared: each factor
computes its own model data (N calls) and the summed likelihood is unchanged.
"""
analyses = [
CountingAnalysis(data, noise_map, share_model_data=False)
for data, noise_map in _datasets(n=3)
]
collection = _shared_gaussian_graph(analyses)
instance = _instance(collection)

log_likelihood = collection.log_likelihood_function(instance)
reference_log_likelihood = _reference_log_likelihood(collection, instance)

total_calls = sum(analysis.model_data_calls for analysis in analyses)
# one call per factor from the graph evaluation, plus one per factor from the
# reference sum — the graph did not share, so it computed all three itself.
assert total_calls == 2 * len(analyses)
assert log_likelihood == pytest.approx(reference_log_likelihood)


def test_shared_state_from_default_returns_none():
analysis = af.ex.Analysis(
data=np.arange(10, dtype=float), noise_map=np.ones(10)
)
model = af.Model(af.ex.Gaussian)
instance = model.instance_from_unit_vector([0.5] * model.prior_count)

assert analysis.shared_state_from(instance) is None
Loading