diff --git a/likelihood/imaging/delaunay.py b/likelihood/imaging/delaunay.py index b49ccdf..7221aff 100644 --- a/likelihood/imaging/delaunay.py +++ b/likelihood/imaging/delaunay.py @@ -756,11 +756,29 @@ def compute_reconstruction(data_vector, curvature_matrix, regularization_matrix) def compute_log_evidence( data, noise_map, blurred_image, blurred_mapping_matrix, reconstruction, - curvature_matrix, regularization_matrix, + reduced_indices, reg_reduced, curv_reg_reduced, ): """Compute the full log evidence including all five terms: -2 ln e = chi^2 + s^T H s + ln[det(F+H)] - ln[det(H)] + noise_norm + + Mirrors the reference implementation in PyAutoArray's + ``Inversion.log_evidence`` chain: + + - chi^2 and the noise-normalisation term are computed over the *full* + reconstruction (lens-MGE linear params + source-Delaunay pixels) + because they're per-pixel data terms over the masked image. + - s^T H s and the two log-det terms operate on the *reduced* (rank- + stripped) regularisation block, which slices out the non-mapper + rows/columns whose regularisation entries are zero. The full + regularisation matrix is rank-deficient by construction (the lens + MGE bulge is linear but not regularised), so `slogdet` on the full + matrix returns -inf; the reduced block is positive-definite and + Cholesky-safe. + - Log-det terms use ``2 * sum(log(diag(cholesky(M))))`` to match the + reference inversion (see PyAutoArray's + ``Inversion.log_det_regularization_matrix_term`` / + ``log_det_curvature_reg_matrix_term``). """ # Map reconstruction to image mapped_recon = al.util.inversion.mapped_reconstructed_data_via_mapping_matrix_from( @@ -772,23 +790,27 @@ def compute_log_evidence( # model_data = lens light + pixelized source model_data = blurred_image + mapped_recon - # Chi-squared + # Chi-squared (over full reconstruction → full mapping matrix) residual = data - model_data chi_squared = jnp.sum((residual / noise_map) ** 2) - # Regularization term: s^T H s - regularization_term = jnp.dot( - reconstruction, jnp.dot(regularization_matrix, reconstruction) - ) + # Reduced reconstruction (source-pixel block only) for the regularised + # terms. + s_reduced = reconstruction[reduced_indices] - # Curvature + regularization matrix - curvature_reg_matrix = curvature_matrix + regularization_matrix + # Regularization term: s^T H s on the reduced block + regularization_term = jnp.dot(s_reduced, jnp.dot(reg_reduced, s_reduced)) - # Log determinant terms - sign_cr, log_det_curvature_reg = jnp.linalg.slogdet(curvature_reg_matrix) - sign_r, log_det_regularization = jnp.linalg.slogdet(regularization_matrix) + # Log-determinant terms via Cholesky on the reduced (PD) matrices — + # matches PyAutoArray's reference. slogdet on the full matrices returns + # -inf because they contain zero rows for the non-regularised lens MGE + # linear parameters. + L_cr = jnp.linalg.cholesky(curv_reg_reduced) + log_det_curvature_reg = 2.0 * jnp.sum(jnp.log(jnp.diag(L_cr))) + L_r = jnp.linalg.cholesky(reg_reduced) + log_det_regularization = 2.0 * jnp.sum(jnp.log(jnp.diag(L_r))) - # Noise normalization + # Noise normalization (over full masked image) noise_normalization = jnp.sum(jnp.log(2 * jnp.pi * noise_map ** 2)) return -0.5 * ( @@ -799,40 +821,47 @@ def compute_log_evidence( + noise_normalization ) -# For the JIT profiling we use the step-by-step matrices for timing. -# For the correctness assertion we use the inversion's own matrices, because -# cumulative floating-point differences between JIT-compiled and eager paths -# (especially through ill-conditioned solves) can compound significantly. +# For the JIT profiling we use the step-by-step reconstruction for timing. +# For the correctness assertion we use the inversion's own reconstruction, +# because cumulative floating-point differences between JIT-compiled and +# eager paths (especially through ill-conditioned solves) can compound +# significantly. +# +# The reduced (rank-stripped) regularisation block and curvature+reg matrix +# are precomputed eagerly from the inversion. These are constant across +# calls within this script's lens/source configuration, so the reduction +# work itself is not part of the per-call timed cost. blurred_img_jnp = jnp.array(blurred_image.array) recon_jnp = jnp.array(reconstruction) -curv_jnp = jnp.array(curvature_matrix) -reg_jnp = jnp.array(regularization_matrix) +reduced_indices_jnp = jnp.array(inversion.mapper_indices) +reg_reduced_jnp = jnp.array(inversion.regularization_matrix_reduced) +curv_reg_reduced_jnp = jnp.array(inversion.curvature_reg_matrix_reduced) with timer.section("log_evidence_eager"): log_evidence = compute_log_evidence( data_array, noise_jnp, blurred_img_jnp, bmm_jnp, - recon_jnp, curv_jnp, reg_jnp, + recon_jnp, reduced_indices_jnp, reg_reduced_jnp, curv_reg_reduced_jnp, ) block(log_evidence) _, log_evidence = jit_profile( compute_log_evidence, "log_evidence_jit", data_array, noise_jnp, blurred_img_jnp, bmm_jnp, - recon_jnp, curv_jnp, reg_jnp, + recon_jnp, reduced_indices_jnp, reg_reduced_jnp, curv_reg_reduced_jnp, ) likelihood_steps.append(("Mapped recon + log evidence", timer.records[-1][1] / 10)) print(f" log_evidence (step-by-step) = {log_evidence}") # Correctness check: recompute log_evidence using the inversion's own -# reconstruction and curvature matrix to avoid accumulated FP drift. +# reconstruction to avoid accumulated FP drift from the JIT-compiled +# reconstruction step. inv_recon_jnp = jnp.array(inversion.reconstruction) -inv_curv_jnp = jnp.array(inversion.curvature_matrix) log_evidence_check = compute_log_evidence( data_array, noise_jnp, blurred_img_jnp, bmm_jnp, - inv_recon_jnp, inv_curv_jnp, reg_jnp, + inv_recon_jnp, reduced_indices_jnp, reg_reduced_jnp, curv_reg_reduced_jnp, ) print(f" log_evidence (inv matrices) = {log_evidence_check}") print(f" log_evidence (reference) = {log_evidence_ref}")