Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified jax_profiling/imaging/dataset/imaging/hst/data.fits
Binary file not shown.
Binary file modified jax_profiling/imaging/dataset/imaging/hst/noise_map.fits
Binary file not shown.
48 changes: 24 additions & 24 deletions jax_profiling/imaging/dataset/imaging/hst/tracer.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,19 @@
"type": "instance",
"class_path": "autolens.lens.tracer.Tracer",
"arguments": {
"cosmology": {
"type": "instance",
"class_path": "autogalaxy.cosmology.model.Planck15",
"arguments": {}
},
"galaxies": {
"type": "list",
"values": [
{
"type": "instance",
"class_path": "autogalaxy.galaxy.galaxy.Galaxy",
"arguments": {
"redshift": 0.5,
"label": null,
"redshift": 0.5,
"bulge": {
"type": "instance",
"class_path": "autogalaxy.profiles.light.standard.sersic.Sersic",
"arguments": {
"sersic_index": 3.0,
"intensity": 2.0,
"ell_comps": {
"type": "tuple",
Expand All @@ -29,43 +23,44 @@
3.2227547345982974e-18
]
},
"effective_radius": 0.6,
"centre": {
"type": "tuple",
"values": [
0.0,
0.0
]
},
"effective_radius": 0.6
"sersic_index": 3.0
}
},
"mass": {
"type": "instance",
"class_path": "autogalaxy.profiles.mass.total.isothermal.Isothermal",
"arguments": {
"centre": {
"type": "tuple",
"values": [
0.0,
0.0
]
},
"ell_comps": {
"type": "tuple",
"values": [
0.05263157894736841,
3.2227547345982974e-18
]
},
"einstein_radius": 1.6
"einstein_radius": 1.6,
"centre": {
"type": "tuple",
"values": [
0.0,
0.0
]
}
}
},
"shear": {
"type": "instance",
"class_path": "autogalaxy.profiles.mass.sheets.external_shear.ExternalShear",
"arguments": {
"gamma_2": 0.05,
"gamma_1": 0.05
"gamma_1": 0.05,
"gamma_2": 0.05
}
}
}
Expand All @@ -74,37 +69,42 @@
"type": "instance",
"class_path": "autogalaxy.galaxy.galaxy.Galaxy",
"arguments": {
"redshift": 1.0,
"label": null,
"redshift": 1.0,
"bulge": {
"type": "instance",
"class_path": "autogalaxy.profiles.light.standard.sersic_core.SersicCore",
"arguments": {
"sersic_index": 1.0,
"radius_break": 0.025,
"intensity": 4.0,
"gamma": 0.25,
"alpha": 3.0,
"ell_comps": {
"type": "tuple",
"values": [
0.0962250448649376,
-0.05555555555555551
]
},
"gamma": 0.25,
"radius_break": 0.025,
"effective_radius": 0.1,
"centre": {
"type": "tuple",
"values": [
0.0,
0.0
]
},
"effective_radius": 0.1,
"alpha": 3.0
"sersic_index": 1.0
}
}
}
}
]
},
"cosmology": {
"type": "instance",
"class_path": "autogalaxy.cosmology.model.Planck15",
"arguments": {}
}
}
}
26 changes: 26 additions & 0 deletions jax_profiling/imaging/delaunay.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,3 +1030,29 @@ def compute_log_evidence(
fig.savefig(chart_path, dpi=150)
plt.close(fig)
print(f" Bar chart saved to: {chart_path}")


# ===================================================================
# Regression assertion — realistic-scale deterministic log-evidence
# ===================================================================
#
# Seeded simulator (noise_seed=1 in simulators/imaging.py) + fixed model
# parameters make the full-pipeline log-evidence deterministic at this
# HST-scale Delaunay-pixelization dataset. vmap result asserted only when
# DELAUNAY_VMAP=1 (vmap compile takes 20+ min otherwise).
EXPECTED_LOG_EVIDENCE_HST = -1802826962.700122

np.testing.assert_allclose(
float(full_result),
EXPECTED_LOG_EVIDENCE_HST,
rtol=1e-4,
err_msg=f"imaging/delaunay[{instrument}]: regression — full log_evidence drifted",
)
if run_vmap:
np.testing.assert_allclose(
np.array(result_vmap),
EXPECTED_LOG_EVIDENCE_HST,
rtol=1e-4,
err_msg=f"imaging/delaunay[{instrument}]: regression — vmap log_evidence drifted",
)
print(f" Regression assertion PASSED: log_evidence matches {EXPECTED_LOG_EVIDENCE_HST:.6f}")
24 changes: 24 additions & 0 deletions jax_profiling/imaging/mge.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,3 +786,27 @@ def full_pipeline_from_params(params_tree):
plt.close(fig)
print(f" Bar chart saved to: {chart_path}")


# ===================================================================
# Regression assertion — realistic-scale deterministic likelihood
# ===================================================================
#
# Seeded simulator (noise_seed=1 in simulators/imaging.py) + fixed model
# parameters make the full-pipeline log-likelihood deterministic at this
# HST-scale dataset. Hardcoded value guards against silent regressions in
# the light-profile / blurring / chi-squared stack.
EXPECTED_LOG_LIKELIHOOD_HST = -159736.35504220804

np.testing.assert_allclose(
float(full_result),
EXPECTED_LOG_LIKELIHOOD_HST,
rtol=1e-4,
err_msg=f"imaging/mge[{instrument}]: regression — full log_likelihood drifted",
)
np.testing.assert_allclose(
np.array(result_vmap),
EXPECTED_LOG_LIKELIHOOD_HST,
rtol=1e-4,
err_msg=f"imaging/mge[{instrument}]: regression — vmap log_likelihood drifted",
)
print(f" Regression assertion PASSED: log_likelihood matches {EXPECTED_LOG_LIKELIHOOD_HST:.6f}")
15 changes: 15 additions & 0 deletions jax_profiling/imaging/mge_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,3 +785,18 @@ def full_pipeline_from_params(params_tree):
print("-" * 70)
print(f" {n_pass} passed, {n_fail} failed, {n_error} errors out of {len(results)} tests")
print("=" * 70)

assert n_error == 0, (
f"Regression: {n_error} gradient step(s) errored on stock library defaults "
f"(nnls_target_kappa=1e-11, no_regularization_add_to_curvature_diag_value=1e-3)."
)
assert n_fail == 0, (
f"Regression: {n_fail} gradient step(s) produced non-finite values on stock "
f"library defaults (nnls_target_kappa=1e-11, "
f"no_regularization_add_to_curvature_diag_value=1e-3)."
)
assert n_pass == len(results), (
f"Regression: expected all {len(results)} gradient steps to PASS on stock "
f"library defaults, got {n_pass}."
)
print(f" Regression assertion PASSED: all {n_pass}/{len(results)} gradient steps finite")
25 changes: 25 additions & 0 deletions jax_profiling/imaging/pixelization.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,3 +950,28 @@ def full_pipeline_from_params(params_tree):
fig.savefig(chart_path, dpi=150)
plt.close(fig)
print(f" Bar chart saved to: {chart_path}")


# ===================================================================
# Regression assertion — realistic-scale deterministic log-evidence
# ===================================================================
#
# Seeded simulator (noise_seed=1 in simulators/imaging.py) + fixed model
# parameters make the full-pipeline log-evidence deterministic at this
# HST-scale rectangular-pixelization dataset. Guards against regressions
# in the mapper / curvature / NNLS / regularization stack.
EXPECTED_LOG_EVIDENCE_HST = -1338521802.3596945

np.testing.assert_allclose(
float(full_result),
EXPECTED_LOG_EVIDENCE_HST,
rtol=1e-4,
err_msg=f"imaging/pixelization[{instrument}]: regression — full log_evidence drifted",
)
np.testing.assert_allclose(
np.array(result_vmap),
EXPECTED_LOG_EVIDENCE_HST,
rtol=1e-4,
err_msg=f"imaging/pixelization[{instrument}]: regression — vmap log_evidence drifted",
)
print(f" Regression assertion PASSED: log_evidence matches {EXPECTED_LOG_EVIDENCE_HST:.6f}")
3 changes: 2 additions & 1 deletion jax_profiling/imaging/simulators/imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,13 @@ def simulate(instrument: str, mask_radius: float = 3.5):
pixel_scales=grid.pixel_scales,
)

# Simulator
# Simulator — seeded so likelihood assertions downstream are deterministic
simulator = al.SimulatorImaging(
exposure_time=300.0,
psf=psf,
background_sky_level=0.1,
add_poisson_noise_to_data=True,
noise_seed=1,
)

# Galaxies — lens with Sersic light + Isothermal mass, source with cored Sersic
Expand Down
22 changes: 11 additions & 11 deletions jax_profiling/interferometer/dataset/interferometer/sma/tracer.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
"type": "instance",
"class_path": "autogalaxy.profiles.mass.total.isothermal.Isothermal",
"arguments": {
"einstein_radius": 1.6,
"centre": {
"type": "tuple",
"values": [
0.0,
0.0
]
},
"einstein_radius": 1.6,
"ell_comps": {
"type": "tuple",
"values": [
Expand Down Expand Up @@ -52,26 +52,26 @@
"type": "instance",
"class_path": "autogalaxy.profiles.light.standard.sersic_core.SersicCore",
"arguments": {
"alpha": 3.0,
"gamma": 0.25,
"ell_comps": {
"effective_radius": 1.0,
"centre": {
"type": "tuple",
"values": [
0.0962250448649376,
-0.05555555555555551
0.1,
0.1
]
},
"intensity": 0.3,
"alpha": 3.0,
"centre": {
"radius_break": 0.025,
"ell_comps": {
"type": "tuple",
"values": [
0.1,
0.1
0.0962250448649376,
-0.05555555555555551
]
},
"effective_radius": 1.0,
"sersic_index": 2.5,
"radius_break": 0.025
"sersic_index": 2.5
}
}
}
Expand Down
25 changes: 25 additions & 0 deletions jax_profiling/interferometer/mge.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,28 @@ def full_pipeline_from_params(params_tree):
fig.savefig(chart_path, dpi=150)
plt.close(fig)
print(f" Bar chart saved to: {chart_path}")


# ===================================================================
# Regression assertion — realistic-scale deterministic log-likelihood
# ===================================================================
#
# Seeded simulator (noise_seed=1 in simulators/interferometer.py) + fixed
# SMA uv-coverage + fixed model parameters make the full-pipeline
# log-likelihood deterministic. Guards against regressions in the
# visibility transform / MGE inversion / chi-squared stack.
EXPECTED_LOG_LIKELIHOOD_SMA = -3154.8053574023816

np.testing.assert_allclose(
float(full_result),
EXPECTED_LOG_LIKELIHOOD_SMA,
rtol=1e-4,
err_msg=f"interferometer/mge[{instrument}]: regression — full log_likelihood drifted",
)
np.testing.assert_allclose(
np.array(result_vmap),
EXPECTED_LOG_LIKELIHOOD_SMA,
rtol=1e-4,
err_msg=f"interferometer/mge[{instrument}]: regression — vmap log_likelihood drifted",
)
print(f" Regression assertion PASSED: log_likelihood matches {EXPECTED_LOG_LIKELIHOOD_SMA:.6f}")
Loading