Skip to content

feat: migrate jax_profiling/imaging/mge.py to pytree inputs#11

Merged
Jammy2211 merged 1 commit into
mainfrom
feature/imaging-mge-pytree-migration
Apr 16, 2026
Merged

feat: migrate jax_profiling/imaging/mge.py to pytree inputs#11
Jammy2211 merged 1 commit into
mainfrom
feature/imaging-mge-pytree-migration

Conversation

@Jammy2211

Copy link
Copy Markdown
Contributor

Summary

Replaces the flat 1D parameter vector used by the MGE JAX profiling script with a pytree ModelInstance built via the new autofit.jax.register_model(model) API. The full-pipeline closure now runs analysis.log_likelihood_function(instance=params_tree) directly under jax.jit + jax.vmap — no instance_from_vector inside the trace.

Fitness.call and model.instance_from_vector are untouched, so the production sampler path is unaffected. This is a workspace-side demonstration of the opt-in pytree API landing upstream.

Scripts Changed

  • jax_profiling/imaging/mge.py — register pytrees once, then pass the ModelInstance pytree directly into jit/vmap; replaces Fitness.call(param_vector) with analysis.log_likelihood_function(instance=params_tree) via al.AnalysisImaging(dataset, use_jax=True); updates the vmap section to broadcast leaves with jax.tree_util.tree_map instead of stacking a flat vector.
  • jax_profiling/imaging/results/mge_likelihood_summary_hst_v2026.4.13.6.json — regenerated summary from the pytree pipeline.
  • jax_profiling/imaging/results/mge_likelihood_summary_hst_v2026.4.13.6.png — regenerated bar chart.

Upstream PRs

Test Plan

  • python jax_profiling/imaging/mge.py runs end-to-end on hst instrument
  • Step-by-step assertion passes: per-step chain matches FitImaging.log_likelihood
  • vmap correctness check passes: batched results identical across the batch
  • Full-pipeline JIT compile + steady-state timings recorded and written to the results summary

🤖 Generated with Claude Code

Replaces the flat 1D parameter vector passed into jit/vmap with a
ModelInstance pytree built via autofit.jax.register_model(model). The
full-pipeline closure now takes the pytree directly and runs
analysis.log_likelihood_function(instance=params_tree) under jit+vmap.

Fitness.call and model.instance_from_vector are unchanged — the
production path is untouched; this is a workspace-side demonstration of
the new opt-in pytree API landing in PyAutoConf #93 and PyAutoFit #1220.

Regenerates the corresponding results summary.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@Jammy2211 Jammy2211 added the pending-release Pending release build label Apr 16, 2026
@Jammy2211 Jammy2211 merged commit 53764c1 into main Apr 16, 2026
@Jammy2211 Jammy2211 deleted the feature/imaging-mge-pytree-migration branch April 16, 2026 20:56
Jammy2211 pushed a commit that referenced this pull request Apr 17, 2026
Matches the pattern shipped in mge.py (#11) and pixelization.py (#12):
pass a ModelInstance pytree through jit/value_and_grad instead of a
flat 1D vector + Fitness/instance_from_vector inside the trace.

- register_model(model) once after eager instance build
- params_tree = jax.tree_util.tree_map(jnp.asarray, instance)
- per-step closures take params_tree and build Tracer from
  params_tree.galaxies (dropping inner instance_from_vector calls)
- test_grad tree-flattens gradient output for statistics
- PART C full pipeline uses AnalysisImaging(use_jax=True)
  .log_likelihood_function(instance=params_tree) in place of Fitness
- NNLS kappa diagnostic (_diagnose_kappa) uses pytree params

Verified: all 9 gradient tests PASS, all 4 NNLS kappa values report
FULLY FINITE GRADIENTS.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pending-release Pending release build

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant