Skip to content

Commit

Permalink
fix loss of context/cause on exceptions raised inside open_websocket
Browse files Browse the repository at this point in the history
  • Loading branch information
jakkdl committed Oct 20, 2024
1 parent f5fd6d7 commit 77e5779
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 7 deletions.
37 changes: 36 additions & 1 deletion tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,8 @@ async def handler(request):
server_ws = await request.accept()
await server_ws.ping(b"a")
user_cancelled = None
user_cancelled_cause = None
user_cancelled_context = None

server = await nursery.start(serve_websocket, handler, HOST, 0, None)
with trio.move_on_after(2):
Expand All @@ -522,8 +524,18 @@ async def handler(request):
await trio.sleep_forever()
except trio.Cancelled as e:
user_cancelled = e
user_cancelled_cause = e.__cause__
user_cancelled_context = e.__context__
raise
assert exc_info.value is user_cancelled

# a copy of user_cancelled is reraised
assert exc_info.value is not user_cancelled
# with the same cause
assert exc_info.value.__cause__ is user_cancelled_cause
# the context is the exception group, which contains the original user_cancelled
assert exc_info.value.__context__.exceptions[1] is user_cancelled
assert exc_info.value.__context__.exceptions[1].__cause__ is user_cancelled_cause
assert exc_info.value.__context__.exceptions[1].__context__ is user_cancelled_context

def _trio_default_non_strict_exception_groups() -> bool:
assert re.match(r'^0\.\d\d\.', trio.__version__), "unexpected trio versioning scheme"
Expand Down Expand Up @@ -560,6 +572,29 @@ async def handler(request):
RaisesGroup(ValueError)))).matches(exc.value)


async def test_user_exception_cause(nursery) -> None:
async def handler(request):
await request.accept()
server = await nursery.start(serve_websocket, handler, HOST, 0, None)
e_context = TypeError("foo")
e_primary = ValueError("bar")
e_cause = RuntimeError("zee")
with pytest.raises(ValueError) as exc_info:
async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False):
try:
raise e_context
except TypeError:
raise e_primary from e_cause
e = exc_info.value
# a copy is reraised
assert e is not e_primary
assert e.__cause__ is e_cause

# the nursery-internal group is injected as context
assert isinstance(e.__context__, ExceptionGroup)
assert e.__context__.exceptions[0] is e_primary
assert e.__context__.exceptions[0].__context__ is e_context

@fail_after(1)
async def test_reject_handshake(nursery):
async def handler(request):
Expand Down
23 changes: 17 additions & 6 deletions trio_websocket/_impl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import copy
import sys
from collections import OrderedDict
from contextlib import asynccontextmanager
Expand Down Expand Up @@ -91,6 +92,16 @@ def __exit__(self, ty, value, tb):
filtered_exception = _ignore_cancel(value)
return filtered_exception is None

def copy_exc(e: BaseException) -> BaseException:
"""Copy an exception.
`copy.copy` fails on `trio.Cancelled`, and on exceptions with a custom `__init__`
that calls `super().__init__()`. It may be the case that this also fails on something.
"""
cls = type(e)
result = cls.__new__(cls)
result.__dict__ = copy.copy(e.__dict__)
return result

@asynccontextmanager
async def open_websocket(
Expand Down Expand Up @@ -205,7 +216,7 @@ async def _close_connection(connection: WebSocketConnection) -> None:
except _TRIO_EXC_GROUP_TYPE as e:
# user_error, or exception bubbling up from _reader_task
if len(e.exceptions) == 1:
raise e.exceptions[0]
raise copy_exc(e.exceptions[0]) from e.exceptions[0].__cause__

# contains at most 1 non-cancelled exceptions
exception_to_raise: BaseException|None = None
Expand All @@ -222,21 +233,21 @@ async def _close_connection(connection: WebSocketConnection) -> None:
if user_error is not None:
# no reason to raise from e, just to include a bunch of extra
# cancelleds.
raise user_error # pylint: disable=raise-missing-from
raise copy_exc(user_error) from user_error.__cause__
# multiple internal Cancelled is not possible afaik
raise e.exceptions[0] # pragma: no cover # pylint: disable=raise-missing-from
raise exception_to_raise
raise copy_exc(e.exceptions[0]) from e # pragma: no cover
raise copy_exc(exception_to_raise) from exception_to_raise.__cause__

# if we have any KeyboardInterrupt in the group, make sure to raise it.
for sub_exc in e.exceptions:
if isinstance(sub_exc, KeyboardInterrupt):
raise sub_exc from e
raise copy_exc(sub_exc) from e

# Both user code and internal code raised non-cancelled exceptions.
# We "hide" the internal exception(s) in the __cause__ and surface
# the user_error.
if user_error is not None:
raise user_error from e
raise copy_exc(user_error) from e

raise TrioWebsocketInternalError(
"The trio-websocket API is not expected to raise multiple exceptions. "
Expand Down

0 comments on commit 77e5779

Please sign in to comment.