Skip to content

feat: replace Ludlow16 colossus pure_callback with JAX-native implementation #403

@Jammy2211

Description

@Jammy2211

Overview

Phase 2 follow-up to #397. The Phase 1 feasibility report (docs/research/nfw_ludlow16_jax_assessment.md, shipped in #402) confirmed that the jax.pure_callback wrapping colossus.halo.concentration in autogalaxy/profiles/mass/dark/mcr_util.py can be replaced by a JAX-native port of modelLudlow16 (Approach A — full Eisenstein-Hu '98 transfer + Heath '77 growth factor + Einasto gammainc mass ratio + 200-point Ludlow c-solver, ~330 lines of straight-line JAX).

Validated numbers from the Phase 1 prototype at docs/research/nfw_ludlow16_jax/ludlow16_jax.py:

  • Max c200c relative error vs colossus over the lensing grid: 7.5 × 10⁻⁴
  • Max kappa_s relative error end-to-end: 1.07 × 10⁻³
  • Max NFW convergence / deflection per-pixel relative error: 8.21 × 10⁻⁴
  • ~350× below the intrinsic Ludlow16 scatter (σ_log10(c) = 0.13 dex) — scientifically invisible.
  • Single-call post-JIT: 0.69 ms vs colossus' 0.83 ms.
  • jax.grad works end-to-end (agreement to 7 × 10⁻⁴ vs finite-diff).

This issue implements the swap-in: move the prototype into production, collapse the JAX/NumPy branches in the two callers, drop colossus as a runtime dependency, add tests.

Plan

  • Promote docs/research/nfw_ludlow16_jax/ludlow16_jax.py into autogalaxy/profiles/mass/dark/ludlow16.py (or fold it into mcr_util.py if size stays reasonable). Convert the jnp.* calls to be xp-aware so the same function runs under both numpy and JAX paths (the numpy path is needed for xp=np callers; both branches already exist in the helpers).
  • Replace _ludlow16_cosmology_callback and ludlow16_cosmology_jax in mcr_util.py with a single xp-aware ludlow16_cosmology(mass_at_200, redshift_object, redshift_source, xp=np) that calls the new ludlow16_concentration_xp(...) for the c200c part and continues to use the existing JAX-native Planck15 for the other three returned values.
  • Drop the if xp is np: ... else: ... branching in kappa_s_and_scale_radius_for_ludlow and kappa_s_scale_radius_and_core_radius_for_ludlow — both become straight xp calls.
  • Make colossus an optional dependency: only the unit-test cross-check against the original implementation needs it at runtime. Update pyproject.toml (and any conda/eden recipe) accordingly.
  • Add unit tests covering:
    • Numpy path regression vs colossus on a small (log M, z) grid (the existing prototype's validate.py numbers — keep them as the regression baseline).
    • JAX path jit correctness (numpy/JAX paths agree to ~1e-12).
    • jax.grad smoke test (finite, non-NaN gradient through ludlow16_concentration_xp).
  • Validate the five MCR-Ludlow callers end-to-end under jax.jit in autogalaxy_workspace_test/: nfw_mcr, nfw_mcr_scatter, nfw_truncated_mcr_scatter, gnfw_mcr, cnfw_mcr_scatter.

Out of scope (separate work)

Two pre-existing issues surfaced during the Phase 1 science check but are unrelated to this swap-in:

  • cNFWSph.convergence_2d_from returns zeros — marked "not yet implemented" in cnfw.py:239-244.
  • The Penarrubia mcr formula in kappa_s_scale_radius_and_core_radius_for_ludlow produces a negative kappa_s for f_c ≳ 0.18 due to a sign flip in the denominator. Either restrict f_c priors or guard the formula.

References

🤖 Generated with Claude Code

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions