Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 52 additions & 43 deletions jax_profiling/imaging/mge_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,20 @@ def test_grad(label, func, params):
# Force evaluation.
if hasattr(value, "block_until_ready"):
value.block_until_ready()
if hasattr(grad, "block_until_ready"):
grad.block_until_ready()

val_f = float(value)
grad_np = np.array(grad)
grad_leaves = jax.tree_util.tree_leaves(grad)
for leaf in grad_leaves:
if hasattr(leaf, "block_until_ready"):
leaf.block_until_ready()
grad_np = (
np.concatenate([np.asarray(l).ravel() for l in grad_leaves])
if grad_leaves
else np.array([])
)

print(f" value = {val_f:.8g}")
print(f" grad leaves = {len(grad_leaves)}")
print(f" grad shape = {grad_np.shape}")
print(f" grad norm = {np.linalg.norm(grad_np):.8g}")
print(f" grad min = {grad_np.min():.8g}")
Expand Down Expand Up @@ -244,6 +251,14 @@ def test_grad(label, func, params):
data_array = jnp.array(dataset.data.array)
noise_map_array = jnp.array(dataset.noise_map.array)

# Register pytree support for the model and convert the sampled ModelInstance
# into a JAX-array-valued pytree. Each step closure differentiates w.r.t. this
# tree directly rather than rebuilding the instance from a flat vector inside
# the trace. Matches the pattern used by mge.py and pixelization.py.
from autofit.jax import register_model as _register_model_pytrees
_register_model_pytrees(model)
params_tree = jax.tree_util.tree_map(jnp.asarray, instance)


# ===================================================================
# PART B -- Per-step gradient testing
Expand All @@ -258,22 +273,20 @@ def test_grad(label, func, params):
# ---------------------------------------------------------------------------

def step_ray_trace(params):
inst = model.instance_from_vector(vector=params, xp=jnp)
t = al.Tracer(galaxies=list(inst.galaxies))
t = al.Tracer(galaxies=list(params.galaxies))
grid_raw = jnp.array(grid_lp.array)
grid = aa.Grid2DIrregular(values=grid_raw, xp=jnp)
traced = t.traced_grid_2d_list_from(grid=grid, xp=jnp)
return jnp.sum(jnp.stack([tg.array for tg in traced]))

test_grad("Step 1: Ray-trace grids", step_ray_trace, jnp_params)
test_grad("Step 1: Ray-trace grids", step_ray_trace, params_tree)

# ---------------------------------------------------------------------------
# Step 2: Mapping matrix (linear profile images)
# ---------------------------------------------------------------------------

def step_mapping_matrix(params):
inst = model.instance_from_vector(vector=params, xp=jnp)
t = al.Tracer(galaxies=list(inst.galaxies))
t = al.Tracer(galaxies=list(params.galaxies))
tti = al.TracerToInversion(
dataset=aa.DatasetInterface(
data=fit.profile_subtracted_image,
Expand All @@ -291,15 +304,14 @@ def step_mapping_matrix(params):
mm = jnp.hstack(matrices) if len(matrices) > 1 else matrices[0]
return jnp.sum(mm)

test_grad("Step 2: Mapping matrix", step_mapping_matrix, jnp_params)
test_grad("Step 2: Mapping matrix", step_mapping_matrix, params_tree)

# ---------------------------------------------------------------------------
# Step 3: Blurred mapping matrix (PSF convolution of each profile)
# ---------------------------------------------------------------------------

def step_blurred_mapping_matrix(params):
inst = model.instance_from_vector(vector=params, xp=jnp)
t = al.Tracer(galaxies=list(inst.galaxies))
t = al.Tracer(galaxies=list(params.galaxies))
tti = al.TracerToInversion(
dataset=aa.DatasetInterface(
data=fit.profile_subtracted_image,
Expand All @@ -317,15 +329,14 @@ def step_blurred_mapping_matrix(params):
bmm = jnp.hstack(matrices) if len(matrices) > 1 else matrices[0]
return jnp.sum(bmm)

test_grad("Step 3: Blurred mapping matrix", step_blurred_mapping_matrix, jnp_params)
test_grad("Step 3: Blurred mapping matrix", step_blurred_mapping_matrix, params_tree)

# ---------------------------------------------------------------------------
# Step 4: Data vector (D)
# ---------------------------------------------------------------------------

def step_data_vector(params):
inst = model.instance_from_vector(vector=params, xp=jnp)
t = al.Tracer(galaxies=list(inst.galaxies))
t = al.Tracer(galaxies=list(params.galaxies))

tti = al.TracerToInversion(
dataset=aa.DatasetInterface(
Expand All @@ -350,15 +361,14 @@ def step_data_vector(params):
)
return jnp.sum(data_vector)

test_grad("Step 4: Data vector (D)", step_data_vector, jnp_params)
test_grad("Step 4: Data vector (D)", step_data_vector, params_tree)

# ---------------------------------------------------------------------------
# Step 5: Curvature matrix (F)
# ---------------------------------------------------------------------------

def step_curvature_matrix(params):
inst = model.instance_from_vector(vector=params, xp=jnp)
t = al.Tracer(galaxies=list(inst.galaxies))
t = al.Tracer(galaxies=list(params.galaxies))
tti = al.TracerToInversion(
dataset=aa.DatasetInterface(
data=fit.profile_subtracted_image,
Expand All @@ -385,15 +395,14 @@ def step_curvature_matrix(params):
)
return jnp.sum(curvature)

test_grad("Step 5: Curvature matrix (F)", step_curvature_matrix, jnp_params)
test_grad("Step 5: Curvature matrix (F)", step_curvature_matrix, params_tree)

# ---------------------------------------------------------------------------
# Step 6: Reconstruction (NNLS)
# ---------------------------------------------------------------------------

def step_reconstruction(params):
inst = model.instance_from_vector(vector=params, xp=jnp)
t = al.Tracer(galaxies=list(inst.galaxies))
t = al.Tracer(galaxies=list(params.galaxies))

tti = al.TracerToInversion(
dataset=aa.DatasetInterface(
Expand Down Expand Up @@ -433,15 +442,14 @@ def step_reconstruction(params):
)
return jnp.sum(reconstruction)

test_grad("Step 6: Reconstruction (NNLS)", step_reconstruction, jnp_params)
test_grad("Step 6: Reconstruction (NNLS)", step_reconstruction, params_tree)

# ---------------------------------------------------------------------------
# Step 7: Mapped reconstructed image
# ---------------------------------------------------------------------------

def step_mapped_recon(params):
inst = model.instance_from_vector(vector=params, xp=jnp)
t = al.Tracer(galaxies=list(inst.galaxies))
t = al.Tracer(galaxies=list(params.galaxies))

tti = al.TracerToInversion(
dataset=aa.DatasetInterface(
Expand Down Expand Up @@ -487,15 +495,14 @@ def step_mapped_recon(params):
)
return jnp.sum(mapped_recon)

test_grad("Step 7: Mapped reconstructed image", step_mapped_recon, jnp_params)
test_grad("Step 7: Mapped reconstructed image", step_mapped_recon, params_tree)

# ---------------------------------------------------------------------------
# Step 8: Log likelihood (chi-squared)
# ---------------------------------------------------------------------------

def step_log_likelihood(params):
inst = model.instance_from_vector(vector=params, xp=jnp)
t = al.Tracer(galaxies=list(inst.galaxies))
t = al.Tracer(galaxies=list(params.galaxies))

tti = al.TracerToInversion(
dataset=aa.DatasetInterface(
Expand Down Expand Up @@ -545,7 +552,7 @@ def step_log_likelihood(params):
noise_norm = jnp.sum(jnp.log(2 * jnp.pi * noise_map_array ** 2))
return -0.5 * (chi_squared + noise_norm)

test_grad("Step 8: Log likelihood", step_log_likelihood, jnp_params)
test_grad("Step 8: Log likelihood", step_log_likelihood, params_tree)


# ===================================================================
Expand All @@ -571,8 +578,7 @@ def step_log_likelihood(params):

def _build_Q_q(params):
"""Rebuild (curvature_reg_matrix, data_vector) for the given params."""
inst = model.instance_from_vector(vector=params, xp=jnp)
t = al.Tracer(galaxies=list(inst.galaxies))
t = al.Tracer(galaxies=list(params.galaxies))
tti = al.TracerToInversion(
dataset=aa.DatasetInterface(
data=fit.profile_subtracted_image,
Expand Down Expand Up @@ -605,7 +611,7 @@ def _build_Q_q(params):
return Q, q


Q_eval, q_eval = _build_Q_q(jnp_params)
Q_eval, q_eval = _build_Q_q(params_tree)
Q_np = np.array(Q_eval)
q_np = np.array(q_eval)

Expand Down Expand Up @@ -684,8 +690,13 @@ def _loss(params):
return jnp.sum(x_p)

try:
val, grad = jax.value_and_grad(_loss)(jnp_params)
grad_np = np.array(grad)
val, grad = jax.value_and_grad(_loss)(params_tree)
grad_leaves = jax.tree_util.tree_leaves(grad)
grad_np = (
np.concatenate([np.asarray(l).ravel() for l in grad_leaves])
if grad_leaves
else np.array([])
)
n_nan = int(np.sum(~np.isfinite(grad_np)))
print(f" grad finite entries: {grad_np.size - n_nan}/{grad_np.size}")
if n_nan < grad_np.size:
Expand Down Expand Up @@ -735,26 +746,24 @@ def _loss(params):


# ===================================================================
# PART C -- Full pipeline gradient (via Fitness)
# PART C -- Full pipeline gradient (via AnalysisImaging)
# ===================================================================

print("\n" + "=" * 70)
print("PART C -- FULL PIPELINE GRADIENT (via Fitness)")
print("PART C -- FULL PIPELINE GRADIENT (via AnalysisImaging)")
print("=" * 70)

from autofit.non_linear.fitness import Fitness
analysis = al.AnalysisImaging(dataset=dataset, use_jax=True)

analysis = al.AnalysisImaging(dataset=dataset)
def full_pipeline_from_params(params_tree):
return analysis.log_likelihood_function(instance=params_tree)

fitness = Fitness(
model=model,
analysis=analysis,
fom_is_log_likelihood=True,
resample_figure_of_merit=-1.0e99,
test_grad(
"Full pipeline (AnalysisImaging.log_likelihood)",
full_pipeline_from_params,
params_tree,
)

test_grad("Full pipeline (Fitness.call)", fitness.call, jnp_params)


# ===================================================================
# PART D -- Summary
Expand Down