diff --git a/pytest_asyncio/plugin.py b/pytest_asyncio/plugin.py index aecf6e96..5eadd110 100644 --- a/pytest_asyncio/plugin.py +++ b/pytest_asyncio/plugin.py @@ -21,13 +21,13 @@ Iterator, Sequence, ) +from types import AsyncGeneratorType, CoroutineType from typing import ( Any, Callable, Literal, TypeVar, Union, - cast, overload, ) @@ -35,7 +35,6 @@ import pytest from _pytest.scope import Scope from pytest import ( - Collector, Config, FixtureDef, FixtureRequest, @@ -43,6 +42,7 @@ Item, Mark, Metafunc, + MonkeyPatch, Parser, PytestCollectionWarning, PytestDeprecationWarning, @@ -50,9 +50,9 @@ ) if sys.version_info >= (3, 10): - from typing import ParamSpec + from typing import Concatenate, ParamSpec else: - from typing_extensions import ParamSpec + from typing_extensions import Concatenate, ParamSpec _ScopeName = Literal["session", "package", "module", "class", "function"] @@ -230,45 +230,16 @@ def pytest_report_header(config: Config) -> list[str]: ] -def _preprocess_async_fixtures( - collector: Collector, - processed_fixturedefs: set[FixtureDef], -) -> None: - config = collector.config - default_loop_scope = config.getini("asyncio_default_fixture_loop_scope") - asyncio_mode = _get_asyncio_mode(config) - fixturemanager = config.pluginmanager.get_plugin("funcmanage") - assert fixturemanager is not None - for fixtures in fixturemanager._arg2fixturedefs.values(): - for fixturedef in fixtures: - func = fixturedef.func - if fixturedef in processed_fixturedefs or not _is_coroutine_or_asyncgen( - func - ): - continue - if asyncio_mode == Mode.STRICT and not _is_asyncio_fixture_function(func): - # Ignore async fixtures without explicit asyncio mark in strict mode - # This applies to pytest_trio fixtures, for example - continue - loop_scope = ( - getattr(func, "_loop_scope", None) - or default_loop_scope - or fixturedef.scope - ) - _make_asyncio_fixture_function(func, loop_scope) - if "request" not in fixturedef.argnames: - fixturedef.argnames += ("request",) - _synchronize_async_fixture(fixturedef) - assert _is_asyncio_fixture_function(fixturedef.func) - processed_fixturedefs.add(fixturedef) - - -def _synchronize_async_fixture(fixturedef: FixtureDef) -> None: - """Wraps the fixture function of an async fixture in a synchronous function.""" +def _fixture_synchronizer( + fixturedef: FixtureDef, event_loop: AbstractEventLoop +) -> Callable: + """Returns a synchronous function evaluating the specified fixture.""" if inspect.isasyncgenfunction(fixturedef.func): - _wrap_asyncgen_fixture(fixturedef) + return _wrap_asyncgen_fixture(fixturedef.func, event_loop) elif inspect.iscoroutinefunction(fixturedef.func): - _wrap_async_fixture(fixturedef) + return _wrap_async_fixture(fixturedef.func, event_loop) + else: + return fixturedef.func def _add_kwargs( @@ -299,18 +270,26 @@ def _perhaps_rebind_fixture_func(func: _T, instance: Any | None) -> _T: return func -def _wrap_asyncgen_fixture(fixturedef: FixtureDef) -> None: - fixture = fixturedef.func +AsyncGenFixtureParams = ParamSpec("AsyncGenFixtureParams") +AsyncGenFixtureYieldType = TypeVar("AsyncGenFixtureYieldType") - @functools.wraps(fixture) - def _asyncgen_fixture_wrapper(request: FixtureRequest, **kwargs: Any): - func = _perhaps_rebind_fixture_func(fixture, request.instance) - event_loop_fixture_id = _get_event_loop_fixture_id_for_async_fixture( - request, func - ) - event_loop = request.getfixturevalue(event_loop_fixture_id) - kwargs.pop(event_loop_fixture_id, None) - gen_obj = func(**_add_kwargs(func, kwargs, request)) + +def _wrap_asyncgen_fixture( + fixture_function: Callable[ + AsyncGenFixtureParams, AsyncGeneratorType[AsyncGenFixtureYieldType, Any] + ], + event_loop: AbstractEventLoop, +) -> Callable[ + Concatenate[FixtureRequest, AsyncGenFixtureParams], AsyncGenFixtureYieldType +]: + @functools.wraps(fixture_function) + def _asyncgen_fixture_wrapper( + request: FixtureRequest, + *args: AsyncGenFixtureParams.args, + **kwargs: AsyncGenFixtureParams.kwargs, + ): + func = _perhaps_rebind_fixture_func(fixture_function, request.instance) + gen_obj = func(*args, **_add_kwargs(func, kwargs, request)) async def setup(): res = await gen_obj.__anext__() # type: ignore[union-attr] @@ -343,23 +322,30 @@ async def async_finalizer() -> None: request.addfinalizer(finalizer) return result - fixturedef.func = _asyncgen_fixture_wrapper # type: ignore[misc] + return _asyncgen_fixture_wrapper -def _wrap_async_fixture(fixturedef: FixtureDef) -> None: - fixture = fixturedef.func +AsyncFixtureParams = ParamSpec("AsyncFixtureParams") +AsyncFixtureReturnType = TypeVar("AsyncFixtureReturnType") - @functools.wraps(fixture) - def _async_fixture_wrapper(request: FixtureRequest, **kwargs: Any): - func = _perhaps_rebind_fixture_func(fixture, request.instance) - event_loop_fixture_id = _get_event_loop_fixture_id_for_async_fixture( - request, func - ) - event_loop = request.getfixturevalue(event_loop_fixture_id) - kwargs.pop(event_loop_fixture_id, None) + +def _wrap_async_fixture( + fixture_function: Callable[ + AsyncFixtureParams, CoroutineType[Any, Any, AsyncFixtureReturnType] + ], + event_loop: AbstractEventLoop, +) -> Callable[Concatenate[FixtureRequest, AsyncFixtureParams], AsyncFixtureReturnType]: + + @functools.wraps(fixture_function) # type: ignore[arg-type] + def _async_fixture_wrapper( + request: FixtureRequest, + *args: AsyncFixtureParams.args, + **kwargs: AsyncFixtureParams.kwargs, + ): + func = _perhaps_rebind_fixture_func(fixture_function, request.instance) async def setup(): - res = await func(**_add_kwargs(func, kwargs, request)) + res = await func(*args, **_add_kwargs(func, kwargs, request)) return res context = contextvars.copy_context() @@ -380,19 +366,7 @@ async def setup(): return result - fixturedef.func = _async_fixture_wrapper # type: ignore[misc] - - -def _get_event_loop_fixture_id_for_async_fixture( - request: FixtureRequest, func: Any -) -> str: - default_loop_scope = cast( - _ScopeName, request.config.getini("asyncio_default_fixture_loop_scope") - ) - loop_scope = ( - getattr(func, "_loop_scope", None) or default_loop_scope or request.scope - ) - return f"_{loop_scope}_event_loop" + return _async_fixture_wrapper def _create_task_in_context( @@ -573,22 +547,6 @@ def runtest(self) -> None: super().runtest() -_HOLDER: set[FixtureDef] = set() - - -# The function name needs to start with "pytest_" -# see https://github.com/pytest-dev/pytest/issues/11307 -@pytest.hookimpl(specname="pytest_pycollect_makeitem", tryfirst=True) -def pytest_pycollect_makeitem_preprocess_async_fixtures( - collector: pytest.Module | pytest.Class, name: str, obj: object -) -> pytest.Item | pytest.Collector | list[pytest.Item | pytest.Collector] | None: - """A pytest hook to collect asyncio coroutines.""" - if not collector.funcnamefilter(name): - return None - _preprocess_async_fixtures(collector, _HOLDER) - return None - - # The function name needs to start with "pytest_" # see https://github.com/pytest-dev/pytest/issues/11307 @pytest.hookimpl(specname="pytest_pycollect_makeitem", hookwrapper=True) @@ -803,6 +761,34 @@ def pytest_runtest_setup(item: pytest.Item) -> None: ) +@pytest.hookimpl(wrapper=True) +def pytest_fixture_setup(fixturedef: FixtureDef, request) -> object | None: + asyncio_mode = _get_asyncio_mode(request.config) + if not _is_asyncio_fixture_function(fixturedef.func): + if asyncio_mode == Mode.STRICT: + # Ignore async fixtures without explicit asyncio mark in strict mode + # This applies to pytest_trio fixtures, for example + return (yield) + if not _is_coroutine_or_asyncgen(fixturedef.func): + return (yield) + default_loop_scope = request.config.getini("asyncio_default_fixture_loop_scope") + loop_scope = ( + getattr(fixturedef.func, "_loop_scope", None) + or default_loop_scope + or fixturedef.scope + ) + event_loop_fixture_id = f"_{loop_scope}_event_loop" + event_loop = request.getfixturevalue(event_loop_fixture_id) + synchronizer = _fixture_synchronizer(fixturedef, event_loop) + _make_asyncio_fixture_function(synchronizer, loop_scope) + with MonkeyPatch.context() as c: + if "request" not in fixturedef.argnames: + c.setattr(fixturedef, "argnames", (*fixturedef.argnames, "request")) + c.setattr(fixturedef, "func", synchronizer) + hook_result = yield + return hook_result + + _DUPLICATE_LOOP_SCOPE_DEFINITION_ERROR = """\ An asyncio pytest marker defines both "scope" and "loop_scope", \ but it should only use "loop_scope".