Skip to content

Commit 6410f2d

Browse files
Jammy2211claude
authored andcommitted
perf: direct-ndtr fast path for TruncatedGaussianPrior.value_for
Replace scipy.stats.norm.cdf/ppf inside the truncated-normal inverse-CDF path with direct scipy.special.ndtr / ndtri (and the jax.scipy.special equivalents on the JAX branch). The wrapper-free path skips scipy.stats._distn_infrastructure dispatch -- which the graphical-ep-scale-up cProfile baseline (autofit_workspace_developer PR #17) showed was the #1 hotspot in TruncatedGaussianPrior.value_for (~33% of total wall time at N=10). The new helper module autofit.mapper.prior._erf_helpers exposes truncated_normal_value_for(...) which both TruncatedGaussianPrior and TruncatedNormalMessage now route through. ndtr/ndtri are bit-exact equivalents of scipy.stats.norm.cdf/ppf (same Cephes routines). Measured on the autofit_workspace_developer toy 1D Gaussian baseline: - graphical N=3: 22.8s -> 5.6s (4.04x, 75% reduction) - EP N=3: 251.9s -> 76.3s (3.30x, 70% reduction) Sanity blocks PASS on both; max log L matches pre-fix within Dynesty stochastic noise (~1e-3 rel - same scale as re-running with a different seed). Closes #1284. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent a89fa59 commit 6410f2d

4 files changed

Lines changed: 189 additions & 31 deletions

File tree

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""Direct-`ndtr` primitives for hot prior-transform paths.
2+
3+
Replaces `scipy.stats.norm.cdf` / `norm.ppf` (and their `jax.scipy.stats`
4+
counterparts) with direct calls to `scipy.special.ndtr` / `ndtri` — the
5+
Cephes routines that scipy.stats wraps. Bit-exact equivalent on both
6+
NumPy and JAX backends, but skips the
7+
`scipy.stats._distn_infrastructure` wrapper overhead — which the
8+
graphical-ep-scale-up cProfile baseline showed was the #1 hotspot in
9+
`TruncatedGaussianPrior.value_for` (~33% of total wall time at N=10).
10+
11+
See PyAutoFit issue #1284 for the motivating measurements.
12+
"""
13+
14+
import numpy as np
15+
16+
17+
def _norm_cdf(z, xp):
18+
"""Standard-normal CDF (== ``scipy.stats.norm.cdf(z)`` to ULPs)."""
19+
if xp is np:
20+
from scipy.special import ndtr
21+
else:
22+
from jax.scipy.special import ndtr
23+
return ndtr(z)
24+
25+
26+
def _norm_ppf(p, xp):
27+
"""Standard-normal PPF (== ``scipy.stats.norm.ppf(p)`` to ULPs)."""
28+
if xp is np:
29+
from scipy.special import ndtri
30+
else:
31+
from jax.scipy.special import ndtri
32+
return ndtri(p)
33+
34+
35+
def truncated_normal_value_for(unit, mean, sigma, lower_limit, upper_limit, xp=np):
36+
"""Inverse-CDF mapping for a truncated normal distribution.
37+
38+
Returns ``mean + sigma * Phi^{-1}(Phi(a) + unit * (Phi(b) - Phi(a)))``
39+
where ``a = (lower_limit - mean) / sigma`` and
40+
``b = (upper_limit - mean) / sigma``.
41+
42+
Used by ``TruncatedGaussianPrior.value_for`` and
43+
``TruncatedNormalMessage.value_for`` to share a single
44+
`scipy.special.erf`-based code path on both NumPy and JAX backends.
45+
46+
Parameters
47+
----------
48+
unit
49+
Unit-cube draw(s) in ``[0, 1]``. Scalar or array.
50+
mean, sigma
51+
Underlying-Gaussian mean and standard deviation.
52+
lower_limit, upper_limit
53+
Truncation bounds. ``-inf`` / ``+inf`` are supported.
54+
xp
55+
Array module: ``numpy`` (default) or ``jax.numpy``. Determines
56+
whether ``scipy.special`` or ``jax.scipy.special`` is used.
57+
58+
Returns
59+
-------
60+
Physical sample(s) drawn from the truncated normal.
61+
"""
62+
a = (lower_limit - mean) / sigma
63+
b = (upper_limit - mean) / sigma
64+
65+
lower_cdf = _norm_cdf(a, xp)
66+
upper_cdf = _norm_cdf(b, xp)
67+
truncated_cdf = lower_cdf + unit * (upper_cdf - lower_cdf)
68+
69+
x_standard = _norm_ppf(truncated_cdf, xp)
70+
return mean + sigma * x_standard

autofit/mapper/prior/truncated_gaussian.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -143,21 +143,12 @@ def value_for(self, unit, xp=np):
143143
A unit value between 0 and 1.
144144
xp
145145
Array-module to dispatch on (``numpy`` or ``jax.numpy``). Default ``numpy``.
146-
Both paths share the standard truncated-normal inverse-CDF construction
147-
via ``norm.cdf`` / ``norm.ppf`` from the matching ``scipy.stats`` /
148-
``jax.scipy.stats`` namespace.
146+
Delegates to ``_erf_helpers.truncated_normal_value_for``, which uses
147+
``scipy.special.erf`` / ``erfinv`` (or the ``jax.scipy.special``
148+
equivalents) directly — skipping the ``scipy.stats`` wrapper that
149+
previously dominated this hot path.
149150
"""
150-
if xp is np:
151-
from scipy.stats import norm
152-
else:
153-
from jax.scipy.stats import norm
154-
155-
a = (self.lower_limit - self.mean) / self.sigma
156-
b = (self.upper_limit - self.mean) / self.sigma
157-
158-
lower_cdf = norm.cdf(a)
159-
upper_cdf = norm.cdf(b)
160-
truncated_cdf = lower_cdf + unit * (upper_cdf - lower_cdf)
161-
162-
x_standard = norm.ppf(truncated_cdf)
163-
return self.mean + self.sigma * x_standard
151+
from autofit.mapper.prior._erf_helpers import truncated_normal_value_for
152+
return truncated_normal_value_for(
153+
unit, self.mean, self.sigma, self.lower_limit, self.upper_limit, xp,
154+
)

autofit/messages/truncated_normal.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -463,20 +463,10 @@ def value_for(self, unit, xp=np):
463463
>>> prior = af.TruncatedNormalMessage(mean=1.0, sigma=2.0, lower_limit=0.0, upper_limit=2.0)
464464
>>> physical_value = prior.value_for(unit=0.5)
465465
"""
466-
if xp is np:
467-
from scipy.stats import norm
468-
else:
469-
from jax.scipy.stats import norm
470-
471-
a = (self.lower_limit - self.mean) / self.sigma
472-
b = (self.upper_limit - self.mean) / self.sigma
473-
474-
lower_cdf = norm.cdf(a)
475-
upper_cdf = norm.cdf(b)
476-
truncated_cdf = lower_cdf + unit * (upper_cdf - lower_cdf)
477-
478-
x_standard = norm.ppf(truncated_cdf)
479-
return self.mean + self.sigma * x_standard
466+
from autofit.mapper.prior._erf_helpers import truncated_normal_value_for
467+
return truncated_normal_value_for(
468+
unit, self.mean, self.sigma, self.lower_limit, self.upper_limit, xp,
469+
)
480470

481471
def log_prior_from_value(self, value: float, xp=np) -> float:
482472
"""

test_autofit/mapper/prior/test_truncated_gaussian.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
import math
2+
13
import numpy as np
24
import pytest
5+
from scipy.stats import norm, truncnorm
36

47
import autofit as af
8+
from autofit.messages.truncated_normal import TruncatedNormalMessage
59

610

711
@pytest.fixture(name="truncated_gaussian")
@@ -34,3 +38,106 @@ def test__log_prior_from_value(truncated_gaussian, unit, value):
3438
assert truncated_gaussian.log_prior_from_value(unit) == pytest.approx(value, rel=0.1)
3539

3640

41+
# --- Numerical equivalence: new direct-ndtr path vs the OLD scipy.stats.norm
42+
# CDF/PPF composition that this PR replaces. They must be bit-exact equal —
43+
# that's the "numerics don't change" gate.
44+
45+
PARAMS = [
46+
# (mean, sigma, lower_limit, upper_limit)
47+
(0.0, 1.0, -3.0, 3.0), # symmetric, moderate
48+
(0.0, 1.0, -10.0, 10.0), # very wide
49+
(5.0, 2.0, 0.0, math.inf), # half-bounded (matches toy normalization)
50+
(5.0, 5.0, 0.0, math.inf), # half-bounded (matches toy sigma)
51+
(1.0, 2.0, 0.95, 1.05), # narrow bracket
52+
(0.0, 1.0, -0.001, 0.001), # very narrow
53+
]
54+
55+
UNITS = [1e-6, 1e-3, 0.1, 0.3, 0.5, 0.7, 0.9, 1 - 1e-3, 1 - 1e-6]
56+
57+
58+
def _old_value_for(unit, mean, sigma, lower, upper):
59+
"""Reproduces the pre-refactor scipy.stats.norm.cdf/ppf composition.
60+
This is the algorithm whose results must be preserved."""
61+
a = (lower - mean) / sigma
62+
b = (upper - mean) / sigma
63+
lower_cdf = norm.cdf(a)
64+
upper_cdf = norm.cdf(b)
65+
truncated_cdf = lower_cdf + unit * (upper_cdf - lower_cdf)
66+
x_standard = norm.ppf(truncated_cdf)
67+
return mean + sigma * x_standard
68+
69+
70+
@pytest.mark.parametrize("mean,sigma,lower,upper", PARAMS)
71+
@pytest.mark.parametrize("unit", UNITS)
72+
def test__prior_value_for_bit_exact_to_old_path(unit, mean, sigma, lower, upper):
73+
"""`TruncatedGaussianPrior.value_for` must produce results bit-exact to
74+
the pre-refactor scipy.stats.norm.cdf/ppf composition that this PR
75+
replaces. This is the "numerics don't change" gate at the algorithmic
76+
level — both paths share the same ndtr/ndtri Cephes routines, only the
77+
Python-side wrapper differs."""
78+
prior = af.TruncatedGaussianPrior(
79+
mean=mean, sigma=sigma, lower_limit=lower, upper_limit=upper,
80+
)
81+
expected = float(_old_value_for(unit, mean, sigma, lower, upper))
82+
actual = float(prior.value_for(unit))
83+
84+
if expected == 0.0:
85+
assert actual == 0.0
86+
else:
87+
# Same Cephes routines under the hood — must be bit-exact.
88+
assert actual == expected, f"new={actual!r} old={expected!r}"
89+
90+
91+
@pytest.mark.parametrize("mean,sigma,lower,upper", PARAMS)
92+
@pytest.mark.parametrize("unit", [0.1, 0.3, 0.5, 0.7, 0.9])
93+
def test__prior_value_for_close_to_scipy_truncnorm(unit, mean, sigma, lower, upper):
94+
"""`TruncatedGaussianPrior.value_for` matches scipy.stats.truncnorm.ppf
95+
away from the deep tails. scipy.stats.truncnorm uses its own tail-safe
96+
branching that the simple ndtr/ndtri composition does not — so this
97+
test deliberately covers only ``unit in [0.1, 0.9]`` where both paths
98+
are stable. Documents the precision regime; not a regression gate."""
99+
prior = af.TruncatedGaussianPrior(
100+
mean=mean, sigma=sigma, lower_limit=lower, upper_limit=upper,
101+
)
102+
a = (lower - mean) / sigma
103+
b = (upper - mean) / sigma
104+
expected = float(truncnorm.ppf(unit, a=a, b=b, loc=mean, scale=sigma))
105+
actual = float(prior.value_for(unit))
106+
107+
if expected == 0.0:
108+
assert actual == pytest.approx(0.0, abs=1e-12)
109+
else:
110+
assert actual == pytest.approx(expected, rel=1e-10)
111+
112+
113+
@pytest.mark.parametrize("mean,sigma,lower,upper", PARAMS)
114+
@pytest.mark.parametrize("unit", UNITS)
115+
def test__message_value_for_matches_prior(unit, mean, sigma, lower, upper):
116+
"""`TruncatedNormalMessage.value_for` must produce the same output as
117+
`TruncatedGaussianPrior.value_for` for matching parameters — both now
118+
route through the shared helper, so the equality is bit-exact."""
119+
prior = af.TruncatedGaussianPrior(
120+
mean=mean, sigma=sigma, lower_limit=lower, upper_limit=upper,
121+
)
122+
message = TruncatedNormalMessage(
123+
mean=mean, sigma=sigma, lower_limit=lower, upper_limit=upper,
124+
)
125+
assert float(message.value_for(unit)) == float(prior.value_for(unit))
126+
127+
128+
def test__jax_value_for_parity():
129+
"""JAX path must match the numpy path to within float64 rounding noise.
130+
131+
Uses moderate (half-bounded) parameters representative of the toy model.
132+
Skipped if jax is not installed; CI / dev installs both.
133+
"""
134+
jax = pytest.importorskip("jax")
135+
jnp = jax.numpy
136+
137+
prior = af.TruncatedGaussianPrior(
138+
mean=5.0, sigma=2.0, lower_limit=0.0, upper_limit=math.inf,
139+
)
140+
for unit in [0.1, 0.5, 0.9]:
141+
numpy_val = float(prior.value_for(unit, xp=np))
142+
jax_val = float(prior.value_for(jnp.asarray(unit), xp=jnp))
143+
assert jax_val == pytest.approx(numpy_val, rel=1e-9)

0 commit comments

Comments
 (0)