Skip to content

Commit cde3ba4

Browse files
authored
Merge pull request #1263 from PyAutoLabs/feature/priors-jax-native
feat: JAX-native priors — xp dispatch on value_for / log_prior_from_value / vector_from_unit_vector
2 parents 4652ed8 + 2e35407 commit cde3ba4

9 files changed

Lines changed: 140 additions & 37 deletions

File tree

autofit/mapper/prior/abstract.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from copy import copy
55
from typing import Union, Tuple, Optional, Dict
66

7+
import numpy as np
8+
79
from autoconf import conf
810

911
from autofit.mapper.prior.arithmetic import ArithmeticMixin
@@ -137,7 +139,7 @@ def random(
137139
)
138140
)
139141

140-
def value_for(self, unit: float) -> float:
142+
def value_for(self, unit, xp=np):
141143
"""
142144
Return a physical value for a value between 0 and 1 with the transformation
143145
described by this prior.
@@ -146,6 +148,11 @@ def value_for(self, unit: float) -> float:
146148
----------
147149
unit
148150
A unit value between 0 and 1.
151+
xp
152+
Array-module to dispatch on (``numpy`` or ``jax.numpy``). Default ``numpy``.
153+
Concrete subclasses override this method to provide a JAX-traceable
154+
closed-form when ``xp`` is ``jax.numpy``; the base path delegates to
155+
the message stack which is scipy-backed and NumPy-only.
149156
150157
Returns
151158
-------

autofit/mapper/prior/gaussian.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,21 @@ def parameter_string(self) -> str:
113113
Return a human-readable string summarizing the GaussianPrior parameters.
114114
"""
115115
return f"mean = {self.mean}, sigma = {self.sigma}"
116+
117+
def value_for(self, unit, xp=np):
118+
"""
119+
Map a unit value in [0, 1] to a physical value drawn from this Gaussian prior.
120+
121+
Parameters
122+
----------
123+
unit
124+
A unit value between 0 and 1.
125+
xp
126+
Array-module to dispatch on (``numpy`` or ``jax.numpy``). Default ``numpy``.
127+
The NumPy path delegates to the message stack (``erfinv`` via scipy); the
128+
JAX path uses the same closed-form via ``jax.scipy.special.erfinv``.
129+
"""
130+
if xp is np:
131+
return self.message.value_for(unit)
132+
from jax.scipy.special import erfinv
133+
return self.mean + self.sigma * xp.sqrt(2.0) * erfinv(2.0 * unit - 1.0)

autofit/mapper/prior/log_gaussian.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def _new_for_base_message(self, message):
113113
id_=self.instance().id,
114114
)
115115

116-
def value_for(self, unit: float) -> float:
116+
def value_for(self, unit, xp=np):
117117
"""
118118
Return a physical value for a value between 0 and 1 with the transformation
119119
described by this prior.
@@ -122,21 +122,33 @@ def value_for(self, unit: float) -> float:
122122
----------
123123
unit
124124
A unit value between 0 and 1.
125+
xp
126+
Array-module to dispatch on (``numpy`` or ``jax.numpy``). Default ``numpy``.
127+
The NumPy path delegates to the message stack; the JAX path uses the
128+
closed-form ``exp(mean + sigma * sqrt(2) * erfinv(2*unit - 1))``.
125129
126130
Returns
127131
-------
128132
A physical value, mapped from the unit value accoridng to the prior.
129133
"""
130-
return super().value_for(unit)
134+
if xp is np:
135+
return super().value_for(unit)
136+
from jax.scipy.special import erfinv
137+
log_value = self.mean + self.sigma * xp.sqrt(2.0) * erfinv(2.0 * unit - 1.0)
138+
return xp.exp(log_value)
131139

132140
@property
133141
def parameter_string(self) -> str:
134142
return f"mean = {self.mean}, sigma = {self.sigma}"
135143

136144
def log_prior_from_value(self, value, xp=np):
137-
if value <= 0:
138-
return float("-inf")
139-
140-
return self.message.base_message.log_prior_from_value(np.log(value)) - np.log(
141-
value
142-
)
145+
if xp is np:
146+
if value <= 0:
147+
return float("-inf")
148+
return self.message.base_message.log_prior_from_value(
149+
np.log(value)
150+
) - np.log(value)
151+
152+
log_value = xp.log(value)
153+
base_log_prior = (log_value - self.mean) ** 2 / (2 * self.sigma ** 2)
154+
return xp.where(value > 0, base_log_prior - log_value, xp.inf)

autofit/mapper/prior/log_uniform.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,17 +121,23 @@ def log_prior_from_value(self, value, xp=np) -> float:
121121
----------
122122
value : float
123123
The physical value of this prior's corresponding parameter in a `NonLinearSearch` sample.
124+
xp
125+
Array-module to dispatch on (``numpy`` or ``jax.numpy``). Default ``numpy``.
124126
"""
125127
return 1.0 / value
126128

127-
def value_for(self, unit: float) -> float:
129+
def value_for(self, unit, xp=np):
128130
"""
129131
Returns a physical value from an input unit value according to the limits of the log10 uniform prior.
130132
131133
Parameters
132134
----------
133135
unit
134136
A unit value between 0 and 1.
137+
xp
138+
Array-module to dispatch on (``numpy`` or ``jax.numpy``). Default ``numpy``.
139+
The NumPy path delegates to the message stack (scipy-backed); the JAX
140+
path uses the closed-form ``lower * (upper / lower) ** unit``.
135141
136142
Returns
137143
-------
@@ -145,7 +151,9 @@ def value_for(self, unit: float) -> float:
145151
146152
physical_value = prior.value_for(unit=0.2)
147153
"""
148-
return super().value_for(unit)
154+
if xp is np:
155+
return super().value_for(unit)
156+
return self.lower_limit * (self.upper_limit / self.lower_limit) ** unit
149157

150158
def dict(self) -> dict:
151159
"""

autofit/mapper/prior/truncated_gaussian.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Optional, Tuple
22

3+
import numpy as np
4+
35
from autofit.messages.truncated_normal import TruncatedNormalMessage
46
from .abstract import Prior
57

@@ -130,3 +132,32 @@ def parameter_string(self) -> str:
130132
f"lower_limit = {self.lower_limit}, "
131133
f"upper_limit = {self.upper_limit}"
132134
)
135+
136+
def value_for(self, unit, xp=np):
137+
"""
138+
Map a unit value in [0, 1] to a physical value drawn from this truncated Gaussian prior.
139+
140+
Parameters
141+
----------
142+
unit
143+
A unit value between 0 and 1.
144+
xp
145+
Array-module to dispatch on (``numpy`` or ``jax.numpy``). Default ``numpy``.
146+
Both paths share the standard truncated-normal inverse-CDF construction
147+
via ``norm.cdf`` / ``norm.ppf`` from the matching ``scipy.stats`` /
148+
``jax.scipy.stats`` namespace.
149+
"""
150+
if xp is np:
151+
from scipy.stats import norm
152+
else:
153+
from jax.scipy.stats import norm
154+
155+
a = (self.lower_limit - self.mean) / self.sigma
156+
b = (self.upper_limit - self.mean) / self.sigma
157+
158+
lower_cdf = norm.cdf(a)
159+
upper_cdf = norm.cdf(b)
160+
truncated_cdf = lower_cdf + unit * (upper_cdf - lower_cdf)
161+
162+
x_standard = norm.ppf(truncated_cdf)
163+
return self.mean + self.sigma * x_standard

autofit/mapper/prior/uniform.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,14 +132,19 @@ def parameter_string(self) -> str:
132132
"""A human-readable string summarizing the prior's lower and upper limits."""
133133
return f"lower_limit = {self.lower_limit}, upper_limit = {self.upper_limit}"
134134

135-
def value_for(self, unit: float) -> float:
135+
def value_for(self, unit, xp=np):
136136
"""
137137
Returns a physical value from an input unit value according to the limits of the uniform prior.
138138
139139
Parameters
140140
----------
141141
unit
142142
A unit value between 0 and 1.
143+
xp
144+
Array-module to dispatch on (``numpy`` or ``jax.numpy``). Default ``numpy``.
145+
The NumPy path preserves the historical ``float(round(..., 14))`` snap
146+
(used as a hash key in ``model.priors``); the JAX path uses the
147+
closed-form ``lower + (upper - lower) * unit`` so the trace stays symbolic.
143148
144149
Returns
145150
-------
@@ -153,9 +158,11 @@ def value_for(self, unit: float) -> float:
153158
154159
physical_value = prior.value_for(unit=0.2)
155160
"""
156-
return float(
157-
round(super().value_for(unit), 14)
158-
)
161+
if xp is np:
162+
return float(
163+
round(super().value_for(unit), 14)
164+
)
165+
return self.lower_limit + (self.upper_limit - self.lower_limit) * unit
159166

160167
def log_prior_from_value(self, value, xp=np):
161168
"""
@@ -166,7 +173,10 @@ def log_prior_from_value(self, value, xp=np):
166173
167174
For a UniformPrior this is always zero, provided the value is between the lower and upper limit.
168175
"""
169-
return 0.0
176+
if xp is np:
177+
return 0.0
178+
in_bounds = (value >= self.lower_limit) & (value <= self.upper_limit)
179+
return xp.where(in_bounds, xp.zeros_like(value), -xp.inf)
170180

171181
@property
172182
def limits(self) -> Tuple[float, float]:

autofit/mapper/prior_model/abstract.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -645,26 +645,41 @@ def priors_ordered_by_id(self):
645645
"""Unique priors sorted by their id, defining the canonical parameter ordering."""
646646
return [prior for _, prior in self.prior_tuples_ordered_by_id]
647647

648-
def vector_from_unit_vector(self, unit_vector):
648+
def vector_from_unit_vector(self, unit_vector, xp=np):
649649
"""
650650
Parameters
651651
----------
652652
unit_vector: [float]
653653
A unit hypercube vector
654+
xp
655+
Array-module to dispatch on (``numpy`` or ``jax.numpy``). Default ``numpy``.
656+
When ``xp is numpy`` the return type stays a plain ``list`` to preserve
657+
the existing contract used by Nautilus / Dynesty / Emcee / Zeus. When
658+
``xp is jax.numpy`` an ``xp.stack`` 1-D array is returned so the call
659+
is JIT-traceable end-to-end.
654660
655661
Returns
656662
-------
657663
values: [float]
658664
A vector with values output by priors
659665
"""
660-
return list(
661-
map(
662-
lambda prior_tuple, unit: prior_tuple.prior.value_for(
663-
unit,
664-
),
665-
self.prior_tuples_ordered_by_id,
666-
unit_vector,
666+
if xp is np:
667+
return list(
668+
map(
669+
lambda prior_tuple, unit: prior_tuple.prior.value_for(
670+
unit,
671+
),
672+
self.prior_tuples_ordered_by_id,
673+
unit_vector,
674+
)
667675
)
676+
return xp.stack(
677+
[
678+
prior_tuple.prior.value_for(unit, xp=xp)
679+
for prior_tuple, unit in zip(
680+
self.prior_tuples_ordered_by_id, unit_vector
681+
)
682+
]
668683
)
669684

670685
def random_unit_vector_within_limits(

autofit/messages/normal.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -399,14 +399,16 @@ def logpdf_gradient_hessian(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray
399399

400400
__default_fields__ = ("log_norm", "id_")
401401

402-
def value_for(self, unit: float) -> float:
402+
def value_for(self, unit, xp=np):
403403
"""
404404
Map a unit value in [0, 1] to a physical value drawn from this Gaussian prior.
405405
406406
Parameters
407407
----------
408408
unit
409409
A unit value between 0 and 1 representing a uniform draw.
410+
xp
411+
Array-module to dispatch on (``numpy`` or ``jax.numpy``). Default ``numpy``.
410412
411413
Returns
412414
-------
@@ -417,15 +419,13 @@ def value_for(self, unit: float) -> float:
417419
>>> prior = af.GaussianPrior(mean=1.0, sigma=2.0)
418420
>>> physical_value = prior.value_for(unit=0.5)
419421
"""
420-
if isinstance(unit, (np.ndarray, np.float64, float, int, list)):
421-
from scipy.special import erfinv as scipy_erfinv
422-
inv = scipy_erfinv(1 - 2.0 * (1.0 - unit))
422+
if xp is np:
423+
from scipy.special import erfinv
423424
else:
424-
import jax.numpy as jnp
425-
from jax._src.scipy.special import erfinv
426-
inv = erfinv(1 - 2.0 * (1.0 - unit))
425+
from jax.scipy.special import erfinv
427426

428-
return self.mean + (self.sigma * np.sqrt(2) * inv)
427+
inv = erfinv(2.0 * unit - 1.0)
428+
return self.mean + self.sigma * xp.sqrt(2.0) * inv
429429

430430
def log_prior_from_value(self, value: float, xp=np) -> float:
431431
"""

autofit/messages/truncated_normal.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def logpdf_gradient_hessian(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray
440440

441441
__default_fields__ = ("log_norm", "id_")
442442

443-
def value_for(self, unit: float) -> float:
443+
def value_for(self, unit, xp=np):
444444
"""
445445
Map a unit value in [0, 1] to a physical value drawn from this truncated Gaussian prior.
446446
@@ -451,6 +451,8 @@ def value_for(self, unit: float) -> float:
451451
----------
452452
unit
453453
A unit value between 0 and 1 representing a uniform draw.
454+
xp
455+
Array-module to dispatch on (``numpy`` or ``jax.numpy``). Default ``numpy``.
454456
455457
Returns
456458
-------
@@ -461,18 +463,18 @@ def value_for(self, unit: float) -> float:
461463
>>> prior = af.TruncatedNormalMessage(mean=1.0, sigma=2.0, lower_limit=0.0, upper_limit=2.0)
462464
>>> physical_value = prior.value_for(unit=0.5)
463465
"""
464-
from scipy.stats import norm
466+
if xp is np:
467+
from scipy.stats import norm
468+
else:
469+
from jax.scipy.stats import norm
465470

466-
# Standardized truncation bounds
467471
a = (self.lower_limit - self.mean) / self.sigma
468472
b = (self.upper_limit - self.mean) / self.sigma
469473

470-
# Interpolate unit into [Phi(a), Phi(b)]
471474
lower_cdf = norm.cdf(a)
472475
upper_cdf = norm.cdf(b)
473476
truncated_cdf = lower_cdf + unit * (upper_cdf - lower_cdf)
474477

475-
# Map back to x using inverse CDF, then rescale
476478
x_standard = norm.ppf(truncated_cdf)
477479
return self.mean + self.sigma * x_standard
478480

0 commit comments

Comments
 (0)