Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions autoarray/dataset/interferometer/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(
noise_sigma=0.1,
noise_if_add_noise_false=0.1,
noise_seed=-1,
use_jax: bool = False,
):
"""
Simulates observations of `Interferometer` data, including transforming a real-space image to
Expand Down Expand Up @@ -59,6 +60,14 @@ def __init__(
noise_seed
The random seed used for noise generation. A value of -1 uses a different random seed
on every run, producing different noise realisations each time.
use_jax
If ``True``, ``via_image_from`` defaults ``xp`` to ``jax.numpy`` and
the simulator's internal complex-Gaussian noise generation routes
through ``jax.random``. The returned ``Interferometer`` carries
``jax.Array`` visibilities. Mirror of ``SimulatorImaging.use_jax``;
same caveat applies — ``@jax.jit`` wrapping is currently blocked
by autoarray's pre-existing ``.native`` reshape limitation in the
transformer / dataset construction path. Eager JAX usage works.
"""

self.uv_wavelengths = uv_wavelengths
Expand All @@ -67,8 +76,20 @@ def __init__(
self.noise_sigma = noise_sigma
self.noise_if_add_noise_false = noise_if_add_noise_false
self.noise_seed = noise_seed
self.use_jax = use_jax

def via_image_from(self, image):
@property
def _xp(self):
"""The array module the simulator runs against by default. ``jnp`` when
``use_jax=True``, ``np`` otherwise. ``via_image_from`` falls back to
this when the caller does not pass ``xp=`` explicitly."""
if self.use_jax:
import jax.numpy as jnp

return jnp
return np

def via_image_from(self, image, xp=None):
"""
Simulate an `Interferometer` dataset from an input real-space image.

Expand All @@ -83,13 +104,19 @@ def via_image_from(self, image):
The 2D real-space image from which the interferometer dataset is simulated (e.g. the
surface brightness of a galaxy or lens system). Must be an `Array2D` with an associated
mask that defines the real-space region used for the Fourier transform.
xp
The array module. When ``None`` (the default), falls back to ``self._xp`` —
``jnp`` if the simulator was constructed with ``use_jax=True``, ``np``
otherwise. Pass explicitly to override.

Returns
-------
Interferometer
The simulated interferometer dataset containing visibilities, noise map, uv_wavelengths
and the real-space mask derived from the input image.
"""
if xp is None:
xp = self._xp

transformer = self.transformer_class(
uv_wavelengths=self.uv_wavelengths, real_space_mask=image.mask
Expand All @@ -99,7 +126,7 @@ def via_image_from(self, image):

if self.noise_sigma is not None:
visibilities = preprocess.data_with_complex_gaussian_noise_added(
data=visibilities, sigma=self.noise_sigma, seed=self.noise_seed
data=visibilities, sigma=self.noise_sigma, seed=self.noise_seed, xp=xp
)
noise_map = VisibilitiesNoiseMap.full(
fill_value=self.noise_sigma, shape_slim=(visibilities.shape[0],)
Expand All @@ -110,7 +137,10 @@ def via_image_from(self, image):
shape_slim=(visibilities.shape[0],),
)

if np.isnan(noise_map).any():
# NaN-noise guard is NumPy-side only — Python `if <tracer>:` triggers
# TracerBoolConversionError under JAX. JAX users must confirm
# noise_sigma is positive themselves.
if xp is np and np.isnan(noise_map).any():
raise exc.DatasetException(
"The noise-map has NaN values in it. This suggests your exposure time and / or"
"background sky levels are too low, creating signal counts at or close to 0.0."
Expand Down
33 changes: 22 additions & 11 deletions autoarray/dataset/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def data_eps_with_poisson_noise_added(data_eps, exposure_time_map, seed=-1, xp=n
)


def gaussian_noise_via_shape_and_sigma_from(shape, sigma, seed=-1):
def gaussian_noise_via_shape_and_sigma_from(shape, sigma, seed=-1, xp=np):
"""Generate a two-dimensional read noises-map, generating values from a Gaussian distribution with mean 0.0.

Params
Expand All @@ -536,24 +536,35 @@ def gaussian_noise_via_shape_and_sigma_from(shape, sigma, seed=-1):
Standard deviation of the 1D Gaussian that each noises value is drawn from
seed
The seed of the random number generator, used for the random noises maps.
xp
The array module (``numpy`` or ``jax.numpy``). On the JAX path the seed
is used to construct a ``jax.random.PRNGKey``; ``seed=-1`` (random per
call) falls back to a time-derived 32-bit integer key.
"""
if seed == -1:
# Use one seed, so all regions have identical column non-uniformity.
seed = np.random.randint(0, int(1e9))
np.random.seed(seed)
read_noise_map = np.random.normal(loc=0.0, scale=sigma, size=shape)
return read_noise_map
if xp is np:
if seed == -1:
# Use one seed, so all regions have identical column non-uniformity.
seed = np.random.randint(0, int(1e9))
np.random.seed(seed)
read_noise_map = np.random.normal(loc=0.0, scale=sigma, size=shape)
return read_noise_map

import jax.random

effective_seed = seed if seed != -1 else int(time.time() * 1e6) & 0xFFFFFFFF
key = jax.random.PRNGKey(effective_seed)
return sigma * jax.random.normal(key, shape)


def data_with_gaussian_noise_added(data, sigma, seed=-1):
def data_with_gaussian_noise_added(data, sigma, seed=-1, xp=np):
return data + gaussian_noise_via_shape_and_sigma_from(
shape=data.shape, sigma=sigma, seed=seed
shape=data.shape, sigma=sigma, seed=seed, xp=xp
)


def data_with_complex_gaussian_noise_added(data, sigma, seed=-1):
def data_with_complex_gaussian_noise_added(data, sigma, seed=-1, xp=np):
gaussian_noise = gaussian_noise_via_shape_and_sigma_from(
shape=(data.shape[0], 2), sigma=sigma, seed=seed
shape=(data.shape[0], 2), sigma=sigma, seed=seed, xp=xp
)

return data + gaussian_noise[:, 0] + 1.0j * gaussian_noise[:, 1]
Expand Down
35 changes: 35 additions & 0 deletions test_autoarray/dataset/interferometer/test_simulator_use_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Unit tests for ``SimulatorInterferometer(use_jax=True)`` constructor wiring.

Library unit tests stay NumPy-only per [[feedback_no_jax_in_unit_tests]];
cross-xp numerical parity lives in the workspace_test parity script.
"""
import numpy as np

import autoarray as aa


def test_use_jax_defaults_false():
simulator = aa.SimulatorInterferometer(
uv_wavelengths=np.array([[0.1, 0.2], [0.3, 0.4]]),
exposure_time=300.0,
)
assert simulator.use_jax is False
assert simulator._xp is np


def test_use_jax_true_flag_stored():
simulator = aa.SimulatorInterferometer(
uv_wavelengths=np.array([[0.1, 0.2], [0.3, 0.4]]),
exposure_time=300.0,
use_jax=True,
)
assert simulator.use_jax is True


def test_via_image_from_accepts_xp_param():
"""via_image_from now accepts xp= (signature symmetry with SimulatorImaging)."""
import inspect

sig = inspect.signature(aa.SimulatorInterferometer.via_image_from)
assert "xp" in sig.parameters
assert sig.parameters["xp"].default is None
Loading