Skip to content

Conversation

@steppi
Copy link

@steppi steppi commented Jan 3, 2026

Closes #488

In #488 I asked about applying something like lazy_xp_function to classes. This was a half-baked idea motivated by wanting xp_capabilities in SciPy, which everyone here should be familiar with, to always be applied at the class level instead of the individual method level. In scipy/scipy#24267 I created a workflow that keeps xp_capabilities applied to classes, but still allows capabilities to be defined separately for individual methods if needed, and allows lazy_xp_function to be applied to individual methods separately. This means no nonsense about trying to automatically determine which methods of a class should have lazy_xp_function applied to them.

While working on this, I settled on specifying methods with tuples of the form Tuple[type, str] specifying an (uninstantiated) class and a method name. This is to allow distinguishing things like A.f from B.f when B is a subclass of A that inherits f from A, since capabilities may differ at different levels of the inheritance hierarchy. Through changes in SciPy, I was able to allow precise declarations of capabilities for class methods in the presence of inheritance, but a separate change is needed here to allow things like applying lazy_xp_function to B.f but not A.f when f is inherited from A.

The change here modifies lazy_xp_function to also take tuples of the form Tuple[type, str]. When this is done, for say (B, "f"), then B.f is replaced with a shallow clone of itself before adding the tags. This allowsB.f gets the tags without A.f to get them. If replacing an inherited method with a shallow clone is too obtrusive, I have a workaround that keeps more obtrusive modifications only within patch_lazy_xp_functions, but it makes things considerably more complicated, so I hope that won't be necessary.

https://github.com/scipy/scipy/blob/dfa1b87e4af7cf7ee5a8b8faf5c4360b63c86b36/scipy/_lib/tests/test_xp_capabilities.py has an (xfailed) test giving an example of what can go wrong with inheritance currently which can be made to pass with the change made in this PR.

@lucascolley lucascolley added the enhancement New feature or request label Jan 3, 2026
@lucascolley
Copy link
Member

a few CI failures otherwise looks pretty good

@lucascolley
Copy link
Member

cc @crusaderky

@lucascolley lucascolley added this to the 0.10.0 milestone Jan 3, 2026
@steppi
Copy link
Author

steppi commented Jan 3, 2026

Thanks @lucascolley. I've pushed a commit which fixes the failures locally.

Copy link
Member

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks Albert LGTM

@steppi
Copy link
Author

steppi commented Jan 3, 2026

Coverage misses some lines in the diff because ufuncs allow setting attributes as of NumPy 2.2 (numpy/numpy#27735) so those lines only get hit for older NumPy versions.

image

@lucascolley
Copy link
Member

Are they hit by pixi run -e tests-numpy1 tests-ci or not? If not, is there something else than just the NumPy version?

@steppi
Copy link
Author

steppi commented Jan 3, 2026

Are they hit by pixi run -e tests-numpy1 tests-ci or not? If not, is there something else than just the NumPy version?

Interesting, I tried it locally and it seems like they are still not hit. But they should be hit by this test if an older NumPy is installed::

try:
    # Test an arbitrary Cython ufunc (@cython.vectorize).
    # When SCIPY_ARRAY_API is not set, this is the same as
    # scipy.special.erf.
    from scipy.special._ufuncs import erf  # type: ignore[import-untyped]

    lazy_xp_function(erf)
except ImportError:
    erf = None


@pytest.mark.skip_xp_backend(Backend.TORCH_GPU, reason="device->host copy")
@pytest.mark.filterwarnings("ignore:__array_wrap__:DeprecationWarning")  # PyTorch
def test_lazy_xp_function_cython_ufuncs(xp: ModuleType, library: Backend):
    pytest.importorskip("scipy")
    assert erf is not None
    x = xp.asarray([6.0, 7.0])
    if library.like(Backend.ARRAY_API_STRICT, Backend.JAX):
        # array-api-strict arrays are auto-converted to NumPy
        # which results in an assertion error for mismatched namespaces
        # eager JAX arrays are auto-converted to NumPy in eager JAX
        # and fail in jax.jit (which lazy_xp_function tests here)
        with pytest.raises((TypeError, AssertionError)):
            xp_assert_equal(cast(Array, erf(x)), xp.asarray([1.0, 1.0]))
    else:
        # CuPy, Dask and sparse define __array_ufunc__ and dispatch accordingly
        # note that when sparse reduces to scalar it returns a np.generic, which
        # would make xp_assert_equal fail.
        xp_assert_equal(cast(Array, erf(x)), xp.asarray([1.0, 1.0]))

@steppi
Copy link
Author

steppi commented Jan 3, 2026

There's no mystery. The test that would hit those lines depends on SciPy and there's no SciPy in the tests-numpy1 env.

@lucascolley
Copy link
Member

Wanna try pixi add -f tests scipy? If there are any difficulties then probably fine to just ignore it.

@steppi
Copy link
Author

steppi commented Jan 3, 2026

Wanna try pixi add -f tests scipy? If there are any difficulties then probably fine to just ignore it.

I can confirm that adding scipy causes those lines to get hit locally.

Copy link
Contributor

@crusaderky crusaderky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO the API would feel more user-friendly if one could just write

lazy_xp_function(B.g)

is there a strong reason against it?

Also, is there a draft scipy PR that shows how this gets integrated in scipy's xp_capabilities?

@steppi
Copy link
Author

steppi commented Jan 7, 2026

Also, is there a draft scipy PR that shows how this gets integrated in scipy's xp_capabilities?

The draft PR is here scipy/scipy#24267.

there a strong reason against it?

In situations where B inherits g from A, B.g is a reference to the parent method and it is impossible to determine within the body of lazy_xp_function which class is intended from just B.g as far as I'm aware. Tuples allow unambiguous determination of which class is intended within the class hierarchy. This sidesteps issues that can arise when lazy capabilities differ at different levels of the hierarchy.

Co-authored-by: Guido Imperiale <crusaderky@gmail.com>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now obsolete. Could you either

  • clarify that you need to explicitly list classes as well as modules
  • change patch_lazy_xp_functions to descend into classes from the modules (preferrable)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently classes need to be explicitly listed as well as modules but yes, it would be nice if patch_lazy_xp_functions descended into modules so I'm +1 for that.

Comment on lines 353 to 355
def test_lazy_xp_function_class_inheritance():
assert hasattr(B.g, "_lazy_xp_function")
assert not hasattr(A.g, "_lazy_xp_function")
Copy link
Contributor

@crusaderky crusaderky Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What this new test doesn't verify is that B.g actually runs with the JAX/Dask wrapper in a test with the xp fixture. In fact, it doesn't because patch_lazy_xp_functions doesn't descend into the class.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, because the class needs to be added to lazy_xp_modules and the test needs the xp fixture. Things will work in that case, and there are tests in SciPy that work after doing this. Maybe I should just update patch_lazy_xp_functions here too to sidestep that though.

Copy link
Contributor

@crusaderky crusaderky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.

@steppi
Copy link
Author

steppi commented Jan 7, 2026

Thanks @crusaderky, patch_lazy_xp_functions should correctly search classes within modules now.

@lucascolley lucascolley requested a review from crusaderky January 7, 2026 19:02
foo = B(x)
observed = foo.g(y, z)
expected = xp.asarray(44.0)[()]
xp_assert_close(observed, expected)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This runs the function, but it doesn't test that it's been wrapped.
You need to write another function that will fail when wrapped, e.g.

def w(self):
    return bool(self._xp.any(self.x))`

See tests earlier in this same test module for examples.

@steppi
Copy link
Author

steppi commented Jan 8, 2026

Thanks @crusaderky; the test is actually doing what it's supposed to now.

@lucascolley lucascolley requested a review from crusaderky January 8, 2026 17:23
foo = A(x)
bar = B(x)

if library.like(Backend.JAX):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not just JAX, sparse raises a RuntimeError, see CI

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but for different reasons orthogonal from what's being tested so I'm just going to add a skip for it. By the way, what's the pixi command to run tests with all backends?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

array-api-extra on 🎋 lazy-class-methods is 📦 v0.10.0.dev0 via 🐍 took 4spixi run tests
? The task 'tests' can be run in multiple environments.

Please select an environment to run the task in: ›tests
  tests-py313
  dev
  tests-backends
  tests-backends-py311
  dev-cuda
  tests-cuda
  tests-cuda-py311
  tests-numpy1
  tests-py311
  tests-nogil

tests-backends for all (CPU) backends!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if in doubt the CI log will always show the exact Pixi task you need to reproduce

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Also, is there a convenient way to add a skip for all GPU backends? In SciPy skip_xp_backends has a cpu_only=True option, but here it looks like I'd need to add a skip for each GPU backend separately? I guess I could also just find a way to test the same behavior without needing the skips.

Copy link
Member

@lucascolley lucascolley Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that exists yet, however it should be somewhat easy to add a method similar to

def like(self, *others: Backend) -> bool: # numpydoc ignore=PR01,RT01
"""Check if this backend uses the same module as others."""
return any(self.modname == other.modname for other in others)
which checks if self.name ends with :gpu (or something more general).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. That's probably out of scope for this PR so I just added separate skips for all of the GPU backends.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request xpx.testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ENH: Add support for something like lazy_xp_function for classes

3 participants