Skip to content

refactor(jax_profiling): migrate imaging/pixelization.py to pytree inputs#12

Merged
Jammy2211 merged 1 commit into
mainfrom
feature/pixelization-pytree-migration
Apr 17, 2026
Merged

refactor(jax_profiling): migrate imaging/pixelization.py to pytree inputs#12
Jammy2211 merged 1 commit into
mainfrom
feature/pixelization-pytree-migration

Conversation

@Jammy2211

Copy link
Copy Markdown
Contributor

Summary

Migrates jax_profiling/imaging/pixelization.py to pass a ModelInstance pytree directly into jit/vmap, dropping the flat-vector + model.instance_from_vector pattern that used to live inside the trace. This brings pixelization.py in line with the mge.py migration shipped in #11 on 2026-04-16 and makes the profiling match how a real pytree-input sampler would actually run.

Required paired library fix: PyAutoFit#1221 — without it, Galaxy(pixelization=<Pixelization>) becomes a JAX tracer under JIT and fit.inversion silently resolves to None.

Scripts Changed

  • jax_profiling/imaging/pixelization.py — call autofit.jax.register_model(model) up front; build params_tree = jax.tree_util.tree_map(jnp.asarray, instance); rewrite the two per-step closures blurred_image_from_params and blurred_mm_from_params to take params_tree directly (no more instance_from_vector inside the trace); replace full-pipeline Fitness wrapper with AnalysisImaging(use_jax=True).log_likelihood_function(instance=...); broadcast pytree leaves via tree_map for the vmap step; reuse vmapped_full.lower(parameters) for static-memory analysis. Per-step raw-array closures (ray-trace, inversion linear algebra) are unchanged — they don't take params.
  • jax_profiling/imaging/results/pixelization_likelihood_summary_hst_v2026.4.13.6.json — regenerated.
  • jax_profiling/imaging/results/pixelization_likelihood_summary_hst_v2026.4.13.6.png — regenerated.

Upstream PR

PyAutoLabs/PyAutoFit#1221

Test Plan

  • Script runs end-to-end with NUMBA_CACHE_DIR=/tmp/numba_cache MPLCONFIGDIR=/tmp/matplotlib python jax_profiling/imaging/pixelization.py.
  • Step-by-step correctness assertion: inversion-matrix log_evidence matches FitImaging.log_evidence reference within rtol=1e-4.
  • vmap correctness assertion: batch-of-3 log_evidence matches single-JIT result within rtol=1e-4.
  • No TracerBoolConversionErrorregister_model partitions redshift, pixelization mesh shape, and any concrete Galaxy kwargs into aux_data.
  • Do NOT merge until PyAutoFit#1221 is merged — workspace depends on the library bug fix.

🤖 Generated with Claude Code

…puts

Replaces the flat-vector JIT/vmap pattern (model.instance_from_vector
called inside the trace) with direct pytree-instance inputs, matching
the mge.py migration shipped in autolens_workspace_developer#11.

Changes:
- Call autofit.jax.register_model(model) up front, then build a
  params_tree via jax.tree_util.tree_map(jnp.asarray, instance).
- Two per-step JIT closures (blurred_image_from_params,
  blurred_mm_from_params) now take params_tree directly. No more
  rebuilding the ModelInstance from a flat vector inside the trace.
- Full-pipeline PART C: Fitness wrapper replaced with
  AnalysisImaging(use_jax=True).log_likelihood_function(instance=...).
- vmap PART D: broadcast pytree leaves via tree_map rather than tiling
  a flat vector; use jax.jit(jax.vmap(full_pipeline_from_params)).
- Static memory PART E: reuse vmapped_full.lower(parameters).

Per-step raw-array closures (ray_trace_raw, inversion linear algebra
matrices, etc.) are unchanged — they don't take params.

Both correctness assertions continue to pass end-to-end:
  - step-by-step log_evidence vs reference: rtol=1e-4 OK
  - vmap batch=3 vs single JIT: rtol=1e-4 OK

Depends on PyAutoFit#1221 (fix for concrete Pixelization kwarg being
traced under JIT instead of kept as aux_data).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@Jammy2211 Jammy2211 added the pending-release Pending release build label Apr 17, 2026
@Jammy2211 Jammy2211 merged commit 2309ec1 into main Apr 17, 2026
@Jammy2211 Jammy2211 deleted the feature/pixelization-pytree-migration branch April 17, 2026 09:25
Jammy2211 pushed a commit that referenced this pull request Apr 17, 2026
Matches the pattern shipped in mge.py (#11) and pixelization.py (#12):
pass a ModelInstance pytree through jit/value_and_grad instead of a
flat 1D vector + Fitness/instance_from_vector inside the trace.

- register_model(model) once after eager instance build
- params_tree = jax.tree_util.tree_map(jnp.asarray, instance)
- per-step closures take params_tree and build Tracer from
  params_tree.galaxies (dropping inner instance_from_vector calls)
- test_grad tree-flattens gradient output for statistics
- PART C full pipeline uses AnalysisImaging(use_jax=True)
  .log_likelihood_function(instance=params_tree) in place of Fitness
- NNLS kappa diagnostic (_diagnose_kappa) uses pytree params

Verified: all 9 gradient tests PASS, all 4 NNLS kappa values report
FULLY FINITE GRADIENTS.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pending-release Pending release build

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant