From 397124d521fd8c3f74a70dd3d138922a61b5c408 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Thu, 14 May 2026 13:24:26 +0100 Subject: [PATCH] fix: register DatasetModel pytree in _register_fit_quantity_pytrees MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The imaging and interferometer analogues both register DatasetModel alongside their Fit* class (autogalaxy/imaging/model/analysis.py:186, autogalaxy/interferometer/model/analysis.py:183). The quantity registration shipped in #401 omitted it, leaving a TypeError trap: when jax.jit flattens a FitQuantity, the DatasetModel attribute (reachable via the aa.FitImaging base-class) hits "not a valid JAX type" because it's neither static-aux nor a registered pytree. Surfaced when writing autogalaxy_workspace_test/scripts/quantity/ visualization_jax.py — the script needed a `register_instance_pytree (DatasetModel)` workaround at module level just to JIT-flatten the fit returned by `analysis.fit_for_visualization`. Fix: add `register_instance_pytree(DatasetModel)` to `_register_fit_quantity_pytrees`, mirroring the imaging and interferometer registrations. Workaround in the workspace_test JAX script will be removed in the follow-up PR. 18/18 test_autogalaxy/quantity/ tests pass — registration only fires under `use_jax=True`, no impact on existing NumPy-path tests. --- autogalaxy/quantity/model/analysis.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/autogalaxy/quantity/model/analysis.py b/autogalaxy/quantity/model/analysis.py index f964190b..2001168f 100644 --- a/autogalaxy/quantity/model/analysis.py +++ b/autogalaxy/quantity/model/analysis.py @@ -78,12 +78,14 @@ def _register_fit_quantity_pytrees() -> None: carry the traced model arrays and ride as pytree children. """ from autoarray.abstract_ndarray import register_instance_pytree + from autoarray.dataset.dataset_model import DatasetModel from autogalaxy.analysis.jax_pytrees import register_galaxies_pytree register_instance_pytree( FitQuantity, no_flatten=("dataset", "func_str", "use_mask_in_fit"), ) + register_instance_pytree(DatasetModel) register_galaxies_pytree() def log_likelihood_function(self, instance: af.ModelInstance) -> float: