Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement StapledMemoryChannel #1784

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions trio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
open_memory_channel,
MemorySendChannel,
MemoryReceiveChannel,
StapledMemoryChannel,
)

from ._signals import open_signal_receiver
Expand Down
64 changes: 63 additions & 1 deletion trio/_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from outcome import Error, Value

from .abc import SendChannel, ReceiveChannel, Channel
from ._util import generic_function, NoPublicConstructor
from ._util import generic_function, NoPublicConstructor, Final

import trio
from ._core import enable_ki_protection
Expand Down Expand Up @@ -344,3 +344,65 @@ async def aclose(self):
self._state.send_tasks.clear()
self._state.data.clear()
await trio.lowlevel.checkpoint()


@attr.s(auto_attribs=True, eq=False, hash=False)
class StapledMemoryChannel(Channel, metaclass=Final):
"""This class `staples <https://en.wikipedia.org/wiki/Staple_(fastener)>`__
together two memory channel halves to make a bidirectional channel.

Args:
send_channel (~trio.MemorySendChannel): The channel to use for sending.
receive_channel (~trio.MemoryReceiveChannel): The channel to use for
receiving.

Example:

The channel halves from :function:~trio.open_memory_channel can
be bound together and accessed by a simple API::

channel = StapledMemoryChannel(*open_memory_channel(1))
await channel.send("x")
assert await channel.receive() == "x"

:class:`StapledMemoryChannel` objects implement the methods in the
:class:`~trio.abc.Channel` interface, as well as the "nowait" variants
of send and receive. They also have two additional public attributes:

.. attribute:: send_channel

The underlying :class:`~trio.MemorySendChannel`. :meth:`send` and
:meth:`send_nowait` are delegated to this object.

.. attribute:: receive_channel

The underlying :class:`~trio.MemoryReceiveChannel`. :meth:`receive()`
and :meth:`receive_nowait` are delegated to this object.

"""

send_channel: MemorySendChannel
receive_channel: MemoryReceiveChannel

async def send(self, value):
"""Calls ``self.send_channel.send``."""
await self.send_channel.send(value)

def send_nowait(self, value):
"""Calls ``self.send_channel.send_nowait``."""
self.send_channel.send_nowait(value)

async def receive(self):
"""Calls ``self.receive_channel.receive``."""
return await self.receive_channel.receive()

def receive_nowait(self):
"""Calls ``self.receive_channel.receive_nowait``."""
return self.receive_channel.receive_nowait()

async def aclose(self):
"""Calls ``aclose`` on both underlying channels."""
try:
await self.send_channel.aclose()
finally:
await self.receive_channel.aclose()
26 changes: 25 additions & 1 deletion trio/tests/test_channel.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import pytest

from .._abc import Channel
from ..testing import wait_all_tasks_blocked, assert_checkpoints
import trio
from trio import open_memory_channel, EndOfChannel
from trio import open_memory_channel, EndOfChannel, StapledMemoryChannel


async def test_channel():
Expand Down Expand Up @@ -350,3 +351,26 @@ async def do_send(s, v):
assert await r.receive() == 1
with pytest.raises(trio.WouldBlock):
r.receive_nowait()


async def test_stapled_memory_channel():
assert issubclass(StapledMemoryChannel, Channel)
stapled = StapledMemoryChannel(*open_memory_channel(0))

async with trio.open_nursery() as nursery:

@nursery.start_soon
async def _listener():
assert await stapled.receive() == 10

await wait_all_tasks_blocked()
await stapled.send(10)

with pytest.raises(trio.WouldBlock):
stapled.send_nowait(10)
with pytest.raises(trio.WouldBlock):
stapled.receive_nowait()

assert not (stapled.send_channel._closed or stapled.receive_channel._closed)
await stapled.aclose()
assert stapled.send_channel._closed and stapled.receive_channel._closed