Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 52 additions & 23 deletions likelihood/imaging/delaunay.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,11 +756,29 @@ def compute_reconstruction(data_vector, curvature_matrix, regularization_matrix)

def compute_log_evidence(
data, noise_map, blurred_image, blurred_mapping_matrix, reconstruction,
curvature_matrix, regularization_matrix,
reduced_indices, reg_reduced, curv_reg_reduced,
):
"""Compute the full log evidence including all five terms:

-2 ln e = chi^2 + s^T H s + ln[det(F+H)] - ln[det(H)] + noise_norm

Mirrors the reference implementation in PyAutoArray's
``Inversion.log_evidence`` chain:

- chi^2 and the noise-normalisation term are computed over the *full*
reconstruction (lens-MGE linear params + source-Delaunay pixels)
because they're per-pixel data terms over the masked image.
- s^T H s and the two log-det terms operate on the *reduced* (rank-
stripped) regularisation block, which slices out the non-mapper
rows/columns whose regularisation entries are zero. The full
regularisation matrix is rank-deficient by construction (the lens
MGE bulge is linear but not regularised), so `slogdet` on the full
matrix returns -inf; the reduced block is positive-definite and
Cholesky-safe.
- Log-det terms use ``2 * sum(log(diag(cholesky(M))))`` to match the
reference inversion (see PyAutoArray's
``Inversion.log_det_regularization_matrix_term`` /
``log_det_curvature_reg_matrix_term``).
"""
# Map reconstruction to image
mapped_recon = al.util.inversion.mapped_reconstructed_data_via_mapping_matrix_from(
Expand All @@ -772,23 +790,27 @@ def compute_log_evidence(
# model_data = lens light + pixelized source
model_data = blurred_image + mapped_recon

# Chi-squared
# Chi-squared (over full reconstruction → full mapping matrix)
residual = data - model_data
chi_squared = jnp.sum((residual / noise_map) ** 2)

# Regularization term: s^T H s
regularization_term = jnp.dot(
reconstruction, jnp.dot(regularization_matrix, reconstruction)
)
# Reduced reconstruction (source-pixel block only) for the regularised
# terms.
s_reduced = reconstruction[reduced_indices]

# Curvature + regularization matrix
curvature_reg_matrix = curvature_matrix + regularization_matrix
# Regularization term: s^T H s on the reduced block
regularization_term = jnp.dot(s_reduced, jnp.dot(reg_reduced, s_reduced))

# Log determinant terms
sign_cr, log_det_curvature_reg = jnp.linalg.slogdet(curvature_reg_matrix)
sign_r, log_det_regularization = jnp.linalg.slogdet(regularization_matrix)
# Log-determinant terms via Cholesky on the reduced (PD) matrices —
# matches PyAutoArray's reference. slogdet on the full matrices returns
# -inf because they contain zero rows for the non-regularised lens MGE
# linear parameters.
L_cr = jnp.linalg.cholesky(curv_reg_reduced)
log_det_curvature_reg = 2.0 * jnp.sum(jnp.log(jnp.diag(L_cr)))
L_r = jnp.linalg.cholesky(reg_reduced)
log_det_regularization = 2.0 * jnp.sum(jnp.log(jnp.diag(L_r)))

# Noise normalization
# Noise normalization (over full masked image)
noise_normalization = jnp.sum(jnp.log(2 * jnp.pi * noise_map ** 2))

return -0.5 * (
Expand All @@ -799,40 +821,47 @@ def compute_log_evidence(
+ noise_normalization
)

# For the JIT profiling we use the step-by-step matrices for timing.
# For the correctness assertion we use the inversion's own matrices, because
# cumulative floating-point differences between JIT-compiled and eager paths
# (especially through ill-conditioned solves) can compound significantly.
# For the JIT profiling we use the step-by-step reconstruction for timing.
# For the correctness assertion we use the inversion's own reconstruction,
# because cumulative floating-point differences between JIT-compiled and
# eager paths (especially through ill-conditioned solves) can compound
# significantly.
#
# The reduced (rank-stripped) regularisation block and curvature+reg matrix
# are precomputed eagerly from the inversion. These are constant across
# calls within this script's lens/source configuration, so the reduction
# work itself is not part of the per-call timed cost.

blurred_img_jnp = jnp.array(blurred_image.array)
recon_jnp = jnp.array(reconstruction)
curv_jnp = jnp.array(curvature_matrix)
reg_jnp = jnp.array(regularization_matrix)
reduced_indices_jnp = jnp.array(inversion.mapper_indices)
reg_reduced_jnp = jnp.array(inversion.regularization_matrix_reduced)
curv_reg_reduced_jnp = jnp.array(inversion.curvature_reg_matrix_reduced)

with timer.section("log_evidence_eager"):
log_evidence = compute_log_evidence(
data_array, noise_jnp, blurred_img_jnp, bmm_jnp,
recon_jnp, curv_jnp, reg_jnp,
recon_jnp, reduced_indices_jnp, reg_reduced_jnp, curv_reg_reduced_jnp,
)
block(log_evidence)

_, log_evidence = jit_profile(
compute_log_evidence, "log_evidence_jit",
data_array, noise_jnp, blurred_img_jnp, bmm_jnp,
recon_jnp, curv_jnp, reg_jnp,
recon_jnp, reduced_indices_jnp, reg_reduced_jnp, curv_reg_reduced_jnp,
)
likelihood_steps.append(("Mapped recon + log evidence", timer.records[-1][1] / 10))

print(f" log_evidence (step-by-step) = {log_evidence}")

# Correctness check: recompute log_evidence using the inversion's own
# reconstruction and curvature matrix to avoid accumulated FP drift.
# reconstruction to avoid accumulated FP drift from the JIT-compiled
# reconstruction step.
inv_recon_jnp = jnp.array(inversion.reconstruction)
inv_curv_jnp = jnp.array(inversion.curvature_matrix)

log_evidence_check = compute_log_evidence(
data_array, noise_jnp, blurred_img_jnp, bmm_jnp,
inv_recon_jnp, inv_curv_jnp, reg_jnp,
inv_recon_jnp, reduced_indices_jnp, reg_reduced_jnp, curv_reg_reduced_jnp,
)
print(f" log_evidence (inv matrices) = {log_evidence_check}")
print(f" log_evidence (reference) = {log_evidence_ref}")
Expand Down
Loading