Skip to content

feat: assess JAX-native replacement for Ludlow16 concentration pure_callback #397

@Jammy2211

Description

@Jammy2211

Overview

The _ludlow16_cosmology_callback in autogalaxy/profiles/mass/dark/mcr_util.py is wrapped in jax.pure_callback because the colossus library is not JAX-native. This blocks JIT optimisation and autodiff through any NFW MCR profile (mp.NFWMCRDuffyPhys, mp.NFWMCRLudlow, mp.NFWMCRScatterLudlow, mp.NFWTruncatedMCRScatterLudlow, mp.GNFWMCRLudlow, mp.CNFWMCRScatterLudlow). This issue assesses whether the colossus dependency can be removed in favour of a JAX-native implementation.

Plan

  • Audit the dependency surface of colossus.halo.concentration.modelLudlow16 and identify which pieces are truly JAX-hostile (σ(R) via the matter power spectrum is the only real blocker — D(z), Einasto enclosed mass, and the interpolation loop are all straightforward).
  • Note that three of the four values returned by the current callback (cosmic_average_density, critical_surface_density, kpc_per_arcsec) are already JAX-native via autogalaxy.cosmology.model.Planck15 — only the concentration call needs replacing.
  • Prototype Approach A first: a full JAX port of modelLudlow16 including Eisenstein-Hu '98 σ(R). If A is clean (accuracy ≤ 1% in the lensing regime, code size reasonable, JIT/grad work), stop and write up — skip B and C.
  • Only if A turns out to be messy, fall back in order to: (B) a precomputed 2-D lookup table c(M, z) for Planck15 interpolated in JAX at runtime, then (C) a closed-form analytic fit calibrated to Ludlow16 over the lensing parameter regime.
  • Deliver a written feasibility report with a recommendation. A separate follow-up issue will implement the chosen approach, remove the pure_callback, and collapse the JAX/NumPy branches in kappa_s_and_scale_radius_for_ludlow and kappa_s_scale_radius_and_core_radius_for_ludlow.
Detailed implementation plan

Affected Repositories

  • PyAutoGalaxy (primary — feasibility report; later implementation lives here)

Work Classification

Library (research / feasibility — Phase 1 produces a report, not code changes in production paths)

Branch Survey

Repository Current Branch Dirty?
./PyAutoGalaxy main clean (modulo a benign CLAUDE.md reformat)

Suggested branch: feature/nfw-jax-port
Worktree root: ~/Code/PyAutoLabs-wt/nfw-jax-port/ (created later by /start_library)

Note: This task is currently queued in planned.md behind jax-interp-2d, which holds PyAutoGalaxy. Worktree creation deferred until that task ships.

Dependency audit of colossus.halo.concentration.modelLudlow16 (lines 1104-1192)

Dependency JAX porting difficulty
cosmology.sigma(R, z=0) — RMS mass variance Hard — requires P(k) transfer function (Eisenstein-Hu 1998) + window-function integral. ~100-200 lines.
cosmology.growthFactor(z) — linear growth D(z) Easy — flat ΛCDM has closed-form / 1-D quadrature.
peaks.lagrangianR(M) = (3M / 4π ρ_m,0)^(1/3) Trivial.
profile_einasto.EinastoProfile.enclosedMassInner Easy — uses scipy.special.gammainc, mirrored by jax.scipy.special.gammainc.
scipy.special.erfc Trivial — jax.scipy.special.erfc.
200-point c_array brute-force + np.interp per mass Trivial — jnp.interp, fully vectorisable.

Phase 1 — Feasibility report (A-first)

  1. Prototype Approach A — full JAX port. Primary track.

    • Port Eisenstein-Hu 1998 transfer function T(k, Ω_m, Ω_b, h) into a single JAX function.
    • Implement sigma2(R, z=0) = (1/2π²) ∫ k² P(k) W²(kR) dk via fixed-grid log-k quadrature (jax.numpy.trapezoid or Gauss-Legendre).
    • Implement closed-form D(z) for flat ΛCDM (Heath 1977 / Carroll-Press-Turner 1992).
    • Implement Einasto enclosedMassInner using jax.scipy.special.gammainc.
    • Vectorise the 200-point c_array solver with jnp.interp.
    • Validate against the current callback to ≤ 1% relative error in c200c over the lensing regime (z ∈ [0.1, 2.5], log M ∈ [10, 14]).
    • Output: a single ludlow16_concentration_jax(M200c_Msun, z, h, Om0, Ode0, sigma8, ns) function. JIT and grad through it end-to-end.
  2. Decision gate after A prototype:

    • If A is clean (accuracy meets target, code is reasonable size, JIT/grad work) → write the report recommending A and move directly to Phase 2 implementation. Skip B and C entirely.
    • If A is messy (σ(R) port balloons, accuracy is hard to hit, numerical pathologies under JIT) → fall back to Approach B (lookup table) as the next prototype. Document why A failed.
    • If B is also unsuitable (accuracy demands too fine a grid, cosmology-parameter dependence unacceptable) → fall back to Approach C (analytic fit).
  3. Approach B — precomputed lookup table (fallback).

    • At install or first-import, call the existing colossus path to populate a 2-D grid of c(M, z) for the Planck15 cosmology over (log M ∈ [9, 16], z ∈ [0, 3]) at e.g. 64×64.
    • Cache the grid to disk under autogalaxy/config/cache/ludlow16_planck15.npz so subsequent imports skip colossus entirely.
    • At runtime, 2-D interpolate with a JAX-friendly routine.
  4. Approach C — analytic fit (last-resort fallback).

    • Sample c(M, z) from colossus and fit a Duffy-style power law c = A (1+z)^B (M/M_pivot)^C (or richer functional form if residuals demand it).
    • Trade fidelity for simplicity — only if A and B both fail.
  5. Benchmark whichever approach lands against _ludlow16_cosmology_callback:

    • Max relative error in c200c over (log M ∈ [10, 14], z ∈ [0.1, 2.5]).
    • Max relative error in downstream kappa_s, scale_radius, radius_at_200 for a representative NFW model.
    • Single-call wall time (callback vs JIT-compiled).
    • JIT trace + compile time.
    • Lines of code added / dependencies introduced.
  6. Write a feasibility report summarising:

    • Why the callback exists (just the concentration; rest is already JAX-native).
    • Which approach was tried and why (chain of fallbacks if any).
    • Quantitative comparison vs the current callback.
    • Implementation plan for Phase 2.
    • Saved under docs/research/nfw_ludlow16_jax_assessment.md or pasted into the issue if short.

Phase 2 — Implementation (separate follow-up issue)

After the user reviews the report and picks an approach, a follow-up task will:

  1. Implement the chosen function in mcr_util.py.
  2. Collapse _ludlow16_cosmology_callback and ludlow16_cosmology_jax into a single xp-aware ludlow16_cosmology(mass_at_200, redshift_object, redshift_source, xp=np).
  3. Drop the if xp is np: ... else: ... branching in kappa_s_and_scale_radius_for_ludlow and kappa_s_scale_radius_and_core_radius_for_ludlow.
  4. Make colossus an optional dependency (only needed by the lookup-table generator under Approach B, removable entirely under A and C).
  5. Add unit tests for the JAX path (JIT correctness, jax.grad smoke test).
  6. Validate the five caller profiles under JIT in autogalaxy_workspace_test/.

Key Files

  • autogalaxy/profiles/mass/dark/mcr_util.py — the callback and its callers.
  • autogalaxy/cosmology/model.py — already-JAX-native Planck15 (no changes; just confirm we read from here rather than colossus' Planck15).
  • Callers (read-only in Phase 1):
    • autogalaxy/profiles/mass/dark/nfw_mcr.py
    • autogalaxy/profiles/mass/dark/nfw_mcr_scatter.py
    • autogalaxy/profiles/mass/dark/nfw_truncated_mcr_scatter.py
    • autogalaxy/profiles/mass/dark/gnfw_mcr.py
    • autogalaxy/profiles/mass/dark/cnfw_mcr_scatter.py
  • Reference (read-only):
    • colossus/halo/concentration.py::modelLudlow16 (lines 1104-1192)
    • colossus/halo/profile_einasto.py::EinastoProfile.enclosedMassInner
    • colossus/cosmology/cosmology.py::Cosmology.sigma, .growthFactor, .transferFunctionEH98

Original Prompt

Click to expand starting prompt

Certain dark matter profiles in mcr_util.py (autogalaxy/profiles/mass/dark) use a pure_callback to do a calculation:

def ludlow16_cosmology_jax(
    mass_at_200,
    redshift_object,
    redshift_source,
):
    """
    JAX-safe wrapper around Colossus + Astropy cosmology.
    """
    import jax
    import jax.numpy as jnp
    from jax import ShapeDtypeStruct

    return jax.pure_callback(
        _ludlow16_cosmology_callback,
        (
            ShapeDtypeStruct((), jnp.float64),  # concentration
            ShapeDtypeStruct((), jnp.float64),  # rho_crit(z)
            ShapeDtypeStruct((), jnp.float64),  # Sigma_crit
            ShapeDtypeStruct((), jnp.float64),  # kpc/arcsec
        ),
        mass_at_200,
        redshift_object,
        redshift_source,
        vmap_method="sequential",
    )

Thius is required because the library collosus is not JAX native, and it is not easy to make it JAX native.

Can you inspect the collosus source code and work out whether extracintg the specific functionality we need and making it JAX native is feasible? Alternatively, can you assess if the same calculation can be done using a simpler Python fcuntion does doesntly necesssrily use the collosus code?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions