Skip to content

Commit 2cddc73

Browse files
Jammy2211claude
authored andcommitted
feat: implement gradient/Hessian for TruncatedNormalMessage
TruncatedNormalMessage._normal_gradient_hessian previously raised NotImplementedError, which silently broke EP + Laplace optimisation for any GaussianPrior with finite limits — the EP outer loop catches the exception per factor, so runs looked successful but truncated parameters never received their cavity update. Implement the gradient and Hessian w.r.t. x: inside the truncation support these match the untruncated Gaussian (Z and log sigma are constant in x); logl additionally subtracts log Z. Outside the support, return logl = -inf and grad = 0 so the optimiser sees a flat region rather than NaNs. Override on TruncatedNaturalNormal because its mean / sigma properties return *truncated* moments (via scipy.stats.truncnorm), not the underlying Gaussian's parameters that the gradient formula needs; the override reconstructs the underlying (mu, sigma) from the natural parameters. Closes #1237 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 4d93bf0 commit 2cddc73

3 files changed

Lines changed: 246 additions & 1 deletion

File tree

autofit/messages/truncated_normal.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,62 @@ def from_mode(
349349
def _normal_gradient_hessian(
350350
self, x: np.ndarray
351351
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
352-
raise NotImplementedError
352+
"""
353+
Compute the log-pdf, gradient, and Hessian of the truncated Gaussian
354+
with respect to ``x``.
355+
356+
Inside the truncation support the gradient and Hessian are identical
357+
to the untruncated Gaussian's (the truncation normalisation ``Z`` and
358+
``log σ`` are constants in ``x``). The ``logl`` value picks up an
359+
extra ``-log Z`` correction. Outside the support, ``logl`` is ``-inf``
360+
and the gradient is zeroed so the optimiser sees a flat region rather
361+
than NaNs that would crash linesearch.
362+
"""
363+
return self._normal_gradient_hessian_from(self.mean, self.sigma, x)
364+
365+
def _normal_gradient_hessian_from(
366+
self,
367+
mean: Union[float, np.ndarray],
368+
sigma: Union[float, np.ndarray],
369+
x: np.ndarray,
370+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
371+
from scipy.stats import norm
372+
373+
a = (self.lower_limit - mean) / sigma
374+
b = (self.upper_limit - mean) / sigma
375+
Z = norm.cdf(b) - norm.cdf(a)
376+
log_Z = np.log(Z) if Z > 0 else -np.inf
377+
378+
shape = np.shape(x)
379+
if shape:
380+
x = np.asanyarray(x)
381+
deltax = x - mean
382+
hess_logl = -sigma ** -2
383+
grad_logl = deltax * hess_logl
384+
eta_t = 0.5 * grad_logl * deltax
385+
logl = self.log_base_measure + eta_t - np.log(sigma) - log_Z
386+
387+
in_bounds = (x >= self.lower_limit) & (x <= self.upper_limit)
388+
logl = np.where(in_bounds, logl, -np.inf)
389+
grad_logl = np.where(in_bounds, grad_logl, 0.0)
390+
391+
if shape[1:] == self.shape:
392+
hess_logl = np.repeat(
393+
np.reshape(hess_logl, (1,) + np.shape(hess_logl)), shape[0], 0
394+
)
395+
396+
else:
397+
deltax = x - mean
398+
hess_logl = -sigma ** -2
399+
grad_logl = deltax * hess_logl
400+
eta_t = 0.5 * grad_logl * deltax
401+
logl = self.log_base_measure + eta_t - np.log(sigma) - log_Z
402+
403+
if not (self.lower_limit <= x <= self.upper_limit):
404+
logl = -np.inf
405+
grad_logl = 0.0
406+
407+
return logl, grad_logl, hess_logl
353408

354409
def logpdf_gradient(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
355410
"""
@@ -644,6 +699,20 @@ def natural_parameters(self, xp=np) -> np.ndarray:
644699
"""
645700
return self.calc_natural_parameters(*self.parameters, self.lower_limit, self.upper_limit, xp=xp)
646701

702+
def _normal_gradient_hessian(
703+
self, x: np.ndarray
704+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
705+
# ``self.mean`` and ``self.sigma`` here are the truncated moments
706+
# (scipy.stats.truncnorm.mean / .std), not the underlying Gaussian.
707+
# The gradient/Hessian formula needs the underlying (μ, σ), which we
708+
# reconstruct from the natural parameters.
709+
precision = -2 * self.parameters[1]
710+
if np.any(precision <= 0) or np.any(np.isinf(precision)) or np.any(np.isnan(precision)):
711+
return np.nan, np.nan, np.nan
712+
mean_underlying = -self.parameters[0] / (2 * self.parameters[1])
713+
sigma_underlying = precision ** -0.5
714+
return self._normal_gradient_hessian_from(mean_underlying, sigma_underlying, x)
715+
647716
@classmethod
648717
def invert_sufficient_statistics(
649718
cls,

test_autofit/messages/__init__.py

Whitespace-only changes.
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
import numpy as np
2+
import pytest
3+
from scipy.stats import norm
4+
5+
from autofit.messages.normal import NormalMessage
6+
from autofit.messages.truncated_normal import (
7+
TruncatedNaturalNormal,
8+
TruncatedNormalMessage,
9+
)
10+
11+
12+
@pytest.fixture
13+
def truncated_message():
14+
# Bounds chosen so the inner test points are well inside the support
15+
# but Z is meaningfully less than 1 (Z ≈ 0.34).
16+
return TruncatedNormalMessage(mean=0.0, sigma=1.0, lower_limit=-1.0, upper_limit=0.5)
17+
18+
19+
@pytest.fixture
20+
def reference_normal():
21+
return NormalMessage(mean=0.0, sigma=1.0)
22+
23+
24+
def _expected_log_Z(message):
25+
a = (message.lower_limit - message.mean) / message.sigma
26+
b = (message.upper_limit - message.mean) / message.sigma
27+
return float(np.log(norm.cdf(b) - norm.cdf(a)))
28+
29+
30+
def test_in_support_gradient_and_hessian_match_normal(
31+
truncated_message, reference_normal
32+
):
33+
x = np.array([-0.5, 0.0, 0.25])
34+
35+
_, grad_t, hess_t = truncated_message.logpdf_gradient_hessian(x)
36+
_, grad_n, hess_n = reference_normal.logpdf_gradient_hessian(x)
37+
38+
assert np.allclose(grad_t, grad_n)
39+
# Hessian for a univariate Gaussian is the same scalar at every x.
40+
assert np.allclose(hess_t, hess_n)
41+
42+
43+
def test_in_support_logl_differs_by_minus_log_Z(
44+
truncated_message, reference_normal
45+
):
46+
x = np.array([-0.5, 0.0, 0.25])
47+
48+
logl_t, _, _ = truncated_message.logpdf_gradient_hessian(x)
49+
logl_n, _, _ = reference_normal.logpdf_gradient_hessian(x)
50+
51+
assert np.allclose(logl_t, logl_n - _expected_log_Z(truncated_message))
52+
53+
54+
def test_out_of_support_array(truncated_message):
55+
x = np.array([-2.0, -0.5, 1.5]) # below, inside, above
56+
57+
logl, grad, _ = truncated_message.logpdf_gradient_hessian(x)
58+
59+
assert np.isneginf(logl[0])
60+
assert np.isfinite(logl[1])
61+
assert np.isneginf(logl[2])
62+
63+
assert grad[0] == 0.0
64+
assert grad[2] == 0.0
65+
assert grad[1] != 0.0
66+
67+
68+
def test_out_of_support_scalar_below(truncated_message):
69+
logl, grad, _ = truncated_message.logpdf_gradient_hessian(-2.0)
70+
assert np.isneginf(logl)
71+
assert grad == 0.0
72+
73+
74+
def test_out_of_support_scalar_above(truncated_message):
75+
logl, grad, _ = truncated_message.logpdf_gradient_hessian(1.5)
76+
assert np.isneginf(logl)
77+
assert grad == 0.0
78+
79+
80+
def test_scalar_returns_scalar(truncated_message):
81+
logl, grad, hess = truncated_message.logpdf_gradient_hessian(0.0)
82+
83+
assert np.ndim(logl) == 0
84+
assert np.ndim(grad) == 0
85+
assert np.ndim(hess) == 0
86+
87+
88+
def test_array_returns_array(truncated_message):
89+
x = np.array([-0.5, 0.0, 0.25])
90+
logl, grad, _ = truncated_message.logpdf_gradient_hessian(x)
91+
92+
assert logl.shape == x.shape
93+
assert grad.shape == x.shape
94+
95+
96+
def test_logpdf_gradient_returns_two(truncated_message):
97+
result = truncated_message.logpdf_gradient(np.array([-0.5, 0.0]))
98+
assert len(result) == 2
99+
100+
101+
def test_numerical_gradient_agreement(truncated_message):
102+
# Stay strictly inside the support so the truncation indicator's
103+
# discontinuity at the boundary doesn't leak into the finite-difference
104+
# estimate.
105+
x = np.array([-0.6, -0.2, 0.3])
106+
107+
res = truncated_message.logpdf_gradient(x)
108+
nres = truncated_message.numerical_logpdf_gradient(x)
109+
for analytic, numerical in zip(res, nres):
110+
assert np.allclose(analytic, numerical, rtol=1e-2, atol=1e-2)
111+
112+
res = truncated_message.logpdf_gradient_hessian(x)
113+
nres = truncated_message.numerical_logpdf_gradient_hessian(x)
114+
for analytic, numerical in zip(res, nres):
115+
assert np.allclose(analytic, numerical, rtol=1e-2, atol=1e-2)
116+
117+
118+
def test_no_truncation_matches_normal():
119+
# With infinite bounds, log Z = 0 and the truncated message should agree
120+
# with the untruncated one on logl, gradient, and Hessian.
121+
truncated = TruncatedNormalMessage(mean=0.5, sigma=1.3)
122+
normal = NormalMessage(mean=0.5, sigma=1.3)
123+
124+
x = np.array([-1.0, 0.0, 1.0, 2.0])
125+
logl_t, grad_t, hess_t = truncated.logpdf_gradient_hessian(x)
126+
logl_n, grad_n, hess_n = normal.logpdf_gradient_hessian(x)
127+
128+
assert np.allclose(logl_t, logl_n)
129+
assert np.allclose(grad_t, grad_n)
130+
assert np.allclose(hess_t, hess_n)
131+
132+
133+
def test_truncated_natural_normal_finite_gradients():
134+
# Build via natural parameters: eta1 = mu/sigma^2, eta2 = -1/(2 sigma^2).
135+
mu_underlying, sigma_underlying = 0.0, 1.0
136+
eta1 = mu_underlying / sigma_underlying ** 2
137+
eta2 = -0.5 / sigma_underlying ** 2
138+
msg = TruncatedNaturalNormal(
139+
eta1, eta2, lower_limit=-1.0, upper_limit=0.5
140+
)
141+
142+
x = np.array([-0.5, 0.0, 0.25])
143+
logl, grad, hess = msg.logpdf_gradient_hessian(x)
144+
145+
assert np.all(np.isfinite(logl))
146+
assert np.all(np.isfinite(grad))
147+
assert np.all(np.isfinite(hess))
148+
149+
150+
def test_truncated_natural_normal_uses_underlying_mu_sigma():
151+
# The override on TruncatedNaturalNormal must reconstruct the underlying
152+
# (mu, sigma) from the natural parameters rather than using
153+
# self.mean / self.sigma (which return *truncated* moments). Verify by
154+
# comparing against an equivalent TruncatedNormalMessage built from the
155+
# same underlying parameters.
156+
mu_underlying, sigma_underlying = 0.2, 0.8
157+
eta1 = mu_underlying / sigma_underlying ** 2
158+
eta2 = -0.5 / sigma_underlying ** 2
159+
160+
natural = TruncatedNaturalNormal(
161+
eta1, eta2, lower_limit=-1.0, upper_limit=0.5
162+
)
163+
standard = TruncatedNormalMessage(
164+
mean=mu_underlying,
165+
sigma=sigma_underlying,
166+
lower_limit=-1.0,
167+
upper_limit=0.5,
168+
)
169+
170+
x = np.array([-0.4, 0.0, 0.3])
171+
logl_n, grad_n, hess_n = natural.logpdf_gradient_hessian(x)
172+
logl_s, grad_s, hess_s = standard.logpdf_gradient_hessian(x)
173+
174+
assert np.allclose(logl_n, logl_s)
175+
assert np.allclose(grad_n, grad_s)
176+
assert np.allclose(hess_n, hess_s)

0 commit comments

Comments
 (0)