From 8cac0e47b6251b65ddfee434949e2f277c001f19 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 14 Mar 2025 10:19:33 +0000 Subject: [PATCH] ENH: `lazy_xp_function` namespaces support --- src/array_api_extra/testing.py | 101 ++++++++++++++++++++++++--------- tests/test_testing.py | 37 ++++++++++++ 2 files changed, 110 insertions(+), 28 deletions(-) diff --git a/src/array_api_extra/testing.py b/src/array_api_extra/testing.py index 8aca160f..a0f97a81 100644 --- a/src/array_api_extra/testing.py +++ b/src/array_api_extra/testing.py @@ -56,9 +56,8 @@ def lazy_xp_function( # type: ignore[no-any-explicit] """ Tag a function to be tested on lazy backends. - Tag a function, which must be imported in the test module globals, so that when any - tests defined in the same module are executed with ``xp=jax.numpy`` the function is - replaced with a jitted version of itself, and when it is executed with + Tag a function so that when any tests are executed with ``xp=jax.numpy`` the + function is replaced with a jitted version of itself, and when it is executed with ``xp=dask.array`` the function will raise if it attempts to materialize the graph. This will be later expanded to provide test coverage for other lazy backends. @@ -120,19 +119,59 @@ def test_myfunc(xp): Notes ----- - A test function can circumvent this monkey-patching system by calling `func` as an - attribute of the original module. You need to sanitize your code to make sure this - does not happen. + In order for this tag to be effective, the test function must be imported into the + test module globals without its namespace; alternatively its namespace must be + declared in a ``lazy_xp_modules`` list in the test module globals. - Example:: + Example 1:: - import mymodule from mymodule import myfunc + from mymodule import myfunc lazy_xp_function(myfunc) def test_myfunc(xp): - a = xp.asarray([1, 2]) b = myfunc(a) # This is jitted when xp=jax.numpy c = - mymodule.myfunc(a) # This is not + x = myfunc(xp.asarray([1, 2])) + + Example 2:: + + import mymodule + + lazy_xp_modules = [mymodule] + lazy_xp_function(mymodule.myfunc) + + def test_myfunc(xp): + x = mymodule.myfunc(xp.asarray([1, 2])) + + A test function can circumvent this monkey-patching system by using a namespace + outside of the two above patterns. You need to sanitize your code to make sure this + only happens intentionally. + + Example 1:: + + import mymodule + from mymodule import myfunc + + lazy_xp_function(myfunc) + + def test_myfunc(xp): + a = xp.asarray([1, 2]) + b = myfunc(a) # This is wrapped when xp=jax.numpy or xp=dask.array + c = mymodule.myfunc(a) # This is not + + Example 2:: + + import mymodule + + class naked: + myfunc = mymodule.myfunc + + lazy_xp_modules = [mymodule] + lazy_xp_function(mymodule.myfunc) + + def test_myfunc(xp): + a = xp.asarray([1, 2]) + b = mymodule.myfunc(a) # This is wrapped when xp=jax.numpy or xp=dask.array + c = naked.myfunc(a) # This is not """ tags = { "allow_dask_compute": allow_dask_compute, @@ -153,11 +192,13 @@ def patch_lazy_xp_functions( Test lazy execution of functions tagged with :func:`lazy_xp_function`. If ``xp==jax.numpy``, search for all functions which have been tagged with - :func:`lazy_xp_function` in the globals of the module that defines the current test + :func:`lazy_xp_function` in the globals of the module that defines the current test, + as well as in the ``lazy_xp_modules`` list in the globals of the same module, and wrap them with :func:`jax.jit`. Unwrap them at the end of the test. If ``xp==dask.array``, wrap the functions with a decorator that disables - ``compute()`` and ``persist()``. + ``compute()`` and ``persist()`` and ensures that exceptions and warnings are raised + eagerly. This function should be typically called by your library's `xp` fixture that runs tests on multiple backends:: @@ -183,29 +224,33 @@ def xp(request, monkeypatch): lazy_xp_function : Tag a function to be tested on lazy backends. pytest.FixtureRequest : `request` test function parameter. """ - globals_ = cast("dict[str, Any]", request.module.__dict__) # type: ignore[no-any-explicit] - - def iter_tagged() -> Iterator[tuple[str, Callable[..., Any], dict[str, Any]]]: # type: ignore[no-any-explicit] - for name, func in globals_.items(): - tags: dict[str, Any] | None = None # type: ignore[no-any-explicit] - with contextlib.suppress(AttributeError): - tags = func._lazy_xp_function # pylint: disable=protected-access - if tags is None: - with contextlib.suppress(KeyError, TypeError): - tags = _ufuncs_tags[func] - if tags is not None: - yield name, func, tags + mod = cast(ModuleType, request.module) + mods = [mod, *cast(list[ModuleType], getattr(mod, "lazy_xp_modules", []))] + + def iter_tagged() -> ( # type: ignore[no-any-explicit] + Iterator[tuple[ModuleType, str, Callable[..., Any], dict[str, Any]]] + ): + for mod in mods: + for name, func in mod.__dict__.items(): + tags: dict[str, Any] | None = None # type: ignore[no-any-explicit] + with contextlib.suppress(AttributeError): + tags = func._lazy_xp_function # pylint: disable=protected-access + if tags is None: + with contextlib.suppress(KeyError, TypeError): + tags = _ufuncs_tags[func] + if tags is not None: + yield mod, name, func, tags if is_dask_namespace(xp): - for name, func, tags in iter_tagged(): + for mod, name, func, tags in iter_tagged(): n = tags["allow_dask_compute"] wrapped = _dask_wrap(func, n) - monkeypatch.setitem(globals_, name, wrapped) + monkeypatch.setattr(mod, name, wrapped) elif is_jax_namespace(xp): import jax - for name, func, tags in iter_tagged(): + for mod, name, func, tags in iter_tagged(): if tags["jax_jit"]: # suppress unused-ignore to run mypy in -e lint as well as -e dev wrapped = cast( # type: ignore[no-any-explicit] @@ -216,7 +261,7 @@ def iter_tagged() -> Iterator[tuple[str, Callable[..., Any], dict[str, Any]]]: static_argnames=tags["static_argnames"], ), ) - monkeypatch.setitem(globals_, name, wrapped) + monkeypatch.setattr(mod, name, wrapped) class CountingDaskScheduler(SchedulerGetCallable): diff --git a/tests/test_testing.py b/tests/test_testing.py index 1649dd86..ed21feb2 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -108,6 +108,7 @@ def non_materializable(x: Array) -> Array: and it will trigger an expensive computation in dask. """ xp = array_namespace(x) + # Crashes inside jax.jit # On dask, this triggers two computations of the whole graph if xp.any(x < 0.0) or xp.any(x > 10.0): msg = "Values must be in the [0, 10] range" @@ -261,3 +262,39 @@ def test_lazy_xp_function_eagerly_raises(da: ModuleType): x = da.arange(3) with pytest.raises(ValueError, match="Hello world"): dask_raises(x) + + +class Wrapped: + def f(x: Array) -> Array: # noqa: N805 # pyright: ignore[reportSelfClsParameterName] + xp = array_namespace(x) + # Crash in jax.jit and trigger compute() on dask + if not xp.all(x): + msg = "Values must be non-zero" + raise ValueError(msg) + return x + + +class Naked: + f = Wrapped.f # pyright: ignore[reportUnannotatedClassAttribute] + + +lazy_xp_function(Wrapped.f) +lazy_xp_modules = [Wrapped] + + +def test_lazy_xp_modules(xp: ModuleType, library: Backend): + x = xp.asarray([1.0, 2.0]) + y = Naked.f(x) + xp_assert_equal(y, x) + + if library is Backend.JAX: + with pytest.raises( + TypeError, match="Attempted boolean conversion of traced array" + ): + Wrapped.f(x) + elif library is Backend.DASK: + with pytest.raises(AssertionError, match=r"dask\.compute"): + Wrapped.f(x) + else: + y = Wrapped.f(x) + xp_assert_equal(y, x)