From 0c0172963ee9fd1b40e6c39f340639d5511e8282 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sat, 16 May 2026 10:16:59 +0100 Subject: [PATCH] perf(jax-profiling): enable apply_sparse_operator on Delaunay profilers Wire the now-unblocked apply_sparse_operator(use_jax=True) path into the two Delaunay JAX profilers so per-fit curvature assembly uses the FFT-based sparse precision-matrix preload instead of dense DFT per source pixel. Unblocked by PyAutoLabs/PyAutoArray#316 (Pmax > 1 extent-indexing fix); the path was previously guarded with NotImplementedError on Delaunay by PR #315. - jax_profiling/jit/interferometer/delaunay.py: new "apply_sparse_operator" timer section right after dataset_load. - jax_profiling/jit/datacube/delaunay.py: chain .apply_sparse_operator(use_jax=True, show_progress=False) onto the per-channel from_fits in the dataset_list comprehension. Validation: - jit/interferometer/delaunay.py: end-to-end run, eager log_likelihood = -3152.876, figure_of_merit = -3167.526, all 11 per-step JIT stages completed. - jit/datacube/delaunay.py: end-to-end run on the hannah 34-channel cube, cube log_evidence = -6964494.694 (regression assertion PASSED), all 8 per-step JIT stages completed. Co-Authored-By: Claude Opus 4.7 (1M context) --- jax_profiling/jit/datacube/delaunay.py | 7 ++++++- jax_profiling/jit/interferometer/delaunay.py | 7 +++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/jax_profiling/jit/datacube/delaunay.py b/jax_profiling/jit/datacube/delaunay.py index 3d7918a..01ff97d 100644 --- a/jax_profiling/jit/datacube/delaunay.py +++ b/jax_profiling/jit/datacube/delaunay.py @@ -211,6 +211,11 @@ def jit_profile(func, label, *args, n_repeats=10): ) with timer.section("dataset_list_load"): + # apply_sparse_operator: precompute the NUFFT precision-matrix preload per + # channel so per-fit curvature assembly uses the FFT-based sparse path + # instead of dense DFT for every source pixel. Unblocked by PyAutoArray#316 + # (the Pmax > 1 extent-indexing fix); on Delaunay this was previously + # guarded with NotImplementedError. dataset_list = [ al.Interferometer.from_fits( data_path=dataset_path / "data.fits", @@ -222,7 +227,7 @@ def jit_profile(func, label, *args, n_repeats=10): # the JAX-traceable path is the goal, NUFFT (pynufft) is not yet # JIT-friendly. raise_error_dft_visibilities_limit=False, - ) + ).apply_sparse_operator(use_jax=True, show_progress=False) for _ in range(n_channels) ] diff --git a/jax_profiling/jit/interferometer/delaunay.py b/jax_profiling/jit/interferometer/delaunay.py index 34c1d9b..4fec68c 100644 --- a/jax_profiling/jit/interferometer/delaunay.py +++ b/jax_profiling/jit/interferometer/delaunay.py @@ -212,6 +212,13 @@ def jit_profile(func, label, *args, n_repeats=10): raise_error_dft_visibilities_limit=False, ) +with timer.section("apply_sparse_operator"): + # Precompute the NUFFT precision-matrix preload so per-fit curvature + # assembly uses the FFT-based sparse path instead of dense DFT for every + # source pixel. Unblocked by PyAutoArray#316 (the Pmax > 1 extent-indexing + # fix); on Delaunay this was previously guarded with NotImplementedError. + dataset = dataset.apply_sparse_operator(use_jax=True, show_progress=True) + n_visibilities = dataset.uv_wavelengths.shape[0] print(f" Total visibilities: {n_visibilities}")