diff --git a/docs/_newsfragments/2146.bugfix.rst b/docs/_newsfragments/2146.bugfix.rst new file mode 100644 index 000000000..ab27e52ba --- /dev/null +++ b/docs/_newsfragments/2146.bugfix.rst @@ -0,0 +1,6 @@ + +:ref:`WebSocket ` implementation has been fixed to properly handle +:class:`~falcon.HTTPError` and :class:`~falcon.HTTPStatus` exceptions raised by +custom :func:`error handlers `. +The WebSocket connection is now correctly closed with an appropriate code +instead of bubbling up an unhandled error to the application server. diff --git a/falcon/asgi/app.py b/falcon/asgi/app.py index 5f8cc4c90..0f63ad5ce 100644 --- a/falcon/asgi/app.py +++ b/falcon/asgi/app.py @@ -45,6 +45,7 @@ from .request import Request from .response import Response from .structures import SSEvent +from .ws import http_status_to_ws_code from .ws import WebSocket from .ws import WebSocketOptions @@ -1027,30 +1028,47 @@ def _prepare_middleware(self, middleware=None, independent_middleware=False): asgi=True, ) - async def _http_status_handler(self, req, resp, status, params): - self._compose_status_response(req, resp, status) + async def _http_status_handler(self, req, resp, status, params, ws=None): + if resp: + self._compose_status_response(req, resp, status) + elif ws: + code = http_status_to_ws_code(status.status) + falcon._logger.error( + '[FALCON] HTTPStatus %s raised while handling WebSocket. ' + 'Closing with code %s', + status, + code, + ) + await ws.close(code) + else: + raise NotImplementedError('resp or ws expected') async def _http_error_handler(self, req, resp, error, params, ws=None): if resp: self._compose_error_response(req, resp, error) - - if ws: + elif ws: + # NOTE(vytas): error.status_code is not yet in this backport. + # code = http_status_to_ws_code(error.status_code) + code = http_status_to_ws_code(http_status_to_code(error.status)) falcon._logger.error( - '[FALCON] WebSocket handshake rejected due to raised HTTP error: %s', + '[FALCON] HTTPError %s raised while handling WebSocket. ' + 'Closing with code %s', error, + code, ) - - code = 3000 + falcon.util.http_status_to_code(error.status) await ws.close(code) + else: + raise NotImplementedError('resp or ws expected') async def _python_error_handler(self, req, resp, error, params, ws=None): falcon._logger.error('[FALCON] Unhandled exception in ASGI app', exc_info=error) if resp: self._compose_error_response(req, resp, falcon.HTTPInternalServerError()) - - if ws: + elif ws: await self._ws_cleanup_on_error(ws) + else: + raise NotImplementedError('resp or ws expected') async def _ws_disconnected_error_handler(self, req, resp, error, params, ws): falcon._logger.debug( @@ -1095,9 +1113,9 @@ async def _handle_exception(self, req, resp, ex, params, ws=None): await err_handler(req, resp, ex, params, **kwargs) except HTTPStatus as status: - self._compose_status_response(req, resp, status) + await self._http_status_handler(req, resp, status, params, ws=ws) except HTTPError as error: - self._compose_error_response(req, resp, error) + await self._http_error_handler(req, resp, error, params, ws=ws) return True diff --git a/falcon/asgi/ws.py b/falcon/asgi/ws.py index 19dd2f2f1..e7db3d5fa 100644 --- a/falcon/asgi/ws.py +++ b/falcon/asgi/ws.py @@ -701,3 +701,8 @@ async def _pump(self): if self._pop_message_waiter is not None: self._pop_message_waiter.set_result(None) self._pop_message_waiter = None + + +def http_status_to_ws_code(http_status: int) -> int: + """Convert the provided http status to a websocket close code by adding 3000.""" + return http_status + 3000 diff --git a/falcon/util/misc.py b/falcon/util/misc.py index 4cdc7c5b2..b7c7746b1 100644 --- a/falcon/util/misc.py +++ b/falcon/util/misc.py @@ -456,7 +456,9 @@ def code_to_http_status(status): if isinstance(status, http.HTTPStatus): return '{} {}'.format(status.value, status.phrase) - if isinstance(status, str): + # NOTE(kgriffs): If it is a str but does not have a space, assume it is + # just the number by itself. + if isinstance(status, str) and ' ' in status: return status if isinstance(status, bytes): @@ -464,10 +466,10 @@ def code_to_http_status(status): try: code = int(status) - if not 100 <= code <= 999: - raise ValueError('{} is not a valid status code'.format(status)) except (ValueError, TypeError): raise ValueError('{!r} is not a valid status code'.format(status)) + if not 100 <= code <= 999: + raise ValueError('{} is not a valid status code'.format(status)) try: # NOTE(kgriffs): We do this instead of using http.HTTPStatus since diff --git a/tests/asgi/test_misc.py b/tests/asgi/test_misc.py new file mode 100644 index 000000000..1a609db3e --- /dev/null +++ b/tests/asgi/test_misc.py @@ -0,0 +1,32 @@ +# misc test for 100% coverage + +from unittest.mock import MagicMock + +import pytest + +from falcon.asgi import App +from falcon.http_error import HTTPError +from falcon.http_status import HTTPStatus + + +@pytest.mark.asyncio +async def test_http_status_not_impl(): + app = App() + with pytest.raises(NotImplementedError): + await app._http_status_handler(MagicMock(), None, HTTPStatus(200), {}, None) + + +@pytest.mark.asyncio +async def test_http_error_not_impl(): + app = App() + with pytest.raises(NotImplementedError): + await app._http_error_handler(MagicMock(), None, HTTPError(400), {}, None) + + +@pytest.mark.asyncio +async def test_python_error_not_impl(): + app = App() + with pytest.raises(NotImplementedError): + await app._python_error_handler( + MagicMock(), None, ValueError('error'), {}, None + ) diff --git a/tests/asgi/test_ws.py b/tests/asgi/test_ws.py index dd327b405..7ada6cb24 100644 --- a/tests/asgi/test_ws.py +++ b/tests/asgi/test_ws.py @@ -1093,3 +1093,126 @@ def test_msgpack_missing(): with pytest.raises(RuntimeError): handler.deserialize(b'{}') + + +@pytest.mark.asyncio +@pytest.mark.parametrize('status', [200, 500, 422, 400]) +@pytest.mark.parametrize('thing', [falcon.HTTPStatus, falcon.HTTPError]) +@pytest.mark.parametrize('accept', [True, False]) +async def test_ws_http_error_or_status_response(conductor, status, thing, accept): + class Resource: + async def on_websocket(self, req, ws): + if accept: + await ws.accept() + raise thing(status) + + conductor.app.add_route('/', Resource()) + exp_code = 3000 + status + + async with conductor as c: + if accept: + async with c.simulate_ws() as ws: + assert ws.closed + assert ws.close_code == exp_code + else: + with pytest.raises(falcon.WebSocketDisconnected) as err: + async with c.simulate_ws(): + pass + assert err.value.code == exp_code + + +@pytest.mark.asyncio +@pytest.mark.parametrize('status', [200, 500, 422, 400]) +@pytest.mark.parametrize( + 'thing', + [ + falcon.HTTPStatus, + falcon.HTTPError, + ], +) +@pytest.mark.parametrize('place', ['request', 'resource']) +async def test_ws_http_error_or_status_middleware(conductor, status, thing, place): + called = False + + class Resource: + async def on_websocket(self, req, ws): + nonlocal called + called = True + + class Middleware: + async def process_request_ws(self, req, ws): + if place == 'request': + raise thing(status) + + async def process_resource_ws(self, req, ws, res, params): + if place == 'resource': + raise thing(status) + + conductor.app.add_route('/', Resource()) + conductor.app.add_middleware(Middleware()) + exp_code = 3000 + status + + async with conductor as c: + with pytest.raises(falcon.WebSocketDisconnected) as err: + async with c.simulate_ws(): + pass + assert err.value.code == exp_code + assert not called + + +class FooBarError(Exception): + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize('status', [200, 500, 422, 400]) +@pytest.mark.parametrize('thing', [falcon.HTTPStatus, falcon.HTTPError]) +@pytest.mark.parametrize( + 'place', ['request', 'resource', 'ws_before_accept', 'ws_after_accept'] +) +@pytest.mark.parametrize('handler_has_ws', [True, False]) +async def test_ws_http_error_or_status_error_handler( + conductor, status, thing, place, handler_has_ws +): + class Resource: + async def on_websocket(self, req, ws): + if place == 'ws_before_accept': + raise FooBarError + await ws.accept() + if place == 'ws_after_accept': + raise FooBarError + + class Middleware: + async def process_request_ws(self, req, ws): + if place == 'request': + raise FooBarError + + async def process_resource_ws(self, req, ws, res, params): + if place == 'resource': + raise FooBarError + + if handler_has_ws: + + async def handle_foobar(req, resp, ex, param, ws=None): # type: ignore + raise thing(status) + + else: + + async def handle_foobar(req, resp, ex, param): # type: ignore + raise thing(status) + + conductor.app.add_route('/', Resource()) + conductor.app.add_middleware(Middleware()) + conductor.app.add_error_handler(FooBarError, handle_foobar) + exp_code = 3000 + status + + async with conductor as c: + if place == 'ws_after_accept': + async with c.simulate_ws() as ws: + assert ws.closed + assert ws.close_code == exp_code + else: + with pytest.raises(falcon.WebSocketDisconnected) as err: + async with c.simulate_ws(): + pass + assert err.value.code == exp_code diff --git a/tests/test_utils.py b/tests/test_utils.py index 1e618fd40..13d0a6cb1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -522,7 +522,23 @@ def test_get_http_status(self): def test_code_to_http_status(self, v_in, v_out): assert falcon.code_to_http_status(v_in) == v_out - @pytest.mark.parametrize('v', [0, 13, 99, 1000, 1337.01, -99, -404.3, -404, -404.3]) + @pytest.mark.parametrize( + 'v', + [ + 0, + 13, + 99, + 1000, + 1337.01, + -99, + -404.3, + -404, + -404.3, + 'Successful', + 'Failed', + None, + ], + ) def test_code_to_http_status_value_error(self, v): with pytest.raises(ValueError): falcon.code_to_http_status(v)