Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions autofit/mapper/prior/_erf_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Direct-`ndtr` primitives for hot prior-transform paths.

Replaces `scipy.stats.norm.cdf` / `norm.ppf` (and their `jax.scipy.stats`
counterparts) with direct calls to `scipy.special.ndtr` / `ndtri` — the
Cephes routines that scipy.stats wraps. Bit-exact equivalent on both
NumPy and JAX backends, but skips the
`scipy.stats._distn_infrastructure` wrapper overhead — which the
graphical-ep-scale-up cProfile baseline showed was the #1 hotspot in
`TruncatedGaussianPrior.value_for` (~33% of total wall time at N=10).

See PyAutoFit issue #1284 for the motivating measurements.
"""

import numpy as np


def _norm_cdf(z, xp):
"""Standard-normal CDF (== ``scipy.stats.norm.cdf(z)`` to ULPs)."""
if xp is np:
from scipy.special import ndtr
else:
from jax.scipy.special import ndtr
return ndtr(z)


def _norm_ppf(p, xp):
"""Standard-normal PPF (== ``scipy.stats.norm.ppf(p)`` to ULPs)."""
if xp is np:
from scipy.special import ndtri
else:
from jax.scipy.special import ndtri
return ndtri(p)


def truncated_normal_value_for(unit, mean, sigma, lower_limit, upper_limit, xp=np):
"""Inverse-CDF mapping for a truncated normal distribution.

Returns ``mean + sigma * Phi^{-1}(Phi(a) + unit * (Phi(b) - Phi(a)))``
where ``a = (lower_limit - mean) / sigma`` and
``b = (upper_limit - mean) / sigma``.

Used by ``TruncatedGaussianPrior.value_for`` and
``TruncatedNormalMessage.value_for`` to share a single
`scipy.special.erf`-based code path on both NumPy and JAX backends.

Parameters
----------
unit
Unit-cube draw(s) in ``[0, 1]``. Scalar or array.
mean, sigma
Underlying-Gaussian mean and standard deviation.
lower_limit, upper_limit
Truncation bounds. ``-inf`` / ``+inf`` are supported.
xp
Array module: ``numpy`` (default) or ``jax.numpy``. Determines
whether ``scipy.special`` or ``jax.scipy.special`` is used.

Returns
-------
Physical sample(s) drawn from the truncated normal.
"""
a = (lower_limit - mean) / sigma
b = (upper_limit - mean) / sigma

lower_cdf = _norm_cdf(a, xp)
upper_cdf = _norm_cdf(b, xp)
truncated_cdf = lower_cdf + unit * (upper_cdf - lower_cdf)

x_standard = _norm_ppf(truncated_cdf, xp)
return mean + sigma * x_standard
25 changes: 8 additions & 17 deletions autofit/mapper/prior/truncated_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,21 +143,12 @@ def value_for(self, unit, xp=np):
A unit value between 0 and 1.
xp
Array-module to dispatch on (``numpy`` or ``jax.numpy``). Default ``numpy``.
Both paths share the standard truncated-normal inverse-CDF construction
via ``norm.cdf`` / ``norm.ppf`` from the matching ``scipy.stats`` /
``jax.scipy.stats`` namespace.
Delegates to ``_erf_helpers.truncated_normal_value_for``, which uses
``scipy.special.erf`` / ``erfinv`` (or the ``jax.scipy.special``
equivalents) directly — skipping the ``scipy.stats`` wrapper that
previously dominated this hot path.
"""
if xp is np:
from scipy.stats import norm
else:
from jax.scipy.stats import norm

a = (self.lower_limit - self.mean) / self.sigma
b = (self.upper_limit - self.mean) / self.sigma

lower_cdf = norm.cdf(a)
upper_cdf = norm.cdf(b)
truncated_cdf = lower_cdf + unit * (upper_cdf - lower_cdf)

x_standard = norm.ppf(truncated_cdf)
return self.mean + self.sigma * x_standard
from autofit.mapper.prior._erf_helpers import truncated_normal_value_for
return truncated_normal_value_for(
unit, self.mean, self.sigma, self.lower_limit, self.upper_limit, xp,
)
18 changes: 4 additions & 14 deletions autofit/messages/truncated_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,20 +463,10 @@ def value_for(self, unit, xp=np):
>>> prior = af.TruncatedNormalMessage(mean=1.0, sigma=2.0, lower_limit=0.0, upper_limit=2.0)
>>> physical_value = prior.value_for(unit=0.5)
"""
if xp is np:
from scipy.stats import norm
else:
from jax.scipy.stats import norm

a = (self.lower_limit - self.mean) / self.sigma
b = (self.upper_limit - self.mean) / self.sigma

lower_cdf = norm.cdf(a)
upper_cdf = norm.cdf(b)
truncated_cdf = lower_cdf + unit * (upper_cdf - lower_cdf)

x_standard = norm.ppf(truncated_cdf)
return self.mean + self.sigma * x_standard
from autofit.mapper.prior._erf_helpers import truncated_normal_value_for
return truncated_normal_value_for(
unit, self.mean, self.sigma, self.lower_limit, self.upper_limit, xp,
)

def log_prior_from_value(self, value: float, xp=np) -> float:
"""
Expand Down
107 changes: 107 additions & 0 deletions test_autofit/mapper/prior/test_truncated_gaussian.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import math

import numpy as np
import pytest
from scipy.stats import norm, truncnorm

import autofit as af
from autofit.messages.truncated_normal import TruncatedNormalMessage


@pytest.fixture(name="truncated_gaussian")
Expand Down Expand Up @@ -34,3 +38,106 @@ def test__log_prior_from_value(truncated_gaussian, unit, value):
assert truncated_gaussian.log_prior_from_value(unit) == pytest.approx(value, rel=0.1)


# --- Numerical equivalence: new direct-ndtr path vs the OLD scipy.stats.norm
# CDF/PPF composition that this PR replaces. They must be bit-exact equal —
# that's the "numerics don't change" gate.

PARAMS = [
# (mean, sigma, lower_limit, upper_limit)
(0.0, 1.0, -3.0, 3.0), # symmetric, moderate
(0.0, 1.0, -10.0, 10.0), # very wide
(5.0, 2.0, 0.0, math.inf), # half-bounded (matches toy normalization)
(5.0, 5.0, 0.0, math.inf), # half-bounded (matches toy sigma)
(1.0, 2.0, 0.95, 1.05), # narrow bracket
(0.0, 1.0, -0.001, 0.001), # very narrow
]

UNITS = [1e-6, 1e-3, 0.1, 0.3, 0.5, 0.7, 0.9, 1 - 1e-3, 1 - 1e-6]


def _old_value_for(unit, mean, sigma, lower, upper):
"""Reproduces the pre-refactor scipy.stats.norm.cdf/ppf composition.
This is the algorithm whose results must be preserved."""
a = (lower - mean) / sigma
b = (upper - mean) / sigma
lower_cdf = norm.cdf(a)
upper_cdf = norm.cdf(b)
truncated_cdf = lower_cdf + unit * (upper_cdf - lower_cdf)
x_standard = norm.ppf(truncated_cdf)
return mean + sigma * x_standard


@pytest.mark.parametrize("mean,sigma,lower,upper", PARAMS)
@pytest.mark.parametrize("unit", UNITS)
def test__prior_value_for_bit_exact_to_old_path(unit, mean, sigma, lower, upper):
"""`TruncatedGaussianPrior.value_for` must produce results bit-exact to
the pre-refactor scipy.stats.norm.cdf/ppf composition that this PR
replaces. This is the "numerics don't change" gate at the algorithmic
level — both paths share the same ndtr/ndtri Cephes routines, only the
Python-side wrapper differs."""
prior = af.TruncatedGaussianPrior(
mean=mean, sigma=sigma, lower_limit=lower, upper_limit=upper,
)
expected = float(_old_value_for(unit, mean, sigma, lower, upper))
actual = float(prior.value_for(unit))

if expected == 0.0:
assert actual == 0.0
else:
# Same Cephes routines under the hood — must be bit-exact.
assert actual == expected, f"new={actual!r} old={expected!r}"


@pytest.mark.parametrize("mean,sigma,lower,upper", PARAMS)
@pytest.mark.parametrize("unit", [0.1, 0.3, 0.5, 0.7, 0.9])
def test__prior_value_for_close_to_scipy_truncnorm(unit, mean, sigma, lower, upper):
"""`TruncatedGaussianPrior.value_for` matches scipy.stats.truncnorm.ppf
away from the deep tails. scipy.stats.truncnorm uses its own tail-safe
branching that the simple ndtr/ndtri composition does not — so this
test deliberately covers only ``unit in [0.1, 0.9]`` where both paths
are stable. Documents the precision regime; not a regression gate."""
prior = af.TruncatedGaussianPrior(
mean=mean, sigma=sigma, lower_limit=lower, upper_limit=upper,
)
a = (lower - mean) / sigma
b = (upper - mean) / sigma
expected = float(truncnorm.ppf(unit, a=a, b=b, loc=mean, scale=sigma))
actual = float(prior.value_for(unit))

if expected == 0.0:
assert actual == pytest.approx(0.0, abs=1e-12)
else:
assert actual == pytest.approx(expected, rel=1e-10)


@pytest.mark.parametrize("mean,sigma,lower,upper", PARAMS)
@pytest.mark.parametrize("unit", UNITS)
def test__message_value_for_matches_prior(unit, mean, sigma, lower, upper):
"""`TruncatedNormalMessage.value_for` must produce the same output as
`TruncatedGaussianPrior.value_for` for matching parameters — both now
route through the shared helper, so the equality is bit-exact."""
prior = af.TruncatedGaussianPrior(
mean=mean, sigma=sigma, lower_limit=lower, upper_limit=upper,
)
message = TruncatedNormalMessage(
mean=mean, sigma=sigma, lower_limit=lower, upper_limit=upper,
)
assert float(message.value_for(unit)) == float(prior.value_for(unit))


def test__jax_value_for_parity():
"""JAX path must match the numpy path to within float64 rounding noise.

Uses moderate (half-bounded) parameters representative of the toy model.
Skipped if jax is not installed; CI / dev installs both.
"""
jax = pytest.importorskip("jax")
jnp = jax.numpy

prior = af.TruncatedGaussianPrior(
mean=5.0, sigma=2.0, lower_limit=0.0, upper_limit=math.inf,
)
for unit in [0.1, 0.5, 0.9]:
numpy_val = float(prior.value_for(unit, xp=np))
jax_val = float(prior.value_for(jnp.asarray(unit), xp=jnp))
assert jax_val == pytest.approx(numpy_val, rel=1e-9)
Loading