Skip to content

Migrate mge_gradients.py JAX profiling to pytree inputs#13

Merged
Jammy2211 merged 1 commit into
mainfrom
feature/mge-gradients-pytree-migration
Apr 17, 2026
Merged

Migrate mge_gradients.py JAX profiling to pytree inputs#13
Jammy2211 merged 1 commit into
mainfrom
feature/mge-gradients-pytree-migration

Conversation

@Jammy2211

Copy link
Copy Markdown
Contributor

Summary

Migrates jax_profiling/imaging/mge_gradients.py to pass a ModelInstance pytree through jit/value_and_grad instead of a flat 1D parameter vector. Matches the pattern shipped in #11 (mge.py) and #12 (pixelization.py), now that autofit.jax.register_model is on main (PyAutoFit#1220, #1221).

Removes Fitness + Model.instance_from_vector from JIT traces so gradient profiling reflects how a real pytree-native sampler would run.

Scripts Changed

  • jax_profiling/imaging/mge_gradients.py — per-step JIT closures take params_tree (built once via jax.tree_util.tree_map(jnp.asarray, instance)) instead of a flat jnp_params vector; test_grad helper tree-flattens the gradient output; PART C full-pipeline gradient now calls AnalysisImaging(use_jax=True).log_likelihood_function(instance=params_tree) in place of Fitness.call; NNLS _diagnose_kappa inner loss updated the same way.

Test Plan

  • Script runs end-to-end from the workspace root
  • All 9 per-step gradient tests PASS (norms match pre-migration)
  • All 4 NNLS kappa diagnostics (1e-3, 1e-2, 1e-1, 1.0) report "FULLY FINITE GRADIENTS"
  • Full-pipeline gradient (PART C) PASSES via AnalysisImaging

Related: follow-up to #10.

🤖 Generated with Claude Code

Matches the pattern shipped in mge.py (#11) and pixelization.py (#12):
pass a ModelInstance pytree through jit/value_and_grad instead of a
flat 1D vector + Fitness/instance_from_vector inside the trace.

- register_model(model) once after eager instance build
- params_tree = jax.tree_util.tree_map(jnp.asarray, instance)
- per-step closures take params_tree and build Tracer from
  params_tree.galaxies (dropping inner instance_from_vector calls)
- test_grad tree-flattens gradient output for statistics
- PART C full pipeline uses AnalysisImaging(use_jax=True)
  .log_likelihood_function(instance=params_tree) in place of Fitness
- NNLS kappa diagnostic (_diagnose_kappa) uses pytree params

Verified: all 9 gradient tests PASS, all 4 NNLS kappa values report
FULLY FINITE GRADIENTS.
@Jammy2211 Jammy2211 merged commit d811108 into main Apr 17, 2026
@Jammy2211 Jammy2211 deleted the feature/mge-gradients-pytree-migration branch April 17, 2026 09:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant