Skip to content

JAX Tracer #343

@Jammy2211

Description

@Jammy2211

The feature/jax_wrapper branches now mostly work.

For PyAutoLens, many point unit tests give this exception:

autolens/point/solver/point_solver.py:57: in solve
    return aa.Grid2DIrregular([pair for pair in filtered_means])
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <[RecursionError('maximum recursion depth exceeded') raised in repr()] Grid2DIrregular object at 0x7fd9c5ea4bd0>
values = [Traced<ShapedArray(float64[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float64[2])>with<DynamicJaxprTr...float64[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float64[2])>with<DynamicJaxprTrace(level=1/0)>, ...]

    def __init__(self, values: Union[np.ndarray, List]):
        """
        An irregular grid of (y,x) coordinates.

        The `Grid2DIrregular` stores the (y,x) irregular grid of coordinates as 2D NumPy array of shape
        [total_coordinates, 2].

        Calculations should use the NumPy array structure wherever possible for efficient calculations.

        The coordinates input to this function can have any of the following forms (they will be converted to the
        1D NumPy array structure and can be converted back using the object's properties):

        ::

            [(y0,x0), (y1,x1)]
            [[y0,x0], [y1,x1]]

        If your grid lies on a 2D uniform grid of data the `Grid2D` data structure should be used.

        Parameters
        ----------
        values
            The irregular grid of (y,x) coordinates.
        """

        if len(values) == 0:
            super().__init__(values)
            return

        if type(values) is list:
            if isinstance(values[0], Grid2DIrregular):
                values = values
            else:
>               values = np.asarray(values)
E               jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float64[2].
E               The error occurred while tracing the function solve at /mnt/c/Users/Jammy/Code/PyAutoJAX/PyAutoLens/autolens/point/solver/point_solver.py:17 for jit. This concrete value was not available in Python because it depends on the values of the arguments source_plane_coordinate[0] and source_plane_coordinate[1].
E               See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
E               --------------------
E               For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

../PyAutoArray/autoarray/structures/grids/irregular_2d.py:47: TracerArrayConversionError
____________________________________________________________________________________ test_trivial[source_plane_coordinate6] ____________________________________________________________________________________

source_plane_coordinate = (-1.0, -1.0)
grid = Grid2D([[ 4.5, -4.5],
       [ 4.5, -3.5],
       [ 4.5, -2.5],
       [ 4.5, -1.5],
       [ 4.5, -0.5],
       [ 4.5...[-4.5, -0.5],
       [-4.5,  0.5],
       [-4.5,  1.5],
       [-4.5,  2.5],
       [-4.5,  3.5],
       [-4.5,  4.5]])

    @pytest.mark.parametrize(
        "source_plane_coordinate",
        [
            (0.0, 0.0),
            (0.0, 1.0),
            (1.0, 0.0),
            (1.0, 1.0),
            (0.5, 0.5),
            (0.1, 0.1),
            (-1.0, -1.0),
        ],
    )
    def test_trivial(
        source_plane_coordinate: Tuple[float, float],
        grid,
    ):
        solver = PointSolver.for_grid(
            grid=grid,
            pixel_scale_precision=0.01,
        )
>       coordinates = solver.solve(
            tracer=NullTracer(),
            source_plane_coordinate=source_plane_coordinate,
        )

test_autolens/point/triangles/test_solver.py:67:

I suspect this isn't too complicated a fix, let me know what you think.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions