From aee5e94d3948fe01aa7837fee1a5dc9b6aa65be8 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sat, 16 May 2026 14:22:31 +0100 Subject: [PATCH] fix(inversion): make AbstractMeshGeometry picklable MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Carve-out from PyAutoFit #1279 Q2 (Phase 4 of the JAX visualization roadmap). A picklability spike found FitImaging cannot be pickled today because AbstractMeshGeometry.__init__ stores `self._xp = xp` — the literal numpy or jax.numpy module — and Python's pickle cannot serialise module references. This blocked sending a FitImaging over an mp.Process+Queue or ProcessPoolExecutor boundary, which is the production target for Phase 4 subprocess visualization. Replace the module-attribute pattern with `self._use_jax: bool` + `_xp` as a property derived from that flag. Same pattern already used in Analysis._xp (PyAutoFit) and AbstractMaker._xp (PyAutoArray decorators per CLAUDE.md). All existing `self._xp` reads continue to work transparently via the property. End-to-end verified: a populated FitImaging round-trips through pickle.dumps/loads with log_likelihood delta=0.00e+00 on both numpy and JAX backends. Pickle size ~4.6 MB for a Rectangular-adaptive- density pixelization fit. Closes #320. Carve-out from #1279. --- .../inversion/mesh/mesh_geometry/abstract.py | 9 ++- .../mesh_geometry/test_picklability.py | 77 +++++++++++++++++++ 2 files changed, 85 insertions(+), 1 deletion(-) create mode 100644 test_autoarray/inversion/pixelization/mesh_geometry/test_picklability.py diff --git a/autoarray/inversion/mesh/mesh_geometry/abstract.py b/autoarray/inversion/mesh/mesh_geometry/abstract.py index bc448fcc1..8ac7f1dc5 100644 --- a/autoarray/inversion/mesh/mesh_geometry/abstract.py +++ b/autoarray/inversion/mesh/mesh_geometry/abstract.py @@ -20,4 +20,11 @@ def __init__( # When non-None, rectangular geometry uses the spline-CDF helpers # instead of the linear-interp CDF (areas / edges transforms only). self.spline_deg = spline_deg - self._xp = xp + self._use_jax = xp is not np + + @property + def _xp(self): + if self._use_jax: + import jax.numpy as jnp + return jnp + return np diff --git a/test_autoarray/inversion/pixelization/mesh_geometry/test_picklability.py b/test_autoarray/inversion/pixelization/mesh_geometry/test_picklability.py new file mode 100644 index 000000000..d869febed --- /dev/null +++ b/test_autoarray/inversion/pixelization/mesh_geometry/test_picklability.py @@ -0,0 +1,77 @@ +"""Pickle round-trip tests for AbstractMeshGeometry subclasses. + +Required by subprocess visualization (PyAutoFit #1279, Phase 4 of the JAX +visualization roadmap). A populated FitImaging must be sendable over an +mp.Process+Queue or ProcessPoolExecutor boundary — historically blocked +by `self._xp = xp` (module attribute) on AbstractMeshGeometry, since +Python's pickle cannot serialise module references. + +Fix: `_xp` is now a property derived from a boolean `_use_jax` flag. +This file is the regression test for that invariant. +""" + +import importlib.util +import pickle + +import numpy as np +import pytest + +from autoarray.inversion.mesh.mesh_geometry.rectangular import MeshGeometryRectangular +from autoarray.inversion.mesh.mesh_geometry.delaunay import MeshGeometryDelaunay + + +def _jax_installed() -> bool: + return importlib.util.find_spec("jax") is not None + + +@pytest.mark.parametrize("cls", [MeshGeometryRectangular, MeshGeometryDelaunay]) +def test_pickle_round_trip_numpy_backend(cls): + """A numpy-backed MeshGeometry instance must round-trip through pickle + with `_xp` restored to the numpy module.""" + mg = cls.__new__(cls) + mg._use_jax = False + + restored = pickle.loads(pickle.dumps(mg)) + + assert restored._use_jax is False + assert restored._xp is np + + +@pytest.mark.skipif(not _jax_installed(), reason="jax not installed") +@pytest.mark.parametrize("cls", [MeshGeometryRectangular, MeshGeometryDelaunay]) +def test_pickle_round_trip_jax_backend(cls): + """A JAX-backed MeshGeometry instance must round-trip through pickle + with `_xp` restored to the jax.numpy module.""" + import jax.numpy as jnp + + mg = cls.__new__(cls) + mg._use_jax = True + + restored = pickle.loads(pickle.dumps(mg)) + + assert restored._use_jax is True + assert restored._xp is jnp + + +def test_use_jax_inferred_from_xp_kwarg_in_init(): + """The __init__ continues to accept an `xp=` kwarg but stores it as + a boolean — modules never land on the instance.""" + import types + + # Use __new__ + manual __init__ call with stubbed positional args to + # avoid pulling in Mesh / Grid construction. The invariant under test + # is post-__init__ state. + mg = MeshGeometryRectangular.__new__(MeshGeometryRectangular) + MeshGeometryRectangular.__init__( + mg, + mesh=None, + mesh_grid=None, + data_grid=None, + xp=np, + ) + assert mg._use_jax is False + # No module attribute should exist on the instance. + module_attrs = [ + name for name, val in vars(mg).items() if isinstance(val, types.ModuleType) + ] + assert module_attrs == [], f"unexpected module attrs on instance: {module_attrs}"