diff --git a/dataset/imaging/hst/data.fits b/dataset/imaging/hst/data.fits index f0f9c28..056d1e5 100644 Binary files a/dataset/imaging/hst/data.fits and b/dataset/imaging/hst/data.fits differ diff --git a/dataset/imaging/hst/noise_map.fits b/dataset/imaging/hst/noise_map.fits index fee51bf..dd44044 100644 Binary files a/dataset/imaging/hst/noise_map.fits and b/dataset/imaging/hst/noise_map.fits differ diff --git a/dataset/imaging/hst/tracer.json b/dataset/imaging/hst/tracer.json index 60f4998..1e1554e 100644 --- a/dataset/imaging/hst/tracer.json +++ b/dataset/imaging/hst/tracer.json @@ -9,8 +9,8 @@ "type": "instance", "class_path": "autogalaxy.galaxy.galaxy.Galaxy", "arguments": { - "redshift": 0.5, "label": null, + "redshift": 0.5, "bulge": { "type": "instance", "class_path": "autogalaxy.profiles.light.standard.sersic.Sersic", @@ -23,6 +23,7 @@ 3.2227547345982974e-18 ] }, + "effective_radius": 0.6, "centre": { "type": "tuple", "values": [ @@ -30,8 +31,7 @@ 0.0 ] }, - "sersic_index": 3.0, - "effective_radius": 0.6 + "sersic_index": 3.0 } }, "mass": { @@ -45,22 +45,22 @@ 3.2227547345982974e-18 ] }, + "einstein_radius": 1.6, "centre": { "type": "tuple", "values": [ 0.0, 0.0 ] - }, - "einstein_radius": 1.6 + } } }, "shear": { "type": "instance", "class_path": "autogalaxy.profiles.mass.sheets.external_shear.ExternalShear", "arguments": { - "gamma_2": 0.05, - "gamma_1": 0.05 + "gamma_1": 0.05, + "gamma_2": 0.05 } } } @@ -69,14 +69,14 @@ "type": "instance", "class_path": "autogalaxy.galaxy.galaxy.Galaxy", "arguments": { - "redshift": 1.0, "label": null, + "redshift": 1.0, "bulge": { "type": "instance", "class_path": "autogalaxy.profiles.light.standard.sersic_core.SersicCore", "arguments": { "intensity": 4.0, - "gamma": 0.25, + "alpha": 3.0, "ell_comps": { "type": "tuple", "values": [ @@ -84,6 +84,9 @@ -0.05555555555555551 ] }, + "gamma": 0.25, + "radius_break": 0.025, + "effective_radius": 0.1, "centre": { "type": "tuple", "values": [ @@ -91,10 +94,7 @@ 0.0 ] }, - "sersic_index": 1.0, - "effective_radius": 0.1, - "alpha": 3.0, - "radius_break": 0.025 + "sersic_index": 1.0 } } } diff --git a/dataset/interferometer/hannah/data.fits b/dataset/interferometer/hannah/data.fits new file mode 100644 index 0000000..1a1054a Binary files /dev/null and b/dataset/interferometer/hannah/data.fits differ diff --git a/dataset/interferometer/hannah/noise_map.fits b/dataset/interferometer/hannah/noise_map.fits new file mode 100644 index 0000000..18a3313 Binary files /dev/null and b/dataset/interferometer/hannah/noise_map.fits differ diff --git a/dataset/interferometer/hannah/positions.json b/dataset/interferometer/hannah/positions.json new file mode 100644 index 0000000..b00a365 --- /dev/null +++ b/dataset/interferometer/hannah/positions.json @@ -0,0 +1,20 @@ +{ + "type": "instance", + "class_path": "autoarray.structures.grids.irregular_2d.Grid2DIrregular", + "arguments": { + "values": { + "type": "ndarray", + "array": [ + [ + 0.77734375, + 1.6655202101427289 + ], + [ + -1.3515625, + -0.5051814855409225 + ] + ], + "dtype": "float64" + } + } +} \ No newline at end of file diff --git a/dataset/interferometer/hannah/tracer.json b/dataset/interferometer/hannah/tracer.json new file mode 100644 index 0000000..dba4964 --- /dev/null +++ b/dataset/interferometer/hannah/tracer.json @@ -0,0 +1,87 @@ +{ + "type": "instance", + "class_path": "autolens.lens.tracer.Tracer", + "arguments": { + "galaxies": { + "type": "list", + "values": [ + { + "type": "instance", + "class_path": "autogalaxy.galaxy.galaxy.Galaxy", + "arguments": { + "label": null, + "redshift": 0.5, + "mass": { + "type": "instance", + "class_path": "autogalaxy.profiles.mass.total.isothermal.Isothermal", + "arguments": { + "ell_comps": { + "type": "tuple", + "values": [ + 0.05263157894736841, + 3.2227547345982974e-18 + ] + }, + "centre": { + "type": "tuple", + "values": [ + 0.0, + 0.0 + ] + }, + "einstein_radius": 1.6 + } + }, + "shear": { + "type": "instance", + "class_path": "autogalaxy.profiles.mass.sheets.external_shear.ExternalShear", + "arguments": { + "gamma_1": 0.05, + "gamma_2": 0.05 + } + } + } + }, + { + "type": "instance", + "class_path": "autogalaxy.galaxy.galaxy.Galaxy", + "arguments": { + "label": null, + "redshift": 1.0, + "bulge": { + "type": "instance", + "class_path": "autogalaxy.profiles.light.standard.sersic_core.SersicCore", + "arguments": { + "sersic_index": 2.5, + "radius_break": 0.025, + "effective_radius": 1.0, + "ell_comps": { + "type": "tuple", + "values": [ + 0.0962250448649376, + -0.05555555555555551 + ] + }, + "intensity": 0.3, + "gamma": 0.25, + "centre": { + "type": "tuple", + "values": [ + 0.1, + 0.1 + ] + }, + "alpha": 3.0 + } + } + } + } + ] + }, + "cosmology": { + "type": "instance", + "class_path": "autogalaxy.cosmology.model.Planck15", + "arguments": {} + } + } +} \ No newline at end of file diff --git a/dataset/interferometer/hannah/uv_wavelengths.fits b/dataset/interferometer/hannah/uv_wavelengths.fits new file mode 100644 index 0000000..7fdf935 Binary files /dev/null and b/dataset/interferometer/hannah/uv_wavelengths.fits differ diff --git a/likelihood/datacube/delaunay.py b/likelihood/datacube/delaunay.py index 67c3a81..7051d1c 100644 --- a/likelihood/datacube/delaunay.py +++ b/likelihood/datacube/delaunay.py @@ -90,6 +90,7 @@ import numpy as np import jax import jax.numpy as jnp +import os import time import subprocess import sys @@ -106,13 +107,16 @@ # --------------------------------------------------------------------------- INSTRUMENTS = { - "sma": {"pixel_scale": 0.1, "real_space_shape": (256, 256)}, - "alma": {"pixel_scale": 0.05, "real_space_shape": (256, 256)}, + "sma": {"pixel_scale": 0.1, "real_space_shape": (256, 256), "mask_radius": 3.0}, + "alma": {"pixel_scale": 0.05, "real_space_shape": (256, 256), "mask_radius": 3.0}, + "hannah": {"pixel_scale": 0.125, "real_space_shape": (40, 40), "mask_radius": 2.3}, } -instrument = "sma" # <-- change this to profile a different instrument +instrument = "hannah" # <-- realistic ALMA settings for Hannah's science case -n_channels = 4 +# n_channels = 34 matches Hannah's real ALMA cube. For quick iteration on the +# smaller SMA dataset, drop this to 4 (also flip ``instrument`` back to "sma"). +n_channels = 34 overlay_shape = (26, 26) edge_n_points = 30 regularization_coefficient = 1.0 @@ -195,7 +199,7 @@ def jit_profile(func, label, *args, n_repeats=10): f"then copy the result into autolens_profiling/dataset/." ) -mask_radius = 3.0 +mask_radius = INSTRUMENTS[instrument]["mask_radius"] real_space_mask = al.Mask2D.circular( shape_native=real_space_shape, @@ -204,6 +208,11 @@ def jit_profile(func, label, *args, n_repeats=10): ) with timer.section("dataset_list_load"): + # apply_sparse_operator: precompute the NUFFT precision-matrix preload per + # channel so per-fit curvature assembly uses the FFT-based sparse path + # instead of dense DFT for every source pixel. Unblocked by PyAutoArray#316 + # (the Pmax > 1 extent-indexing fix); on Delaunay this was previously + # guarded with NotImplementedError. dataset_list = [ al.Interferometer.from_fits( data_path=dataset_path / "data.fits", @@ -211,7 +220,11 @@ def jit_profile(func, label, *args, n_repeats=10): uv_wavelengths_path=dataset_path / "uv_wavelengths.fits", real_space_mask=real_space_mask, transformer_class=al.TransformerDFT, - ) + # DFT is intentional even at ALMA-scale visibility counts — profiling + # the JAX-traceable path is the goal, NUFFT (pynufft) is not yet + # JIT-friendly. + raise_error_dft_visibilities_limit=False, + ).apply_sparse_operator(use_jax=True, show_progress=False) for _ in range(n_channels) ] @@ -756,39 +769,59 @@ def compute_log_evidence( print("FULL-PIPELINE CUBE JIT (for comparison)") print("=" * 70) -analysis_list = [ - al.AnalysisInterferometer(dataset=d, adapt_images=adapt_images, use_jax=True) - for d in dataset_list -] - - -def full_cube_pipeline_from_params(params_tree): - """Cube log-evidence via the explicit per-channel sum. - - Same shape as the user-facing ``datacube/likelihood_function.py``: - feeds the shared instance to every per-channel - ``AnalysisInterferometer.log_likelihood_function`` and sums. - """ - total = jnp.zeros(()) - for analysis in analysis_list: - total = total + analysis.log_likelihood_function(instance=params_tree) - return total - +# Part C is expensive at large n_channels: lower + compile build a graph +# proportional to n_channels (e.g. ~70s for n_channels=34 on a laptop CPU), +# and the steady-state first-call follows. Default to skipping; opt in with +# CUBE_FULL_JIT=1 when the full-pipeline timing matters (e.g. comparing +# step-by-step total against single-JIT). +_run_full_cube_jit = os.environ.get("CUBE_FULL_JIT") == "1" + +if _run_full_cube_jit: + analysis_list = [ + al.AnalysisInterferometer(dataset=d, adapt_images=adapt_images, use_jax=True) + for d in dataset_list + ] -_, full_cube_result = jit_profile( - full_cube_pipeline_from_params, "full_cube_pipeline", params_tree -) -full_pipeline_per_call = timer.records[-1][1] / 10 + def full_cube_pipeline_from_params(params_tree): + """Cube log-evidence via the explicit per-channel sum. + + Same shape as the user-facing ``datacube/likelihood_function.py``: + feeds the shared instance to every per-channel + ``AnalysisInterferometer.log_likelihood_function`` and sums. + """ + total = jnp.zeros(()) + for analysis in analysis_list: + total = total + analysis.log_likelihood_function(instance=params_tree) + return total + + _full_cube_n_repeats = 3 + _, full_cube_result = jit_profile( + full_cube_pipeline_from_params, + "full_cube_pipeline", + params_tree, + n_repeats=_full_cube_n_repeats, + ) + full_pipeline_per_call = timer.records[-1][1] / _full_cube_n_repeats -print(f" full cube log_evidence (JIT) = {full_cube_result}") + print(f" full cube log_evidence (JIT) = {full_cube_result}") -np.testing.assert_allclose( - float(full_cube_result), - cube_log_evidence_ref, - rtol=1e-4, - err_msg="Full-pipeline cube JIT log_evidence does not match summed eager FitInterferometer.log_evidence", -) -print(" Eager-vs-JIT cube correctness PASSED") + np.testing.assert_allclose( + float(full_cube_result), + cube_log_evidence_ref, + rtol=1e-4, + err_msg="Full-pipeline cube JIT log_evidence does not match summed eager FitInterferometer.log_evidence", + ) + print(" Eager-vs-JIT cube correctness PASSED") +else: + full_cube_result = None + full_pipeline_per_call = float("nan") + print( + " Full-pipeline cube JIT SKIPPED — opt-in via CUBE_FULL_JIT=1. " + f"At n_channels={n_channels} the lower + compile alone is on the order of " + f"{n_channels * 2}-{n_channels * 3}s, so it's gated to keep the default " + "runtime usable; the per-step Part B JIT data above is what feeds the " + "shared-Lᵀ W̃ L analysis." + ) # =================================================================== # PART D — vmap (skipped for cube) @@ -829,7 +862,10 @@ def full_cube_pipeline_from_params(params_tree): print(f" Edge zeroed pixels: {edge_pixels_total}") print("-" * 70) print(f" Cube reference log_evidence: {cube_log_evidence_ref}") -print(f" Cube JIT log_evidence: {float(full_cube_result)}") +if full_cube_result is not None: + print(f" Cube JIT log_evidence: {float(full_cube_result)}") +else: + print(f" Cube JIT log_evidence: SKIPPED (CUBE_FULL_JIT=1 to enable)") print("-" * 70) max_label = max(len(label) for label, _ in likelihood_steps) @@ -845,7 +881,10 @@ def full_cube_pipeline_from_params(params_tree): print("-" * 70) print(f" {'TOTAL (step-by-step cube cost)':<{max_label}} {step_total:>12.6f} s") -print(f" {'Full pipeline cube (single JIT)':<{max_label}} {full_pipeline_per_call:>12.6f} s") +if np.isfinite(full_pipeline_per_call): + print(f" {'Full pipeline cube (single JIT)':<{max_label}} {full_pipeline_per_call:>12.6f} s") +else: + print(f" {'Full pipeline cube (single JIT)':<{max_label}} SKIPPED") print(f" {f'Shared-Lᵀ W̃ L savings (curvature only, est.)':<{max_label}} {shared_lwl_savings:>12.6f} s") print("=" * 70) @@ -868,7 +907,9 @@ def full_cube_pipeline_from_params(params_tree): "regularization_coefficient": regularization_coefficient, }, "cube_log_evidence_eager": cube_log_evidence_ref, - "cube_log_evidence_jit": float(full_cube_result), + "cube_log_evidence_jit": ( + float(full_cube_result) if full_cube_result is not None else None + ), "log_evidence_per_channel_eager": [float(le) for le in log_evidence_per_channel], "steps_cube_cost": {label: per_call for label, per_call in likelihood_steps}, "per_channel_costs": { @@ -911,13 +952,14 @@ def full_cube_pipeline_from_params(params_tree): fontsize=9, ) -ax.axvline( - full_pipeline_per_call, - color="#C44E52", - linestyle="--", - linewidth=1.5, - label=f"Full pipeline cube (single JIT): {full_pipeline_per_call:.6f} s", -) +if np.isfinite(full_pipeline_per_call): + ax.axvline( + full_pipeline_per_call, + color="#C44E52", + linestyle="--", + linewidth=1.5, + label=f"Full pipeline cube (single JIT): {full_pipeline_per_call:.6f} s", + ) ax.axvline( shared_lwl_savings, color="#8172B2", @@ -956,26 +998,47 @@ def full_cube_pipeline_from_params(params_tree): # Regression assertion — deterministic cube log-evidence # =================================================================== # -# Identical channels = exact N × single-channel log-evidence. -EXPECTED_LOG_EVIDENCE_CUBE_SMA = n_channels * -3167.5258928840763 +# Identical channels = exact N × single-channel log-evidence (for "sma"). +# For "hannah" the per-channel literal isn't pinned yet, so the assertion is +# skipped until the value below is filled in from a clean run. +EXPECTED_LOG_EVIDENCE_PER_CHANNEL = { + "sma": -3167.5258928840763, + "alma": None, + "hannah": -204838.07924622478, +} -np.testing.assert_allclose( - cube_log_evidence_ref, - EXPECTED_LOG_EVIDENCE_CUBE_SMA, - rtol=1e-4, - err_msg=( - f"datacube/delaunay[{instrument}]: regression — eager cube log_evidence " - f"drifted (got {cube_log_evidence_ref}, expected {EXPECTED_LOG_EVIDENCE_CUBE_SMA})" - ), -) -print( - f"\n Eager cube regression assertion PASSED: log_evidence matches " - f"{EXPECTED_LOG_EVIDENCE_CUBE_SMA:.6f}" +_per_channel = EXPECTED_LOG_EVIDENCE_PER_CHANNEL.get(instrument) +expected_cube_log_evidence = ( + n_channels * _per_channel if _per_channel is not None else None ) -np.testing.assert_allclose( - float(full_cube_result), - EXPECTED_LOG_EVIDENCE_CUBE_SMA, - rtol=1e-4, - err_msg=f"datacube/delaunay[{instrument}]: regression — full cube log_evidence drifted", -) -print(f" Full-pipeline cube regression assertion PASSED") + +if expected_cube_log_evidence is None: + print( + f"\n Cube regression assertion SKIPPED for [{instrument}] — " + f"capture this run's eager cube log_evidence ({cube_log_evidence_ref}), " + f"divide by n_channels ({n_channels}) to get the per-channel value " + f"({cube_log_evidence_ref / n_channels}), and paste that into " + f"EXPECTED_LOG_EVIDENCE_PER_CHANNEL[{instrument!r}]." + ) +else: + np.testing.assert_allclose( + cube_log_evidence_ref, + expected_cube_log_evidence, + rtol=1e-4, + err_msg=( + f"datacube/delaunay[{instrument}]: regression — eager cube log_evidence " + f"drifted (got {cube_log_evidence_ref}, expected {expected_cube_log_evidence})" + ), + ) + print( + f"\n Eager cube regression assertion PASSED: log_evidence matches " + f"{expected_cube_log_evidence:.6f}" + ) + if full_cube_result is not None: + np.testing.assert_allclose( + float(full_cube_result), + expected_cube_log_evidence, + rtol=1e-4, + err_msg=f"datacube/delaunay[{instrument}]: regression — full cube log_evidence drifted", + ) + print(f" Full-pipeline cube regression assertion PASSED") diff --git a/likelihood/imaging/delaunay.py b/likelihood/imaging/delaunay.py index fd0892e..64e47a0 100644 --- a/likelihood/imaging/delaunay.py +++ b/likelihood/imaging/delaunay.py @@ -72,7 +72,6 @@ # Profiling helpers # --------------------------------------------------------------------------- - class Timer: """Accumulates named timing measurements and prints a summary.""" @@ -204,10 +203,7 @@ def jit_profile(func, label, *args, n_repeats=10): print("\n--- Image mesh construction (Delaunay) ---") -overlay_shape = ( - 39, - 39, -) # calibrated → 1231 mesh vertices (1201 inside + 30 edge), science fiducial near 1250 +overlay_shape = (39, 39) # calibrated → 1231 mesh vertices (1201 inside + 30 edge), science fiducial near 1250 edge_n_points = 30 with timer.section("image_mesh_overlay"): @@ -260,7 +256,9 @@ def jit_profile(func, label, *args, n_repeats=10): shear.gamma_1 = af.GaussianPrior(mean=0.05, sigma=0.005) shear.gamma_2 = af.GaussianPrior(mean=0.05, sigma=0.005) - lens = af.Model(al.Galaxy, redshift=0.5, bulge=lens_bulge, mass=mass, shear=shear) + lens = af.Model( + al.Galaxy, redshift=0.5, bulge=lens_bulge, mass=mass, shear=shear + ) mesh = al.mesh.Delaunay( pixels=n_mesh_vertices, @@ -384,14 +382,12 @@ def jit_profile(func, label, *args, n_repeats=10): print(f" Number of planes traced: {len(traced_grids)}") - def ray_trace_data_raw(grid_raw): """Wraps ray-tracing so inputs/outputs are raw arrays.""" grid = aa.Grid2DIrregular(values=grid_raw, xp=jnp) traced = tracer.traced_grid_2d_list_from(grid=grid, xp=jnp) return jnp.stack([tg.array for tg in traced]) - _, traced_data_grids_raw = jit_profile( ray_trace_data_raw, "ray_trace_data_jit", grid_pix_raw ) @@ -412,14 +408,12 @@ def ray_trace_data_raw(grid_raw): for tg in traced_mesh: block(tg) - def ray_trace_mesh_raw(mesh_raw): """Ray-trace image-plane mesh vertices to source plane.""" grid = aa.Grid2DIrregular(values=mesh_raw, xp=jnp) traced = tracer.traced_grid_2d_list_from(grid=grid, xp=jnp) return jnp.stack([tg.array for tg in traced]) - _, traced_mesh_grids_raw = jit_profile( ray_trace_mesh_raw, "ray_trace_mesh_jit", mesh_grid_raw ) @@ -433,7 +427,6 @@ def ray_trace_mesh_raw(mesh_raw): print("\n--- Step 3: Blurred image (lens light profiles) ---") - # Sub-step 3a: Compute raw lens light images (JIT-profiled) def lens_image_raw(grid_raw, blurring_grid_raw): """Compute lens light images on masked + blurring grids (no PSF).""" @@ -443,7 +436,6 @@ def lens_image_raw(grid_raw, blurring_grid_raw): blurring_image = tracer.image_2d_from(grid=blurring_grid, xp=jnp) return image.array, blurring_image.array - with timer.section("lens_image_eager"): img_eager, blur_img_eager = lens_image_raw(grid_lp_raw, grid_blurring_raw) block(img_eager) @@ -466,19 +458,14 @@ def lens_image_raw(grid_raw, blurring_grid_raw): print(f" blurred_image shape: {blurred_image.array.shape}") - def blurred_image_from_params(params_tree): """Compute blurred image directly from a pytree ModelInstance — fully JIT-traceable.""" t = al.Tracer(galaxies=list(params_tree.galaxies)) result = t.blurred_image_2d_from( - grid=grid_lp, - psf=dataset.psf, - blurring_grid=grid_blurring, - xp=jnp, + grid=grid_lp, psf=dataset.psf, blurring_grid=grid_blurring, xp=jnp, ) return result.array - _, blurred_img_jit = jit_profile( blurred_image_from_params, "blurred_image_jit", params_tree ) @@ -490,11 +477,9 @@ def blurred_image_from_params(params_tree): print("\n--- Step 4: Profile-subtracted image ---") - def profile_subtract(data, blurred_image): return data - blurred_image - with timer.section("profile_subtract_eager"): blurred_img_jnp = jnp.array(blurred_image.array) profile_subtracted = profile_subtract(data_array, blurred_img_jnp) @@ -528,8 +513,7 @@ def profile_subtract(data, blurred_image): with timer.section("border_relocation_eager"): relocated_grid = border_relocator.relocated_grid_from(grid=traced_source_grid) relocated_mesh_grid = border_relocator.relocated_mesh_grid_from( - grid=traced_source_grid, - mesh_grid=traced_mesh_source, + grid=traced_source_grid, mesh_grid=traced_mesh_source, ) block(relocated_grid) block(relocated_mesh_grid) @@ -613,7 +597,6 @@ def profile_subtract(data, blurred_image): # border relocation → Delaunay triangulation → interpolation → mapper → mapping matrix → PSF convolution. # These steps are tightly sequential; the full pipeline JIT-compiles them all together. - def blurred_mm_from_params(params_tree): """Compute blurred mapping matrix via full inversion setup from a pytree ModelInstance.""" t = al.Tracer(galaxies=list(params_tree.galaxies)) @@ -627,19 +610,13 @@ def blurred_mm_from_params(params_tree): }, ) fit_jax = al.FitImaging( - dataset=dataset, - tracer=t, - adapt_images=adapt_images_jax, - settings=al.Settings(use_border_relocator=True), - xp=jnp, + dataset=dataset, tracer=t, adapt_images=adapt_images_jax, + settings=al.Settings(use_border_relocator=True), xp=jnp, ) return jnp.array(fit_jax.inversion.operated_mapping_matrix) - _, bmm_jit = jit_profile(blurred_mm_from_params, "inversion_setup_jit", params_tree) -likelihood_steps.append( - ("Inversion setup (steps 5-8 combined)", timer.records[-1][1] / 10) -) +likelihood_steps.append(("Inversion setup (steps 5-8 combined)", timer.records[-1][1] / 10)) print(f" blurred_mapping_matrix (JIT) shape: {bmm_jit.shape}") @@ -652,7 +629,6 @@ def blurred_mm_from_params(params_tree): print("\n--- Step 9: Data vector ---") - def compute_data_vector(blurred_mapping_matrix, image, noise_map): return al.util.inversion_imaging.data_vector_via_blurred_mapping_matrix_from( blurred_mapping_matrix=blurred_mapping_matrix, @@ -660,7 +636,6 @@ def compute_data_vector(blurred_mapping_matrix, image, noise_map): noise_map=noise_map, ) - profile_sub_jnp = jnp.array(fit.profile_subtracted_image.array) noise_jnp = jnp.array(dataset.noise_map.array) @@ -683,7 +658,6 @@ def compute_data_vector(blurred_mapping_matrix, image, noise_map): no_reg_list = list(inversion.no_regularization_index_list) - def compute_curvature_matrix(blurred_mapping_matrix, noise_map): return al.util.inversion.curvature_matrix_via_mapping_matrix_from( mapping_matrix=blurred_mapping_matrix, @@ -694,7 +668,6 @@ def compute_curvature_matrix(blurred_mapping_matrix, noise_map): xp=jnp, ) - with timer.section("curvature_matrix_eager"): curvature_matrix = compute_curvature_matrix(bmm_jnp, noise_jnp) block(curvature_matrix) @@ -739,7 +712,6 @@ def compute_curvature_matrix(blurred_mapping_matrix, noise_map): print("\n--- Step 12: Regularized reconstruction ---") - def compute_reconstruction(data_vector, curvature_matrix, regularization_matrix): curvature_reg_matrix = curvature_matrix + regularization_matrix return al.util.inversion.reconstruction_positive_only_from( @@ -748,7 +720,6 @@ def compute_reconstruction(data_vector, curvature_matrix, regularization_matrix) xp=jnp, ) - with timer.section("reconstruction_eager"): reconstruction = compute_reconstruction( jnp.array(data_vector), @@ -758,8 +729,7 @@ def compute_reconstruction(data_vector, curvature_matrix, regularization_matrix) block(reconstruction) _, reconstruction = jit_profile( - compute_reconstruction, - "reconstruction_jit", + compute_reconstruction, "reconstruction_jit", jnp.array(data_vector), jnp.array(curvature_matrix), jnp.array(regularization_matrix), @@ -774,28 +744,13 @@ def compute_reconstruction(data_vector, curvature_matrix, regularization_matrix) print("\n--- Step 13: Mapped reconstruction + log evidence ---") - def compute_log_evidence( - data, - noise_map, - blurred_image, - blurred_mapping_matrix, - reconstruction, - curvature_matrix, - regularization_matrix, - mapper_indices, + data, noise_map, blurred_image, blurred_mapping_matrix, reconstruction, + curvature_matrix, regularization_matrix, ): """Compute the full log evidence including all five terms: -2 ln e = chi^2 + s^T H s + ln[det(F+H)] - ln[det(H)] + noise_norm - - Matches the production formula in - ``autoarray/inversion/inversion/abstract.py:log_det_*`` — - reduces both the curvature_reg_matrix and the regularization_matrix - to the rows/cols indexed by ``mapper_indices`` before the log_det. - This drops the no-regularization rows (e.g. MGE Basis linear - components) which otherwise make ``det(H) = 0`` and the slogdet - return -inf, then uses Cholesky for a numerically stable log_det. """ # Map reconstruction to image mapped_recon = al.util.inversion.mapped_reconstructed_data_via_mapping_matrix_from( @@ -819,24 +774,12 @@ def compute_log_evidence( # Curvature + regularization matrix curvature_reg_matrix = curvature_matrix + regularization_matrix - # Reduce to the pixelization rows/cols only, matching the production - # ``regularization_matrix_reduced`` / ``curvature_reg_matrix_reduced`` - # properties — necessary for models with non-regularized linear - # components (e.g. MGE lens light). - creg_reduced = curvature_reg_matrix[mapper_indices][:, mapper_indices] - reg_reduced = regularization_matrix[mapper_indices][:, mapper_indices] - - # Cholesky-based log_det (matches production): - # 2 * sum(log(diag(L))) where L is the lower Cholesky factor. - log_det_curvature_reg = 2.0 * jnp.sum( - jnp.log(jnp.diag(jnp.linalg.cholesky(creg_reduced))) - ) - log_det_regularization = 2.0 * jnp.sum( - jnp.log(jnp.diag(jnp.linalg.cholesky(reg_reduced))) - ) + # Log determinant terms + sign_cr, log_det_curvature_reg = jnp.linalg.slogdet(curvature_reg_matrix) + sign_r, log_det_regularization = jnp.linalg.slogdet(regularization_matrix) # Noise normalization - noise_normalization = jnp.sum(jnp.log(2 * jnp.pi * noise_map**2)) + noise_normalization = jnp.sum(jnp.log(2 * jnp.pi * noise_map ** 2)) return -0.5 * ( chi_squared @@ -846,7 +789,6 @@ def compute_log_evidence( + noise_normalization ) - # For the JIT profiling we use the step-by-step matrices for timing. # For the correctness assertion we use the inversion's own matrices, because # cumulative floating-point differences between JIT-compiled and eager paths @@ -856,32 +798,18 @@ def compute_log_evidence( recon_jnp = jnp.array(reconstruction) curv_jnp = jnp.array(curvature_matrix) reg_jnp = jnp.array(regularization_matrix) -mapper_indices_jnp = jnp.array(np.asarray(inversion.mapper_indices)) with timer.section("log_evidence_eager"): log_evidence = compute_log_evidence( - data_array, - noise_jnp, - blurred_img_jnp, - bmm_jnp, - recon_jnp, - curv_jnp, - reg_jnp, - mapper_indices_jnp, + data_array, noise_jnp, blurred_img_jnp, bmm_jnp, + recon_jnp, curv_jnp, reg_jnp, ) block(log_evidence) _, log_evidence = jit_profile( - compute_log_evidence, - "log_evidence_jit", - data_array, - noise_jnp, - blurred_img_jnp, - bmm_jnp, - recon_jnp, - curv_jnp, - reg_jnp, - mapper_indices_jnp, + compute_log_evidence, "log_evidence_jit", + data_array, noise_jnp, blurred_img_jnp, bmm_jnp, + recon_jnp, curv_jnp, reg_jnp, ) likelihood_steps.append(("Mapped recon + log evidence", timer.records[-1][1] / 10)) @@ -893,14 +821,8 @@ def compute_log_evidence( inv_curv_jnp = jnp.array(inversion.curvature_matrix) log_evidence_check = compute_log_evidence( - data_array, - noise_jnp, - blurred_img_jnp, - bmm_jnp, - inv_recon_jnp, - inv_curv_jnp, - reg_jnp, - mapper_indices_jnp, + data_array, noise_jnp, blurred_img_jnp, bmm_jnp, + inv_recon_jnp, inv_curv_jnp, reg_jnp, ) print(f" log_evidence (inv matrices) = {log_evidence_check}") print(f" log_evidence (reference) = {log_evidence_ref}") @@ -911,9 +833,7 @@ def compute_log_evidence( rtol=1e-4, err_msg="Log_evidence from inversion matrices does not match FitImaging.log_evidence", ) -print( - " Assertion PASSED: inversion-matrix log_evidence matches FitImaging.log_evidence" -) +print(" Assertion PASSED: inversion-matrix log_evidence matches FitImaging.log_evidence") # =================================================================== # PART C — Full-pipeline JIT for comparison @@ -925,11 +845,9 @@ def compute_log_evidence( analysis = al.AnalysisImaging(dataset=dataset, adapt_images=adapt_images, use_jax=True) - def full_pipeline_from_params(params_tree): return analysis.log_likelihood_function(instance=params_tree) - _, full_result = jit_profile(full_pipeline_from_params, "full_pipeline", params_tree) full_pipeline_per_call = timer.records[-1][1] / 10 @@ -953,7 +871,6 @@ def full_pipeline_from_params(params_tree): # requested via DELAUNAY_VMAP=1 environment variable. import os - run_vmap = os.environ.get("DELAUNAY_VMAP", "0") == "1" if not run_vmap: @@ -1022,7 +939,6 @@ def full_pipeline_from_params(params_tree): import json import matplotlib - matplotlib.use("Agg") import matplotlib.pyplot as plt @@ -1048,9 +964,7 @@ def full_pipeline_from_params(params_tree): print("-" * 70) print(f" {'TOTAL (step-by-step)':<{max_label}} {step_total:>12.6f} s") -print( - f" {'Full pipeline (single JIT)':<{max_label}} {full_pipeline_per_call:>12.6f} s" -) +print(f" {'Full pipeline (single JIT)':<{max_label}} {full_pipeline_per_call:>12.6f} s") if vmap_per_call is not None: print(f" {f'vmap batch (per call)':<{max_label}} {vmap_per_call:>12.6f} s") print(f" {f'vmap speedup vs single JIT':<{max_label}} {vmap_speedup:>11.1f}x") @@ -1136,7 +1050,7 @@ def full_pipeline_from_params(params_tree): fontweight="bold", ) ax.set_title( - f'AutoLens v{al_version} | {pixel_scale}"/px | {n_image_pixels} pixels | ' + f"AutoLens v{al_version} | {pixel_scale}\"/px | {n_image_pixels} pixels | " f"{n_over_sampled_pixels} over-sampled | {n_source_pixels} Delaunay vertices | " f"total: {step_total:.6f} s", fontsize=9, @@ -1158,7 +1072,7 @@ def full_pipeline_from_params(params_tree): # Simulator truth parameters via GaussianPrior(mean=truth, sigma=small) # make the full-pipeline log-evidence deterministic at the prior median. # vmap result asserted only when DELAUNAY_VMAP=1 (vmap compile takes 20+ min). -EXPECTED_LOG_EVIDENCE_HST = 26642.278160003658 # MGE-60 lens light + 39x39 overlay + ConstantSplit (rebaselined 2026-05-12) +EXPECTED_LOG_EVIDENCE_HST = 26288.321397232066 # 39x39 overlay → 1231 vertices, MGE-60 lens np.testing.assert_allclose( log_evidence_ref, @@ -1186,6 +1100,4 @@ def full_pipeline_from_params(params_tree): rtol=1e-4, err_msg=f"imaging/delaunay[{instrument}]: regression — vmap log_evidence drifted", ) -print( - f" Regression assertion PASSED: log_evidence matches {EXPECTED_LOG_EVIDENCE_HST:.6f}" -) +print(f" Regression assertion PASSED: log_evidence matches {EXPECTED_LOG_EVIDENCE_HST:.6f}") diff --git a/likelihood/imaging/mge.py b/likelihood/imaging/mge.py index 0f5ceb1..ed4a50b 100644 --- a/likelihood/imaging/mge.py +++ b/likelihood/imaging/mge.py @@ -81,7 +81,6 @@ # Profiling helpers # --------------------------------------------------------------------------- - class Timer: """Accumulates named timing measurements and prints a summary.""" @@ -232,7 +231,9 @@ def jit_profile(func, label, *args, n_repeats=10): shear.gamma_1 = 0.05 shear.gamma_2 = 0.05 - lens = af.Model(al.Galaxy, redshift=0.5, bulge=lens_bulge, mass=mass, shear=shear) + lens = af.Model( + al.Galaxy, redshift=0.5, bulge=lens_bulge, mass=mass, shear=shear + ) source_bulge = al.model_util.mge_model_from( mask_radius=mask_radius, total_gaussians=20, centre_prior_is_uniform=False @@ -337,14 +338,12 @@ def jit_profile(func, label, *args, n_repeats=10): print(f" Number of planes traced: {len(traced_grids)}") - def ray_trace_raw(grid_raw): """Wraps ray-tracing so inputs/outputs are raw arrays.""" grid = aa.Grid2DIrregular(values=grid_raw, xp=jnp) traced = tracer.traced_grid_2d_list_from(grid=grid, xp=jnp) return jnp.stack([tg.array for tg in traced]) - _, traced_grids_raw = jit_profile(ray_trace_raw, "ray_trace_jit", grid_lp_raw) likelihood_steps.append(("Ray-trace grids", timer.records[-1][1] / 10)) @@ -376,15 +375,10 @@ def ray_trace_raw(grid_raw): # mapping_matrix and operated_mapping_matrix_override already return raw arrays. with timer.section("mapping_matrix"): mapping_matrices = [func.mapping_matrix for func in lp_linear_funcs] - mapping_matrix = ( - np.hstack(mapping_matrices) - if len(mapping_matrices) > 1 - else mapping_matrices[0] - ) + mapping_matrix = np.hstack(mapping_matrices) if len(mapping_matrices) > 1 else mapping_matrices[0] print(f" mapping_matrix shape: {mapping_matrix.shape}") - def mapping_matrix_from_params(params_tree): """Compute mapping matrix from a pytree-shaped ``ModelInstance``. @@ -408,7 +402,6 @@ def mapping_matrix_from_params(params_tree): matrices = [f.mapping_matrix for f in funcs] return jnp.hstack(matrices) if len(matrices) > 1 else matrices[0] - _, mm_jit = jit_profile(mapping_matrix_from_params, "mapping_matrix_jit", params_tree) likelihood_steps.append(("Mapping matrix", timer.records[-1][1] / 10)) @@ -421,18 +414,11 @@ def mapping_matrix_from_params(params_tree): print("\n--- Step 3: Blurred mapping matrix ---") with timer.section("blurred_mapping_matrix"): - blurred_matrices = [ - func.operated_mapping_matrix_override for func in lp_linear_funcs - ] - blurred_mapping_matrix = ( - np.hstack(blurred_matrices) - if len(blurred_matrices) > 1 - else blurred_matrices[0] - ) + blurred_matrices = [func.operated_mapping_matrix_override for func in lp_linear_funcs] + blurred_mapping_matrix = np.hstack(blurred_matrices) if len(blurred_matrices) > 1 else blurred_matrices[0] print(f" blurred_mapping_matrix shape: {blurred_mapping_matrix.shape}") - def blurred_mm_from_params(params_tree): """Compute blurred mapping matrix from a pytree-shaped ``ModelInstance``.""" t = al.Tracer(galaxies=list(params_tree.galaxies)) @@ -452,7 +438,6 @@ def blurred_mm_from_params(params_tree): matrices = [f.operated_mapping_matrix_override for f in funcs] return jnp.hstack(matrices) if len(matrices) > 1 else matrices[0] - _, bmm_jit = jit_profile(blurred_mm_from_params, "blurred_mm_jit", params_tree) likelihood_steps.append(("Blurred mapping matrix", timer.records[-1][1] / 10)) @@ -464,7 +449,6 @@ def blurred_mm_from_params(params_tree): print("\n--- Step 4: Data vector ---") - def compute_data_vector(blurred_mapping_matrix, image, noise_map): return al.util.inversion_imaging.data_vector_via_blurred_mapping_matrix_from( blurred_mapping_matrix=blurred_mapping_matrix, @@ -472,7 +456,6 @@ def compute_data_vector(blurred_mapping_matrix, image, noise_map): noise_map=noise_map, ) - bmm_jnp = jnp.array(blurred_mapping_matrix) noise_jnp = jnp.array(dataset.noise_map.array) @@ -495,7 +478,6 @@ def compute_data_vector(blurred_mapping_matrix, image, noise_map): n_linear = bmm_jnp.shape[1] - def compute_curvature_matrix(blurred_mapping_matrix, noise_map): return al.util.inversion.curvature_matrix_via_mapping_matrix_from( mapping_matrix=blurred_mapping_matrix, @@ -505,7 +487,6 @@ def compute_curvature_matrix(blurred_mapping_matrix, noise_map): xp=jnp, ) - with timer.section("curvature_matrix_eager"): curvature_matrix = compute_curvature_matrix(bmm_jnp, noise_jnp) block(curvature_matrix) @@ -523,7 +504,6 @@ def compute_curvature_matrix(blurred_mapping_matrix, noise_map): print("\n--- Step 6: Reconstruction (NNLS) ---") - def compute_reconstruction(data_vector, curvature_matrix): return al.util.inversion.reconstruction_positive_only_from( data_vector=data_vector, @@ -531,7 +511,6 @@ def compute_reconstruction(data_vector, curvature_matrix): xp=jnp, ) - with timer.section("reconstruction_eager"): reconstruction = compute_reconstruction( jnp.array(data_vector), jnp.array(curvature_matrix) @@ -539,10 +518,8 @@ def compute_reconstruction(data_vector, curvature_matrix): block(reconstruction) _, reconstruction = jit_profile( - compute_reconstruction, - "reconstruction_jit", - jnp.array(data_vector), - jnp.array(curvature_matrix), + compute_reconstruction, "reconstruction_jit", + jnp.array(data_vector), jnp.array(curvature_matrix) ) likelihood_steps.append(("Reconstruction (NNLS)", timer.records[-1][1] / 10)) @@ -554,7 +531,6 @@ def compute_reconstruction(data_vector, curvature_matrix): print("\n--- Step 7: Mapped reconstructed image ---") - def compute_mapped_recon(blurred_mapping_matrix, reconstruction): return al.util.inversion.mapped_reconstructed_data_via_mapping_matrix_from( mapping_matrix=blurred_mapping_matrix, @@ -562,7 +538,6 @@ def compute_mapped_recon(blurred_mapping_matrix, reconstruction): xp=jnp, ) - with timer.section("mapped_recon_eager"): mapped_recon = compute_mapped_recon(bmm_jnp, jnp.array(reconstruction)) block(mapped_recon) @@ -580,26 +555,23 @@ def compute_mapped_recon(blurred_mapping_matrix, reconstruction): print("\n--- Step 8: Chi-squared & log likelihood ---") - def compute_log_likelihood(data, noise_map, mapped_recon): residual = data - mapped_recon chi_squared = jnp.sum((residual / noise_map) ** 2) - noise_norm = jnp.sum(jnp.log(2 * jnp.pi * noise_map**2)) + noise_norm = jnp.sum(jnp.log(2 * jnp.pi * noise_map ** 2)) return -0.5 * (chi_squared + noise_norm) - mapped_recon_jnp = jnp.array(mapped_recon) with timer.section("log_likelihood_eager"): - log_like = compute_log_likelihood(data_array, noise_jnp, mapped_recon_jnp) + log_like = compute_log_likelihood( + data_array, noise_jnp, mapped_recon_jnp + ) block(log_like) _, log_like = jit_profile( - compute_log_likelihood, - "log_likelihood_jit", - data_array, - noise_jnp, - mapped_recon_jnp, + compute_log_likelihood, "log_likelihood_jit", + data_array, noise_jnp, mapped_recon_jnp ) likelihood_steps.append(("Chi-squared & log likelihood", timer.records[-1][1] / 10)) @@ -630,7 +602,6 @@ def compute_log_likelihood(data, noise_map, mapped_recon): # instead of going through ``model.instance_from_vector(parameters, xp=jnp)``. analysis = al.AnalysisImaging(dataset=dataset, use_jax=True) - def full_pipeline_from_params(params_tree): """Full likelihood from a pytree-shaped ``ModelInstance``. @@ -640,7 +611,6 @@ def full_pipeline_from_params(params_tree): """ return analysis.log_likelihood_function(instance=params_tree) - _, full_result = jit_profile(full_pipeline_from_params, "full_pipeline", params_tree) full_pipeline_per_call = timer.records[-1][1] / 10 @@ -716,7 +686,6 @@ def full_pipeline_from_params(params_tree): import json import matplotlib - matplotlib.use("Agg") import matplotlib.pyplot as plt @@ -741,12 +710,8 @@ def full_pipeline_from_params(params_tree): print("-" * 70) print(f" {'TOTAL (step-by-step)':<{max_label}} {step_total:>12.6f} s") -print( - f" {'Full pipeline (single JIT)':<{max_label}} {full_pipeline_per_call:>12.6f} s" -) -print( - f" {f'vmap batch={batch_size} (per call)':<{max_label}} {vmap_per_call:>12.6f} s" -) +print(f" {'Full pipeline (single JIT)':<{max_label}} {full_pipeline_per_call:>12.6f} s") +print(f" {f'vmap batch={batch_size} (per call)':<{max_label}} {vmap_per_call:>12.6f} s") print(f" {f'vmap speedup vs single JIT':<{max_label}} {vmap_speedup:>11.1f}x") print("=" * 70) @@ -823,7 +788,7 @@ def full_pipeline_from_params(params_tree): fontweight="bold", ) ax.set_title( - f'AutoLens v{al_version} | {pixel_scale}"/px | {n_image_pixels} pixels | ' + f"AutoLens v{al_version} | {pixel_scale}\"/px | {n_image_pixels} pixels | " f"{n_over_sampled_pixels} over-sampled | {n_linear_gaussians} Gaussians | " f"total: {step_total:.6f} s", fontsize=9, @@ -873,6 +838,4 @@ def full_pipeline_from_params(params_tree): rtol=1e-4, err_msg=f"imaging/mge[{instrument}]: regression — vmap log_likelihood drifted", ) -print( - f" Regression assertion PASSED: log_likelihood matches {EXPECTED_LOG_LIKELIHOOD_HST:.6f}" -) +print(f" Regression assertion PASSED: log_likelihood matches {EXPECTED_LOG_LIKELIHOOD_HST:.6f}") diff --git a/likelihood/imaging/pixelization.py b/likelihood/imaging/pixelization.py index 8ecf0c4..bce0580 100644 --- a/likelihood/imaging/pixelization.py +++ b/likelihood/imaging/pixelization.py @@ -63,7 +63,6 @@ # Profiling helpers # --------------------------------------------------------------------------- - class Timer: """Accumulates named timing measurements and prints a summary.""" @@ -223,7 +222,9 @@ def jit_profile(func, label, *args, n_repeats=10): shear.gamma_1 = af.GaussianPrior(mean=0.05, sigma=0.005) shear.gamma_2 = af.GaussianPrior(mean=0.05, sigma=0.005) - lens = af.Model(al.Galaxy, redshift=0.5, bulge=lens_bulge, mass=mass, shear=shear) + lens = af.Model( + al.Galaxy, redshift=0.5, bulge=lens_bulge, mass=mass, shear=shear + ) pixelization = al.Pixelization( mesh=al.mesh.RectangularAdaptDensity(shape=mesh_shape), @@ -321,22 +322,18 @@ def jit_profile(func, label, *args, n_repeats=10): print("\n--- Step 1: Ray-trace grids ---") with timer.section("ray_trace_eager"): - traced_grids = tracer.traced_grid_2d_list_from( - grid=dataset.grids.pixelization, xp=jnp - ) + traced_grids = tracer.traced_grid_2d_list_from(grid=dataset.grids.pixelization, xp=jnp) for tg in traced_grids: block(tg) print(f" Number of planes traced: {len(traced_grids)}") - def ray_trace_raw(grid_raw): """Wraps ray-tracing so inputs/outputs are raw arrays.""" grid = aa.Grid2DIrregular(values=grid_raw, xp=jnp) traced = tracer.traced_grid_2d_list_from(grid=grid, xp=jnp) return jnp.stack([tg.array for tg in traced]) - _, traced_grids_raw = jit_profile(ray_trace_raw, "ray_trace_jit", grid_pix_raw) likelihood_steps.append(("Ray-trace grids", timer.records[-1][1] / 10)) @@ -348,7 +345,6 @@ def ray_trace_raw(grid_raw): print("\n--- Step 2: Blurred image (lens light profiles) ---") - # Sub-step 2a: Compute raw lens light images (JIT-profiled) def lens_image_raw(grid_raw, blurring_grid_raw): """Compute lens light images on masked + blurring grids (no PSF).""" @@ -358,7 +354,6 @@ def lens_image_raw(grid_raw, blurring_grid_raw): blurring_image = tracer.image_2d_from(grid=blurring_grid, xp=jnp) return image.array, blurring_image.array - with timer.section("lens_image_eager"): img_eager, blur_img_eager = lens_image_raw(grid_lp_raw, grid_blurring_raw) block(img_eager) @@ -381,19 +376,14 @@ def lens_image_raw(grid_raw, blurring_grid_raw): print(f" blurred_image shape: {blurred_image.array.shape}") - def blurred_image_from_params(params_tree): """Compute blurred image directly from a pytree ModelInstance — fully JIT-traceable.""" t = al.Tracer(galaxies=list(params_tree.galaxies)) result = t.blurred_image_2d_from( - grid=grid_lp, - psf=dataset.psf, - blurring_grid=grid_blurring, - xp=jnp, + grid=grid_lp, psf=dataset.psf, blurring_grid=grid_blurring, xp=jnp, ) return result.array - _, blurred_img_jit = jit_profile( blurred_image_from_params, "blurred_image_jit", params_tree ) @@ -405,11 +395,9 @@ def blurred_image_from_params(params_tree): print("\n--- Step 3: Profile-subtracted image ---") - def profile_subtract(data, blurred_image): return data - blurred_image - with timer.section("profile_subtract_eager"): blurred_img_jnp = jnp.array(blurred_image.array) profile_subtracted = profile_subtract(data_array, blurred_img_jnp) @@ -463,18 +451,14 @@ def profile_subtract(data, blurred_image): ) block(mesh_grid) - def overlay_grid_raw_fn(relocated_grid_raw): grid = al.Grid2DIrregular(values=relocated_grid_raw, xp=jnp) return overlay_grid_from(shape_native=mesh_shape, grid=grid, xp=jnp) - _, mesh_grid_raw = jit_profile( overlay_grid_raw_fn, "overlay_grid_jit", relocated_grid_raw ) -likelihood_steps.append( - ("Overlay grid (source pixel centres)", timer.records[-1][1] / 10) -) +likelihood_steps.append(("Overlay grid (source pixel centres)", timer.records[-1][1] / 10)) print(f" mesh_grid shape: {mesh_grid_raw.shape}") @@ -553,23 +537,17 @@ def overlay_grid_raw_fn(relocated_grid_raw): # border relocation → overlay grid → interpolation → mapper → mapping matrix → PSF convolution. # These steps are tightly sequential; the full pipeline JIT-compiles them all together. - def blurred_mm_from_params(params_tree): """Compute blurred mapping matrix via full inversion setup from a pytree ModelInstance.""" t = al.Tracer(galaxies=list(params_tree.galaxies)) fit_jax = al.FitImaging( - dataset=dataset, - tracer=t, - settings=al.Settings(use_border_relocator=True), - xp=jnp, + dataset=dataset, tracer=t, + settings=al.Settings(use_border_relocator=True), xp=jnp, ) return jnp.array(fit_jax.inversion.operated_mapping_matrix) - _, bmm_jit = jit_profile(blurred_mm_from_params, "inversion_setup_jit", params_tree) -likelihood_steps.append( - ("Inversion setup (steps 4-8 combined)", timer.records[-1][1] / 10) -) +likelihood_steps.append(("Inversion setup (steps 4-8 combined)", timer.records[-1][1] / 10)) print(f" blurred_mapping_matrix (JIT) shape: {bmm_jit.shape}") @@ -582,7 +560,6 @@ def blurred_mm_from_params(params_tree): print("\n--- Step 9: Data vector ---") - def compute_data_vector(blurred_mapping_matrix, image, noise_map): return al.util.inversion_imaging.data_vector_via_blurred_mapping_matrix_from( blurred_mapping_matrix=blurred_mapping_matrix, @@ -590,7 +567,6 @@ def compute_data_vector(blurred_mapping_matrix, image, noise_map): noise_map=noise_map, ) - profile_sub_jnp = jnp.array(fit.profile_subtracted_image.array) noise_jnp = jnp.array(dataset.noise_map.array) @@ -614,7 +590,6 @@ def compute_data_vector(blurred_mapping_matrix, image, noise_map): # Match the FitImaging inversion: add_to_curvature_diag=True, with settings no_reg_list = list(inversion.no_regularization_index_list) - def compute_curvature_matrix(blurred_mapping_matrix, noise_map): return al.util.inversion.curvature_matrix_via_mapping_matrix_from( mapping_matrix=blurred_mapping_matrix, @@ -625,7 +600,6 @@ def compute_curvature_matrix(blurred_mapping_matrix, noise_map): xp=jnp, ) - with timer.section("curvature_matrix_eager"): curvature_matrix = compute_curvature_matrix(bmm_jnp, noise_jnp) block(curvature_matrix) @@ -643,7 +617,6 @@ def compute_curvature_matrix(blurred_mapping_matrix, noise_map): print("\n--- Step 11: Regularization matrix ---") - def compute_regularization_matrix(neighbors_array, neighbors_sizes): return al.util.regularization.constant_regularization_matrix_from( coefficient=reg_coefficient, @@ -652,7 +625,6 @@ def compute_regularization_matrix(neighbors_array, neighbors_sizes): xp=jnp, ) - with timer.section("regularization_matrix_eager"): regularization_matrix = compute_regularization_matrix( neighbors_array, neighbors_sizes @@ -660,10 +632,8 @@ def compute_regularization_matrix(neighbors_array, neighbors_sizes): block(regularization_matrix) _, regularization_matrix = jit_profile( - compute_regularization_matrix, - "regularization_matrix_jit", - neighbors_array, - neighbors_sizes, + compute_regularization_matrix, "regularization_matrix_jit", + neighbors_array, neighbors_sizes ) likelihood_steps.append(("Regularization matrix (H)", timer.records[-1][1] / 10)) @@ -675,7 +645,6 @@ def compute_regularization_matrix(neighbors_array, neighbors_sizes): print("\n--- Step 12: Regularized reconstruction ---") - def compute_reconstruction(data_vector, curvature_matrix, regularization_matrix): curvature_reg_matrix = curvature_matrix + regularization_matrix return al.util.inversion.reconstruction_positive_only_from( @@ -684,7 +653,6 @@ def compute_reconstruction(data_vector, curvature_matrix, regularization_matrix) xp=jnp, ) - with timer.section("reconstruction_eager"): reconstruction = compute_reconstruction( jnp.array(data_vector), @@ -694,8 +662,7 @@ def compute_reconstruction(data_vector, curvature_matrix, regularization_matrix) block(reconstruction) _, reconstruction = jit_profile( - compute_reconstruction, - "reconstruction_jit", + compute_reconstruction, "reconstruction_jit", jnp.array(data_vector), jnp.array(curvature_matrix), jnp.array(regularization_matrix), @@ -710,28 +677,13 @@ def compute_reconstruction(data_vector, curvature_matrix, regularization_matrix) print("\n--- Step 13: Mapped reconstruction + log evidence ---") - def compute_log_evidence( - data, - noise_map, - blurred_image, - blurred_mapping_matrix, - reconstruction, - curvature_matrix, - regularization_matrix, - mapper_indices, + data, noise_map, blurred_image, blurred_mapping_matrix, reconstruction, + curvature_matrix, regularization_matrix, ): """Compute the full log evidence including all five terms: -2 ln e = chi^2 + s^T H s + ln[det(F+H)] - ln[det(H)] + noise_norm - - Matches the production formula in - ``autoarray/inversion/inversion/abstract.py:log_det_*`` — - reduces both the curvature_reg_matrix and the regularization_matrix - to the rows/cols indexed by ``mapper_indices`` before the log_det. - This drops the no-regularization rows (e.g. MGE Basis linear - components) which otherwise make ``det(H) = 0`` and the slogdet - return -inf, then uses Cholesky for a numerically stable log_det. """ # Map reconstruction to image mapped_recon = al.util.inversion.mapped_reconstructed_data_via_mapping_matrix_from( @@ -755,21 +707,12 @@ def compute_log_evidence( # Curvature + regularization matrix curvature_reg_matrix = curvature_matrix + regularization_matrix - # Reduce to pixelization rows/cols only (matches production - # ``*_matrix_reduced``): required for models with non-regularised - # linear components (e.g. MGE lens light). - creg_reduced = curvature_reg_matrix[mapper_indices][:, mapper_indices] - reg_reduced = regularization_matrix[mapper_indices][:, mapper_indices] - - log_det_curvature_reg = 2.0 * jnp.sum( - jnp.log(jnp.diag(jnp.linalg.cholesky(creg_reduced))) - ) - log_det_regularization = 2.0 * jnp.sum( - jnp.log(jnp.diag(jnp.linalg.cholesky(reg_reduced))) - ) + # Log determinant terms + sign_cr, log_det_curvature_reg = jnp.linalg.slogdet(curvature_reg_matrix) + sign_r, log_det_regularization = jnp.linalg.slogdet(regularization_matrix) # Noise normalization - noise_normalization = jnp.sum(jnp.log(2 * jnp.pi * noise_map**2)) + noise_normalization = jnp.sum(jnp.log(2 * jnp.pi * noise_map ** 2)) return -0.5 * ( chi_squared @@ -779,7 +722,6 @@ def compute_log_evidence( + noise_normalization ) - # For the JIT profiling we use the step-by-step matrices for timing. # For the correctness assertion we use the inversion's own matrices, because # cumulative floating-point differences between JIT-compiled and eager paths @@ -789,32 +731,18 @@ def compute_log_evidence( recon_jnp = jnp.array(reconstruction) curv_jnp = jnp.array(curvature_matrix) reg_jnp = jnp.array(regularization_matrix) -mapper_indices_jnp = jnp.array(np.asarray(inversion.mapper_indices)) with timer.section("log_evidence_eager"): log_evidence = compute_log_evidence( - data_array, - noise_jnp, - blurred_img_jnp, - bmm_jnp, - recon_jnp, - curv_jnp, - reg_jnp, - mapper_indices_jnp, + data_array, noise_jnp, blurred_img_jnp, bmm_jnp, + recon_jnp, curv_jnp, reg_jnp, ) block(log_evidence) _, log_evidence = jit_profile( - compute_log_evidence, - "log_evidence_jit", - data_array, - noise_jnp, - blurred_img_jnp, - bmm_jnp, - recon_jnp, - curv_jnp, - reg_jnp, - mapper_indices_jnp, + compute_log_evidence, "log_evidence_jit", + data_array, noise_jnp, blurred_img_jnp, bmm_jnp, + recon_jnp, curv_jnp, reg_jnp, ) likelihood_steps.append(("Mapped recon + log evidence", timer.records[-1][1] / 10)) @@ -826,14 +754,8 @@ def compute_log_evidence( inv_curv_jnp = jnp.array(inversion.curvature_matrix) log_evidence_check = compute_log_evidence( - data_array, - noise_jnp, - blurred_img_jnp, - bmm_jnp, - inv_recon_jnp, - inv_curv_jnp, - reg_jnp, - mapper_indices_jnp, + data_array, noise_jnp, blurred_img_jnp, bmm_jnp, + inv_recon_jnp, inv_curv_jnp, reg_jnp, ) print(f" log_evidence (inv matrices) = {log_evidence_check}") print(f" log_evidence (reference) = {log_evidence_ref}") @@ -844,9 +766,7 @@ def compute_log_evidence( rtol=1e-4, err_msg="Log_evidence from inversion matrices does not match FitImaging.log_evidence", ) -print( - " Assertion PASSED: inversion-matrix log_evidence matches FitImaging.log_evidence" -) +print(" Assertion PASSED: inversion-matrix log_evidence matches FitImaging.log_evidence") # =================================================================== # PART C — Full-pipeline JIT for comparison @@ -858,11 +778,9 @@ def compute_log_evidence( analysis = al.AnalysisImaging(dataset=dataset, use_jax=True) - def full_pipeline_from_params(params_tree): return analysis.log_likelihood_function(instance=params_tree) - _, full_result = jit_profile(full_pipeline_from_params, "full_pipeline", params_tree) full_pipeline_per_call = timer.records[-1][1] / 10 @@ -886,10 +804,8 @@ def full_pipeline_from_params(params_tree): _n_leaves = len(jax.tree_util.tree_leaves(params_tree)) if _n_leaves == 0: - print( - f" SKIPPED: model has 0 free parameters (all fixed to truth); " - f"vmap requires at least one array leaf." - ) + print(f" SKIPPED: model has 0 free parameters (all fixed to truth); " + f"vmap requires at least one array leaf.") else: parameters = jax.tree_util.tree_map( lambda leaf: jnp.broadcast_to(leaf, (batch_size, *leaf.shape)), @@ -954,7 +870,6 @@ def full_pipeline_from_params(params_tree): import json import matplotlib - matplotlib.use("Agg") import matplotlib.pyplot as plt @@ -980,13 +895,9 @@ def full_pipeline_from_params(params_tree): print("-" * 70) print(f" {'TOTAL (step-by-step)':<{max_label}} {step_total:>12.6f} s") -print( - f" {'Full pipeline (single JIT)':<{max_label}} {full_pipeline_per_call:>12.6f} s" -) +print(f" {'Full pipeline (single JIT)':<{max_label}} {full_pipeline_per_call:>12.6f} s") if vmap_per_call is not None: - print( - f" {f'vmap batch={batch_size} (per call)':<{max_label}} {vmap_per_call:>12.6f} s" - ) + print(f" {f'vmap batch={batch_size} (per call)':<{max_label}} {vmap_per_call:>12.6f} s") print(f" {f'vmap speedup vs single JIT':<{max_label}} {vmap_speedup:>11.1f}x") else: print(f" {'vmap':<{max_label}} {'SKIPPED (0 free params)':>12}") @@ -1008,24 +919,18 @@ def full_pipeline_from_params(params_tree): "steps": {label: per_call for label, per_call in likelihood_steps}, "total_step_by_step": step_total, "full_pipeline_single_jit": full_pipeline_per_call, - "vmap": ( - "SKIPPED — model has 0 free parameters (all fixed to truth)" - if vmap_per_call is None - else { - "batch_size": batch_size, - "batch_time": vmap_batch_time, - "per_call": vmap_per_call, - "speedup_vs_single_jit": round(vmap_speedup, 1), - } - ), + "vmap": "SKIPPED — model has 0 free parameters (all fixed to truth)" if vmap_per_call is None else { + "batch_size": batch_size, + "batch_time": vmap_batch_time, + "per_call": vmap_per_call, + "speedup_vs_single_jit": round(vmap_speedup, 1), + }, } results_dir = _workspace_root / "results" / "likelihood" / "imaging" results_dir.mkdir(parents=True, exist_ok=True) -dict_path = ( - results_dir / f"pixelization_likelihood_summary_{instrument}_v{al_version}.json" -) +dict_path = results_dir / f"pixelization_likelihood_summary_{instrument}_v{al_version}.json" dict_path.write_text(json.dumps(likelihood_summary, indent=2)) print(f"\n Results dict saved to: {dict_path}") @@ -1073,7 +978,7 @@ def full_pipeline_from_params(params_tree): fontweight="bold", ) ax.set_title( - f'AutoLens v{al_version} | {pixel_scale}"/px | {n_image_pixels} pixels | ' + f"AutoLens v{al_version} | {pixel_scale}\"/px | {n_image_pixels} pixels | " f"{n_over_sampled_pixels} over-sampled | {mesh_shape[0]}x{mesh_shape[1]} mesh | " f"total: {step_total:.6f} s", fontsize=9, @@ -1082,9 +987,7 @@ def full_pipeline_from_params(params_tree): ax.margins(x=0.15) fig.tight_layout() -chart_path = ( - results_dir / f"pixelization_likelihood_summary_{instrument}_v{al_version}.png" -) +chart_path = results_dir / f"pixelization_likelihood_summary_{instrument}_v{al_version}.png" fig.savefig(chart_path, dpi=150) plt.close(fig) print(f" Bar chart saved to: {chart_path}") @@ -1097,9 +1000,7 @@ def full_pipeline_from_params(params_tree): # RectangularAdaptDensity at prior medians is deterministic across the # eager / full-JIT / vmap paths to within rtol=1e-4 — the constant below # is the value those three paths agree on. -EXPECTED_LOG_EVIDENCE_HST = ( - 24746.105672366088 # 35x35 = 1225 source pixels, MGE-60 lens light -) +EXPECTED_LOG_EVIDENCE_HST = 24746.105672366088 # 35x35 = 1225 source pixels, MGE-60 lens light np.testing.assert_allclose( log_evidence_ref, @@ -1126,6 +1027,4 @@ def full_pipeline_from_params(params_tree): rtol=1e-4, err_msg=f"imaging/pixelization[{instrument}]: regression — vmap log_evidence drifted", ) -print( - f" Regression assertion PASSED: log_evidence matches {EXPECTED_LOG_EVIDENCE_HST:.6f}" -) +print(f" Regression assertion PASSED: log_evidence matches {EXPECTED_LOG_EVIDENCE_HST:.6f}") diff --git a/likelihood/interferometer/delaunay.py b/likelihood/interferometer/delaunay.py index 32b3e1f..33abf51 100644 --- a/likelihood/interferometer/delaunay.py +++ b/likelihood/interferometer/delaunay.py @@ -98,8 +98,9 @@ # --------------------------------------------------------------------------- INSTRUMENTS = { - "sma": {"pixel_scale": 0.1, "real_space_shape": (256, 256)}, - "alma": {"pixel_scale": 0.05, "real_space_shape": (256, 256)}, + "sma": {"pixel_scale": 0.1, "real_space_shape": (256, 256), "mask_radius": 3.0}, + "alma": {"pixel_scale": 0.05, "real_space_shape": (256, 256), "mask_radius": 3.0}, + "hannah": {"pixel_scale": 0.125, "real_space_shape": (40, 40), "mask_radius": 2.3}, } instrument = "sma" # <-- change this to profile a different instrument @@ -187,7 +188,7 @@ def jit_profile(func, label, *args, n_repeats=10): f"then copy the result into autolens_profiling/dataset/." ) -mask_radius = 3.0 +mask_radius = INSTRUMENTS[instrument]["mask_radius"] real_space_mask = al.Mask2D.circular( shape_native=real_space_shape, @@ -202,8 +203,19 @@ def jit_profile(func, label, *args, n_repeats=10): uv_wavelengths_path=dataset_path / "uv_wavelengths.fits", real_space_mask=real_space_mask, transformer_class=al.TransformerDFT, + # DFT is intentional even at ALMA-scale visibility counts — profiling + # the JAX-traceable path is the goal, NUFFT (pynufft) is not yet + # JIT-friendly. + raise_error_dft_visibilities_limit=False, ) +with timer.section("apply_sparse_operator"): + # Precompute the NUFFT precision-matrix preload so per-fit curvature + # assembly uses the FFT-based sparse path instead of dense DFT for every + # source pixel. Unblocked by PyAutoArray#316 (the Pmax > 1 extent-indexing + # fix); on Delaunay this was previously guarded with NotImplementedError. + dataset = dataset.apply_sparse_operator(use_jax=True, show_progress=True) + n_visibilities = dataset.uv_wavelengths.shape[0] print(f" Total visibilities: {n_visibilities}") @@ -1105,31 +1117,39 @@ def full_pipeline_from_params(params_tree): # # Simulator truth parameters via GaussianPrior(mean=truth, sigma=small) # make the full-pipeline log-evidence deterministic at the prior median. -EXPECTED_LOG_EVIDENCE_SMA = -3167.5258928840763 +# Pinned empirically per instrument; ``None`` means "skip the assertion and +# print the value so it can be pasted in here on a clean run". +EXPECTED_LOG_EVIDENCE = { + "sma": -3167.5258928840763, + "alma": None, + "hannah": -204838.07924622478, +} + +expected_log_evidence = EXPECTED_LOG_EVIDENCE.get(instrument) -if EXPECTED_LOG_EVIDENCE_SMA is None: +if expected_log_evidence is None: print( - f"\n Regression assertion SKIPPED — " + f"\n Regression assertion SKIPPED for [{instrument}] — " f"capture this run's eager log_evidence ({figure_of_merit_ref}) " - f"and set EXPECTED_LOG_EVIDENCE_SMA in this script." + f"and paste it into EXPECTED_LOG_EVIDENCE[{instrument!r}]." ) else: np.testing.assert_allclose( figure_of_merit_ref, - EXPECTED_LOG_EVIDENCE_SMA, + expected_log_evidence, rtol=1e-4, err_msg=( f"interferometer/delaunay[{instrument}]: regression — eager log_evidence " - f"drifted (got {figure_of_merit_ref}, expected {EXPECTED_LOG_EVIDENCE_SMA})" + f"drifted (got {figure_of_merit_ref}, expected {expected_log_evidence})" ), ) print( f" Eager regression assertion PASSED: log_evidence matches " - f"{EXPECTED_LOG_EVIDENCE_SMA:.6f}" + f"{expected_log_evidence:.6f}" ) np.testing.assert_allclose( float(full_result), - EXPECTED_LOG_EVIDENCE_SMA, + expected_log_evidence, rtol=1e-4, err_msg=f"interferometer/delaunay[{instrument}]: regression — full log_evidence drifted", ) @@ -1137,7 +1157,7 @@ def full_pipeline_from_params(params_tree): if result_vmap is not None: np.testing.assert_allclose( np.array(result_vmap), - EXPECTED_LOG_EVIDENCE_SMA, + expected_log_evidence, rtol=1e-4, err_msg=f"interferometer/delaunay[{instrument}]: regression — vmap log_evidence drifted", ) diff --git a/likelihood/interferometer/mge.py b/likelihood/interferometer/mge.py index 0710e38..7d5ab25 100644 --- a/likelihood/interferometer/mge.py +++ b/likelihood/interferometer/mge.py @@ -68,7 +68,6 @@ # Profiling helpers (copied verbatim from imaging/mge.py) # --------------------------------------------------------------------------- - class Timer: """Accumulates named timing measurements and prints a summary.""" @@ -229,7 +228,6 @@ def jit_profile(func, label, *args, n_repeats=10): # --------------------------------------------------------------------------- from autogalaxy.profiles.basis import Basis as _Basis - _basis_list = [b for g in instance.galaxies for b in g.cls_list_from(cls=_Basis)] n_linear_gaussians = sum(len(b.profile_list) for b in _basis_list) @@ -270,7 +268,6 @@ def jit_profile(func, label, *args, n_repeats=10): analysis = al.AnalysisInterferometer(dataset=dataset, use_jax=True) - def full_pipeline_from_params(params_tree): """Full interferometer likelihood from a pytree-shaped ``ModelInstance``. @@ -280,7 +277,6 @@ def full_pipeline_from_params(params_tree): """ return analysis.log_likelihood_function(instance=params_tree) - _, full_result = jit_profile(full_pipeline_from_params, "full_pipeline", params_tree) full_pipeline_per_call = timer.records[-1][1] / 10 @@ -364,7 +360,6 @@ def full_pipeline_from_params(params_tree): import json import matplotlib - matplotlib.use("Agg") import matplotlib.pyplot as plt @@ -433,9 +428,7 @@ def full_pipeline_from_params(params_tree): fig, ax = plt.subplots(figsize=(10, 3.5)) y_pos = range(len(labels)) -bars = ax.barh( - y_pos, times, color=["#4C72B0", "#55A868"], edgecolor="white", height=0.55 -) +bars = ax.barh(y_pos, times, color=["#4C72B0", "#55A868"], edgecolor="white", height=0.55) for bar, t in zip(bars, times): ax.text( @@ -456,7 +449,7 @@ def full_pipeline_from_params(params_tree): fontweight="bold", ) ax.set_title( - f'AutoLens v{al_version} | {pixel_scale}"/px | ' + f"AutoLens v{al_version} | {pixel_scale}\"/px | " f"{real_space_shape[0]}x{real_space_shape[1]} real-space | " f"{n_visibilities} visibilities | {n_linear_gaussians} Gaussians | " f"vmap speedup: {vmap_speedup:.1f}x", @@ -504,6 +497,4 @@ def full_pipeline_from_params(params_tree): rtol=1e-4, err_msg=f"interferometer/mge[{instrument}]: regression — vmap log_likelihood drifted", ) -print( - f" Regression assertion PASSED: log_likelihood matches {EXPECTED_LOG_LIKELIHOOD_SMA:.6f}" -) +print(f" Regression assertion PASSED: log_likelihood matches {EXPECTED_LOG_LIKELIHOOD_SMA:.6f}") diff --git a/likelihood/interferometer/pixelization.py b/likelihood/interferometer/pixelization.py index 384df58..4b67b59 100644 --- a/likelihood/interferometer/pixelization.py +++ b/likelihood/interferometer/pixelization.py @@ -76,7 +76,6 @@ # Profiling helpers # --------------------------------------------------------------------------- - class Timer: """Accumulates named timing measurements and prints a summary.""" @@ -273,7 +272,6 @@ def jit_profile(func, label, *args, n_repeats=10): analysis = al.AnalysisInterferometer(dataset=dataset, use_jax=True) - def full_pipeline_from_params(params_tree): """Full interferometer likelihood from a pytree-shaped ``ModelInstance``. @@ -283,7 +281,6 @@ def full_pipeline_from_params(params_tree): """ return analysis.log_likelihood_function(instance=params_tree) - _, full_result = jit_profile(full_pipeline_from_params, "full_pipeline", params_tree) full_pipeline_per_call = timer.records[-1][1] / 10 @@ -307,10 +304,8 @@ def full_pipeline_from_params(params_tree): _n_leaves = len(jax.tree_util.tree_leaves(params_tree)) if _n_leaves == 0: - print( - f" SKIPPED: model has 0 free parameters (all fixed to truth); " - f"vmap requires at least one array leaf." - ) + print(f" SKIPPED: model has 0 free parameters (all fixed to truth); " + f"vmap requires at least one array leaf.") else: parameters = jax.tree_util.tree_map( lambda leaf: jnp.broadcast_to(leaf, (batch_size, *leaf.shape)), @@ -388,7 +383,6 @@ def full_pipeline_from_params(params_tree): import json import matplotlib - matplotlib.use("Agg") import matplotlib.pyplot as plt @@ -436,32 +430,22 @@ def full_pipeline_from_params(params_tree): "figure_of_merit_eager": float(figure_of_merit_ref), "log_evidence_jit": float(full_result), "full_pipeline_single_jit": full_pipeline_per_call, - "vmap": ( - "SKIPPED — model has 0 free parameters (all fixed to truth)" - if vmap_per_call is None - else { - "batch_size": batch_size, - "batch_time": vmap_batch_time, - "per_call": vmap_per_call, - "speedup_vs_single_jit": round(vmap_speedup, 1), - } - ), - "memory_mb": ( - None - if memory_analysis is None - else { - "output": memory_analysis.output_size_in_bytes / 1024**2, - "temp": memory_analysis.temp_size_in_bytes / 1024**2, - } - ), + "vmap": "SKIPPED — model has 0 free parameters (all fixed to truth)" if vmap_per_call is None else { + "batch_size": batch_size, + "batch_time": vmap_batch_time, + "per_call": vmap_per_call, + "speedup_vs_single_jit": round(vmap_speedup, 1), + }, + "memory_mb": None if memory_analysis is None else { + "output": memory_analysis.output_size_in_bytes / 1024**2, + "temp": memory_analysis.temp_size_in_bytes / 1024**2, + }, } results_dir = _workspace_root / "results" / "likelihood" / "interferometer" results_dir.mkdir(parents=True, exist_ok=True) -dict_path = ( - results_dir / f"pixelization_likelihood_summary_{instrument}_v{al_version}.json" -) +dict_path = results_dir / f"pixelization_likelihood_summary_{instrument}_v{al_version}.json" dict_path.write_text(json.dumps(likelihood_summary, indent=2)) print(f"\n Results dict saved to: {dict_path}") @@ -497,13 +481,9 @@ def full_pipeline_from_params(params_tree): fontsize=12, fontweight="bold", ) -_vmap_title = ( - f"vmap speedup: {vmap_speedup:.1f}x" - if vmap_speedup is not None - else "vmap: SKIPPED" -) +_vmap_title = f"vmap speedup: {vmap_speedup:.1f}x" if vmap_speedup is not None else "vmap: SKIPPED" ax.set_title( - f'AutoLens v{al_version} | {pixel_scale}"/px | ' + f"AutoLens v{al_version} | {pixel_scale}\"/px | " f"{real_space_shape[0]}x{real_space_shape[1]} real-space | " f"{n_visibilities} visibilities | {mesh_shape[0]}x{mesh_shape[1]} mesh | " f"{_vmap_title}", @@ -512,9 +492,7 @@ def full_pipeline_from_params(params_tree): ax.margins(x=0.2) fig.tight_layout() -chart_path = ( - results_dir / f"pixelization_likelihood_summary_{instrument}_v{al_version}.png" -) +chart_path = results_dir / f"pixelization_likelihood_summary_{instrument}_v{al_version}.png" fig.savefig(chart_path, dpi=150) plt.close(fig) print(f" Bar chart saved to: {chart_path}") @@ -553,6 +531,4 @@ def full_pipeline_from_params(params_tree): rtol=1e-4, err_msg=f"interferometer/pixelization[{instrument}]: regression — vmap log_evidence drifted", ) -print( - f" Regression assertion PASSED: log_evidence matches {EXPECTED_LOG_EVIDENCE_SMA:.6f}" -) +print(f" Regression assertion PASSED: log_evidence matches {EXPECTED_LOG_EVIDENCE_SMA:.6f}") diff --git a/likelihood/point_source/image_plane.py b/likelihood/point_source/image_plane.py index b8278fc..cd51051 100644 --- a/likelihood/point_source/image_plane.py +++ b/likelihood/point_source/image_plane.py @@ -129,7 +129,9 @@ def jit_profile(func, label, *args, n_repeats=10): _script_dir = Path(__file__).resolve().parent _workspace_root = _script_dir.parents[1] -dataset_path = Path("dataset") / "point_source" / dataset_name +dataset_path = ( + Path("dataset") / "point_source" / dataset_name +) if al.util.dataset.should_simulate(str(dataset_path)): raise FileNotFoundError( @@ -341,13 +343,9 @@ def full_pipeline_from_params(params_tree): print(f" Observed image positions: {n_observed_positions}") print(f" Position noise sigma: {positions_noise_sigma}") print(f" Free parameters: {model.total_free_parameters}") -print( - f" fit_positions_cls: FitPositionsImagePairAll (image-plane chi-squared)" -) +print(f" fit_positions_cls: FitPositionsImagePairAll (image-plane chi-squared)") print("-" * 70) -print( - f" Eager full likelihood: {eager_per_call:.6f} s/call ({log_likelihood_ref:.6f})" -) +print(f" Eager full likelihood: {eager_per_call:.6f} s/call ({log_likelihood_ref:.6f})") print(f" Full pipeline (JIT): {full_pipeline_per_call:.6f} s/call") print(f" vmap per-call (batch={batch_size}): {vmap_per_call:.6f} s") print(f" vmap speedup vs single JIT: {vmap_speedup:.1f}x") diff --git a/likelihood/point_source/source_plane.py b/likelihood/point_source/source_plane.py index 43ab685..660e010 100644 --- a/likelihood/point_source/source_plane.py +++ b/likelihood/point_source/source_plane.py @@ -112,7 +112,9 @@ def jit_profile(func, label, *args, n_repeats=10): _script_dir = Path(__file__).resolve().parent _workspace_root = _script_dir.parents[1] -dataset_path = Path("dataset") / "point_source" / dataset_name +dataset_path = ( + Path("dataset") / "point_source" / dataset_name +) if al.util.dataset.should_simulate(str(dataset_path)): raise FileNotFoundError( @@ -379,9 +381,7 @@ def ray_trace_to_source_plane(params_tree, positions_raw): print(f" Free parameters: {model.total_free_parameters}") print(f" fit_positions_cls: FitPositionsSource (source-plane chi-squared)") print("-" * 70) -print( - f" Eager full likelihood: {eager_per_call:.6f} s/call ({log_likelihood_ref:.6f})" -) +print(f" Eager full likelihood: {eager_per_call:.6f} s/call ({log_likelihood_ref:.6f})") if full_pipeline_jits: print(f" Full pipeline (JIT): {full_pipeline_per_call:.6f} s/call") else: @@ -460,7 +460,9 @@ def ray_trace_to_source_plane(params_tree, positions_raw): fontsize=12, fontweight="bold", ) -title_extra = " | full pipeline JIT BLOCKED" if not full_pipeline_jits else "" +title_extra = ( + " | full pipeline JIT BLOCKED" if not full_pipeline_jits else "" +) ax.set_title( f"AutoLens v{al_version} | {n_observed_positions} positions | " f"{model.total_free_parameters} free params{title_extra}",