Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

typing: add middleware protocols #2390

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions falcon/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,147 @@ def __call__(self, media: Any, content_type: Optional[str] = ...) -> bytes: ...
DeserializeSync = Callable[[bytes], Any]

Responder = Union[ResponderMethod, AsgiResponderMethod]


# Middleware
class MiddlewareWithProcessRequest(Protocol):
"""WSGI Middleware with request handler."""

def process_request(self, req: Request, resp: Response) -> None: ...


class MiddlewareWithProcessResource(Protocol):
"""WSGI Middleware with resource handler."""

def process_resource(
self,
req: Request,
resp: Response,
resource: object,
params: Dict[str, Any],
) -> None: ...


class MiddlewareWithProcessResponse(Protocol):
"""WSGI Middleware with response handler."""

def process_response(
self, req: Request, resp: Response, resource: object, req_succeeded: bool
) -> None: ...


class AsgiMiddlewareWithProcessStartup(Protocol):
"""ASGI middleware with startup handler."""

async def process_startup(
self, scope: Mapping[str, Any], event: Mapping[str, Any]
) -> None: ...


class AsgiMiddlewareWithProcessShutdown(Protocol):
"""ASGI middleware with shutdown handler."""

async def process_shutdown(
self, scope: Mapping[str, Any], event: Mapping[str, Any]
) -> None: ...


class AsgiMiddlewareWithProcessRequest(Protocol):
"""ASGI middleware with request handler."""

async def process_request(self, req: AsgiRequest, resp: AsgiResponse) -> None: ...


class AsgiMiddlewareWithProcessResource(Protocol):
"""ASGI middleware with resource handler."""

async def process_resource(
self,
req: AsgiRequest,
resp: AsgiResponse,
resource: object,
params: Mapping[str, Any],
) -> None: ...


class AsgiMiddlewareWithProcessResponse(Protocol):
"""ASGI middleware with response handler."""

async def process_response(
self,
req: AsgiRequest,
resp: AsgiResponse,
resource: object,
req_succeeded: bool,
) -> None: ...


class MiddlewareWithAsyncProcessRequestWs(Protocol):
"""ASGI middleware with WebSocket request handler."""

async def process_request_ws(self, req: AsgiRequest, ws: WebSocket) -> None: ...


class MiddlewareWithAsyncProcessResourceWs(Protocol):
"""ASGI middleware with WebSocket resource handler."""

async def process_resource_ws(
self,
req: AsgiRequest,
ws: WebSocket,
resource: object,
params: Mapping[str, Any],
) -> None: ...


class UniversalMiddlewareWithProcessRequest(MiddlewareWithProcessRequest, Protocol):
"""WSGI/ASGI middleware with request handler."""

async def process_request_async(
self, req: AsgiRequest, resp: AsgiResponse
) -> None: ...


class UniversalMiddlewareWithProcessResource(MiddlewareWithProcessResource, Protocol):
"""WSGI/ASGI middleware with resource handler."""

async def process_resource_async(
self,
req: AsgiRequest,
resp: AsgiResponse,
resource: object,
params: Mapping[str, Any],
) -> None: ...


class UniversalMiddlewareWithProcessResponse(MiddlewareWithProcessResponse, Protocol):
"""WSGI/ASGI middleware with response handler."""

async def process_response_async(
self,
req: AsgiRequest,
resp: AsgiResponse,
resource: object,
req_succeeded: bool,
) -> None: ...


# NOTE(jkmnt): This typing is far from perfect due to the Python typing limitations,
# but better than nothing. Middleware conforming to any protocol of the union
# will pass the type check. Other protocols violations are not checked.
Middleware = Union[
MiddlewareWithProcessRequest,
MiddlewareWithProcessResource,
MiddlewareWithProcessResponse,
]

AsgiMiddleware = Union[
AsgiMiddlewareWithProcessRequest,
AsgiMiddlewareWithProcessResource,
AsgiMiddlewareWithProcessResponse,
AsgiMiddlewareWithProcessStartup,
AsgiMiddlewareWithProcessShutdown,
UniversalMiddlewareWithProcessRequest,
UniversalMiddlewareWithProcessResource,
UniversalMiddlewareWithProcessResponse,
]
29 changes: 16 additions & 13 deletions falcon/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from falcon._typing import ErrorHandler
from falcon._typing import ErrorSerializer
from falcon._typing import FindMethod
from falcon._typing import Middleware
from falcon._typing import ProcessResponseMethod
from falcon._typing import ResponderCallable
from falcon._typing import SinkCallable
Expand Down Expand Up @@ -286,7 +287,7 @@ def process_response(
_static_routes: List[
Tuple[routing.StaticRoute, routing.StaticRoute, Literal[False]]
]
_unprepared_middleware: List[object]
_unprepared_middleware: List[Middleware]

# Attributes
req_options: RequestOptions
Expand All @@ -305,7 +306,7 @@ def __init__(
media_type: str = constants.DEFAULT_MEDIA_TYPE,
request_type: Optional[Type[Request]] = None,
response_type: Optional[Type[Response]] = None,
middleware: Union[object, Iterable[object]] = None,
middleware: Optional[Union[Middleware, Iterable[Middleware]]] = None,
router: Optional[routing.CompiledRouter] = None,
independent_middleware: bool = True,
cors_enable: bool = False,
Expand All @@ -327,17 +328,17 @@ def __init__(
# NOTE(kgriffs): Check to see if middleware is an
# iterable, and if so, append the CORSMiddleware
# instance.
middleware = list(middleware) # type: ignore[arg-type]
middleware.append(cm) # type: ignore[arg-type]
middleware = list(cast(Iterable[Middleware], middleware))
middleware.append(cm)
except TypeError:
# NOTE(kgriffs): Assume the middleware kwarg references
# a single middleware component.
middleware = [middleware, cm]
middleware = [cast(Middleware, middleware), cm]

# set middleware
self._unprepared_middleware = []
self._independent_middleware = independent_middleware
self.add_middleware(middleware)
self.add_middleware(middleware or [])

self._router = router or routing.DefaultRouter()
self._router_search = self._router.find
Expand Down Expand Up @@ -524,7 +525,9 @@ def router_options(self) -> routing.CompiledRouterOptions:
"""
return self._router.options

def add_middleware(self, middleware: Union[object, Iterable[object]]) -> None:
def add_middleware(
self, middleware: Union[Middleware, Iterable[Middleware]]
) -> None:
"""Add one or more additional middleware components.

Arguments:
Expand All @@ -535,20 +538,20 @@ def add_middleware(self, middleware: Union[object, Iterable[object]]) -> None:
"""

# NOTE(kgriffs): Since this is called by the initializer, there is
# the chance that middleware may be None.
# the chance that middleware may be empty.
if middleware:
try:
middleware = list(middleware) # type: ignore[call-overload]
middleware = list(cast(Iterable[Middleware], middleware))
except TypeError:
# middleware is not iterable; assume it is just one bare component
middleware = [middleware]
middleware = [cast(Middleware, middleware)]

if (
self._cors_enable
and len(
[
mc
for mc in self._unprepared_middleware + middleware # type: ignore[operator]
for mc in self._unprepared_middleware + middleware
if isinstance(mc, CORSMiddleware)
]
)
Expand All @@ -559,7 +562,7 @@ def add_middleware(self, middleware: Union[object, Iterable[object]]) -> None:
'cors_enable (which already constructs one instance)'
)

self._unprepared_middleware += middleware # type: ignore[arg-type]
self._unprepared_middleware += middleware

# NOTE(kgriffs): Even if middleware is None or an empty list, we still
# need to make sure self._middleware is initialized if this is the
Expand Down Expand Up @@ -1012,7 +1015,7 @@ def my_serializer(
# ------------------------------------------------------------------------

def _prepare_middleware(
self, middleware: List[object], independent_middleware: bool = False
self, middleware: List[Middleware], independent_middleware: bool = False
) -> helpers.PreparedMiddlewareResult:
return helpers.prepare_middleware(
middleware=middleware, independent_middleware=independent_middleware
Expand Down
19 changes: 14 additions & 5 deletions falcon/app_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
from typing import IO, Iterable, List, Literal, Optional, overload, Tuple, Union

from falcon import util
from falcon._typing import AsgiMiddleware
from falcon._typing import AsgiProcessRequestMethod as APRequest
from falcon._typing import AsgiProcessRequestWsMethod
from falcon._typing import AsgiProcessResourceMethod as APResource
from falcon._typing import AsgiProcessResourceWsMethod
from falcon._typing import AsgiProcessResponseMethod as APResponse
from falcon._typing import Middleware
from falcon._typing import ProcessRequestMethod as PRequest
from falcon._typing import ProcessResourceMethod as PResource
from falcon._typing import ProcessResponseMethod as PResponse
Expand Down Expand Up @@ -62,24 +64,31 @@

@overload
def prepare_middleware(
middleware: Iterable, independent_middleware: bool = ..., asgi: Literal[False] = ...
middleware: Iterable[Middleware],
independent_middleware: bool = ...,
asgi: Literal[False] = ...,
) -> PreparedMiddlewareResult: ...


@overload
def prepare_middleware(
middleware: Iterable, independent_middleware: bool = ..., *, asgi: Literal[True]
middleware: Iterable[AsgiMiddleware],
independent_middleware: bool = ...,
*,
asgi: Literal[True],
) -> AsyncPreparedMiddlewareResult: ...


@overload
def prepare_middleware(
middleware: Iterable, independent_middleware: bool = ..., asgi: bool = ...
middleware: Union[Iterable[Middleware], Iterable[AsgiMiddleware]],
independent_middleware: bool = ...,
asgi: bool = ...,
) -> Union[PreparedMiddlewareResult, AsyncPreparedMiddlewareResult]: ...


def prepare_middleware(
middleware: Iterable[object],
middleware: Union[Iterable[Middleware], Iterable[AsgiMiddleware]],
independent_middleware: bool = False,
asgi: bool = False,
) -> Union[PreparedMiddlewareResult, AsyncPreparedMiddlewareResult]:
Expand Down Expand Up @@ -214,7 +223,7 @@ def prepare_middleware(


def prepare_middleware_ws(
middleware: Iterable[object],
middleware: Iterable[AsgiMiddleware],
) -> AsyncPreparedMiddlewareWsResult:
"""Check middleware interfaces and prepare WebSocket methods for request handling.

Expand Down
8 changes: 5 additions & 3 deletions falcon/asgi/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from falcon import routing
from falcon._typing import _UNSET
from falcon._typing import AsgiErrorHandler
from falcon._typing import AsgiMiddleware
from falcon._typing import AsgiReceive
from falcon._typing import AsgiResponderCallable
from falcon._typing import AsgiResponderWsCallable
Expand Down Expand Up @@ -356,6 +357,7 @@ async def process_resource_ws(
_middleware_ws: AsyncPreparedMiddlewareWsResult
_request_type: Type[Request]
_response_type: Type[Response]
_unprepared_middleware: List[AsgiMiddleware] # type: ignore[assignment]

ws_options: WebSocketOptions
"""A set of behavioral options related to WebSocket connections.
Expand All @@ -368,7 +370,7 @@ def __init__(
media_type: str = constants.DEFAULT_MEDIA_TYPE,
request_type: Optional[Type[Request]] = None,
response_type: Optional[Type[Response]] = None,
middleware: Union[object, Iterable[object]] = None,
middleware: Optional[Union[AsgiMiddleware, Iterable[AsgiMiddleware]]] = None,
router: Optional[routing.CompiledRouter] = None,
independent_middleware: bool = True,
cors_enable: bool = False,
Expand All @@ -378,7 +380,7 @@ def __init__(
media_type,
request_type or Request,
response_type or Response,
middleware,
middleware, # type: ignore[arg-type]
router,
independent_middleware,
cors_enable,
Expand Down Expand Up @@ -1163,7 +1165,7 @@ async def _handle_websocket(
raise

def _prepare_middleware( # type: ignore[override]
self, middleware: List[object], independent_middleware: bool = False
self, middleware: List[AsgiMiddleware], independent_middleware: bool = False
) -> AsyncPreparedMiddlewareResult:
self._middleware_ws = prepare_middleware_ws(middleware)

Expand Down
23 changes: 17 additions & 6 deletions falcon/middleware.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from __future__ import annotations

from typing import Any, Iterable, Optional, Union
from typing import Iterable, Optional, TYPE_CHECKING, Union

from .request import Request
from .response import Response
from ._typing import UniversalMiddlewareWithProcessResponse

if TYPE_CHECKING:
from .asgi.request import Request as AsgiRequest
from .asgi.response import Response as AsgiResponse
from .request import Request
from .response import Response

class CORSMiddleware(object):

class CORSMiddleware(UniversalMiddlewareWithProcessResponse):
"""CORS Middleware.

This middleware provides a simple out-of-the box CORS policy, including handling
Expand Down Expand Up @@ -141,5 +146,11 @@ def process_response(
resp.set_header('Access-Control-Allow-Headers', allow_headers)
resp.set_header('Access-Control-Max-Age', '86400') # 24 hours

async def process_response_async(self, *args: Any) -> None:
self.process_response(*args)
async def process_response_async(
self,
req: AsgiRequest,
resp: AsgiResponse,
resource: object,
req_succeeded: bool,
) -> None:
self.process_response(req, resp, resource, req_succeeded)
Loading