From e69a7d580a30bb4c14f61e1f61a8ce7820a8ea6f Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 20 Dec 2024 21:05:15 +0100 Subject: [PATCH] linting++ --- pyproject.toml | 48 ++++- src/h2/__init__.py | 4 +- src/h2/config.py | 46 +++-- src/h2/connection.py | 416 ++++++++++++++++++++++------------------- src/h2/errors.py | 9 +- src/h2/events.py | 216 +++++++++++---------- src/h2/exceptions.py | 30 +-- src/h2/frame_buffer.py | 71 ++++--- src/h2/settings.py | 62 +++--- src/h2/stream.py | 278 ++++++++++++++------------- src/h2/utilities.py | 285 +++++++++++++--------------- src/h2/windows.py | 29 ++- 12 files changed, 777 insertions(+), 717 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a77110d3..f7b38e6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,8 +91,54 @@ h2 = [ "py.typed" ] version = { attr = "h2.__version__" } [tool.ruff] -line-length = 140 +line-length = 150 target-version = "py39" +format.preview = true +format.docstring-code-line-length = 100 +format.docstring-code-format = true +lint.select = [ + "ALL", +] +lint.ignore = [ + "PYI034", # PEP 673 not yet available in Python 3.9 - only in 3.11+ + "ANN001", # args with typing.Any + "ANN002", # args with typing.Any + "ANN003", # kwargs with typing.Any + "ANN401", # kwargs with typing.Any + "SLF001", # implementation detail + "CPY", # not required + "D101", # docs readability + "D102", # docs readability + "D105", # docs readability + "D107", # docs readability + "D200", # docs readability + "D205", # docs readability + "D205", # docs readability + "D203", # docs readability + "D212", # docs readability + "D400", # docs readability + "D401", # docs readability + "D415", # docs readability + "PLR2004", # readability + "SIM108", # readability + "RUF012", # readability + "FBT001", # readability + "FBT002", # readability + "PGH003", # readability + "PGH004", # readability + "FIX001", # readability + "FIX002", # readability + "TD001", # readability + "TD002", # readability + "TD003", # readability + "S101", # readability + "PD901", # readability + "ERA001", # readability + "ARG001", # readability + "ARG002", # readability + "PLR0913", # readability +] +lint.isort.required-imports = [ "from __future__ import annotations" ] [tool.mypy] show_error_codes = true diff --git a/src/h2/__init__.py b/src/h2/__init__.py index 1cd9b80e..edb8be5b 100644 --- a/src/h2/__init__.py +++ b/src/h2/__init__.py @@ -4,4 +4,6 @@ A HTTP/2 implementation. """ -__version__ = '4.1.0' +from __future__ import annotations + +__version__ = "4.1.0" diff --git a/src/h2/config.py b/src/h2/config.py index e4d29fa2..cbc3b1ea 100644 --- a/src/h2/config.py +++ b/src/h2/config.py @@ -4,9 +4,10 @@ Objects for controlling the configuration of the HTTP/2 stack. """ +from __future__ import annotations import sys -from typing import Any, Optional, Union +from typing import Any class _BooleanConfigOption: @@ -14,16 +15,18 @@ class _BooleanConfigOption: Descriptor for handling a boolean config option. This will block attempts to set boolean config options to non-bools. """ + def __init__(self, name: str) -> None: self.name = name - self.attr_name = '_%s' % self.name + self.attr_name = f"_{self.name}" def __get__(self, instance: Any, owner: Any) -> bool: return getattr(instance, self.attr_name) # type: ignore def __set__(self, instance: Any, value: bool) -> None: if not isinstance(value, bool): - raise ValueError("%s must be a bool" % self.name) + msg = f"{self.name} must be a bool" + raise ValueError(msg) # noqa: TRY004 setattr(instance, self.attr_name, value) @@ -35,6 +38,7 @@ class DummyLogger: conditionals being sprinkled throughout the h2 code for calls to logging functions when no logger is passed into the corresponding object. """ + def __init__(self, *vargs) -> None: # type: ignore pass @@ -42,13 +46,11 @@ def debug(self, *vargs, **kwargs) -> None: # type: ignore """ No-op logging. Only level needed for now. """ - pass def trace(self, *vargs, **kwargs) -> None: # type: ignore """ No-op logging. Only level needed for now. """ - pass class OutputLogger: @@ -61,15 +63,16 @@ class OutputLogger: Defaults to ``sys.stderr``. :param trace: Enables trace-level output. Defaults to ``False``. """ - def __init__(self, file=None, trace_level=False): # type: ignore + + def __init__(self, file=None, trace_level=False) -> None: # type: ignore super().__init__() self.file = file or sys.stderr self.trace_level = trace_level - def debug(self, fmtstr, *args): # type: ignore + def debug(self, fmtstr, *args) -> None: # type: ignore print(f"h2 (debug): {fmtstr % args}", file=self.file) - def trace(self, fmtstr, *args): # type: ignore + def trace(self, fmtstr, *args) -> None: # type: ignore if self.trace_level: print(f"h2 (trace): {fmtstr % args}", file=self.file) @@ -147,32 +150,33 @@ class H2Configuration: :type logger: ``logging.Logger`` """ - client_side = _BooleanConfigOption('client_side') + + client_side = _BooleanConfigOption("client_side") validate_outbound_headers = _BooleanConfigOption( - 'validate_outbound_headers' + "validate_outbound_headers", ) normalize_outbound_headers = _BooleanConfigOption( - 'normalize_outbound_headers' + "normalize_outbound_headers", ) split_outbound_cookies = _BooleanConfigOption( - 'split_outbound_cookies' + "split_outbound_cookies", ) validate_inbound_headers = _BooleanConfigOption( - 'validate_inbound_headers' + "validate_inbound_headers", ) normalize_inbound_headers = _BooleanConfigOption( - 'normalize_inbound_headers' + "normalize_inbound_headers", ) def __init__(self, client_side: bool = True, - header_encoding: Optional[Union[bool, str]] = None, + header_encoding: bool | str | None = None, validate_outbound_headers: bool = True, normalize_outbound_headers: bool = True, split_outbound_cookies: bool = False, validate_inbound_headers: bool = True, normalize_inbound_headers: bool = True, - logger: Optional[Union[DummyLogger, OutputLogger]] = None) -> None: + logger: DummyLogger | OutputLogger | None = None) -> None: self.client_side = client_side self.header_encoding = header_encoding self.validate_outbound_headers = validate_outbound_headers @@ -183,7 +187,7 @@ def __init__(self, self.logger = logger or DummyLogger(__name__) @property - def header_encoding(self) -> Optional[Union[bool, str]]: + def header_encoding(self) -> bool | str | None: """ Controls whether the headers emitted by this object in events are transparently decoded to ``unicode`` strings, and what encoding is used @@ -195,12 +199,14 @@ def header_encoding(self) -> Optional[Union[bool, str]]: return self._header_encoding @header_encoding.setter - def header_encoding(self, value: Optional[Union[bool, str]]) -> None: + def header_encoding(self, value: bool | str | None) -> None: """ Enforces constraints on the value of header encoding. """ if not isinstance(value, (bool, str, type(None))): - raise ValueError("header_encoding must be bool, string, or None") + msg = "header_encoding must be bool, string, or None" + raise ValueError(msg) # noqa: TRY004 if value is True: - raise ValueError("header_encoding cannot be True") + msg = "header_encoding cannot be True" + raise ValueError(msg) self._header_encoding = value diff --git a/src/h2/connection.py b/src/h2/connection.py index c7347b26..28be9fca 100644 --- a/src/h2/connection.py +++ b/src/h2/connection.py @@ -4,45 +4,71 @@ An implementation of a HTTP/2 connection. """ -import base64 +from __future__ import annotations +import base64 from enum import Enum, IntEnum +from typing import TYPE_CHECKING, Any, Callable +from hpack.exceptions import HPACKError, OversizedHeaderListError +from hpack.hpack import Decoder, Encoder from hyperframe.exceptions import InvalidPaddingError from hyperframe.frame import ( + AltSvcFrame, + ContinuationFrame, + DataFrame, + ExtensionFrame, Frame, - AltSvcFrame, ContinuationFrame, DataFrame, ExtensionFrame, GoAwayFrame, - HeadersFrame, PingFrame, PriorityFrame, PushPromiseFrame, - RstStreamFrame, SettingsFrame, WindowUpdateFrame + GoAwayFrame, + HeadersFrame, + PingFrame, + PriorityFrame, + PushPromiseFrame, + RstStreamFrame, + SettingsFrame, + WindowUpdateFrame, ) -from hpack.hpack import Encoder, Decoder -from hpack.struct import Header, HeaderWeaklyTyped -from hpack.exceptions import HPACKError, OversizedHeaderListError from .config import H2Configuration from .errors import ErrorCodes, _error_code_from_int from .events import ( + AlternativeServiceAvailable, + ConnectionTerminated, Event, - AlternativeServiceAvailable, ConnectionTerminated, InformationalResponseReceived, - PingAckReceived, PingReceived, PriorityUpdated, - RemoteSettingsChanged, RequestReceived, ResponseReceived, + PingAckReceived, + PingReceived, + PriorityUpdated, + RemoteSettingsChanged, + RequestReceived, + ResponseReceived, SettingsAcknowledged, - TrailersReceived, UnknownFrameReceived, - WindowUpdated + TrailersReceived, + UnknownFrameReceived, + WindowUpdated, ) from .exceptions import ( - ProtocolError, NoSuchStreamError, FlowControlError, FrameTooLargeError, - TooManyStreamsError, StreamClosedError, StreamIDTooLowError, - NoAvailableStreamIDError, RFC1122Error, DenialOfServiceError + DenialOfServiceError, + FlowControlError, + FrameTooLargeError, + NoAvailableStreamIDError, + NoSuchStreamError, + ProtocolError, + RFC1122Error, + StreamClosedError, + StreamIDTooLowError, + TooManyStreamsError, ) from .frame_buffer import FrameBuffer -from .settings import Settings, SettingCodes, ChangedSetting +from .settings import ChangedSetting, SettingCodes, Settings from .stream import H2Stream, StreamClosedBy from .utilities import SizeLimitDict, guard_increment_window from .windows import WindowManager -from typing import Any, Callable, Optional, Union, Iterable +if TYPE_CHECKING: # pragma: no cover + from collections.abc import Iterable + + from hpack.struct import Header, HeaderWeaklyTyped class ConnectionState(Enum): @@ -89,6 +115,7 @@ class H2ConnectionStateMachine: maintains very little state directly, instead focusing entirely on managing state transitions. """ + # For the purposes of this state machine we treat HEADERS and their # associated CONTINUATION frames as a single jumbo frame. The protocol # allows/requires this by preventing other frames from being interleved in @@ -226,16 +253,16 @@ def process_input(self, input_: ConnectionInputs) -> list[Event]: Process a specific input in the state machine. """ if not isinstance(input_, ConnectionInputs): - raise ValueError("Input must be an instance of ConnectionInputs") + msg = "Input must be an instance of ConnectionInputs" + raise ValueError(msg) # noqa: TRY004 try: func, target_state = self._transitions[(self.state, input_)] - except KeyError: + except KeyError as e: old_state = self.state self.state = ConnectionState.CLOSED - raise ProtocolError( - "Invalid input %s in state %s" % (input_, old_state) - ) + msg = f"Invalid input {input_} in state {old_state}" + raise ProtocolError(msg) from e else: self.state = target_state if func is not None: # pragma: no cover @@ -272,6 +299,7 @@ class H2Connection: :type config: :class:`H2Configuration ` """ + # The initial maximum outbound frame size. This can be changed by receiving # a settings frame. DEFAULT_MAX_OUTBOUND_FRAME_SIZE = 65535 @@ -292,7 +320,7 @@ class H2Connection: # Keep in memory limited amount of results for streams closes MAX_CLOSED_STREAMS = 2**16 - def __init__(self, config: Optional[H2Configuration] = None) -> None: + def __init__(self, config: H2Configuration | None = None) -> None: self.state_machine = H2ConnectionStateMachine() self.streams: dict[int, H2Stream] = {} self.highest_inbound_stream_id = 0 @@ -328,7 +356,7 @@ def __init__(self, config: Optional[H2Configuration] = None) -> None: SettingCodes.MAX_CONCURRENT_STREAMS: 100, SettingCodes.MAX_HEADER_LIST_SIZE: self.DEFAULT_MAX_HEADER_LIST_SIZE, - } + }, ) self.remote_settings = Settings(client=not self.config.client_side) @@ -362,13 +390,13 @@ def __init__(self, config: Optional[H2Configuration] = None) -> None: # Also used to determine whether we should consider a frame received # while a stream is closed as either a stream error or a connection # error. - self._closed_streams: dict[int, Optional[StreamClosedBy]] = SizeLimitDict( - size_limit=self.MAX_CLOSED_STREAMS + self._closed_streams: dict[int, StreamClosedBy | None] = SizeLimitDict( + size_limit=self.MAX_CLOSED_STREAMS, ) # The flow control window manager for the connection. self._inbound_flow_control_window_manager = WindowManager( - max_window_size=self.local_settings.initial_window_size + max_window_size=self.local_settings.initial_window_size, ) # When in doubt use dict-dispatch. @@ -384,13 +412,13 @@ def __init__(self, config: Optional[H2Configuration] = None) -> None: GoAwayFrame: self._receive_goaway_frame, ContinuationFrame: self._receive_naked_continuation, AltSvcFrame: self._receive_alt_svc_frame, - ExtensionFrame: self._receive_unknown_frame + ExtensionFrame: self._receive_unknown_frame, } def _prepare_for_sending(self, frames: list[Frame]) -> None: if not frames: return - self._data_to_send += b''.join(f.serialize() for f in frames) + self._data_to_send += b"".join(f.serialize() for f in frames) assert all(f.body_len <= self.max_outbound_frame_size for f in frames) def _open_streams(self, remainder: int) -> int: @@ -451,7 +479,7 @@ def _begin_new_stream(self, stream_id: int, allowed_ids: AllowedStreamIDs) -> H2 :param allowed_ids: What kind of stream ID is allowed. """ self.config.logger.debug( - "Attempting to initiate stream ID %d", stream_id + "Attempting to initiate stream ID %d", stream_id, ) outbound = self._stream_id_is_outbound(stream_id) highest_stream_id = ( @@ -463,15 +491,14 @@ def _begin_new_stream(self, stream_id: int, allowed_ids: AllowedStreamIDs) -> H2 raise StreamIDTooLowError(stream_id, highest_stream_id) if (stream_id % 2) != int(allowed_ids): - raise ProtocolError( - "Invalid stream ID for peer." - ) + msg = "Invalid stream ID for peer." + raise ProtocolError(msg) s = H2Stream( stream_id, config=self.config, inbound_window_size=self.local_settings.initial_window_size, - outbound_window_size=self.remote_settings.initial_window_size + outbound_window_size=self.remote_settings.initial_window_size, ) self.config.logger.debug("Stream ID %d created", stream_id) s.max_outbound_frame_size = self.max_outbound_frame_size @@ -494,20 +521,20 @@ def initiate_connection(self) -> None: self.config.logger.debug("Initializing connection") self.state_machine.process_input(ConnectionInputs.SEND_SETTINGS) if self.config.client_side: - preamble = b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n' + preamble = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" else: - preamble = b'' + preamble = b"" f = SettingsFrame(0) for setting, value in self.local_settings.items(): f.settings[setting] = value self.config.logger.debug( - "Send Settings frame: %s", self.local_settings + "Send Settings frame: %s", self.local_settings, ) self._data_to_send += preamble + f.serialize() - def initiate_upgrade_connection(self, settings_header: Optional[bytes] = None) -> Optional[bytes]: + def initiate_upgrade_connection(self, settings_header: bytes | None = None) -> bytes | None: """ Call to initialise the connection object for use with an upgraded HTTP/2 connection (i.e. a connection negotiated using the @@ -541,7 +568,7 @@ def initiate_upgrade_connection(self, settings_header: Optional[bytes] = None) - :rtype: ``bytes`` or ``None`` """ self.config.logger.debug( - "Upgrade connection. Current settings: %s", self.local_settings + "Upgrade connection. Current settings: %s", self.local_settings, ) frame_data = None @@ -594,7 +621,7 @@ def _get_or_create_stream(self, stream_id: int, allowed_ids: AllowedStreamIDs) - except KeyError: return self._begin_new_stream(stream_id, allowed_ids) - def _get_stream_by_id(self, stream_id: Optional[int]) -> H2Stream: + def _get_stream_by_id(self, stream_id: int | None) -> H2Stream: """ Gets a stream by its stream ID. Raises NoSuchStreamError if the stream ID does not correspond to a known stream and is higher than the current @@ -607,7 +634,7 @@ def _get_stream_by_id(self, stream_id: Optional[int]) -> H2Stream: raise NoSuchStreamError(-1) # pragma: no cover try: return self.streams[stream_id] - except KeyError: + except KeyError as e: outbound = self._stream_id_is_outbound(stream_id) highest_stream_id = ( self.highest_outbound_stream_id if outbound else @@ -615,9 +642,8 @@ def _get_stream_by_id(self, stream_id: Optional[int]) -> H2Stream: ) if stream_id > highest_stream_id: - raise NoSuchStreamError(stream_id) - else: - raise StreamClosedError(stream_id) + raise NoSuchStreamError(stream_id) from e + raise StreamClosedError(stream_id) from e def get_next_available_stream_id(self) -> int: """ @@ -648,10 +674,11 @@ def get_next_available_stream_id(self) -> int: else: next_stream_id = self.highest_outbound_stream_id + 2 self.config.logger.debug( - "Next available stream ID %d", next_stream_id + "Next available stream ID %d", next_stream_id, ) if next_stream_id > self.HIGHEST_ALLOWED_STREAM_ID: - raise NoAvailableStreamIDError("Exhausted allowed stream IDs") + msg = "Exhausted allowed stream IDs" + raise NoAvailableStreamIDError(msg) return next_stream_id @@ -659,9 +686,9 @@ def send_headers(self, stream_id: int, headers: Iterable[HeaderWeaklyTyped], end_stream: bool = False, - priority_weight: Optional[int] = None, - priority_depends_on: Optional[int] = None, - priority_exclusive: Optional[bool] = None) -> None: + priority_weight: int | None = None, + priority_depends_on: int | None = None, + priority_exclusive: bool | None = None) -> None: """ Send headers on a given stream. @@ -760,26 +787,24 @@ def send_headers(self, :returns: Nothing """ self.config.logger.debug( - "Send headers on stream ID %d", stream_id + "Send headers on stream ID %d", stream_id, ) # Check we can open the stream. if stream_id not in self.streams: max_open_streams = self.remote_settings.max_concurrent_streams if (self.open_outbound_streams + 1) > max_open_streams: - raise TooManyStreamsError( - "Max outbound streams is %d, %d open" % - (max_open_streams, self.open_outbound_streams) - ) + msg = f"Max outbound streams is {max_open_streams}, {self.open_outbound_streams} open" + raise TooManyStreamsError(msg) self.state_machine.process_input(ConnectionInputs.SEND_HEADERS) stream = self._get_or_create_stream( - stream_id, AllowedStreamIDs(self.config.client_side) + stream_id, AllowedStreamIDs(self.config.client_side), ) frames: list[Frame] = [] frames.extend(stream.send_headers( - headers, self.encoder, end_stream + headers, self.encoder, end_stream, )) # We may need to send priority information. @@ -791,24 +816,25 @@ def send_headers(self, if priority_present: if not self.config.client_side: - raise RFC1122Error("Servers SHOULD NOT prioritize streams.") + msg = "Servers SHOULD NOT prioritize streams." + raise RFC1122Error(msg) headers_frame = frames[0] assert isinstance(headers_frame, HeadersFrame) - headers_frame.flags.add('PRIORITY') + headers_frame.flags.add("PRIORITY") frames[0] = _add_frame_priority( headers_frame, priority_weight, priority_depends_on, - priority_exclusive + priority_exclusive, ) self._prepare_for_sending(frames) def send_data(self, stream_id: int, - data: Union[bytes, memoryview], + data: bytes | memoryview, end_stream: bool = False, pad_length: Any = None) -> None: """ @@ -845,34 +871,32 @@ def send_data(self, :returns: Nothing """ self.config.logger.debug( - "Send data on stream ID %d with len %d", stream_id, len(data) + "Send data on stream ID %d with len %d", stream_id, len(data), ) frame_size = len(data) if pad_length is not None: if not isinstance(pad_length, int): - raise TypeError("pad_length must be an int") + msg = "pad_length must be an int" + raise TypeError(msg) if pad_length < 0 or pad_length > 255: - raise ValueError("pad_length must be within range: [0, 255]") + msg = "pad_length must be within range: [0, 255]" + raise ValueError(msg) # Account for padding bytes plus the 1-byte padding length field. frame_size += pad_length + 1 self.config.logger.debug( - "Frame size on stream ID %d is %d", stream_id, frame_size + "Frame size on stream ID %d is %d", stream_id, frame_size, ) if frame_size > self.local_flow_control_window(stream_id): - raise FlowControlError( - "Cannot send %d bytes, flow control window is %d." % - (frame_size, self.local_flow_control_window(stream_id)) - ) - elif frame_size > self.max_outbound_frame_size: - raise FrameTooLargeError( - "Cannot send frame size %d, max frame size is %d" % - (frame_size, self.max_outbound_frame_size) - ) + msg = f"Cannot send {frame_size} bytes, flow control window is {self.local_flow_control_window(stream_id)}" + raise FlowControlError(msg) + if frame_size > self.max_outbound_frame_size: + msg = f"Cannot send frame size {frame_size}, max frame size is {self.max_outbound_frame_size}" + raise FrameTooLargeError(msg) self.state_machine.process_input(ConnectionInputs.SEND_DATA) frames = self.streams[stream_id].send_data( - data, end_stream, pad_length=pad_length + data, end_stream, pad_length=pad_length, ) self._prepare_for_sending(frames) @@ -880,7 +904,7 @@ def send_data(self, self.outbound_flow_control_window -= frame_size self.config.logger.debug( "Outbound flow control window size is %d", - self.outbound_flow_control_window + self.outbound_flow_control_window, ) assert self.outbound_flow_control_window >= 0 @@ -900,7 +924,7 @@ def end_stream(self, stream_id: int) -> None: frames = self.streams[stream_id].end_stream() self._prepare_for_sending(frames) - def increment_flow_control_window(self, increment: int, stream_id: Optional[int] = None) -> None: + def increment_flow_control_window(self, increment: int, stream_id: int | None = None) -> None: """ Increment a flow control window, optionally for a single stream. Allows the remote peer to send more data. @@ -919,22 +943,20 @@ def increment_flow_control_window(self, increment: int, stream_id: Optional[int] :raises: ``ValueError`` """ if not (1 <= increment <= self.MAX_WINDOW_INCREMENT): - raise ValueError( - "Flow control increment must be between 1 and %d" % - self.MAX_WINDOW_INCREMENT - ) + msg = f"Flow control increment must be between 1 and {self.MAX_WINDOW_INCREMENT}" + raise ValueError(msg) self.state_machine.process_input(ConnectionInputs.SEND_WINDOW_UPDATE) if stream_id is not None: stream = self.streams[stream_id] frames = stream.increase_flow_control_window( - increment + increment, ) self.config.logger.debug( "Increase stream ID %d flow control window by %d", - stream_id, increment + stream_id, increment, ) else: self._inbound_flow_control_window_manager.window_opened(increment) @@ -943,7 +965,7 @@ def increment_flow_control_window(self, increment: int, stream_id: Optional[int] frames = [f] self.config.logger.debug( - "Increase connection flow control window by %d", increment + "Increase connection flow control window by %d", increment, ) self._prepare_for_sending(frames) @@ -974,11 +996,12 @@ def push_stream(self, :returns: Nothing """ self.config.logger.debug( - "Send Push Promise frame on stream ID %d", stream_id + "Send Push Promise frame on stream ID %d", stream_id, ) if not self.remote_settings.enable_push: - raise ProtocolError("Remote peer has disabled stream push") + msg = "Remote peer has disabled stream push" + raise ProtocolError(msg) self.state_machine.process_input(ConnectionInputs.SEND_PUSH_PROMISE) stream = self._get_stream_by_id(stream_id) @@ -989,20 +1012,21 @@ def push_stream(self, # this shortcut works because only servers can push and the state # machine will enforce this. if (stream_id % 2) == 0: - raise ProtocolError("Cannot recursively push streams.") + msg = "Cannot recursively push streams." + raise ProtocolError(msg) new_stream = self._begin_new_stream( - promised_stream_id, AllowedStreamIDs.EVEN + promised_stream_id, AllowedStreamIDs.EVEN, ) self.streams[promised_stream_id] = new_stream frames = stream.push_stream_in_band( - promised_stream_id, request_headers, self.encoder + promised_stream_id, request_headers, self.encoder, ) new_frames = new_stream.locally_pushed() self._prepare_for_sending(frames + new_frames) - def ping(self, opaque_data: Union[bytes, str]) -> None: + def ping(self, opaque_data: bytes | str) -> None: """ Send a PING frame. @@ -1013,14 +1037,15 @@ def ping(self, opaque_data: Union[bytes, str]) -> None: self.config.logger.debug("Send Ping frame") if not isinstance(opaque_data, bytes) or len(opaque_data) != 8: - raise ValueError("Invalid value for ping data: %r" % opaque_data) + msg = f"Invalid value for ping data: {opaque_data!r}" + raise ValueError(msg) self.state_machine.process_input(ConnectionInputs.SEND_PING) f = PingFrame(0) f.opaque_data = opaque_data self._prepare_for_sending([f]) - def reset_stream(self, stream_id: int, error_code: Union[ErrorCodes, int] = 0) -> None: + def reset_stream(self, stream_id: int, error_code: ErrorCodes | int = 0) -> None: """ Reset a stream. @@ -1044,10 +1069,9 @@ def reset_stream(self, stream_id: int, error_code: Union[ErrorCodes, int] = 0) - self._prepare_for_sending(frames) def close_connection(self, - error_code: Union[ErrorCodes, int] = 0, - additional_data: Optional[bytes] = None, - last_stream_id: Optional[int] = None) -> None: - + error_code: ErrorCodes | int = 0, + additional_data: bytes | None = None, + last_stream_id: int | None = None) -> None: """ Close a connection, emitting a GOAWAY frame. @@ -1076,11 +1100,11 @@ def close_connection(self, stream_id=0, last_stream_id=last_stream_id, error_code=error_code, - additional_data=(additional_data or b'') + additional_data=(additional_data or b""), ) self._prepare_for_sending([f]) - def update_settings(self, new_settings: dict[Union[SettingCodes, int], int]) -> None: + def update_settings(self, new_settings: dict[SettingCodes | int, int]) -> None: """ Update the local settings. This will prepare and emit the appropriate SETTINGS frame. @@ -1088,7 +1112,7 @@ def update_settings(self, new_settings: dict[Union[SettingCodes, int], int]) -> :param new_settings: A dictionary of {setting: new value} """ self.config.logger.debug( - "Update connection settings to %s", new_settings + "Update connection settings to %s", new_settings, ) self.state_machine.process_input(ConnectionInputs.SEND_SETTINGS) self.local_settings.update(new_settings) @@ -1097,9 +1121,9 @@ def update_settings(self, new_settings: dict[Union[SettingCodes, int], int]) -> self._prepare_for_sending([s]) def advertise_alternative_service(self, - field_value: Union[bytes, str], - origin: Optional[bytes] = None, - stream_id: Optional[int] = None) -> None: + field_value: bytes | str, + origin: bytes | None = None, + stream_id: int | None = None) -> None: """ Notify a client about an available Alternative Service. @@ -1154,13 +1178,15 @@ def advertise_alternative_service(self, :returns: Nothing. """ if not isinstance(field_value, bytes): - raise ValueError("Field must be bytestring.") + msg = "Field must be bytestring." + raise ValueError(msg) # noqa: TRY004 if origin is not None and stream_id is not None: - raise ValueError("Must not provide both origin and stream_id") + msg = "Must not provide both origin and stream_id" + raise ValueError(msg) self.state_machine.process_input( - ConnectionInputs.SEND_ALTERNATIVE_SERVICE + ConnectionInputs.SEND_ALTERNATIVE_SERVICE, ) if origin is not None: @@ -1177,9 +1203,9 @@ def advertise_alternative_service(self, def prioritize(self, stream_id: int, - weight: Optional[int] = None, - depends_on: Optional[int] = None, - exclusive: Optional[bool] = None) -> None: + weight: int | None = None, + depends_on: int | None = None, + exclusive: bool | None = None) -> None: """ Notify a server about the priority of a stream. @@ -1243,10 +1269,11 @@ def prioritize(self, :type exclusive: ``bool`` """ if not self.config.client_side: - raise RFC1122Error("Servers SHOULD NOT prioritize streams.") + msg = "Servers SHOULD NOT prioritize streams." + raise RFC1122Error(msg) self.state_machine.process_input( - ConnectionInputs.SEND_PRIORITY + ConnectionInputs.SEND_PRIORITY, ) frame = PriorityFrame(stream_id) @@ -1278,7 +1305,7 @@ def local_flow_control_window(self, stream_id: int) -> int: stream = self._get_stream_by_id(stream_id) return min( self.outbound_flow_control_window, - stream.outbound_flow_control_window + stream.outbound_flow_control_window, ) def remote_flow_control_window(self, stream_id: int) -> int: @@ -1305,7 +1332,7 @@ def remote_flow_control_window(self, stream_id: int) -> int: stream = self._get_stream_by_id(stream_id) return min( self.inbound_flow_control_window, - stream.inbound_flow_control_window + stream.inbound_flow_control_window, ) def acknowledge_received_data(self, acknowledged_size: int, stream_id: int) -> None: @@ -1328,15 +1355,14 @@ def acknowledge_received_data(self, acknowledged_size: int, stream_id: int) -> N """ self.config.logger.debug( "Ack received data on stream ID %d with size %d", - stream_id, acknowledged_size + stream_id, acknowledged_size, ) if stream_id <= 0: - raise ValueError( - "Stream ID %d is not valid for acknowledge_received_data" % - stream_id - ) + msg = f"Stream ID {stream_id} is not valid for acknowledge_received_data" + raise ValueError(msg) if acknowledged_size < 0: - raise ValueError("Cannot acknowledge negative data") + msg = "Cannot acknowledge negative data" + raise ValueError(msg) frames: list[Frame] = [] @@ -1357,12 +1383,12 @@ def acknowledge_received_data(self, acknowledged_size: int, stream_id: int) -> N # No point incrementing the windows of closed streams. if stream.open: frames.extend( - stream.acknowledge_received_data(acknowledged_size) + stream.acknowledge_received_data(acknowledged_size), ) self._prepare_for_sending(frames) - def data_to_send(self, amount: Optional[int] = None) -> bytes: + def data_to_send(self, amount: int | None = None) -> bytes: """ Returns some data for sending out of the internal data buffer. @@ -1381,10 +1407,9 @@ def data_to_send(self, amount: Optional[int] = None) -> bytes: data = bytes(self._data_to_send) self._data_to_send = bytearray() return data - else: - data = bytes(self._data_to_send[:amount]) - self._data_to_send = self._data_to_send[amount:] - return data + data = bytes(self._data_to_send[:amount]) + self._data_to_send = self._data_to_send[amount:] + return data def clear_outbound_data_buffer(self) -> None: """ @@ -1432,10 +1457,10 @@ def _acknowledge_settings(self) -> list[Frame]: stream.max_outbound_frame_size = setting.new_value f = SettingsFrame(0) - f.flags.add('ACK') + f.flags.add("ACK") return [f] - def _flow_control_change_from_settings(self, old_value: Optional[int], new_value: int) -> None: + def _flow_control_change_from_settings(self, old_value: int | None, new_value: int) -> None: """ Update flow control windows in response to a change in the value of SETTINGS_INITIAL_WINDOW_SIZE. @@ -1450,10 +1475,10 @@ def _flow_control_change_from_settings(self, old_value: Optional[int], new_value for stream in self.streams.values(): stream.outbound_flow_control_window = guard_increment_window( stream.outbound_flow_control_window, - delta + delta, ) - def _inbound_flow_control_change_from_settings(self, old_value: Optional[int], new_value: int) -> None: + def _inbound_flow_control_change_from_settings(self, old_value: int | None, new_value: int) -> None: """ Update remote flow control windows in response to a change in the value of SETTINGS_INITIAL_WINDOW_SIZE. @@ -1476,7 +1501,7 @@ def receive_data(self, data: bytes) -> list[Event]: this data. """ self.config.logger.trace( - "Process received data on connection. Received data: %r", data + "Process received data on connection. Received data: %r", data, ) events: list[Event] = [] @@ -1486,9 +1511,10 @@ def receive_data(self, data: bytes) -> list[Event]: try: for frame in self.incoming_buffer: events.extend(self._receive_frame(frame)) - except InvalidPaddingError: + except InvalidPaddingError as e: self._terminate_connection(ErrorCodes.PROTOCOL_ERROR) - raise ProtocolError("Received frame with invalid padding.") + msg = "Received frame with invalid padding." + raise ProtocolError(msg) from e except ProtocolError as e: # For whatever reason, receiving the frame caused a protocol error. # We should prepare to emit a GoAway frame before throwing the @@ -1538,7 +1564,7 @@ def _receive_frame(self, frame: Frame) -> list[Event]: events = [] elif self._stream_is_closed_by_end(e.stream_id): # Closed by END_STREAM is a connection error. - raise StreamClosedError(e.stream_id) + raise StreamClosedError(e.stream_id) from e else: # Closed implicitly, also a connection error, but of type # PROTOCOL_ERROR. @@ -1568,10 +1594,8 @@ def _receive_headers_frame(self, frame: HeadersFrame) -> tuple[list[Frame], list if frame.stream_id not in self.streams: max_open_streams = self.local_settings.max_concurrent_streams if (self.open_inbound_streams + 1) > max_open_streams: - raise TooManyStreamsError( - "Max outbound streams is %d, %d open" % - (max_open_streams, self.open_outbound_streams) - ) + msg = f"Max outbound streams is {max_open_streams}, {self.open_outbound_streams} open" + raise TooManyStreamsError(msg) # Let's decode the headers. We handle headers as bytes internally up # until we hang them off the event, at which point we may optionally @@ -1579,18 +1603,18 @@ def _receive_headers_frame(self, frame: HeadersFrame) -> tuple[list[Frame], list headers = _decode_headers(self.decoder, frame.data) events = self.state_machine.process_input( - ConnectionInputs.RECV_HEADERS + ConnectionInputs.RECV_HEADERS, ) stream = self._get_or_create_stream( - frame.stream_id, AllowedStreamIDs(not self.config.client_side) + frame.stream_id, AllowedStreamIDs(not self.config.client_side), ) frames, stream_events = stream.receive_headers( headers, - 'END_STREAM' in frame.flags, - self.config.header_encoding + "END_STREAM" in frame.flags, + self.config.header_encoding, ) - if 'PRIORITY' in frame.flags: + if "PRIORITY" in frame.flags: p_frames, p_events = self._receive_priority_frame(frame) expected_frame_types = (RequestReceived, ResponseReceived, TrailersReceived, InformationalResponseReceived) assert isinstance(stream_events[0], expected_frame_types) @@ -1606,17 +1630,18 @@ def _receive_push_promise_frame(self, frame: PushPromiseFrame) -> tuple[list[Fra Receive a push-promise frame on the connection. """ if not self.local_settings.enable_push: - raise ProtocolError("Received pushed stream") + msg = "Received pushed stream" + raise ProtocolError(msg) pushed_headers = _decode_headers(self.decoder, frame.data) events = self.state_machine.process_input( - ConnectionInputs.RECV_PUSH_PROMISE + ConnectionInputs.RECV_PUSH_PROMISE, ) try: stream = self._get_stream_by_id(frame.stream_id) - except NoSuchStreamError: + except NoSuchStreamError as e: # We need to check if the parent stream was reset by us. If it was # then we presume that the PUSH_PROMISE was in flight when we reset # the parent stream. Rather than accept the new stream, just reset @@ -1632,7 +1657,8 @@ def _receive_push_promise_frame(self, frame: PushPromiseFrame) -> tuple[list[Fra f.error_code = ErrorCodes.REFUSED_STREAM return [f], events - raise ProtocolError("Attempted to push on closed stream.") + msg = "Attempted to push on closed stream." + raise ProtocolError(msg) from e # We need to prevent peers pushing streams in response to streams that # they themselves have already pushed: see #163 and RFC 7540 § 6.6. The @@ -1640,7 +1666,8 @@ def _receive_push_promise_frame(self, frame: PushPromiseFrame) -> tuple[list[Fra # this shortcut works because only servers can push and the state # machine will enforce this. if (frame.stream_id % 2) == 0: - raise ProtocolError("Cannot recursively push streams.") + msg = "Cannot recursively push streams." + raise ProtocolError(msg) try: frames, stream_events = stream.receive_push_promise_in_band( @@ -1657,7 +1684,7 @@ def _receive_push_promise_frame(self, frame: PushPromiseFrame) -> tuple[list[Fra return [f], events new_stream = self._begin_new_stream( - frame.promised_stream_id, AllowedStreamIDs.EVEN + frame.promised_stream_id, AllowedStreamIDs.EVEN, ) self.streams[frame.promised_stream_id] = new_stream new_stream.remotely_pushed(pushed_headers) @@ -1676,7 +1703,7 @@ def _handle_data_on_closed_stream(self, frames: list[Frame] = [] conn_manager = self._inbound_flow_control_window_manager conn_increment = conn_manager.process_bytes( - frame.flow_controlled_length + frame.flow_controlled_length, ) if conn_increment: @@ -1686,15 +1713,15 @@ def _handle_data_on_closed_stream(self, self.config.logger.debug( "Received DATA frame on closed stream %d - " "auto-emitted a WINDOW_UPDATE by %d", - frame.stream_id, conn_increment + frame.stream_id, conn_increment, ) rst_stream_frame = RstStreamFrame(exc.stream_id) rst_stream_frame.error_code = exc.error_code frames.append(rst_stream_frame) self.config.logger.debug( - "Stream %d already CLOSED or cleaned up - " - "auto-emitted a RST_FRAME" % frame.stream_id + "Stream %s already CLOSED or cleaned up - auto-emitted a RST_FRAME", + frame.stream_id, ) return frames, events + exc._events @@ -1705,18 +1732,18 @@ def _receive_data_frame(self, frame: DataFrame) -> tuple[list[Frame], list[Event flow_controlled_length = frame.flow_controlled_length events = self.state_machine.process_input( - ConnectionInputs.RECV_DATA + ConnectionInputs.RECV_DATA, ) self._inbound_flow_control_window_manager.window_consumed( - flow_controlled_length + flow_controlled_length, ) try: stream = self._get_stream_by_id(frame.stream_id) frames, stream_events = stream.receive_data( frame.data, - 'END_STREAM' in frame.flags, - flow_controlled_length + "END_STREAM" in frame.flags, + flow_controlled_length, ) except StreamClosedError as e: # This stream is either marked as CLOSED or already gone from our @@ -1730,11 +1757,11 @@ def _receive_settings_frame(self, frame: SettingsFrame) -> tuple[list[Frame], li Receive a SETTINGS frame on the connection. """ events = self.state_machine.process_input( - ConnectionInputs.RECV_SETTINGS + ConnectionInputs.RECV_SETTINGS, ) # This is an ack of the local settings. - if 'ACK' in frame.flags: + if "ACK" in frame.flags: changed_settings = self._local_settings_acked() ack_event = SettingsAcknowledged() ack_event.changed_settings = changed_settings @@ -1745,8 +1772,8 @@ def _receive_settings_frame(self, frame: SettingsFrame) -> tuple[list[Frame], li self.remote_settings.update(frame.settings) events.append( RemoteSettingsChanged.from_settings( - self.remote_settings, frame.settings - ) + self.remote_settings, frame.settings, + ), ) frames = self._acknowledge_settings() @@ -1760,14 +1787,14 @@ def _receive_window_update_frame(self, frame: WindowUpdateFrame) -> tuple[list[F # If we reach in here, we can assume a valid value. events = self.state_machine.process_input( - ConnectionInputs.RECV_WINDOW_UPDATE + ConnectionInputs.RECV_WINDOW_UPDATE, ) if frame.stream_id: try: stream = self._get_stream_by_id(frame.stream_id) frames, stream_events = stream.receive_window_update( - frame.window_increment + frame.window_increment, ) except StreamClosedError: return [], events @@ -1775,7 +1802,7 @@ def _receive_window_update_frame(self, frame: WindowUpdateFrame) -> tuple[list[F # Increment our local flow control window. self.outbound_flow_control_window = guard_increment_window( self.outbound_flow_control_window, - frame.window_increment + frame.window_increment, ) # FIXME: Should we split this into one event per active stream? @@ -1792,19 +1819,19 @@ def _receive_ping_frame(self, frame: PingFrame) -> tuple[list[Frame], list[Event Receive a PING frame on the connection. """ events = self.state_machine.process_input( - ConnectionInputs.RECV_PING + ConnectionInputs.RECV_PING, ) frames: list[Frame] = [] - evt: Union[PingReceived, PingAckReceived] - if 'ACK' in frame.flags: + evt: PingReceived | PingAckReceived + if "ACK" in frame.flags: evt = PingAckReceived() else: evt = PingReceived() # automatically ACK the PING with the same 'opaque data' f = PingFrame(0) - f.flags.add('ACK') + f.flags.add("ACK") f.opaque_data = frame.opaque_data frames.append(f) @@ -1818,7 +1845,7 @@ def _receive_rst_stream_frame(self, frame: RstStreamFrame) -> tuple[list[Frame], Receive a RST_STREAM frame on the connection. """ events = self.state_machine.process_input( - ConnectionInputs.RECV_RST_STREAM + ConnectionInputs.RECV_RST_STREAM, ) try: stream = self._get_stream_by_id(frame.stream_id) @@ -1831,12 +1858,12 @@ def _receive_rst_stream_frame(self, frame: RstStreamFrame) -> tuple[list[Frame], return stream_frames, events + stream_events - def _receive_priority_frame(self, frame: Union[HeadersFrame, PriorityFrame]) -> tuple[list[Frame], list[Event]]: + def _receive_priority_frame(self, frame: HeadersFrame | PriorityFrame) -> tuple[list[Frame], list[Event]]: """ Receive a PRIORITY frame on the connection. """ events = self.state_machine.process_input( - ConnectionInputs.RECV_PRIORITY + ConnectionInputs.RECV_PRIORITY, ) event = PriorityUpdated() @@ -1850,9 +1877,8 @@ def _receive_priority_frame(self, frame: Union[HeadersFrame, PriorityFrame]) -> # A stream may not depend on itself. if event.depends_on == frame.stream_id: - raise ProtocolError( - "Stream %d may not depend on itself" % frame.stream_id - ) + msg = f"Stream {frame.stream_id} may not depend on itself" + raise ProtocolError(msg) events.append(event) return [], events @@ -1862,7 +1888,7 @@ def _receive_goaway_frame(self, frame: GoAwayFrame) -> tuple[list[Frame], list[E Receive a GOAWAY frame on the connection. """ events = self.state_machine.process_input( - ConnectionInputs.RECV_GOAWAY + ConnectionInputs.RECV_GOAWAY, ) # Clear the outbound data buffer: we cannot send further data now. @@ -1887,7 +1913,8 @@ def _receive_naked_continuation(self, frame: ContinuationFrame) -> None: """ stream = self._get_stream_by_id(frame.stream_id) stream.receive_continuation() - assert False, "Should not be reachable" + msg = "Should not be reachable" # pragma: no cover + raise AssertionError(msg) # pragma: no cover def _receive_alt_svc_frame(self, frame: AltSvcFrame) -> tuple[list[Frame], list[Event]]: """ @@ -1899,7 +1926,7 @@ def _receive_alt_svc_frame(self, frame: AltSvcFrame) -> tuple[list[Frame], list[ 0, and its semantics are different in each case. """ events = self.state_machine.process_input( - ConnectionInputs.RECV_ALTERNATIVE_SERVICE + ConnectionInputs.RECV_ALTERNATIVE_SERVICE, ) frames = [] @@ -1945,13 +1972,13 @@ def _receive_unknown_frame(self, frame: ExtensionFrame) -> tuple[list[Frame], li """ # All we do here is log. self.config.logger.debug( - "Received unknown extension frame (ID %d)", frame.stream_id + "Received unknown extension frame (ID %d)", frame.stream_id, ) event = UnknownFrameReceived() event.frame = frame return [], [event] - def _local_settings_acked(self) -> dict[Union[SettingCodes, int], ChangedSetting]: + def _local_settings_acked(self) -> dict[SettingCodes | int, ChangedSetting]: """ Handle the local settings being ACKed, update internal state. """ @@ -1987,7 +2014,7 @@ def _stream_id_is_outbound(self, stream_id: int) -> bool: """ return (stream_id % 2 == int(self.config.client_side)) - def _stream_closed_by(self, stream_id: int) -> Optional[StreamClosedBy]: + def _stream_closed_by(self, stream_id: int) -> StreamClosedBy | None: """ Returns how the stream was closed. @@ -2008,7 +2035,7 @@ def _stream_is_closed_by_reset(self, stream_id: int) -> bool: RST_STREAM frame. Returns ``False`` otherwise. """ return self._stream_closed_by(stream_id) in ( - StreamClosedBy.RECV_RST_STREAM, StreamClosedBy.SEND_RST_STREAM + StreamClosedBy.RECV_RST_STREAM, StreamClosedBy.SEND_RST_STREAM, ) def _stream_is_closed_by_end(self, stream_id: int) -> bool: @@ -2018,14 +2045,14 @@ def _stream_is_closed_by_end(self, stream_id: int) -> bool: otherwise. """ return self._stream_closed_by(stream_id) in ( - StreamClosedBy.RECV_END_STREAM, StreamClosedBy.SEND_END_STREAM + StreamClosedBy.RECV_END_STREAM, StreamClosedBy.SEND_END_STREAM, ) -def _add_frame_priority(frame: Union[PriorityFrame, HeadersFrame], - weight: Optional[int] = None, - depends_on: Optional[int] = None, - exclusive: Optional[bool] = None) -> Union[PriorityFrame, HeadersFrame]: +def _add_frame_priority(frame: PriorityFrame | HeadersFrame, + weight: int | None = None, + depends_on: int | None = None, + exclusive: bool | None = None) -> PriorityFrame | HeadersFrame: """ Adds priority data to a given frame. Does not change any flags set on that frame: if the caller is adding priority information to a HEADERS frame they @@ -2037,20 +2064,17 @@ def _add_frame_priority(frame: Union[PriorityFrame, HeadersFrame], """ # A stream may not depend on itself. if depends_on == frame.stream_id: - raise ProtocolError( - "Stream %d may not depend on itself" % frame.stream_id - ) + msg = f"Stream {frame.stream_id} may not depend on itself" + raise ProtocolError(msg) # Weight must be between 1 and 256. if weight is not None: if weight > 256 or weight < 1: - raise ProtocolError( - "Weight must be between 1 and 256, not %d" % weight - ) - else: - # Weight is an integer between 1 and 256, but the byte only allows - # 0 to 255: subtract one. - weight -= 1 + msg = f"Weight must be between 1 and 256, not {weight}" + raise ProtocolError(msg) + # Weight is an integer between 1 and 256, but the byte only allows + # 0 to 255: subtract one. + weight -= 1 # Set defaults for anything not provided. weight = weight if weight is not None else 15 @@ -2078,9 +2102,11 @@ def _decode_headers(decoder: Decoder, encoded_header_block: bytes) -> Iterable[H # This is a symptom of a HPACK bomb attack: the user has # disregarded our requirements on how large a header block we'll # accept. - raise DenialOfServiceError("Oversized header block: %s" % e) + msg = f"Oversized header block: {e}" + raise DenialOfServiceError(msg) from e except (HPACKError, IndexError, TypeError, UnicodeDecodeError) as e: # We should only need HPACKError here, but versions of HPACK older # than 2.1.0 throw all three others as well. For maximum # compatibility, catch all of them. - raise ProtocolError("Error decoding header block: %s" % e) + msg = f"Error decoding header block: {e}" + raise ProtocolError(msg) from e diff --git a/src/h2/errors.py b/src/h2/errors.py index ed0d754a..24ebe00f 100644 --- a/src/h2/errors.py +++ b/src/h2/errors.py @@ -7,9 +7,9 @@ The current registry is available at: https://tools.ietf.org/html/rfc7540#section-11.4 """ -import enum +from __future__ import annotations -from typing import Union +import enum class ErrorCodes(enum.IntEnum): @@ -18,6 +18,7 @@ class ErrorCodes(enum.IntEnum): .. versionadded:: 2.5.0 """ + #: Graceful shutdown. NO_ERROR = 0x0 @@ -61,7 +62,7 @@ class ErrorCodes(enum.IntEnum): HTTP_1_1_REQUIRED = 0xd -def _error_code_from_int(code: int) -> Union[ErrorCodes, int]: +def _error_code_from_int(code: int) -> ErrorCodes | int: """ Given an integer error code, returns either one of :class:`ErrorCodes ` or, if not present in the known set of codes, @@ -73,4 +74,4 @@ def _error_code_from_int(code: int) -> Union[ErrorCodes, int]: return code -__all__ = ['ErrorCodes'] +__all__ = ["ErrorCodes"] diff --git a/src/h2/events.py b/src/h2/events.py index ee0c6911..c7804d56 100644 --- a/src/h2/events.py +++ b/src/h2/events.py @@ -8,22 +8,25 @@ track of events triggered by receiving data. Each time data is provided to the H2 state machine it processes the data and returns a list of Event objects. """ +from __future__ import annotations + import binascii +from typing import TYPE_CHECKING -from hpack import HeaderTuple -from hyperframe.frame import Frame +from .settings import ChangedSetting, SettingCodes, Settings, _setting_code_from_int -from .errors import ErrorCodes -from .settings import SettingCodes, Settings, ChangedSetting, _setting_code_from_int +if TYPE_CHECKING: # pragma: no cover + from hpack import HeaderTuple + from hyperframe.frame import Frame -from typing import Union, Optional + from .errors import ErrorCodes class Event: """ Base class for h2 events. """ - pass + class RequestReceived(Event): @@ -39,31 +42,30 @@ class RequestReceived(Event): .. versionchanged:: 2.4.0 Added ``stream_ended`` and ``priority_updated`` properties. """ + def __init__(self) -> None: #: The Stream ID for the stream this request was made on. - self.stream_id: Optional[int] = None + self.stream_id: int | None = None #: The request headers. - self.headers: Optional[list[HeaderTuple]] = None + self.headers: list[HeaderTuple] | None = None #: If this request also ended the stream, the associated #: :class:`StreamEnded ` event will be available #: here. #: #: .. versionadded:: 2.4.0 - self.stream_ended: Optional[StreamEnded] = None + self.stream_ended: StreamEnded | None = None #: If this request also had associated priority information, the #: associated :class:`PriorityUpdated ` #: event will be available here. #: #: .. versionadded:: 2.4.0 - self.priority_updated: Optional[PriorityUpdated] = None + self.priority_updated: PriorityUpdated | None = None def __repr__(self) -> str: - return "" % ( - self.stream_id, self.headers - ) + return f"" class ResponseReceived(Event): @@ -76,34 +78,33 @@ class ResponseReceived(Event): Changed the type of ``headers`` to :class:`HeaderTuple `. This has no effect on current users. - .. versionchanged:: 2.4.0 + .. versionchanged:: 2.4.0 Added ``stream_ended`` and ``priority_updated`` properties. """ + def __init__(self) -> None: #: The Stream ID for the stream this response was made on. - self.stream_id: Optional[int] = None + self.stream_id: int | None = None #: The response headers. - self.headers: Optional[list[HeaderTuple]] = None + self.headers: list[HeaderTuple] | None = None #: If this response also ended the stream, the associated #: :class:`StreamEnded ` event will be available #: here. #: #: .. versionadded:: 2.4.0 - self.stream_ended: Optional[StreamEnded] = None + self.stream_ended: StreamEnded | None = None #: If this response also had associated priority information, the #: associated :class:`PriorityUpdated ` #: event will be available here. #: #: .. versionadded:: 2.4.0 - self.priority_updated: Optional[PriorityUpdated] = None + self.priority_updated: PriorityUpdated | None = None def __repr__(self) -> str: - return "" % ( - self.stream_id, self.headers - ) + return f"" class TrailersReceived(Event): @@ -122,30 +123,29 @@ class TrailersReceived(Event): .. versionchanged:: 2.4.0 Added ``stream_ended`` and ``priority_updated`` properties. """ + def __init__(self) -> None: #: The Stream ID for the stream on which these trailers were received. - self.stream_id: Optional[int] = None + self.stream_id: int | None = None #: The trailers themselves. - self.headers: Optional[list[HeaderTuple]] = None + self.headers: list[HeaderTuple] | None = None #: Trailers always end streams. This property has the associated #: :class:`StreamEnded ` in it. #: #: .. versionadded:: 2.4.0 - self.stream_ended: Optional[StreamEnded] = None + self.stream_ended: StreamEnded | None = None #: If the trailers also set associated priority information, the #: associated :class:`PriorityUpdated ` #: event will be available here. #: #: .. versionadded:: 2.4.0 - self.priority_updated: Optional[PriorityUpdated] = None + self.priority_updated: PriorityUpdated | None = None def __repr__(self) -> str: - return "" % ( - self.stream_id, self.headers - ) + return f"" class _HeadersSent(Event): @@ -155,7 +155,7 @@ class _HeadersSent(Event): This is an internal event, used to determine validation steps on outgoing header blocks. """ - pass + class _ResponseSent(_HeadersSent): @@ -166,7 +166,7 @@ class _ResponseSent(_HeadersSent): This is an internal event, used to determine validation steps on outgoing header blocks. """ - pass + class _RequestSent(_HeadersSent): @@ -177,7 +177,7 @@ class _RequestSent(_HeadersSent): This is an internal event, used to determine validation steps on outgoing header blocks. """ - pass + class _TrailersSent(_HeadersSent): @@ -190,7 +190,7 @@ class _TrailersSent(_HeadersSent): This is an internal event, used to determine validation steps on outgoing header blocks. """ - pass + class _PushedRequestSent(_HeadersSent): @@ -201,7 +201,7 @@ class _PushedRequestSent(_HeadersSent): This is an internal event, used to determine validation steps on outgoing header blocks. """ - pass + class InformationalResponseReceived(Event): @@ -226,25 +226,24 @@ class InformationalResponseReceived(Event): .. versionchanged:: 2.4.0 Added ``priority_updated`` property. """ + def __init__(self) -> None: #: The Stream ID for the stream this informational response was made #: on. - self.stream_id: Optional[int] = None + self.stream_id: int | None = None #: The headers for this informational response. - self.headers: Optional[list[HeaderTuple]] = None + self.headers: list[HeaderTuple] | None = None #: If this response also had associated priority information, the #: associated :class:`PriorityUpdated ` #: event will be available here. #: #: .. versionadded:: 2.4.0 - self.priority_updated: Optional[PriorityUpdated] = None + self.priority_updated: PriorityUpdated | None = None def __repr__(self) -> str: - return "" % ( - self.stream_id, self.headers - ) + return f"" class DataReceived(Event): @@ -256,31 +255,32 @@ class DataReceived(Event): .. versionchanged:: 2.4.0 Added ``stream_ended`` property. """ + def __init__(self) -> None: #: The Stream ID for the stream this data was received on. - self.stream_id: Optional[int] = None + self.stream_id: int | None = None #: The data itself. - self.data: Optional[bytes] = None + self.data: bytes | None = None #: The amount of data received that counts against the flow control #: window. Note that padding counts against the flow control window, so #: when adjusting flow control you should always use this field rather #: than ``len(data)``. - self.flow_controlled_length: Optional[int] = None + self.flow_controlled_length: int | None = None #: If this data chunk also completed the stream, the associated #: :class:`StreamEnded ` event will be available #: here. #: #: .. versionadded:: 2.4.0 - self.stream_ended: Optional[StreamEnded] = None + self.stream_ended: StreamEnded | None = None def __repr__(self) -> str: return ( - "" % ( + "".format( self.stream_id, self.flow_controlled_length, _bytes_representation(self.data[:20]) if self.data else "", @@ -296,18 +296,17 @@ class WindowUpdated(Event): the stream to which it applies (set to zero if the window update applies to the connection), and the delta in the window size. """ + def __init__(self) -> None: #: The Stream ID of the stream whose flow control window was changed. #: May be ``0`` if the connection window was changed. - self.stream_id: Optional[int] = None + self.stream_id: int | None = None #: The window delta. - self.delta: Optional[int] = None + self.delta: int | None = None def __repr__(self) -> str: - return "" % ( - self.stream_id, self.delta - ) + return f"" class RemoteSettingsChanged(Event): @@ -330,6 +329,7 @@ class RemoteSettingsChanged(Event): This is no longer the case: h2 now automatically acknowledges them. """ + def __init__(self) -> None: #: A dictionary of setting byte to #: :class:`ChangedSetting `, representing @@ -338,8 +338,8 @@ def __init__(self) -> None: @classmethod def from_settings(cls, - old_settings: Union[Settings, dict[int, int]], - new_settings: dict[int, int]) -> "RemoteSettingsChanged": + old_settings: Settings | dict[int, int], + new_settings: dict[int, int]) -> RemoteSettingsChanged: """ Build a RemoteSettingsChanged event from a set of changed settings. @@ -350,15 +350,15 @@ def from_settings(cls, """ e = cls() for setting, new_value in new_settings.items(): - setting = _setting_code_from_int(setting) - original_value = old_settings.get(setting) - change = ChangedSetting(setting, original_value, new_value) - e.changed_settings[setting] = change + s = _setting_code_from_int(setting) + original_value = old_settings.get(s) + change = ChangedSetting(s, original_value, new_value) + e.changed_settings[s] = change return e def __repr__(self) -> str: - return "" % ( + return "".format( ", ".join(repr(cs) for cs in self.changed_settings.values()), ) @@ -371,14 +371,13 @@ class PingReceived(Event): .. versionadded:: 3.1.0 """ + def __init__(self) -> None: #: The data included on the ping. - self.ping_data: Optional[bytes] = None + self.ping_data: bytes | None = None def __repr__(self) -> str: - return "" % ( - _bytes_representation(self.ping_data), - ) + return f"" class PingAckReceived(Event): @@ -392,14 +391,13 @@ class PingAckReceived(Event): .. versionchanged:: 4.0.0 Removed deprecated but equivalent ``PingAcknowledged``. """ + def __init__(self) -> None: #: The data included on the ping. - self.ping_data: Optional[bytes] = None + self.ping_data: bytes | None = None def __repr__(self) -> str: - return "" % ( - _bytes_representation(self.ping_data), - ) + return f"" class StreamEnded(Event): @@ -408,12 +406,13 @@ class StreamEnded(Event): party. The stream may not be fully closed if it has not been closed locally, but no further data or headers should be expected on that stream. """ + def __init__(self) -> None: #: The Stream ID of the stream that was closed. - self.stream_id: Optional[int] = None + self.stream_id: int | None = None def __repr__(self) -> str: - return "" % self.stream_id + return f"" class StreamReset(Event): @@ -426,21 +425,20 @@ class StreamReset(Event): .. versionchanged:: 2.0.0 This event is now fired when h2 automatically resets a stream. """ + def __init__(self) -> None: #: The Stream ID of the stream that was reset. - self.stream_id: Optional[int] = None + self.stream_id: int | None = None #: The error code given. Either one of :class:`ErrorCodes #: ` or ``int`` - self.error_code: Optional[ErrorCodes] = None + self.error_code: ErrorCodes | None = None #: Whether the remote peer sent a RST_STREAM or we did. self.remote_reset = True def __repr__(self) -> str: - return "" % ( - self.stream_id, self.error_code, self.remote_reset - ) + return f"" class PushedStreamReceived(Event): @@ -449,24 +447,21 @@ class PushedStreamReceived(Event): received from a remote peer. The event carries on it the new stream ID, the ID of the parent stream, and the request headers pushed by the remote peer. """ + def __init__(self) -> None: #: The Stream ID of the stream created by the push. - self.pushed_stream_id: Optional[int] = None + self.pushed_stream_id: int | None = None #: The Stream ID of the stream that the push is related to. - self.parent_stream_id: Optional[int] = None + self.parent_stream_id: int | None = None #: The request headers, sent by the remote party in the push. - self.headers: Optional[list[HeaderTuple]] = None + self.headers: list[HeaderTuple] | None = None def __repr__(self) -> str: return ( - "" % ( - self.pushed_stream_id, - self.parent_stream_id, - self.headers, - ) + f"" ) @@ -477,16 +472,16 @@ class SettingsAcknowledged(Event): acknowedged, in the same format as :class:`h2.events.RemoteSettingsChanged`. """ + def __init__(self) -> None: #: A dictionary of setting byte to #: :class:`ChangedSetting `, representing #: the changed settings. - self.changed_settings: dict[Union[SettingCodes, int], ChangedSetting] = {} + self.changed_settings: dict[SettingCodes | int, ChangedSetting] = {} def __repr__(self) -> str: - return "" % ( - ", ".join(repr(cs) for cs in self.changed_settings.values()), - ) + s = ", ".join(repr(cs) for cs in self.changed_settings.values()) + return f"" class PriorityUpdated(Event): @@ -499,31 +494,27 @@ class PriorityUpdated(Event): .. versionadded:: 2.0.0 """ + def __init__(self) -> None: #: The ID of the stream whose priority information is being updated. - self.stream_id: Optional[int] = None + self.stream_id: int | None = None #: The new stream weight. May be the same as the original stream #: weight. An integer between 1 and 256. - self.weight: Optional[int] = None + self.weight: int | None = None #: The stream ID this stream now depends on. May be ``0``. - self.depends_on: Optional[int] = None + self.depends_on: int | None = None #: Whether the stream *exclusively* depends on the parent stream. If it #: does, this stream should inherit the current children of its new #: parent. - self.exclusive: Optional[bool] = None + self.exclusive: bool | None = None def __repr__(self) -> str: return ( - "" % ( - self.stream_id, - self.weight, - self.depends_on, - self.exclusive - ) + f"" ) @@ -533,29 +524,30 @@ class ConnectionTerminated(Event): the remote peer using a GOAWAY frame. Once received, no further action may be taken on the connection: a new connection must be established. """ + def __init__(self) -> None: #: The error code cited when tearing down the connection. Should be #: one of :class:`ErrorCodes `, but may not be if #: unknown HTTP/2 extensions are being used. - self.error_code: Optional[Union[ErrorCodes, int]] = None + self.error_code: ErrorCodes | int | None = None #: The stream ID of the last stream the remote peer saw. This can #: provide an indication of what data, if any, never reached the remote #: peer and so can safely be resent. - self.last_stream_id: Optional[int] = None + self.last_stream_id: int | None = None #: Additional debug data that can be appended to GOAWAY frame. - self.additional_data: Optional[bytes] = None + self.additional_data: bytes | None = None def __repr__(self) -> str: return ( - "" % ( + "".format( self.error_code, self.last_stream_id, _bytes_representation( self.additional_data[:20] - if self.additional_data else None) + if self.additional_data else None), ) ) @@ -580,26 +572,27 @@ class AlternativeServiceAvailable(Event): .. versionadded:: 2.3.0 """ + def __init__(self) -> None: #: The origin to which the alternative service field value applies. #: This field is either supplied by the server directly, or inferred by #: h2 from the ``:authority`` pseudo-header field that was sent #: by the user when initiating the stream on which the frame was #: received. - self.origin: Optional[bytes] = None + self.origin: bytes | None = None #: The ALTSVC field value. This contains information about the HTTP #: alternative service being advertised by the server. h2 does #: not parse this field: it is left exactly as sent by the server. The #: structure of the data in this field is given by `RFC 7838 Section 3 #: `_. - self.field_value: Optional[bytes] = None + self.field_value: bytes | None = None def __repr__(self) -> str: return ( - "" % ( - (self.origin or b"").decode('utf-8', 'ignore'), - (self.field_value or b"").decode('utf-8', 'ignore'), + "".format( + (self.origin or b"").decode("utf-8", "ignore"), + (self.field_value or b"").decode("utf-8", "ignore"), ) ) @@ -618,15 +611,16 @@ class UnknownFrameReceived(Event): .. versionadded:: 2.7.0 """ + def __init__(self) -> None: #: The hyperframe Frame object that encapsulates the received frame. - self.frame: Optional[Frame] = None + self.frame: Frame | None = None def __repr__(self) -> str: return "" -def _bytes_representation(data: Optional[bytes]) -> Optional[str]: +def _bytes_representation(data: bytes | None) -> str | None: """ Converts a bytestring into something that is safe to print on all Python platforms. @@ -638,4 +632,4 @@ def _bytes_representation(data: Optional[bytes]) -> Optional[str]: if data is None: return None - return binascii.hexlify(data).decode('ascii') + return binascii.hexlify(data).decode("ascii") diff --git a/src/h2/exceptions.py b/src/h2/exceptions.py index e43e5715..e4776795 100644 --- a/src/h2/exceptions.py +++ b/src/h2/exceptions.py @@ -4,6 +4,8 @@ Exceptions for the HTTP/2 module. """ +from __future__ import annotations + from .errors import ErrorCodes @@ -17,6 +19,7 @@ class ProtocolError(H2Error): """ An action was attempted in violation of the HTTP/2 protocol. """ + #: The error code corresponds to this kind of Protocol Error. error_code = ErrorCodes.PROTOCOL_ERROR @@ -25,6 +28,7 @@ class FrameTooLargeError(ProtocolError): """ The frame that we tried to send or that we received was too large. """ + #: The error code corresponds to this kind of Protocol Error. error_code = ErrorCodes.FRAME_SIZE_ERROR @@ -35,6 +39,7 @@ class FrameDataMissingError(ProtocolError): .. versionadded:: 2.0.0 """ + #: The error code corresponds to this kind of Protocol Error. error_code = ErrorCodes.FRAME_SIZE_ERROR @@ -44,13 +49,14 @@ class TooManyStreamsError(ProtocolError): An attempt was made to open a stream that would lead to too many concurrent streams. """ - pass + class FlowControlError(ProtocolError): """ An attempted action violates flow control constraints. """ + #: The error code corresponds to this kind of Protocol Error. error_code = ErrorCodes.FLOW_CONTROL_ERROR @@ -60,6 +66,7 @@ class StreamIDTooLowError(ProtocolError): An attempt was made to open a stream that had an ID that is lower than the highest ID we have seen on this connection. """ + def __init__(self, stream_id: int, max_stream_id: int) -> None: #: The ID of the stream that we attempted to open. self.stream_id = stream_id @@ -68,9 +75,7 @@ def __init__(self, stream_id: int, max_stream_id: int) -> None: self.max_stream_id = max_stream_id def __str__(self) -> str: - return "StreamIDTooLowError: %d is lower than %d" % ( - self.stream_id, self.max_stream_id - ) + return f"StreamIDTooLowError: {self.stream_id} is lower than {self.max_stream_id}" class NoAvailableStreamIDError(ProtocolError): @@ -80,7 +85,7 @@ class NoAvailableStreamIDError(ProtocolError): .. versionadded:: 2.0.0 """ - pass + class NoSuchStreamError(ProtocolError): @@ -91,6 +96,7 @@ class NoSuchStreamError(ProtocolError): Became a subclass of :class:`ProtocolError ` """ + def __init__(self, stream_id: int) -> None: #: The stream ID corresponds to the non-existent stream. self.stream_id = stream_id @@ -103,6 +109,7 @@ class StreamClosedError(NoSuchStreamError): that the stream has since been closed, and that all state relating to that stream has been removed. """ + def __init__(self, stream_id: int) -> None: #: The stream ID corresponds to the nonexistent stream. self.stream_id = stream_id @@ -121,8 +128,9 @@ class InvalidSettingsValueError(ProtocolError, ValueError): .. versionadded:: 2.0.0 """ + def __init__(self, msg: str, error_code: ErrorCodes) -> None: - super(InvalidSettingsValueError, self).__init__(msg) + super().__init__(msg) self.error_code = error_code @@ -133,14 +141,13 @@ class InvalidBodyLengthError(ProtocolError): .. versionadded:: 2.0.0 """ + def __init__(self, expected: int, actual: int) -> None: self.expected_length = expected self.actual_length = actual def __str__(self) -> str: - return "InvalidBodyLengthError: Expected %d bytes, received %d" % ( - self.expected_length, self.actual_length - ) + return f"InvalidBodyLengthError: Expected {self.expected_length} bytes, received {self.actual_length}" class UnsupportedFrameError(ProtocolError): @@ -152,7 +159,7 @@ class UnsupportedFrameError(ProtocolError): .. versionchanged:: 4.0.0 Removed deprecated KeyError parent class. """ - pass + class RFC1122Error(H2Error): @@ -167,9 +174,9 @@ class RFC1122Error(H2Error): .. versionadded:: 2.4.0 """ + # shazow says I'm going to regret naming the exception this way. If that # turns out to be true, TELL HIM NOTHING. - pass class DenialOfServiceError(ProtocolError): @@ -181,6 +188,7 @@ class DenialOfServiceError(ProtocolError): .. versionadded:: 2.5.0 """ + #: The error code corresponds to this kind of #: :class:`ProtocolError ` error_code = ErrorCodes.ENHANCE_YOUR_CALM diff --git a/src/h2/frame_buffer.py b/src/h2/frame_buffer.py index 7d8e6622..30d96e81 100644 --- a/src/h2/frame_buffer.py +++ b/src/h2/frame_buffer.py @@ -5,16 +5,12 @@ A data structure that provides a way to iterate over a byte buffer in terms of frames. """ -from hyperframe.exceptions import InvalidFrameError, InvalidDataError -from hyperframe.frame import ( - Frame, HeadersFrame, ContinuationFrame, PushPromiseFrame -) +from __future__ import annotations -from .exceptions import ( - ProtocolError, FrameTooLargeError, FrameDataMissingError -) +from hyperframe.exceptions import InvalidDataError, InvalidFrameError +from hyperframe.frame import ContinuationFrame, Frame, HeadersFrame, PushPromiseFrame -from typing import Optional, Union +from .exceptions import FrameDataMissingError, FrameTooLargeError, ProtocolError # To avoid a DOS attack based on sending loads of continuation frames, we limit # the maximum number we're perpared to receive. In this case, we'll set the @@ -29,15 +25,16 @@ class FrameBuffer: """ - This is a data structure that expects to act as a buffer for HTTP/2 data - that allows iteraton in terms of H2 frames. + A buffer data structure for HTTP/2 data that allows iteraton in terms of + H2 frames. """ + def __init__(self, server: bool = False) -> None: - self.data = b'' + self.data = b"" self.max_frame_size = 0 - self._preamble = b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n' if server else b'' + self._preamble = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" if server else b"" self._preamble_len = len(self._preamble) - self._headers_buffer: list[Union[HeadersFrame, ContinuationFrame, PushPromiseFrame]] = [] + self._headers_buffer: list[HeadersFrame | ContinuationFrame | PushPromiseFrame] = [] def add_data(self, data: bytes) -> None: """ @@ -50,7 +47,8 @@ def add_data(self, data: bytes) -> None: of_which_preamble = min(self._preamble_len, data_len) if self._preamble[:of_which_preamble] != data[:of_which_preamble]: - raise ProtocolError("Invalid HTTP/2 preamble.") + msg = "Invalid HTTP/2 preamble." + raise ProtocolError(msg) data = data[of_which_preamble:] self._preamble_len -= of_which_preamble @@ -63,12 +61,10 @@ def _validate_frame_length(self, length: int) -> None: Confirm that the frame is an appropriate length. """ if length > self.max_frame_size: - raise FrameTooLargeError( - "Received overlong frame: length %d, max %d" % - (length, self.max_frame_size) - ) + msg = f"Received overlong frame: length {length}, max {self.max_frame_size}" + raise FrameTooLargeError(msg) - def _update_header_buffer(self, f: Optional[Frame]) -> Optional[Frame]: + def _update_header_buffer(self, f: Frame | None) -> Frame | None: """ Updates the internal header buffer. Returns a frame that should replace the current one. May throw exceptions if this frame is invalid. @@ -86,27 +82,29 @@ def _update_header_buffer(self, f: Optional[Frame]) -> Optional[Frame]: f.stream_id == stream_id ) if not valid_frame: - raise ProtocolError("Invalid frame during header block.") + msg = "Invalid frame during header block." + raise ProtocolError(msg) assert isinstance(f, ContinuationFrame) # Append the frame to the buffer. self._headers_buffer.append(f) if len(self._headers_buffer) > CONTINUATION_BACKLOG: - raise ProtocolError("Too many continuation frames received.") + msg = "Too many continuation frames received." + raise ProtocolError(msg) # If this is the end of the header block, then we want to build a # mutant HEADERS frame that's massive. Use the original one we got, # then set END_HEADERS and set its data appopriately. If it's not # the end of the block, lose the current frame: we can't yield it. - if 'END_HEADERS' in f.flags: + if "END_HEADERS" in f.flags: f = self._headers_buffer[0] - f.flags.add('END_HEADERS') - f.data = b''.join(x.data for x in self._headers_buffer) + f.flags.add("END_HEADERS") + f.data = b"".join(x.data for x in self._headers_buffer) self._headers_buffer = [] else: f = None elif (isinstance(f, (HeadersFrame, PushPromiseFrame)) and - 'END_HEADERS' not in f.flags): + "END_HEADERS" not in f.flags): # This is the start of a headers block! Save the frame off and then # act like we didn't receive one. self._headers_buffer.append(f) @@ -115,26 +113,25 @@ def _update_header_buffer(self, f: Optional[Frame]) -> Optional[Frame]: return f # The methods below support the iterator protocol. - def __iter__(self) -> "FrameBuffer": + def __iter__(self) -> FrameBuffer: return self def __next__(self) -> Frame: # First, check that we have enough data to successfully parse the # next frame header. If not, bail. Otherwise, parse it. if len(self.data) < 9: - raise StopIteration() + raise StopIteration try: f, length = Frame.parse_frame_header(memoryview(self.data[:9])) - except (InvalidDataError, InvalidFrameError) as e: # pragma: no cover - raise ProtocolError( - "Received frame with invalid header: %s" % str(e) - ) + except (InvalidDataError, InvalidFrameError) as err: # pragma: no cover + msg = f"Received frame with invalid header: {err!s}" + raise ProtocolError(msg) from err # Next, check that we have enough length to parse the frame body. If # not, bail, leaving the frame header data in the buffer for next time. if len(self.data) < length + 9: - raise StopIteration() + raise StopIteration # Confirm the frame has an appropriate length. self._validate_frame_length(length) @@ -142,10 +139,12 @@ def __next__(self) -> Frame: # Try to parse the frame body try: f.parse_body(memoryview(self.data[9:9+length])) - except InvalidDataError: - raise ProtocolError("Received frame with non-compliant data") - except InvalidFrameError: - raise FrameDataMissingError("Frame data missing or invalid") + except InvalidDataError as err: + msg = "Received frame with non-compliant data" + raise ProtocolError(msg) from err + except InvalidFrameError as err: + msg = "Frame data missing or invalid" + raise FrameDataMissingError(msg) from err # At this point, as we know we'll use or discard the entire frame, we # can update the data. diff --git a/src/h2/settings.py b/src/h2/settings.py index 2649f30d..c1be953b 100644 --- a/src/h2/settings.py +++ b/src/h2/settings.py @@ -6,17 +6,18 @@ API for manipulating HTTP/2 settings, keeping track of both the current active state of the settings and the unacknowledged future values of the settings. """ +from __future__ import annotations + import collections -from collections.abc import MutableMapping import enum +from collections.abc import Iterator, MutableMapping +from typing import Union from hyperframe.frame import SettingsFrame from .errors import ErrorCodes from .exceptions import InvalidSettingsValueError -from typing import Iterator, Optional, Union - class SettingCodes(enum.IntEnum): """ @@ -56,7 +57,7 @@ class SettingCodes(enum.IntEnum): ENABLE_CONNECT_PROTOCOL = SettingsFrame.ENABLE_CONNECT_PROTOCOL -def _setting_code_from_int(code: int) -> Union[SettingCodes, int]: +def _setting_code_from_int(code: int) -> SettingCodes | int: """ Given an integer setting code, returns either one of :class:`SettingCodes ` or, if not present in the known set of codes, @@ -70,7 +71,7 @@ def _setting_code_from_int(code: int) -> Union[SettingCodes, int]: class ChangedSetting: - def __init__(self, setting: Union[SettingCodes, int], original_value: Optional[int], new_value: int) -> None: + def __init__(self, setting: SettingCodes | int, original_value: int | None, new_value: int) -> None: #: The setting code given. Either one of :class:`SettingCodes #: ` or ``int`` #: @@ -85,12 +86,7 @@ def __init__(self, setting: Union[SettingCodes, int], original_value: Optional[i def __repr__(self) -> str: return ( - "ChangedSetting(setting=%s, original_value=%s, " - "new_value=%s)" - ) % ( - self.setting, - self.original_value, - self.new_value + f"ChangedSetting(setting={self.setting!s}, original_value={self.original_value}, new_value={self.new_value})" ) @@ -129,14 +125,15 @@ class Settings(MutableMapping[Union[SettingCodes, int], int]): set, rather than RFC 7540's defaults. :type initial_vales: ``MutableMapping`` """ - def __init__(self, client: bool = True, initial_values: Optional[dict[SettingCodes, int]] = None) -> None: + + def __init__(self, client: bool = True, initial_values: dict[SettingCodes, int] | None = None) -> None: # Backing object for the settings. This is a dictionary of # (setting: [list of values]), where the first value in the list is the # current value of the setting. Strictly this doesn't use lists but # instead uses collections.deque to avoid repeated memory allocations. # # This contains the default values for HTTP/2. - self._settings: dict[Union[SettingCodes, int], collections.deque[int]] = { + self._settings: dict[SettingCodes | int, collections.deque[int]] = { SettingCodes.HEADER_TABLE_SIZE: collections.deque([4096]), SettingCodes.ENABLE_PUSH: collections.deque([int(client)]), SettingCodes.INITIAL_WINDOW_SIZE: collections.deque([65535]), @@ -147,20 +144,21 @@ def __init__(self, client: bool = True, initial_values: Optional[dict[SettingCod for key, value in initial_values.items(): invalid = _validate_setting(key, value) if invalid: + msg = f"Setting {key} has invalid value {value}" raise InvalidSettingsValueError( - "Setting %d has invalid value %d" % (key, value), - error_code=invalid + msg, + error_code=invalid, ) self._settings[key] = collections.deque([value]) - def acknowledge(self) -> dict[Union[SettingCodes, int], ChangedSetting]: + def acknowledge(self) -> dict[SettingCodes | int, ChangedSetting]: """ The settings have been acknowledged, either by the user (remote settings) or by the remote peer (local settings). :returns: A dict of {setting: ChangedSetting} that were applied. """ - changed_settings: dict[Union[SettingCodes, int], ChangedSetting] = {} + changed_settings: dict[SettingCodes | int, ChangedSetting] = {} # If there is more than one setting in the list, we have a setting # value outstanding. Update them. @@ -169,7 +167,7 @@ def acknowledge(self) -> dict[Union[SettingCodes, int], ChangedSetting]: old_setting = v.popleft() new_setting = v[0] changed_settings[k] = ChangedSetting( - k, old_setting, new_setting + k, old_setting, new_setting, ) return changed_settings @@ -236,7 +234,7 @@ def max_concurrent_streams(self, value: int) -> None: self[SettingCodes.MAX_CONCURRENT_STREAMS] = value @property - def max_header_list_size(self) -> Optional[int]: + def max_header_list_size(self) -> int | None: """ The current value of the :data:`MAX_HEADER_LIST_SIZE ` setting. If not set, @@ -263,7 +261,7 @@ def enable_connect_protocol(self, value: int) -> None: self[SettingCodes.ENABLE_CONNECT_PROTOCOL] = value # Implement the MutableMapping API. - def __getitem__(self, key: Union[SettingCodes, int]) -> int: + def __getitem__(self, key: SettingCodes | int) -> int: val = self._settings[key][0] # Things that were created when a setting was received should stay @@ -273,12 +271,13 @@ def __getitem__(self, key: Union[SettingCodes, int]) -> int: return val - def __setitem__(self, key: Union[SettingCodes, int], value: int) -> None: + def __setitem__(self, key: SettingCodes | int, value: int) -> None: invalid = _validate_setting(key, value) if invalid: + msg = f"Setting {key} has invalid value {value}" raise InvalidSettingsValueError( - "Setting %d has invalid value %d" % (key, value), - error_code=invalid + msg, + error_code=invalid, ) try: @@ -289,10 +288,10 @@ def __setitem__(self, key: Union[SettingCodes, int], value: int) -> None: items.append(value) - def __delitem__(self, key: Union[SettingCodes, int]) -> None: + def __delitem__(self, key: SettingCodes | int) -> None: del self._settings[key] - def __iter__(self) -> Iterator[Union[SettingCodes, int]]: + def __iter__(self) -> Iterator[SettingCodes | int]: return self._settings.__iter__() def __len__(self) -> int: @@ -301,17 +300,15 @@ def __len__(self) -> int: def __eq__(self, other: object) -> bool: if isinstance(other, Settings): return self._settings == other._settings - else: - return NotImplemented + return NotImplemented def __ne__(self, other: object) -> bool: if isinstance(other, Settings): return not self == other - else: - return NotImplemented + return NotImplemented -def _validate_setting(setting: Union[SettingCodes, int], value: int) -> ErrorCodes: # noqa: C901 +def _validate_setting(setting: SettingCodes | int, value: int) -> ErrorCodes: """ Confirms that a specific setting has a well-formed value. If the setting is invalid, returns an error code. Otherwise, returns 0 (NO_ERROR). @@ -328,8 +325,7 @@ def _validate_setting(setting: Union[SettingCodes, int], value: int) -> ErrorCod elif setting == SettingCodes.MAX_HEADER_LIST_SIZE: if value < 0: return ErrorCodes.PROTOCOL_ERROR - elif setting == SettingCodes.ENABLE_CONNECT_PROTOCOL: - if value not in (0, 1): - return ErrorCodes.PROTOCOL_ERROR + elif setting == SettingCodes.ENABLE_CONNECT_PROTOCOL and value not in (0, 1): + return ErrorCodes.PROTOCOL_ERROR return ErrorCodes.NO_ERROR diff --git a/src/h2/stream.py b/src/h2/stream.py index 9002d02a..7d4a12e3 100644 --- a/src/h2/stream.py +++ b/src/h2/stream.py @@ -4,41 +4,54 @@ An implementation of a HTTP/2 stream. """ +from __future__ import annotations + from enum import Enum, IntEnum -from collections.abc import Generator +from typing import TYPE_CHECKING, Any from hpack import HeaderTuple -from hpack.hpack import Encoder -from hpack.struct import Header, HeaderWeaklyTyped - -from hyperframe.frame import ( - Frame, - AltSvcFrame, ContinuationFrame, DataFrame, HeadersFrame, PushPromiseFrame, - RstStreamFrame, WindowUpdateFrame -) +from hyperframe.frame import AltSvcFrame, ContinuationFrame, DataFrame, Frame, HeadersFrame, PushPromiseFrame, RstStreamFrame, WindowUpdateFrame from .errors import ErrorCodes, _error_code_from_int from .events import ( + AlternativeServiceAvailable, + DataReceived, Event, - RequestReceived, ResponseReceived, DataReceived, WindowUpdated, - StreamEnded, PushedStreamReceived, StreamReset, TrailersReceived, - InformationalResponseReceived, AlternativeServiceAvailable, - _ResponseSent, _RequestSent, _TrailersSent, _PushedRequestSent -) -from .exceptions import ( - ProtocolError, StreamClosedError, InvalidBodyLengthError, FlowControlError + InformationalResponseReceived, + PushedStreamReceived, + RequestReceived, + ResponseReceived, + StreamEnded, + StreamReset, + TrailersReceived, + WindowUpdated, + _PushedRequestSent, + _RequestSent, + _ResponseSent, + _TrailersSent, ) +from .exceptions import FlowControlError, InvalidBodyLengthError, ProtocolError, StreamClosedError from .utilities import ( - guard_increment_window, is_informational_response, authority_from_headers, - validate_headers, validate_outbound_headers, normalize_outbound_headers, - HeaderValidationFlags, extract_method_header, normalize_inbound_headers, - utf8_encode_headers + HeaderValidationFlags, + authority_from_headers, + extract_method_header, + guard_increment_window, + is_informational_response, + normalize_inbound_headers, + normalize_outbound_headers, + utf8_encode_headers, + validate_headers, + validate_outbound_headers, ) from .windows import WindowManager -from .config import H2Configuration +if TYPE_CHECKING: # pragma: no cover + from collections.abc import Generator, Iterable -from typing import Any, Iterable, Optional, Union + from hpack.hpack import Encoder + from hpack.struct import Header, HeaderWeaklyTyped + + from .config import H2Configuration class StreamState(IntEnum): @@ -85,7 +98,7 @@ class StreamClosedBy(Enum): # this is that we potentially check whether a stream in a given state is open # quite frequently: given that we check so often, we should do so in the # fastest and most performant way possible. -STREAM_OPEN = [False for _ in range(0, len(StreamState))] +STREAM_OPEN = [False for _ in range(len(StreamState))] STREAM_OPEN[StreamState.OPEN] = True STREAM_OPEN[StreamState.HALF_CLOSED_LOCAL] = True STREAM_OPEN[StreamState.HALF_CLOSED_REMOTE] = True @@ -101,37 +114,38 @@ class H2StreamStateMachine: :param stream_id: The stream ID of this stream. This is stored primarily for logging purposes. """ + def __init__(self, stream_id: int) -> None: self.state = StreamState.IDLE self.stream_id = stream_id #: Whether this peer is the client side of this stream. - self.client: Optional[bool] = None + self.client: bool | None = None # Whether trailers have been sent/received on this stream or not. - self.headers_sent: Optional[bool] = None - self.trailers_sent: Optional[bool] = None - self.headers_received: Optional[bool] = None - self.trailers_received: Optional[bool] = None + self.headers_sent: bool | None = None + self.trailers_sent: bool | None = None + self.headers_received: bool | None = None + self.trailers_received: bool | None = None # How the stream was closed. One of StreamClosedBy. - self.stream_closed_by: Optional[StreamClosedBy] = None + self.stream_closed_by: StreamClosedBy | None = None def process_input(self, input_: StreamInputs) -> Any: """ Process a specific input in the state machine. """ if not isinstance(input_, StreamInputs): - raise ValueError("Input must be an instance of StreamInputs") + msg = "Input must be an instance of StreamInputs" + raise ValueError(msg) # noqa: TRY004 try: func, target_state = _transitions[(self.state, input_)] - except KeyError: + except KeyError as err: old_state = self.state self.state = StreamState.CLOSED - raise ProtocolError( - "Invalid input %s in state %s" % (input_, old_state) - ) + msg = f"Invalid input {input_} in state {old_state}" + raise ProtocolError(msg) from err else: previous_state = self.state self.state = target_state @@ -141,9 +155,9 @@ def process_input(self, input_: StreamInputs) -> Any: except ProtocolError: self.state = StreamState.CLOSED raise - except AssertionError as e: # pragma: no cover + except AssertionError as err: # pragma: no cover self.state = StreamState.CLOSED - raise ProtocolError(e) + raise ProtocolError(err) from err return [] @@ -164,13 +178,13 @@ def response_sent(self, previous_state: StreamState) -> list[Event]: """ if not self.headers_sent: if self.client is True or self.client is None: - raise ProtocolError("Client cannot send responses.") + msg = "Client cannot send responses." + raise ProtocolError(msg) self.headers_sent = True return [_ResponseSent()] - else: - assert not self.trailers_sent - self.trailers_sent = True - return [_TrailersSent()] + assert not self.trailers_sent + self.trailers_sent = True + return [_TrailersSent()] def request_received(self, previous_state: StreamState) -> list[Event]: """ @@ -190,7 +204,7 @@ def response_received(self, previous_state: StreamState) -> list[Event]: Fires when a response is received. Also disambiguates between responses and trailers. """ - event: Union[ResponseReceived, TrailersReceived] + event: ResponseReceived | TrailersReceived if not self.headers_received: assert self.client is True self.headers_received = True @@ -208,7 +222,8 @@ def data_received(self, previous_state: StreamState) -> list[Event]: Fires when data is received. """ if not self.headers_received: - raise ProtocolError("cannot receive data before headers") + msg = "cannot receive data before headers" + raise ProtocolError(msg) event = DataReceived() event.stream_id = self.stream_id return [event] @@ -276,7 +291,8 @@ def send_push_promise(self, previous_state: StreamState) -> list[Event]: We may only send PUSH_PROMISE frames if we're a server. """ if self.client is True: - raise ProtocolError("Cannot push streams from client peers.") + msg = "Cannot push streams from client peers." + raise ProtocolError(msg) event = _PushedRequestSent() return [event] @@ -374,8 +390,8 @@ def recv_push_on_closed_stream(self, previous_state: StreamState) -> None: if self.stream_closed_by == StreamClosedBy.SEND_RST_STREAM: raise StreamClosedError(self.stream_id) - else: - raise ProtocolError("Attempted to push on closed stream.") + msg = "Attempted to push on closed stream." + raise ProtocolError(msg) def send_push_on_closed_stream(self, previous_state: StreamState) -> None: """ @@ -387,7 +403,8 @@ def send_push_on_closed_stream(self, previous_state: StreamState) -> None: allowed to be there. The only recourse is to tear the whole connection down. """ - raise ProtocolError("Attempted to push on closed stream.") + msg = "Attempted to push on closed stream." + raise ProtocolError(msg) def send_informational_response(self, previous_state: StreamState) -> list[Event]: """ @@ -397,7 +414,8 @@ def send_informational_response(self, previous_state: StreamState) -> list[Event Only enforces that these are sent *before* final headers are sent. """ if self.headers_sent: - raise ProtocolError("Information response after final response") + msg = "Information response after final response" + raise ProtocolError(msg) event = _ResponseSent() return [event] @@ -408,7 +426,8 @@ def recv_informational_response(self, previous_state: StreamState) -> list[Event where the :status header has a 1XX value). """ if self.headers_received: - raise ProtocolError("Informational response after final response") + msg = "Informational response after final response" + raise ProtocolError(msg) event = InformationalResponseReceived() event.stream_id = self.stream_id @@ -468,11 +487,9 @@ def send_alt_svc(self, previous_state: StreamState) -> None: # We should not send ALTSVC after we've sent response headers, as the # client may have disposed of its state. if self.headers_sent: - raise ProtocolError( - "Cannot send ALTSVC after sending response headers." - ) + msg = "Cannot send ALTSVC after sending response headers." + raise ProtocolError(msg) - return # STATE MACHINE @@ -755,6 +772,7 @@ class H2Stream: Attempts to create frames that cannot be sent will raise a ``ProtocolError``. """ + def __init__(self, stream_id: int, config: H2Configuration, @@ -762,8 +780,8 @@ def __init__(self, outbound_window_size: int) -> None: self.state_machine = H2StreamStateMachine(stream_id) self.stream_id = stream_id - self.max_outbound_frame_size: Optional[int] = None - self.request_method: Optional[bytes] = None + self.max_outbound_frame_size: int | None = None + self.request_method: bytes | None = None # The current value of the outbound stream flow control window self.outbound_flow_control_window = outbound_window_size @@ -772,23 +790,19 @@ def __init__(self, self._inbound_window_manager = WindowManager(inbound_window_size) # The expected content length, if any. - self._expected_content_length: Optional[int] = None + self._expected_content_length: int | None = None # The actual received content length. Always tracked. self._actual_content_length = 0 # The authority we believe this stream belongs to. - self._authority: Optional[bytes] = None + self._authority: bytes | None = None # The configuration for this stream. self.config = config def __repr__(self) -> str: - return "<%s id:%d state:%r>" % ( - type(self).__name__, - self.stream_id, - self.state_machine.state - ) + return f"<{type(self).__name__} id:{self.stream_id} state:{self.state_machine.state!r}>" @property def inbound_flow_control_window(self) -> int: @@ -821,7 +835,7 @@ def closed(self) -> bool: return self.state_machine.state == StreamState.CLOSED @property - def closed_by(self) -> Optional[StreamClosedBy]: + def closed_by(self) -> StreamClosedBy | None: """ Returns how the stream was closed, as one of StreamClosedBy. """ @@ -843,12 +857,11 @@ def upgrade(self, client_side: bool) -> None: # This may return events, we deliberately don't want them. self.state_machine.process_input(input_) - return def send_headers(self, headers: Iterable[HeaderWeaklyTyped], encoder: Encoder, - end_stream: bool = False) -> list[Union[HeadersFrame, ContinuationFrame, PushPromiseFrame]]: + end_stream: bool = False) -> list[HeadersFrame | ContinuationFrame | PushPromiseFrame]: """ Returns a list of HEADERS/CONTINUATION frames to emit as either headers or trailers. @@ -869,9 +882,8 @@ def send_headers(self, if ((not self.state_machine.client) and is_informational_response(bytes_headers)): if end_stream: - raise ProtocolError( - "Cannot set END_STREAM on informational responses." - ) + msg = "Cannot set END_STREAM on informational responses." + raise ProtocolError(msg) input_ = StreamInputs.SEND_INFORMATIONAL_HEADERS @@ -880,17 +892,18 @@ def send_headers(self, hf = HeadersFrame(self.stream_id) hdr_validation_flags = self._build_hdr_validation_flags(events) frames = self._build_headers_frames( - bytes_headers, encoder, hf, hdr_validation_flags + bytes_headers, encoder, hf, hdr_validation_flags, ) if end_stream: # Not a bug: the END_STREAM flag is valid on the initial HEADERS # frame, not the CONTINUATION frames that follow. self.state_machine.process_input(StreamInputs.SEND_END_STREAM) - frames[0].flags.add('END_STREAM') + frames[0].flags.add("END_STREAM") if self.state_machine.trailers_sent and not end_stream: - raise ProtocolError("Trailers must have END_STREAM set.") + msg = "Trailers must have END_STREAM set." + raise ProtocolError(msg) if self.state_machine.client and self._authority is None: self._authority = authority_from_headers(bytes_headers) @@ -903,7 +916,7 @@ def send_headers(self, def push_stream_in_band(self, related_stream_id: int, headers: Iterable[HeaderWeaklyTyped], - encoder: Encoder) -> list[Union[HeadersFrame, ContinuationFrame, PushPromiseFrame]]: + encoder: Encoder) -> list[HeadersFrame | ContinuationFrame | PushPromiseFrame]: """ Returns a list of PUSH_PROMISE/CONTINUATION frames to emit as a pushed stream header. Called on the stream that has the PUSH_PROMISE frame @@ -915,7 +928,7 @@ def push_stream_in_band(self, # compression context, we make the state transition *first*. events = self.state_machine.process_input( - StreamInputs.SEND_PUSH_PROMISE + StreamInputs.SEND_PUSH_PROMISE, ) ppf = PushPromiseFrame(self.stream_id) @@ -924,11 +937,10 @@ def push_stream_in_band(self, bytes_headers = utf8_encode_headers(headers) - frames = self._build_headers_frames( - bytes_headers, encoder, ppf, hdr_validation_flags + return self._build_headers_frames( + bytes_headers, encoder, ppf, hdr_validation_flags, ) - return frames def locally_pushed(self) -> list[Frame]: """ @@ -938,22 +950,22 @@ def locally_pushed(self) -> list[Frame]: """ # This does not trigger any events. events = self.state_machine.process_input( - StreamInputs.SEND_PUSH_PROMISE + StreamInputs.SEND_PUSH_PROMISE, ) assert not events return [] def send_data(self, - data: Union[bytes, memoryview], + data: bytes | memoryview, end_stream: bool = False, - pad_length: Optional[int] = None) -> list[Frame]: + pad_length: int | None = None) -> list[Frame]: """ Prepare some data frames. Optionally end the stream. .. warning:: Does not perform flow control checks. """ self.config.logger.debug( - "Send data on %r with end stream set to %s", self, end_stream + "Send data on %r with end stream set to %s", self, end_stream, ) self.state_machine.process_input(StreamInputs.SEND_DATA) @@ -962,9 +974,9 @@ def send_data(self, df.data = data if end_stream: self.state_machine.process_input(StreamInputs.SEND_END_STREAM) - df.flags.add('END_STREAM') + df.flags.add("END_STREAM") if pad_length is not None: - df.flags.add('PADDED') + df.flags.add("PADDED") df.pad_length = pad_length # Subtract flow_controlled_length to account for possible padding @@ -981,7 +993,7 @@ def end_stream(self) -> list[Frame]: self.state_machine.process_input(StreamInputs.SEND_END_STREAM) df = DataFrame(self.stream_id) - df.flags.add('END_STREAM') + df.flags.add("END_STREAM") return [df] def advertise_alternative_service(self, field_value: bytes) -> list[Frame]: @@ -990,7 +1002,7 @@ def advertise_alternative_service(self, field_value: bytes) -> list[Frame]: better documented in the ``H2Connection`` class. """ self.config.logger.debug( - "Advertise alternative service of %r for %r", field_value, self + "Advertise alternative service of %r for %r", field_value, self, ) self.state_machine.process_input(StreamInputs.SEND_ALTERNATIVE_SERVICE) asf = AltSvcFrame(self.stream_id) @@ -1003,7 +1015,7 @@ def increase_flow_control_window(self, increment: int) -> list[Frame]: """ self.config.logger.debug( "Increase flow control window for %r by %d", - self, increment + self, increment, ) self.state_machine.process_input(StreamInputs.SEND_WINDOW_UPDATE) self._inbound_window_manager.window_opened(increment) @@ -1015,7 +1027,7 @@ def increase_flow_control_window(self, increment: int) -> list[Frame]: def receive_push_promise_in_band(self, promised_stream_id: int, headers: Iterable[Header], - header_encoding: Optional[Union[bool, str]]) -> tuple[list[Frame], list[Event]]: + header_encoding: bool | str | None) -> tuple[list[Frame], list[Event]]: """ Receives a push promise frame sent on this stream, pushing a remote stream. This is called on the stream that has the PUSH_PROMISE sent @@ -1023,16 +1035,16 @@ def receive_push_promise_in_band(self, """ self.config.logger.debug( "Receive Push Promise on %r for remote stream %d", - self, promised_stream_id + self, promised_stream_id, ) events = self.state_machine.process_input( - StreamInputs.RECV_PUSH_PROMISE + StreamInputs.RECV_PUSH_PROMISE, ) events[0].pushed_stream_id = promised_stream_id hdr_validation_flags = self._build_hdr_validation_flags(events) events[0].headers = self._process_received_headers( - headers, hdr_validation_flags, header_encoding + headers, hdr_validation_flags, header_encoding, ) return [], events @@ -1044,7 +1056,7 @@ def remotely_pushed(self, pushed_headers: Iterable[Header]) -> tuple[list[Frame] """ self.config.logger.debug("%r pushed by remote peer", self) events = self.state_machine.process_input( - StreamInputs.RECV_PUSH_PROMISE + StreamInputs.RECV_PUSH_PROMISE, ) self._authority = authority_from_headers(pushed_headers) return [], events @@ -1052,15 +1064,14 @@ def remotely_pushed(self, pushed_headers: Iterable[Header]) -> tuple[list[Frame] def receive_headers(self, headers: Iterable[Header], end_stream: bool, - header_encoding: Optional[Union[bool, str]]) -> tuple[list[Frame], list[Event]]: + header_encoding: bool | str | None) -> tuple[list[Frame], list[Event]]: """ Receive a set of headers (or trailers). """ if is_informational_response(headers): if end_stream: - raise ProtocolError( - "Cannot set END_STREAM on informational responses" - ) + msg = "Cannot set END_STREAM on informational responses" + raise ProtocolError(msg) input_ = StreamInputs.RECV_INFORMATIONAL_HEADERS else: input_ = StreamInputs.RECV_HEADERS @@ -1069,20 +1080,20 @@ def receive_headers(self, if end_stream: es_events = self.state_machine.process_input( - StreamInputs.RECV_END_STREAM + StreamInputs.RECV_END_STREAM, ) events[0].stream_ended = es_events[0] events += es_events self._initialize_content_length(headers) - if isinstance(events[0], TrailersReceived): - if not end_stream: - raise ProtocolError("Trailers must have END_STREAM set") + if isinstance(events[0], TrailersReceived) and not end_stream: + msg = "Trailers must have END_STREAM set" + raise ProtocolError(msg) hdr_validation_flags = self._build_hdr_validation_flags(events) events[0].headers = self._process_received_headers( - headers, hdr_validation_flags, header_encoding + headers, hdr_validation_flags, header_encoding, ) return [], events @@ -1092,7 +1103,7 @@ def receive_data(self, data: bytes, end_stream: bool, flow_control_len: int) -> """ self.config.logger.debug( "Receive data on %r with end stream %s and flow control length " - "set to %d", self, end_stream, flow_control_len + "set to %d", self, end_stream, flow_control_len, ) events = self.state_machine.process_input(StreamInputs.RECV_DATA) self._inbound_window_manager.window_consumed(flow_control_len) @@ -1100,7 +1111,7 @@ def receive_data(self, data: bytes, end_stream: bool, flow_control_len: int) -> if end_stream: es_events = self.state_machine.process_input( - StreamInputs.RECV_END_STREAM + StreamInputs.RECV_END_STREAM, ) events[0].stream_ended = es_events[0] events.extend(es_events) @@ -1115,10 +1126,10 @@ def receive_window_update(self, increment: int) -> tuple[list[Frame], list[Event """ self.config.logger.debug( "Receive Window Update on %r for increment of %d", - self, increment + self, increment, ) events = self.state_machine.process_input( - StreamInputs.RECV_WINDOW_UPDATE + StreamInputs.RECV_WINDOW_UPDATE, ) frames = [] @@ -1130,7 +1141,7 @@ def receive_window_update(self, increment: int) -> tuple[list[Frame], list[Event try: self.outbound_flow_control_window = guard_increment_window( self.outbound_flow_control_window, - increment + increment, ) except FlowControlError: # Ok, this is bad. We're going to need to perform a local @@ -1153,9 +1164,10 @@ def receive_continuation(self) -> None: """ self.config.logger.debug("Receive Continuation frame on %r", self) self.state_machine.process_input( - StreamInputs.RECV_CONTINUATION + StreamInputs.RECV_CONTINUATION, ) - assert False, "Should not be reachable" + msg = "Should not be reachable" # pragma: no cover + raise AssertionError(msg) # pragma: no cover def receive_alt_svc(self, frame: AltSvcFrame) -> tuple[list[Frame], list[Event]]: """ @@ -1163,7 +1175,7 @@ def receive_alt_svc(self, frame: AltSvcFrame) -> tuple[list[Frame], list[Event]] inherits the origin associated with this stream. """ self.config.logger.debug( - "Receive Alternative Service frame on stream %r", self + "Receive Alternative Service frame on stream %r", self, ) # If the origin is present, RFC 7838 says we have to ignore it. @@ -1171,7 +1183,7 @@ def receive_alt_svc(self, frame: AltSvcFrame) -> tuple[list[Frame], list[Event]] return [], [] events = self.state_machine.process_input( - StreamInputs.RECV_ALTERNATIVE_SERVICE + StreamInputs.RECV_ALTERNATIVE_SERVICE, ) # There are lots of situations where we want to ignore the ALTSVC @@ -1184,12 +1196,12 @@ def receive_alt_svc(self, frame: AltSvcFrame) -> tuple[list[Frame], list[Event]] return [], events - def reset_stream(self, error_code: Union[ErrorCodes, int] = 0) -> list[Frame]: + def reset_stream(self, error_code: ErrorCodes | int = 0) -> list[Frame]: """ Close the stream locally. Reset the stream with an error code. """ self.config.logger.debug( - "Local reset %r with error code: %d", self, error_code + "Local reset %r with error code: %d", self, error_code, ) self.state_machine.process_input(StreamInputs.SEND_RST_STREAM) @@ -1202,7 +1214,7 @@ def stream_reset(self, frame: RstStreamFrame) -> tuple[list[Frame], list[Event]] Handle a stream being reset remotely. """ self.config.logger.debug( - "Remote reset %r with error code: %d", self, frame.error_code + "Remote reset %r with error code: %d", self, frame.error_code, ) events = self.state_machine.process_input(StreamInputs.RECV_RST_STREAM) @@ -1220,10 +1232,10 @@ def acknowledge_received_data(self, acknowledged_size: int) -> list[Frame]: """ self.config.logger.debug( "Acknowledge received data with size %d on %r", - acknowledged_size, self + acknowledged_size, self, ) increment = self._inbound_window_manager.process_bytes( - acknowledged_size + acknowledged_size, ) if increment: f = WindowUpdateFrame(self.stream_id) @@ -1238,22 +1250,22 @@ def _build_hdr_validation_flags(self, events: Any) -> HeaderValidationFlags: and validating header blocks. """ is_trailer = isinstance( - events[0], (_TrailersSent, TrailersReceived) + events[0], (_TrailersSent, TrailersReceived), ) is_response_header = isinstance( events[0], ( _ResponseSent, ResponseReceived, - InformationalResponseReceived - ) + InformationalResponseReceived, + ), ) is_push_promise = isinstance( - events[0], (PushedStreamReceived, _PushedRequestSent) + events[0], (PushedStreamReceived, _PushedRequestSent), ) return HeaderValidationFlags( - is_client=self.state_machine.client, + is_client=self.state_machine.client or False, is_trailer=is_trailer, is_response_header=is_response_header, is_push_promise=is_push_promise, @@ -1262,13 +1274,12 @@ def _build_hdr_validation_flags(self, events: Any) -> HeaderValidationFlags: def _build_headers_frames(self, headers: Iterable[Header], encoder: Encoder, - first_frame: Union[HeadersFrame, PushPromiseFrame], + first_frame: HeadersFrame | PushPromiseFrame, hdr_validation_flags: HeaderValidationFlags) \ - -> list[Union[HeadersFrame, ContinuationFrame, PushPromiseFrame]]: + -> list[HeadersFrame | ContinuationFrame | PushPromiseFrame]: """ Helper method to build headers or push promise frames. """ - # We need to lowercase the header names, and to ensure that secure # header fields are kept out of compression contexts. if self.config.normalize_outbound_headers: @@ -1277,11 +1288,11 @@ def _build_headers_frames(self, should_split_outbound_cookies = self.config.split_outbound_cookies headers = normalize_outbound_headers( - headers, hdr_validation_flags, should_split_outbound_cookies + headers, hdr_validation_flags, should_split_outbound_cookies, ) if self.config.validate_outbound_headers: headers = validate_outbound_headers( - headers, hdr_validation_flags + headers, hdr_validation_flags, ) encoded_headers = encoder.encode(headers) @@ -1292,11 +1303,11 @@ def _build_headers_frames(self, header_blocks = [ encoded_headers[i:i+(self.max_outbound_frame_size or 0)] for i in range( - 0, len(encoded_headers), (self.max_outbound_frame_size or 0) + 0, len(encoded_headers), (self.max_outbound_frame_size or 0), ) ] - frames: list[Union[HeadersFrame, ContinuationFrame, PushPromiseFrame]] = [] + frames: list[HeadersFrame | ContinuationFrame | PushPromiseFrame] = [] first_frame.data = header_blocks[0] frames.append(first_frame) @@ -1305,13 +1316,13 @@ def _build_headers_frames(self, cf.data = block frames.append(cf) - frames[-1].flags.add('END_HEADERS') + frames[-1].flags.add("END_HEADERS") return frames def _process_received_headers(self, headers: Iterable[Header], header_validation_flags: HeaderValidationFlags, - header_encoding: Optional[Union[bool, str]]) -> Iterable[Header]: + header_encoding: bool | str | None) -> Iterable[Header]: """ When headers have been received from the remote peer, run a processing pipeline on them to transform them into the appropriate form for @@ -1319,7 +1330,7 @@ def _process_received_headers(self, """ if self.config.normalize_inbound_headers: headers = normalize_inbound_headers( - headers, header_validation_flags + headers, header_validation_flags, ) if self.config.validate_inbound_headers: @@ -1338,18 +1349,17 @@ def _initialize_content_length(self, headers: Iterable[Header]) -> None: _expected_content_length field from it. It's not an error for no Content-Length header to be present. """ - if self.request_method == b'HEAD': + if self.request_method == b"HEAD": self._expected_content_length = 0 return for n, v in headers: - if n == b'content-length': + if n == b"content-length": try: self._expected_content_length = int(v, 10) - except ValueError: - raise ProtocolError( - f"Invalid content-length header: {repr(v)}" - ) + except ValueError as err: + msg = f"Invalid content-length header: {v!r}" + raise ProtocolError(msg) from err return diff --git a/src/h2/utilities.py b/src/h2/utilities.py index 4c377335..f706b1ef 100644 --- a/src/h2/utilities.py +++ b/src/h2/utilities.py @@ -4,72 +4,75 @@ Utility functions that do not belong in a separate module. """ +from __future__ import annotations + import collections import re from string import whitespace +from typing import TYPE_CHECKING, Any, NamedTuple -from hpack.struct import HeaderTuple, NeverIndexedHeaderTuple, Header, HeaderWeaklyTyped +from hpack.struct import Header, HeaderTuple, HeaderWeaklyTyped, NeverIndexedHeaderTuple -from .exceptions import ProtocolError, FlowControlError +from .exceptions import FlowControlError, ProtocolError -from typing import Any, Dict, Iterable, Optional, Set, Union -from collections.abc import Generator +if TYPE_CHECKING: # pragma: no cover + from collections.abc import Generator, Iterable UPPER_RE = re.compile(b"[A-Z]") -SIGIL = ord(b':') -INFORMATIONAL_START = ord(b'1') +SIGIL = ord(b":") +INFORMATIONAL_START = ord(b"1") # A set of headers that are hop-by-hop or connection-specific and thus # forbidden in HTTP/2. This list comes from RFC 7540 § 8.1.2.2. CONNECTION_HEADERS = frozenset([ - b'connection', - b'proxy-connection', - b'keep-alive', - b'transfer-encoding', - b'upgrade', + b"connection", + b"proxy-connection", + b"keep-alive", + b"transfer-encoding", + b"upgrade", ]) _ALLOWED_PSEUDO_HEADER_FIELDS = frozenset([ - b':method', - b':scheme', - b':authority', - b':path', - b':status', - b':protocol', + b":method", + b":scheme", + b":authority", + b":path", + b":status", + b":protocol", ]) _SECURE_HEADERS = frozenset([ # May have basic credentials which are vulnerable to dictionary attacks. - b'authorization', - b'proxy-authorization', + b"authorization", + b"proxy-authorization", ]) _REQUEST_ONLY_HEADERS = frozenset([ - b':scheme', - b':path', - b':authority', - b':method', - b':protocol', + b":scheme", + b":path", + b":authority", + b":method", + b":protocol", ]) -_RESPONSE_ONLY_HEADERS = frozenset([b':status']) +_RESPONSE_ONLY_HEADERS = frozenset([b":status"]) # A Set of pseudo headers that are only valid if the method is # CONNECT, see RFC 8441 § 5 -_CONNECT_REQUEST_ONLY_HEADERS = frozenset([b':protocol']) +_CONNECT_REQUEST_ONLY_HEADERS = frozenset([b":protocol"]) _WHITESPACE = frozenset(map(ord, whitespace)) def _secure_headers(headers: Iterable[Header], - hdr_validation_flags: Optional["HeaderValidationFlags"]) -> Generator[Header, None, None]: + hdr_validation_flags: HeaderValidationFlags | None) -> Generator[Header, None, None]: """ Certain headers are at risk of being attacked during the header compression phase, and so need to be kept out of header compression contexts. This @@ -89,23 +92,21 @@ def _secure_headers(headers: Iterable[Header], """ for header in headers: assert isinstance(header[0], bytes) - if header[0] in _SECURE_HEADERS: - yield NeverIndexedHeaderTuple(header[0], header[1]) - elif header[0] in b'cookie' and len(header[1]) < 20: + if header[0] in _SECURE_HEADERS or (header[0] in b"cookie" and len(header[1]) < 20): yield NeverIndexedHeaderTuple(header[0], header[1]) else: yield header -def extract_method_header(headers: Iterable[Header]) -> Optional[bytes]: +def extract_method_header(headers: Iterable[Header]) -> bytes | None: """ Extracts the request method from the headers list. """ for k, v in headers: - if isinstance(v, bytes) and k == b':method': + if isinstance(v, bytes) and k == b":method": return v - elif isinstance(v, str) and k == ':method': - return v.encode('utf-8') + if isinstance(v, str) and k == ":method": + return v.encode("utf-8") # pragma: no cover return None @@ -121,13 +122,13 @@ def is_informational_response(headers: Iterable[Header]) -> bool: :returns: A boolean indicating if this is an informational response. """ for n, v in headers: - if not n.startswith(b':'): + if not n.startswith(b":"): return False - if n != b':status': + if n != b":status": # If we find a non-special header, we're done here: stop looping. continue # If the first digit is a 1, we've got informational headers. - return v.startswith(b'1') + return v.startswith(b"1") return False @@ -142,20 +143,18 @@ def guard_increment_window(current: int, increment: int) -> int: :raises: ``FlowControlError`` """ # The largest value the flow control window may take. - LARGEST_FLOW_CONTROL_WINDOW = 2**31 - 1 + LARGEST_FLOW_CONTROL_WINDOW = 2**31 - 1 # noqa: N806 new_size = current + increment if new_size > LARGEST_FLOW_CONTROL_WINDOW: - raise FlowControlError( - "May not increment flow control window past %d" % - LARGEST_FLOW_CONTROL_WINDOW - ) + msg = f"May not increment flow control window past {LARGEST_FLOW_CONTROL_WINDOW}" + raise FlowControlError(msg) return new_size -def authority_from_headers(headers: Iterable[Header]) -> Optional[bytes]: +def authority_from_headers(headers: Iterable[Header]) -> bytes | None: """ Given a header set, searches for the authority header and returns the value. @@ -169,7 +168,7 @@ def authority_from_headers(headers: Iterable[Header]) -> Optional[bytes]: :rtype: ``bytes`` or ``None``. """ for n, v in headers: - if n == b':authority': + if n == b":authority": return v return None @@ -177,10 +176,11 @@ def authority_from_headers(headers: Iterable[Header]) -> Optional[bytes]: # Flags used by the validate_headers pipeline to determine which checks # should be applied to a given set of headers. -HeaderValidationFlags = collections.namedtuple( - 'HeaderValidationFlags', - ['is_client', 'is_trailer', 'is_response_header', 'is_push_promise'] -) +class HeaderValidationFlags(NamedTuple): + is_client: bool + is_trailer: bool + is_response_header: bool + is_push_promise: bool def validate_headers(headers: Iterable[Header], hdr_validation_flags: HeaderValidationFlags) -> Iterable[Header]: @@ -200,29 +200,28 @@ def validate_headers(headers: Iterable[Header], hdr_validation_flags: HeaderVali # fixed cost that we don't want to spend, instead indexing into the header # tuples. headers = _reject_empty_header_names( - headers, hdr_validation_flags + headers, hdr_validation_flags, ) headers = _reject_uppercase_header_fields( - headers, hdr_validation_flags + headers, hdr_validation_flags, ) headers = _reject_surrounding_whitespace( - headers, hdr_validation_flags + headers, hdr_validation_flags, ) headers = _reject_te( - headers, hdr_validation_flags + headers, hdr_validation_flags, ) headers = _reject_connection_header( - headers, hdr_validation_flags + headers, hdr_validation_flags, ) headers = _reject_pseudo_header_fields( - headers, hdr_validation_flags + headers, hdr_validation_flags, ) headers = _check_host_authority_header( - headers, hdr_validation_flags + headers, hdr_validation_flags, ) - headers = _check_path_header(headers, hdr_validation_flags) + return _check_path_header(headers, hdr_validation_flags) - return headers def _reject_empty_header_names(headers: Iterable[Header], @@ -235,7 +234,8 @@ def _reject_empty_header_names(headers: Iterable[Header], """ for header in headers: if len(header[0]) == 0: - raise ProtocolError("Received header name with zero length.") + msg = "Received header name with zero length." + raise ProtocolError(msg) yield header @@ -247,9 +247,8 @@ def _reject_uppercase_header_fields(headers: Iterable[Header], """ for header in headers: if UPPER_RE.search(header[0]): - raise ProtocolError( - f"Received uppercase header name {repr(header[0])}." - ) + msg = f"Received uppercase header name {header[0]!r}." + raise ProtocolError(msg) yield header @@ -266,13 +265,12 @@ def _reject_surrounding_whitespace(headers: Iterable[Header], # doesn't. for header in headers: if header[0][0] in _WHITESPACE or header[0][-1] in _WHITESPACE: - raise ProtocolError( - "Received header name surrounded by whitespace %r" % header[0]) + msg = f"Received header name surrounded by whitespace {header[0]!r}" + raise ProtocolError(msg) if header[1] and ((header[1][0] in _WHITESPACE) or (header[1][-1] in _WHITESPACE)): - raise ProtocolError( - "Received header value surrounded by whitespace %r" % header[1] - ) + msg = f"Received header value surrounded by whitespace {header[1]!r}" + raise ProtocolError(msg) yield header @@ -282,11 +280,9 @@ def _reject_te(headers: Iterable[Header], hdr_validation_flags: HeaderValidation its value is anything other than "trailers". """ for header in headers: - if header[0] == b'te': - if header[1].lower() != b'trailers': - raise ProtocolError( - f"Invalid value for TE header: {repr(header[1])}" - ) + if header[0] == b"te" and header[1].lower() != b"trailers": + msg = f"Invalid value for TE header: {header[1]!r}" + raise ProtocolError(msg) yield header @@ -298,24 +294,22 @@ def _reject_connection_header(headers: Iterable[Header], hdr_validation_flags: H """ for header in headers: if header[0] in CONNECTION_HEADERS: - raise ProtocolError( - f"Connection-specific header field present: {repr(header[0])}." - ) + msg = f"Connection-specific header field present: {header[0]!r}." + raise ProtocolError(msg) yield header def _assert_header_in_set(bytes_header: bytes, - header_set: Union[Set[Union[bytes, str]], Set[bytes], Set[str]]) -> None: + header_set: set[bytes | str] | set[bytes] | set[str]) -> None: """ Given a set of header names, checks whether the string or byte version of the header name is present. Raises a Protocol error with the appropriate error if it's missing. """ if bytes_header not in header_set: - raise ProtocolError( - f"Header block missing mandatory {repr(bytes_header)} header" - ) + msg = f"Header block missing mandatory {bytes_header!r} header" + raise ProtocolError(msg) def _reject_pseudo_header_fields(headers: Iterable[Header], @@ -334,23 +328,20 @@ def _reject_pseudo_header_fields(headers: Iterable[Header], for header in headers: if header[0][0] == SIGIL: if header[0] in seen_pseudo_header_fields: - raise ProtocolError( - f"Received duplicate pseudo-header field {repr(header[0])}" - ) + msg = f"Received duplicate pseudo-header field {header[0]!r}" + raise ProtocolError(msg) seen_pseudo_header_fields.add(header[0]) if seen_regular_header: - raise ProtocolError( - f"Received pseudo-header field out of sequence: {repr(header[0])}" - ) + msg = f"Received pseudo-header field out of sequence: {header[0]!r}" + raise ProtocolError(msg) if header[0] not in _ALLOWED_PSEUDO_HEADER_FIELDS: - raise ProtocolError( - f"Received custom pseudo-header field {repr(header[0])}" - ) + msg = f"Received custom pseudo-header field {header[0]!r}" + raise ProtocolError(msg) - if header[0] in b':method': + if header[0] in b":method": method = header[1] else: @@ -360,12 +351,12 @@ def _reject_pseudo_header_fields(headers: Iterable[Header], # Check the pseudo-headers we got to confirm they're acceptable. _check_pseudo_header_field_acceptability( - seen_pseudo_header_fields, method, hdr_validation_flags + seen_pseudo_header_fields, method, hdr_validation_flags, ) -def _check_pseudo_header_field_acceptability(pseudo_headers: Union[Set[Union[bytes, str]], Set[bytes], Set[str]], - method: Optional[bytes], +def _check_pseudo_header_field_acceptability(pseudo_headers: set[bytes | str] | set[bytes] | set[str], + method: bytes | None, hdr_validation_flags: HeaderValidationFlags) -> None: """ Given the set of pseudo-headers present in a header block and the @@ -373,9 +364,8 @@ def _check_pseudo_header_field_acceptability(pseudo_headers: Union[Set[Union[byt """ # Pseudo-header fields MUST NOT appear in trailers - RFC 7540 § 8.1.2.1 if hdr_validation_flags.is_trailer and pseudo_headers: - raise ProtocolError( - "Received pseudo-header in trailer %s" % pseudo_headers - ) + msg = f"Received pseudo-header in trailer {pseudo_headers}" + raise ProtocolError(msg) # If ':status' pseudo-header is not there in a response header, reject it. # Similarly, if ':path', ':method', or ':scheme' are not there in a request @@ -384,32 +374,27 @@ def _check_pseudo_header_field_acceptability(pseudo_headers: Union[Set[Union[byt # Relevant RFC section: RFC 7540 § 8.1.2.4 # https://tools.ietf.org/html/rfc7540#section-8.1.2.4 if hdr_validation_flags.is_response_header: - _assert_header_in_set(b':status', pseudo_headers) + _assert_header_in_set(b":status", pseudo_headers) invalid_response_headers = pseudo_headers & _REQUEST_ONLY_HEADERS if invalid_response_headers: - raise ProtocolError( - "Encountered request-only headers %s" % - invalid_response_headers - ) + msg = f"Encountered request-only headers {invalid_response_headers}" + raise ProtocolError(msg) elif (not hdr_validation_flags.is_response_header and not hdr_validation_flags.is_trailer): # This is a request, so we need to have seen :path, :method, and # :scheme. - _assert_header_in_set(b':path', pseudo_headers) - _assert_header_in_set(b':method', pseudo_headers) - _assert_header_in_set(b':scheme', pseudo_headers) + _assert_header_in_set(b":path", pseudo_headers) + _assert_header_in_set(b":method", pseudo_headers) + _assert_header_in_set(b":scheme", pseudo_headers) invalid_request_headers = pseudo_headers & _RESPONSE_ONLY_HEADERS if invalid_request_headers: - raise ProtocolError( - "Encountered response-only headers %s" % - invalid_request_headers - ) - if method != b'CONNECT': + msg = f"Encountered response-only headers {invalid_request_headers}" + raise ProtocolError(msg) + if method != b"CONNECT": invalid_headers = pseudo_headers & _CONNECT_REQUEST_ONLY_HEADERS if invalid_headers: - raise ProtocolError( - f"Encountered connect-request-only headers {repr(invalid_headers)}" - ) + msg = f"Encountered connect-request-only headers {invalid_headers!r}" + raise ProtocolError(msg) def _validate_host_authority_header(headers: Iterable[Header]) -> Generator[Header, None, None]: @@ -433,9 +418,9 @@ def _validate_host_authority_header(headers: Iterable[Header]) -> Generator[Head host_header_val = None for header in headers: - if header[0] == b':authority': + if header[0] == b":authority": authority_header_val = header[1] - elif header[0] == b'host': + elif header[0] == b"host": host_header_val = header[1] yield header @@ -448,18 +433,16 @@ def _validate_host_authority_header(headers: Iterable[Header]) -> Generator[Head # It is an error for a request header block to contain neither # an :authority header nor a Host header. if not authority_present and not host_present: - raise ProtocolError( - "Request header block does not have an :authority or Host header." - ) + msg = "Request header block does not have an :authority or Host header." + raise ProtocolError(msg) # If we receive both headers, they should definitely match. - if authority_present and host_present: - if authority_header_val != host_header_val: - raise ProtocolError( - "Request header block has mismatched :authority and " - "Host headers: %r / %r" - % (authority_header_val, host_header_val) - ) + if authority_present and host_present and authority_header_val != host_header_val: + msg = ( + "Request header block has mismatched :authority and " + f"Host headers: {authority_header_val!r} / {host_header_val!r}" + ) + raise ProtocolError(msg) def _check_host_authority_header(headers: Iterable[Header], @@ -490,9 +473,9 @@ def _check_path_header(headers: Iterable[Header], """ def inner() -> Generator[Header, None, None]: for header in headers: - if header[0] == b':path': - if not header[1]: - raise ProtocolError("An empty :path header is forbidden") + if header[0] == b":path" and not header[1]: + msg = "An empty :path header is forbidden" + raise ProtocolError(msg) yield header @@ -505,17 +488,16 @@ def inner() -> Generator[Header, None, None]: ) if skip_validation: return (h for h in headers) - else: - return inner() + return inner() -def _to_bytes(v: Union[bytes, str]) -> bytes: +def _to_bytes(v: bytes | str) -> bytes: """ Given an assumed `str` (or anything that supports `.encode()`), encodes it using utf-8 into bytes. Returns the unmodified object if it is already a `bytes` object. """ - return v if isinstance(v, bytes) else v.encode('utf-8') + return v if isinstance(v, bytes) else v.encode("utf-8") def utf8_encode_headers(headers: Iterable[HeaderWeaklyTyped]) -> list[Header]: @@ -536,7 +518,7 @@ def utf8_encode_headers(headers: Iterable[HeaderWeaklyTyped]) -> list[Header]: def _lowercase_header_names(headers: Iterable[Header], - hdr_validation_flags: Optional[HeaderValidationFlags]) -> Generator[Header, None, None]: + hdr_validation_flags: HeaderValidationFlags | None) -> Generator[Header, None, None]: """ Given an iterable of header two-tuples, rebuilds that iterable with the header names lowercased. This generator produces tuples that preserve the @@ -550,7 +532,7 @@ def _lowercase_header_names(headers: Iterable[Header], def _strip_surrounding_whitespace(headers: Iterable[Header], - hdr_validation_flags: Optional[HeaderValidationFlags]) -> Generator[Header, None, None]: + hdr_validation_flags: HeaderValidationFlags | None) -> Generator[Header, None, None]: """ Given an iterable of header two-tuples, strip both leading and trailing whitespace from both header names and header values. This generator @@ -565,7 +547,7 @@ def _strip_surrounding_whitespace(headers: Iterable[Header], def _strip_connection_headers(headers: Iterable[Header], - hdr_validation_flags: Optional[HeaderValidationFlags]) -> Generator[Header, None, None]: + hdr_validation_flags: HeaderValidationFlags | None) -> Generator[Header, None, None]: """ Strip any connection headers as per RFC7540 § 8.1.2.2. """ @@ -608,17 +590,17 @@ def _combine_cookie_fields(headers: Iterable[Header], hdr_validation_flags: Head # logic and make them never-indexed. cookies: list[bytes] = [] for header in headers: - if header[0] == b'cookie': + if header[0] == b"cookie": cookies.append(header[1]) else: yield header if cookies: - cookie_val = b'; '.join(cookies) - yield NeverIndexedHeaderTuple(b'cookie', cookie_val) + cookie_val = b"; ".join(cookies) + yield NeverIndexedHeaderTuple(b"cookie", cookie_val) def _split_outbound_cookie_fields(headers: Iterable[Header], - hdr_validation_flags: Optional[HeaderValidationFlags]) -> Generator[Header, None, None]: + hdr_validation_flags: HeaderValidationFlags | None) -> Generator[Header, None, None]: """ RFC 7540 § 8.1.2.5 allows for better compression efficiency, to split the Cookie header field into separate header fields @@ -629,8 +611,8 @@ def _split_outbound_cookie_fields(headers: Iterable[Header], for header in headers: assert isinstance(header[0], bytes) assert isinstance(header[1], bytes) - if header[0] == b'cookie': - for cookie_val in header[1].split(b'; '): + if header[0] == b"cookie": + for cookie_val in header[1].split(b"; "): if isinstance(header, HeaderTuple): yield header.__class__(header[0], cookie_val) else: @@ -640,7 +622,7 @@ def _split_outbound_cookie_fields(headers: Iterable[Header], def normalize_outbound_headers(headers: Iterable[Header], - hdr_validation_flags: Optional[HeaderValidationFlags], + hdr_validation_flags: HeaderValidationFlags | None, should_split_outbound_cookies: bool) -> Generator[Header, None, None]: """ Normalizes a header sequence that we are about to send. @@ -654,9 +636,8 @@ def normalize_outbound_headers(headers: Iterable[Header], headers = _split_outbound_cookie_fields(headers, hdr_validation_flags) headers = _strip_surrounding_whitespace(headers, hdr_validation_flags) headers = _strip_connection_headers(headers, hdr_validation_flags) - headers = _secure_headers(headers, hdr_validation_flags) + return _secure_headers(headers, hdr_validation_flags) - return headers def normalize_inbound_headers(headers: Iterable[Header], @@ -667,8 +648,7 @@ def normalize_inbound_headers(headers: Iterable[Header], :param headers: The HTTP header set. :param hdr_validation_flags: An instance of HeaderValidationFlags """ - headers = _combine_cookie_fields(headers, hdr_validation_flags) - return headers + return _combine_cookie_fields(headers, hdr_validation_flags) def validate_outbound_headers(headers: Iterable[Header], @@ -680,32 +660,31 @@ def validate_outbound_headers(headers: Iterable[Header], :param hdr_validation_flags: An instance of HeaderValidationFlags. """ headers = _reject_te( - headers, hdr_validation_flags + headers, hdr_validation_flags, ) headers = _reject_connection_header( - headers, hdr_validation_flags + headers, hdr_validation_flags, ) headers = _reject_pseudo_header_fields( - headers, hdr_validation_flags + headers, hdr_validation_flags, ) headers = _check_sent_host_authority_header( - headers, hdr_validation_flags + headers, hdr_validation_flags, ) - headers = _check_path_header(headers, hdr_validation_flags) + return _check_path_header(headers, hdr_validation_flags) - return headers class SizeLimitDict(collections.OrderedDict[int, Any]): - def __init__(self, *args: Dict[int, int], **kwargs: Any) -> None: + def __init__(self, *args: dict[int, int], **kwargs: Any) -> None: self._size_limit = kwargs.pop("size_limit", None) - super(SizeLimitDict, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self._check_size_limit() - def __setitem__(self, key: int, value: Union[Any, int]) -> None: - super(SizeLimitDict, self).__setitem__(key, value) + def __setitem__(self, key: int, value: Any | int) -> None: + super().__setitem__(key, value) self._check_size_limit() diff --git a/src/h2/windows.py b/src/h2/windows.py index d6af62ed..0efdd9fe 100644 --- a/src/h2/windows.py +++ b/src/h2/windows.py @@ -11,13 +11,10 @@ to manage the flow control window without user input, trying to ensure that it does not emit too many WINDOW_UPDATE frames. """ -from __future__ import division +from __future__ import annotations from .exceptions import FlowControlError -from typing import Optional - - # The largest acceptable value for a HTTP/2 flow control window. LARGEST_FLOW_CONTROL_WINDOW = 2**31 - 1 @@ -29,6 +26,7 @@ class WindowManager: :param max_window_size: The maximum size of the flow control window. :type max_window_size: ``int`` """ + def __init__(self, max_window_size: int) -> None: assert max_window_size <= LARGEST_FLOW_CONTROL_WINDOW self.max_window_size = max_window_size @@ -48,7 +46,8 @@ def window_consumed(self, size: int) -> None: """ self.current_window_size -= size if self.current_window_size < 0: - raise FlowControlError("Flow control window shrunk below 0") + msg = "Flow control window shrunk below 0" + raise FlowControlError(msg) def window_opened(self, size: int) -> None: """ @@ -68,15 +67,12 @@ def window_opened(self, size: int) -> None: self.current_window_size += size if self.current_window_size > LARGEST_FLOW_CONTROL_WINDOW: - raise FlowControlError( - "Flow control window mustn't exceed %d" % - LARGEST_FLOW_CONTROL_WINDOW - ) + msg = f"Flow control window mustn't exceed {LARGEST_FLOW_CONTROL_WINDOW}" + raise FlowControlError(msg) - if self.current_window_size > self.max_window_size: - self.max_window_size = self.current_window_size + self.max_window_size = max(self.current_window_size, self.max_window_size) - def process_bytes(self, size: int) -> Optional[int]: + def process_bytes(self, size: int) -> int | None: """ The application has informed us that it has processed a certain number of bytes. This may cause us to want to emit a window update frame. If @@ -93,7 +89,7 @@ def process_bytes(self, size: int) -> Optional[int]: self._bytes_processed += size return self._maybe_update_window() - def _maybe_update_window(self) -> Optional[int]: + def _maybe_update_window(self) -> int | None: """ Run the algorithm. @@ -128,11 +124,8 @@ def _maybe_update_window(self) -> Optional[int]: # Note that, even though we may increment less than _bytes_processed, # we still want to set it to zero whenever we emit an increment. This # is because we'll always increment up to the maximum we can. - if (self.current_window_size == 0) and ( - self._bytes_processed > min(1024, self.max_window_size // 4)): - increment = min(self._bytes_processed, max_increment) - self._bytes_processed = 0 - elif self._bytes_processed >= (self.max_window_size // 2): + if ((self.current_window_size == 0) and ( + self._bytes_processed > min(1024, self.max_window_size // 4))) or self._bytes_processed >= (self.max_window_size // 2): increment = min(self._bytes_processed, max_increment) self._bytes_processed = 0