|
1 | 1 | from typing import Union, Optional |
2 | 2 |
|
3 | 3 | from autofit.graphical.declarative.factor.hierarchical import HierarchicalFactor |
4 | | -from autofit.mapper.model import ModelInstance |
| 4 | + |
5 | 5 | from autofit.tools.namer import namer |
6 | 6 | from .abstract import AbstractDeclarativeFactor |
7 | 7 | from autofit.non_linear.paths.abstract import AbstractPaths |
8 | 8 | from autofit.non_linear.samples.pdf import SamplesPDF |
9 | 9 | from autofit.non_linear.samples.summary import SamplesSummary |
10 | 10 | from autofit.non_linear.analysis.combined import CombinedResult |
11 | 11 |
|
| 12 | +from autofit.mapper.model import ModelInstance |
| 13 | +from autofit.mapper.prior_model.prior_model import Model |
| 14 | + |
12 | 15 | from autofit.jax_wrapper import register_pytree_node_class |
13 | 16 |
|
14 | 17 |
|
@@ -165,3 +168,114 @@ def make_result( |
165 | 168 | search_internal=search_internal, |
166 | 169 | analysis=analysis, |
167 | 170 | ) |
| 171 | + |
| 172 | + def _for_each_analysis( |
| 173 | + self, |
| 174 | + name, |
| 175 | + paths, |
| 176 | + *args, |
| 177 | + **kwargs, |
| 178 | + ): |
| 179 | + """ |
| 180 | + Convenience function to call an underlying function for each |
| 181 | + analysis with a paths object with an integer attached to the |
| 182 | + end. |
| 183 | +
|
| 184 | + Parameters |
| 185 | + ---------- |
| 186 | + func |
| 187 | + Some function of the analysis class |
| 188 | + paths |
| 189 | + An object describing the paths for saving data (e.g. hard-disk directories or entries in sqlite database). |
| 190 | + """ |
| 191 | + results = [] |
| 192 | + for (i, analysis), *args in zip( |
| 193 | + enumerate(self.model_factors), |
| 194 | + *args, |
| 195 | + ): |
| 196 | + child_paths = paths.for_sub_analysis(analysis_name=f"analyses/analysis_{i}") |
| 197 | + func = getattr(analysis, name) |
| 198 | + results.append( |
| 199 | + func( |
| 200 | + child_paths, |
| 201 | + *args, |
| 202 | + **kwargs, |
| 203 | + ) |
| 204 | + ) |
| 205 | + |
| 206 | + return results |
| 207 | + |
| 208 | + def visualize( |
| 209 | + self, |
| 210 | + paths: AbstractPaths, |
| 211 | + instance: ModelInstance, |
| 212 | + during_analysis: bool, |
| 213 | + ): |
| 214 | + """ |
| 215 | + Visualise the instances provided using each factor. |
| 216 | +
|
| 217 | + Instances in the ModelInstance must have the same order as the factors. |
| 218 | +
|
| 219 | + Parameters |
| 220 | + ---------- |
| 221 | + paths |
| 222 | + Object describing where data should be saved to |
| 223 | + instance |
| 224 | + A collection of instances, each corresponding to a factor |
| 225 | + during_analysis |
| 226 | + Is this visualisation during analysis? |
| 227 | + """ |
| 228 | + self._for_each_analysis( |
| 229 | + "visualize", |
| 230 | + paths, |
| 231 | + instance, |
| 232 | + during_analysis=during_analysis, |
| 233 | + ) |
| 234 | + |
| 235 | + def visualize_before_fit( |
| 236 | + self, |
| 237 | + paths: AbstractPaths, |
| 238 | + model: Model, |
| 239 | + ): |
| 240 | + """ |
| 241 | + Visualise the model provided using each factor. |
| 242 | +
|
| 243 | + Models in the ModelInstance must have the same order as the factors. |
| 244 | +
|
| 245 | + Parameters |
| 246 | + ---------- |
| 247 | + paths |
| 248 | + Object describing where data should be saved to |
| 249 | + model |
| 250 | + A collection of models, each corresponding to a factor |
| 251 | + """ |
| 252 | + self._for_each_analysis( |
| 253 | + "visualize_before_fit", |
| 254 | + paths, |
| 255 | + model, |
| 256 | + ) |
| 257 | + |
| 258 | + def save_attributes(self, paths: AbstractPaths): |
| 259 | + """ |
| 260 | + Save the attributes of the analysis to the paths object. |
| 261 | + """ |
| 262 | + self._for_each_analysis("save_attributes", paths) |
| 263 | + |
| 264 | + def save_results(self, paths: AbstractPaths, result): |
| 265 | + """ |
| 266 | + Save the results of the analysis to the paths object. |
| 267 | + """ |
| 268 | + self._for_each_analysis("save_results", paths, result) |
| 269 | + |
| 270 | + def visualize_combined( |
| 271 | + self, |
| 272 | + instance, |
| 273 | + paths: AbstractPaths, |
| 274 | + during_analysis, |
| 275 | + ): |
| 276 | + self.model_factors[0].visualize_combined( |
| 277 | + self.model_factors, |
| 278 | + paths, |
| 279 | + instance, |
| 280 | + during_analysis=during_analysis, |
| 281 | + ) |
0 commit comments