From 77e5779d8813e9eb7c76a6cc43cc810cb008a508 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sun, 20 Oct 2024 15:57:00 +0200 Subject: [PATCH 1/8] fix loss of context/cause on exceptions raised inside open_websocket --- tests/test_connection.py | 37 ++++++++++++++++++++++++++++++++++++- trio_websocket/_impl.py | 23 +++++++++++++++++------ 2 files changed, 53 insertions(+), 7 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 6cccefa..e1292a8 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -513,6 +513,8 @@ async def handler(request): server_ws = await request.accept() await server_ws.ping(b"a") user_cancelled = None + user_cancelled_cause = None + user_cancelled_context = None server = await nursery.start(serve_websocket, handler, HOST, 0, None) with trio.move_on_after(2): @@ -522,8 +524,18 @@ async def handler(request): await trio.sleep_forever() except trio.Cancelled as e: user_cancelled = e + user_cancelled_cause = e.__cause__ + user_cancelled_context = e.__context__ raise - assert exc_info.value is user_cancelled + + # a copy of user_cancelled is reraised + assert exc_info.value is not user_cancelled + # with the same cause + assert exc_info.value.__cause__ is user_cancelled_cause + # the context is the exception group, which contains the original user_cancelled + assert exc_info.value.__context__.exceptions[1] is user_cancelled + assert exc_info.value.__context__.exceptions[1].__cause__ is user_cancelled_cause + assert exc_info.value.__context__.exceptions[1].__context__ is user_cancelled_context def _trio_default_non_strict_exception_groups() -> bool: assert re.match(r'^0\.\d\d\.', trio.__version__), "unexpected trio versioning scheme" @@ -560,6 +572,29 @@ async def handler(request): RaisesGroup(ValueError)))).matches(exc.value) +async def test_user_exception_cause(nursery) -> None: + async def handler(request): + await request.accept() + server = await nursery.start(serve_websocket, handler, HOST, 0, None) + e_context = TypeError("foo") + e_primary = ValueError("bar") + e_cause = RuntimeError("zee") + with pytest.raises(ValueError) as exc_info: + async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): + try: + raise e_context + except TypeError: + raise e_primary from e_cause + e = exc_info.value + # a copy is reraised + assert e is not e_primary + assert e.__cause__ is e_cause + + # the nursery-internal group is injected as context + assert isinstance(e.__context__, ExceptionGroup) + assert e.__context__.exceptions[0] is e_primary + assert e.__context__.exceptions[0].__context__ is e_context + @fail_after(1) async def test_reject_handshake(nursery): async def handler(request): diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index a71e0be..e69959b 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import sys from collections import OrderedDict from contextlib import asynccontextmanager @@ -91,6 +92,16 @@ def __exit__(self, ty, value, tb): filtered_exception = _ignore_cancel(value) return filtered_exception is None +def copy_exc(e: BaseException) -> BaseException: + """Copy an exception. + + `copy.copy` fails on `trio.Cancelled`, and on exceptions with a custom `__init__` + that calls `super().__init__()`. It may be the case that this also fails on something. + """ + cls = type(e) + result = cls.__new__(cls) + result.__dict__ = copy.copy(e.__dict__) + return result @asynccontextmanager async def open_websocket( @@ -205,7 +216,7 @@ async def _close_connection(connection: WebSocketConnection) -> None: except _TRIO_EXC_GROUP_TYPE as e: # user_error, or exception bubbling up from _reader_task if len(e.exceptions) == 1: - raise e.exceptions[0] + raise copy_exc(e.exceptions[0]) from e.exceptions[0].__cause__ # contains at most 1 non-cancelled exceptions exception_to_raise: BaseException|None = None @@ -222,21 +233,21 @@ async def _close_connection(connection: WebSocketConnection) -> None: if user_error is not None: # no reason to raise from e, just to include a bunch of extra # cancelleds. - raise user_error # pylint: disable=raise-missing-from + raise copy_exc(user_error) from user_error.__cause__ # multiple internal Cancelled is not possible afaik - raise e.exceptions[0] # pragma: no cover # pylint: disable=raise-missing-from - raise exception_to_raise + raise copy_exc(e.exceptions[0]) from e # pragma: no cover + raise copy_exc(exception_to_raise) from exception_to_raise.__cause__ # if we have any KeyboardInterrupt in the group, make sure to raise it. for sub_exc in e.exceptions: if isinstance(sub_exc, KeyboardInterrupt): - raise sub_exc from e + raise copy_exc(sub_exc) from e # Both user code and internal code raised non-cancelled exceptions. # We "hide" the internal exception(s) in the __cause__ and surface # the user_error. if user_error is not None: - raise user_error from e + raise copy_exc(user_error) from e raise TrioWebsocketInternalError( "The trio-websocket API is not expected to raise multiple exceptions. " From df0f56d06c8462b75aacae4443ffafb8ac001346 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sun, 20 Oct 2024 17:07:14 +0200 Subject: [PATCH 2/8] fix test on non-strict --- tests/test_connection.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index e1292a8..4058409 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -586,14 +586,19 @@ async def handler(request): except TypeError: raise e_primary from e_cause e = exc_info.value - # a copy is reraised - assert e is not e_primary - assert e.__cause__ is e_cause - - # the nursery-internal group is injected as context - assert isinstance(e.__context__, ExceptionGroup) - assert e.__context__.exceptions[0] is e_primary - assert e.__context__.exceptions[0].__context__ is e_context + if _trio_default_non_strict_exception_groups(): + assert e is e_primary + assert e.__cause__ is e_cause + assert e.__context__ is e_context + else: + # a copy is reraised to avoid losing e_context + assert e is not e_primary + assert e.__cause__ is e_cause + + # the nursery-internal group is injected as context + assert isinstance(e.__context__, ExceptionGroup) + assert e.__context__.exceptions[0] is e_primary + assert e.__context__.exceptions[0].__context__ is e_context @fail_after(1) async def test_reject_handshake(nursery): From a4e0562ce46f39ab761aaa00718754a1c1953a7c Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sun, 20 Oct 2024 17:10:34 +0200 Subject: [PATCH 3/8] fix another test fail --- tests/test_connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 4058409..61c2848 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -596,7 +596,7 @@ async def handler(request): assert e.__cause__ is e_cause # the nursery-internal group is injected as context - assert isinstance(e.__context__, ExceptionGroup) + assert isinstance(e.__context__, _TRIO_EXC_GROUP_TYPE) assert e.__context__.exceptions[0] is e_primary assert e.__context__.exceptions[0].__context__ is e_context From 5defd0341677029efbfdaf640cd254425f6c1d1f Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 23 Oct 2024 12:19:30 +0200 Subject: [PATCH 4/8] no-copy solution that completely hides the exceptiongroup in most cases --- tests/test_connection.py | 51 ++++++++++++++-------------------- trio_websocket/_impl.py | 60 +++++++++++++++++++++++++++++++--------- 2 files changed, 67 insertions(+), 44 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 61c2848..304390d 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -452,7 +452,6 @@ async def test_open_websocket_internal_ki(nursery, monkeypatch, autojump_clock): Make sure that KI is delivered, and the user exception is in the __cause__ exceptiongroup """ async def ki_raising_ping_handler(*args, **kwargs) -> None: - print("raising ki") raise KeyboardInterrupt monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", ki_raising_ping_handler) async def handler(request): @@ -474,11 +473,14 @@ async def handler(request): async def test_open_websocket_internal_exc(nursery, monkeypatch, autojump_clock): """_reader_task._handle_ping_event triggers ValueError. user code also raises exception. - internal exception is in __cause__ exceptiongroup and user exc is delivered + internal exception is in __context__ exceptiongroup and user exc is delivered """ - my_value_error = ValueError() + internal_error = ValueError() + internal_error.__context__ = TypeError() + user_error = NameError() + user_error_context = KeyError() async def raising_ping_event(*args, **kwargs) -> None: - raise my_value_error + raise internal_error monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", raising_ping_event) async def handler(request): @@ -486,15 +488,17 @@ async def handler(request): await server_ws.ping(b"a") server = await nursery.start(serve_websocket, handler, HOST, 0, None) - with pytest.raises(trio.TooSlowError) as exc_info: + with pytest.raises(type(user_error)) as exc_info: async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): - with trio.fail_after(1) as cs: - cs.shield = True - await trio.sleep(2) + await trio.lowlevel.checkpoint() + user_error.__context__ = user_error_context + raise user_error - e_cause = exc_info.value.__cause__ - assert isinstance(e_cause, _TRIO_EXC_GROUP_TYPE) - assert my_value_error in e_cause.exceptions + assert exc_info.value is user_error + e_context = exc_info.value.__context__ + assert isinstance(e_context, BaseExceptionGroup) + assert internal_error in e_context.exceptions + assert user_error_context in e_context.exceptions @fail_after(5) async def test_open_websocket_cancellations(nursery, monkeypatch, autojump_clock): @@ -528,14 +532,9 @@ async def handler(request): user_cancelled_context = e.__context__ raise - # a copy of user_cancelled is reraised - assert exc_info.value is not user_cancelled - # with the same cause + assert exc_info.value is user_cancelled assert exc_info.value.__cause__ is user_cancelled_cause - # the context is the exception group, which contains the original user_cancelled - assert exc_info.value.__context__.exceptions[1] is user_cancelled - assert exc_info.value.__context__.exceptions[1].__cause__ is user_cancelled_cause - assert exc_info.value.__context__.exceptions[1].__context__ is user_cancelled_context + assert exc_info.value.__context__ is user_cancelled_context def _trio_default_non_strict_exception_groups() -> bool: assert re.match(r'^0\.\d\d\.', trio.__version__), "unexpected trio versioning scheme" @@ -586,19 +585,9 @@ async def handler(request): except TypeError: raise e_primary from e_cause e = exc_info.value - if _trio_default_non_strict_exception_groups(): - assert e is e_primary - assert e.__cause__ is e_cause - assert e.__context__ is e_context - else: - # a copy is reraised to avoid losing e_context - assert e is not e_primary - assert e.__cause__ is e_cause - - # the nursery-internal group is injected as context - assert isinstance(e.__context__, _TRIO_EXC_GROUP_TYPE) - assert e.__context__.exceptions[0] is e_primary - assert e.__context__.exceptions[0].__context__ is e_context + assert e is e_primary + assert e.__cause__ is e_cause + assert e.__context__ is e_context @fail_after(1) async def test_reject_handshake(nursery): diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index e69959b..4ae7338 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -12,7 +12,7 @@ import ssl import struct import urllib.parse -from typing import Iterable, List, Optional, Union +from typing import Iterable, List, NoReturn, Optional, Union import outcome import trio @@ -192,10 +192,29 @@ async def _close_connection(connection: WebSocketConnection) -> None: except trio.TooSlowError: raise DisconnectionTimeout from None + def _raise(exc: BaseException) -> NoReturn: + __tracebackhide__ = True + context = exc.__context__ + try: + raise exc + finally: + exc.__context__ = context + del exc, context + connection: WebSocketConnection|None=None close_result: outcome.Maybe[None] | None = None user_error = None + # Unwrapping exception groups has a lot of pitfalls, one of them stemming from + # the exception we raise also being inside the group that's set as the context. + # This leads to loss of info unless properly handled. + # See https://github.com/python-trio/flake8-async/issues/298 + # We therefore save the exception before raising it, and save our intended context, + # so they can be modified in the `finally`. + exc_to_raise = None + exc_context = None + # by avoiding use of `raise .. from ..` we leave the original __cause__ + try: async with trio.open_nursery() as new_nursery: result = await outcome.acapture(_open_connection, new_nursery) @@ -216,7 +235,7 @@ async def _close_connection(connection: WebSocketConnection) -> None: except _TRIO_EXC_GROUP_TYPE as e: # user_error, or exception bubbling up from _reader_task if len(e.exceptions) == 1: - raise copy_exc(e.exceptions[0]) from e.exceptions[0].__cause__ + _raise(e.exceptions[0]) # contains at most 1 non-cancelled exceptions exception_to_raise: BaseException|None = None @@ -229,25 +248,40 @@ async def _close_connection(connection: WebSocketConnection) -> None: else: if exception_to_raise is None: # all exceptions are cancelled - # prefer raising the one from the user, for traceback reasons + # we reraise the user exception and throw out internal if user_error is not None: - # no reason to raise from e, just to include a bunch of extra - # cancelleds. - raise copy_exc(user_error) from user_error.__cause__ + _raise(user_error) # multiple internal Cancelled is not possible afaik - raise copy_exc(e.exceptions[0]) from e # pragma: no cover - raise copy_exc(exception_to_raise) from exception_to_raise.__cause__ + # but if so we just raise one of them + _raise(e.exceptions[0]) + # raise the non-cancelled exception + _raise(exception_to_raise) - # if we have any KeyboardInterrupt in the group, make sure to raise it. + # if we have any KeyboardInterrupt in the group, raise a new KeyboardInterrupt + # with the group as cause & context for sub_exc in e.exceptions: if isinstance(sub_exc, KeyboardInterrupt): - raise copy_exc(sub_exc) from e + raise KeyboardInterrupt from e # Both user code and internal code raised non-cancelled exceptions. - # We "hide" the internal exception(s) in the __cause__ and surface - # the user_error. + # We set the context to be an exception group containing internal exceptions + # and, if not None, `user_error.__context__` if user_error is not None: - raise copy_exc(user_error) from e + exceptions = [subexc for subexc in e.exceptions if subexc is not user_error] + eg_substr = '' + # there's technically loss of info here, with __suppress_context__=True you + # still have original __context__ available, just not printed. But we delete + # it completely because we can't partially suppress the group + if user_error.__context__ is not None and not user_error.__suppress_context__: + exceptions.append(user_error.__context__) + eg_substr = ' and the context for the user exception' + eg_str = ( + "Both internal and user exceptions encountered. This group contains " + "the internal exception(s)" + eg_substr + "." + ) + user_error.__context__ = BaseExceptionGroup(eg_str, exceptions) + user_error.__suppress_context__ = False + _raise(user_error) raise TrioWebsocketInternalError( "The trio-websocket API is not expected to raise multiple exceptions. " From 1c5bf53130d95d9ea2b8a231ed920940374af6ce Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 23 Oct 2024 12:36:48 +0200 Subject: [PATCH 5/8] fix pylint --- tests/test_connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 304390d..326a133 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -496,7 +496,7 @@ async def handler(request): assert exc_info.value is user_error e_context = exc_info.value.__context__ - assert isinstance(e_context, BaseExceptionGroup) + assert isinstance(e_context, BaseExceptionGroup) # pylint: disable=possibly-used-before-assignment assert internal_error in e_context.exceptions assert user_error_context in e_context.exceptions From 16d8d9f1dbfba049f0e00596b22c94e3388f3e09 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 25 Oct 2024 11:59:07 +0200 Subject: [PATCH 6/8] remove unused _copy_exc --- trio_websocket/_impl.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 4ae7338..fa16cfb 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -92,16 +92,6 @@ def __exit__(self, ty, value, tb): filtered_exception = _ignore_cancel(value) return filtered_exception is None -def copy_exc(e: BaseException) -> BaseException: - """Copy an exception. - - `copy.copy` fails on `trio.Cancelled`, and on exceptions with a custom `__init__` - that calls `super().__init__()`. It may be the case that this also fails on something. - """ - cls = type(e) - result = cls.__new__(cls) - result.__dict__ = copy.copy(e.__dict__) - return result @asynccontextmanager async def open_websocket( From 1b0be0583c2195f4d69f67617399b9ed125d7483 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 25 Oct 2024 12:24:11 +0200 Subject: [PATCH 7/8] small cleanups --- trio_websocket/_impl.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index fa16cfb..62c7ff7 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -1,6 +1,5 @@ from __future__ import annotations -import copy import sys from collections import OrderedDict from contextlib import asynccontextmanager @@ -152,14 +151,14 @@ async def open_websocket( # yield to user code. If only one of those raise a non-cancelled exception # we will raise that non-cancelled exception. # If we get multiple cancelled, we raise the user's cancelled. - # If both raise exceptions, we raise the user code's exception with the entire - # exception group as the __cause__. + # If both raise exceptions, we raise the user code's exception with __context__ + # set to a group containing internal exception(s) + any user exception __context__ # If we somehow get multiple exceptions, but no user exception, then we raise # TrioWebsocketInternalError. # If closing the connection fails, then that will be raised as the top # exception in the last `finally`. If we encountered exceptions in user code - # or in reader task then they will be set as the `__cause__`. + # or in reader task then they will be set as the `__context__`. async def _open_connection(nursery: trio.Nursery) -> WebSocketConnection: @@ -183,6 +182,8 @@ async def _close_connection(connection: WebSocketConnection) -> None: raise DisconnectionTimeout from None def _raise(exc: BaseException) -> NoReturn: + """This helper allows re-raising an exception without __context__ being set.""" + # cause does not need special handlng, we simply avoid using `raise .. from ..` __tracebackhide__ = True context = exc.__context__ try: @@ -199,11 +200,7 @@ def _raise(exc: BaseException) -> NoReturn: # the exception we raise also being inside the group that's set as the context. # This leads to loss of info unless properly handled. # See https://github.com/python-trio/flake8-async/issues/298 - # We therefore save the exception before raising it, and save our intended context, - # so they can be modified in the `finally`. - exc_to_raise = None - exc_context = None - # by avoiding use of `raise .. from ..` we leave the original __cause__ + # We therefore avoid having the exceptiongroup included as either cause or context try: async with trio.open_nursery() as new_nursery: @@ -243,7 +240,7 @@ def _raise(exc: BaseException) -> NoReturn: _raise(user_error) # multiple internal Cancelled is not possible afaik # but if so we just raise one of them - _raise(e.exceptions[0]) + _raise(e.exceptions[0]) # pragma: no cover # raise the non-cancelled exception _raise(exception_to_raise) From b8d1fc7fea4cadd70619cfe852711e02a3095c42 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 29 Oct 2024 12:24:52 +0100 Subject: [PATCH 8/8] make exceptions copy- and pickleable --- tests/test_connection.py | 14 ++++++++++++++ trio_websocket/_impl.py | 4 ++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 326a133..0837aa5 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -31,6 +31,7 @@ ''' from __future__ import annotations +import copy from functools import partial, wraps import re import ssl @@ -1205,3 +1206,16 @@ async def server(): async with trio.open_nursery() as nursery: nursery.start_soon(server) nursery.start_soon(client) + + +def test_copy_exceptions(): + # test that exceptions are copy- and pickleable + copy.copy(HandshakeError()) + copy.copy(ConnectionTimeout()) + copy.copy(DisconnectionTimeout()) + assert copy.copy(ConnectionClosed("foo")).reason == "foo" + + rej_copy = copy.copy(ConnectionRejected(404, (("a", "b"),), b"c")) + assert rej_copy.status_code == 404 + assert rej_copy.headers == (("a", "b"),) + assert rej_copy.body == b"c" diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index 62c7ff7..5f3a9d4 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -608,7 +608,7 @@ def __init__(self, reason): :param reason: :type reason: CloseReason ''' - super().__init__() + super().__init__(reason) self.reason = reason def __repr__(self): @@ -628,7 +628,7 @@ def __init__(self, status_code, headers, body): :param reason: :type reason: CloseReason ''' - super().__init__() + super().__init__(status_code, headers, body) #: a 3 digit HTTP status code self.status_code = status_code #: a tuple of 2-tuples containing header key/value pairs