diff --git a/jax_profiling/point_source/results/image_plane_summary_v2026.4.13.6.json b/jax_profiling/point_source/results/image_plane_summary_v2026.4.13.6.json index 9fa2a0a..6b65177 100644 --- a/jax_profiling/point_source/results/image_plane_summary_v2026.4.13.6.json +++ b/jax_profiling/point_source/results/image_plane_summary_v2026.4.13.6.json @@ -7,14 +7,14 @@ "positions_noise_sigma": 0.2, "free_parameters": 8 }, - "eager_per_call": 1.4605535200000304, + "eager_per_call": 2.956794569999329, "eager_log_likelihood": 0.3936326580483207, - "full_pipeline_single_jit": 0.029288589999850956, + "full_pipeline_single_jit": 0.04489453000060166, "full_pipeline_log_likelihood": 0.39363265804832137, "vmap": { "batch_size": 3, - "batch_time": 0.041269269999975225, - "per_call": 0.013756423333325074, - "speedup_vs_single_jit": 2.1 + "batch_time": 0.07011199000044144, + "per_call": 0.02337066333348048, + "speedup_vs_single_jit": 1.9 } } \ No newline at end of file diff --git a/jax_profiling/point_source/results/image_plane_summary_v2026.4.13.6.png b/jax_profiling/point_source/results/image_plane_summary_v2026.4.13.6.png index 1d745d7..c2a0a50 100644 Binary files a/jax_profiling/point_source/results/image_plane_summary_v2026.4.13.6.png and b/jax_profiling/point_source/results/image_plane_summary_v2026.4.13.6.png differ diff --git a/jax_profiling/point_source/results/source_plane_summary_v2026.4.13.6.json b/jax_profiling/point_source/results/source_plane_summary_v2026.4.13.6.json index 1ed292d..02e2c71 100644 --- a/jax_profiling/point_source/results/source_plane_summary_v2026.4.13.6.json +++ b/jax_profiling/point_source/results/source_plane_summary_v2026.4.13.6.json @@ -7,19 +7,19 @@ "positions_noise_sigma": 0.2, "free_parameters": 8 }, - "eager_per_call": 0.0028598109999802544, - "eager_log_likelihood": -4496.798984131583, + "eager_per_call": 0.007736787000030745, + "eager_log_likelihood": -4491.83220547254, "full_pipeline_jits": true, "full_pipeline_blocker": null, - "full_pipeline_single_jit": 0.00038716000017302576, + "full_pipeline_single_jit": 0.0005248499997833278, "jit_able_prefix": { "name": "ray-trace observed positions to source plane", - "per_call": 0.00034986999999091494 + "per_call": 0.0004281299996364396 }, "vmap_prefix": { "batch_size": 3, - "batch_time": 0.0002824099999997998, - "per_call": 9.413666666659993e-05, - "speedup_vs_single_jit_prefix": 3.7 + "batch_time": 0.0004249800003890414, + "per_call": 0.00014166000012968047, + "speedup_vs_single_jit_prefix": 3.0 } } \ No newline at end of file diff --git a/jax_profiling/point_source/results/source_plane_summary_v2026.4.13.6.png b/jax_profiling/point_source/results/source_plane_summary_v2026.4.13.6.png index e287e9e..5b70fa1 100644 Binary files a/jax_profiling/point_source/results/source_plane_summary_v2026.4.13.6.png and b/jax_profiling/point_source/results/source_plane_summary_v2026.4.13.6.png differ diff --git a/jax_profiling/point_source/source_plane.py b/jax_profiling/point_source/source_plane.py index da4300f..2451fb0 100644 --- a/jax_profiling/point_source/source_plane.py +++ b/jax_profiling/point_source/source_plane.py @@ -11,23 +11,6 @@ ray-traced positions and the model source position. No image-plane solver is required. -LIBRARY-LEVEL NUMERICAL DRIFT (np vs jnp magnifications) ---------------------------------------------------------- - -The earlier ``TracerArrayConversionError`` JIT blocker (``Grid2DIrregular. -grid_2d_via_deflection_grid_from`` not propagating ``xp``) has been fixed -upstream — the full source-plane likelihood now JIT-traces end-to-end. - -A separate numerical issue remains: the eager-np and eager-jnp paths of -``LensCalc.magnification_2d_via_hessian_from`` disagree at ~5e-4 relative, -which amplifies into a ~0.1% drift in the final log-likelihood. See the -follow-up prompt ``admin_jammy/prompt/autolens/lens_calc_magnification_xp_divergence.md`` -for the full analysis and plan. - -Until that lands, the JIT regression assertion below is loosened from -``rtol=1e-4`` to ``rtol=2e-3`` so the script stays green. The eager-only -assertion remains tight. - Pytree-native parameter inputs ------------------------------ @@ -496,14 +479,15 @@ def ray_trace_to_source_plane(params_tree, positions_raw): # =================================================================== -# Regression assertions (eager-only — see BLOCKER above) +# Regression assertions (eager and full-pipeline JIT) # =================================================================== # # Seeded simulator (noise_seed=1 in simulators/point_source.py) + prior-median -# parameter vector make the eager source-plane log-likelihood deterministic. -# The full-JIT regression assertion is gated on full_pipeline_jits — once the -# library blocker is fixed, both branches will fire. -EXPECTED_LOG_LIKELIHOOD_SOURCE_PLANE = -4496.798984131583 +# parameter vector make the source-plane log-likelihood deterministic. Both +# the eager numpy and the full-pipeline JIT paths now agree to float64 +# precision, following the Richardson-extrapolation fix to +# LensCalc.hessian_from in PyAutoGalaxy (PR #358). +EXPECTED_LOG_LIKELIHOOD_SOURCE_PLANE = -4491.83220547254 np.testing.assert_allclose( log_likelihood_ref, @@ -520,13 +504,10 @@ def ray_trace_to_source_plane(params_tree, positions_raw): ) if full_pipeline_jits: - # rtol loosened pending the np/jnp parity fix in - # LensCalc.magnification_2d_via_hessian_from — see - # admin_jammy/prompt/autolens/lens_calc_magnification_xp_divergence.md np.testing.assert_allclose( float(full_result), EXPECTED_LOG_LIKELIHOOD_SOURCE_PLANE, - rtol=2e-3, + rtol=1e-4, err_msg=( f"point_source/source_plane: regression — JIT log_likelihood drifted " f"(got {float(full_result)}, expected {EXPECTED_LOG_LIKELIHOOD_SOURCE_PLANE})"