diff --git a/jax_profiling/dataset/interferometer/hannah/data.fits b/jax_profiling/dataset/interferometer/hannah/data.fits new file mode 100644 index 0000000..1a1054a Binary files /dev/null and b/jax_profiling/dataset/interferometer/hannah/data.fits differ diff --git a/jax_profiling/dataset/interferometer/hannah/noise_map.fits b/jax_profiling/dataset/interferometer/hannah/noise_map.fits new file mode 100644 index 0000000..18a3313 Binary files /dev/null and b/jax_profiling/dataset/interferometer/hannah/noise_map.fits differ diff --git a/jax_profiling/dataset/interferometer/hannah/positions.json b/jax_profiling/dataset/interferometer/hannah/positions.json new file mode 100644 index 0000000..b00a365 --- /dev/null +++ b/jax_profiling/dataset/interferometer/hannah/positions.json @@ -0,0 +1,20 @@ +{ + "type": "instance", + "class_path": "autoarray.structures.grids.irregular_2d.Grid2DIrregular", + "arguments": { + "values": { + "type": "ndarray", + "array": [ + [ + 0.77734375, + 1.6655202101427289 + ], + [ + -1.3515625, + -0.5051814855409225 + ] + ], + "dtype": "float64" + } + } +} \ No newline at end of file diff --git a/jax_profiling/dataset/interferometer/hannah/tracer.json b/jax_profiling/dataset/interferometer/hannah/tracer.json new file mode 100644 index 0000000..dba4964 --- /dev/null +++ b/jax_profiling/dataset/interferometer/hannah/tracer.json @@ -0,0 +1,87 @@ +{ + "type": "instance", + "class_path": "autolens.lens.tracer.Tracer", + "arguments": { + "galaxies": { + "type": "list", + "values": [ + { + "type": "instance", + "class_path": "autogalaxy.galaxy.galaxy.Galaxy", + "arguments": { + "label": null, + "redshift": 0.5, + "mass": { + "type": "instance", + "class_path": "autogalaxy.profiles.mass.total.isothermal.Isothermal", + "arguments": { + "ell_comps": { + "type": "tuple", + "values": [ + 0.05263157894736841, + 3.2227547345982974e-18 + ] + }, + "centre": { + "type": "tuple", + "values": [ + 0.0, + 0.0 + ] + }, + "einstein_radius": 1.6 + } + }, + "shear": { + "type": "instance", + "class_path": "autogalaxy.profiles.mass.sheets.external_shear.ExternalShear", + "arguments": { + "gamma_1": 0.05, + "gamma_2": 0.05 + } + } + } + }, + { + "type": "instance", + "class_path": "autogalaxy.galaxy.galaxy.Galaxy", + "arguments": { + "label": null, + "redshift": 1.0, + "bulge": { + "type": "instance", + "class_path": "autogalaxy.profiles.light.standard.sersic_core.SersicCore", + "arguments": { + "sersic_index": 2.5, + "radius_break": 0.025, + "effective_radius": 1.0, + "ell_comps": { + "type": "tuple", + "values": [ + 0.0962250448649376, + -0.05555555555555551 + ] + }, + "intensity": 0.3, + "gamma": 0.25, + "centre": { + "type": "tuple", + "values": [ + 0.1, + 0.1 + ] + }, + "alpha": 3.0 + } + } + } + } + ] + }, + "cosmology": { + "type": "instance", + "class_path": "autogalaxy.cosmology.model.Planck15", + "arguments": {} + } + } +} \ No newline at end of file diff --git a/jax_profiling/dataset/interferometer/hannah/uv_wavelengths.fits b/jax_profiling/dataset/interferometer/hannah/uv_wavelengths.fits new file mode 100644 index 0000000..7fdf935 Binary files /dev/null and b/jax_profiling/dataset/interferometer/hannah/uv_wavelengths.fits differ diff --git a/jax_profiling/dataset_setup/interferometer.py b/jax_profiling/dataset_setup/interferometer.py index 0575f09..fa1942f 100644 --- a/jax_profiling/dataset_setup/interferometer.py +++ b/jax_profiling/dataset_setup/interferometer.py @@ -48,6 +48,7 @@ "shape_native": (256, 256), "noise_sigma": 1000.0, "seed": 1, + "transformer_class": "dft", }, "alma": { "n_visibilities": 1000, @@ -56,6 +57,16 @@ "shape_native": (256, 256), "noise_sigma": 100.0, "seed": 1, + "transformer_class": "dft", + }, + "hannah": { + "n_visibilities": 16984, + "uv_scale": 2.0e6, + "pixel_scale": 0.125, + "shape_native": (40, 40), + "noise_sigma": 100.0, + "seed": 1, + "transformer_class": "nufft", }, } @@ -102,11 +113,17 @@ def simulate(instrument: str): print(f" Total visibilities: {uv_wavelengths.shape[0]}") + transformer_choice = config.get("transformer_class", "dft").lower() + transformer_class = { + "dft": al.TransformerDFT, + "nufft": al.TransformerNUFFT, + }[transformer_choice] + simulator = al.SimulatorInterferometer( uv_wavelengths=uv_wavelengths, exposure_time=300.0, noise_sigma=config["noise_sigma"], - transformer_class=al.TransformerDFT, + transformer_class=transformer_class, noise_seed=1, ) diff --git a/jax_profiling/jit/datacube/delaunay.py b/jax_profiling/jit/datacube/delaunay.py index 15735da..3d7918a 100644 --- a/jax_profiling/jit/datacube/delaunay.py +++ b/jax_profiling/jit/datacube/delaunay.py @@ -90,6 +90,7 @@ import numpy as np import jax import jax.numpy as jnp +import os import time import subprocess import sys @@ -106,13 +107,16 @@ # --------------------------------------------------------------------------- INSTRUMENTS = { - "sma": {"pixel_scale": 0.1, "real_space_shape": (256, 256)}, - "alma": {"pixel_scale": 0.05, "real_space_shape": (256, 256)}, + "sma": {"pixel_scale": 0.1, "real_space_shape": (256, 256), "mask_radius": 3.0}, + "alma": {"pixel_scale": 0.05, "real_space_shape": (256, 256), "mask_radius": 3.0}, + "hannah": {"pixel_scale": 0.125, "real_space_shape": (40, 40), "mask_radius": 2.3}, } -instrument = "sma" # <-- change this to profile a different instrument +instrument = "hannah" # <-- realistic ALMA settings for Hannah's science case -n_channels = 4 +# n_channels = 34 matches Hannah's real ALMA cube. For quick iteration on the +# smaller SMA dataset, drop this to 4 (also flip ``instrument`` back to "sma"). +n_channels = 34 overlay_shape = (26, 26) edge_n_points = 30 regularization_coefficient = 1.0 @@ -198,7 +202,7 @@ def jit_profile(func, label, *args, n_repeats=10): check=True, ) -mask_radius = 3.0 +mask_radius = INSTRUMENTS[instrument]["mask_radius"] real_space_mask = al.Mask2D.circular( shape_native=real_space_shape, @@ -214,6 +218,10 @@ def jit_profile(func, label, *args, n_repeats=10): uv_wavelengths_path=dataset_path / "uv_wavelengths.fits", real_space_mask=real_space_mask, transformer_class=al.TransformerDFT, + # DFT is intentional even at ALMA-scale visibility counts — profiling + # the JAX-traceable path is the goal, NUFFT (pynufft) is not yet + # JIT-friendly. + raise_error_dft_visibilities_limit=False, ) for _ in range(n_channels) ] @@ -759,39 +767,59 @@ def compute_log_evidence( print("FULL-PIPELINE CUBE JIT (for comparison)") print("=" * 70) -analysis_list = [ - al.AnalysisInterferometer(dataset=d, adapt_images=adapt_images, use_jax=True) - for d in dataset_list -] - - -def full_cube_pipeline_from_params(params_tree): - """Cube log-evidence via the explicit per-channel sum. - - Same shape as the user-facing ``datacube/likelihood_function.py``: - feeds the shared instance to every per-channel - ``AnalysisInterferometer.log_likelihood_function`` and sums. - """ - total = jnp.zeros(()) - for analysis in analysis_list: - total = total + analysis.log_likelihood_function(instance=params_tree) - return total - +# Part C is expensive at large n_channels: lower + compile build a graph +# proportional to n_channels (e.g. ~70s for n_channels=34 on a laptop CPU), +# and the steady-state first-call follows. Default to skipping; opt in with +# CUBE_FULL_JIT=1 when the full-pipeline timing matters (e.g. comparing +# step-by-step total against single-JIT). +_run_full_cube_jit = os.environ.get("CUBE_FULL_JIT") == "1" + +if _run_full_cube_jit: + analysis_list = [ + al.AnalysisInterferometer(dataset=d, adapt_images=adapt_images, use_jax=True) + for d in dataset_list + ] -_, full_cube_result = jit_profile( - full_cube_pipeline_from_params, "full_cube_pipeline", params_tree -) -full_pipeline_per_call = timer.records[-1][1] / 10 + def full_cube_pipeline_from_params(params_tree): + """Cube log-evidence via the explicit per-channel sum. + + Same shape as the user-facing ``datacube/likelihood_function.py``: + feeds the shared instance to every per-channel + ``AnalysisInterferometer.log_likelihood_function`` and sums. + """ + total = jnp.zeros(()) + for analysis in analysis_list: + total = total + analysis.log_likelihood_function(instance=params_tree) + return total + + _full_cube_n_repeats = 3 + _, full_cube_result = jit_profile( + full_cube_pipeline_from_params, + "full_cube_pipeline", + params_tree, + n_repeats=_full_cube_n_repeats, + ) + full_pipeline_per_call = timer.records[-1][1] / _full_cube_n_repeats -print(f" full cube log_evidence (JIT) = {full_cube_result}") + print(f" full cube log_evidence (JIT) = {full_cube_result}") -np.testing.assert_allclose( - float(full_cube_result), - cube_log_evidence_ref, - rtol=1e-4, - err_msg="Full-pipeline cube JIT log_evidence does not match summed eager FitInterferometer.log_evidence", -) -print(" Eager-vs-JIT cube correctness PASSED") + np.testing.assert_allclose( + float(full_cube_result), + cube_log_evidence_ref, + rtol=1e-4, + err_msg="Full-pipeline cube JIT log_evidence does not match summed eager FitInterferometer.log_evidence", + ) + print(" Eager-vs-JIT cube correctness PASSED") +else: + full_cube_result = None + full_pipeline_per_call = float("nan") + print( + " Full-pipeline cube JIT SKIPPED — opt-in via CUBE_FULL_JIT=1. " + f"At n_channels={n_channels} the lower + compile alone is on the order of " + f"{n_channels * 2}-{n_channels * 3}s, so it's gated to keep the default " + "runtime usable; the per-step Part B JIT data above is what feeds the " + "shared-Lᵀ W̃ L analysis." + ) # =================================================================== # PART D — vmap (skipped for cube) @@ -832,7 +860,10 @@ def full_cube_pipeline_from_params(params_tree): print(f" Edge zeroed pixels: {edge_pixels_total}") print("-" * 70) print(f" Cube reference log_evidence: {cube_log_evidence_ref}") -print(f" Cube JIT log_evidence: {float(full_cube_result)}") +if full_cube_result is not None: + print(f" Cube JIT log_evidence: {float(full_cube_result)}") +else: + print(f" Cube JIT log_evidence: SKIPPED (CUBE_FULL_JIT=1 to enable)") print("-" * 70) max_label = max(len(label) for label, _ in likelihood_steps) @@ -848,7 +879,10 @@ def full_cube_pipeline_from_params(params_tree): print("-" * 70) print(f" {'TOTAL (step-by-step cube cost)':<{max_label}} {step_total:>12.6f} s") -print(f" {'Full pipeline cube (single JIT)':<{max_label}} {full_pipeline_per_call:>12.6f} s") +if np.isfinite(full_pipeline_per_call): + print(f" {'Full pipeline cube (single JIT)':<{max_label}} {full_pipeline_per_call:>12.6f} s") +else: + print(f" {'Full pipeline cube (single JIT)':<{max_label}} SKIPPED") print(f" {f'Shared-Lᵀ W̃ L savings (curvature only, est.)':<{max_label}} {shared_lwl_savings:>12.6f} s") print("=" * 70) @@ -871,7 +905,9 @@ def full_cube_pipeline_from_params(params_tree): "regularization_coefficient": regularization_coefficient, }, "cube_log_evidence_eager": cube_log_evidence_ref, - "cube_log_evidence_jit": float(full_cube_result), + "cube_log_evidence_jit": ( + float(full_cube_result) if full_cube_result is not None else None + ), "log_evidence_per_channel_eager": [float(le) for le in log_evidence_per_channel], "steps_cube_cost": {label: per_call for label, per_call in likelihood_steps}, "per_channel_costs": { @@ -914,13 +950,14 @@ def full_cube_pipeline_from_params(params_tree): fontsize=9, ) -ax.axvline( - full_pipeline_per_call, - color="#C44E52", - linestyle="--", - linewidth=1.5, - label=f"Full pipeline cube (single JIT): {full_pipeline_per_call:.6f} s", -) +if np.isfinite(full_pipeline_per_call): + ax.axvline( + full_pipeline_per_call, + color="#C44E52", + linestyle="--", + linewidth=1.5, + label=f"Full pipeline cube (single JIT): {full_pipeline_per_call:.6f} s", + ) ax.axvline( shared_lwl_savings, color="#8172B2", @@ -959,26 +996,47 @@ def full_cube_pipeline_from_params(params_tree): # Regression assertion — deterministic cube log-evidence # =================================================================== # -# Identical channels = exact N × single-channel log-evidence. -EXPECTED_LOG_EVIDENCE_CUBE_SMA = n_channels * -3167.5258928840763 +# Identical channels = exact N × single-channel log-evidence (for "sma"). +# For "hannah" the per-channel literal isn't pinned yet, so the assertion is +# skipped until the value below is filled in from a clean run. +EXPECTED_LOG_EVIDENCE_PER_CHANNEL = { + "sma": -3167.5258928840763, + "alma": None, + "hannah": -204838.07924622478, +} -np.testing.assert_allclose( - cube_log_evidence_ref, - EXPECTED_LOG_EVIDENCE_CUBE_SMA, - rtol=1e-4, - err_msg=( - f"datacube/delaunay[{instrument}]: regression — eager cube log_evidence " - f"drifted (got {cube_log_evidence_ref}, expected {EXPECTED_LOG_EVIDENCE_CUBE_SMA})" - ), -) -print( - f"\n Eager cube regression assertion PASSED: log_evidence matches " - f"{EXPECTED_LOG_EVIDENCE_CUBE_SMA:.6f}" +_per_channel = EXPECTED_LOG_EVIDENCE_PER_CHANNEL.get(instrument) +expected_cube_log_evidence = ( + n_channels * _per_channel if _per_channel is not None else None ) -np.testing.assert_allclose( - float(full_cube_result), - EXPECTED_LOG_EVIDENCE_CUBE_SMA, - rtol=1e-4, - err_msg=f"datacube/delaunay[{instrument}]: regression — full cube log_evidence drifted", -) -print(f" Full-pipeline cube regression assertion PASSED") + +if expected_cube_log_evidence is None: + print( + f"\n Cube regression assertion SKIPPED for [{instrument}] — " + f"capture this run's eager cube log_evidence ({cube_log_evidence_ref}), " + f"divide by n_channels ({n_channels}) to get the per-channel value " + f"({cube_log_evidence_ref / n_channels}), and paste that into " + f"EXPECTED_LOG_EVIDENCE_PER_CHANNEL[{instrument!r}]." + ) +else: + np.testing.assert_allclose( + cube_log_evidence_ref, + expected_cube_log_evidence, + rtol=1e-4, + err_msg=( + f"datacube/delaunay[{instrument}]: regression — eager cube log_evidence " + f"drifted (got {cube_log_evidence_ref}, expected {expected_cube_log_evidence})" + ), + ) + print( + f"\n Eager cube regression assertion PASSED: log_evidence matches " + f"{expected_cube_log_evidence:.6f}" + ) + if full_cube_result is not None: + np.testing.assert_allclose( + float(full_cube_result), + expected_cube_log_evidence, + rtol=1e-4, + err_msg=f"datacube/delaunay[{instrument}]: regression — full cube log_evidence drifted", + ) + print(f" Full-pipeline cube regression assertion PASSED") diff --git a/jax_profiling/jit/interferometer/delaunay.py b/jax_profiling/jit/interferometer/delaunay.py index df2a046..34c1d9b 100644 --- a/jax_profiling/jit/interferometer/delaunay.py +++ b/jax_profiling/jit/interferometer/delaunay.py @@ -98,8 +98,9 @@ # --------------------------------------------------------------------------- INSTRUMENTS = { - "sma": {"pixel_scale": 0.1, "real_space_shape": (256, 256)}, - "alma": {"pixel_scale": 0.05, "real_space_shape": (256, 256)}, + "sma": {"pixel_scale": 0.1, "real_space_shape": (256, 256), "mask_radius": 3.0}, + "alma": {"pixel_scale": 0.05, "real_space_shape": (256, 256), "mask_radius": 3.0}, + "hannah": {"pixel_scale": 0.125, "real_space_shape": (40, 40), "mask_radius": 2.3}, } instrument = "sma" # <-- change this to profile a different instrument @@ -190,7 +191,7 @@ def jit_profile(func, label, *args, n_repeats=10): check=True, ) -mask_radius = 3.0 +mask_radius = INSTRUMENTS[instrument]["mask_radius"] real_space_mask = al.Mask2D.circular( shape_native=real_space_shape, @@ -205,6 +206,10 @@ def jit_profile(func, label, *args, n_repeats=10): uv_wavelengths_path=dataset_path / "uv_wavelengths.fits", real_space_mask=real_space_mask, transformer_class=al.TransformerDFT, + # DFT is intentional even at ALMA-scale visibility counts — profiling + # the JAX-traceable path is the goal, NUFFT (pynufft) is not yet + # JIT-friendly. + raise_error_dft_visibilities_limit=False, ) n_visibilities = dataset.uv_wavelengths.shape[0] @@ -1108,31 +1113,39 @@ def full_pipeline_from_params(params_tree): # # Simulator truth parameters via GaussianPrior(mean=truth, sigma=small) # make the full-pipeline log-evidence deterministic at the prior median. -EXPECTED_LOG_EVIDENCE_SMA = -3167.5258928840763 +# Pinned empirically per instrument; ``None`` means "skip the assertion and +# print the value so it can be pasted in here on a clean run". +EXPECTED_LOG_EVIDENCE = { + "sma": -3167.5258928840763, + "alma": None, + "hannah": -204838.07924622478, +} + +expected_log_evidence = EXPECTED_LOG_EVIDENCE.get(instrument) -if EXPECTED_LOG_EVIDENCE_SMA is None: +if expected_log_evidence is None: print( - f"\n Regression assertion SKIPPED — " + f"\n Regression assertion SKIPPED for [{instrument}] — " f"capture this run's eager log_evidence ({figure_of_merit_ref}) " - f"and set EXPECTED_LOG_EVIDENCE_SMA in this script." + f"and paste it into EXPECTED_LOG_EVIDENCE[{instrument!r}]." ) else: np.testing.assert_allclose( figure_of_merit_ref, - EXPECTED_LOG_EVIDENCE_SMA, + expected_log_evidence, rtol=1e-4, err_msg=( f"interferometer/delaunay[{instrument}]: regression — eager log_evidence " - f"drifted (got {figure_of_merit_ref}, expected {EXPECTED_LOG_EVIDENCE_SMA})" + f"drifted (got {figure_of_merit_ref}, expected {expected_log_evidence})" ), ) print( f" Eager regression assertion PASSED: log_evidence matches " - f"{EXPECTED_LOG_EVIDENCE_SMA:.6f}" + f"{expected_log_evidence:.6f}" ) np.testing.assert_allclose( float(full_result), - EXPECTED_LOG_EVIDENCE_SMA, + expected_log_evidence, rtol=1e-4, err_msg=f"interferometer/delaunay[{instrument}]: regression — full log_evidence drifted", ) @@ -1140,7 +1153,7 @@ def full_pipeline_from_params(params_tree): if result_vmap is not None: np.testing.assert_allclose( np.array(result_vmap), - EXPECTED_LOG_EVIDENCE_SMA, + expected_log_evidence, rtol=1e-4, err_msg=f"interferometer/delaunay[{instrument}]: regression — vmap log_evidence drifted", ) diff --git a/jax_profiling/results/jit/datacube/delaunay_likelihood_summary_hannah_v2026.5.14.2.json b/jax_profiling/results/jit/datacube/delaunay_likelihood_summary_hannah_v2026.5.14.2.json new file mode 100644 index 0000000..8f06028 --- /dev/null +++ b/jax_profiling/results/jit/datacube/delaunay_likelihood_summary_hannah_v2026.5.14.2.json @@ -0,0 +1,82 @@ +{ + "autolens_version": "2026.5.14.2", + "instrument": "hannah", + "model": "delaunay", + "n_channels": 34, + "configuration": { + "pixel_scale_arcsec": 0.125, + "mask_radius_arcsec": 2.3, + "real_space_shape": [ + 40, + 40 + ], + "visibilities_per_channel": 16984, + "overlay_shape": [ + 26, + 26 + ], + "edge_n_points": 30, + "delaunay_vertices": 578, + "edge_zeroed_pixels": 30, + "regularization_coefficient": 1.0 + }, + "cube_log_evidence_eager": -6964494.6943716435, + "cube_log_evidence_jit": null, + "log_evidence_per_channel_eager": [ + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478, + -204838.07924622478 + ], + "steps_cube_cost": { + "Ray-trace data grid (shared)": 0.0004436899998836452, + "Ray-trace mesh grid (shared)": 0.0004176900001766626, + "Inversion setup, incl. NUFFT (per channel \u00d7 34)": 163.414330899989, + "Data vector D (per channel \u00d7 34)": 4.0696905199954925, + "Curvature matrix F (per channel \u00d7 34)": 35.87535703999529, + "Regularization matrix H (shared)": 0.015476199998374796, + "Reconstruction NNLS (per channel \u00d7 34)": 1.985523500007548, + "Mapped recon + log evidence (per channel \u00d7 34)": 0.5584676800041053 + }, + "per_channel_costs": { + "inversion_setup": 4.806303849999677, + "data_vector": 0.11969677999986743, + "curvature_matrix": 1.0551575599998615, + "reconstruction": 0.058397750000222, + "log_evidence": 0.016425520000120743 + }, + "total_step_by_step_cube": 205.9197072199899, + "full_pipeline_cube_single_jit": NaN, + "shared_lwl_savings_estimate": 34.82019947999543, + "vmap": "SKIPPED \u2014 cube batching axis is 'datasets', not 'parameters'" +} \ No newline at end of file diff --git a/jax_profiling/results/jit/datacube/delaunay_likelihood_summary_hannah_v2026.5.14.2.png b/jax_profiling/results/jit/datacube/delaunay_likelihood_summary_hannah_v2026.5.14.2.png new file mode 100644 index 0000000..96ad5b3 Binary files /dev/null and b/jax_profiling/results/jit/datacube/delaunay_likelihood_summary_hannah_v2026.5.14.2.png differ diff --git a/jax_profiling/results/jit/interferometer/delaunay_likelihood_summary_hannah_v2026.5.14.2.json b/jax_profiling/results/jit/interferometer/delaunay_likelihood_summary_hannah_v2026.5.14.2.json new file mode 100644 index 0000000..160bd7f --- /dev/null +++ b/jax_profiling/results/jit/interferometer/delaunay_likelihood_summary_hannah_v2026.5.14.2.json @@ -0,0 +1,39 @@ +{ + "autolens_version": "2026.5.14.2", + "instrument": "hannah", + "model": "delaunay", + "configuration": { + "pixel_scale_arcsec": 0.125, + "mask_radius_arcsec": 2.3, + "real_space_shape": [ + 40, + 40 + ], + "visibilities": 16984, + "overlay_shape": [ + 26, + 26 + ], + "edge_n_points": 30, + "delaunay_vertices": 578, + "edge_zeroed_pixels": 30, + "regularization_coefficient": 1.0 + }, + "log_likelihood_eager": -204561.286468125, + "figure_of_merit_eager": -204838.07924622478, + "log_evidence_jit": -204838.07924622117, + "steps": { + "Ray-trace data grid": 0.0003818699999101227, + "Ray-trace mesh grid": 0.00034264000023540573, + "Inversion setup (steps 5-8 combined, incl. NUFFT)": 4.5931885300000435, + "Data vector (D)": 0.11071577000002435, + "Curvature matrix (F)": 1.0810229799997615, + "Regularization matrix (H)": 0.018836100000044098, + "Regularized reconstruction": 0.07121871999988798, + "Mapped recon + log evidence": 0.021166530000118654 + }, + "total_step_by_step": 5.896873140000025, + "full_pipeline_single_jit": 5.6682252599999625, + "vmap": "SKIPPED \u2014 opt-in via DELAUNAY_VMAP=1", + "memory_mb": null +} \ No newline at end of file diff --git a/jax_profiling/results/jit/interferometer/delaunay_likelihood_summary_hannah_v2026.5.14.2.png b/jax_profiling/results/jit/interferometer/delaunay_likelihood_summary_hannah_v2026.5.14.2.png new file mode 100644 index 0000000..207ce84 Binary files /dev/null and b/jax_profiling/results/jit/interferometer/delaunay_likelihood_summary_hannah_v2026.5.14.2.png differ