Skip to content

Commit

Permalink
Merge pull request #192 from jakkdl/exc_group_cause_context
Browse files Browse the repository at this point in the history
fix loss of context/cause on exceptions raised inside open_websocket
  • Loading branch information
jakkdl authored Nov 25, 2024
2 parents f5fd6d7 + b8d1fc7 commit e7706f4
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 29 deletions.
65 changes: 54 additions & 11 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
'''
from __future__ import annotations

import copy
from functools import partial, wraps
import re
import ssl
Expand Down Expand Up @@ -452,7 +453,6 @@ async def test_open_websocket_internal_ki(nursery, monkeypatch, autojump_clock):
Make sure that KI is delivered, and the user exception is in the __cause__ exceptiongroup
"""
async def ki_raising_ping_handler(*args, **kwargs) -> None:
print("raising ki")
raise KeyboardInterrupt
monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", ki_raising_ping_handler)
async def handler(request):
Expand All @@ -474,27 +474,32 @@ async def handler(request):
async def test_open_websocket_internal_exc(nursery, monkeypatch, autojump_clock):
"""_reader_task._handle_ping_event triggers ValueError.
user code also raises exception.
internal exception is in __cause__ exceptiongroup and user exc is delivered
internal exception is in __context__ exceptiongroup and user exc is delivered
"""
my_value_error = ValueError()
internal_error = ValueError()
internal_error.__context__ = TypeError()
user_error = NameError()
user_error_context = KeyError()
async def raising_ping_event(*args, **kwargs) -> None:
raise my_value_error
raise internal_error

monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", raising_ping_event)
async def handler(request):
server_ws = await request.accept()
await server_ws.ping(b"a")

server = await nursery.start(serve_websocket, handler, HOST, 0, None)
with pytest.raises(trio.TooSlowError) as exc_info:
with pytest.raises(type(user_error)) as exc_info:
async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False):
with trio.fail_after(1) as cs:
cs.shield = True
await trio.sleep(2)
await trio.lowlevel.checkpoint()
user_error.__context__ = user_error_context
raise user_error

e_cause = exc_info.value.__cause__
assert isinstance(e_cause, _TRIO_EXC_GROUP_TYPE)
assert my_value_error in e_cause.exceptions
assert exc_info.value is user_error
e_context = exc_info.value.__context__
assert isinstance(e_context, BaseExceptionGroup) # pylint: disable=possibly-used-before-assignment
assert internal_error in e_context.exceptions
assert user_error_context in e_context.exceptions

@fail_after(5)
async def test_open_websocket_cancellations(nursery, monkeypatch, autojump_clock):
Expand All @@ -513,6 +518,8 @@ async def handler(request):
server_ws = await request.accept()
await server_ws.ping(b"a")
user_cancelled = None
user_cancelled_cause = None
user_cancelled_context = None

server = await nursery.start(serve_websocket, handler, HOST, 0, None)
with trio.move_on_after(2):
Expand All @@ -522,8 +529,13 @@ async def handler(request):
await trio.sleep_forever()
except trio.Cancelled as e:
user_cancelled = e
user_cancelled_cause = e.__cause__
user_cancelled_context = e.__context__
raise

assert exc_info.value is user_cancelled
assert exc_info.value.__cause__ is user_cancelled_cause
assert exc_info.value.__context__ is user_cancelled_context

def _trio_default_non_strict_exception_groups() -> bool:
assert re.match(r'^0\.\d\d\.', trio.__version__), "unexpected trio versioning scheme"
Expand Down Expand Up @@ -560,6 +572,24 @@ async def handler(request):
RaisesGroup(ValueError)))).matches(exc.value)


async def test_user_exception_cause(nursery) -> None:
async def handler(request):
await request.accept()
server = await nursery.start(serve_websocket, handler, HOST, 0, None)
e_context = TypeError("foo")
e_primary = ValueError("bar")
e_cause = RuntimeError("zee")
with pytest.raises(ValueError) as exc_info:
async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False):
try:
raise e_context
except TypeError:
raise e_primary from e_cause
e = exc_info.value
assert e is e_primary
assert e.__cause__ is e_cause
assert e.__context__ is e_context

@fail_after(1)
async def test_reject_handshake(nursery):
async def handler(request):
Expand Down Expand Up @@ -1176,3 +1206,16 @@ async def server():
async with trio.open_nursery() as nursery:
nursery.start_soon(server)
nursery.start_soon(client)


def test_copy_exceptions():
# test that exceptions are copy- and pickleable
copy.copy(HandshakeError())
copy.copy(ConnectionTimeout())
copy.copy(DisconnectionTimeout())
assert copy.copy(ConnectionClosed("foo")).reason == "foo"

rej_copy = copy.copy(ConnectionRejected(404, (("a", "b"),), b"c"))
assert rej_copy.status_code == 404
assert rej_copy.headers == (("a", "b"),)
assert rej_copy.body == b"c"
68 changes: 50 additions & 18 deletions trio_websocket/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import ssl
import struct
import urllib.parse
from typing import Iterable, List, Optional, Union
from typing import Iterable, List, NoReturn, Optional, Union

import outcome
import trio
Expand Down Expand Up @@ -151,14 +151,14 @@ async def open_websocket(
# yield to user code. If only one of those raise a non-cancelled exception
# we will raise that non-cancelled exception.
# If we get multiple cancelled, we raise the user's cancelled.
# If both raise exceptions, we raise the user code's exception with the entire
# exception group as the __cause__.
# If both raise exceptions, we raise the user code's exception with __context__
# set to a group containing internal exception(s) + any user exception __context__
# If we somehow get multiple exceptions, but no user exception, then we raise
# TrioWebsocketInternalError.

# If closing the connection fails, then that will be raised as the top
# exception in the last `finally`. If we encountered exceptions in user code
# or in reader task then they will be set as the `__cause__`.
# or in reader task then they will be set as the `__context__`.


async def _open_connection(nursery: trio.Nursery) -> WebSocketConnection:
Expand All @@ -181,10 +181,27 @@ async def _close_connection(connection: WebSocketConnection) -> None:
except trio.TooSlowError:
raise DisconnectionTimeout from None

def _raise(exc: BaseException) -> NoReturn:
"""This helper allows re-raising an exception without __context__ being set."""
# cause does not need special handlng, we simply avoid using `raise .. from ..`
__tracebackhide__ = True
context = exc.__context__
try:
raise exc
finally:
exc.__context__ = context
del exc, context

connection: WebSocketConnection|None=None
close_result: outcome.Maybe[None] | None = None
user_error = None

# Unwrapping exception groups has a lot of pitfalls, one of them stemming from
# the exception we raise also being inside the group that's set as the context.
# This leads to loss of info unless properly handled.
# See https://github.com/python-trio/flake8-async/issues/298
# We therefore avoid having the exceptiongroup included as either cause or context

try:
async with trio.open_nursery() as new_nursery:
result = await outcome.acapture(_open_connection, new_nursery)
Expand All @@ -205,7 +222,7 @@ async def _close_connection(connection: WebSocketConnection) -> None:
except _TRIO_EXC_GROUP_TYPE as e:
# user_error, or exception bubbling up from _reader_task
if len(e.exceptions) == 1:
raise e.exceptions[0]
_raise(e.exceptions[0])

# contains at most 1 non-cancelled exceptions
exception_to_raise: BaseException|None = None
Expand All @@ -218,25 +235,40 @@ async def _close_connection(connection: WebSocketConnection) -> None:
else:
if exception_to_raise is None:
# all exceptions are cancelled
# prefer raising the one from the user, for traceback reasons
# we reraise the user exception and throw out internal
if user_error is not None:
# no reason to raise from e, just to include a bunch of extra
# cancelleds.
raise user_error # pylint: disable=raise-missing-from
_raise(user_error)
# multiple internal Cancelled is not possible afaik
raise e.exceptions[0] # pragma: no cover # pylint: disable=raise-missing-from
raise exception_to_raise
# but if so we just raise one of them
_raise(e.exceptions[0]) # pragma: no cover
# raise the non-cancelled exception
_raise(exception_to_raise)

# if we have any KeyboardInterrupt in the group, make sure to raise it.
# if we have any KeyboardInterrupt in the group, raise a new KeyboardInterrupt
# with the group as cause & context
for sub_exc in e.exceptions:
if isinstance(sub_exc, KeyboardInterrupt):
raise sub_exc from e
raise KeyboardInterrupt from e

# Both user code and internal code raised non-cancelled exceptions.
# We "hide" the internal exception(s) in the __cause__ and surface
# the user_error.
# We set the context to be an exception group containing internal exceptions
# and, if not None, `user_error.__context__`
if user_error is not None:
raise user_error from e
exceptions = [subexc for subexc in e.exceptions if subexc is not user_error]
eg_substr = ''
# there's technically loss of info here, with __suppress_context__=True you
# still have original __context__ available, just not printed. But we delete
# it completely because we can't partially suppress the group
if user_error.__context__ is not None and not user_error.__suppress_context__:
exceptions.append(user_error.__context__)
eg_substr = ' and the context for the user exception'
eg_str = (
"Both internal and user exceptions encountered. This group contains "
"the internal exception(s)" + eg_substr + "."
)
user_error.__context__ = BaseExceptionGroup(eg_str, exceptions)
user_error.__suppress_context__ = False
_raise(user_error)

raise TrioWebsocketInternalError(
"The trio-websocket API is not expected to raise multiple exceptions. "
Expand Down Expand Up @@ -576,7 +608,7 @@ def __init__(self, reason):
:param reason:
:type reason: CloseReason
'''
super().__init__()
super().__init__(reason)
self.reason = reason

def __repr__(self):
Expand All @@ -596,7 +628,7 @@ def __init__(self, status_code, headers, body):
:param reason:
:type reason: CloseReason
'''
super().__init__()
super().__init__(status_code, headers, body)
#: a 3 digit HTTP status code
self.status_code = status_code
#: a tuple of 2-tuples containing header key/value pairs
Expand Down

0 comments on commit e7706f4

Please sign in to comment.