Skip to content

Add Delaunay imaging profiling: A100 + RTX 2060 sweep + three-way comparison#58

Merged
Jammy2211 merged 2 commits into
mainfrom
feature/delaunay-profiling-a100
May 10, 2026
Merged

Add Delaunay imaging profiling: A100 + RTX 2060 sweep + three-way comparison#58
Jammy2211 merged 2 commits into
mainfrom
feature/delaunay-profiling-a100

Conversation

@Jammy2211
Copy link
Copy Markdown
Contributor

Summary

Adds long-term tracking artifacts for the Delaunay imaging likelihood under jax_profiling/results/jit/imaging/delaunay/ — four configs (RTX 2060 + A100, fp64 + mp). Generated by new tooling in z_projects/profiling/scripts/ (separate local-only commit, no PR target).

Likelihood: Sersic + Isothermal + ExternalShear lens with a Delaunay source mesh of ~706 vertices (Overlay 26×26 + 30 circular edge points) + ConstantSplit regularization. Mirrors the canonical reference at jax_profiling/jit/imaging/delaunay.py.

This is the third entry in the imaging family after MGE (#56) and rectangular pixelization (#57), so the framing here is the three-way cross-likelihood comparison rather than a standalone Delaunay write-up.

Headline numbers — Delaunay alone

Config Full pipeline vmap per call
hpc_a100_fp64 17.3 ms 70.6 ms
hpc_a100_mp 17.3 ms 70.6 ms
local_gpu_fp64 (RTX 2060) 162.5 ms 227.6 ms
local_gpu_mp 123.4 ms 173.8 ms

Local CPU configs (local_cpu_fp64 / local_cpu_mp) were attempted but both runs hung indefinitely in the dataset / mask oversampling setup (sub-2% CPU for 18–24 min, no progress past mask padding). This is a Delaunay-on-CPU stall not present in the prior MGE or rectangular pixelization sweeps; suspect numba JIT cache contention between a prior GPU-mode run and forced-CPU mode in the same shell. Skipped for this PR; flagged for followup.

Three-way cross-likelihood comparison (full pipeline per call)

Likelihood A100 fp64 A100 mp RTX 2060 fp64 RTX 2060 mp A100 vs RTX
MGE (#56) 5.7 ms 5.4 ms 43.7 ms 43.0 ms 7.7×
Rectangular pixelization (#57) 9.7 ms 10.1 ms 212.2 ms 192.6 ms 22×
Delaunay (this PR) 17.3 ms 17.3 ms 162.5 ms 123.4 ms 9.4×

Three-way cross-likelihood comparison (vmap per call)

Likelihood A100 fp64 RTX 2060 fp64 vmap behaviour
MGE 2.4 ms 23.9 ms 2.4× speedup vs single JIT (good)
Rectangular pixelization 12.3 ms 233.1 ms 0.8× — regresses (NNLS serial)
Delaunay 70.6 ms 227.6 ms 0.25× — 4× regress on A100 (NNLS + something else)

Key findings

  • A100 vs RTX 2060 speedup is non-monotonic across likelihoods. MGE = 7.7×, rectangular = 22×, Delaunay = 9.4×. The pattern is best explained by what fraction of each pipeline JIT-compiles to GPU vs falls back to host CPU:

    • MGE is GPU-native end-to-end → modest A100 win because the consumer GPU was already adequate.
    • Rectangular pixelization's bottleneck (F construction) collapses 200× on A100 thanks to tensor-core dense GEMM → biggest A100 win.
    • Delaunay is bottlenecked by the regularization matrix H (ConstantSplit), which uses interpolator-derived sparse weights and runs on host CPU regardless of A100. F still collapses (51 ms → 0.5 ms, ~100×), but H stays at ~17 ms on both hardware classes — capping A100's headline win.
  • vmap is increasingly hostile as we move through the imaging family. MGE benefits 2.4×; rectangular regresses 0.8×; Delaunay regresses 0.25× on A100 — i.e. batch=3 vmap is 4× slower per call than single-JIT. NNLS being serial explains the rectangular regression, but that doesn't explain why Delaunay is 4× worse than rectangular on A100. Likely the Delaunay triangulation step (scipy on host) doesn't vmap usefully and adds overhead per batch element rather than amortising.

  • Mixed precision behaviour shifts further toward "GPU lever" with Delaunay. MGE: ~0% on either GPU. Rectangular: ~10% RTX 2060, ~0% A100. Delaunay: 24% RTX 2060 (biggest mp benefit yet on consumer GPU), ~0% A100. The pattern matches: more dense linalg on consumer GPU → more headroom for fp32 to help. But A100's fp64 throughput is so high that mp never moves the needle there.

  • Reconstruction NNLS cost scales sublinearly across likelihoods. RTX 2060 fp64: rectangular 60 ms, Delaunay 36 ms — Delaunay is faster despite similar source pixel counts (~706 vs 784). On A100: rectangular 6.8 ms, Delaunay 4.5 ms — same ordering. NNLS converges faster when the curvature matrix is better conditioned (Delaunay's edge-point + zero-pixel scheme tightens it).

Caveats

  • NNLS-vs-linear-solve discrepancy in canonical reference. The canonical jax_profiling/jit/imaging/delaunay.py uses jnp.linalg.solve(F+H, D) for its per-step "Regularized reconstruction" timing, which under-reports cost (~5 ms vs ~36 ms NNLS on RTX 2060). This per-config script uses NNLS to match production AnalysisImaging behaviour — the full-pipeline JIT path is unaffected (it always uses production NNLS). A separate one-line PR should switch the canonical's step 12 to reconstruction_positive_only_from.

  • Regularization matrix (H) is an eager wall-clock measurement, not a JIT per-call average. ConstantSplit's interpolator-derived sparse weights aren't easily JIT-traced, so H is extracted once from the reference inversion. The reported time can include cold-start/setup costs and shouldn't be summed naively into "total step-by-step." The full-pipeline JIT path inside analysis.log_likelihood_function handles H differently and the 17.3 ms full-pipeline number is the trustworthy per-call cost.

  • Delaunay-on-CPU stall. Both local_cpu_* configs hung in setup (sub-2% CPU for 18–24 min). Reproducible on this laptop. Not seen for MGE or rectangular pixelization. Suspect numba cache contention; needs isolated investigation. PR ships with 4 configs instead of 6 as a result.

  • A100 JIT log-evidence shows fp32-level truncation (29181.09 vs eager 29179.95). Same root cause as PR Add MGE imaging profiling: A100 + RTX 2060 + CPU sweep #56 + Add pixelization imaging profiling: A100 + RTX 2060 + CPU sweep #57: HPC PyAutoNSS venv lacks jax_enable_x64. Doesn't affect timing data here; assertion uses rtol=1e-4 which still passes.

  • Local sweep timings vary across sessions due to JAX cache + GPU thermal state. Cross-platform comparisons (A100 vs RTX 2060) are robust; cross-likelihood comparisons in the table above use values from each PR's own sweep, not a single-session run, so cross-session noise applies.

Generated by

  • z_projects/profiling/scripts/delaunay_profile.py — single-config 11-step JIT profiler (per-step timings + full pipeline + vmap + memory analysis). Argparse-driven, honours PYAUTO_ROOT for worktree-aware canonical writes.
  • z_projects/profiling/scripts/delaunay_aggregate.py--ingest-pre-fix /tmp (no-op unless artifacts present); --consolidate-from <staging> to move HPC pulls into this canonical dir; default to emit comparison.json + comparison.png.
  • z_projects/profiling/scripts/_setup_delaunay.py — shared build_dataset / build_image_plane_mesh_grid / build_model / build_adapt_images / build_analysis so the canonical reference's EXPECTED_LOG_EVIDENCE_HST = 29179.9490711974 constant carries through asserted on every run.
  • z_projects/profiling/hpc/batch_gpu/submit_delaunay_profile_{fp64,mp} — A100 SLURM submits.

Test plan

  • All 4 JSON files schema-valid (parsed cleanly by delaunay_aggregate.py)
  • comparison.json + comparison.png regenerated end-to-end
  • No untracked or modified files in jax_profiling/results/jit/imaging/ outside the new delaunay/ subdir
  • Existing legacy per-version flat summary files (delaunay_likelihood_summary_hst_v*.{json,png}, delaunay_sparse_cpu_*) untouched
  • Eager log_evidence regression assertion: 29179.949 matches EXPECTED_LOG_EVIDENCE_HST (rtol=1e-4) on all 4 configs
  • Full-pipeline JIT + vmap log_evidence assertions pass (rtol=1e-4 fp64, rtol=1e-2 mp)

Followups

  • One-line PR to switch jax_profiling/jit/imaging/delaunay.py step 12 from jnp.linalg.solve to al.util.inversion.reconstruction_positive_only_from so the canonical reference's per-step reconstruction timing reflects production cost.
  • Investigate Delaunay-on-CPU stall (likely numba cache contention with prior GPU-mode runs).

🤖 Generated with Claude Code

… 2060 sweep

Four configs side-by-side for the Delaunay imaging likelihood (Sersic
+ Isothermal + ExternalShear lens with a Delaunay source mesh of ~706
vertices + ConstantSplit regularization), covering the consumer RTX
2060 Max-Q + i9-10885H laptop and production A100, in both fp64 and
mixed-precision variants.

Local CPU configs (local_cpu_fp64 / local_cpu_mp) were attempted but
both runs hung indefinitely in the dataset / mask oversampling setup
(sub-2% CPU usage for 18-24 minutes, no progress past mask padding).
This is a Delaunay-on-CPU specific stall not present in the prior MGE
or rectangular pixelization sweeps; root cause likely in numba JIT
cache contention between the prior GPU-mode run and the forced-CPU
mode in the same shell. Skipped for this PR; will investigate as a
followup.

Generated by new tooling in z_projects/profiling/scripts/ (separate
local-only commit, no PR target).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…dness

Investigated the "Delaunay-on-CPU stall" flagged in the original PR
caveat. Root cause: my output filter pipeline (tee | tail -25, then
grep) block-buffered Python's print() calls, making a healthy 90-second
CPU fp64 run look hung at 1-2% CPU for 18+ minutes. With
PYTHONUNBUFFERED=1 and a tee target the script reads directly, the run
completes end-to-end in ~90 sec.

Adding the local_cpu_fp64 row to the canonical dir + updated
comparison.json + comparison.png. Now 5 configs side-by-side.

local_cpu_mp still hangs at full_pipeline_first_call after compile
(different failure mode from the buffering issue — main thread blocks
on futex_wait_queue_me, JAX worker threads also on futex). Likely a
real but separate issue specific to mixed-precision JAX on CPU. Left
as a followup investigation; ships 5 configs instead of 6.

Companion script fix on the z_projects/profiling side:
delaunay_profile.py now forces line-buffered stdout so future runs
flush per-section progress regardless of downstream pipe.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Jammy2211 Jammy2211 merged commit 5a1c98e into main May 10, 2026
Jammy2211 added a commit that referenced this pull request May 10, 2026
…59)

Production AnalysisImaging uses NNLS (reconstruction_positive_only_from)
for the source reconstruction; the canonical step-by-step profiler
inadvertently used the cheaper jnp.linalg.solve, under-reporting the
per-step "Regularized reconstruction" cost by roughly an order of
magnitude (5 ms vs 47 ms on a consumer RTX 2060).

The downstream log-evidence value is unchanged within rtol=1e-4 — at
prior medians the well-conditioned ConstantSplit problem yields no
negative source pixels, so NNLS reduces to the linear solve.

Verified end-to-end against EXPECTED_LOG_EVIDENCE_HST = 29179.9490711974.

Followup to #58 (Delaunay profiling sweep), which already uses NNLS
in the per-config delaunay_profile.py and called this discrepancy out
in its caveats.

Co-authored-by: Jammy2211 <JNightingale2211@gmail.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant