|
| 1 | +import numpy as np |
| 2 | +import pytest |
| 3 | +from scipy.stats import norm |
| 4 | + |
| 5 | +from autofit.messages.normal import NormalMessage |
| 6 | +from autofit.messages.truncated_normal import ( |
| 7 | + TruncatedNaturalNormal, |
| 8 | + TruncatedNormalMessage, |
| 9 | +) |
| 10 | + |
| 11 | + |
| 12 | +@pytest.fixture |
| 13 | +def truncated_message(): |
| 14 | + # Bounds chosen so the inner test points are well inside the support |
| 15 | + # but Z is meaningfully less than 1 (Z ≈ 0.34). |
| 16 | + return TruncatedNormalMessage(mean=0.0, sigma=1.0, lower_limit=-1.0, upper_limit=0.5) |
| 17 | + |
| 18 | + |
| 19 | +@pytest.fixture |
| 20 | +def reference_normal(): |
| 21 | + return NormalMessage(mean=0.0, sigma=1.0) |
| 22 | + |
| 23 | + |
| 24 | +def _expected_log_Z(message): |
| 25 | + a = (message.lower_limit - message.mean) / message.sigma |
| 26 | + b = (message.upper_limit - message.mean) / message.sigma |
| 27 | + return float(np.log(norm.cdf(b) - norm.cdf(a))) |
| 28 | + |
| 29 | + |
| 30 | +def test_in_support_gradient_and_hessian_match_normal( |
| 31 | + truncated_message, reference_normal |
| 32 | +): |
| 33 | + x = np.array([-0.5, 0.0, 0.25]) |
| 34 | + |
| 35 | + _, grad_t, hess_t = truncated_message.logpdf_gradient_hessian(x) |
| 36 | + _, grad_n, hess_n = reference_normal.logpdf_gradient_hessian(x) |
| 37 | + |
| 38 | + assert np.allclose(grad_t, grad_n) |
| 39 | + # Hessian for a univariate Gaussian is the same scalar at every x. |
| 40 | + assert np.allclose(hess_t, hess_n) |
| 41 | + |
| 42 | + |
| 43 | +def test_in_support_logl_differs_by_minus_log_Z( |
| 44 | + truncated_message, reference_normal |
| 45 | +): |
| 46 | + x = np.array([-0.5, 0.0, 0.25]) |
| 47 | + |
| 48 | + logl_t, _, _ = truncated_message.logpdf_gradient_hessian(x) |
| 49 | + logl_n, _, _ = reference_normal.logpdf_gradient_hessian(x) |
| 50 | + |
| 51 | + assert np.allclose(logl_t, logl_n - _expected_log_Z(truncated_message)) |
| 52 | + |
| 53 | + |
| 54 | +def test_out_of_support_array(truncated_message): |
| 55 | + x = np.array([-2.0, -0.5, 1.5]) # below, inside, above |
| 56 | + |
| 57 | + logl, grad, _ = truncated_message.logpdf_gradient_hessian(x) |
| 58 | + |
| 59 | + assert np.isneginf(logl[0]) |
| 60 | + assert np.isfinite(logl[1]) |
| 61 | + assert np.isneginf(logl[2]) |
| 62 | + |
| 63 | + assert grad[0] == 0.0 |
| 64 | + assert grad[2] == 0.0 |
| 65 | + assert grad[1] != 0.0 |
| 66 | + |
| 67 | + |
| 68 | +def test_out_of_support_scalar_below(truncated_message): |
| 69 | + logl, grad, _ = truncated_message.logpdf_gradient_hessian(-2.0) |
| 70 | + assert np.isneginf(logl) |
| 71 | + assert grad == 0.0 |
| 72 | + |
| 73 | + |
| 74 | +def test_out_of_support_scalar_above(truncated_message): |
| 75 | + logl, grad, _ = truncated_message.logpdf_gradient_hessian(1.5) |
| 76 | + assert np.isneginf(logl) |
| 77 | + assert grad == 0.0 |
| 78 | + |
| 79 | + |
| 80 | +def test_scalar_returns_scalar(truncated_message): |
| 81 | + logl, grad, hess = truncated_message.logpdf_gradient_hessian(0.0) |
| 82 | + |
| 83 | + assert np.ndim(logl) == 0 |
| 84 | + assert np.ndim(grad) == 0 |
| 85 | + assert np.ndim(hess) == 0 |
| 86 | + |
| 87 | + |
| 88 | +def test_array_returns_array(truncated_message): |
| 89 | + x = np.array([-0.5, 0.0, 0.25]) |
| 90 | + logl, grad, _ = truncated_message.logpdf_gradient_hessian(x) |
| 91 | + |
| 92 | + assert logl.shape == x.shape |
| 93 | + assert grad.shape == x.shape |
| 94 | + |
| 95 | + |
| 96 | +def test_logpdf_gradient_returns_two(truncated_message): |
| 97 | + result = truncated_message.logpdf_gradient(np.array([-0.5, 0.0])) |
| 98 | + assert len(result) == 2 |
| 99 | + |
| 100 | + |
| 101 | +def test_numerical_gradient_agreement(truncated_message): |
| 102 | + # Stay strictly inside the support so the truncation indicator's |
| 103 | + # discontinuity at the boundary doesn't leak into the finite-difference |
| 104 | + # estimate. |
| 105 | + x = np.array([-0.6, -0.2, 0.3]) |
| 106 | + |
| 107 | + res = truncated_message.logpdf_gradient(x) |
| 108 | + nres = truncated_message.numerical_logpdf_gradient(x) |
| 109 | + for analytic, numerical in zip(res, nres): |
| 110 | + assert np.allclose(analytic, numerical, rtol=1e-2, atol=1e-2) |
| 111 | + |
| 112 | + res = truncated_message.logpdf_gradient_hessian(x) |
| 113 | + nres = truncated_message.numerical_logpdf_gradient_hessian(x) |
| 114 | + for analytic, numerical in zip(res, nres): |
| 115 | + assert np.allclose(analytic, numerical, rtol=1e-2, atol=1e-2) |
| 116 | + |
| 117 | + |
| 118 | +def test_no_truncation_matches_normal(): |
| 119 | + # With infinite bounds, log Z = 0 and the truncated message should agree |
| 120 | + # with the untruncated one on logl, gradient, and Hessian. |
| 121 | + truncated = TruncatedNormalMessage(mean=0.5, sigma=1.3) |
| 122 | + normal = NormalMessage(mean=0.5, sigma=1.3) |
| 123 | + |
| 124 | + x = np.array([-1.0, 0.0, 1.0, 2.0]) |
| 125 | + logl_t, grad_t, hess_t = truncated.logpdf_gradient_hessian(x) |
| 126 | + logl_n, grad_n, hess_n = normal.logpdf_gradient_hessian(x) |
| 127 | + |
| 128 | + assert np.allclose(logl_t, logl_n) |
| 129 | + assert np.allclose(grad_t, grad_n) |
| 130 | + assert np.allclose(hess_t, hess_n) |
| 131 | + |
| 132 | + |
| 133 | +def test_truncated_natural_normal_finite_gradients(): |
| 134 | + # Build via natural parameters: eta1 = mu/sigma^2, eta2 = -1/(2 sigma^2). |
| 135 | + mu_underlying, sigma_underlying = 0.0, 1.0 |
| 136 | + eta1 = mu_underlying / sigma_underlying ** 2 |
| 137 | + eta2 = -0.5 / sigma_underlying ** 2 |
| 138 | + msg = TruncatedNaturalNormal( |
| 139 | + eta1, eta2, lower_limit=-1.0, upper_limit=0.5 |
| 140 | + ) |
| 141 | + |
| 142 | + x = np.array([-0.5, 0.0, 0.25]) |
| 143 | + logl, grad, hess = msg.logpdf_gradient_hessian(x) |
| 144 | + |
| 145 | + assert np.all(np.isfinite(logl)) |
| 146 | + assert np.all(np.isfinite(grad)) |
| 147 | + assert np.all(np.isfinite(hess)) |
| 148 | + |
| 149 | + |
| 150 | +def test_truncated_natural_normal_uses_underlying_mu_sigma(): |
| 151 | + # The override on TruncatedNaturalNormal must reconstruct the underlying |
| 152 | + # (mu, sigma) from the natural parameters rather than using |
| 153 | + # self.mean / self.sigma (which return *truncated* moments). Verify by |
| 154 | + # comparing against an equivalent TruncatedNormalMessage built from the |
| 155 | + # same underlying parameters. |
| 156 | + mu_underlying, sigma_underlying = 0.2, 0.8 |
| 157 | + eta1 = mu_underlying / sigma_underlying ** 2 |
| 158 | + eta2 = -0.5 / sigma_underlying ** 2 |
| 159 | + |
| 160 | + natural = TruncatedNaturalNormal( |
| 161 | + eta1, eta2, lower_limit=-1.0, upper_limit=0.5 |
| 162 | + ) |
| 163 | + standard = TruncatedNormalMessage( |
| 164 | + mean=mu_underlying, |
| 165 | + sigma=sigma_underlying, |
| 166 | + lower_limit=-1.0, |
| 167 | + upper_limit=0.5, |
| 168 | + ) |
| 169 | + |
| 170 | + x = np.array([-0.4, 0.0, 0.3]) |
| 171 | + logl_n, grad_n, hess_n = natural.logpdf_gradient_hessian(x) |
| 172 | + logl_s, grad_s, hess_s = standard.logpdf_gradient_hessian(x) |
| 173 | + |
| 174 | + assert np.allclose(logl_n, logl_s) |
| 175 | + assert np.allclose(grad_n, grad_s) |
| 176 | + assert np.allclose(hess_n, hess_s) |
0 commit comments