Skip to content

Commit 77f8d73

Browse files
authored
Merge pull request #1172 from rhayes777/feature/model_docs
Add expanded model mapping unit tests
2 parents d3638a7 + 2d91219 commit 77f8d73

7 files changed

Lines changed: 927 additions & 0 deletions

File tree

autofit/mapper/prior/abstract.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,18 @@ def unit_value_for(self, physical_value: float) -> float:
6161
return self.message.cdf(physical_value)
6262

6363
def with_message(self, message):
64+
"""Return a copy of this prior with a different message (distribution).
65+
66+
Parameters
67+
----------
68+
message
69+
The new message object defining the prior's distribution.
70+
71+
Returns
72+
-------
73+
Prior
74+
A copy of this prior using the new message.
75+
"""
6476
new = copy(self)
6577
new.message = message
6678
return new
@@ -88,6 +100,23 @@ def factor(self):
88100

89101
@staticmethod
90102
def for_class_and_attribute_name(cls, attribute_name):
103+
"""Create a prior from the configuration for a given class and attribute.
104+
105+
Looks up the prior type and parameters in the prior config files
106+
for the specified class and attribute name.
107+
108+
Parameters
109+
----------
110+
cls
111+
The model class whose config is looked up.
112+
attribute_name
113+
The name of the attribute on that class.
114+
115+
Returns
116+
-------
117+
Prior
118+
A prior instance constructed from the config entry.
119+
"""
91120
prior_dict = conf.instance.prior_config.for_class_and_suffix_path(
92121
cls, [attribute_name]
93122
)
@@ -129,10 +158,31 @@ def instance_for_arguments(
129158
arguments,
130159
ignore_assertions=False,
131160
):
161+
"""Look up this prior's value in an arguments dictionary.
162+
163+
Parameters
164+
----------
165+
arguments
166+
A dictionary mapping Prior objects to physical values.
167+
ignore_assertions
168+
Unused for priors (present for interface compatibility).
169+
"""
132170
_ = ignore_assertions
133171
return arguments[self]
134172

135173
def project(self, samples, weights):
174+
"""Project this prior given samples and log weights from a search.
175+
176+
Returns a copy of this prior whose message has been updated to
177+
reflect the posterior information from the samples.
178+
179+
Parameters
180+
----------
181+
samples
182+
Array of sample values for this parameter.
183+
weights
184+
Log weights for each sample.
185+
"""
136186
result = copy(self)
137187
result.message = self.message.project(
138188
samples=samples,
@@ -170,6 +220,11 @@ def __str__(self):
170220
@property
171221
@abstractmethod
172222
def parameter_string(self) -> str:
223+
"""A human-readable string summarizing this prior's parameters.
224+
225+
Subclasses must implement this to return a description such as
226+
``"mean = 0.0, sigma = 1.0"`` or ``"lower_limit = 0.0, upper_limit = 1.0"``.
227+
"""
173228
pass
174229

175230
def __float__(self):
@@ -254,7 +309,22 @@ def name_of_class(cls) -> str:
254309

255310
@property
256311
def limits(self) -> Tuple[float, float]:
312+
"""The (lower, upper) bounds of this prior.
313+
314+
Returns (-inf, inf) by default. Subclasses with finite bounds
315+
(e.g. UniformPrior) override this.
316+
"""
257317
return (float("-inf"), float("inf"))
258318

259319
def gaussian_prior_model_for_arguments(self, arguments):
320+
"""Look up this prior in an arguments dict and return the mapped value.
321+
322+
Used during prior replacement workflows where each prior is mapped
323+
to a new prior or fixed value via an arguments dictionary.
324+
325+
Parameters
326+
----------
327+
arguments
328+
A dictionary mapping Prior objects to their replacement values.
329+
"""
260330
return arguments[self]

autofit/mapper/prior/gaussian.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ def __init__(
5757
)
5858

5959
def tree_flatten(self):
60+
"""Flatten this prior into a JAX-compatible PyTree representation.
61+
62+
Returns
63+
-------
64+
tuple
65+
A (children, aux_data) pair where children are (mean, sigma, id).
66+
"""
6067
return (self.mean, self.sigma, self.id), ()
6168

6269
@classmethod

autofit/mapper/prior/uniform.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,23 +64,50 @@ def __init__(
6464
)
6565

6666
def tree_flatten(self):
67+
"""Flatten this prior into a JAX-compatible PyTree representation.
68+
69+
Returns
70+
-------
71+
tuple
72+
A (children, aux_data) pair where children are (lower_limit, upper_limit, id).
73+
"""
6774
return (self.lower_limit, self.upper_limit, self.id), ()
6875

6976
@property
7077
def width(self):
78+
"""The width of the uniform distribution (upper_limit - lower_limit)."""
7179
return self.upper_limit - self.lower_limit
7280

7381
def with_limits(
7482
self,
7583
lower_limit: float,
7684
upper_limit: float,
7785
) -> "Prior":
86+
"""Create a new UniformPrior with different bounds.
87+
88+
Parameters
89+
----------
90+
lower_limit
91+
The new lower bound.
92+
upper_limit
93+
The new upper bound.
94+
"""
7895
return UniformPrior(
7996
lower_limit=lower_limit,
8097
upper_limit=upper_limit,
8198
)
8299

83100
def logpdf(self, x):
101+
"""Compute the log probability density at x.
102+
103+
Adjusts boundary values by epsilon to avoid evaluating exactly at
104+
the distribution edges where the PDF is undefined.
105+
106+
Parameters
107+
----------
108+
x
109+
The value at which to evaluate the log PDF.
110+
"""
84111
# TODO: handle x as a numpy array
85112
if x == self.lower_limit:
86113
x += epsilon
@@ -102,6 +129,7 @@ def dict(self) -> dict:
102129

103130
@property
104131
def parameter_string(self) -> str:
132+
"""A human-readable string summarizing the prior's lower and upper limits."""
105133
return f"lower_limit = {self.lower_limit}, upper_limit = {self.upper_limit}"
106134

107135
def value_for(self, unit: float) -> float:
@@ -142,4 +170,5 @@ def log_prior_from_value(self, value, xp=np):
142170

143171
@property
144172
def limits(self) -> Tuple[float, float]:
173+
"""The (lower_limit, upper_limit) bounds of this uniform prior."""
145174
return self.lower_limit, self.upper_limit

0 commit comments

Comments
 (0)