From 9ceda4e24da78ce1e165642cc1079faf56444c35 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 24 May 2026 16:56:47 +0100 Subject: [PATCH] feat: add use_jax=True to SimulatorInterferometer + xp-aware preprocess MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds use_jax constructor flag to aa.SimulatorInterferometer and threads xp through via_image_from. When use_jax=True, complex Gaussian noise routes through jax.random.PRNGKey + jax.random.normal instead of numpy's RNG. Also fixes the pre-existing signature asymmetry: aa.SimulatorImaging.via_image_from accepts xp=None but aa.SimulatorInterferometer.via_image_from did not. Both now match: via_image_from(image, xp=None). preprocess.gaussian_noise_via_shape_and_sigma_from, data_with_gaussian_noise_added, and data_with_complex_gaussian_noise_added all gain xp=np parameter, with JAX path routing through jax.random. The NaN-noise-map runtime guard is now NumPy-side only (same pattern as SimulatorImaging PR — Python `if :` triggers TracerBoolConversionError under JAX). Eager JAX path works. @jax.jit wrap of via_image_from is currently blocked by the same pre-existing autoarray .native limitation flagged in the SimulatorImaging PR (slim/native reshape uses indexed assignment, not jit-traceable). Separate task needed for that. Part of Phase 2 PR 3 of z_features/jax_user_intro.md. Companion PRs add SimulatorInterferometer subclass overrides in PyAutoLens and PyAutoGalaxy. Design doc: admin_jammy/notes/jax_interface.md Issue: PyAutoArray#334 Co-Authored-By: Claude Opus 4.7 (1M context) --- autoarray/dataset/interferometer/simulator.py | 36 +++++++++++++++++-- autoarray/dataset/preprocess.py | 33 +++++++++++------ .../interferometer/test_simulator_use_jax.py | 35 ++++++++++++++++++ 3 files changed, 90 insertions(+), 14 deletions(-) create mode 100644 test_autoarray/dataset/interferometer/test_simulator_use_jax.py diff --git a/autoarray/dataset/interferometer/simulator.py b/autoarray/dataset/interferometer/simulator.py index 52455cddb..e13dd15b5 100644 --- a/autoarray/dataset/interferometer/simulator.py +++ b/autoarray/dataset/interferometer/simulator.py @@ -17,6 +17,7 @@ def __init__( noise_sigma=0.1, noise_if_add_noise_false=0.1, noise_seed=-1, + use_jax: bool = False, ): """ Simulates observations of `Interferometer` data, including transforming a real-space image to @@ -59,6 +60,14 @@ def __init__( noise_seed The random seed used for noise generation. A value of -1 uses a different random seed on every run, producing different noise realisations each time. + use_jax + If ``True``, ``via_image_from`` defaults ``xp`` to ``jax.numpy`` and + the simulator's internal complex-Gaussian noise generation routes + through ``jax.random``. The returned ``Interferometer`` carries + ``jax.Array`` visibilities. Mirror of ``SimulatorImaging.use_jax``; + same caveat applies — ``@jax.jit`` wrapping is currently blocked + by autoarray's pre-existing ``.native`` reshape limitation in the + transformer / dataset construction path. Eager JAX usage works. """ self.uv_wavelengths = uv_wavelengths @@ -67,8 +76,20 @@ def __init__( self.noise_sigma = noise_sigma self.noise_if_add_noise_false = noise_if_add_noise_false self.noise_seed = noise_seed + self.use_jax = use_jax - def via_image_from(self, image): + @property + def _xp(self): + """The array module the simulator runs against by default. ``jnp`` when + ``use_jax=True``, ``np`` otherwise. ``via_image_from`` falls back to + this when the caller does not pass ``xp=`` explicitly.""" + if self.use_jax: + import jax.numpy as jnp + + return jnp + return np + + def via_image_from(self, image, xp=None): """ Simulate an `Interferometer` dataset from an input real-space image. @@ -83,6 +104,10 @@ def via_image_from(self, image): The 2D real-space image from which the interferometer dataset is simulated (e.g. the surface brightness of a galaxy or lens system). Must be an `Array2D` with an associated mask that defines the real-space region used for the Fourier transform. + xp + The array module. When ``None`` (the default), falls back to ``self._xp`` — + ``jnp`` if the simulator was constructed with ``use_jax=True``, ``np`` + otherwise. Pass explicitly to override. Returns ------- @@ -90,6 +115,8 @@ def via_image_from(self, image): The simulated interferometer dataset containing visibilities, noise map, uv_wavelengths and the real-space mask derived from the input image. """ + if xp is None: + xp = self._xp transformer = self.transformer_class( uv_wavelengths=self.uv_wavelengths, real_space_mask=image.mask @@ -99,7 +126,7 @@ def via_image_from(self, image): if self.noise_sigma is not None: visibilities = preprocess.data_with_complex_gaussian_noise_added( - data=visibilities, sigma=self.noise_sigma, seed=self.noise_seed + data=visibilities, sigma=self.noise_sigma, seed=self.noise_seed, xp=xp ) noise_map = VisibilitiesNoiseMap.full( fill_value=self.noise_sigma, shape_slim=(visibilities.shape[0],) @@ -110,7 +137,10 @@ def via_image_from(self, image): shape_slim=(visibilities.shape[0],), ) - if np.isnan(noise_map).any(): + # NaN-noise guard is NumPy-side only — Python `if :` triggers + # TracerBoolConversionError under JAX. JAX users must confirm + # noise_sigma is positive themselves. + if xp is np and np.isnan(noise_map).any(): raise exc.DatasetException( "The noise-map has NaN values in it. This suggests your exposure time and / or" "background sky levels are too low, creating signal counts at or close to 0.0." diff --git a/autoarray/dataset/preprocess.py b/autoarray/dataset/preprocess.py index a7b96a5d7..5b27b6f7f 100644 --- a/autoarray/dataset/preprocess.py +++ b/autoarray/dataset/preprocess.py @@ -525,7 +525,7 @@ def data_eps_with_poisson_noise_added(data_eps, exposure_time_map, seed=-1, xp=n ) -def gaussian_noise_via_shape_and_sigma_from(shape, sigma, seed=-1): +def gaussian_noise_via_shape_and_sigma_from(shape, sigma, seed=-1, xp=np): """Generate a two-dimensional read noises-map, generating values from a Gaussian distribution with mean 0.0. Params @@ -536,24 +536,35 @@ def gaussian_noise_via_shape_and_sigma_from(shape, sigma, seed=-1): Standard deviation of the 1D Gaussian that each noises value is drawn from seed The seed of the random number generator, used for the random noises maps. + xp + The array module (``numpy`` or ``jax.numpy``). On the JAX path the seed + is used to construct a ``jax.random.PRNGKey``; ``seed=-1`` (random per + call) falls back to a time-derived 32-bit integer key. """ - if seed == -1: - # Use one seed, so all regions have identical column non-uniformity. - seed = np.random.randint(0, int(1e9)) - np.random.seed(seed) - read_noise_map = np.random.normal(loc=0.0, scale=sigma, size=shape) - return read_noise_map + if xp is np: + if seed == -1: + # Use one seed, so all regions have identical column non-uniformity. + seed = np.random.randint(0, int(1e9)) + np.random.seed(seed) + read_noise_map = np.random.normal(loc=0.0, scale=sigma, size=shape) + return read_noise_map + + import jax.random + + effective_seed = seed if seed != -1 else int(time.time() * 1e6) & 0xFFFFFFFF + key = jax.random.PRNGKey(effective_seed) + return sigma * jax.random.normal(key, shape) -def data_with_gaussian_noise_added(data, sigma, seed=-1): +def data_with_gaussian_noise_added(data, sigma, seed=-1, xp=np): return data + gaussian_noise_via_shape_and_sigma_from( - shape=data.shape, sigma=sigma, seed=seed + shape=data.shape, sigma=sigma, seed=seed, xp=xp ) -def data_with_complex_gaussian_noise_added(data, sigma, seed=-1): +def data_with_complex_gaussian_noise_added(data, sigma, seed=-1, xp=np): gaussian_noise = gaussian_noise_via_shape_and_sigma_from( - shape=(data.shape[0], 2), sigma=sigma, seed=seed + shape=(data.shape[0], 2), sigma=sigma, seed=seed, xp=xp ) return data + gaussian_noise[:, 0] + 1.0j * gaussian_noise[:, 1] diff --git a/test_autoarray/dataset/interferometer/test_simulator_use_jax.py b/test_autoarray/dataset/interferometer/test_simulator_use_jax.py new file mode 100644 index 000000000..96e08e9b4 --- /dev/null +++ b/test_autoarray/dataset/interferometer/test_simulator_use_jax.py @@ -0,0 +1,35 @@ +"""Unit tests for ``SimulatorInterferometer(use_jax=True)`` constructor wiring. + +Library unit tests stay NumPy-only per [[feedback_no_jax_in_unit_tests]]; +cross-xp numerical parity lives in the workspace_test parity script. +""" +import numpy as np + +import autoarray as aa + + +def test_use_jax_defaults_false(): + simulator = aa.SimulatorInterferometer( + uv_wavelengths=np.array([[0.1, 0.2], [0.3, 0.4]]), + exposure_time=300.0, + ) + assert simulator.use_jax is False + assert simulator._xp is np + + +def test_use_jax_true_flag_stored(): + simulator = aa.SimulatorInterferometer( + uv_wavelengths=np.array([[0.1, 0.2], [0.3, 0.4]]), + exposure_time=300.0, + use_jax=True, + ) + assert simulator.use_jax is True + + +def test_via_image_from_accepts_xp_param(): + """via_image_from now accepts xp= (signature symmetry with SimulatorImaging).""" + import inspect + + sig = inspect.signature(aa.SimulatorInterferometer.via_image_from) + assert "xp" in sig.parameters + assert sig.parameters["xp"].default is None