Skip to content

feat: register FitImaging, DatasetModel, Galaxies as pytrees in AnalysisImaging#364

Merged
Jammy2211 merged 2 commits into
mainfrom
feature/autogalaxy-wst-jax-lh-imaging
Apr 22, 2026
Merged

feat: register FitImaging, DatasetModel, Galaxies as pytrees in AnalysisImaging#364
Jammy2211 merged 2 commits into
mainfrom
feature/autogalaxy-wst-jax-lh-imaging

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

Register FitImaging, DatasetModel, and Galaxies as JAX pytrees in AnalysisImaging.fit_from, mirroring autolens's _register_fit_imaging_pytrees implementation. Unblocks jax.jit(analysis.fit_from)(instance) returning a FitImaging with jax.Array leaves on the autogalaxy imaging path. Required by autogalaxy_workspace_test's upcoming jax_likelihood_functions/imaging/ scripts (PyAutoLabs/autogalaxy_workspace_test#8).

API Changes

None — internal changes only. _register_fit_imaging_pytrees is a new private staticmethod on AnalysisImaging; fit_from now calls it when use_jax=True (no behaviour change on the NumPy path). See full details below.

Test Plan

  • test_autogalaxy/imaging/model/test_analysis_imaging.py — new test asserts FitImaging, DatasetModel, Galaxies register into _pytree_registered_classes
  • test_autogalaxy/imaging/ passes (69/69)
  • test_autogalaxy/analysis/ passes (33/33)
  • Integration: autogalaxy_workspace_test/scripts/jax_likelihood_functions/imaging/ scripts (follow-up PR)
Full API Changes (for automation & release notes)

Added

  • autogalaxy.imaging.model.analysis.AnalysisImaging._register_fit_imaging_pytrees() (staticmethod, private) — registers ag.FitImaging, aa.DatasetModel, and ag.Galaxies with autoarray.abstract_ndarray.register_instance_pytree. FitImaging is registered with no_flatten=("dataset", "adapt_images", "settings") so per-fit constants ride as pytree aux data. DatasetModel registration is idempotent — it is already registered by autolens.imaging.model.analysis.AnalysisImaging._register_fit_imaging_pytrees in the autolens package; register_instance_pytree is guarded by _pytree_registered_classes and skips duplicate calls.

Changed Behaviour

  • AnalysisImaging.fit_from(instance) — now calls self._register_fit_imaging_pytrees() on entry when self._use_jax is True. No effect on the NumPy path (use_jax=False).

Migration

None — no existing user-visible API is altered.

🤖 Generated with Claude Code

…sisImaging

Adds `_register_fit_imaging_pytrees` staticmethod to `AnalysisImaging`, called
from `fit_from` under the existing `use_jax` gate. Mirrors autolens's
`AnalysisImaging._register_fit_imaging_pytrees`. Unblocks
`jax.jit(analysis.fit_from)(instance)` returning a `FitImaging` with
`jax.Array` leaves on the autogalaxy imaging path.

Part of PyAutoLabs/autogalaxy_workspace_test#8 (epic #5, task 3/9).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
`Galaxies` inherits from `list`, so the generic `__dict__`-based flatten in
`register_instance_pytree` dropped the list contents across the JIT boundary.
Register `Galaxies` with a custom flatten/unflatten that carries list items
as dynamic children and any `__dict__` state as aux.

Without this, `jax.jit(analysis.fit_from)(instance)` round-tripped but the
resulting `FitImaging.galaxies` was an empty `Galaxies`, causing
`galaxies.image_2d_from` to return `sum([])` (int 0) and fail in autoarray's
`Array2D` constructor on `fit.log_likelihood` access.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@Jammy2211
Copy link
Copy Markdown
Collaborator Author

Workspace PR: PyAutoLabs/autogalaxy_workspace_test#9

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pending-release PR queued for the next release build

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant