Overview
Phase 2 follow-up to #397. The Phase 1 feasibility report (docs/research/nfw_ludlow16_jax_assessment.md, shipped in #402) confirmed that the jax.pure_callback wrapping colossus.halo.concentration in autogalaxy/profiles/mass/dark/mcr_util.py can be replaced by a JAX-native port of modelLudlow16 (Approach A — full Eisenstein-Hu '98 transfer + Heath '77 growth factor + Einasto gammainc mass ratio + 200-point Ludlow c-solver, ~330 lines of straight-line JAX).
Validated numbers from the Phase 1 prototype at docs/research/nfw_ludlow16_jax/ludlow16_jax.py:
- Max c200c relative error vs colossus over the lensing grid: 7.5 × 10⁻⁴
- Max
kappa_s relative error end-to-end: 1.07 × 10⁻³
- Max NFW convergence / deflection per-pixel relative error: 8.21 × 10⁻⁴
- ~350× below the intrinsic Ludlow16 scatter (
σ_log10(c) = 0.13 dex) — scientifically invisible.
- Single-call post-JIT: 0.69 ms vs colossus' 0.83 ms.
jax.grad works end-to-end (agreement to 7 × 10⁻⁴ vs finite-diff).
This issue implements the swap-in: move the prototype into production, collapse the JAX/NumPy branches in the two callers, drop colossus as a runtime dependency, add tests.
Plan
- Promote
docs/research/nfw_ludlow16_jax/ludlow16_jax.py into autogalaxy/profiles/mass/dark/ludlow16.py (or fold it into mcr_util.py if size stays reasonable). Convert the jnp.* calls to be xp-aware so the same function runs under both numpy and JAX paths (the numpy path is needed for xp=np callers; both branches already exist in the helpers).
- Replace
_ludlow16_cosmology_callback and ludlow16_cosmology_jax in mcr_util.py with a single xp-aware ludlow16_cosmology(mass_at_200, redshift_object, redshift_source, xp=np) that calls the new ludlow16_concentration_xp(...) for the c200c part and continues to use the existing JAX-native Planck15 for the other three returned values.
- Drop the
if xp is np: ... else: ... branching in kappa_s_and_scale_radius_for_ludlow and kappa_s_scale_radius_and_core_radius_for_ludlow — both become straight xp calls.
- Make
colossus an optional dependency: only the unit-test cross-check against the original implementation needs it at runtime. Update pyproject.toml (and any conda/eden recipe) accordingly.
- Add unit tests covering:
- Numpy path regression vs colossus on a small (log M, z) grid (the existing prototype's
validate.py numbers — keep them as the regression baseline).
- JAX path
jit correctness (numpy/JAX paths agree to ~1e-12).
jax.grad smoke test (finite, non-NaN gradient through ludlow16_concentration_xp).
- Validate the five MCR-Ludlow callers end-to-end under
jax.jit in autogalaxy_workspace_test/: nfw_mcr, nfw_mcr_scatter, nfw_truncated_mcr_scatter, gnfw_mcr, cnfw_mcr_scatter.
Out of scope (separate work)
Two pre-existing issues surfaced during the Phase 1 science check but are unrelated to this swap-in:
cNFWSph.convergence_2d_from returns zeros — marked "not yet implemented" in cnfw.py:239-244.
- The Penarrubia mcr formula in
kappa_s_scale_radius_and_core_radius_for_ludlow produces a negative kappa_s for f_c ≳ 0.18 due to a sign flip in the denominator. Either restrict f_c priors or guard the formula.
References
🤖 Generated with Claude Code
Overview
Phase 2 follow-up to #397. The Phase 1 feasibility report (
docs/research/nfw_ludlow16_jax_assessment.md, shipped in #402) confirmed that thejax.pure_callbackwrappingcolossus.halo.concentrationinautogalaxy/profiles/mass/dark/mcr_util.pycan be replaced by a JAX-native port ofmodelLudlow16(Approach A — full Eisenstein-Hu '98 transfer + Heath '77 growth factor + Einastogammaincmass ratio + 200-point Ludlow c-solver, ~330 lines of straight-line JAX).Validated numbers from the Phase 1 prototype at
docs/research/nfw_ludlow16_jax/ludlow16_jax.py:kappa_srelative error end-to-end: 1.07 × 10⁻³σ_log10(c) = 0.13 dex) — scientifically invisible.jax.gradworks end-to-end (agreement to 7 × 10⁻⁴ vs finite-diff).This issue implements the swap-in: move the prototype into production, collapse the JAX/NumPy branches in the two callers, drop
colossusas a runtime dependency, add tests.Plan
docs/research/nfw_ludlow16_jax/ludlow16_jax.pyintoautogalaxy/profiles/mass/dark/ludlow16.py(or fold it intomcr_util.pyif size stays reasonable). Convert thejnp.*calls to bexp-aware so the same function runs under both numpy and JAX paths (the numpy path is needed forxp=npcallers; both branches already exist in the helpers)._ludlow16_cosmology_callbackandludlow16_cosmology_jaxinmcr_util.pywith a singlexp-awareludlow16_cosmology(mass_at_200, redshift_object, redshift_source, xp=np)that calls the newludlow16_concentration_xp(...)for the c200c part and continues to use the existing JAX-nativePlanck15for the other three returned values.if xp is np: ... else: ...branching inkappa_s_and_scale_radius_for_ludlowandkappa_s_scale_radius_and_core_radius_for_ludlow— both become straightxpcalls.colossusan optional dependency: only the unit-test cross-check against the original implementation needs it at runtime. Updatepyproject.toml(and any conda/eden recipe) accordingly.validate.pynumbers — keep them as the regression baseline).jitcorrectness (numpy/JAX paths agree to ~1e-12).jax.gradsmoke test (finite, non-NaN gradient throughludlow16_concentration_xp).jax.jitinautogalaxy_workspace_test/:nfw_mcr,nfw_mcr_scatter,nfw_truncated_mcr_scatter,gnfw_mcr,cnfw_mcr_scatter.Out of scope (separate work)
Two pre-existing issues surfaced during the Phase 1 science check but are unrelated to this swap-in:
cNFWSph.convergence_2d_fromreturns zeros — marked "not yet implemented" incnfw.py:239-244.kappa_s_scale_radius_and_core_radius_for_ludlowproduces a negativekappa_sforf_c ≳ 0.18due to a sign flip in the denominator. Either restrictf_cpriors or guard the formula.References
docs/research/nfw_ludlow16_jax_assessment.mddocs/research/nfw_ludlow16_jax/ludlow16_jax.py🤖 Generated with Claude Code