Overview
The _ludlow16_cosmology_callback in autogalaxy/profiles/mass/dark/mcr_util.py is wrapped in jax.pure_callback because the colossus library is not JAX-native. This blocks JIT optimisation and autodiff through any NFW MCR profile (mp.NFWMCRDuffyPhys, mp.NFWMCRLudlow, mp.NFWMCRScatterLudlow, mp.NFWTruncatedMCRScatterLudlow, mp.GNFWMCRLudlow, mp.CNFWMCRScatterLudlow). This issue assesses whether the colossus dependency can be removed in favour of a JAX-native implementation.
Plan
- Audit the dependency surface of
colossus.halo.concentration.modelLudlow16 and identify which pieces are truly JAX-hostile (σ(R) via the matter power spectrum is the only real blocker — D(z), Einasto enclosed mass, and the interpolation loop are all straightforward).
- Note that three of the four values returned by the current callback (
cosmic_average_density, critical_surface_density, kpc_per_arcsec) are already JAX-native via autogalaxy.cosmology.model.Planck15 — only the concentration call needs replacing.
- Prototype Approach A first: a full JAX port of
modelLudlow16 including Eisenstein-Hu '98 σ(R). If A is clean (accuracy ≤ 1% in the lensing regime, code size reasonable, JIT/grad work), stop and write up — skip B and C.
- Only if A turns out to be messy, fall back in order to: (B) a precomputed 2-D lookup table c(M, z) for Planck15 interpolated in JAX at runtime, then (C) a closed-form analytic fit calibrated to Ludlow16 over the lensing parameter regime.
- Deliver a written feasibility report with a recommendation. A separate follow-up issue will implement the chosen approach, remove the
pure_callback, and collapse the JAX/NumPy branches in kappa_s_and_scale_radius_for_ludlow and kappa_s_scale_radius_and_core_radius_for_ludlow.
Detailed implementation plan
Affected Repositories
- PyAutoGalaxy (primary — feasibility report; later implementation lives here)
Work Classification
Library (research / feasibility — Phase 1 produces a report, not code changes in production paths)
Branch Survey
| Repository |
Current Branch |
Dirty? |
| ./PyAutoGalaxy |
main |
clean (modulo a benign CLAUDE.md reformat) |
Suggested branch: feature/nfw-jax-port
Worktree root: ~/Code/PyAutoLabs-wt/nfw-jax-port/ (created later by /start_library)
Note: This task is currently queued in planned.md behind jax-interp-2d, which holds PyAutoGalaxy. Worktree creation deferred until that task ships.
Dependency audit of colossus.halo.concentration.modelLudlow16 (lines 1104-1192)
| Dependency |
JAX porting difficulty |
cosmology.sigma(R, z=0) — RMS mass variance |
Hard — requires P(k) transfer function (Eisenstein-Hu 1998) + window-function integral. ~100-200 lines. |
cosmology.growthFactor(z) — linear growth D(z) |
Easy — flat ΛCDM has closed-form / 1-D quadrature. |
peaks.lagrangianR(M) = (3M / 4π ρ_m,0)^(1/3) |
Trivial. |
profile_einasto.EinastoProfile.enclosedMassInner |
Easy — uses scipy.special.gammainc, mirrored by jax.scipy.special.gammainc. |
scipy.special.erfc |
Trivial — jax.scipy.special.erfc. |
200-point c_array brute-force + np.interp per mass |
Trivial — jnp.interp, fully vectorisable. |
Phase 1 — Feasibility report (A-first)
-
Prototype Approach A — full JAX port. Primary track.
- Port Eisenstein-Hu 1998 transfer function
T(k, Ω_m, Ω_b, h) into a single JAX function.
- Implement
sigma2(R, z=0) = (1/2π²) ∫ k² P(k) W²(kR) dk via fixed-grid log-k quadrature (jax.numpy.trapezoid or Gauss-Legendre).
- Implement closed-form D(z) for flat ΛCDM (Heath 1977 / Carroll-Press-Turner 1992).
- Implement Einasto
enclosedMassInner using jax.scipy.special.gammainc.
- Vectorise the 200-point
c_array solver with jnp.interp.
- Validate against the current callback to ≤ 1% relative error in c200c over the lensing regime (z ∈ [0.1, 2.5], log M ∈ [10, 14]).
- Output: a single
ludlow16_concentration_jax(M200c_Msun, z, h, Om0, Ode0, sigma8, ns) function. JIT and grad through it end-to-end.
-
Decision gate after A prototype:
- If A is clean (accuracy meets target, code is reasonable size, JIT/grad work) → write the report recommending A and move directly to Phase 2 implementation. Skip B and C entirely.
- If A is messy (σ(R) port balloons, accuracy is hard to hit, numerical pathologies under JIT) → fall back to Approach B (lookup table) as the next prototype. Document why A failed.
- If B is also unsuitable (accuracy demands too fine a grid, cosmology-parameter dependence unacceptable) → fall back to Approach C (analytic fit).
-
Approach B — precomputed lookup table (fallback).
- At install or first-import, call the existing colossus path to populate a 2-D grid of c(M, z) for the Planck15 cosmology over (log M ∈ [9, 16], z ∈ [0, 3]) at e.g. 64×64.
- Cache the grid to disk under
autogalaxy/config/cache/ludlow16_planck15.npz so subsequent imports skip colossus entirely.
- At runtime, 2-D interpolate with a JAX-friendly routine.
-
Approach C — analytic fit (last-resort fallback).
- Sample c(M, z) from colossus and fit a Duffy-style power law
c = A (1+z)^B (M/M_pivot)^C (or richer functional form if residuals demand it).
- Trade fidelity for simplicity — only if A and B both fail.
-
Benchmark whichever approach lands against _ludlow16_cosmology_callback:
- Max relative error in c200c over (log M ∈ [10, 14], z ∈ [0.1, 2.5]).
- Max relative error in downstream
kappa_s, scale_radius, radius_at_200 for a representative NFW model.
- Single-call wall time (callback vs JIT-compiled).
- JIT trace + compile time.
- Lines of code added / dependencies introduced.
-
Write a feasibility report summarising:
- Why the callback exists (just the concentration; rest is already JAX-native).
- Which approach was tried and why (chain of fallbacks if any).
- Quantitative comparison vs the current callback.
- Implementation plan for Phase 2.
- Saved under
docs/research/nfw_ludlow16_jax_assessment.md or pasted into the issue if short.
Phase 2 — Implementation (separate follow-up issue)
After the user reviews the report and picks an approach, a follow-up task will:
- Implement the chosen function in
mcr_util.py.
- Collapse
_ludlow16_cosmology_callback and ludlow16_cosmology_jax into a single xp-aware ludlow16_cosmology(mass_at_200, redshift_object, redshift_source, xp=np).
- Drop the
if xp is np: ... else: ... branching in kappa_s_and_scale_radius_for_ludlow and kappa_s_scale_radius_and_core_radius_for_ludlow.
- Make
colossus an optional dependency (only needed by the lookup-table generator under Approach B, removable entirely under A and C).
- Add unit tests for the JAX path (JIT correctness,
jax.grad smoke test).
- Validate the five caller profiles under JIT in
autogalaxy_workspace_test/.
Key Files
autogalaxy/profiles/mass/dark/mcr_util.py — the callback and its callers.
autogalaxy/cosmology/model.py — already-JAX-native Planck15 (no changes; just confirm we read from here rather than colossus' Planck15).
- Callers (read-only in Phase 1):
autogalaxy/profiles/mass/dark/nfw_mcr.py
autogalaxy/profiles/mass/dark/nfw_mcr_scatter.py
autogalaxy/profiles/mass/dark/nfw_truncated_mcr_scatter.py
autogalaxy/profiles/mass/dark/gnfw_mcr.py
autogalaxy/profiles/mass/dark/cnfw_mcr_scatter.py
- Reference (read-only):
colossus/halo/concentration.py::modelLudlow16 (lines 1104-1192)
colossus/halo/profile_einasto.py::EinastoProfile.enclosedMassInner
colossus/cosmology/cosmology.py::Cosmology.sigma, .growthFactor, .transferFunctionEH98
Original Prompt
Click to expand starting prompt
Certain dark matter profiles in mcr_util.py (autogalaxy/profiles/mass/dark) use a pure_callback to do a calculation:
def ludlow16_cosmology_jax(
mass_at_200,
redshift_object,
redshift_source,
):
"""
JAX-safe wrapper around Colossus + Astropy cosmology.
"""
import jax
import jax.numpy as jnp
from jax import ShapeDtypeStruct
return jax.pure_callback(
_ludlow16_cosmology_callback,
(
ShapeDtypeStruct((), jnp.float64), # concentration
ShapeDtypeStruct((), jnp.float64), # rho_crit(z)
ShapeDtypeStruct((), jnp.float64), # Sigma_crit
ShapeDtypeStruct((), jnp.float64), # kpc/arcsec
),
mass_at_200,
redshift_object,
redshift_source,
vmap_method="sequential",
)
Thius is required because the library collosus is not JAX native, and it is not easy to make it JAX native.
Can you inspect the collosus source code and work out whether extracintg the specific functionality we need and making it JAX native is feasible? Alternatively, can you assess if the same calculation can be done using a simpler Python fcuntion does doesntly necesssrily use the collosus code?
Overview
The
_ludlow16_cosmology_callbackinautogalaxy/profiles/mass/dark/mcr_util.pyis wrapped injax.pure_callbackbecause the colossus library is not JAX-native. This blocks JIT optimisation and autodiff through any NFW MCR profile (mp.NFWMCRDuffyPhys,mp.NFWMCRLudlow,mp.NFWMCRScatterLudlow,mp.NFWTruncatedMCRScatterLudlow,mp.GNFWMCRLudlow,mp.CNFWMCRScatterLudlow). This issue assesses whether the colossus dependency can be removed in favour of a JAX-native implementation.Plan
colossus.halo.concentration.modelLudlow16and identify which pieces are truly JAX-hostile (σ(R) via the matter power spectrum is the only real blocker — D(z), Einasto enclosed mass, and the interpolation loop are all straightforward).cosmic_average_density,critical_surface_density,kpc_per_arcsec) are already JAX-native viaautogalaxy.cosmology.model.Planck15— only the concentration call needs replacing.modelLudlow16including Eisenstein-Hu '98 σ(R). If A is clean (accuracy ≤ 1% in the lensing regime, code size reasonable, JIT/grad work), stop and write up — skip B and C.pure_callback, and collapse the JAX/NumPy branches inkappa_s_and_scale_radius_for_ludlowandkappa_s_scale_radius_and_core_radius_for_ludlow.Detailed implementation plan
Affected Repositories
Work Classification
Library (research / feasibility — Phase 1 produces a report, not code changes in production paths)
Branch Survey
Suggested branch:
feature/nfw-jax-portWorktree root:
~/Code/PyAutoLabs-wt/nfw-jax-port/(created later by/start_library)Note: This task is currently queued in
planned.mdbehindjax-interp-2d, which holds PyAutoGalaxy. Worktree creation deferred until that task ships.Dependency audit of
colossus.halo.concentration.modelLudlow16(lines 1104-1192)cosmology.sigma(R, z=0)— RMS mass variancecosmology.growthFactor(z)— linear growth D(z)peaks.lagrangianR(M) = (3M / 4π ρ_m,0)^(1/3)profile_einasto.EinastoProfile.enclosedMassInnerscipy.special.gammainc, mirrored byjax.scipy.special.gammainc.scipy.special.erfcjax.scipy.special.erfc.c_arraybrute-force +np.interpper massjnp.interp, fully vectorisable.Phase 1 — Feasibility report (A-first)
Prototype Approach A — full JAX port. Primary track.
T(k, Ω_m, Ω_b, h)into a single JAX function.sigma2(R, z=0) = (1/2π²) ∫ k² P(k) W²(kR) dkvia fixed-grid log-k quadrature (jax.numpy.trapezoidor Gauss-Legendre).enclosedMassInnerusingjax.scipy.special.gammainc.c_arraysolver withjnp.interp.ludlow16_concentration_jax(M200c_Msun, z, h, Om0, Ode0, sigma8, ns)function. JIT and grad through it end-to-end.Decision gate after A prototype:
Approach B — precomputed lookup table (fallback).
autogalaxy/config/cache/ludlow16_planck15.npzso subsequent imports skip colossus entirely.Approach C — analytic fit (last-resort fallback).
c = A (1+z)^B (M/M_pivot)^C(or richer functional form if residuals demand it).Benchmark whichever approach lands against
_ludlow16_cosmology_callback:kappa_s,scale_radius,radius_at_200for a representative NFW model.Write a feasibility report summarising:
docs/research/nfw_ludlow16_jax_assessment.mdor pasted into the issue if short.Phase 2 — Implementation (separate follow-up issue)
After the user reviews the report and picks an approach, a follow-up task will:
mcr_util.py._ludlow16_cosmology_callbackandludlow16_cosmology_jaxinto a singlexp-awareludlow16_cosmology(mass_at_200, redshift_object, redshift_source, xp=np).if xp is np: ... else: ...branching inkappa_s_and_scale_radius_for_ludlowandkappa_s_scale_radius_and_core_radius_for_ludlow.colossusan optional dependency (only needed by the lookup-table generator under Approach B, removable entirely under A and C).jax.gradsmoke test).autogalaxy_workspace_test/.Key Files
autogalaxy/profiles/mass/dark/mcr_util.py— the callback and its callers.autogalaxy/cosmology/model.py— already-JAX-nativePlanck15(no changes; just confirm we read from here rather than colossus' Planck15).autogalaxy/profiles/mass/dark/nfw_mcr.pyautogalaxy/profiles/mass/dark/nfw_mcr_scatter.pyautogalaxy/profiles/mass/dark/nfw_truncated_mcr_scatter.pyautogalaxy/profiles/mass/dark/gnfw_mcr.pyautogalaxy/profiles/mass/dark/cnfw_mcr_scatter.pycolossus/halo/concentration.py::modelLudlow16(lines 1104-1192)colossus/halo/profile_einasto.py::EinastoProfile.enclosedMassInnercolossus/cosmology/cosmology.py::Cosmology.sigma,.growthFactor,.transferFunctionEH98Original Prompt
Click to expand starting prompt
Certain dark matter profiles in mcr_util.py (autogalaxy/profiles/mass/dark) use a pure_callback to do a calculation:
Thius is required because the library collosus is not JAX native, and it is not easy to make it JAX native.
Can you inspect the collosus source code and work out whether extracintg the specific functionality we need and making it JAX native is feasible? Alternatively, can you assess if the same calculation can be done using a simpler Python fcuntion does doesntly necesssrily use the collosus code?