You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Production lens-modelling runs nested sampling, which proposes parameter
vectors in batches and so always goes through the JAX likelihood under vmap (batch=3 today). PR #60 in autolens_workspace_developer made the
new bottleneck picture concrete at the production-fiducial setup
(MGE-60 lens + pixelized source, K ≈ 1285–1291): A100 Delaunay
vmap-per-call is 438 ms, an 8.84× regression vs the 49.6 ms
single-JIT timing. Per-step probes show jaxnnls's PDIP solver is the
universal vmap bottleneck, while the previously-suspected scipy.spatial.Delaunaypure_callback actually scales sublinearly
under vmap. The headline science improvement is hitting the ≤ 200 ms
Delaunay A100 vmap target — a 2× production speedup on the hot path.
This task is more than a MAX_ITER sweep. NNLS is structurally required
for the physics, but the algorithm is not pinned to PDIP. The plan
opens with an algorithmic survey (PDIP, ADMM, FISTA, Schur-complement),
resolves the unexplained 7× A100 vmap-regress factor before structural
changes, sweeps iter/tolerance on whichever algorithm wins, vendors it
into autoarray/inversion/nnls/ to drop the external jaxnnls
dependency, and validates the change across all 6 imaging +
interferometer production pipelines plus the well/ill-conditioned
matrix in autolens_workspace_test/scripts/jax_assertions/nnls.py.
Plan
Algorithmic survey: prototype PDIP / ADMM / FISTA on synthetic
production-shaped Q, time under single + vmap, pick the winner based
on wall-clock × convergence × numerical stability.
Resolve the 8.84× A100 mystery first by porting the per-step vmap
probe to A100 — the ~7× unexplained factor decides whether the lever
lives in the NNLS algorithm or in XLA scheduling / memory layout.
Sweep iter / tolerance on the chosen algorithm at production K
against the full cross-pipeline matrix, not just HST.
Audit whether the relaxed-KKT backward pass is actually live in
production sampling — if not, expose a forward-only entry point.
Vendor the chosen NNLS into autoarray/inversion/nnls/ with full
attribution to Tracy (qpax) and Krawczyk (jaxnnls modifications),
drop jaxnnls from pyproject.toml, and switch inversion_util.reconstruction_positive_only_from to the vendored
module.
Fix the lying docstring in reconstruction_positive_only_from
(currently claims fnnls/Bro-Jong; JAX branch actually calls jaxnnls
PDIP).
Refresh regression artifacts in autolens_workspace_developer across
all 6 imaging + interferometer pipelines.
Land as a coordinated trio of PRs: library first, then workspace
artifact refresh + test-script update.
Library (PyAutoArray is primary; workspace companions are artifact and
test-script edits, no production-script logic changes).
Branch Survey
Repository
Current Branch
Dirty?
./PyAutoArray
main
clean
./autolens_workspace_developer
main
dirty (unrelated euclid_bug/ work; will not be picked up by the worktree branch)
./autolens_workspace_test
main
(to be verified at /start_library time)
Suggested branch:feature/nnls-vmap-speedup Worktree root:~/Code/PyAutoLabs-wt/nnls-vmap-speedup/ (created later by /start_library)
Implementation Steps
Algorithmic survey in z_projects/profiling/scripts/nnls_prototypes/. Prototype pdip_baseline.py, admm_prototype.py, fista_prototype.py and a nnls_bench.py that times all candidates on synthetic Q matching production shape/conditioning under single-JIT + vmap=3. Decide based on (1) hits Delaunay A100 ≤ 200 ms target, (2) holds log-evidence rtol=1e-4 across cross-pipeline matrix, (3) no NaN gradients across condition classes, (4) smallest diff against vendored jaxnnls.
A100 vmap probe — add submit_{delaunay,rectangular}_vmap_probe_fp64 SLURM scripts, hpc/sync push, submit, pull, inspect per-step ratios. If A100 NNLS-vmap is ~1.27× like RTX 2060, the 7× lives in XLA scheduling — diagnose via jax.profiler.start_trace and memory_analysis().
Iter/tolerance sweep on the chosen algorithm. PDIP: MAX_ITER ∈ {10, 15, 20, 30, 50}. ADMM: (ρ, max_iter) over {1e-2..10} × {30, 50, 100, 200}. FISTA: max_iter ∈ {100, 200, 500, 1000}. Record wall-clock + Δlog-evidence + reconstruction-vector L2 vs main across all 6 production pipelines.
Gradient infrastructure audit. Dump compiled HLO from jax.jit(jax.vmap(likelihood)) and grep for pdip_relaxed / diff_nnls. If absent → expose solve_nnls_forward(Q, q) returning just primal, skipping custom_vjp. Switch inversion_util to the forward-only entry under a nnls_forward_only config flag.
Vendor the chosen algorithm into PyAutoArray/autoarray/inversion/nnls/:
Algorithm files copied verbatim from jaxnnls 1.0.1 (PDIP path) plus any new files (ADMM/FISTA).
LICENSE.txt with Tracy/Krawczyk attribution.
Module-scope constants (MAX_ITER, etc.) replaced with autoconf reads at the entry function level (not inside jax.jit-traced loops).
Tests at test_autoarray/inversion/inversion/nnls/test_solver.py — scipy-NNLS parity at K=20, primal-forward parity, max_iter config honoured, well-conditioned K=1000, ill-conditioned K=1000 with Jacobi preconditioning (no NaN).
Switch call site in inversion_util.py:275-321 from import jaxnnls to from autoarray.inversion import nnls. Fix docstring at lines 241-273 to describe both numpy fnnls and JAX PDIP/ADMM/FISTA branches accurately.
Config in autoarray/config/general.yaml: keep existing nnls_jacobi_preconditioning, nnls_target_kappa; add nnls_algorithm, nnls_max_iter, nnls_solver_tol, and nnls_forward_only (if Step 4 produced such a switch).
Drop jaxnnls from pyproject.toml.
Cross-pipeline correctness matrix — local RTX 2060 fp64 first, then A100 via HPC: imaging-mge, imaging-pixelization (HST 24746.105672366088), imaging-delaunay (HST 26288.321397232066), interferometer-mge, interferometer-pixelization (SMA -3165.251161569041), interferometer-delaunay. Plus autolens_workspace_test/scripts/jax_assertions/nnls.py well-conditioned match and ill-conditioned finite-gradient checks.
Refresh artifacts in autolens_workspace_developer/jax_profiling/results/jit/{imaging,interferometer}/{mge,pixelization,delaunay}/ and update EXPECTED_LOG_EVIDENCE_* constants if drift detected (artifact mtimes advance regardless).
Validation gates (PR cannot land without all)
All 6 EXPECTED_LOG_EVIDENCE_* assertions pass at rtol=1e-4 fp64 / 1e-2 mp.
Delaunay A100 vmap_per_call ≤ 200 ms (conservative; headline target ≤ 100 ms). Rect A100 vmap_per_call ≤ 25 ms.
autolens_workspace_test/scripts/jax_assertions/nnls.py passes well-conditioned-match and ill-conditioned-finite-gradient checks.
pytest test_autoarray/inversion/ clean.
After pip uninstall jaxnnls, from autoarray.inversion import nnls; nnls.solve_nnls_primal(...) still works.
Backward differentiation through solve_nnls_primal on the ill-conditioned case from jax_assertions/nnls.py returns finite gradients.
Key Files
PyAutoArray/autoarray/inversion/nnls/ (new directory) — vendored solver
PyAutoArray/autoarray/inversion/inversion/inversion_util.py:275-321 — call site + docstring
Speed up the JAX NNLS solver under vmap at production-fiducial size
This supersedesautoarray/nnls_gpu_bottleneck.md — that earlier prompt
was written from MGE-only profiling at a 40-element problem and prioritised
single-JIT GPU performance. Production runs use vmap (nested-sampling
proposes batches of parameter vectors) at much larger problem sizes, and
PR #60 (autolens_workspace_developer) makes the new bottleneck picture
concrete.
The earlier prompt's 5-lever structure (MAX_ITER, Cholesky vectorisation,
gradient-infrastructure check, K≤N closed-form bypass, vendor jaxnnls) is
all still valid — keep it as the "how to attack" reference. This prompt
re-frames "what to attack first" given the new vmap-relevant data.
Combined K (NNLS dimension) = ~1285 for rectangular, ~1291 for Delaunay
A100 fp64 timings at this fiducial:
Likelihood
Single-JIT full
vmap (batch=3) per call
vmap regress
Rectangular
25.1 ms
34.8 ms
1.39×
Delaunay
49.6 ms
438 ms
8.84×
The Delaunay-vmap regression is the science-throughput wall. RTX 2060
shows the same pattern less dramatically (Delaunay vmap 590 → 1217 ms mp).
Per-step vmap probes confirmed scipy is NOT the issue
Earlier rounds suspected scipy.spatial.Delaunay (inside jax.pure_callback
with vmap_method="sequential") was the problem. Per-step vmap probes
on RTX 2060 (committed in z_projects/profiling/scripts/{delaunay, rectangular}_vmap_probe.py) flip the story:
Step
single (ms)
vmap (ms/call)
ratio
Inversion setup (incl. scipy pure_callback)
127
108
0.85×
NNLS reconstruction
111
141
1.27×
Mapped + log_ev (slogdet)
33
32
0.98×
pure_callback actually scales sublinearly under vmap — its dispatch
overhead amortises across batch elements. NNLS scales superlinearly
because its iterations are inherently serial per batch element.
So: NNLS is the universal vmap target. Pure-JAX-Delaunay (the previous
"long-term win") is no longer the highest-value lever.
Cross-likelihood comparison (where NNLS sits in each)
The three imaging-family likelihoods all use jaxnnls for the reconstruction_positive_only_from step. Their setups differ in
problem size K and in what other steps share the budget.
Production-fiducial summary (all from PR #56 / PR #60 measurements):
*MGE PR #56 did not separate "lens light" from "source MGE" — both galaxies
use MGE bulges (20 Gaussians each = 40 total linear obj). It's the small-K
reference but not directly comparable to the rect/Delaunay production setup.
A100 fp64 timings (latest measurements per PR):
Config
MGE (K=40)
Rectangular (K=1285)
Delaunay (K=1291)
Full pipeline single
5.7 ms
25.1 ms
49.6 ms
Full pipeline vmap (per call)
2.4 ms
34.8 ms
438 ms
NNLS step single
2.0 ms
~16 ms
~16 ms
NNLS share of single
35%
64%
31%
vmap regress vs single
0.4× (faster!)
1.39×
8.84×
RTX 2060 fp64 timings:
Config
MGE
Rectangular
Delaunay
Full pipeline single
43.7 ms
537 ms
590 ms
Full pipeline vmap
23.9 ms
567 ms
954 ms
NNLS step single
~12 ms
~125 ms
~110 ms
vmap regress
0.55×
1.06×
1.62×
The pattern:
MGE benefits from vmap (small K, kernel-launch-bound NNLS — see prior
prompt's analysis). Lever 1 (MAX_ITER) maps to "fewer launches".
Rectangular at production K barely regresses under vmap (1.39× A100,
1.06× RTX 2060). NNLS dominates the budget but it's compute-bound.
Delaunay at production K explodes under vmap (8.84× A100). The
combined K=1291 with the linear lens MGE columns hits something nasty
in NNLS-under-vmap that doesn't appear at smaller K. This is the
unexplained 7× factor that needs diagnosis (see "Open question"
section below).
If the optimisation work successfully fixes NNLS under vmap, all three
likelihoods benefit: MGE keeps its existing "vmap is faster" status,
rectangular drops vmap regress to ~1.0×, and Delaunay hopefully drops
toward 2× or better (still won't beat single-JIT due to the inherent
batch-3 overhead, but 438 ms → 100 ms would be a 4× production speedup).
Both now ship at the MGE-60 fiducial (PR #60). Their EXPECTED_LOG_EVIDENCE_HST constants:
Rectangular: 24746.105672366088
Delaunay: 26288.321397232066
Track these on every NNLS variant — log-evidence drift > rtol=1e-4 means
the optimisation broke something.
For wall-clock targets, the per-config profiles in z_projects/profiling/scripts/{pixelization,delaunay}_profile.py give
side-by-side comparisons. Specific targets:
Config
Current
Target
Rationale
Rect A100 fp64 vmap-per-call
34.8 ms
≤ 25 ms
Within 1.0× of single-JIT (currently 1.39×)
Delaunay A100 fp64 vmap-per-call
438 ms
≤ 100 ms
Within 2.0× of single-JIT (currently 8.84×)
Rect RTX 2060 fp64 vmap
567 ms
≤ 540 ms
Within 1.0× of single-JIT
Delaunay RTX 2060 fp64 vmap
954 ms
≤ 700 ms
NNLS share is biggest regress vector
The Delaunay A100 vmap target is the headline science improvement —
4.4× speedup on production hot path. Anything that hits 200 ms (2× from
current) is already meaningful.
What's likely to work (priority order, post-PR-#60 evidence)
MAX_ITER reduction (still the obvious first lever). The previous
prompt's analysis of pdip.py:MAX_ITER = 50 → ~5 kernel launches/iter → 20 ms was at K=40. At K=1285 each Cholesky kernel actually does
real work, so the kernel-launch model is less dominant. But fewer
iterations still helps — both at the launch level AND at the linalg
level. Test MAX_ITER ∈ {10, 15, 20, 30} against the rectangular and
Delaunay log-evidence. Look for the inflection where Δlog-evidence
exceeds 1e-3.
Audit the gradient infrastructure (lever 3 from prior prompt).
For sampler use we don't need NNLS gradients — only the forward
solve. Confirm solve_nnls isn't dragging pdip_relaxed.py machinery
through the JIT trace. If it is, splitting forward/backward at the
API level should give a clean speedup with no numerical change.
The vmap pathology specifically. This is the new highest-value
item, not in the prior prompt. The 8.8× Delaunay vmap regression vs
1.27× per-step probe ratio means the full-pipeline vmap is hitting
something worse than just NNLS×3-batch. Hypotheses worth testing:
PDIP's per-iter Cholesky on (1291, 1291) under vmap might be
causing XLA to materialise (3, 1291, 1291, 1291) workspace tensors
that don't fit nicely. Use jax.jit(jax.vmap(...)).lower(...).compile() .memory_analysis() to compare.
The pure_callback for Delaunay triangulation runs sequentially
(3 calls), but XLA may stall on the pure_callback's host
synchronisation. Try removing the callback's vmap_method
constraint, or bundling all 3 mesh grids into a single callback
call.
Profile the full vmap path with jax.profiler.start_trace(...) to
see where the wall-time actually goes — the 8.8× factor over the
per-step decomposition is unexplained.
Pallas kernel for the per-iter Cholesky (lever 2 from prior).
At K=1285 a jax.lax.linalg.cholesky is real work (not launch-bound),
but a fused triangular factorisation could still help. Lower-priority
than items 1-3 because it's invasive and only nibbles at the cost.
Vendor jaxnnls (lever 5 from prior). Last resort if items 1-4
produce code changes worth maintaining.
What to NOT spend time on (deprioritised by PR #60)
Pure-JAX Delaunay triangulation. This was previously framed as
the long-term win for unblocking A100. Per-step probes show the
pure_callback already scales sublinearly under vmap — it's not the
vmap bottleneck. The single-JIT speedup it would provide is small
relative to the NNLS opportunity.
vmap=8 batching tuning. PR Feature/check solution #60's vmap regression measurements
use batch=3. Larger batch sizes are the sampler's lever, not NNLS's
— sampler concern (PyAutoFit), not PyAutoArray.
How to validate (full operational runbook)
Local-only validation (RTX 2060 + CPU)
Apply the NNLS change. If you're testing in-place: edit /home/jammy/venv/PyAutoGPU/lib/python3.10/site-packages/jaxnnls/pdip.py
directly. If you're testing a vendored copy: import from autoarray.inversion.nnls import solve_nnls
in inversion_util.py. Either way, pip show jaxnnls confirms which
version is loaded.
Activate venv + worktree (or canonical), in this order:
source /home/jammy/venv/PyAutoGPU/bin/activate
source /home/jammy/Code/PyAutoLabs-wt/<your-task>/activate.sh # if worktreed
The PyAutoGPU venv has Python 3.10 + JAX-CUDA12 (the default python
may resolve to the CPU-only PyAuto venv — explicit GPU venv first).
Run the per-config profilers from anywhere (they take absolute paths):
# GPU fp64 — the headline config
PYTHONUNBUFFERED=1 python -u \
/home/jammy/Code/PyAutoLabs/z_projects/profiling/scripts/pixelization_profile.py \
--config-name local_gpu_fp64 \
--output-dir "$PYAUTO_ROOT/autolens_workspace_developer/jax_profiling/results/jit/imaging/pixelization"
PYTHONUNBUFFERED=1 python -u \
/home/jammy/Code/PyAutoLabs/z_projects/profiling/scripts/delaunay_profile.py \
--config-name local_gpu_fp64 \
--output-dir "$PYAUTO_ROOT/autolens_workspace_developer/jax_profiling/results/jit/imaging/delaunay"# mp variants
... add --use-mixed-precision --config-name local_gpu_mp ...
# CPU variants — GOTCHA: use JAX_PLATFORM_NAME=cpu (NOT JAX_PLATFORMS=cpu;# JAX 0.4.38 has a bug with the new env var that errors with# "Unknown backend cuda" because pre-existing CUDA arrays from# register_model can't move).
PYTHONUNBUFFERED=1 JAX_PLATFORM_NAME=cpu \
NUMBA_CACHE_DIR=/tmp/numba_cache_cpu_$$ \
python -u .../pixelization_profile.py --config-name local_cpu_fp64 ...
Per-step vmap probe (the diagnostic that lets you see WHERE NNLS
speedup lands):
Each probe times inversion-setup, NNLS, and log-ev steps both
single and under vmap=3, and prints per-step vmap/single ratios.
For a useful NNLS speedup you want the NNLS ratio to drop from
1.27× toward 1.0×.
Confirm EXPECTED_LOG_EVIDENCE_HST assertion still passes — every
per-config profiler asserts against the constant on every config
that runs.
HPC validation (A100, where production lives)
The HPC profiling harness is a self-contained z_projects/profiling/
project. The hpc/sync tool wraps push/submit/pull/jobs. Key paths:
HPC project dir: /mnt/ral/jnightin/profiling
HPC venv: /mnt/ral/jnightin/PyAuto/PyAuto/ (sourced by the project's activate.sh at submit time)
Local working dir: /home/jammy/Code/PyAutoLabs/z_projects/profiling
Each prints Submitted batch job <ID>. Job IDs are sequential. Each
takes ~1–4 min A100 wall + queue wait (often instant if no other
user's array is in the queue, occasionally 15–30 min if there's
competition).
Wait for completion in the background (the sleep pattern works
even if you exit the shell — jobs run on the cluster regardless):
# Foreground if you're going to wait activelyuntil! hpc/sync jobs2>/dev/null | grep -qE "(<ID1>|<ID2>|<ID3>|<ID4>)";do
sleep 60
done;echo"ALL DONE"
Pull results + consolidate:
hpc/sync pull # rsync down output/ + hpc/batch_gpu/{output,error}/# Move HPC json+png pairs into the canonical worktree dir:
python scripts/pixelization_aggregate.py \
--consolidate-from output/imaging/pixelization
python scripts/delaunay_aggregate.py \
--consolidate-from output/imaging/delaunay
# Generate comparison.json + comparison.png across all 4–6 configs:
python scripts/pixelization_aggregate.py
python scripts/delaunay_aggregate.py
The aggregator honours PYAUTO_ROOT from the worktree's activate.sh —
canonical results land on the feature branch's worktree, not on
canonical-main. Without PYAUTO_ROOT, results go to canonical-main
(wrong if you're working on a branch).
After modifying any script in z_projects/profiling/scripts/, you
MUST re-run hpc/sync push before resubmitting — otherwise HPC
runs the OLD script and you get stale results.
SLURM jobs report exit 0:0 even if Python crashes inside, because
the bash submit script's epilogue (echo "Finished."; date) always
runs. Verify success by reading hpc/batch_gpu/output/output.<ID>.out
AND checking the JSON file mtime in output/imaging/<lik>/.
hpc/sync push skips dataset files that already exist on HPC
(--ignore-existing). If the local dataset was regenerated and
HPC needs the new version, manual force-push the dataset.
A complete "is the optimisation working?" loop
# 0. Set upcd /home/jammy/Code/PyAutoLabs/z_projects/profiling
source /home/jammy/venv/PyAutoGPU/bin/activate
source /home/jammy/Code/PyAutoLabs-wt/<task>/activate.sh
# 1. Local first — fast feedback (~5 min total for 4 GPU configs)
WORKTREE_OUT_PIX="$PYAUTO_ROOT/autolens_workspace_developer/jax_profiling/results/jit/imaging/pixelization"
WORKTREE_OUT_DEL="$PYAUTO_ROOT/autolens_workspace_developer/jax_profiling/results/jit/imaging/delaunay"
PYTHONUNBUFFERED=1 python -u scripts/pixelization_profile.py --config-name local_gpu_fp64 --output-dir "$WORKTREE_OUT_PIX"
PYTHONUNBUFFERED=1 python -u scripts/pixelization_profile.py --use-mixed-precision --config-name local_gpu_mp --output-dir "$WORKTREE_OUT_PIX"
PYTHONUNBUFFERED=1 python -u scripts/delaunay_profile.py --config-name local_gpu_fp64 --output-dir "$WORKTREE_OUT_DEL"
PYTHONUNBUFFERED=1 python -u scripts/delaunay_profile.py --use-mixed-precision --config-name local_gpu_mp --output-dir "$WORKTREE_OUT_DEL"# 2. Probe shows per-step vmap-vs-single — confirms NNLS share dropped
PYTHONUNBUFFERED=1 python -u scripts/delaunay_vmap_probe.py --config-name local_gpu_fp64 --output-dir /tmp/probe
PYTHONUNBUFFERED=1 python -u scripts/rectangular_vmap_probe.py --config-name local_gpu_fp64 --output-dir /tmp/probe
# 3. HPC for the production-fiducial number
hpc/sync push
hpc/sync submit gpu submit_pixelization_profile_fp64
hpc/sync submit gpu submit_pixelization_profile_mp
hpc/sync submit gpu submit_delaunay_profile_fp64
hpc/sync submit gpu submit_delaunay_profile_mp
# (wait for completion — see watcher snippet above)
hpc/sync pull
python scripts/pixelization_aggregate.py --consolidate-from output/imaging/pixelization
python scripts/delaunay_aggregate.py --consolidate-from output/imaging/delaunay
python scripts/pixelization_aggregate.py
python scripts/delaunay_aggregate.py
# 4. The tables in the comparison.json's headline section + the printed# stdout summary are what you compare against the targets in this# prompt's "What to measure" section.
Decision point: if MAX_ITER reduction alone gets us to the targets,
ship as a config change (no vendor). If it requires structural
changes, vendor jaxnnls under autoarray/inversion/nnls/.
Open question: the 8.8× vmap-regress mystery
The per-step probes on RTX 2060 show NNLS scaling at 1.27× under vmap.
The full-pipeline Delaunay vmap on A100 shows 8.84× regress. Even if
NNLS is the bottleneck, 1.27× × 1.0 (for the rest) ≠ 8.84×. There's a
~7× factor unaccounted for.
Possible explanations:
A100 NNLS vmap behaves much worse than RTX 2060 NNLS vmap (test:
port delaunay_vmap_probe.py to HPC and run on A100)
XLA scheduler pessimisation at this matrix size only kicks in for
the full-pipeline graph, not for per-step sub-graphs
Some other step (slogdet? pure_callback?) regresses much harder
under the full graph than in isolation
Resolving this is a precondition for the optimisation work — without
it, we can fix NNLS and still see most of the 8.8× regression remaining.
Files to touch / read
Read-only first:
/home/jammy/venv/PyAutoGPU/lib/python3.10/site-packages/jaxnnls/pdip.py
(the PDIP solver — MAX_ITER = 50 is at module scope)
autoarray/nnls_gpu_bottleneck.md (this repo): the prior framing,
superseded but its 5-lever structure is still valid as a how-to.
z_projects/profiling/scripts/delaunay_vmap_probe.py and rectangular_vmap_probe.py: the per-step decomposition probes that
produced the "scipy pure_callback is fine, NNLS is the wall"
finding.
Out of scope
Same as the prior prompt:
Replacing PDIP with active-set NNLS (much bigger change with its own
gradient story).
Anything that silently reduces correctness (clipping negatives, etc.).
vmap-batch sizing on the sampler side (PyAutoFit concern).
Plus newly:
Pure-JAX Delaunay triangulation (deprioritised — see "What to NOT
spend time on").
Mixed-precision micro-optimisation (not a lever at production scale).
Definition of done
PR to a feature branch on PyAutoArray (and possibly a small companion
to autolens_workspace_developer to bump the regression artifacts at
the new NNLS implementation), passing:
EXPECTED_LOG_EVIDENCE_HST assertions for both rectangular + Delaunay
at MGE-60 fiducial (rtol=1e-4 fp64, rtol=1e-2 mp)
Delaunay A100 vmap_per_call ≤ 200 ms (2× improvement, conservative
target — the headline is ≤ 100 ms but anything significant counts)
Rect A100 vmap_per_call ≤ 25 ms (within 1.0× of single-JIT)
No regression on the existing autolens_workspace_test JAX likelihood
functions (rtol=1e-4 vs current main)
Overview
Production lens-modelling runs nested sampling, which proposes parameter
vectors in batches and so always goes through the JAX likelihood under
vmap(batch=3 today). PR #60 inautolens_workspace_developermade thenew bottleneck picture concrete at the production-fiducial setup
(MGE-60 lens + pixelized source, K ≈ 1285–1291): A100 Delaunay
vmap-per-call is 438 ms, an 8.84× regression vs the 49.6 ms
single-JIT timing. Per-step probes show jaxnnls's PDIP solver is the
universal vmap bottleneck, while the previously-suspected
scipy.spatial.Delaunaypure_callbackactually scales sublinearlyunder vmap. The headline science improvement is hitting the ≤ 200 ms
Delaunay A100 vmap target — a 2× production speedup on the hot path.
This task is more than a MAX_ITER sweep. NNLS is structurally required
for the physics, but the algorithm is not pinned to PDIP. The plan
opens with an algorithmic survey (PDIP, ADMM, FISTA, Schur-complement),
resolves the unexplained 7× A100 vmap-regress factor before structural
changes, sweeps iter/tolerance on whichever algorithm wins, vendors it
into
autoarray/inversion/nnls/to drop the externaljaxnnlsdependency, and validates the change across all 6 imaging +
interferometer production pipelines plus the well/ill-conditioned
matrix in
autolens_workspace_test/scripts/jax_assertions/nnls.py.Plan
production-shaped Q, time under single + vmap, pick the winner based
on wall-clock × convergence × numerical stability.
probe to A100 — the ~7× unexplained factor decides whether the lever
lives in the NNLS algorithm or in XLA scheduling / memory layout.
against the full cross-pipeline matrix, not just HST.
production sampling — if not, expose a forward-only entry point.
autoarray/inversion/nnls/with fullattribution to Tracy (qpax) and Krawczyk (jaxnnls modifications),
drop
jaxnnlsfrompyproject.toml, and switchinversion_util.reconstruction_positive_only_fromto the vendoredmodule.
reconstruction_positive_only_from(currently claims fnnls/Bro-Jong; JAX branch actually calls jaxnnls
PDIP).
autolens_workspace_developeracrossall 6 imaging + interferometer pipelines.
artifact refresh + test-script update.
Detailed implementation plan
Affected Repositories
Work Classification
Library (PyAutoArray is primary; workspace companions are artifact and
test-script edits, no production-script logic changes).
Branch Survey
euclid_bug/work; will not be picked up by the worktree branch)/start_librarytime)Suggested branch:
feature/nnls-vmap-speedupWorktree root:
~/Code/PyAutoLabs-wt/nnls-vmap-speedup/(created later by/start_library)Implementation Steps
z_projects/profiling/scripts/nnls_prototypes/. Prototypepdip_baseline.py,admm_prototype.py,fista_prototype.pyand annls_bench.pythat times all candidates on synthetic Q matching production shape/conditioning under single-JIT + vmap=3. Decide based on (1) hits Delaunay A100 ≤ 200 ms target, (2) holds log-evidence rtol=1e-4 across cross-pipeline matrix, (3) no NaN gradients across condition classes, (4) smallest diff against vendored jaxnnls.submit_{delaunay,rectangular}_vmap_probe_fp64SLURM scripts,hpc/sync push, submit, pull, inspect per-step ratios. If A100 NNLS-vmap is ~1.27× like RTX 2060, the 7× lives in XLA scheduling — diagnose viajax.profiler.start_traceandmemory_analysis().jax.jit(jax.vmap(likelihood))and grep forpdip_relaxed/diff_nnls. If absent → exposesolve_nnls_forward(Q, q)returning just primal, skippingcustom_vjp. Switchinversion_utilto the forward-only entry under annls_forward_onlyconfig flag.PyAutoArray/autoarray/inversion/nnls/:__init__.pyre-exportssolve_nnls,solve_nnls_primal, possiblysolve_nnls_forward.LICENSE.txtwith Tracy/Krawczyk attribution.MAX_ITER, etc.) replaced with autoconf reads at the entry function level (not insidejax.jit-traced loops).test_autoarray/inversion/inversion/nnls/test_solver.py— scipy-NNLS parity at K=20, primal-forward parity, max_iter config honoured, well-conditioned K=1000, ill-conditioned K=1000 with Jacobi preconditioning (no NaN).inversion_util.py:275-321fromimport jaxnnlstofrom autoarray.inversion import nnls. Fix docstring at lines 241-273 to describe both numpy fnnls and JAX PDIP/ADMM/FISTA branches accurately.autoarray/config/general.yaml: keep existingnnls_jacobi_preconditioning,nnls_target_kappa; addnnls_algorithm,nnls_max_iter,nnls_solver_tol, andnnls_forward_only(if Step 4 produced such a switch).jaxnnlsfrompyproject.toml.autolens_workspace_test/scripts/jax_assertions/nnls.pywell-conditioned match and ill-conditioned finite-gradient checks.autolens_workspace_developer/jax_profiling/results/jit/{imaging,interferometer}/{mge,pixelization,delaunay}/and updateEXPECTED_LOG_EVIDENCE_*constants if drift detected (artifact mtimes advance regardless).Validation gates (PR cannot land without all)
EXPECTED_LOG_EVIDENCE_*assertions pass at rtol=1e-4 fp64 / 1e-2 mp.autolens_workspace_test/scripts/jax_assertions/nnls.pypasses well-conditioned-match and ill-conditioned-finite-gradient checks.pytest test_autoarray/inversion/clean.pip uninstall jaxnnls,from autoarray.inversion import nnls; nnls.solve_nnls_primal(...)still works.solve_nnls_primalon the ill-conditioned case from jax_assertions/nnls.py returns finite gradients.Key Files
PyAutoArray/autoarray/inversion/nnls/(new directory) — vendored solverPyAutoArray/autoarray/inversion/inversion/inversion_util.py:275-321— call site + docstringPyAutoArray/autoarray/config/general.yaml:9-10— config knobsPyAutoArray/pyproject.toml— dropjaxnnlsdepPyAutoArray/test_autoarray/inversion/inversion/nnls/test_solver.py(new) — unit testsautolens_workspace_developer/jax_profiling/jit/{imaging,interferometer}/{mge,pixelization,delaunay}.py—EXPECTED_LOG_EVIDENCE_*if drift detectedautolens_workspace_test/scripts/jax_assertions/nnls.py— import switch + matrix extensionOriginal Prompt
Click to expand starting prompt
Speed up the JAX NNLS solver under vmap at production-fiducial size
This supersedes
autoarray/nnls_gpu_bottleneck.md— that earlier promptwas written from MGE-only profiling at a 40-element problem and prioritised
single-JIT GPU performance. Production runs use vmap (nested-sampling
proposes batches of parameter vectors) at much larger problem sizes, and
PR #60 (autolens_workspace_developer) makes the new bottleneck picture
concrete.
The earlier prompt's 5-lever structure (
MAX_ITER, Cholesky vectorisation,gradient-infrastructure check, K≤N closed-form bypass, vendor jaxnnls) is
all still valid — keep it as the "how to attack" reference. This prompt
re-frames "what to attack first" given the new vmap-relevant data.
What the latest profiling found (PR #60)
Production-fiducial setup:
setup, not single-Sersic
(39×39 overlay → 1231 vertices)
A100 fp64 timings at this fiducial:
The Delaunay-vmap regression is the science-throughput wall. RTX 2060
shows the same pattern less dramatically (Delaunay vmap 590 → 1217 ms mp).
Per-step vmap probes confirmed scipy is NOT the issue
Earlier rounds suspected
scipy.spatial.Delaunay(insidejax.pure_callbackwith
vmap_method="sequential") was the problem. Per-step vmap probeson RTX 2060 (committed in
z_projects/profiling/scripts/{delaunay, rectangular}_vmap_probe.py) flip the story:pure_callbackactually scales sublinearly under vmap — its dispatchoverhead amortises across batch elements. NNLS scales superlinearly
because its iterations are inherently serial per batch element.
So: NNLS is the universal vmap target. Pure-JAX-Delaunay (the previous
"long-term win") is no longer the highest-value lever.
Cross-likelihood comparison (where NNLS sits in each)
The three imaging-family likelihoods all use jaxnnls for the
reconstruction_positive_only_fromstep. Their setups differ inproblem size K and in what other steps share the budget.
Production-fiducial summary (all from PR #56 / PR #60 measurements):
pure_callback?*MGE PR #56 did not separate "lens light" from "source MGE" — both galaxies
use MGE bulges (20 Gaussians each = 40 total linear obj). It's the small-K
reference but not directly comparable to the rect/Delaunay production setup.
A100 fp64 timings (latest measurements per PR):
RTX 2060 fp64 timings:
The pattern:
prompt's analysis). Lever 1 (MAX_ITER) maps to "fewer launches".
1.06× RTX 2060). NNLS dominates the budget but it's compute-bound.
combined K=1291 with the linear lens MGE columns hits something nasty
in NNLS-under-vmap that doesn't appear at smaller K. This is the
unexplained 7× factor that needs diagnosis (see "Open question"
section below).
If the optimisation work successfully fixes NNLS under vmap, all three
likelihoods benefit: MGE keeps its existing "vmap is faster" status,
rectangular drops vmap regress to ~1.0×, and Delaunay hopefully drops
toward 2× or better (still won't beat single-JIT due to the inherent
batch-3 overhead, but 438 ms → 100 ms would be a 4× production speedup).
What to measure (updated regression baselines)
Use the canonical references at:
autolens_workspace_developer/jax_profiling/jit/imaging/pixelization.pyautolens_workspace_developer/jax_profiling/jit/imaging/delaunay.pyBoth now ship at the MGE-60 fiducial (PR #60). Their
EXPECTED_LOG_EVIDENCE_HSTconstants:24746.10567236608826288.321397232066Track these on every NNLS variant — log-evidence drift > rtol=1e-4 means
the optimisation broke something.
For wall-clock targets, the per-config profiles in
z_projects/profiling/scripts/{pixelization,delaunay}_profile.pygiveside-by-side comparisons. Specific targets:
The Delaunay A100 vmap target is the headline science improvement —
4.4× speedup on production hot path. Anything that hits 200 ms (2× from
current) is already meaningful.
What's likely to work (priority order, post-PR-#60 evidence)
MAX_ITER reduction (still the obvious first lever). The previous
prompt's analysis of
pdip.py:MAX_ITER = 50 → ~5 kernel launches/iter → 20 mswas at K=40. At K=1285 each Cholesky kernel actually doesreal work, so the kernel-launch model is less dominant. But fewer
iterations still helps — both at the launch level AND at the linalg
level. Test MAX_ITER ∈ {10, 15, 20, 30} against the rectangular and
Delaunay log-evidence. Look for the inflection where Δlog-evidence
exceeds 1e-3.
Audit the gradient infrastructure (lever 3 from prior prompt).
For sampler use we don't need NNLS gradients — only the forward
solve. Confirm
solve_nnlsisn't draggingpdip_relaxed.pymachinerythrough the JIT trace. If it is, splitting forward/backward at the
API level should give a clean speedup with no numerical change.
The vmap pathology specifically. This is the new highest-value
item, not in the prior prompt. The 8.8× Delaunay vmap regression vs
1.27× per-step probe ratio means the full-pipeline vmap is hitting
something worse than just NNLS×3-batch. Hypotheses worth testing:
causing XLA to materialise (3, 1291, 1291, 1291) workspace tensors
that don't fit nicely. Use
jax.jit(jax.vmap(...)).lower(...).compile() .memory_analysis()to compare.(3 calls), but XLA may stall on the pure_callback's host
synchronisation. Try removing the callback's
vmap_methodconstraint, or bundling all 3 mesh grids into a single callback
call.
jax.profiler.start_trace(...)tosee where the wall-time actually goes — the 8.8× factor over the
per-step decomposition is unexplained.
Pallas kernel for the per-iter Cholesky (lever 2 from prior).
At K=1285 a
jax.lax.linalg.choleskyis real work (not launch-bound),but a fused triangular factorisation could still help. Lower-priority
than items 1-3 because it's invasive and only nibbles at the cost.
Vendor jaxnnls (lever 5 from prior). Last resort if items 1-4
produce code changes worth maintaining.
What to NOT spend time on (deprioritised by PR #60)
the long-term win for unblocking A100. Per-step probes show the
pure_callback already scales sublinearly under vmap — it's not the
vmap bottleneck. The single-JIT speedup it would provide is small
relative to the NNLS opportunity.
mp is essentially a no-op on A100 and a marginal-to-no-op effect on
RTX 2060 at production scale. Not a useful lever.
use batch=3. Larger batch sizes are the sampler's lever, not NNLS's
— sampler concern (PyAutoFit), not PyAutoArray.
How to validate (full operational runbook)
Local-only validation (RTX 2060 + CPU)
/home/jammy/venv/PyAutoGPU/lib/python3.10/site-packages/jaxnnls/pdip.pydirectly. If you're testing a vendored copy: import
from autoarray.inversion.nnls import solve_nnlsin
inversion_util.py. Either way,pip show jaxnnlsconfirms whichversion is loaded.
pythonmay resolve to the CPU-only PyAuto venv — explicit GPU venv first).
speedup lands):
single and under vmap=3, and prints per-step vmap/single ratios.
For a useful NNLS speedup you want the NNLS ratio to drop from
1.27× toward 1.0×.
EXPECTED_LOG_EVIDENCE_HSTassertion still passes — everyper-config profiler asserts against the constant on every config
that runs.
HPC validation (A100, where production lives)
The HPC profiling harness is a self-contained
z_projects/profiling/project. The
hpc/synctool wraps push/submit/pull/jobs. Key paths:/mnt/ral/jnightin/profiling/mnt/ral/jnightin/PyAuto/PyAuto/(sourced by the project'sactivate.shat submit time)/home/jammy/Code/PyAutoLabs/z_projects/profilingz_projects/profiling/hpc/sync.conf(alreadyconfigured:
HPC_HOST=euclid_jump, etc.)Operate from
/home/jammy/Code/PyAutoLabs/z_projects/profiling/.Push code + check connection (does NOT need the worktree, sync
operates on the canonical z_projects/profiling/):
Submit jobs — one per (likelihood × precision) pair:
Each prints
Submitted batch job <ID>. Job IDs are sequential. Eachtakes ~1–4 min A100 wall + queue wait (often instant if no other
user's array is in the queue, occasionally 15–30 min if there's
competition).
Wait for completion in the background (the sleep pattern works
even if you exit the shell — jobs run on the cluster regardless):
Pull results + consolidate:
The aggregator honours
PYAUTO_ROOTfrom the worktree'sactivate.sh—canonical results land on the feature branch's worktree, not on
canonical-main. Without
PYAUTO_ROOT, results go to canonical-main(wrong if you're working on a branch).
Common gotchas observed during PR #60:
z_projects/profiling/scripts/, youMUST re-run
hpc/sync pushbefore resubmitting — otherwise HPCruns the OLD script and you get stale results.
0:0even if Python crashes inside, becausethe bash submit script's epilogue (
echo "Finished."; date) alwaysruns. Verify success by reading
hpc/batch_gpu/output/output.<ID>.outAND checking the JSON file mtime in
output/imaging/<lik>/.hpc/sync pushskips dataset files that already exist on HPC(
--ignore-existing). If the local dataset was regenerated andHPC needs the new version, manual force-push the dataset.
A complete "is the optimisation working?" loop
Decision point: if MAX_ITER reduction alone gets us to the targets,
ship as a config change (no vendor). If it requires structural
changes, vendor
jaxnnlsunderautoarray/inversion/nnls/.Open question: the 8.8× vmap-regress mystery
The per-step probes on RTX 2060 show NNLS scaling at 1.27× under vmap.
The full-pipeline Delaunay vmap on A100 shows 8.84× regress. Even if
NNLS is the bottleneck, 1.27× × 1.0 (for the rest) ≠ 8.84×. There's a
~7× factor unaccounted for.
Possible explanations:
port
delaunay_vmap_probe.pyto HPC and run on A100)the full-pipeline graph, not for per-step sub-graphs
under the full graph than in isolation
Resolving this is a precondition for the optimisation work — without
it, we can fix NNLS and still see most of the 8.8× regression remaining.
Files to touch / read
Read-only first:
/home/jammy/venv/PyAutoGPU/lib/python3.10/site-packages/jaxnnls/pdip.py(the PDIP solver —
MAX_ITER = 50is at module scope)/home/jammy/venv/PyAutoGPU/lib/python3.10/site-packages/jaxnnls/pdip_relaxed.py(gradient path; check whether it's needed)
@PyAutoArray/autoarray/inversion/inversion/inversion_util.py(
reconstruction_positive_only_from— incorrect docstring claimsfnnls/Bro-Jong, actually calls jaxnnls PDIP)
@autolens_workspace_developer/jax_profiling/jit/imaging/pixelization.pydelaunay.py(canonical refs at MGE-60 fiducial; the regressionbenchmarks with
EXPECTED_LOG_EVIDENCE_HSTconstants)@z_projects/profiling/scripts/{pixelization,delaunay,delaunay_vmap_probe, rectangular_vmap_probe}_*.py(the per-config profilers + vmap probes)If changes needed (in roughly this order):
jaxnnls, changeMAX_ITER, install@PyAutoArray/autoarray/inversion/inversion/inversion_util.py(fix the docstring lie at minimum)
jaxnnlsto@PyAutoArray/autoarray/inversion/nnls/(only if MAX_ITER reduction isn't enough)
Reference precedent
that produced the headline 8.8× Delaunay vmap regression.
Re-profile rectangular + Delaunay at full production fiducial (1225/1231 src + MGE-60 lens) autolens_workspace_developer#60
See its body for the headline numbers and the bottleneck-shift table
across the three iterations of "make it production-realistic".
autoarray/nnls_gpu_bottleneck.md(this repo): the prior framing,superseded but its 5-lever structure is still valid as a how-to.
z_projects/profiling/scripts/delaunay_vmap_probe.pyandrectangular_vmap_probe.py: the per-step decomposition probes thatproduced the "scipy pure_callback is fine, NNLS is the wall"
finding.
Out of scope
Same as the prior prompt:
gradient story).
Plus newly:
spend time on").
Definition of done
PR to a feature branch on PyAutoArray (and possibly a small companion
to autolens_workspace_developer to bump the regression artifacts at
the new NNLS implementation), passing:
EXPECTED_LOG_EVIDENCE_HSTassertions for both rectangular + Delaunayat MGE-60 fiducial (rtol=1e-4 fp64, rtol=1e-2 mp)
target — the headline is ≤ 100 ms but anything significant counts)
functions (rtol=1e-4 vs current main)