Skip to content

Commit

Permalink
Rework fixtures
Browse files Browse the repository at this point in the history
- Add @pytest_trio.trio_fixture for explicitly marking a fixture as
  being a trio fixture
- Make the nursery fixture a @trio_fixture
- Refactor Trio fixture classes into one class
- Check for trio marker instead of trio keyword (fixes gh-43)
  - This also raises the minimum pytest version to 3.6
- Raise an error if a Trio fixture is used with a non-function
  scope (fixes gh-18)
- Raise an error if a Trio fixture is used with a non-Trio test

I think this also closes gh-10's discussion, though we still need to
convince pytest-asyncio to fix their side of things.
  • Loading branch information
njsmith committed Jul 25, 2018
1 parent c905f15 commit 29be010
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 99 deletions.
3 changes: 3 additions & 0 deletions pytest_trio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""Top-level package for pytest-trio."""

from ._version import __version__
from .plugin import trio_fixture

__all__ = ["trio_fixture"]
68 changes: 68 additions & 0 deletions pytest_trio/_tests/test_fixture_mistakes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import pytest
from pytest_trio import trio_fixture


def test_trio_fixture_with_non_trio_test(testdir):
testdir.makepyfile(
"""
import trio
from pytest_trio import trio_fixture
import pytest
@trio_fixture
def trio_time():
return trio.current_time()
@pytest.fixture
def indirect_trio_time(trio_time):
return trio_time + 1
@pytest.mark.trio
async def test_async(mock_clock, trio_time, indirect_trio_time):
assert trio_time == 0
assert indirect_trio_time == 1
def test_sync(trio_time):
pass
def test_sync_indirect(indirect_trio_time):
pass
"""
)

result = testdir.runpytest()

result.assert_outcomes(passed=1, error=2)
result.stdout.fnmatch_lines(
["*Trio fixtures can only be used by Trio tests*"]
)


def test_trio_fixture_with_wrong_scope(testdir):
# There's a trick here: when you have a non-function-scope fixture, it's
# not instantiated for any particular function (obviously). So... when our
# pytest_fixture_setup hook tries to check for marks, it can't normally
# see @pytest.mark.trio. So... it's actually almost impossible to have an
# async fixture get treated as a Trio fixture *and* have it be
# non-function-scope. But, class-scoped fixtures can see marks on the
# class, so this is one way (the only way?) it can happen:
testdir.makepyfile(
"""
import pytest
import pytest_trio
@pytest.fixture(scope="class")
async def async_class_fixture():
pass
@pytest.mark.trio
class TestFoo:
async def test_foo(self, async_class_fixture):
pass
"""
)

result = testdir.runpytest()

result.assert_outcomes(error=1)
result.stdout.fnmatch_lines(["*must be function-scope*"])
161 changes: 62 additions & 99 deletions pytest_trio/plugin.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""pytest-trio implementation."""
import sys
from traceback import format_exception
from collections.abc import Coroutine, Generator
from inspect import iscoroutinefunction, isgeneratorfunction
import pytest
import trio
from trio.testing import MockClock, trio_test
from async_generator import (
async_generator, yield_, asynccontextmanager, isasyncgenfunction
async_generator, yield_, asynccontextmanager, isasyncgen,
isasyncgenfunction
)

################################################################
Expand Down Expand Up @@ -106,7 +108,7 @@ async def _setup_async_fixtures_in(deps):
__tracebackhide__ = True

need_resolved_deps_stack = [
(k, v) for k, v in deps.items() if isinstance(v, BaseAsyncFixture)
(k, v) for k, v in deps.items() if isinstance(v, TrioFixture)
]
if not ORDERED_DICTS:
# Make the fixture resolution order determinist
Expand Down Expand Up @@ -136,7 +138,7 @@ async def _recursive_setup(deps_stack):
await yield_({**deps, **dict(resolved_deps_stack)})


class BaseAsyncFixture:
class TrioFixture:
"""
Represent a fixture that need to be run in a trio context to be resolved.
"""
Expand All @@ -155,102 +157,39 @@ async def setup(self):
await yield_(self.result)
else:
async with _setup_async_fixtures_in(self.deps) as resolved_deps:
async with self._setup(resolved_deps) as self.result:
self.setup_done = True
await yield_(self.result)

async def _setup(self):
raise NotImplementedError()


class AsyncYieldFixture(BaseAsyncFixture):
"""
Async generator fixture.
"""

@asynccontextmanager
@async_generator
async def _setup(self, resolved_deps):
__tracebackhide__ = True
agen = self.fixturedef.func(**resolved_deps)

try:
await yield_(await agen.asend(None))
finally:
try:
await agen.asend(None)
except StopAsyncIteration:
pass
else:
raise RuntimeError('Only one yield in fixture is allowed')


class SyncFixtureWithAsyncDeps(BaseAsyncFixture):
"""
Synchronous function fixture with asynchronous dependencies fixtures.
"""

@asynccontextmanager
@async_generator
async def _setup(self, resolved_deps):
__tracebackhide__ = True
await yield_(self.fixturedef.func(**resolved_deps))


class SyncYieldFixtureWithAsyncDeps(BaseAsyncFixture):
"""
Synchronous generator fixture with asynchronous dependencies fixtures.
"""

@asynccontextmanager
@async_generator
async def _setup(self, resolved_deps):
__tracebackhide__ = True
gen = self.fixturedef.func(**resolved_deps)

try:
await yield_(gen.send(None))
finally:
try:
gen.send(None)
except StopIteration:
pass
else:
raise RuntimeError('Only one yield in fixture is allowed')


class AsyncFixture(BaseAsyncFixture):
"""
Regular async fixture (i.e. coroutine).
"""

@asynccontextmanager
@async_generator
async def _setup(self, resolved_deps):
__tracebackhide__ = True
await yield_(await self.fixturedef.func(**resolved_deps))

retval = self.fixturedef.func(**resolved_deps)
if isinstance(retval, Coroutine):
self.result = await retval
elif isasyncgen(retval):
self.result = await retval.asend(None)
elif isinstance(retval, Generator):
self.result = retval.send(None)
else:
# Regular synchronous function
self.result = retval

def _install_async_fixture_if_needed(fixturedef, request):
asyncfix = None
deps = {dep: request.getfixturevalue(dep) for dep in fixturedef.argnames}
if iscoroutinefunction(fixturedef.func):
asyncfix = AsyncFixture(fixturedef, deps)
elif isasyncgenfunction(fixturedef.func):
asyncfix = AsyncYieldFixture(fixturedef, deps)
elif any(isinstance(dep, BaseAsyncFixture) for dep in deps.values()):
if isgeneratorfunction(fixturedef.func):
asyncfix = SyncYieldFixtureWithAsyncDeps(fixturedef, deps)
else:
asyncfix = SyncFixtureWithAsyncDeps(fixturedef, deps)
if asyncfix:
fixturedef.cached_result = (asyncfix, request.param_index, None)
return asyncfix
try:
await yield_(self.result)
finally:
if isasyncgen(retval):
try:
await retval.asend(None)
except StopAsyncIteration:
pass
else:
raise RuntimeError("too many yields in fixture")
elif isinstance(retval, Generator):
try:
retval.send(None)
except StopIteration:
pass
else:
raise RuntimeError("too many yields in fixture")


@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_call(item):
if 'trio' in item.keywords:
if item.get_closest_marker("trio") is not None:
if hasattr(item.obj, 'hypothesis'):
# If it's a Hypothesis test, we go in a layer.
item.obj.hypothesis.inner_test = _trio_test_runner_factory(
Expand All @@ -267,10 +206,34 @@ def pytest_runtest_call(item):
yield


@pytest.hookimpl()
def trio_fixture(func):
func._force_trio_fixture = True
return pytest.fixture(func)


def _is_trio_fixture(func, is_trio_test, deps):
if getattr(func, "_force_trio_fixture", False):
return True
if is_trio_test:
if iscoroutinefunction(func) or isasyncgenfunction(func):
return True
if any(isinstance(dep, TrioFixture) for dep in deps.values()):
return True
return False


@pytest.hookimpl
def pytest_fixture_setup(fixturedef, request):
if 'trio' in request.keywords:
return _install_async_fixture_if_needed(fixturedef, request)
is_trio_test = (request.node.get_closest_marker("trio") is not None)
deps = {dep: request.getfixturevalue(dep) for dep in fixturedef.argnames}
if _is_trio_fixture(fixturedef.func, is_trio_test, deps):
if request.scope != "function":
raise RuntimeError("Trio fixtures must be function-scope")
if not is_trio_test:
raise RuntimeError("Trio fixtures can only be used by Trio tests")
fixture = TrioFixture(fixturedef, deps)
fixturedef.cached_result = (fixture, request.param_index, None)
return fixture


################################################################
Expand Down Expand Up @@ -308,6 +271,6 @@ def autojump_clock():
return MockClock(autojump_threshold=0)


@pytest.fixture
async def nursery(request):
@trio_fixture
def nursery(request):
return request.node._trio_nursery
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
install_requires=[
"trio",
"async_generator >= 1.9",
# For node.get_closest_marker
"pytest >= 3.6"
],
keywords=[
'async',
Expand Down

0 comments on commit 29be010

Please sign in to comment.