diff --git a/jax_profiling/imaging/mge_gradients.py b/jax_profiling/imaging/mge_gradients.py index 512bb80..76a819a 100644 --- a/jax_profiling/imaging/mge_gradients.py +++ b/jax_profiling/imaging/mge_gradients.py @@ -79,13 +79,20 @@ def test_grad(label, func, params): # Force evaluation. if hasattr(value, "block_until_ready"): value.block_until_ready() - if hasattr(grad, "block_until_ready"): - grad.block_until_ready() val_f = float(value) - grad_np = np.array(grad) + grad_leaves = jax.tree_util.tree_leaves(grad) + for leaf in grad_leaves: + if hasattr(leaf, "block_until_ready"): + leaf.block_until_ready() + grad_np = ( + np.concatenate([np.asarray(l).ravel() for l in grad_leaves]) + if grad_leaves + else np.array([]) + ) print(f" value = {val_f:.8g}") + print(f" grad leaves = {len(grad_leaves)}") print(f" grad shape = {grad_np.shape}") print(f" grad norm = {np.linalg.norm(grad_np):.8g}") print(f" grad min = {grad_np.min():.8g}") @@ -244,6 +251,14 @@ def test_grad(label, func, params): data_array = jnp.array(dataset.data.array) noise_map_array = jnp.array(dataset.noise_map.array) +# Register pytree support for the model and convert the sampled ModelInstance +# into a JAX-array-valued pytree. Each step closure differentiates w.r.t. this +# tree directly rather than rebuilding the instance from a flat vector inside +# the trace. Matches the pattern used by mge.py and pixelization.py. +from autofit.jax import register_model as _register_model_pytrees +_register_model_pytrees(model) +params_tree = jax.tree_util.tree_map(jnp.asarray, instance) + # =================================================================== # PART B -- Per-step gradient testing @@ -258,22 +273,20 @@ def test_grad(label, func, params): # --------------------------------------------------------------------------- def step_ray_trace(params): - inst = model.instance_from_vector(vector=params, xp=jnp) - t = al.Tracer(galaxies=list(inst.galaxies)) + t = al.Tracer(galaxies=list(params.galaxies)) grid_raw = jnp.array(grid_lp.array) grid = aa.Grid2DIrregular(values=grid_raw, xp=jnp) traced = t.traced_grid_2d_list_from(grid=grid, xp=jnp) return jnp.sum(jnp.stack([tg.array for tg in traced])) -test_grad("Step 1: Ray-trace grids", step_ray_trace, jnp_params) +test_grad("Step 1: Ray-trace grids", step_ray_trace, params_tree) # --------------------------------------------------------------------------- # Step 2: Mapping matrix (linear profile images) # --------------------------------------------------------------------------- def step_mapping_matrix(params): - inst = model.instance_from_vector(vector=params, xp=jnp) - t = al.Tracer(galaxies=list(inst.galaxies)) + t = al.Tracer(galaxies=list(params.galaxies)) tti = al.TracerToInversion( dataset=aa.DatasetInterface( data=fit.profile_subtracted_image, @@ -291,15 +304,14 @@ def step_mapping_matrix(params): mm = jnp.hstack(matrices) if len(matrices) > 1 else matrices[0] return jnp.sum(mm) -test_grad("Step 2: Mapping matrix", step_mapping_matrix, jnp_params) +test_grad("Step 2: Mapping matrix", step_mapping_matrix, params_tree) # --------------------------------------------------------------------------- # Step 3: Blurred mapping matrix (PSF convolution of each profile) # --------------------------------------------------------------------------- def step_blurred_mapping_matrix(params): - inst = model.instance_from_vector(vector=params, xp=jnp) - t = al.Tracer(galaxies=list(inst.galaxies)) + t = al.Tracer(galaxies=list(params.galaxies)) tti = al.TracerToInversion( dataset=aa.DatasetInterface( data=fit.profile_subtracted_image, @@ -317,15 +329,14 @@ def step_blurred_mapping_matrix(params): bmm = jnp.hstack(matrices) if len(matrices) > 1 else matrices[0] return jnp.sum(bmm) -test_grad("Step 3: Blurred mapping matrix", step_blurred_mapping_matrix, jnp_params) +test_grad("Step 3: Blurred mapping matrix", step_blurred_mapping_matrix, params_tree) # --------------------------------------------------------------------------- # Step 4: Data vector (D) # --------------------------------------------------------------------------- def step_data_vector(params): - inst = model.instance_from_vector(vector=params, xp=jnp) - t = al.Tracer(galaxies=list(inst.galaxies)) + t = al.Tracer(galaxies=list(params.galaxies)) tti = al.TracerToInversion( dataset=aa.DatasetInterface( @@ -350,15 +361,14 @@ def step_data_vector(params): ) return jnp.sum(data_vector) -test_grad("Step 4: Data vector (D)", step_data_vector, jnp_params) +test_grad("Step 4: Data vector (D)", step_data_vector, params_tree) # --------------------------------------------------------------------------- # Step 5: Curvature matrix (F) # --------------------------------------------------------------------------- def step_curvature_matrix(params): - inst = model.instance_from_vector(vector=params, xp=jnp) - t = al.Tracer(galaxies=list(inst.galaxies)) + t = al.Tracer(galaxies=list(params.galaxies)) tti = al.TracerToInversion( dataset=aa.DatasetInterface( data=fit.profile_subtracted_image, @@ -385,15 +395,14 @@ def step_curvature_matrix(params): ) return jnp.sum(curvature) -test_grad("Step 5: Curvature matrix (F)", step_curvature_matrix, jnp_params) +test_grad("Step 5: Curvature matrix (F)", step_curvature_matrix, params_tree) # --------------------------------------------------------------------------- # Step 6: Reconstruction (NNLS) # --------------------------------------------------------------------------- def step_reconstruction(params): - inst = model.instance_from_vector(vector=params, xp=jnp) - t = al.Tracer(galaxies=list(inst.galaxies)) + t = al.Tracer(galaxies=list(params.galaxies)) tti = al.TracerToInversion( dataset=aa.DatasetInterface( @@ -433,15 +442,14 @@ def step_reconstruction(params): ) return jnp.sum(reconstruction) -test_grad("Step 6: Reconstruction (NNLS)", step_reconstruction, jnp_params) +test_grad("Step 6: Reconstruction (NNLS)", step_reconstruction, params_tree) # --------------------------------------------------------------------------- # Step 7: Mapped reconstructed image # --------------------------------------------------------------------------- def step_mapped_recon(params): - inst = model.instance_from_vector(vector=params, xp=jnp) - t = al.Tracer(galaxies=list(inst.galaxies)) + t = al.Tracer(galaxies=list(params.galaxies)) tti = al.TracerToInversion( dataset=aa.DatasetInterface( @@ -487,15 +495,14 @@ def step_mapped_recon(params): ) return jnp.sum(mapped_recon) -test_grad("Step 7: Mapped reconstructed image", step_mapped_recon, jnp_params) +test_grad("Step 7: Mapped reconstructed image", step_mapped_recon, params_tree) # --------------------------------------------------------------------------- # Step 8: Log likelihood (chi-squared) # --------------------------------------------------------------------------- def step_log_likelihood(params): - inst = model.instance_from_vector(vector=params, xp=jnp) - t = al.Tracer(galaxies=list(inst.galaxies)) + t = al.Tracer(galaxies=list(params.galaxies)) tti = al.TracerToInversion( dataset=aa.DatasetInterface( @@ -545,7 +552,7 @@ def step_log_likelihood(params): noise_norm = jnp.sum(jnp.log(2 * jnp.pi * noise_map_array ** 2)) return -0.5 * (chi_squared + noise_norm) -test_grad("Step 8: Log likelihood", step_log_likelihood, jnp_params) +test_grad("Step 8: Log likelihood", step_log_likelihood, params_tree) # =================================================================== @@ -571,8 +578,7 @@ def step_log_likelihood(params): def _build_Q_q(params): """Rebuild (curvature_reg_matrix, data_vector) for the given params.""" - inst = model.instance_from_vector(vector=params, xp=jnp) - t = al.Tracer(galaxies=list(inst.galaxies)) + t = al.Tracer(galaxies=list(params.galaxies)) tti = al.TracerToInversion( dataset=aa.DatasetInterface( data=fit.profile_subtracted_image, @@ -605,7 +611,7 @@ def _build_Q_q(params): return Q, q -Q_eval, q_eval = _build_Q_q(jnp_params) +Q_eval, q_eval = _build_Q_q(params_tree) Q_np = np.array(Q_eval) q_np = np.array(q_eval) @@ -684,8 +690,13 @@ def _loss(params): return jnp.sum(x_p) try: - val, grad = jax.value_and_grad(_loss)(jnp_params) - grad_np = np.array(grad) + val, grad = jax.value_and_grad(_loss)(params_tree) + grad_leaves = jax.tree_util.tree_leaves(grad) + grad_np = ( + np.concatenate([np.asarray(l).ravel() for l in grad_leaves]) + if grad_leaves + else np.array([]) + ) n_nan = int(np.sum(~np.isfinite(grad_np))) print(f" grad finite entries: {grad_np.size - n_nan}/{grad_np.size}") if n_nan < grad_np.size: @@ -735,26 +746,24 @@ def _loss(params): # =================================================================== -# PART C -- Full pipeline gradient (via Fitness) +# PART C -- Full pipeline gradient (via AnalysisImaging) # =================================================================== print("\n" + "=" * 70) -print("PART C -- FULL PIPELINE GRADIENT (via Fitness)") +print("PART C -- FULL PIPELINE GRADIENT (via AnalysisImaging)") print("=" * 70) -from autofit.non_linear.fitness import Fitness +analysis = al.AnalysisImaging(dataset=dataset, use_jax=True) -analysis = al.AnalysisImaging(dataset=dataset) +def full_pipeline_from_params(params_tree): + return analysis.log_likelihood_function(instance=params_tree) -fitness = Fitness( - model=model, - analysis=analysis, - fom_is_log_likelihood=True, - resample_figure_of_merit=-1.0e99, +test_grad( + "Full pipeline (AnalysisImaging.log_likelihood)", + full_pipeline_from_params, + params_tree, ) -test_grad("Full pipeline (Fitness.call)", fitness.call, jnp_params) - # =================================================================== # PART D -- Summary