diff --git a/autofit/mapper/prior/_erf_helpers.py b/autofit/mapper/prior/_erf_helpers.py new file mode 100644 index 000000000..b07fd2bad --- /dev/null +++ b/autofit/mapper/prior/_erf_helpers.py @@ -0,0 +1,70 @@ +"""Direct-`ndtr` primitives for hot prior-transform paths. + +Replaces `scipy.stats.norm.cdf` / `norm.ppf` (and their `jax.scipy.stats` +counterparts) with direct calls to `scipy.special.ndtr` / `ndtri` — the +Cephes routines that scipy.stats wraps. Bit-exact equivalent on both +NumPy and JAX backends, but skips the +`scipy.stats._distn_infrastructure` wrapper overhead — which the +graphical-ep-scale-up cProfile baseline showed was the #1 hotspot in +`TruncatedGaussianPrior.value_for` (~33% of total wall time at N=10). + +See PyAutoFit issue #1284 for the motivating measurements. +""" + +import numpy as np + + +def _norm_cdf(z, xp): + """Standard-normal CDF (== ``scipy.stats.norm.cdf(z)`` to ULPs).""" + if xp is np: + from scipy.special import ndtr + else: + from jax.scipy.special import ndtr + return ndtr(z) + + +def _norm_ppf(p, xp): + """Standard-normal PPF (== ``scipy.stats.norm.ppf(p)`` to ULPs).""" + if xp is np: + from scipy.special import ndtri + else: + from jax.scipy.special import ndtri + return ndtri(p) + + +def truncated_normal_value_for(unit, mean, sigma, lower_limit, upper_limit, xp=np): + """Inverse-CDF mapping for a truncated normal distribution. + + Returns ``mean + sigma * Phi^{-1}(Phi(a) + unit * (Phi(b) - Phi(a)))`` + where ``a = (lower_limit - mean) / sigma`` and + ``b = (upper_limit - mean) / sigma``. + + Used by ``TruncatedGaussianPrior.value_for`` and + ``TruncatedNormalMessage.value_for`` to share a single + `scipy.special.erf`-based code path on both NumPy and JAX backends. + + Parameters + ---------- + unit + Unit-cube draw(s) in ``[0, 1]``. Scalar or array. + mean, sigma + Underlying-Gaussian mean and standard deviation. + lower_limit, upper_limit + Truncation bounds. ``-inf`` / ``+inf`` are supported. + xp + Array module: ``numpy`` (default) or ``jax.numpy``. Determines + whether ``scipy.special`` or ``jax.scipy.special`` is used. + + Returns + ------- + Physical sample(s) drawn from the truncated normal. + """ + a = (lower_limit - mean) / sigma + b = (upper_limit - mean) / sigma + + lower_cdf = _norm_cdf(a, xp) + upper_cdf = _norm_cdf(b, xp) + truncated_cdf = lower_cdf + unit * (upper_cdf - lower_cdf) + + x_standard = _norm_ppf(truncated_cdf, xp) + return mean + sigma * x_standard diff --git a/autofit/mapper/prior/truncated_gaussian.py b/autofit/mapper/prior/truncated_gaussian.py index 77a5d182b..0728bde47 100644 --- a/autofit/mapper/prior/truncated_gaussian.py +++ b/autofit/mapper/prior/truncated_gaussian.py @@ -143,21 +143,12 @@ def value_for(self, unit, xp=np): A unit value between 0 and 1. xp Array-module to dispatch on (``numpy`` or ``jax.numpy``). Default ``numpy``. - Both paths share the standard truncated-normal inverse-CDF construction - via ``norm.cdf`` / ``norm.ppf`` from the matching ``scipy.stats`` / - ``jax.scipy.stats`` namespace. + Delegates to ``_erf_helpers.truncated_normal_value_for``, which uses + ``scipy.special.erf`` / ``erfinv`` (or the ``jax.scipy.special`` + equivalents) directly — skipping the ``scipy.stats`` wrapper that + previously dominated this hot path. """ - if xp is np: - from scipy.stats import norm - else: - from jax.scipy.stats import norm - - a = (self.lower_limit - self.mean) / self.sigma - b = (self.upper_limit - self.mean) / self.sigma - - lower_cdf = norm.cdf(a) - upper_cdf = norm.cdf(b) - truncated_cdf = lower_cdf + unit * (upper_cdf - lower_cdf) - - x_standard = norm.ppf(truncated_cdf) - return self.mean + self.sigma * x_standard + from autofit.mapper.prior._erf_helpers import truncated_normal_value_for + return truncated_normal_value_for( + unit, self.mean, self.sigma, self.lower_limit, self.upper_limit, xp, + ) diff --git a/autofit/messages/truncated_normal.py b/autofit/messages/truncated_normal.py index 609ad160b..812f80384 100644 --- a/autofit/messages/truncated_normal.py +++ b/autofit/messages/truncated_normal.py @@ -463,20 +463,10 @@ def value_for(self, unit, xp=np): >>> prior = af.TruncatedNormalMessage(mean=1.0, sigma=2.0, lower_limit=0.0, upper_limit=2.0) >>> physical_value = prior.value_for(unit=0.5) """ - if xp is np: - from scipy.stats import norm - else: - from jax.scipy.stats import norm - - a = (self.lower_limit - self.mean) / self.sigma - b = (self.upper_limit - self.mean) / self.sigma - - lower_cdf = norm.cdf(a) - upper_cdf = norm.cdf(b) - truncated_cdf = lower_cdf + unit * (upper_cdf - lower_cdf) - - x_standard = norm.ppf(truncated_cdf) - return self.mean + self.sigma * x_standard + from autofit.mapper.prior._erf_helpers import truncated_normal_value_for + return truncated_normal_value_for( + unit, self.mean, self.sigma, self.lower_limit, self.upper_limit, xp, + ) def log_prior_from_value(self, value: float, xp=np) -> float: """ diff --git a/test_autofit/mapper/prior/test_truncated_gaussian.py b/test_autofit/mapper/prior/test_truncated_gaussian.py index 2efbc5a17..78ac12dd4 100644 --- a/test_autofit/mapper/prior/test_truncated_gaussian.py +++ b/test_autofit/mapper/prior/test_truncated_gaussian.py @@ -1,7 +1,11 @@ +import math + import numpy as np import pytest +from scipy.stats import norm, truncnorm import autofit as af +from autofit.messages.truncated_normal import TruncatedNormalMessage @pytest.fixture(name="truncated_gaussian") @@ -34,3 +38,106 @@ def test__log_prior_from_value(truncated_gaussian, unit, value): assert truncated_gaussian.log_prior_from_value(unit) == pytest.approx(value, rel=0.1) +# --- Numerical equivalence: new direct-ndtr path vs the OLD scipy.stats.norm +# CDF/PPF composition that this PR replaces. They must be bit-exact equal — +# that's the "numerics don't change" gate. + +PARAMS = [ + # (mean, sigma, lower_limit, upper_limit) + (0.0, 1.0, -3.0, 3.0), # symmetric, moderate + (0.0, 1.0, -10.0, 10.0), # very wide + (5.0, 2.0, 0.0, math.inf), # half-bounded (matches toy normalization) + (5.0, 5.0, 0.0, math.inf), # half-bounded (matches toy sigma) + (1.0, 2.0, 0.95, 1.05), # narrow bracket + (0.0, 1.0, -0.001, 0.001), # very narrow +] + +UNITS = [1e-6, 1e-3, 0.1, 0.3, 0.5, 0.7, 0.9, 1 - 1e-3, 1 - 1e-6] + + +def _old_value_for(unit, mean, sigma, lower, upper): + """Reproduces the pre-refactor scipy.stats.norm.cdf/ppf composition. + This is the algorithm whose results must be preserved.""" + a = (lower - mean) / sigma + b = (upper - mean) / sigma + lower_cdf = norm.cdf(a) + upper_cdf = norm.cdf(b) + truncated_cdf = lower_cdf + unit * (upper_cdf - lower_cdf) + x_standard = norm.ppf(truncated_cdf) + return mean + sigma * x_standard + + +@pytest.mark.parametrize("mean,sigma,lower,upper", PARAMS) +@pytest.mark.parametrize("unit", UNITS) +def test__prior_value_for_bit_exact_to_old_path(unit, mean, sigma, lower, upper): + """`TruncatedGaussianPrior.value_for` must produce results bit-exact to + the pre-refactor scipy.stats.norm.cdf/ppf composition that this PR + replaces. This is the "numerics don't change" gate at the algorithmic + level — both paths share the same ndtr/ndtri Cephes routines, only the + Python-side wrapper differs.""" + prior = af.TruncatedGaussianPrior( + mean=mean, sigma=sigma, lower_limit=lower, upper_limit=upper, + ) + expected = float(_old_value_for(unit, mean, sigma, lower, upper)) + actual = float(prior.value_for(unit)) + + if expected == 0.0: + assert actual == 0.0 + else: + # Same Cephes routines under the hood — must be bit-exact. + assert actual == expected, f"new={actual!r} old={expected!r}" + + +@pytest.mark.parametrize("mean,sigma,lower,upper", PARAMS) +@pytest.mark.parametrize("unit", [0.1, 0.3, 0.5, 0.7, 0.9]) +def test__prior_value_for_close_to_scipy_truncnorm(unit, mean, sigma, lower, upper): + """`TruncatedGaussianPrior.value_for` matches scipy.stats.truncnorm.ppf + away from the deep tails. scipy.stats.truncnorm uses its own tail-safe + branching that the simple ndtr/ndtri composition does not — so this + test deliberately covers only ``unit in [0.1, 0.9]`` where both paths + are stable. Documents the precision regime; not a regression gate.""" + prior = af.TruncatedGaussianPrior( + mean=mean, sigma=sigma, lower_limit=lower, upper_limit=upper, + ) + a = (lower - mean) / sigma + b = (upper - mean) / sigma + expected = float(truncnorm.ppf(unit, a=a, b=b, loc=mean, scale=sigma)) + actual = float(prior.value_for(unit)) + + if expected == 0.0: + assert actual == pytest.approx(0.0, abs=1e-12) + else: + assert actual == pytest.approx(expected, rel=1e-10) + + +@pytest.mark.parametrize("mean,sigma,lower,upper", PARAMS) +@pytest.mark.parametrize("unit", UNITS) +def test__message_value_for_matches_prior(unit, mean, sigma, lower, upper): + """`TruncatedNormalMessage.value_for` must produce the same output as + `TruncatedGaussianPrior.value_for` for matching parameters — both now + route through the shared helper, so the equality is bit-exact.""" + prior = af.TruncatedGaussianPrior( + mean=mean, sigma=sigma, lower_limit=lower, upper_limit=upper, + ) + message = TruncatedNormalMessage( + mean=mean, sigma=sigma, lower_limit=lower, upper_limit=upper, + ) + assert float(message.value_for(unit)) == float(prior.value_for(unit)) + + +def test__jax_value_for_parity(): + """JAX path must match the numpy path to within float64 rounding noise. + + Uses moderate (half-bounded) parameters representative of the toy model. + Skipped if jax is not installed; CI / dev installs both. + """ + jax = pytest.importorskip("jax") + jnp = jax.numpy + + prior = af.TruncatedGaussianPrior( + mean=5.0, sigma=2.0, lower_limit=0.0, upper_limit=math.inf, + ) + for unit in [0.1, 0.5, 0.9]: + numpy_val = float(prior.value_for(unit, xp=np)) + jax_val = float(prior.value_for(jnp.asarray(unit), xp=jnp)) + assert jax_val == pytest.approx(numpy_val, rel=1e-9)