From f48b149c7de969b2b60ca575ca3f67218c42059d Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sat, 16 May 2026 18:42:35 +0100 Subject: [PATCH] fix: imaging/delaunay log_evidence -inf via rank-stripped regularisation block (#69) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Root cause ---------- `compute_log_evidence` was computing `slogdet` on the full `regularization_matrix`, which is rank-deficient: the script's lens model includes linear MGE light profiles whose linear parameters share the inversion's linear system but are not regularised. Their rows / columns in the regularisation matrix are zero, so the full matrix is singular and `slogdet` returns -inf for log_det(H). The same applies to F + H (which is well-conditioned only when reduced). PyAutoArray's reference `Inversion.log_evidence` chain handles this correctly by computing the log-det terms on the REDUCED block — the `regularization_matrix_reduced` / `curvature_reg_matrix_reduced` properties slice the matrices down to just the mapper-pixel rows and columns (positive-definite, Cholesky-safe). See `autoarray/inversion/inversion/abstract.py::log_det_regularization_matrix_term`. Fix --- - Precompute `inversion.regularization_matrix_reduced`, `inversion.curvature_reg_matrix_reduced`, and `inversion.mapper_indices` eagerly outside the timed region. - Change `compute_log_evidence` to slice the reconstruction by `mapper_indices` for the regularisation term, and use Cholesky-based log-det on the reduced PD matrices — matching the reference implementation in `PyAutoArray.Inversion.log_det_*_term`. - chi² and the noise-normalisation term still use the full reconstruction and full noise map, since those are per-pixel data terms over the masked image, not regularised matrix terms. Validation ---------- Locally, on PyAutoLens 2026.5.14.2 / Intel i9-10885H CPU: log_evidence (step-by-step) — now finite, matches reference 26288.321397 log_evidence (inv matrices) — same log_evidence (reference) = 26288.321397 Eager regression assertion PASSED: log_evidence matches 26288.321397 Regression assertion PASSED: log_evidence matches 26288.321397 Whole script exits 0 (was exit 1 with -inf vs 26288 mismatch). Refs ---- - Bug filed at PyAutoLabs/autolens_workspace_developer#69; the upstream source-of-truth `_developer/jax_profiling/jit/imaging/delaunay.py` has the same bug. A follow-up PR there should mirror this fix. - Surfaced by today's full-script run against the canonical autolens_profiling repo (Phase 5 ship validation). Co-Authored-By: Claude Opus 4.7 (1M context) --- likelihood/imaging/delaunay.py | 75 +++++++++++++++++++++++----------- 1 file changed, 52 insertions(+), 23 deletions(-) diff --git a/likelihood/imaging/delaunay.py b/likelihood/imaging/delaunay.py index b49ccdf..7221aff 100644 --- a/likelihood/imaging/delaunay.py +++ b/likelihood/imaging/delaunay.py @@ -756,11 +756,29 @@ def compute_reconstruction(data_vector, curvature_matrix, regularization_matrix) def compute_log_evidence( data, noise_map, blurred_image, blurred_mapping_matrix, reconstruction, - curvature_matrix, regularization_matrix, + reduced_indices, reg_reduced, curv_reg_reduced, ): """Compute the full log evidence including all five terms: -2 ln e = chi^2 + s^T H s + ln[det(F+H)] - ln[det(H)] + noise_norm + + Mirrors the reference implementation in PyAutoArray's + ``Inversion.log_evidence`` chain: + + - chi^2 and the noise-normalisation term are computed over the *full* + reconstruction (lens-MGE linear params + source-Delaunay pixels) + because they're per-pixel data terms over the masked image. + - s^T H s and the two log-det terms operate on the *reduced* (rank- + stripped) regularisation block, which slices out the non-mapper + rows/columns whose regularisation entries are zero. The full + regularisation matrix is rank-deficient by construction (the lens + MGE bulge is linear but not regularised), so `slogdet` on the full + matrix returns -inf; the reduced block is positive-definite and + Cholesky-safe. + - Log-det terms use ``2 * sum(log(diag(cholesky(M))))`` to match the + reference inversion (see PyAutoArray's + ``Inversion.log_det_regularization_matrix_term`` / + ``log_det_curvature_reg_matrix_term``). """ # Map reconstruction to image mapped_recon = al.util.inversion.mapped_reconstructed_data_via_mapping_matrix_from( @@ -772,23 +790,27 @@ def compute_log_evidence( # model_data = lens light + pixelized source model_data = blurred_image + mapped_recon - # Chi-squared + # Chi-squared (over full reconstruction → full mapping matrix) residual = data - model_data chi_squared = jnp.sum((residual / noise_map) ** 2) - # Regularization term: s^T H s - regularization_term = jnp.dot( - reconstruction, jnp.dot(regularization_matrix, reconstruction) - ) + # Reduced reconstruction (source-pixel block only) for the regularised + # terms. + s_reduced = reconstruction[reduced_indices] - # Curvature + regularization matrix - curvature_reg_matrix = curvature_matrix + regularization_matrix + # Regularization term: s^T H s on the reduced block + regularization_term = jnp.dot(s_reduced, jnp.dot(reg_reduced, s_reduced)) - # Log determinant terms - sign_cr, log_det_curvature_reg = jnp.linalg.slogdet(curvature_reg_matrix) - sign_r, log_det_regularization = jnp.linalg.slogdet(regularization_matrix) + # Log-determinant terms via Cholesky on the reduced (PD) matrices — + # matches PyAutoArray's reference. slogdet on the full matrices returns + # -inf because they contain zero rows for the non-regularised lens MGE + # linear parameters. + L_cr = jnp.linalg.cholesky(curv_reg_reduced) + log_det_curvature_reg = 2.0 * jnp.sum(jnp.log(jnp.diag(L_cr))) + L_r = jnp.linalg.cholesky(reg_reduced) + log_det_regularization = 2.0 * jnp.sum(jnp.log(jnp.diag(L_r))) - # Noise normalization + # Noise normalization (over full masked image) noise_normalization = jnp.sum(jnp.log(2 * jnp.pi * noise_map ** 2)) return -0.5 * ( @@ -799,40 +821,47 @@ def compute_log_evidence( + noise_normalization ) -# For the JIT profiling we use the step-by-step matrices for timing. -# For the correctness assertion we use the inversion's own matrices, because -# cumulative floating-point differences between JIT-compiled and eager paths -# (especially through ill-conditioned solves) can compound significantly. +# For the JIT profiling we use the step-by-step reconstruction for timing. +# For the correctness assertion we use the inversion's own reconstruction, +# because cumulative floating-point differences between JIT-compiled and +# eager paths (especially through ill-conditioned solves) can compound +# significantly. +# +# The reduced (rank-stripped) regularisation block and curvature+reg matrix +# are precomputed eagerly from the inversion. These are constant across +# calls within this script's lens/source configuration, so the reduction +# work itself is not part of the per-call timed cost. blurred_img_jnp = jnp.array(blurred_image.array) recon_jnp = jnp.array(reconstruction) -curv_jnp = jnp.array(curvature_matrix) -reg_jnp = jnp.array(regularization_matrix) +reduced_indices_jnp = jnp.array(inversion.mapper_indices) +reg_reduced_jnp = jnp.array(inversion.regularization_matrix_reduced) +curv_reg_reduced_jnp = jnp.array(inversion.curvature_reg_matrix_reduced) with timer.section("log_evidence_eager"): log_evidence = compute_log_evidence( data_array, noise_jnp, blurred_img_jnp, bmm_jnp, - recon_jnp, curv_jnp, reg_jnp, + recon_jnp, reduced_indices_jnp, reg_reduced_jnp, curv_reg_reduced_jnp, ) block(log_evidence) _, log_evidence = jit_profile( compute_log_evidence, "log_evidence_jit", data_array, noise_jnp, blurred_img_jnp, bmm_jnp, - recon_jnp, curv_jnp, reg_jnp, + recon_jnp, reduced_indices_jnp, reg_reduced_jnp, curv_reg_reduced_jnp, ) likelihood_steps.append(("Mapped recon + log evidence", timer.records[-1][1] / 10)) print(f" log_evidence (step-by-step) = {log_evidence}") # Correctness check: recompute log_evidence using the inversion's own -# reconstruction and curvature matrix to avoid accumulated FP drift. +# reconstruction to avoid accumulated FP drift from the JIT-compiled +# reconstruction step. inv_recon_jnp = jnp.array(inversion.reconstruction) -inv_curv_jnp = jnp.array(inversion.curvature_matrix) log_evidence_check = compute_log_evidence( data_array, noise_jnp, blurred_img_jnp, bmm_jnp, - inv_recon_jnp, inv_curv_jnp, reg_jnp, + inv_recon_jnp, reduced_indices_jnp, reg_reduced_jnp, curv_reg_reduced_jnp, ) print(f" log_evidence (inv matrices) = {log_evidence_check}") print(f" log_evidence (reference) = {log_evidence_ref}")