diff --git a/pytest_trio/__init__.py b/pytest_trio/__init__.py index ffd3b89..5fbd91f 100644 --- a/pytest_trio/__init__.py +++ b/pytest_trio/__init__.py @@ -1,3 +1,6 @@ """Top-level package for pytest-trio.""" from ._version import __version__ +from .plugin import trio_fixture + +__all__ = ["trio_fixture"] diff --git a/pytest_trio/_tests/test_fixture_mistakes.py b/pytest_trio/_tests/test_fixture_mistakes.py new file mode 100644 index 0000000..2e775f2 --- /dev/null +++ b/pytest_trio/_tests/test_fixture_mistakes.py @@ -0,0 +1,68 @@ +import pytest +from pytest_trio import trio_fixture + + +def test_trio_fixture_with_non_trio_test(testdir): + testdir.makepyfile( + """ + import trio + from pytest_trio import trio_fixture + import pytest + + @trio_fixture + def trio_time(): + return trio.current_time() + + @pytest.fixture + def indirect_trio_time(trio_time): + return trio_time + 1 + + @pytest.mark.trio + async def test_async(mock_clock, trio_time, indirect_trio_time): + assert trio_time == 0 + assert indirect_trio_time == 1 + + def test_sync(trio_time): + pass + + def test_sync_indirect(indirect_trio_time): + pass + """ + ) + + result = testdir.runpytest() + + result.assert_outcomes(passed=1, error=2) + result.stdout.fnmatch_lines( + ["*Trio fixtures can only be used by Trio tests*"] + ) + + +def test_trio_fixture_with_wrong_scope(testdir): + # There's a trick here: when you have a non-function-scope fixture, it's + # not instantiated for any particular function (obviously). So... when our + # pytest_fixture_setup hook tries to check for marks, it can't normally + # see @pytest.mark.trio. So... it's actually almost impossible to have an + # async fixture get treated as a Trio fixture *and* have it be + # non-function-scope. But, class-scoped fixtures can see marks on the + # class, so this is one way (the only way?) it can happen: + testdir.makepyfile( + """ + import pytest + import pytest_trio + + @pytest.fixture(scope="class") + async def async_class_fixture(): + pass + + @pytest.mark.trio + class TestFoo: + async def test_foo(self, async_class_fixture): + pass + """ + ) + + result = testdir.runpytest() + + result.assert_outcomes(error=1) + result.stdout.fnmatch_lines(["*must be function-scope*"]) diff --git a/pytest_trio/plugin.py b/pytest_trio/plugin.py index 59d79e0..af5b813 100644 --- a/pytest_trio/plugin.py +++ b/pytest_trio/plugin.py @@ -1,12 +1,14 @@ """pytest-trio implementation.""" import sys from traceback import format_exception +from collections.abc import Coroutine, Generator from inspect import iscoroutinefunction, isgeneratorfunction import pytest import trio from trio.testing import MockClock, trio_test from async_generator import ( - async_generator, yield_, asynccontextmanager, isasyncgenfunction + async_generator, yield_, asynccontextmanager, isasyncgen, + isasyncgenfunction ) ################################################################ @@ -106,7 +108,7 @@ async def _setup_async_fixtures_in(deps): __tracebackhide__ = True need_resolved_deps_stack = [ - (k, v) for k, v in deps.items() if isinstance(v, BaseAsyncFixture) + (k, v) for k, v in deps.items() if isinstance(v, TrioFixture) ] if not ORDERED_DICTS: # Make the fixture resolution order determinist @@ -136,7 +138,7 @@ async def _recursive_setup(deps_stack): await yield_({**deps, **dict(resolved_deps_stack)}) -class BaseAsyncFixture: +class TrioFixture: """ Represent a fixture that need to be run in a trio context to be resolved. """ @@ -155,102 +157,39 @@ async def setup(self): await yield_(self.result) else: async with _setup_async_fixtures_in(self.deps) as resolved_deps: - async with self._setup(resolved_deps) as self.result: - self.setup_done = True - await yield_(self.result) - - async def _setup(self): - raise NotImplementedError() - - -class AsyncYieldFixture(BaseAsyncFixture): - """ - Async generator fixture. - """ - - @asynccontextmanager - @async_generator - async def _setup(self, resolved_deps): - __tracebackhide__ = True - agen = self.fixturedef.func(**resolved_deps) - - try: - await yield_(await agen.asend(None)) - finally: - try: - await agen.asend(None) - except StopAsyncIteration: - pass - else: - raise RuntimeError('Only one yield in fixture is allowed') - - -class SyncFixtureWithAsyncDeps(BaseAsyncFixture): - """ - Synchronous function fixture with asynchronous dependencies fixtures. - """ - - @asynccontextmanager - @async_generator - async def _setup(self, resolved_deps): - __tracebackhide__ = True - await yield_(self.fixturedef.func(**resolved_deps)) - - -class SyncYieldFixtureWithAsyncDeps(BaseAsyncFixture): - """ - Synchronous generator fixture with asynchronous dependencies fixtures. - """ - - @asynccontextmanager - @async_generator - async def _setup(self, resolved_deps): - __tracebackhide__ = True - gen = self.fixturedef.func(**resolved_deps) - - try: - await yield_(gen.send(None)) - finally: - try: - gen.send(None) - except StopIteration: - pass - else: - raise RuntimeError('Only one yield in fixture is allowed') - - -class AsyncFixture(BaseAsyncFixture): - """ - Regular async fixture (i.e. coroutine). - """ - - @asynccontextmanager - @async_generator - async def _setup(self, resolved_deps): - __tracebackhide__ = True - await yield_(await self.fixturedef.func(**resolved_deps)) - + retval = self.fixturedef.func(**resolved_deps) + if isinstance(retval, Coroutine): + self.result = await retval + elif isasyncgen(retval): + self.result = await retval.asend(None) + elif isinstance(retval, Generator): + self.result = retval.send(None) + else: + # Regular synchronous function + self.result = retval -def _install_async_fixture_if_needed(fixturedef, request): - asyncfix = None - deps = {dep: request.getfixturevalue(dep) for dep in fixturedef.argnames} - if iscoroutinefunction(fixturedef.func): - asyncfix = AsyncFixture(fixturedef, deps) - elif isasyncgenfunction(fixturedef.func): - asyncfix = AsyncYieldFixture(fixturedef, deps) - elif any(isinstance(dep, BaseAsyncFixture) for dep in deps.values()): - if isgeneratorfunction(fixturedef.func): - asyncfix = SyncYieldFixtureWithAsyncDeps(fixturedef, deps) - else: - asyncfix = SyncFixtureWithAsyncDeps(fixturedef, deps) - if asyncfix: - fixturedef.cached_result = (asyncfix, request.param_index, None) - return asyncfix + try: + await yield_(self.result) + finally: + if isasyncgen(retval): + try: + await retval.asend(None) + except StopAsyncIteration: + pass + else: + raise RuntimeError("too many yields in fixture") + elif isinstance(retval, Generator): + try: + retval.send(None) + except StopIteration: + pass + else: + raise RuntimeError("too many yields in fixture") @pytest.hookimpl(hookwrapper=True) def pytest_runtest_call(item): - if 'trio' in item.keywords: + if item.get_closest_marker("trio") is not None: if hasattr(item.obj, 'hypothesis'): # If it's a Hypothesis test, we go in a layer. item.obj.hypothesis.inner_test = _trio_test_runner_factory( @@ -267,10 +206,34 @@ def pytest_runtest_call(item): yield -@pytest.hookimpl() +def trio_fixture(func): + func._force_trio_fixture = True + return pytest.fixture(func) + + +def _is_trio_fixture(func, is_trio_test, deps): + if getattr(func, "_force_trio_fixture", False): + return True + if is_trio_test: + if iscoroutinefunction(func) or isasyncgenfunction(func): + return True + if any(isinstance(dep, TrioFixture) for dep in deps.values()): + return True + return False + + +@pytest.hookimpl def pytest_fixture_setup(fixturedef, request): - if 'trio' in request.keywords: - return _install_async_fixture_if_needed(fixturedef, request) + is_trio_test = (request.node.get_closest_marker("trio") is not None) + deps = {dep: request.getfixturevalue(dep) for dep in fixturedef.argnames} + if _is_trio_fixture(fixturedef.func, is_trio_test, deps): + if request.scope != "function": + raise RuntimeError("Trio fixtures must be function-scope") + if not is_trio_test: + raise RuntimeError("Trio fixtures can only be used by Trio tests") + fixture = TrioFixture(fixturedef, deps) + fixturedef.cached_result = (fixture, request.param_index, None) + return fixture ################################################################ @@ -308,6 +271,6 @@ def autojump_clock(): return MockClock(autojump_threshold=0) -@pytest.fixture -async def nursery(request): +@trio_fixture +def nursery(request): return request.node._trio_nursery diff --git a/setup.py b/setup.py index 03a577e..6b25497 100644 --- a/setup.py +++ b/setup.py @@ -18,6 +18,8 @@ install_requires=[ "trio", "async_generator >= 1.9", + # For node.get_closest_marker + "pytest >= 3.6" ], keywords=[ 'async',