feat: register FitImaging, DatasetModel, Galaxies as pytrees in AnalysisImaging#364
Merged
Merged
Conversation
…sisImaging Adds `_register_fit_imaging_pytrees` staticmethod to `AnalysisImaging`, called from `fit_from` under the existing `use_jax` gate. Mirrors autolens's `AnalysisImaging._register_fit_imaging_pytrees`. Unblocks `jax.jit(analysis.fit_from)(instance)` returning a `FitImaging` with `jax.Array` leaves on the autogalaxy imaging path. Part of PyAutoLabs/autogalaxy_workspace_test#8 (epic #5, task 3/9). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
`Galaxies` inherits from `list`, so the generic `__dict__`-based flatten in `register_instance_pytree` dropped the list contents across the JIT boundary. Register `Galaxies` with a custom flatten/unflatten that carries list items as dynamic children and any `__dict__` state as aux. Without this, `jax.jit(analysis.fit_from)(instance)` round-tripped but the resulting `FitImaging.galaxies` was an empty `Galaxies`, causing `galaxies.image_2d_from` to return `sum([])` (int 0) and fail in autoarray's `Array2D` constructor on `fit.log_likelihood` access. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Merged
3 tasks
Collaborator
Author
|
Workspace PR: PyAutoLabs/autogalaxy_workspace_test#9 |
This was referenced Apr 26, 2026
This was referenced May 8, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Register
FitImaging,DatasetModel, andGalaxiesas JAX pytrees inAnalysisImaging.fit_from, mirroring autolens's_register_fit_imaging_pytreesimplementation. Unblocksjax.jit(analysis.fit_from)(instance)returning aFitImagingwithjax.Arrayleaves on the autogalaxy imaging path. Required by autogalaxy_workspace_test's upcomingjax_likelihood_functions/imaging/scripts (PyAutoLabs/autogalaxy_workspace_test#8).API Changes
None — internal changes only.
_register_fit_imaging_pytreesis a new private staticmethod onAnalysisImaging;fit_fromnow calls it whenuse_jax=True(no behaviour change on the NumPy path). See full details below.Test Plan
test_autogalaxy/imaging/model/test_analysis_imaging.py— new test assertsFitImaging,DatasetModel,Galaxiesregister into_pytree_registered_classestest_autogalaxy/imaging/passes (69/69)test_autogalaxy/analysis/passes (33/33)Full API Changes (for automation & release notes)
Added
autogalaxy.imaging.model.analysis.AnalysisImaging._register_fit_imaging_pytrees()(staticmethod, private) — registersag.FitImaging,aa.DatasetModel, andag.Galaxieswithautoarray.abstract_ndarray.register_instance_pytree.FitImagingis registered withno_flatten=("dataset", "adapt_images", "settings")so per-fit constants ride as pytree aux data.DatasetModelregistration is idempotent — it is already registered byautolens.imaging.model.analysis.AnalysisImaging._register_fit_imaging_pytreesin the autolens package;register_instance_pytreeis guarded by_pytree_registered_classesand skips duplicate calls.Changed Behaviour
AnalysisImaging.fit_from(instance)— now callsself._register_fit_imaging_pytrees()on entry whenself._use_jaxis True. No effect on the NumPy path (use_jax=False).Migration
None — no existing user-visible API is altered.
🤖 Generated with Claude Code