Skip to content

Commit 7640ac6

Browse files
authored
Merge pull request #1130 from rhayes777/feature/graphical_analysis_output
feature/graphical analysis output
2 parents 8a5a431 + 47d0315 commit 7640ac6

6 files changed

Lines changed: 479 additions & 301 deletions

File tree

autofit/config/output.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,4 +97,4 @@ covariance: true # `covariance.csv`: The [free parameters x free parameters] cov
9797
data: true # `data.json`: The value of every data point in the data.
9898
noise_map: true # `noise_map.json`: The value of every RMS noise map value.
9999

100-
search_log: true # `search.log`: logging produced whilst running the fit or fit_sequential method
100+
search_log: true # `search.log`: logging produced whilst running the fit method

autofit/graphical/declarative/abstract.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from autofit.graphical.declarative.graph import DeclarativeFactorGraph
88
from autofit.graphical.expectation_propagation import AbstractFactorOptimiser
99
from autofit.graphical.expectation_propagation import EPMeanField, EPOptimiser
10-
from autofit.mapper.model import ModelInstance
1110
from autofit.mapper.prior.abstract import Prior
1211
from autofit.mapper.prior_model.collection import Collection
1312
from autofit.mapper.variable import Plate
@@ -168,7 +167,7 @@ def optimise(
168167
optimiser: AbstractFactorOptimiser,
169168
paths: Optional[AbstractPaths] = None,
170169
ep_history: Optional = None,
171-
**kwargs
170+
**kwargs,
172171
):
173172
"""
174173
Use an EP Optimiser to optimise the graph associated with this collection
@@ -198,29 +197,6 @@ def optimise(
198197
updated_ep_mean_field=updated_ep_mean_field,
199198
)
200199

201-
# TODO : Visualize method before fit?
202-
203-
def visualize(
204-
self, paths: AbstractPaths, instance: ModelInstance, during_analysis: bool
205-
):
206-
"""
207-
Visualise the instances provided using each factor.
208-
209-
Instances in the ModelInstance must have the same order as the factors.
210-
211-
Parameters
212-
----------
213-
paths
214-
Object describing where data should be saved to
215-
instance
216-
A collection of instances, each corresponding to a factor
217-
during_analysis
218-
Is this visualisation during analysis?
219-
"""
220-
for model_factor, instance in zip(self.model_factors, instance):
221-
model_factor.visualize(paths, instance, during_analysis)
222-
model_factor.visualize_combined(None, paths, instance, during_analysis)
223-
224200
@property
225201
def global_prior_model(self) -> Collection:
226202
"""

autofit/graphical/declarative/collection.py

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
from typing import Union, Optional
22

33
from autofit.graphical.declarative.factor.hierarchical import HierarchicalFactor
4-
from autofit.mapper.model import ModelInstance
4+
55
from autofit.tools.namer import namer
66
from .abstract import AbstractDeclarativeFactor
77
from autofit.non_linear.paths.abstract import AbstractPaths
88
from autofit.non_linear.samples.pdf import SamplesPDF
99
from autofit.non_linear.samples.summary import SamplesSummary
1010
from autofit.non_linear.analysis.combined import CombinedResult
1111

12+
from autofit.mapper.model import ModelInstance
13+
from autofit.mapper.prior_model.prior_model import Model
14+
1215
from autofit.jax_wrapper import register_pytree_node_class
1316

1417

@@ -165,3 +168,114 @@ def make_result(
165168
search_internal=search_internal,
166169
analysis=analysis,
167170
)
171+
172+
def _for_each_analysis(
173+
self,
174+
name,
175+
paths,
176+
*args,
177+
**kwargs,
178+
):
179+
"""
180+
Convenience function to call an underlying function for each
181+
analysis with a paths object with an integer attached to the
182+
end.
183+
184+
Parameters
185+
----------
186+
func
187+
Some function of the analysis class
188+
paths
189+
An object describing the paths for saving data (e.g. hard-disk directories or entries in sqlite database).
190+
"""
191+
results = []
192+
for (i, analysis), *args in zip(
193+
enumerate(self.model_factors),
194+
*args,
195+
):
196+
child_paths = paths.for_sub_analysis(analysis_name=f"analyses/analysis_{i}")
197+
func = getattr(analysis, name)
198+
results.append(
199+
func(
200+
child_paths,
201+
*args,
202+
**kwargs,
203+
)
204+
)
205+
206+
return results
207+
208+
def visualize(
209+
self,
210+
paths: AbstractPaths,
211+
instance: ModelInstance,
212+
during_analysis: bool,
213+
):
214+
"""
215+
Visualise the instances provided using each factor.
216+
217+
Instances in the ModelInstance must have the same order as the factors.
218+
219+
Parameters
220+
----------
221+
paths
222+
Object describing where data should be saved to
223+
instance
224+
A collection of instances, each corresponding to a factor
225+
during_analysis
226+
Is this visualisation during analysis?
227+
"""
228+
self._for_each_analysis(
229+
"visualize",
230+
paths,
231+
instance,
232+
during_analysis=during_analysis,
233+
)
234+
235+
def visualize_before_fit(
236+
self,
237+
paths: AbstractPaths,
238+
model: Model,
239+
):
240+
"""
241+
Visualise the model provided using each factor.
242+
243+
Models in the ModelInstance must have the same order as the factors.
244+
245+
Parameters
246+
----------
247+
paths
248+
Object describing where data should be saved to
249+
model
250+
A collection of models, each corresponding to a factor
251+
"""
252+
self._for_each_analysis(
253+
"visualize_before_fit",
254+
paths,
255+
model,
256+
)
257+
258+
def save_attributes(self, paths: AbstractPaths):
259+
"""
260+
Save the attributes of the analysis to the paths object.
261+
"""
262+
self._for_each_analysis("save_attributes", paths)
263+
264+
def save_results(self, paths: AbstractPaths, result):
265+
"""
266+
Save the results of the analysis to the paths object.
267+
"""
268+
self._for_each_analysis("save_results", paths, result)
269+
270+
def visualize_combined(
271+
self,
272+
instance,
273+
paths: AbstractPaths,
274+
during_analysis,
275+
):
276+
self.model_factors[0].visualize_combined(
277+
self.model_factors,
278+
paths,
279+
instance,
280+
during_analysis=during_analysis,
281+
)

autofit/graphical/declarative/factor/analysis.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44

5+
from autofit.mapper.prior_model.prior_model import Model
56
from autofit.graphical.expectation_propagation import AbstractFactorOptimiser
67
from autofit.mapper.model import ModelInstance
78
from autofit.mapper.prior_model.prior_model import AbstractPriorModel
@@ -144,12 +145,49 @@ def visualize(
144145
self.analysis.visualize(
145146
paths=paths, instance=instance, during_analysis=during_analysis
146147
)
147-
self.analysis.visualize_combined(
148-
analyses=None,
149-
paths=paths,
150-
instance=instance,
151-
during_analysis=during_analysis,
152-
)
148+
149+
def visualize_before_fit(
150+
self,
151+
paths: AbstractPaths,
152+
model: Model,
153+
):
154+
"""
155+
Visualise the model provided using each factor.
156+
157+
Models in the ModelInstance must have the same order as the factors.
158+
159+
Parameters
160+
----------
161+
paths
162+
Object describing where data should be saved to
163+
model
164+
A collection of models, each corresponding to a factor
165+
"""
166+
self.analysis.visualize_before_fit(paths=paths, model=model)
167+
168+
def save_attributes(self, paths: AbstractPaths):
169+
"""
170+
Save the attributes of the analysis object to a file.
171+
172+
Parameters
173+
----------
174+
paths
175+
Object describing where data should be saved to
176+
"""
177+
self.analysis.save_attributes(paths=paths)
178+
179+
def save_results(self, paths: AbstractPaths, result):
180+
"""
181+
Save the results of the analysis to a file.
182+
183+
Parameters
184+
----------
185+
paths
186+
Object describing where data should be saved to
187+
result
188+
The result of the analysis
189+
"""
190+
self.analysis.save_results(paths=paths, result=result)
153191

154192
def log_likelihood_function(self, instance: ModelInstance) -> float:
155193
return self.analysis.log_likelihood_function(instance)

0 commit comments

Comments
 (0)