Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
"positions_noise_sigma": 0.2,
"free_parameters": 8
},
"eager_per_call": 1.4605535200000304,
"eager_per_call": 2.956794569999329,
"eager_log_likelihood": 0.3936326580483207,
"full_pipeline_single_jit": 0.029288589999850956,
"full_pipeline_single_jit": 0.04489453000060166,
"full_pipeline_log_likelihood": 0.39363265804832137,
"vmap": {
"batch_size": 3,
"batch_time": 0.041269269999975225,
"per_call": 0.013756423333325074,
"speedup_vs_single_jit": 2.1
"batch_time": 0.07011199000044144,
"per_call": 0.02337066333348048,
"speedup_vs_single_jit": 1.9
}
}
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@
"positions_noise_sigma": 0.2,
"free_parameters": 8
},
"eager_per_call": 0.0028598109999802544,
"eager_log_likelihood": -4496.798984131583,
"eager_per_call": 0.007736787000030745,
"eager_log_likelihood": -4491.83220547254,
"full_pipeline_jits": true,
"full_pipeline_blocker": null,
"full_pipeline_single_jit": 0.00038716000017302576,
"full_pipeline_single_jit": 0.0005248499997833278,
"jit_able_prefix": {
"name": "ray-trace observed positions to source plane",
"per_call": 0.00034986999999091494
"per_call": 0.0004281299996364396
},
"vmap_prefix": {
"batch_size": 3,
"batch_time": 0.0002824099999997998,
"per_call": 9.413666666659993e-05,
"speedup_vs_single_jit_prefix": 3.7
"batch_time": 0.0004249800003890414,
"per_call": 0.00014166000012968047,
"speedup_vs_single_jit_prefix": 3.0
}
}
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
33 changes: 7 additions & 26 deletions jax_profiling/point_source/source_plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,6 @@
ray-traced positions and the model source position. No image-plane solver
is required.

LIBRARY-LEVEL NUMERICAL DRIFT (np vs jnp magnifications)
---------------------------------------------------------

The earlier ``TracerArrayConversionError`` JIT blocker (``Grid2DIrregular.
grid_2d_via_deflection_grid_from`` not propagating ``xp``) has been fixed
upstream — the full source-plane likelihood now JIT-traces end-to-end.

A separate numerical issue remains: the eager-np and eager-jnp paths of
``LensCalc.magnification_2d_via_hessian_from`` disagree at ~5e-4 relative,
which amplifies into a ~0.1% drift in the final log-likelihood. See the
follow-up prompt ``admin_jammy/prompt/autolens/lens_calc_magnification_xp_divergence.md``
for the full analysis and plan.

Until that lands, the JIT regression assertion below is loosened from
``rtol=1e-4`` to ``rtol=2e-3`` so the script stays green. The eager-only
assertion remains tight.

Pytree-native parameter inputs
------------------------------

Expand Down Expand Up @@ -496,14 +479,15 @@ def ray_trace_to_source_plane(params_tree, positions_raw):


# ===================================================================
# Regression assertions (eager-only — see BLOCKER above)
# Regression assertions (eager and full-pipeline JIT)
# ===================================================================
#
# Seeded simulator (noise_seed=1 in simulators/point_source.py) + prior-median
# parameter vector make the eager source-plane log-likelihood deterministic.
# The full-JIT regression assertion is gated on full_pipeline_jits — once the
# library blocker is fixed, both branches will fire.
EXPECTED_LOG_LIKELIHOOD_SOURCE_PLANE = -4496.798984131583
# parameter vector make the source-plane log-likelihood deterministic. Both
# the eager numpy and the full-pipeline JIT paths now agree to float64
# precision, following the Richardson-extrapolation fix to
# LensCalc.hessian_from in PyAutoGalaxy (PR #358).
EXPECTED_LOG_LIKELIHOOD_SOURCE_PLANE = -4491.83220547254

np.testing.assert_allclose(
log_likelihood_ref,
Expand All @@ -520,13 +504,10 @@ def ray_trace_to_source_plane(params_tree, positions_raw):
)

if full_pipeline_jits:
# rtol loosened pending the np/jnp parity fix in
# LensCalc.magnification_2d_via_hessian_from — see
# admin_jammy/prompt/autolens/lens_calc_magnification_xp_divergence.md
np.testing.assert_allclose(
float(full_result),
EXPECTED_LOG_LIKELIHOOD_SOURCE_PLANE,
rtol=2e-3,
rtol=1e-4,
err_msg=(
f"point_source/source_plane: regression — JIT log_likelihood drifted "
f"(got {float(full_result)}, expected {EXPECTED_LOG_LIKELIHOOD_SOURCE_PLANE})"
Expand Down