Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix loss of context/cause on exceptions raised inside open_websocket #192

Merged
merged 8 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 54 additions & 11 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
'''
from __future__ import annotations

import copy
from functools import partial, wraps
import re
import ssl
Expand Down Expand Up @@ -452,7 +453,6 @@ async def test_open_websocket_internal_ki(nursery, monkeypatch, autojump_clock):
Make sure that KI is delivered, and the user exception is in the __cause__ exceptiongroup
"""
async def ki_raising_ping_handler(*args, **kwargs) -> None:
print("raising ki")
raise KeyboardInterrupt
monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", ki_raising_ping_handler)
async def handler(request):
Expand All @@ -474,27 +474,32 @@ async def handler(request):
async def test_open_websocket_internal_exc(nursery, monkeypatch, autojump_clock):
"""_reader_task._handle_ping_event triggers ValueError.
user code also raises exception.
internal exception is in __cause__ exceptiongroup and user exc is delivered
internal exception is in __context__ exceptiongroup and user exc is delivered
"""
my_value_error = ValueError()
internal_error = ValueError()
internal_error.__context__ = TypeError()
user_error = NameError()
user_error_context = KeyError()
async def raising_ping_event(*args, **kwargs) -> None:
raise my_value_error
raise internal_error

monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", raising_ping_event)
async def handler(request):
server_ws = await request.accept()
await server_ws.ping(b"a")

server = await nursery.start(serve_websocket, handler, HOST, 0, None)
with pytest.raises(trio.TooSlowError) as exc_info:
with pytest.raises(type(user_error)) as exc_info:
async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False):
with trio.fail_after(1) as cs:
cs.shield = True
await trio.sleep(2)
await trio.lowlevel.checkpoint()
user_error.__context__ = user_error_context
raise user_error

e_cause = exc_info.value.__cause__
assert isinstance(e_cause, _TRIO_EXC_GROUP_TYPE)
assert my_value_error in e_cause.exceptions
assert exc_info.value is user_error
e_context = exc_info.value.__context__
assert isinstance(e_context, BaseExceptionGroup) # pylint: disable=possibly-used-before-assignment
assert internal_error in e_context.exceptions
assert user_error_context in e_context.exceptions

@fail_after(5)
async def test_open_websocket_cancellations(nursery, monkeypatch, autojump_clock):
Expand All @@ -513,6 +518,8 @@ async def handler(request):
server_ws = await request.accept()
await server_ws.ping(b"a")
user_cancelled = None
user_cancelled_cause = None
user_cancelled_context = None

server = await nursery.start(serve_websocket, handler, HOST, 0, None)
with trio.move_on_after(2):
Expand All @@ -522,8 +529,13 @@ async def handler(request):
await trio.sleep_forever()
except trio.Cancelled as e:
user_cancelled = e
user_cancelled_cause = e.__cause__
user_cancelled_context = e.__context__
raise

assert exc_info.value is user_cancelled
assert exc_info.value.__cause__ is user_cancelled_cause
assert exc_info.value.__context__ is user_cancelled_context

def _trio_default_non_strict_exception_groups() -> bool:
assert re.match(r'^0\.\d\d\.', trio.__version__), "unexpected trio versioning scheme"
Expand Down Expand Up @@ -560,6 +572,24 @@ async def handler(request):
RaisesGroup(ValueError)))).matches(exc.value)


async def test_user_exception_cause(nursery) -> None:
async def handler(request):
await request.accept()
server = await nursery.start(serve_websocket, handler, HOST, 0, None)
e_context = TypeError("foo")
e_primary = ValueError("bar")
e_cause = RuntimeError("zee")
with pytest.raises(ValueError) as exc_info:
async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False):
try:
raise e_context
except TypeError:
raise e_primary from e_cause
e = exc_info.value
assert e is e_primary
assert e.__cause__ is e_cause
assert e.__context__ is e_context

@fail_after(1)
async def test_reject_handshake(nursery):
async def handler(request):
Expand Down Expand Up @@ -1176,3 +1206,16 @@ async def server():
async with trio.open_nursery() as nursery:
nursery.start_soon(server)
nursery.start_soon(client)


def test_copy_exceptions():
# test that exceptions are copy- and pickleable
copy.copy(HandshakeError())
copy.copy(ConnectionTimeout())
copy.copy(DisconnectionTimeout())
assert copy.copy(ConnectionClosed("foo")).reason == "foo"

rej_copy = copy.copy(ConnectionRejected(404, (("a", "b"),), b"c"))
assert rej_copy.status_code == 404
assert rej_copy.headers == (("a", "b"),)
assert rej_copy.body == b"c"
68 changes: 50 additions & 18 deletions trio_websocket/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import ssl
import struct
import urllib.parse
from typing import Iterable, List, Optional, Union
from typing import Iterable, List, NoReturn, Optional, Union

import outcome
import trio
Expand Down Expand Up @@ -151,14 +151,14 @@ async def open_websocket(
# yield to user code. If only one of those raise a non-cancelled exception
# we will raise that non-cancelled exception.
# If we get multiple cancelled, we raise the user's cancelled.
# If both raise exceptions, we raise the user code's exception with the entire
# exception group as the __cause__.
# If both raise exceptions, we raise the user code's exception with __context__
# set to a group containing internal exception(s) + any user exception __context__
# If we somehow get multiple exceptions, but no user exception, then we raise
# TrioWebsocketInternalError.

# If closing the connection fails, then that will be raised as the top
# exception in the last `finally`. If we encountered exceptions in user code
# or in reader task then they will be set as the `__cause__`.
# or in reader task then they will be set as the `__context__`.


async def _open_connection(nursery: trio.Nursery) -> WebSocketConnection:
Expand All @@ -181,10 +181,27 @@ async def _close_connection(connection: WebSocketConnection) -> None:
except trio.TooSlowError:
raise DisconnectionTimeout from None

def _raise(exc: BaseException) -> NoReturn:
"""This helper allows re-raising an exception without __context__ being set."""
# cause does not need special handlng, we simply avoid using `raise .. from ..`
__tracebackhide__ = True
context = exc.__context__
try:
raise exc
finally:
exc.__context__ = context
del exc, context
graingert marked this conversation as resolved.
Show resolved Hide resolved

connection: WebSocketConnection|None=None
close_result: outcome.Maybe[None] | None = None
user_error = None

# Unwrapping exception groups has a lot of pitfalls, one of them stemming from
# the exception we raise also being inside the group that's set as the context.
# This leads to loss of info unless properly handled.
# See https://github.com/python-trio/flake8-async/issues/298
# We therefore avoid having the exceptiongroup included as either cause or context

try:
async with trio.open_nursery() as new_nursery:
result = await outcome.acapture(_open_connection, new_nursery)
Expand All @@ -205,7 +222,7 @@ async def _close_connection(connection: WebSocketConnection) -> None:
except _TRIO_EXC_GROUP_TYPE as e:
# user_error, or exception bubbling up from _reader_task
if len(e.exceptions) == 1:
raise e.exceptions[0]
_raise(e.exceptions[0])

# contains at most 1 non-cancelled exceptions
exception_to_raise: BaseException|None = None
Expand All @@ -218,25 +235,40 @@ async def _close_connection(connection: WebSocketConnection) -> None:
else:
if exception_to_raise is None:
# all exceptions are cancelled
# prefer raising the one from the user, for traceback reasons
# we reraise the user exception and throw out internal
if user_error is not None:
# no reason to raise from e, just to include a bunch of extra
# cancelleds.
raise user_error # pylint: disable=raise-missing-from
_raise(user_error)
# multiple internal Cancelled is not possible afaik
raise e.exceptions[0] # pragma: no cover # pylint: disable=raise-missing-from
raise exception_to_raise
# but if so we just raise one of them
_raise(e.exceptions[0]) # pragma: no cover
# raise the non-cancelled exception
_raise(exception_to_raise)

# if we have any KeyboardInterrupt in the group, make sure to raise it.
# if we have any KeyboardInterrupt in the group, raise a new KeyboardInterrupt
# with the group as cause & context
for sub_exc in e.exceptions:
if isinstance(sub_exc, KeyboardInterrupt):
raise sub_exc from e
raise KeyboardInterrupt from e

# Both user code and internal code raised non-cancelled exceptions.
# We "hide" the internal exception(s) in the __cause__ and surface
# the user_error.
# We set the context to be an exception group containing internal exceptions
# and, if not None, `user_error.__context__`
if user_error is not None:
raise user_error from e
exceptions = [subexc for subexc in e.exceptions if subexc is not user_error]
eg_substr = ''
# there's technically loss of info here, with __suppress_context__=True you
# still have original __context__ available, just not printed. But we delete
# it completely because we can't partially suppress the group
if user_error.__context__ is not None and not user_error.__suppress_context__:
exceptions.append(user_error.__context__)
eg_substr = ' and the context for the user exception'
eg_str = (
"Both internal and user exceptions encountered. This group contains "
"the internal exception(s)" + eg_substr + "."
)
user_error.__context__ = BaseExceptionGroup(eg_str, exceptions)
user_error.__suppress_context__ = False
graingert marked this conversation as resolved.
Show resolved Hide resolved
_raise(user_error)

raise TrioWebsocketInternalError(
"The trio-websocket API is not expected to raise multiple exceptions. "
Expand Down Expand Up @@ -576,7 +608,7 @@ def __init__(self, reason):
:param reason:
:type reason: CloseReason
'''
super().__init__()
super().__init__(reason)
self.reason = reason

def __repr__(self):
Expand All @@ -596,7 +628,7 @@ def __init__(self, status_code, headers, body):
:param reason:
:type reason: CloseReason
'''
super().__init__()
super().__init__(status_code, headers, body)
#: a 3 digit HTTP status code
self.status_code = status_code
#: a tuple of 2-tuples containing header key/value pairs
Expand Down
Loading