You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm building a forward modelling pipeline on top of PyAutoLens for dark matter substructure. The forward model is a fairly standard galaxy scale strong lens with a couple of populations of perturbers layered on top:
one macro lens (power law plus external shear) at z_lens = 0.5,
a handful of lens plane subhaloes (typically 5 to 50, NFWTruncatedSph for CDM and WDM, a mix of cored cNFWSph / cNFWMCRLudlowSph and core collapsed NFWTruncatedSph for SIDM),
a multi plane line of sight (LOS) population of order 1000 halos spread across 8 planes between observer and source, built via al.Tracer and the existing LOSSampler helpers,
a Gaussian PSF (al.Kernel2D style), Poisson noise via al.SimulatorImaging, an over sampled al.Grid2D.uniform.
The hot loop is conceptually one function, theta -> noisy image, and the calling code wants to evaluate it of order 10^6 times with different theta. On the CPU this is the wall clock bottleneck of the whole project by a comfortable margin. A jit compiled, ideally vmap batched and GPU runnable, version of that path would change the practical scale of what we can do by a couple of orders of magnitude.
The point of this issue is to ask whether the simulator side specifically can be added to the JAX roadmap, and to flag the bits that this kind of substructure model actually needs so the priority ordering can be informed by a concrete use case.
What already exists (no need to redo any of this)
Skimming the issue trackers, quite a lot of the groundwork is either done or already in flight. Listing what I found so this issue can focus on the gaps.
remove preloads #306: 2D regular grid interpolator with JAX path, wires into DatasetInterp.
feature/vectorised triangles #286: propagate xp through Grid2DIrregular derived constructors (called out as the thing that unblocks the source plane JIT).
So the big picture is: the inference side (analysis.fit_from) is where most of the open jax.jit work lives at the moment, plus the last few xp propagation fixes around Grid2DIrregular and the CSE port. The bit that doesn't seem to be explicitly tracked yet is the substructure forward simulator path: al.SimulatorImaging.via_tracer_from(tracer=...) for a multi plane Tracer carrying a long, possibly variable length, list of dark matter perturbers.
What I'm asking for
A roadmap covering end to end jax.jit of the substructure forward simulator. Concretely, the following pieces would need to be JAX traceable and jit friendly so that this kind of multi plane substructure simulation can run inside jax.jit.
1. Dark matter mass profiles under convergence_2d_from and deflections_yx_2d_from
The closed #403 /#397 already give us a JAX native Ludlow16, so the MCR branch is presumably partway there. What's missing for the substructure path is vmap friendliness across a batch of N halos with different mass_at_200, concentration, centre, f_c, plus the same on the deflection field. Bluntly: vmap(profile.deflections_yx_2d_from) over the halo axis ought to give us the contribution of every halo with a single GPU launch, then a sum gives the plane deflection. Currently this is a Python loop over halos inside Tracer.
A precondition for the cNFW bits of this is the deflection sign flip and zero convergence bug in autogalaxy.profiles.mass.dark.cnfw that I've reported separately: PyAutoLabs/PyAutoGalaxy#451. That one should land first, otherwise jit'ing the broken branch just makes the broken behaviour faster.
2. Multi plane Tracer traceable under jit with O(1000) galaxies
A typical realisation builds a Tracer(galaxies=[macro, *subhaloes, *los, source]) with often 1000+ entries. Two things make this jit hostile today:
The galaxy list length is theta dependent (different draws of the SHMF give different N), and jit specialises on shape, so each new N would trigger a recompile.
The multi plane recursion inside Tracer uses Python iteration over per plane galaxy lists.
The usual JAX idioms (pad to a max N, mask the unused slots, jax.lax.scan over planes) would address both. I'd happily prototype this against a stripped down branch if it would help focus the design.
3. LOSSampler and friends
The sampler itself doesn't need to be jit'd, it's called once per realisation and dominates approximately none of the wall clock. But the output of the sampler currently feeds galaxies straight into the Tracer, so whatever decision is made on padded vs variable length galaxy lists in (2) drives the API here too. The closed #420 (LOS test slimming) suggests this area has been getting some love already.
4. Convolver / PSF convolution under JAX
The forward model uses a Gaussian PSF (al.Kernel2D) plus an over sampled al.Grid2D.uniform. JAX's own jax.scipy.signal.fftconvolve would be the obvious backend on the JAX path, but the wiring isn't there today as far as I can tell (there's no convolver specific issue I could find). Even a small use_jax aware shim around Convolver.convolved_image_from would cover the relevant use case.
5. Poisson noise with a JAX PRNGKey
al.SimulatorImaging(..., add_poisson_noise=True, noise_seed=...) currently takes a numpy seed. To make theta -> noisy image referentially transparent, jit friendly, and vmap able over a batch of keys, the simulator wants to thread a PRNGKey through instead. jax.random.poisson already exists, so this is fundamentally a plumbing change rather than a numerical one.
6. Macro lens (mp.PowerLaw plus mp.ExternalShear) and source (lp.SersicCore)
These are simpler than the dark profiles and I'd expect them to be mostly there once the broader mass / light profile JAX work is wrapped up. Flagging here just so the priority list captures them too.
7. (Stretch) vmap over theta for batched evaluation
The headline win is vmap(jit(simulate))(thetas, keys) running of order 1024 lensed images per GPU launch. Everything above is a prerequisite, but it's worth naming the stretch goal explicitly because some of the design choices (especially padded vs ragged galaxy lists in 2) make a much bigger difference once vmap is on the table.
Suggested ordering
If it's useful, here's how I'd sequence the chunks above against the in flight work:
Land the cNFW deflection bug fix (separate issue) so the cored branch isn't a moving target.
What I'm trying to do
I'm building a forward modelling pipeline on top of PyAutoLens for dark matter substructure. The forward model is a fairly standard galaxy scale strong lens with a couple of populations of perturbers layered on top:
z_lens = 0.5,NFWTruncatedSphfor CDM and WDM, a mix of coredcNFWSph/cNFWMCRLudlowSphand core collapsedNFWTruncatedSphfor SIDM),al.Tracerand the existingLOSSamplerhelpers,al.Kernel2Dstyle), Poisson noise viaal.SimulatorImaging, an over sampledal.Grid2D.uniform.The hot loop is conceptually one function,
theta -> noisy image, and the calling code wants to evaluate it of order 10^6 times with differenttheta. On the CPU this is the wall clock bottleneck of the whole project by a comfortable margin. A jit compiled, ideallyvmapbatched and GPU runnable, version of that path would change the practical scale of what we can do by a couple of orders of magnitude.The point of this issue is to ask whether the simulator side specifically can be added to the JAX roadmap, and to flag the bits that this kind of substructure model actually needs so the priority ordering can be informed by a concrete use case.
What already exists (no need to redo any of this)
Skimming the issue trackers, quite a lot of the groundwork is either done or already in flight. Listing what I found so this issue can focus on the gaps.
PyAutoArray:
use_jax=Trueflag on Simulators and PointSolver.Array2D.nativejit traceable for the JAX simulator path.xpinterface for user facing introduction.ShapeSolverandCircle#320:AbstractMeshGeometrypicklable,xpmodule replaced by_use_jaxbool plus property.vmap.jaxnnlsbackward pass.DatasetInterp.xpthroughGrid2DIrregularderived constructors (called out as the thing that unblocks the source plane JIT).TransformerNUFFTmigration.PyAutoGalaxy:
self.centreJAX traceable undervmap.pure_callbackwith a JAX native implementation.xpparameter.EllipseMultipoleScaledJAX traceable (move derivation out of__init__).print_vram_usefails onlmp.Sersiclenses under a JAX tracer.AnalysisEllipsefor JAX viafit_fromand pytree registration.PyAutoLens:
TracerBoolConversionErrorunder JAX.AnalysisPointauto configuresPointSolverwhenuse_jax=True.FitImagingforjax.jitreturnable fits (Path A).FitImaging.jax.jit(analysis.fit_from)forFitInterferometerplus MGE source.jax.jit(analysis.fit_from)forFitPointDataset.jit(fit_from)round trip for double source plane plus rectangular, and Delaunay source plus MGE lens.So the big picture is: the inference side (
analysis.fit_from) is where most of the openjax.jitwork lives at the moment, plus the last fewxppropagation fixes aroundGrid2DIrregularand the CSE port. The bit that doesn't seem to be explicitly tracked yet is the substructure forward simulator path:al.SimulatorImaging.via_tracer_from(tracer=...)for a multi planeTracercarrying a long, possibly variable length, list of dark matter perturbers.What I'm asking for
A roadmap covering end to end
jax.jitof the substructure forward simulator. Concretely, the following pieces would need to be JAX traceable and jit friendly so that this kind of multi plane substructure simulation can run insidejax.jit.1. Dark matter mass profiles under
convergence_2d_fromanddeflections_yx_2d_fromThe relevant profiles are:
autogalaxy.profiles.mass.dark.nfw_truncated.NFWTruncatedSph,autogalaxy.profiles.mass.dark.nfw_truncated_mcr.NFWTruncatedMCRLudlowSph,autogalaxy.profiles.mass.dark.cnfw.cNFWSph,autogalaxy.profiles.mass.dark.cnfw_mcr.cNFWMCRLudlowSph(cored SIDM, lens plane).The closed #403 /#397 already give us a JAX native Ludlow16, so the MCR branch is presumably partway there. What's missing for the substructure path is
vmapfriendliness across a batch of N halos with differentmass_at_200,concentration,centre,f_c, plus the same on the deflection field. Bluntly:vmap(profile.deflections_yx_2d_from)over the halo axis ought to give us the contribution of every halo with a single GPU launch, then a sum gives the plane deflection. Currently this is a Python loop over halos insideTracer.A precondition for the cNFW bits of this is the deflection sign flip and zero convergence bug in
autogalaxy.profiles.mass.dark.cnfwthat I've reported separately: PyAutoLabs/PyAutoGalaxy#451. That one should land first, otherwise jit'ing the broken branch just makes the broken behaviour faster.2. Multi plane
Tracertraceable underjitwith O(1000) galaxiesA typical realisation builds a
Tracer(galaxies=[macro, *subhaloes, *los, source])with often 1000+ entries. Two things make this jit hostile today:thetadependent (different draws of the SHMF give different N), and jit specialises on shape, so each new N would trigger a recompile.Traceruses Python iteration over per plane galaxy lists.The usual JAX idioms (pad to a max N, mask the unused slots,
jax.lax.scanover planes) would address both. I'd happily prototype this against a stripped down branch if it would help focus the design.3.
LOSSamplerand friendsThe sampler itself doesn't need to be jit'd, it's called once per realisation and dominates approximately none of the wall clock. But the output of the sampler currently feeds galaxies straight into the
Tracer, so whatever decision is made on padded vs variable length galaxy lists in (2) drives the API here too. The closed #420 (LOS test slimming) suggests this area has been getting some love already.4.
Convolver/ PSF convolution under JAXThe forward model uses a Gaussian PSF (
al.Kernel2D) plus an over sampledal.Grid2D.uniform. JAX's ownjax.scipy.signal.fftconvolvewould be the obvious backend on the JAX path, but the wiring isn't there today as far as I can tell (there's no convolver specific issue I could find). Even a smalluse_jaxaware shim aroundConvolver.convolved_image_fromwould cover the relevant use case.5. Poisson noise with a JAX
PRNGKeyal.SimulatorImaging(..., add_poisson_noise=True, noise_seed=...)currently takes a numpy seed. To maketheta -> noisy imagereferentially transparent, jit friendly, andvmapable over a batch of keys, the simulator wants to thread aPRNGKeythrough instead.jax.random.poissonalready exists, so this is fundamentally a plumbing change rather than a numerical one.6. Macro lens (
mp.PowerLawplusmp.ExternalShear) and source (lp.SersicCore)These are simpler than the dark profiles and I'd expect them to be mostly there once the broader mass / light profile JAX work is wrapped up. Flagging here just so the priority list captures them too.
7. (Stretch)
vmapoverthetafor batched evaluationThe headline win is
vmap(jit(simulate))(thetas, keys)running of order 1024 lensed images per GPU launch. Everything above is a prerequisite, but it's worth naming the stretch goal explicitly because some of the design choices (especially padded vs ragged galaxy lists in 2) make a much bigger difference oncevmapis on the table.Suggested ordering
If it's useful, here's how I'd sequence the chunks above against the in flight work:
xppropagation) and remove preloads #306 (regular grid interpolator JAX path); both look like they unblock several downstream simulator pieces.vmapfriendliness (item 1 above), since that's the bit none of the existing JAX issues explicitly cover.jit(simulate)smoke test on a representative substructure configuration as a regression target, then thevmapextension (item 7).