Skip to content

Re-profile rectangular + Delaunay at full production fiducial (1225/1231 src + MGE-60 lens)#60

Merged
Jammy2211 merged 2 commits into
mainfrom
feature/profiling-1250-fiducial-rerun
May 11, 2026
Merged

Re-profile rectangular + Delaunay at full production fiducial (1225/1231 src + MGE-60 lens)#60
Jammy2211 merged 2 commits into
mainfrom
feature/profiling-1250-fiducial-rerun

Conversation

@Jammy2211
Copy link
Copy Markdown
Contributor

@Jammy2211 Jammy2211 commented May 10, 2026

Summary

Replaces the rectangular-pixelization (#57) and Delaunay (#58) imaging-likelihood profiling sweeps with the full production fiducial setup — both the ~1250 source pixel mesh AND the MGE-60 lens light (60 linear Gaussians).

This PR went through three iterations of "make it more production-realistic":

Iteration Rect source px Delaunay vertices Lens light Status
Initial PRs #57/58 784 570 Single Sersic merged
Mid-iteration (this branch) 1225 1231 Single Sersic superseded
Final (this PR) 1225 1231 MGE-60 shipping

The two earlier states are still in this PR's git history but main lands the final MGE-60 fiducial. Updates canonical references at jax_profiling/jit/imaging/{pixelization,delaunay}.py to match.

35×35 was chosen for rectangular (closest perfect square to 1250; mesh required square). 39×39 Overlay was empirically calibrated for Delaunay so post-mask-filtering yields ~1250 vertices (1201 inside + 30 edge = 1231 total).

Headline at the full fiducial

Config Rect (1225 + MGE-60) full Rect vmap Delaunay (1231 + MGE-60) full Delaunay vmap
A100 fp64 25.1 ms 34.8 ms 49.6 ms 438 ms
A100 mp 24.9 ms 34.4 ms 49.6 ms 441 ms
RTX 2060 fp64 537 ms 567 ms 590 ms 954 ms
RTX 2060 mp 495 ms 528 ms 557 ms 1217 ms
CPU fp64 4443 ms 4991 ms — hang
CPU mp 3803 ms 4308 ms — hang

Delaunay vmap regression on A100 is 8.8× single-call (438 / 49.6 ms) when MGE-60 lens is added. The combined F+H is (1291, 1291) — 1231 source pixels + 60 lens MGE columns — and NNLS doesn't batch usefully at this size.

How the bottleneck picture shifts vs the earlier iterations

Likelihood Metric Initial (784/570 + Sersic) Mid (1225/1231 + Sersic) Final (1225/1231 + MGE-60)
Rect A100 full 9.7 ms 15.3 ms 25.1 ms
Rect A100 vmap 12.3 ms 21.4 ms 34.8 ms
Rect A100/RTX 22× 29× 21×
Delaunay A100 full 17.3 ms 34.9 ms 49.6 ms
Delaunay A100 vmap 70.6 ms 95.8 ms 438 ms
Delaunay vmap/single ratio 4.08× 2.74× 8.84×

A100 vs RTX 2060 widens then narrows for rectangular as the model grows — at the production fiducial it's 21× (vs 22× at 784, 29× at 1225 single-Sersic). The MGE-60 mapping-matrix columns are a cheap GPU operation but the F construction grows faster on CPU/consumer GPU than A100.

For Delaunay, the vmap regression explodes from 4× → 8.8× as the lens MGE columns are added. The combined-size NNLS is the wall.

Key findings at the production fiducial

  • NNLS reconstruction is now the dominant cost everywhere under vmap. A100 single-JIT: NNLS ≈ 16 ms of 25 ms rect total (64%), 16 ms of 50 ms Delaunay total (31%). Under vmap the share grows since NNLS doesn't batch and the rest does. Per-step probes on RTX 2060 confirm: scipy.spatial.Delaunay pure_callback scales SUBLINEARLY under vmap (0.85× ratio) while NNLS scales SUPERlinearly (1.27×). Production runs use vmap → NNLS is the lever.

  • scipy pure_callback is not the production-relevant Delaunay bottleneck. Single-JIT timings would say it is (Inversion-setup-combined dominates), but production uses vmap and the pure_callback amortises across the batch. The new vmap_probe scripts (in z_projects/profiling/scripts/) make this concrete.

  • MGE-60 lens light adds ~10ms single + ~14ms vmap on A100 rect, and ~15ms single + ~340 ms vmap on A100 Delaunay. The Delaunay vmap penalty is grossly out of proportion to the model-size change. Hypothesis: the (1291, 1291) F+H matrix solve under vmap hits an XLA scheduling pathology that the (1231, 1231) single-call path doesn't. Worth a follow-up trace.

  • Mixed precision is essentially a no-op on A100 everywhere (rect 24.9 vs 25.1 ms, Delaunay 49.6 vs 49.6 ms). On RTX 2060: rect mp gives 8% off (495 vs 537), Delaunay mp gives 5% off — much smaller than at single-Sersic where Delaunay-mp saved 24%. The MGE-60 inversion-setup work is mp-friendly but NNLS isn't. On CPU: rect-mp saves 14% (3803 vs 4443) — still positive at this scale, opposite of what we saw at single-Sersic 1225 (where CPU-mp regressed).

  • Delaunay-on-CPU hangs at the bigger meshes. Both fp64 and mp hang at full_pipeline_first_call after compile (futex_wait_queue_me signature). Reproducible. Specific to (CPU + Delaunay + larger mesh). Task Update source_plane regression to Richardson-converged truth #24 followup. Rectangular CPU at 1225 + MGE-60 works fine.

Caveats

  • No Delaunay CPU rows — Task Update source_plane regression to Richardson-converged truth #24 followup.
  • A100 JIT log-evidence shows fp32 truncation (24745.88 vs eager 24746.11; 26289.22 vs eager 26288.32). Same root cause as Add MGE imaging profiling: A100 + RTX 2060 + CPU sweep #56/57/58, HPC PyAutoNSS venv lacks jax_enable_x64. rtol=1e-4 assertions still pass.
  • Step-by-step log_evidence prints -inf in per-config script output for all configs in this PR. Cosmetic bug in the script's hand-rolled compute_log_evidenceslogdet(H) of the regularization matrix padded to full F shape includes the all-zeros lens-MGE block (singular), so the slogdet is -inf. Does not affect any assertion (which all use eager and full-pipeline-JIT paths). Production AnalysisImaging handles this correctly. Will fix the script in a followup.
  • RTX 2060 hits OOM warnings on rect vmap at 1225 + MGE-60 — 6 GB consumer card straining. Run completes.
  • The pixelization HPC jobs had to be resubmitted once mid-iteration after I bumped MGE-30 → MGE-60 and forgot to re-push the pixelization_profile.py fix to HPC. SLURM jobs reported exit 0:0 because the bash epilogue ran even though Python crashed.

Generated by

  • z_projects/profiling/scripts/pixelization_profile.py + delaunay_profile.py — the main per-config profilers (now with MGE-60 lens via _setup_*.py).
  • z_projects/profiling/scripts/{delaunay,rectangular}_vmap_probe.py — new per-step vmap/single decomposition probes used in the body's analysis.
  • Same SLURM submits.

Test plan

  • All 10 JSON files schema-valid (parsed cleanly by both aggregators)
  • Both comparison.json + comparison.png regenerated end-to-end
  • Eager log_evidence assertions: 24746.106 + 26288.321 match EXPECTED_LOG_EVIDENCE_HST (rtol=1e-4) on every config that ran
  • Full-pipeline JIT + vmap log_evidence assertions pass on every shipping config

Followups

  • Task Update source_plane regression to Richardson-converged truth #24: Delaunay-on-CPU hang for both precisions at the bigger mesh + MGE-60. Reproducer is JAX_PLATFORM_NAME=cpu python z_projects/profiling/scripts/delaunay_profile.py --config-name local_cpu_fp64.
  • Cosmetic fix: per-config script's hand-rolled compute_log_evidence should use the source-only H block in slogdet(H) instead of the zero-padded full matrix. Trivial diff once we know it's the only consumer.
  • The next optimization conversation: NNLS replacement for vmap. The 8.8× Delaunay vmap regression at production fiducial is the science-throughput wall. The pure_callback we worried about earlier turns out to be a non-issue under vmap.

🤖 Generated with Claude Code

Jammy2211 and others added 2 commits May 10, 2026 21:25
…e fiducial

Bumps mesh sizes:
- Rectangular: 28x28 = 784 → 35x35 = 1225 source pixels
- Delaunay: 26x26 Overlay = 570 vertices → 39x39 Overlay = 1231 vertices

EXPECTED_LOG_EVIDENCE_HST recomputed:
- Rectangular: 25918.02569499014 (was 26232.07)
- Delaunay: 27433.90296505439 (was 29179.95)

Both canonical references in jax_profiling/jit/imaging/{pixelization,
delaunay}.py and the result artifacts under jax_profiling/results/jit/
imaging/ have been updated together so canonical and per-config stay
aligned at the new fiducial.

Configs shipped:
- Rectangular: 6 (CPU+GPU x fp64+mp, A100 fp64+mp) — full coverage
- Delaunay:    4 (GPU fp64+mp, A100 fp64+mp) — local CPU configs hang

local_cpu_fp64 + local_cpu_mp for Delaunay at 1231 vertices both hang
indefinitely at full_pipeline_first_call after compile succeeds.
Identical futex_wait_queue_me signature for both precisions —
extends Task #24 (was thought to be mp-specific; size-related).
Rectangular CPU configs at 1225 work fine. PR ships Delaunay without
CPU rows.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Lens light: single Sersic → MGE-60 (60 linear Gaussians) in both
canonical refs (jax_profiling/jit/imaging/{pixelization,delaunay}.py)
and corresponding result artifacts. The MGE columns enter the
inversion mapping matrix, growing F+H from NxN to (N+60)x(N+60).

EXPECTED_LOG_EVIDENCE_HST recomputed:
- Pixelization (1225 + MGE-60): 24746.105672366088 (was 25918.026 single-Sersic)
- Delaunay (1231 + MGE-60):     26288.321397232066 (was 27433.903)

Configs shipped:
- Rectangular: 6 (CPU+GPU x fp64+mp, A100 fp64+mp) — full coverage
- Delaunay:    4 (GPU fp64+mp, A100 fp64+mp) — local CPU still hangs (Task #24)

Companion z_projects/profiling commit on local main updates
_setup_pixelization.py + _setup_delaunay.py to MGE-60, fixes
pixelization_profile.py to use eager-extracted H from inversion (the
old source-only constant_regularization_matrix_from gave a (N,N) shape
that mismatched F's (N+60,N+60) when lens is linear-MGE), and adds
delaunay_vmap_probe.py + rectangular_vmap_probe.py for per-step
vmap/single decomposition.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Jammy2211 Jammy2211 changed the title Re-profile rectangular + Delaunay at 1250-pixel science fiducial Re-profile rectangular + Delaunay at full production fiducial (1225/1231 src + MGE-60 lens) May 11, 2026
@Jammy2211 Jammy2211 merged commit 49fb4f6 into main May 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant