File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -95,14 +95,6 @@ def __init__(
9595 The returned ``Imaging`` carries ``jax.Array`` data — useful for
9696 JAX-eager batched simulation (parameter sweeps, mock-data studies).
9797 Defaults to ``False``.
98-
99- Note: ``@jax.jit`` wrapping of ``via_tracer_from`` /
100- ``via_galaxies_from`` is currently blocked by an unrelated
101- pre-existing limitation — ``Array2D.native`` is not jit-traceable
102- because it goes through indexed assignment in
103- ``array_2d_via_indexes_from``. A separate PyAutoArray task is
104- needed to refactor the slim/native reshape to be jit-friendly.
105- Eager JAX usage works today.
10698 """
10799
108100 if psf is not None :
@@ -232,16 +224,7 @@ def via_image_from(
232224 origin = image .origin ,
233225 )
234226
235- # Re-wrap the image against the all-false mask. Use ``.array`` rather
236- # than ``.native`` on the JAX path: ``.native`` routes through
237- # ``array_2d_via_indexes_from`` which is not jit-safe (it builds a
238- # native-shape array by Python-iterating an index tuple). ``.array``
239- # returns the raw backing array which the ``Array2D`` constructor
240- # accepts in either slim or native shape.
241- if xp is np :
242- image = Array2D (values = image .native , mask = mask )
243- else :
244- image = Array2D (values = image .array , mask = mask )
227+ image = Array2D (values = image .native , mask = mask )
245228
246229 dataset = Imaging (
247230 data = image ,
Original file line number Diff line number Diff line change @@ -64,10 +64,7 @@ def __init__(
6464 If ``True``, ``via_image_from`` defaults ``xp`` to ``jax.numpy`` and
6565 the simulator's internal complex-Gaussian noise generation routes
6666 through ``jax.random``. The returned ``Interferometer`` carries
67- ``jax.Array`` visibilities. Mirror of ``SimulatorImaging.use_jax``;
68- same caveat applies — ``@jax.jit`` wrapping is currently blocked
69- by autoarray's pre-existing ``.native`` reshape limitation in the
70- transformer / dataset construction path. Eager JAX usage works.
67+ ``jax.Array`` visibilities. Mirror of ``SimulatorImaging.use_jax``.
7168 """
7269
7370 self .uv_wavelengths = uv_wavelengths
Original file line number Diff line number Diff line change @@ -557,7 +557,10 @@ def array_2d_via_indexes_from(
557557 array = xp .zeros (shape , dtype = array_2d_slim .dtype )
558558
559559 if xp .__name__ .startswith ("jax" ):
560- array = array .at [tuple (native_index_for_slim_index_2d .T )].set (array_2d_slim )
560+ array = array .at [
561+ native_index_for_slim_index_2d [:, 0 ],
562+ native_index_for_slim_index_2d [:, 1 ],
563+ ].set (array_2d_slim )
561564 else :
562565 array [tuple (native_index_for_slim_index_2d .T )] = array_2d_slim
563566
You can’t perform that action at this time.
0 commit comments