Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions jax_profiling/jit/imaging/delaunay.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,14 +702,27 @@ def compute_curvature_matrix(blurred_mapping_matrix, noise_map):
print(f" regularization_matrix shape: {regularization_matrix.shape}")

# ---------------------------------------------------------------------------
# Step 12: Regularized reconstruction: s = (F + H)^{-1} D
# Step 12: Regularized reconstruction: s = NNLS(F + H, D)
# ---------------------------------------------------------------------------
#
# Uses ``reconstruction_positive_only_from`` (NNLS) to match production
# AnalysisImaging behaviour. An earlier version of this script used
# ``jnp.linalg.solve(F+H, D)`` which under-reports the per-step
# reconstruction cost (~5 ms vs ~36 ms NNLS on RTX 2060). The two
# solvers happen to produce identical reconstructions for the
# well-conditioned ConstantSplit setup at prior medians (no negative
# source pixels, NNLS reduces to linear solve), so the downstream
# log-evidence value is unchanged within rtol=1e-4.

print("\n--- Step 12: Regularized reconstruction ---")

def compute_reconstruction(data_vector, curvature_matrix, regularization_matrix):
curvature_reg_matrix = curvature_matrix + regularization_matrix
return jnp.linalg.solve(curvature_reg_matrix, data_vector)
return al.util.inversion.reconstruction_positive_only_from(
data_vector=data_vector,
curvature_reg_matrix=curvature_reg_matrix,
xp=jnp,
)

with timer.section("reconstruction_eager"):
reconstruction = compute_reconstruction(
Expand Down