Skip to content

Commit 4880f48

Browse files
Jammy2211claude
authored andcommitted
fix: make array_2d_via_indexes_from jit-traceable
Replace tuple(native_index_for_slim_index_2d.T) with 2D advanced indexing on the JAX path — the tuple() call iterates the outermost axis of a traced array, triggering TracerArrayConversionError under @jax.jit. Remove the temporary if/else workaround in the imaging simulator and the "jit blocked" docstring caveats. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent c2ab390 commit 4880f48

3 files changed

Lines changed: 6 additions & 23 deletions

File tree

autoarray/dataset/imaging/simulator.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,6 @@ def __init__(
9595
The returned ``Imaging`` carries ``jax.Array`` data — useful for
9696
JAX-eager batched simulation (parameter sweeps, mock-data studies).
9797
Defaults to ``False``.
98-
99-
Note: ``@jax.jit`` wrapping of ``via_tracer_from`` /
100-
``via_galaxies_from`` is currently blocked by an unrelated
101-
pre-existing limitation — ``Array2D.native`` is not jit-traceable
102-
because it goes through indexed assignment in
103-
``array_2d_via_indexes_from``. A separate PyAutoArray task is
104-
needed to refactor the slim/native reshape to be jit-friendly.
105-
Eager JAX usage works today.
10698
"""
10799

108100
if psf is not None:
@@ -232,16 +224,7 @@ def via_image_from(
232224
origin=image.origin,
233225
)
234226

235-
# Re-wrap the image against the all-false mask. Use ``.array`` rather
236-
# than ``.native`` on the JAX path: ``.native`` routes through
237-
# ``array_2d_via_indexes_from`` which is not jit-safe (it builds a
238-
# native-shape array by Python-iterating an index tuple). ``.array``
239-
# returns the raw backing array which the ``Array2D`` constructor
240-
# accepts in either slim or native shape.
241-
if xp is np:
242-
image = Array2D(values=image.native, mask=mask)
243-
else:
244-
image = Array2D(values=image.array, mask=mask)
227+
image = Array2D(values=image.native, mask=mask)
245228

246229
dataset = Imaging(
247230
data=image,

autoarray/dataset/interferometer/simulator.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,7 @@ def __init__(
6464
If ``True``, ``via_image_from`` defaults ``xp`` to ``jax.numpy`` and
6565
the simulator's internal complex-Gaussian noise generation routes
6666
through ``jax.random``. The returned ``Interferometer`` carries
67-
``jax.Array`` visibilities. Mirror of ``SimulatorImaging.use_jax``;
68-
same caveat applies — ``@jax.jit`` wrapping is currently blocked
69-
by autoarray's pre-existing ``.native`` reshape limitation in the
70-
transformer / dataset construction path. Eager JAX usage works.
67+
``jax.Array`` visibilities. Mirror of ``SimulatorImaging.use_jax``.
7168
"""
7269

7370
self.uv_wavelengths = uv_wavelengths

autoarray/structures/arrays/array_2d_util.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,10 @@ def array_2d_via_indexes_from(
557557
array = xp.zeros(shape, dtype=array_2d_slim.dtype)
558558

559559
if xp.__name__.startswith("jax"):
560-
array = array.at[tuple(native_index_for_slim_index_2d.T)].set(array_2d_slim)
560+
array = array.at[
561+
native_index_for_slim_index_2d[:, 0],
562+
native_index_for_slim_index_2d[:, 1],
563+
].set(array_2d_slim)
561564
else:
562565
array[tuple(native_index_for_slim_index_2d.T)] = array_2d_slim
563566

0 commit comments

Comments
 (0)