feat: SimulatorInterferometer(use_jax=True) + xp-aware preprocess Gaussian noise#336
Merged
Merged
Conversation
Adds use_jax constructor flag to aa.SimulatorInterferometer and threads xp through via_image_from. When use_jax=True, complex Gaussian noise routes through jax.random.PRNGKey + jax.random.normal instead of numpy's RNG. Also fixes the pre-existing signature asymmetry: aa.SimulatorImaging.via_image_from accepts xp=None but aa.SimulatorInterferometer.via_image_from did not. Both now match: via_image_from(image, xp=None). preprocess.gaussian_noise_via_shape_and_sigma_from, data_with_gaussian_noise_added, and data_with_complex_gaussian_noise_added all gain xp=np parameter, with JAX path routing through jax.random. The NaN-noise-map runtime guard is now NumPy-side only (same pattern as SimulatorImaging PR — Python `if <tracer>:` triggers TracerBoolConversionError under JAX). Eager JAX path works. @jax.jit wrap of via_image_from is currently blocked by the same pre-existing autoarray .native limitation flagged in the SimulatorImaging PR (slim/native reshape uses indexed assignment, not jit-traceable). Separate task needed for that. Part of Phase 2 PR 3 of z_features/jax_user_intro.md. Companion PRs add SimulatorInterferometer subclass overrides in PyAutoLens and PyAutoGalaxy. Design doc: admin_jammy/notes/jax_interface.md Issue: PyAutoArray#334 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
PR 3 of Phase 2 (
z_features/jax_user_intro.md). Addsuse_jax=Truetoaa.SimulatorInterferometerand routes the complex-Gaussian visibility noise pipeline throughjax.randomon the JAX path. Also fixes the pre-existing signature asymmetry —via_image_fromnow acceptsxp=Nonematchingaa.SimulatorImaging.Eager JAX path works:
simulator.via_image_from(image)withuse_jax=Truereturns anInterferometerwithjax.Arrayvisibilities.@jax.jitwrap is currently blocked by the sameArray2D.nativejit-incompatibility flagged in the SimulatorImaging PR — affects all simulators until a separate slim/native reshape refactor lands.Companion changes in PyAutoLens (
autolens/interferometer/simulator.py) and PyAutoGalaxy (autogalaxy/interferometer/simulator.py) forwardxpfrom the new_xpproperty on the parent.API Changes
aa.SimulatorInterferometer(..., use_jax=False)constructor flag.aa.SimulatorInterferometer._xpproperty.aa.SimulatorInterferometer.via_image_from(image, xp=None)— fixed asymmetry withSimulatorImagingby adding thexpparameter.xp=Nonedefaults toself._xp.preprocess.gaussian_noise_via_shape_and_sigma_from(..., xp=np),data_with_gaussian_noise_added(..., xp=np),data_with_complex_gaussian_noise_added(..., xp=np)all gainxpparameter.See full details below.
Test Plan
test_autoarray/dataset/interferometer/test_simulator_use_jax.py— constructor wiring, signature symmetry.autolens_workspace_test/scripts/interferometer/simulator_use_jax_parity.pyconfirms eager NumPy and JAX paths produce identical visibilities toatol=1e-8(200-visibility noise-free Sersic + Isothermal lens).Full API Changes
Added
aa.SimulatorInterferometer(use_jax: bool = False)constructor parameter.aa.SimulatorInterferometer._xpproperty — returnsjax.numpyifself.use_jax, elsenumpy.Changed signature
aa.SimulatorInterferometer.via_image_from(image, xp=None)— fixed signature asymmetry vsaa.SimulatorImaging.via_image_from(which already hadxp=).xp=Nonedefaults toself._xp.preprocess.gaussian_noise_via_shape_and_sigma_from(shape, sigma, seed=-1, xp=np).preprocess.data_with_gaussian_noise_added(data, sigma, seed=-1, xp=np).preprocess.data_with_complex_gaussian_noise_added(data, sigma, seed=-1, xp=np).Changed behaviour
aa.SimulatorInterferometer.via_image_from: NaN-noise-map runtime guard now NumPy-side only (Pythonif <tracer>:triggersTracerBoolConversionErrorunder JAX).preprocess.gaussian_noise_via_shape_and_sigma_from: JAX path usesjax.random.PRNGKey(seed)+jax.random.normal(scaled bysigma).seed=-1derives a time-based key.Migration
xp=continue to work —use_jaxdefaults toFalse.Known limitations
@jax.jitwrap ofvia_image_fromis currently blocked by the sameArray2D.nativejit-incompatibility flagged in theSimulatorImagingPR. Eager JAX usage works today.Out of scope (separate PRs)
xp=np+ jnp-backed-grid mismatchValueErrorinAbstractMaker.__init__(Phase 2 PR 4).Array2D.nativejit-safety refactor (unblocks @jax.jit for all simulators).🤖 Generated with Claude Code