|
| 1 | +import math |
| 2 | + |
1 | 3 | import numpy as np |
2 | 4 | import pytest |
| 5 | +from scipy.stats import norm, truncnorm |
3 | 6 |
|
4 | 7 | import autofit as af |
| 8 | +from autofit.messages.truncated_normal import TruncatedNormalMessage |
5 | 9 |
|
6 | 10 |
|
7 | 11 | @pytest.fixture(name="truncated_gaussian") |
@@ -34,3 +38,106 @@ def test__log_prior_from_value(truncated_gaussian, unit, value): |
34 | 38 | assert truncated_gaussian.log_prior_from_value(unit) == pytest.approx(value, rel=0.1) |
35 | 39 |
|
36 | 40 |
|
| 41 | +# --- Numerical equivalence: new direct-ndtr path vs the OLD scipy.stats.norm |
| 42 | +# CDF/PPF composition that this PR replaces. They must be bit-exact equal — |
| 43 | +# that's the "numerics don't change" gate. |
| 44 | + |
| 45 | +PARAMS = [ |
| 46 | + # (mean, sigma, lower_limit, upper_limit) |
| 47 | + (0.0, 1.0, -3.0, 3.0), # symmetric, moderate |
| 48 | + (0.0, 1.0, -10.0, 10.0), # very wide |
| 49 | + (5.0, 2.0, 0.0, math.inf), # half-bounded (matches toy normalization) |
| 50 | + (5.0, 5.0, 0.0, math.inf), # half-bounded (matches toy sigma) |
| 51 | + (1.0, 2.0, 0.95, 1.05), # narrow bracket |
| 52 | + (0.0, 1.0, -0.001, 0.001), # very narrow |
| 53 | +] |
| 54 | + |
| 55 | +UNITS = [1e-6, 1e-3, 0.1, 0.3, 0.5, 0.7, 0.9, 1 - 1e-3, 1 - 1e-6] |
| 56 | + |
| 57 | + |
| 58 | +def _old_value_for(unit, mean, sigma, lower, upper): |
| 59 | + """Reproduces the pre-refactor scipy.stats.norm.cdf/ppf composition. |
| 60 | + This is the algorithm whose results must be preserved.""" |
| 61 | + a = (lower - mean) / sigma |
| 62 | + b = (upper - mean) / sigma |
| 63 | + lower_cdf = norm.cdf(a) |
| 64 | + upper_cdf = norm.cdf(b) |
| 65 | + truncated_cdf = lower_cdf + unit * (upper_cdf - lower_cdf) |
| 66 | + x_standard = norm.ppf(truncated_cdf) |
| 67 | + return mean + sigma * x_standard |
| 68 | + |
| 69 | + |
| 70 | +@pytest.mark.parametrize("mean,sigma,lower,upper", PARAMS) |
| 71 | +@pytest.mark.parametrize("unit", UNITS) |
| 72 | +def test__prior_value_for_bit_exact_to_old_path(unit, mean, sigma, lower, upper): |
| 73 | + """`TruncatedGaussianPrior.value_for` must produce results bit-exact to |
| 74 | + the pre-refactor scipy.stats.norm.cdf/ppf composition that this PR |
| 75 | + replaces. This is the "numerics don't change" gate at the algorithmic |
| 76 | + level — both paths share the same ndtr/ndtri Cephes routines, only the |
| 77 | + Python-side wrapper differs.""" |
| 78 | + prior = af.TruncatedGaussianPrior( |
| 79 | + mean=mean, sigma=sigma, lower_limit=lower, upper_limit=upper, |
| 80 | + ) |
| 81 | + expected = float(_old_value_for(unit, mean, sigma, lower, upper)) |
| 82 | + actual = float(prior.value_for(unit)) |
| 83 | + |
| 84 | + if expected == 0.0: |
| 85 | + assert actual == 0.0 |
| 86 | + else: |
| 87 | + # Same Cephes routines under the hood — must be bit-exact. |
| 88 | + assert actual == expected, f"new={actual!r} old={expected!r}" |
| 89 | + |
| 90 | + |
| 91 | +@pytest.mark.parametrize("mean,sigma,lower,upper", PARAMS) |
| 92 | +@pytest.mark.parametrize("unit", [0.1, 0.3, 0.5, 0.7, 0.9]) |
| 93 | +def test__prior_value_for_close_to_scipy_truncnorm(unit, mean, sigma, lower, upper): |
| 94 | + """`TruncatedGaussianPrior.value_for` matches scipy.stats.truncnorm.ppf |
| 95 | + away from the deep tails. scipy.stats.truncnorm uses its own tail-safe |
| 96 | + branching that the simple ndtr/ndtri composition does not — so this |
| 97 | + test deliberately covers only ``unit in [0.1, 0.9]`` where both paths |
| 98 | + are stable. Documents the precision regime; not a regression gate.""" |
| 99 | + prior = af.TruncatedGaussianPrior( |
| 100 | + mean=mean, sigma=sigma, lower_limit=lower, upper_limit=upper, |
| 101 | + ) |
| 102 | + a = (lower - mean) / sigma |
| 103 | + b = (upper - mean) / sigma |
| 104 | + expected = float(truncnorm.ppf(unit, a=a, b=b, loc=mean, scale=sigma)) |
| 105 | + actual = float(prior.value_for(unit)) |
| 106 | + |
| 107 | + if expected == 0.0: |
| 108 | + assert actual == pytest.approx(0.0, abs=1e-12) |
| 109 | + else: |
| 110 | + assert actual == pytest.approx(expected, rel=1e-10) |
| 111 | + |
| 112 | + |
| 113 | +@pytest.mark.parametrize("mean,sigma,lower,upper", PARAMS) |
| 114 | +@pytest.mark.parametrize("unit", UNITS) |
| 115 | +def test__message_value_for_matches_prior(unit, mean, sigma, lower, upper): |
| 116 | + """`TruncatedNormalMessage.value_for` must produce the same output as |
| 117 | + `TruncatedGaussianPrior.value_for` for matching parameters — both now |
| 118 | + route through the shared helper, so the equality is bit-exact.""" |
| 119 | + prior = af.TruncatedGaussianPrior( |
| 120 | + mean=mean, sigma=sigma, lower_limit=lower, upper_limit=upper, |
| 121 | + ) |
| 122 | + message = TruncatedNormalMessage( |
| 123 | + mean=mean, sigma=sigma, lower_limit=lower, upper_limit=upper, |
| 124 | + ) |
| 125 | + assert float(message.value_for(unit)) == float(prior.value_for(unit)) |
| 126 | + |
| 127 | + |
| 128 | +def test__jax_value_for_parity(): |
| 129 | + """JAX path must match the numpy path to within float64 rounding noise. |
| 130 | +
|
| 131 | + Uses moderate (half-bounded) parameters representative of the toy model. |
| 132 | + Skipped if jax is not installed; CI / dev installs both. |
| 133 | + """ |
| 134 | + jax = pytest.importorskip("jax") |
| 135 | + jnp = jax.numpy |
| 136 | + |
| 137 | + prior = af.TruncatedGaussianPrior( |
| 138 | + mean=5.0, sigma=2.0, lower_limit=0.0, upper_limit=math.inf, |
| 139 | + ) |
| 140 | + for unit in [0.1, 0.5, 0.9]: |
| 141 | + numpy_val = float(prior.value_for(unit, xp=np)) |
| 142 | + jax_val = float(prior.value_for(jnp.asarray(unit), xp=jnp)) |
| 143 | + assert jax_val == pytest.approx(numpy_val, rel=1e-9) |
0 commit comments