Skip to content

feat: end to end jax.jit (and ideally vmap) for a multi plane substructure simulator #542

@mwiet

Description

@mwiet

What I'm trying to do

I'm building a forward modelling pipeline on top of PyAutoLens for dark matter substructure. The forward model is a fairly standard galaxy scale strong lens with a couple of populations of perturbers layered on top:

  • one macro lens (power law plus external shear) at z_lens = 0.5,
  • a handful of lens plane subhaloes (typically 5 to 50, NFWTruncatedSph for CDM and WDM, a mix of cored cNFWSph / cNFWMCRLudlowSph and core collapsed NFWTruncatedSph for SIDM),
  • a multi plane line of sight (LOS) population of order 1000 halos spread across 8 planes between observer and source, built via al.Tracer and the existing LOSSampler helpers,
  • a Gaussian PSF (al.Kernel2D style), Poisson noise via al.SimulatorImaging, an over sampled al.Grid2D.uniform.

The hot loop is conceptually one function, theta -> noisy image, and the calling code wants to evaluate it of order 10^6 times with different theta. On the CPU this is the wall clock bottleneck of the whole project by a comfortable margin. A jit compiled, ideally vmap batched and GPU runnable, version of that path would change the practical scale of what we can do by a couple of orders of magnitude.

The point of this issue is to ask whether the simulator side specifically can be added to the JAX roadmap, and to flag the bits that this kind of substructure model actually needs so the priority ordering can be informed by a concrete use case.

What already exists (no need to redo any of this)

Skimming the issue trackers, quite a lot of the groundwork is either done or already in flight. Listing what I found so this issue can focus on the gaps.

PyAutoArray:

PyAutoGalaxy:

PyAutoLens:

So the big picture is: the inference side (analysis.fit_from) is where most of the open jax.jit work lives at the moment, plus the last few xp propagation fixes around Grid2DIrregular and the CSE port. The bit that doesn't seem to be explicitly tracked yet is the substructure forward simulator path: al.SimulatorImaging.via_tracer_from(tracer=...) for a multi plane Tracer carrying a long, possibly variable length, list of dark matter perturbers.

What I'm asking for

A roadmap covering end to end jax.jit of the substructure forward simulator. Concretely, the following pieces would need to be JAX traceable and jit friendly so that this kind of multi plane substructure simulation can run inside jax.jit.

1. Dark matter mass profiles under convergence_2d_from and deflections_yx_2d_from

The relevant profiles are:

  • autogalaxy.profiles.mass.dark.nfw_truncated.NFWTruncatedSph,
  • autogalaxy.profiles.mass.dark.nfw_truncated_mcr.NFWTruncatedMCRLudlowSph,
  • autogalaxy.profiles.mass.dark.cnfw.cNFWSph,
  • autogalaxy.profiles.mass.dark.cnfw_mcr.cNFWMCRLudlowSph (cored SIDM, lens plane).

The closed #403 /#397 already give us a JAX native Ludlow16, so the MCR branch is presumably partway there. What's missing for the substructure path is vmap friendliness across a batch of N halos with different mass_at_200, concentration, centre, f_c, plus the same on the deflection field. Bluntly: vmap(profile.deflections_yx_2d_from) over the halo axis ought to give us the contribution of every halo with a single GPU launch, then a sum gives the plane deflection. Currently this is a Python loop over halos inside Tracer.

A precondition for the cNFW bits of this is the deflection sign flip and zero convergence bug in autogalaxy.profiles.mass.dark.cnfw that I've reported separately: PyAutoLabs/PyAutoGalaxy#451. That one should land first, otherwise jit'ing the broken branch just makes the broken behaviour faster.

2. Multi plane Tracer traceable under jit with O(1000) galaxies

A typical realisation builds a Tracer(galaxies=[macro, *subhaloes, *los, source]) with often 1000+ entries. Two things make this jit hostile today:

  • The galaxy list length is theta dependent (different draws of the SHMF give different N), and jit specialises on shape, so each new N would trigger a recompile.
  • The multi plane recursion inside Tracer uses Python iteration over per plane galaxy lists.

The usual JAX idioms (pad to a max N, mask the unused slots, jax.lax.scan over planes) would address both. I'd happily prototype this against a stripped down branch if it would help focus the design.

3. LOSSampler and friends

The sampler itself doesn't need to be jit'd, it's called once per realisation and dominates approximately none of the wall clock. But the output of the sampler currently feeds galaxies straight into the Tracer, so whatever decision is made on padded vs variable length galaxy lists in (2) drives the API here too. The closed #420 (LOS test slimming) suggests this area has been getting some love already.

4. Convolver / PSF convolution under JAX

The forward model uses a Gaussian PSF (al.Kernel2D) plus an over sampled al.Grid2D.uniform. JAX's own jax.scipy.signal.fftconvolve would be the obvious backend on the JAX path, but the wiring isn't there today as far as I can tell (there's no convolver specific issue I could find). Even a small use_jax aware shim around Convolver.convolved_image_from would cover the relevant use case.

5. Poisson noise with a JAX PRNGKey

al.SimulatorImaging(..., add_poisson_noise=True, noise_seed=...) currently takes a numpy seed. To make theta -> noisy image referentially transparent, jit friendly, and vmap able over a batch of keys, the simulator wants to thread a PRNGKey through instead. jax.random.poisson already exists, so this is fundamentally a plumbing change rather than a numerical one.

6. Macro lens (mp.PowerLaw plus mp.ExternalShear) and source (lp.SersicCore)

These are simpler than the dark profiles and I'd expect them to be mostly there once the broader mass / light profile JAX work is wrapped up. Flagging here just so the priority list captures them too.

7. (Stretch) vmap over theta for batched evaluation

The headline win is vmap(jit(simulate))(thetas, keys) running of order 1024 lensed images per GPU launch. Everything above is a prerequisite, but it's worth naming the stretch goal explicitly because some of the design choices (especially padded vs ragged galaxy lists in 2) make a much bigger difference once vmap is on the table.

Suggested ordering

If it's useful, here's how I'd sequence the chunks above against the in flight work:

  1. Land the cNFW deflection bug fix (separate issue) so the cored branch isn't a moving target.
  2. Land feature/vectorised triangles #286 (Grid2DIrregular xp propagation) and remove preloads #306 (regular grid interpolator JAX path); both look like they unblock several downstream simulator pieces.
  3. Open a new issue for the dark profile vmap friendliness (item 1 above), since that's the bit none of the existing JAX issues explicitly cover.
  4. Decide on the padded vs variable length galaxy list strategy (item 2). This is the single design decision that drives most of the rest.
  5. Convolver and Poisson noise on the JAX path (items 4 and 5).
  6. End to end jit(simulate) smoke test on a representative substructure configuration as a regression target, then the vmap extension (item 7).

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions