diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 10965e132b..f302cbaa2e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,8 +3,12 @@ name: CI on: push: branches-ignore: - - "dependabot/**" + # these branches always have another event associated + - gh-readonly-queue/** # GitHub's merge queue uses `merge_group` + - autodeps/** # autodeps always makes a PR + - pre-commit-ci-update-config # pre-commit.ci's updates always have a PR pull_request: + merge_group: concurrency: group: ${{ github.ref }}-${{ github.workflow }}-${{ github.event_name }}${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) && format('-{0}', github.sha) || '' }} @@ -18,7 +22,7 @@ jobs: strategy: fail-fast: false matrix: - python: ['pypy-3.10', '3.9', '3.10', '3.11', '3.12', '3.13'] + python: ['3.9', '3.10', '3.11', '3.12', '3.13'] arch: ['x86', 'x64'] lsp: [''] lsp_extract_file: [''] @@ -34,6 +38,11 @@ jobs: lsp: 'https://www.proxifier.com/download/legacy/ProxifierSetup342.exe' lsp_extract_file: '' extra_name: ', with IFS LSP' + - python: 'pypy-3.10' + arch: 'x64' + lsp: '' + lsp_extract_file: '' + extra_name: '' #- python: '3.9' # arch: 'x64' # lsp: 'http://download.pctools.com/mirror/updates/9.0.0.2308-SDavfree-lite_en.exe' @@ -113,16 +122,10 @@ jobs: uses: actions/checkout@v4 - name: Setup python uses: actions/setup-python@v5 - if: "!endsWith(matrix.python, '-dev')" with: python-version: ${{ fromJSON(format('["{0}", "{1}"]', format('{0}.0-alpha - {0}.X', matrix.python), matrix.python))[startsWith(matrix.python, 'pypy')] }} cache: pip cache-dependency-path: test-requirements.txt - - name: Setup python (dev) - uses: deadsnakes/action@v2.0.2 - if: endsWith(matrix.python, '-dev') - with: - python-version: '${{ matrix.python }}' - name: Run tests run: ./ci.sh env: @@ -184,7 +187,8 @@ jobs: # can't use setup-python because that python doesn't seem to work; # `python3-dev` (rather than `python:alpine`) for some ctypes reason, # `nodejs` for pyright (`node-env` pulls in nodejs but that takes a while and can time out the test). - run: apk update && apk add python3-dev bash nodejs + # `perl` for a platform independent `sed -i` alternative + run: apk update && apk add python3-dev bash nodejs perl - name: Enter virtual environment run: python -m venv .venv - name: Run tests diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 209d5e26a2..e06d1c7ac5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,5 @@ ci: - autofix_commit_msg: "[pre-commit.ci] auto fixes from pre-commit.com hooks" - autofix_prs: false - autoupdate_commit_msg: "[pre-commit.ci] pre-commit autoupdate" + autofix_prs: true autoupdate_schedule: weekly submodules: false skip: [regenerate-files] @@ -24,7 +22,7 @@ repos: hooks: - id: black - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.9 + rev: v0.8.2 hooks: - id: ruff types: [file] diff --git a/ci.sh b/ci.sh index ef3dee55ca..83ec65748b 100755 --- a/ci.sh +++ b/ci.sh @@ -116,13 +116,13 @@ else echo "::group::Setup for tests" # We run the tests from inside an empty directory, to make sure Python - # doesn't pick up any .py files from our working dir. Might have been - # pre-created by some of the code above. + # doesn't pick up any .py files from our working dir. Might have already + # been created by a previous run. mkdir empty || true cd empty INSTALLDIR=$(python -c "import os, trio; print(os.path.dirname(trio.__file__))") - cp ../pyproject.toml "$INSTALLDIR" + cp ../pyproject.toml "$INSTALLDIR" # TODO: remove this # get mypy tests a nice cache MYPYPATH=".." mypy --config-file= --cache-dir=./.mypy_cache -c "import trio" >/dev/null 2>/dev/null || true @@ -130,9 +130,15 @@ else # support subprocess spawning with coverage.py echo "import coverage; coverage.process_startup()" | tee -a "$INSTALLDIR/../sitecustomize.py" + perl -i -pe 's/-p trio\._tests\.pytest_plugin//' "$INSTALLDIR/pyproject.toml" + echo "::endgroup::" echo "::group:: Run Tests" - if COVERAGE_PROCESS_START=$(pwd)/../pyproject.toml coverage run --rcfile=../pyproject.toml -m pytest -ra --junitxml=../test-results.xml --run-slow "${INSTALLDIR}" --verbose --durations=10 $flags; then + if PYTHONPATH=../tests COVERAGE_PROCESS_START=$(pwd)/../pyproject.toml \ + coverage run --rcfile=../pyproject.toml -m \ + pytest -ra --junitxml=../test-results.xml \ + -p _trio_check_attrs_aliases --verbose --durations=10 \ + -p trio._tests.pytest_plugin --run-slow $flags "${INSTALLDIR}"; then PASSED=true else PASSED=false diff --git a/docs-requirements.txt b/docs-requirements.txt index ee4152f186..03cefbc9a4 100644 --- a/docs-requirements.txt +++ b/docs-requirements.txt @@ -1,6 +1,6 @@ # This file was autogenerated by uv via the following command: # uv pip compile --universal --python-version=3.11 docs-requirements.in -o docs-requirements.txt -alabaster==0.7.16 +alabaster==1.0.0 # via sphinx attrs==24.2.0 # via @@ -16,7 +16,7 @@ cffi==1.17.1 ; platform_python_implementation != 'PyPy' or os_name == 'nt' # via # -r docs-requirements.in # cryptography -charset-normalizer==3.3.2 +charset-normalizer==3.4.0 # via requests click==8.1.7 # via towncrier @@ -24,9 +24,9 @@ colorama==0.4.6 ; sys_platform == 'win32' or platform_system == 'Windows' # via # click # sphinx -cryptography==43.0.1 +cryptography==44.0.0 # via pyopenssl -docutils==0.20.1 +docutils==0.21.2 # via # sphinx # sphinx-rtd-theme @@ -38,24 +38,24 @@ idna==3.10 # requests imagesize==1.4.1 # via sphinx -immutables==0.20 +immutables==0.21 # via -r docs-requirements.in jinja2==3.1.4 # via # -r docs-requirements.in # sphinx # towncrier -markupsafe==2.1.5 +markupsafe==3.0.2 # via jinja2 outcome==1.3.0.post0 # via -r docs-requirements.in -packaging==24.1 +packaging==24.2 # via sphinx pycparser==2.22 ; platform_python_implementation != 'PyPy' or os_name == 'nt' # via cffi pygments==2.18.0 # via sphinx -pyopenssl==24.2.1 +pyopenssl==24.3.0 # via -r docs-requirements.in requests==2.32.3 # via sphinx @@ -67,7 +67,7 @@ sortedcontainers==2.4.0 # via -r docs-requirements.in soupsieve==2.6 # via beautifulsoup4 -sphinx==7.4.7 +sphinx==8.1.3 # via # -r docs-requirements.in # sphinx-codeautolink @@ -77,9 +77,9 @@ sphinx==7.4.7 # sphinxcontrib-trio sphinx-codeautolink==0.15.2 # via -r docs-requirements.in -sphinx-hoverxref==1.4.1 +sphinx-hoverxref==1.4.2 # via -r docs-requirements.in -sphinx-rtd-theme==3.0.0 +sphinx-rtd-theme==3.0.2 # via -r docs-requirements.in sphinxcontrib-applehelp==2.0.0 # via sphinx diff --git a/docs/source/conf.py b/docs/source/conf.py index b4a73aa7b0..fb8e60cdc5 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -22,7 +22,6 @@ import glob import os import sys -import types from pathlib import Path from typing import TYPE_CHECKING, cast @@ -152,16 +151,6 @@ def autodoc_process_signature( return_annotation: str, ) -> tuple[str, str]: """Modify found signatures to fix various issues.""" - if name == "trio.testing._raises_group._ExceptionInfo.type": - # This has the type "type[E]", which gets resolved into the property itself. - # That means Sphinx can't resolve it. Fix the issue by overwriting with a fully-qualified - # name. - assert isinstance(obj, property), obj - assert isinstance(obj.fget, types.FunctionType), obj.fget - assert ( - obj.fget.__annotations__["return"] == "type[MatchE]" - ), obj.fget.__annotations__ - obj.fget.__annotations__["return"] = "type[~trio.testing._raises_group.MatchE]" if signature is not None: signature = signature.replace("~_contextvars.Context", "~contextvars.Context") if name == "trio.lowlevel.RunVar": # Typevar is not useful here. @@ -170,16 +159,6 @@ def autodoc_process_signature( # Strip the type from the union, make it look like = ... signature = signature.replace(" | type[trio._core._local._NoValue]", "") signature = signature.replace("", "...") - if name in ("trio.testing.RaisesGroup", "trio.testing.Matcher") and ( - "+E" in signature or "+MatchE" in signature - ): - # This typevar being covariant isn't handled correctly in some cases, strip the + - # and insert the fully-qualified name. - signature = signature.replace("+E", "~trio.testing._raises_group.E") - signature = signature.replace( - "+MatchE", - "~trio.testing._raises_group.MatchE", - ) if "DTLS" in name: signature = signature.replace("SSL.Context", "OpenSSL.SSL.Context") # Don't specify PathLike[str] | PathLike[bytes], this is just for humans. diff --git a/docs/source/reference-lowlevel.rst b/docs/source/reference-lowlevel.rst index 10c1ddfdc0..46c8b4d485 100644 --- a/docs/source/reference-lowlevel.rst +++ b/docs/source/reference-lowlevel.rst @@ -377,6 +377,46 @@ These transitions are accomplished using two function decorators: poorly-timed :exc:`KeyboardInterrupt` could leave the lock in an inconsistent state and cause a deadlock. + Since KeyboardInterrupt protection is tracked per code object, any attempt to + conditionally protect the same block of code in different ways is unlikely to behave + how you expect. If you try to conditionally protect a closure, it will be + unconditionally protected instead:: + + def example(protect: bool) -> bool: + def inner() -> bool: + return trio.lowlevel.currently_ki_protected() + if protect: + inner = trio.lowlevel.enable_ki_protection(inner) + return inner() + + async def amain(): + assert example(False) == False + assert example(True) == True # once protected ... + assert example(False) == True # ... always protected + + trio.run(amain) + + If you really need conditional protection, you can achieve it by giving each + KI-protected instance of the closure its own code object:: + + def example(protect: bool) -> bool: + def inner() -> bool: + return trio.lowlevel.currently_ki_protected() + if protect: + inner.__code__ = inner.__code__.replace() + inner = trio.lowlevel.enable_ki_protection(inner) + return inner() + + async def amain(): + assert example(False) == False + assert example(True) == True + assert example(False) == False + + trio.run(amain) + + (This isn't done by default because it carries some memory overhead and reduces + the potential for specializing optimizations in recent versions of CPython.) + .. autofunction:: currently_ki_protected diff --git a/newsfragments/2670.bugfix.rst b/newsfragments/2670.bugfix.rst new file mode 100644 index 0000000000..cd5ed3b944 --- /dev/null +++ b/newsfragments/2670.bugfix.rst @@ -0,0 +1,2 @@ +:func:`inspect.iscoroutinefunction` and the like now give correct answers when +called on KI-protected functions. diff --git a/newsfragments/3087.doc.rst b/newsfragments/3087.doc.rst new file mode 100644 index 0000000000..68fa4b05ed --- /dev/null +++ b/newsfragments/3087.doc.rst @@ -0,0 +1 @@ +Improve error message when run after gevent's monkey patching. diff --git a/newsfragments/3097.removal.rst b/newsfragments/3097.removal.rst new file mode 100644 index 0000000000..1eca349d44 --- /dev/null +++ b/newsfragments/3097.removal.rst @@ -0,0 +1 @@ +Remove workaround for OpenSSL 1.1.1 DTLS ClientHello bug. diff --git a/newsfragments/3108.bugfix.rst b/newsfragments/3108.bugfix.rst new file mode 100644 index 0000000000..16cf46b960 --- /dev/null +++ b/newsfragments/3108.bugfix.rst @@ -0,0 +1,26 @@ +Rework KeyboardInterrupt protection to track code objects, rather than frames, +as protected or not. The new implementation no longer needs to access +``frame.f_locals`` dictionaries, so it won't artificially extend the lifetime of +local variables. Since KeyboardInterrupt protection is now imposed statically +(when a protected function is defined) rather than each time the function runs, +its previously-noticeable performance overhead should now be near zero. +The lack of a call-time wrapper has some other benefits as well: + +* :func:`inspect.iscoroutinefunction` and the like now give correct answers when + called on KI-protected functions. + +* Calling a synchronous KI-protected function no longer pushes an additional stack + frame, so tracebacks are clearer. + +* A synchronous KI-protected function invoked from C code (such as a weakref + finalizer) is now guaranteed to start executing; previously there would be a brief + window in which KeyboardInterrupt could be raised before the protection was + established. + +One minor drawback of the new approach is that multiple instances of the same +closure share a single KeyboardInterrupt protection state (because they share a +single code object). That means that if you apply +`@enable_ki_protection ` to some of them +and not others, you won't get the protection semantics you asked for. See the +documentation of `@enable_ki_protection ` +for more details and a workaround. diff --git a/newsfragments/3112.bugfix.rst b/newsfragments/3112.bugfix.rst new file mode 100644 index 0000000000..c34d035520 --- /dev/null +++ b/newsfragments/3112.bugfix.rst @@ -0,0 +1,5 @@ +Rework foreign async generator finalization to track async generator +ids rather than mutating ``ag_frame.f_locals``. This fixes an issue +with the previous implementation: locals' lifetimes will no longer be +extended by materialization in the ``ag_frame.f_locals`` dictionary that +the previous finalization dispatcher logic needed to access to do its work. diff --git a/newsfragments/3114.bugfix.rst b/newsfragments/3114.bugfix.rst new file mode 100644 index 0000000000..2f07712199 --- /dev/null +++ b/newsfragments/3114.bugfix.rst @@ -0,0 +1 @@ +Ensure that Pyright recognizes our underscore prefixed attributes for attrs classes. diff --git a/newsfragments/3121.misc.rst b/newsfragments/3121.misc.rst new file mode 100644 index 0000000000..731232877b --- /dev/null +++ b/newsfragments/3121.misc.rst @@ -0,0 +1 @@ +Improve type annotations in several places by removing `Any` usage. diff --git a/newsfragments/3141.bugfix.rst b/newsfragments/3141.bugfix.rst new file mode 100644 index 0000000000..36d378d5a3 --- /dev/null +++ b/newsfragments/3141.bugfix.rst @@ -0,0 +1 @@ +Fix `trio.testing.RaisesGroup`'s typing. diff --git a/notes-to-self/aio-guest-test.py b/notes-to-self/aio-guest-test.py index 3c607d0281..7bf07d5dd4 100644 --- a/notes-to-self/aio-guest-test.py +++ b/notes-to-self/aio-guest-test.py @@ -27,7 +27,7 @@ async def trio_main(): to_trio, from_aio = trio.open_memory_channel(float("inf")) from_trio = asyncio.Queue() - _task_ref = asyncio.create_task(aio_pingpong(from_trio, to_trio)) + task_ref = asyncio.create_task(aio_pingpong(from_trio, to_trio)) from_trio.put_nowait(0) @@ -37,7 +37,7 @@ async def trio_main(): from_trio.put_nowait(n + 1) if n >= 10: return - del _task_ref + del task_ref async def aio_pingpong(from_trio, to_trio): diff --git a/pyproject.toml b/pyproject.toml index be761bd18b..70ca13632f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,6 +105,7 @@ allowed-confusables = ["–"] select = [ "A", # flake8-builtins + "ANN", # flake8-annotations "ASYNC", # flake8-async "B", # flake8-bugbear "C4", # flake8-comprehensions @@ -125,13 +126,14 @@ select = [ "RET", # flake8-return "RUF", # Ruff-specific rules "SIM", # flake8-simplify - "TCH", # flake8-type-checking + "TC", # flake8-type-checking "UP", # pyupgrade "W", # Warning "YTT", # flake8-2020 ] extend-ignore = [ 'A002', # builtin-argument-shadowing + 'ANN401', # any-type (mypy's `disallow_any_explicit` is better) 'E402', # module-import-not-at-top-of-file (usually OS-specific) 'E501', # line-too-long 'F403', # undefined-local-with-import-star @@ -161,6 +163,8 @@ extend-ignore = [ 'src/trio/_abc.py' = ['A005'] 'src/trio/_socket.py' = ['A005'] 'src/trio/_ssl.py' = ['A005'] +# Don't check annotations in notes-to-self +'notes-to-self/*.py' = ['ANN001', 'ANN002', 'ANN003', 'ANN201', 'ANN202', 'ANN204'] [tool.ruff.lint.isort] combine-as-imports = true @@ -184,6 +188,7 @@ warn_return_any = true # Avoid subtle backsliding disallow_any_decorated = true +disallow_any_explicit = true disallow_any_generics = true disallow_any_unimported = true disallow_incomplete_defs = true @@ -199,7 +204,7 @@ reportUnnecessaryTypeIgnoreComment = true typeCheckingMode = "strict" [tool.pytest.ini_options] -addopts = ["--strict-markers", "--strict-config", "-p trio._tests.pytest_plugin"] +addopts = ["--strict-markers", "--strict-config", "-p trio._tests.pytest_plugin", "--import-mode=importlib"] faulthandler_timeout = 60 filterwarnings = [ "error", diff --git a/src/trio/_channel.py b/src/trio/_channel.py index cb7df95ff6..6410d9120c 100644 --- a/src/trio/_channel.py +++ b/src/trio/_channel.py @@ -99,7 +99,7 @@ def __new__( # type: ignore[misc] # "must return a subtype" ) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]: return _open_memory_channel(max_buffer_size) - def __init__(self, max_buffer_size: int | float): # noqa: PYI041 + def __init__(self, max_buffer_size: int | float) -> None: # noqa: PYI041 ... else: diff --git a/src/trio/_core/_asyncgens.py b/src/trio/_core/_asyncgens.py index 21102ea7d8..b3b6895753 100644 --- a/src/trio/_core/_asyncgens.py +++ b/src/trio/_core/_asyncgens.py @@ -4,7 +4,7 @@ import sys import warnings import weakref -from typing import TYPE_CHECKING, NoReturn +from typing import TYPE_CHECKING, NoReturn, TypeVar import attrs @@ -16,14 +16,31 @@ ASYNCGEN_LOGGER = logging.getLogger("trio.async_generator_errors") if TYPE_CHECKING: + from collections.abc import Callable from types import AsyncGeneratorType + from typing_extensions import ParamSpec + + _P = ParamSpec("_P") + _WEAK_ASYNC_GEN_SET = weakref.WeakSet[AsyncGeneratorType[object, NoReturn]] _ASYNC_GEN_SET = set[AsyncGeneratorType[object, NoReturn]] else: _WEAK_ASYNC_GEN_SET = weakref.WeakSet _ASYNC_GEN_SET = set +_R = TypeVar("_R") + + +@_core.disable_ki_protection +def _call_without_ki_protection( + f: Callable[_P, _R], + /, + *args: _P.args, + **kwargs: _P.kwargs, +) -> _R: + return f(*args, **kwargs) + @attrs.define(eq=False) class AsyncGenerators: @@ -35,6 +52,11 @@ class AsyncGenerators: # regular set so we don't have to deal with GC firing at # unexpected times. alive: _WEAK_ASYNC_GEN_SET | _ASYNC_GEN_SET = attrs.Factory(_WEAK_ASYNC_GEN_SET) + # The ids of foreign async generators are added to this set when first + # iterated. Usually it is not safe to refer to ids like this, but because + # we're using a finalizer we can ensure ids in this set do not outlive + # their async generator. + foreign: set[int] = attrs.Factory(set) # This collects async generators that get garbage collected during # the one-tick window between the system nursery closing and the @@ -51,10 +73,10 @@ def firstiter(agen: AsyncGeneratorType[object, NoReturn]) -> None: # An async generator first iterated outside of a Trio # task doesn't belong to Trio. Probably we're in guest # mode and the async generator belongs to our host. - # The locals dictionary is the only good place to + # A strong set of ids is one of the only good places to # remember this fact, at least until - # https://bugs.python.org/issue40916 is implemented. - agen.ag_frame.f_locals["@trio_foreign_asyncgen"] = True + # https://github.com/python/cpython/issues/85093 is implemented. + self.foreign.add(id(agen)) if self.prev_hooks.firstiter is not None: self.prev_hooks.firstiter(agen) @@ -76,13 +98,16 @@ def finalize_in_trio_context( # have hit it. self.trailing_needs_finalize.add(agen) + @_core.enable_ki_protection def finalizer(agen: AsyncGeneratorType[object, NoReturn]) -> None: - agen_name = name_asyncgen(agen) try: - is_ours = not agen.ag_frame.f_locals.get("@trio_foreign_asyncgen") - except AttributeError: # pragma: no cover + self.foreign.remove(id(agen)) + except KeyError: is_ours = True + else: + is_ours = False + agen_name = name_asyncgen(agen) if is_ours: runner.entry_queue.run_sync_soon( finalize_in_trio_context, @@ -105,8 +130,9 @@ def finalizer(agen: AsyncGeneratorType[object, NoReturn]) -> None: ) else: # Not ours -> forward to the host loop's async generator finalizer - if self.prev_hooks.finalizer is not None: - self.prev_hooks.finalizer(agen) + finalizer = self.prev_hooks.finalizer + if finalizer is not None: + _call_without_ki_protection(finalizer, agen) else: # Host has no finalizer. Reimplement the default # Python behavior with no hooks installed: throw in @@ -116,7 +142,7 @@ def finalizer(agen: AsyncGeneratorType[object, NoReturn]) -> None: try: # If the next thing is a yield, this will raise RuntimeError # which we allow to propagate - closer.send(None) + _call_without_ki_protection(closer.send, None) except StopIteration: pass else: diff --git a/src/trio/_core/_concat_tb.py b/src/trio/_core/_concat_tb.py index 2ddaf2e8e6..a1469618e1 100644 --- a/src/trio/_core/_concat_tb.py +++ b/src/trio/_core/_concat_tb.py @@ -1,7 +1,9 @@ from __future__ import annotations -from types import TracebackType -from typing import Any, ClassVar, cast +from typing import TYPE_CHECKING, ClassVar, cast + +if TYPE_CHECKING: + from types import TracebackType ################################################################ # concat_tb @@ -86,7 +88,9 @@ def copy_tb(base_tb: TracebackType, tb_next: TracebackType | None) -> TracebackT def copy_tb(base_tb: TracebackType, tb_next: TracebackType | None) -> TracebackType: # tputil.ProxyOperation is PyPy-only, and there's no way to specify # cpython/pypy in current type checkers. - def controller(operation: tputil.ProxyOperation) -> Any | None: # type: ignore[no-any-unimported] + def controller( # type: ignore[no-any-unimported] + operation: tputil.ProxyOperation, + ) -> TracebackType | None: # Rationale for pragma: I looked fairly carefully and tried a few # things, and AFAICT it's not actually possible to get any # 'opname' that isn't __getattr__ or __getattribute__. So there's @@ -99,12 +103,13 @@ def controller(operation: tputil.ProxyOperation) -> Any | None: # type: ignore[ "__getattr__", } and operation.args[0] == "tb_next" - ): # pragma: no cover + ) or TYPE_CHECKING: # pragma: no cover return tb_next - return operation.delegate() # Delegate is reverting to original behaviour + # Delegate is reverting to original behaviour + return operation.delegate() # type: ignore[no-any-return] return cast( - TracebackType, + "TracebackType", tputil.make_proxy(controller, type(base_tb), base_tb), ) # Returns proxy to traceback diff --git a/src/trio/_core/_entry_queue.py b/src/trio/_core/_entry_queue.py index 332441a3a0..0691de3517 100644 --- a/src/trio/_core/_entry_queue.py +++ b/src/trio/_core/_entry_queue.py @@ -16,7 +16,8 @@ PosArgsT = TypeVarTuple("PosArgsT") -Function = Callable[..., object] +# Explicit "Any" is not allowed +Function = Callable[..., object] # type: ignore[misc] Job = tuple[Function, tuple[object, ...]] diff --git a/src/trio/_core/_generated_instrumentation.py b/src/trio/_core/_generated_instrumentation.py index 568b76dffa..d03ef9db7d 100644 --- a/src/trio/_core/_generated_instrumentation.py +++ b/src/trio/_core/_generated_instrumentation.py @@ -3,10 +3,9 @@ # ************************************************************* from __future__ import annotations -import sys from typing import TYPE_CHECKING -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._ki import enable_ki_protection from ._run import GLOBAL_RUN_CONTEXT if TYPE_CHECKING: @@ -15,6 +14,7 @@ __all__ = ["add_instrument", "remove_instrument"] +@enable_ki_protection def add_instrument(instrument: Instrument) -> None: """Start instrumenting the current run loop with the given instrument. @@ -24,13 +24,13 @@ def add_instrument(instrument: Instrument) -> None: If ``instrument`` is already active, does nothing. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.instruments.add_instrument(instrument) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def remove_instrument(instrument: Instrument) -> None: """Stop instrumenting the current run loop with the given instrument. @@ -44,7 +44,6 @@ def remove_instrument(instrument: Instrument) -> None: deactivated. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.instruments.remove_instrument(instrument) except AttributeError: diff --git a/src/trio/_core/_generated_io_epoll.py b/src/trio/_core/_generated_io_epoll.py index 9f9ad59725..41cbb40650 100644 --- a/src/trio/_core/_generated_io_epoll.py +++ b/src/trio/_core/_generated_io_epoll.py @@ -6,7 +6,7 @@ import sys from typing import TYPE_CHECKING -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._ki import enable_ki_protection from ._run import GLOBAL_RUN_CONTEXT if TYPE_CHECKING: @@ -18,6 +18,7 @@ __all__ = ["notify_closing", "wait_readable", "wait_writable"] +@enable_ki_protection async def wait_readable(fd: int | _HasFileNo) -> None: """Block until the kernel reports that the given object is readable. @@ -40,13 +41,13 @@ async def wait_readable(fd: int | _HasFileNo) -> None: if another task calls :func:`notify_closing` while this function is still working. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection async def wait_writable(fd: int | _HasFileNo) -> None: """Block until the kernel reports that the given object is writable. @@ -59,13 +60,13 @@ async def wait_writable(fd: int | _HasFileNo) -> None: if another task calls :func:`notify_closing` while this function is still working. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def notify_closing(fd: int | _HasFileNo) -> None: """Notify waiters of the given object that it will be closed. @@ -91,7 +92,6 @@ def notify_closing(fd: int | _HasFileNo) -> None: step, so other tasks won't be able to tell what order they happened in anyway. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) except AttributeError: diff --git a/src/trio/_core/_generated_io_kqueue.py b/src/trio/_core/_generated_io_kqueue.py index b2bdfc5763..016704eac7 100644 --- a/src/trio/_core/_generated_io_kqueue.py +++ b/src/trio/_core/_generated_io_kqueue.py @@ -6,7 +6,7 @@ import sys from typing import TYPE_CHECKING -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._ki import enable_ki_protection from ._run import GLOBAL_RUN_CONTEXT if TYPE_CHECKING: @@ -31,18 +31,19 @@ ] +@enable_ki_protection def current_kqueue() -> select.kqueue: """TODO: these are implemented, but are currently more of a sketch than anything real. See `#26 `__. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.current_kqueue() except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def monitor_kevent( ident: int, filter: int, @@ -51,13 +52,13 @@ def monitor_kevent( anything real. See `#26 `__. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_kevent(ident, filter) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection async def wait_kevent( ident: int, filter: int, @@ -67,7 +68,6 @@ async def wait_kevent( anything real. See `#26 `__. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_kevent( ident, @@ -78,6 +78,7 @@ async def wait_kevent( raise RuntimeError("must be called from async context") from None +@enable_ki_protection async def wait_readable(fd: int | _HasFileNo) -> None: """Block until the kernel reports that the given object is readable. @@ -100,13 +101,13 @@ async def wait_readable(fd: int | _HasFileNo) -> None: if another task calls :func:`notify_closing` while this function is still working. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection async def wait_writable(fd: int | _HasFileNo) -> None: """Block until the kernel reports that the given object is writable. @@ -119,13 +120,13 @@ async def wait_writable(fd: int | _HasFileNo) -> None: if another task calls :func:`notify_closing` while this function is still working. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def notify_closing(fd: int | _HasFileNo) -> None: """Notify waiters of the given object that it will be closed. @@ -151,7 +152,6 @@ def notify_closing(fd: int | _HasFileNo) -> None: step, so other tasks won't be able to tell what order they happened in anyway. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) except AttributeError: diff --git a/src/trio/_core/_generated_io_windows.py b/src/trio/_core/_generated_io_windows.py index d06bb19e0e..745fa4fc4e 100644 --- a/src/trio/_core/_generated_io_windows.py +++ b/src/trio/_core/_generated_io_windows.py @@ -6,7 +6,7 @@ import sys from typing import TYPE_CHECKING -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._ki import enable_ki_protection from ._run import GLOBAL_RUN_CONTEXT if TYPE_CHECKING: @@ -34,6 +34,7 @@ ] +@enable_ki_protection async def wait_readable(sock: _HasFileNo | int) -> None: """Block until the kernel reports that the given object is readable. @@ -56,13 +57,13 @@ async def wait_readable(sock: _HasFileNo | int) -> None: if another task calls :func:`notify_closing` while this function is still working. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(sock) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection async def wait_writable(sock: _HasFileNo | int) -> None: """Block until the kernel reports that the given object is writable. @@ -75,13 +76,13 @@ async def wait_writable(sock: _HasFileNo | int) -> None: if another task calls :func:`notify_closing` while this function is still working. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(sock) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def notify_closing(handle: Handle | int | _HasFileNo) -> None: """Notify waiters of the given object that it will be closed. @@ -107,33 +108,32 @@ def notify_closing(handle: Handle | int | _HasFileNo) -> None: step, so other tasks won't be able to tell what order they happened in anyway. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(handle) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def register_with_iocp(handle: int | CData) -> None: """TODO: these are implemented, but are currently more of a sketch than anything real. See `#26 `__ and `#52 `__. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.register_with_iocp(handle) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection async def wait_overlapped(handle_: int | CData, lpOverlapped: CData | int) -> object: """TODO: these are implemented, but are currently more of a sketch than anything real. See `#26 `__ and `#52 `__. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_overlapped( handle_, @@ -143,6 +143,7 @@ async def wait_overlapped(handle_: int | CData, lpOverlapped: CData | int) -> ob raise RuntimeError("must be called from async context") from None +@enable_ki_protection async def write_overlapped( handle: int | CData, data: Buffer, @@ -153,7 +154,6 @@ async def write_overlapped( `__ and `#52 `__. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.write_overlapped( handle, @@ -164,6 +164,7 @@ async def write_overlapped( raise RuntimeError("must be called from async context") from None +@enable_ki_protection async def readinto_overlapped( handle: int | CData, buffer: Buffer, @@ -174,7 +175,6 @@ async def readinto_overlapped( `__ and `#52 `__. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.readinto_overlapped( handle, @@ -185,19 +185,20 @@ async def readinto_overlapped( raise RuntimeError("must be called from async context") from None +@enable_ki_protection def current_iocp() -> int: """TODO: these are implemented, but are currently more of a sketch than anything real. See `#26 `__ and `#52 `__. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.current_iocp() except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def monitor_completion_key() -> ( AbstractContextManager[tuple[int, UnboundedQueue[object]]] ): @@ -206,7 +207,6 @@ def monitor_completion_key() -> ( `__ and `#52 `__. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_completion_key() except AttributeError: diff --git a/src/trio/_core/_generated_run.py b/src/trio/_core/_generated_run.py index b5957a134e..67d70d9077 100644 --- a/src/trio/_core/_generated_run.py +++ b/src/trio/_core/_generated_run.py @@ -3,10 +3,9 @@ # ************************************************************* from __future__ import annotations -import sys -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._ki import enable_ki_protection from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT, RunStatistics, Task if TYPE_CHECKING: @@ -33,6 +32,7 @@ ] +@enable_ki_protection def current_statistics() -> RunStatistics: """Returns ``RunStatistics``, which contains run-loop-level debugging information. @@ -56,13 +56,13 @@ def current_statistics() -> RunStatistics: other attributes vary between backends. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_statistics() except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def current_time() -> float: """Returns the current time according to Trio's internal clock. @@ -73,36 +73,36 @@ def current_time() -> float: RuntimeError: if not inside a call to :func:`trio.run`. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_time() except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def current_clock() -> Clock: """Returns the current :class:`~trio.abc.Clock`.""" - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_clock() except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def current_root_task() -> Task | None: """Returns the current root :class:`Task`. This is the task that is the ultimate parent of all other tasks. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_root_task() except AttributeError: raise RuntimeError("must be called from async context") from None -def reschedule(task: Task, next_send: Outcome[Any] = _NO_SEND) -> None: +@enable_ki_protection +def reschedule(task: Task, next_send: Outcome[object] = _NO_SEND) -> None: """Reschedule the given task with the given :class:`outcome.Outcome`. @@ -120,13 +120,13 @@ def reschedule(task: Task, next_send: Outcome[Any] = _NO_SEND) -> None: raise) from :func:`wait_task_rescheduled`. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.reschedule(task, next_send) except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection def spawn_system_task( async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]], *args: Unpack[PosArgT], @@ -184,7 +184,6 @@ def spawn_system_task( Task: the newly spawned task """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.spawn_system_task( async_fn, @@ -196,18 +195,19 @@ def spawn_system_task( raise RuntimeError("must be called from async context") from None +@enable_ki_protection def current_trio_token() -> TrioToken: """Retrieve the :class:`TrioToken` for the current call to :func:`trio.run`. """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_trio_token() except AttributeError: raise RuntimeError("must be called from async context") from None +@enable_ki_protection async def wait_all_tasks_blocked(cushion: float = 0.0) -> None: """Block until there are no runnable tasks. @@ -266,7 +266,6 @@ async def test_lock_fairness(): print("FAIL") """ - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.wait_all_tasks_blocked(cushion) except AttributeError: diff --git a/src/trio/_core/_instrumentation.py b/src/trio/_core/_instrumentation.py index 905e81c37a..40bddd1a23 100644 --- a/src/trio/_core/_instrumentation.py +++ b/src/trio/_core/_instrumentation.py @@ -3,7 +3,7 @@ import logging import types from collections.abc import Callable, Sequence -from typing import Any, TypeVar +from typing import TypeVar from .._abc import Instrument @@ -11,12 +11,14 @@ INSTRUMENT_LOGGER = logging.getLogger("trio.abc.Instrument") -F = TypeVar("F", bound=Callable[..., Any]) +# Explicit "Any" is not allowed +F = TypeVar("F", bound=Callable[..., object]) # type: ignore[misc] # Decorator to mark methods public. This does nothing by itself, but # trio/_tools/gen_exports.py looks for it. -def _public(fn: F) -> F: +# Explicit "Any" is not allowed +def _public(fn: F) -> F: # type: ignore[misc] return fn @@ -32,7 +34,7 @@ class Instruments(dict[str, dict[Instrument, None]]): __slots__ = () - def __init__(self, incoming: Sequence[Instrument]): + def __init__(self, incoming: Sequence[Instrument]) -> None: self["_all"] = {} for instrument in incoming: self.add_instrument(instrument) @@ -89,7 +91,11 @@ def remove_instrument(self, instrument: Instrument) -> None: if not instruments: del self[hookname] - def call(self, hookname: str, *args: Any) -> None: + def call( + self, + hookname: str, + *args: object, + ) -> None: """Call hookname(*args) on each applicable instrument. You must first check whether there are any instruments installed for diff --git a/src/trio/_core/_io_windows.py b/src/trio/_core/_io_windows.py index 80b62d4777..1874f5c791 100644 --- a/src/trio/_core/_io_windows.py +++ b/src/trio/_core/_io_windows.py @@ -7,8 +7,8 @@ from contextlib import contextmanager from typing import ( TYPE_CHECKING, - Any, Literal, + Protocol, TypeVar, cast, ) @@ -24,6 +24,7 @@ AFDPollFlags, CData, CompletionModes, + CType, ErrorCodes, FileFlags, Handle, @@ -249,13 +250,28 @@ class AFDWaiters: current_op: AFDPollOp | None = None +# Just used for internal type checking. +class _AFDHandle(Protocol): + Handle: Handle + Status: int + Events: int + + +# Just used for internal type checking. +class _AFDPollInfo(Protocol): + Timeout: int + NumberOfHandles: int + Exclusive: int + Handles: list[_AFDHandle] + + # We also need to bundle up all the info for a single op into a standalone # object, because we need to keep all these objects alive until the operation # finishes, even if we're throwing it away. @attrs.frozen(eq=False) class AFDPollOp: lpOverlapped: CData - poll_info: Any + poll_info: _AFDPollInfo waiters: AFDWaiters afd_group: AFDGroup @@ -684,7 +700,7 @@ def _refresh_afd(self, base_handle: Handle) -> None: lpOverlapped = ffi.new("LPOVERLAPPED") - poll_info: Any = ffi.new("AFD_POLL_INFO *") + poll_info = cast("_AFDPollInfo", ffi.new("AFD_POLL_INFO *")) poll_info.Timeout = 2**63 - 1 # INT64_MAX poll_info.NumberOfHandles = 1 poll_info.Exclusive = 0 @@ -697,9 +713,9 @@ def _refresh_afd(self, base_handle: Handle) -> None: kernel32.DeviceIoControl( afd_group.handle, IoControlCodes.IOCTL_AFD_POLL, - poll_info, + cast("CType", poll_info), ffi.sizeof("AFD_POLL_INFO"), - poll_info, + cast("CType", poll_info), ffi.sizeof("AFD_POLL_INFO"), ffi.NULL, lpOverlapped, @@ -921,13 +937,13 @@ async def _perform_overlapped( # operation will not be cancellable, depending on how Windows is # feeling today. So we need to check for cancellation manually. await _core.checkpoint_if_cancelled() - lpOverlapped = cast(_Overlapped, ffi.new("LPOVERLAPPED")) + lpOverlapped = cast("_Overlapped", ffi.new("LPOVERLAPPED")) try: submit_fn(lpOverlapped) except OSError as exc: if exc.winerror != ErrorCodes.ERROR_IO_PENDING: raise - await self.wait_overlapped(handle, cast(CData, lpOverlapped)) + await self.wait_overlapped(handle, cast("CData", lpOverlapped)) return lpOverlapped @_public diff --git a/src/trio/_core/_ki.py b/src/trio/_core/_ki.py index 51e8a871e2..46a7fdf700 100644 --- a/src/trio/_core/_ki.py +++ b/src/trio/_core/_ki.py @@ -1,26 +1,21 @@ from __future__ import annotations -import inspect import signal import sys -from functools import wraps -from typing import TYPE_CHECKING, Final, Protocol, TypeVar +import types +import weakref +from typing import TYPE_CHECKING, Generic, Protocol, TypeVar import attrs from .._util import is_main_thread - -CallableT = TypeVar("CallableT", bound="Callable[..., object]") -RetT = TypeVar("RetT") +from ._run_context import GLOBAL_RUN_CONTEXT if TYPE_CHECKING: import types from collections.abc import Callable - from typing_extensions import ParamSpec, TypeGuard - - ArgsT = ParamSpec("ArgsT") - + from typing_extensions import Self, TypeGuard # In ordinary single-threaded Python code, when you hit control-C, it raises # an exception and automatically does all the regular unwinding stuff. # @@ -83,20 +78,117 @@ # for any Python program that's written to catch and ignore # KeyboardInterrupt.) -# We use this special string as a unique key into the frame locals dictionary. -# The @ ensures it is not a valid identifier and can't clash with any possible -# real local name. See: https://github.com/python-trio/trio/issues/469 -LOCALS_KEY_KI_PROTECTION_ENABLED: Final = "@TRIO_KI_PROTECTION_ENABLED" +_T = TypeVar("_T") + + +class _IdRef(weakref.ref[_T]): + __slots__ = ("_hash",) + _hash: int + + def __new__( + cls, + ob: _T, + callback: Callable[[Self], object] | None = None, + /, + ) -> Self: + self: Self = weakref.ref.__new__(cls, ob, callback) + self._hash = object.__hash__(ob) + return self + + def __eq__(self, other: object) -> bool: + if self is other: + return True + + if not isinstance(other, _IdRef): + return NotImplemented + + my_obj = None + try: + my_obj = self() + return my_obj is not None and my_obj is other() + finally: + del my_obj + + # we're overriding a builtin so we do need this + def __ne__(self, other: object) -> bool: + return not self == other + + def __hash__(self) -> int: + return self._hash + + +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") + + +# see also: https://github.com/python/cpython/issues/88306 +class WeakKeyIdentityDictionary(Generic[_KT, _VT]): + def __init__(self) -> None: + self._data: dict[_IdRef[_KT], _VT] = {} + + def remove( + k: _IdRef[_KT], + selfref: weakref.ref[ + WeakKeyIdentityDictionary[_KT, _VT] + ] = weakref.ref( # noqa: B008 # function-call-in-default-argument + self, + ), + ) -> None: + self = selfref() + if self is not None: + try: # noqa: SIM105 # supressible-exception + del self._data[k] + except KeyError: + pass + + self._remove = remove + + def __getitem__(self, k: _KT) -> _VT: + return self._data[_IdRef(k)] + + def __setitem__(self, k: _KT, v: _VT) -> None: + self._data[_IdRef(k, self._remove)] = v + + +_CODE_KI_PROTECTION_STATUS_WMAP: WeakKeyIdentityDictionary[ + types.CodeType, + bool, +] = WeakKeyIdentityDictionary() + + +# This is to support the async_generator package necessary for aclosing on <3.10 +# functions decorated @async_generator are given this magic property that's a +# reference to the object itself +# see python-trio/async_generator/async_generator/_impl.py +def legacy_isasyncgenfunction( + obj: object, +) -> TypeGuard[Callable[..., types.AsyncGeneratorType[object, object]]]: + return getattr(obj, "_async_gen_function", None) == id(obj) # NB: according to the signal.signal docs, 'frame' can be None on entry to # this function: def ki_protection_enabled(frame: types.FrameType | None) -> bool: + try: + task = GLOBAL_RUN_CONTEXT.task + except AttributeError: + task_ki_protected = False + task_frame = None + else: + task_ki_protected = task._ki_protected + task_frame = task.coro.cr_frame + while frame is not None: - if LOCALS_KEY_KI_PROTECTION_ENABLED in frame.f_locals: - return bool(frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED]) + try: + v = _CODE_KI_PROTECTION_STATUS_WMAP[frame.f_code] + except KeyError: + pass + else: + return bool(v) if frame.f_code.co_name == "__del__": return True + if frame is task_frame: + return task_ki_protected frame = frame.f_back return True @@ -117,89 +209,33 @@ def currently_ki_protected() -> bool: return ki_protection_enabled(sys._getframe()) -# This is to support the async_generator package necessary for aclosing on <3.10 -# functions decorated @async_generator are given this magic property that's a -# reference to the object itself -# see python-trio/async_generator/async_generator/_impl.py -def legacy_isasyncgenfunction( - obj: object, -) -> TypeGuard[Callable[..., types.AsyncGeneratorType[object, object]]]: - return getattr(obj, "_async_gen_function", None) == id(obj) +class _SupportsCode(Protocol): + __code__: types.CodeType + + +_T_supports_code = TypeVar("_T_supports_code", bound=_SupportsCode) + + +def enable_ki_protection(f: _T_supports_code, /) -> _T_supports_code: + """Decorator to enable KI protection.""" + orig = f + + if legacy_isasyncgenfunction(f): + f = f.__wrapped__ # type: ignore + + _CODE_KI_PROTECTION_STATUS_WMAP[f.__code__] = True + return orig + + +def disable_ki_protection(f: _T_supports_code, /) -> _T_supports_code: + """Decorator to disable KI protection.""" + orig = f + if legacy_isasyncgenfunction(f): + f = f.__wrapped__ # type: ignore -def _ki_protection_decorator( - enabled: bool, -) -> Callable[[Callable[ArgsT, RetT]], Callable[ArgsT, RetT]]: - # The "ignore[return-value]" below is because the inspect functions cast away the - # original return type of fn, making it just CoroutineType[Any, Any, Any] etc. - # ignore[misc] is because @wraps() is passed a callable with Any in the return type. - def decorator(fn: Callable[ArgsT, RetT]) -> Callable[ArgsT, RetT]: - # In some version of Python, isgeneratorfunction returns true for - # coroutine functions, so we have to check for coroutine functions - # first. - if inspect.iscoroutinefunction(fn): - - @wraps(fn) - def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # type: ignore[misc] - # See the comment for regular generators below - coro = fn(*args, **kwargs) - assert coro.cr_frame is not None, "Coroutine frame should exist" - coro.cr_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled - return coro # type: ignore[return-value] - - return wrapper - if inspect.isgeneratorfunction(fn): - - @wraps(fn) - def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # type: ignore[misc] - # It's important that we inject this directly into the - # generator's locals, as opposed to setting it here and then - # doing 'yield from'. The reason is, if a generator is - # throw()n into, then it may magically pop to the top of the - # stack. And @contextmanager generators in particular are a - # case where we often want KI protection, and which are often - # thrown into! See: - # https://bugs.python.org/issue29590 - gen = fn(*args, **kwargs) - gen.gi_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled - return gen # type: ignore[return-value] - - return wrapper - if inspect.isasyncgenfunction(fn) or legacy_isasyncgenfunction(fn): - - @wraps(fn) # type: ignore[arg-type] - def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # type: ignore[misc] - # See the comment for regular generators above - agen = fn(*args, **kwargs) - agen.ag_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled - return agen # type: ignore[return-value] - - return wrapper - - @wraps(fn) - def wrapper_(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled - return fn(*args, **kwargs) - - return wrapper_ - - return decorator - - -# pyright workaround: https://github.com/microsoft/pyright/issues/5866 -class KIProtectionSignature(Protocol): - __name__: str - - def __call__(self, f: CallableT, /) -> CallableT: - pass - - -# the following `type: ignore`s are because we use ParamSpec internally, but want to allow overloads -enable_ki_protection: KIProtectionSignature = _ki_protection_decorator(True) # type: ignore[assignment] -enable_ki_protection.__name__ = "enable_ki_protection" - -disable_ki_protection: KIProtectionSignature = _ki_protection_decorator(False) # type: ignore[assignment] -disable_ki_protection.__name__ = "disable_ki_protection" + _CODE_KI_PROTECTION_STATUS_WMAP[f.__code__] = False + return orig @attrs.define(slots=False) diff --git a/src/trio/_core/_local.py b/src/trio/_core/_local.py index 53cbfc135e..fff1234f59 100644 --- a/src/trio/_core/_local.py +++ b/src/trio/_core/_local.py @@ -38,13 +38,13 @@ class RunVar(Generic[T]): """ - _name: str - _default: T | type[_NoValue] = _NoValue + _name: str = attrs.field(alias="name") + _default: T | type[_NoValue] = attrs.field(default=_NoValue, alias="default") def get(self, default: T | type[_NoValue] = _NoValue) -> T: """Gets the value of this :class:`RunVar` for the current run call.""" try: - return cast(T, _run.GLOBAL_RUN_CONTEXT.runner._locals[self]) + return cast("T", _run.GLOBAL_RUN_CONTEXT.runner._locals[self]) except AttributeError: raise RuntimeError("Cannot be used outside of a run context") from None except KeyError: diff --git a/src/trio/_core/_mock_clock.py b/src/trio/_core/_mock_clock.py index 913c435695..d9f0a5afa5 100644 --- a/src/trio/_core/_mock_clock.py +++ b/src/trio/_core/_mock_clock.py @@ -63,7 +63,7 @@ class MockClock(Clock): """ - def __init__(self, rate: float = 0.0, autojump_threshold: float = inf): + def __init__(self, rate: float = 0.0, autojump_threshold: float = inf) -> None: # when the real clock said 'real_base', the virtual time was # 'virtual_base', and since then it's advanced at 'rate' virtual # seconds per real second. diff --git a/src/trio/_core/_run.py b/src/trio/_core/_run.py index 2cde953ecf..53690b1f0d 100644 --- a/src/trio/_core/_run.py +++ b/src/trio/_core/_run.py @@ -7,7 +7,6 @@ import random import select import sys -import threading import warnings from collections import deque from contextlib import AbstractAsyncContextManager, contextmanager, suppress @@ -39,8 +38,9 @@ from ._entry_queue import EntryQueue, TrioToken from ._exceptions import Cancelled, RunFinishedError, TrioInternalError from ._instrumentation import Instruments -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED, KIManager, enable_ki_protection +from ._ki import KIManager, enable_ki_protection from ._parking_lot import GLOBAL_PARKING_LOT_BREAKER +from ._run_context import GLOBAL_RUN_CONTEXT as GLOBAL_RUN_CONTEXT from ._thread_cache import start_thread_soon from ._traps import ( Abort, @@ -82,14 +82,13 @@ StatusT = TypeVar("StatusT") StatusT_contra = TypeVar("StatusT_contra", contravariant=True) -FnT = TypeVar("FnT", bound="Callable[..., Any]") RetT = TypeVar("RetT") DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: Final = 1000 # Passed as a sentinel -_NO_SEND: Final[Outcome[Any]] = cast("Outcome[Any]", object()) +_NO_SEND: Final[Outcome[object]] = cast("Outcome[object]", object()) # Used to track if an exceptiongroup can be collapsed NONSTRICT_EXCEPTIONGROUP_NOTE = 'This is a "loose" ExceptionGroup, and may be collapsed by Trio if it only contains one exception - typically after `Cancelled` has been stripped from it. Note this has consequences for exception handling, and strict_exception_groups=True is recommended.' @@ -102,7 +101,7 @@ class _NoStatus(metaclass=NoPublicConstructor): # Decorator to mark methods public. This does nothing by itself, but # trio/_tools/gen_exports.py looks for it. -def _public(fn: FnT) -> FnT: +def _public(fn: RetT) -> RetT: return fn @@ -542,9 +541,13 @@ class CancelScope: cancelled_caught: bool = attrs.field(default=False, init=False) # Constructor arguments: - _relative_deadline: float = attrs.field(default=inf, kw_only=True) - _deadline: float = attrs.field(default=inf, kw_only=True) - _shield: bool = attrs.field(default=False, kw_only=True) + _relative_deadline: float = attrs.field( + default=inf, + kw_only=True, + alias="relative_deadline", + ) + _deadline: float = attrs.field(default=inf, kw_only=True, alias="deadline") + _shield: bool = attrs.field(default=False, kw_only=True, alias="shield") def __attrs_post_init__(self) -> None: if isnan(self._deadline): @@ -939,7 +942,7 @@ def started(self: _TaskStatus[StatusT], value: StatusT) -> None: ... def started(self, value: StatusT | None = None) -> None: if self._value is not _NoStatus: raise RuntimeError("called 'started' twice on the same task status") - self._value = cast(StatusT, value) # If None, StatusT == None + self._value = cast("StatusT", value) # If None, StatusT == None # If the old nursery is cancelled, then quietly quit now; the child # will eventually exit on its own, and we don't want to risk moving @@ -1125,7 +1128,7 @@ def __init__( parent_task: Task, cancel_scope: CancelScope, strict_exception_groups: bool, - ): + ) -> None: self._parent_task = parent_task self._strict_exception_groups = strict_exception_groups parent_task._child_nurseries.append(self) @@ -1168,7 +1171,11 @@ def _check_nursery_closed(self) -> None: self._parent_waiting_in_aexit = False GLOBAL_RUN_CONTEXT.runner.reschedule(self._parent_task) - def _child_finished(self, task: Task, outcome: Outcome[Any]) -> None: + def _child_finished( + self, + task: Task, + outcome: Outcome[object], + ) -> None: self._children.remove(task) if isinstance(outcome, Error): self._add_exc(outcome.error) @@ -1274,12 +1281,14 @@ def start_soon( """ GLOBAL_RUN_CONTEXT.runner.spawn_impl(async_fn, args, self, name) - async def start( + # Typing changes blocked by https://github.com/python/mypy/pull/17512 + # Explicit "Any" is not allowed + async def start( # type: ignore[misc] self, async_fn: Callable[..., Awaitable[object]], *args: object, name: object = None, - ) -> Any: + ) -> Any | None: r"""Creates and initializes a child task. Like :meth:`start_soon`, but blocks until the new task has @@ -1330,7 +1339,10 @@ async def async_fn(arg1, arg2, *, task_status=trio.TASK_STATUS_IGNORED): # set strict_exception_groups = True to make sure we always unwrap # *this* nursery's exceptiongroup async with open_nursery(strict_exception_groups=True) as old_nursery: - task_status: _TaskStatus[Any] = _TaskStatus(old_nursery, self) + task_status: _TaskStatus[object | None] = _TaskStatus( + old_nursery, + self, + ) thunk = functools.partial(async_fn, task_status=task_status) task = GLOBAL_RUN_CONTEXT.runner.spawn_impl( thunk, @@ -1371,13 +1383,15 @@ def __del__(self) -> None: @final @attrs.define(eq=False, repr=False) -class Task(metaclass=NoPublicConstructor): +class Task(metaclass=NoPublicConstructor): # type: ignore[misc] _parent_nursery: Nursery | None - coro: Coroutine[Any, Outcome[object], Any] + # Explicit "Any" is not allowed + coro: Coroutine[Any, Outcome[object], Any] # type: ignore[misc] _runner: Runner name: str context: contextvars.Context _counter: int = attrs.field(init=False, factory=itertools.count().__next__) + _ki_protected: bool # Invariant: # - for unscheduled tasks, _next_send_fn and _next_send are both None @@ -1390,10 +1404,11 @@ class Task(metaclass=NoPublicConstructor): # tracebacks with extraneous frames. # - for scheduled tasks, custom_sleep_data is None # Tasks start out unscheduled. - _next_send_fn: Callable[[Any], object] | None = None - _next_send: Outcome[Any] | None | BaseException = None + # Explicit "Any" is not allowed + _next_send_fn: Callable[[Any], object] | None = None # type: ignore[misc] + _next_send: Outcome[Any] | BaseException | None = None # type: ignore[misc] _abort_func: Callable[[_core.RaiseCancelT], Abort] | None = None - custom_sleep_data: Any = None + custom_sleep_data: Any = None # type: ignore[misc] # For introspection and nursery.start() _child_nurseries: list[Nursery] = attrs.Factory(list) @@ -1461,7 +1476,7 @@ def print_stack_for_task(task): """ # Ignore static typing as we're doing lots of dynamic introspection - coro: Any = self.coro + coro: Any = self.coro # type: ignore[misc] while coro is not None: if hasattr(coro, "cr_frame"): # A real coroutine @@ -1553,14 +1568,6 @@ def raise_cancel() -> NoReturn: ################################################################ -class RunContext(threading.local): - runner: Runner - task: Task - - -GLOBAL_RUN_CONTEXT: Final = RunContext() - - @attrs.frozen class RunStatistics: """An object containing run-loop-level debugging information. @@ -1614,13 +1621,16 @@ class RunStatistics: @attrs.define(eq=False) -class GuestState: +# Explicit "Any" is not allowed +class GuestState: # type: ignore[misc] runner: Runner run_sync_soon_threadsafe: Callable[[Callable[[], object]], object] run_sync_soon_not_threadsafe: Callable[[Callable[[], object]], object] - done_callback: Callable[[Outcome[Any]], object] + # Explicit "Any" is not allowed + done_callback: Callable[[Outcome[Any]], object] # type: ignore[misc] unrolled_run_gen: Generator[float, EventResult, None] - unrolled_run_next_send: Outcome[Any] = attrs.Factory(lambda: Value(None)) + # Explicit "Any" is not allowed + unrolled_run_next_send: Outcome[Any] = attrs.Factory(lambda: Value(None)) # type: ignore[misc] def guest_tick(self) -> None: prev_library, sniffio_library.name = sniffio_library.name, "trio" @@ -1665,7 +1675,8 @@ def in_main_thread() -> None: @attrs.define(eq=False) -class Runner: +# Explicit "Any" is not allowed +class Runner: # type: ignore[misc] clock: Clock instruments: Instruments io_manager: TheIOManager @@ -1673,7 +1684,8 @@ class Runner: strict_exception_groups: bool # Run-local values, see _local.py - _locals: dict[_core.RunVar[Any], Any] = attrs.Factory(dict) + # Explicit "Any" is not allowed + _locals: dict[_core.RunVar[Any], object] = attrs.Factory(dict) # type: ignore[misc] runq: deque[Task] = attrs.Factory(deque) tasks: set[Task] = attrs.Factory(set) @@ -1684,7 +1696,7 @@ class Runner: system_nursery: Nursery | None = None system_context: contextvars.Context = attrs.field(kw_only=True) main_task: Task | None = None - main_task_outcome: Outcome[Any] | None = None + main_task_outcome: Outcome[object] | None = None entry_queue: EntryQueue = attrs.Factory(EntryQueue) trio_token: TrioToken | None = None @@ -1776,12 +1788,8 @@ def current_root_task(self) -> Task | None: # Core task handling primitives ################ - @_public # Type-ignore due to use of Any here. - def reschedule( # type: ignore[misc] - self, - task: Task, - next_send: Outcome[Any] = _NO_SEND, - ) -> None: + @_public + def reschedule(self, task: Task, next_send: Outcome[object] = _NO_SEND) -> None: """Reschedule the given task with the given :class:`outcome.Outcome`. @@ -1867,7 +1875,6 @@ async def python_wrapper(orig_coro: Awaitable[RetT]) -> RetT: coro = python_wrapper(coro) assert coro.cr_frame is not None, "Coroutine frame should exist" - coro.cr_frame.f_locals.setdefault(LOCALS_KEY_KI_PROTECTION_ENABLED, system_task) ###### # Set up the Task object @@ -1878,6 +1885,7 @@ async def python_wrapper(orig_coro: Awaitable[RetT]) -> RetT: runner=self, name=name, context=context, + ki_protected=system_task, ) self.tasks.add(task) @@ -1892,7 +1900,7 @@ async def python_wrapper(orig_coro: Awaitable[RetT]) -> RetT: self.reschedule(task, None) # type: ignore[arg-type] return task - def task_exited(self, task: Task, outcome: Outcome[Any]) -> None: + def task_exited(self, task: Task, outcome: Outcome[object]) -> None: # break parking lots associated with the exiting task if task in GLOBAL_PARKING_LOT_BREAKER: for lot in GLOBAL_PARKING_LOT_BREAKER[task]: @@ -2104,7 +2112,8 @@ def _deliver_ki_cb(self) -> None: # sortedcontainers doesn't have types, and is reportedly very hard to type: # https://github.com/grantjenks/python-sortedcontainers/issues/68 - waiting_for_idle: Any = attrs.Factory(SortedDict) + # Explicit "Any" is not allowed + waiting_for_idle: Any = attrs.Factory(SortedDict) # type: ignore[misc] @_public async def wait_all_tasks_blocked(self, cushion: float = 0.0) -> None: @@ -2398,13 +2407,15 @@ def run( # Inlined copy of runner.main_task_outcome.unwrap() to avoid # cluttering every single Trio traceback with an extra frame. if isinstance(runner.main_task_outcome, Value): - return cast(RetT, runner.main_task_outcome.value) + return cast("RetT", runner.main_task_outcome.value) if isinstance(runner.main_task_outcome, Error): raise runner.main_task_outcome.error - raise AssertionError(runner.main_task_outcome) # pragma: no cover + # pragma: no cover + raise AssertionError(runner.main_task_outcome) -def start_guest_run( +# Explicit .../"Any" not allowed +def start_guest_run( # type: ignore[misc] async_fn: Callable[..., Awaitable[RetT]], *args: object, run_sync_soon_threadsafe: Callable[[Callable[[], object]], object], @@ -2520,7 +2531,7 @@ def my_done_callback(run_outcome): # this time, so it shouldn't be possible to get an exception here, # except for a TrioInternalError. next_send = cast( - EventResult, + "EventResult", None, ) # First iteration must be `None`, every iteration after that is EventResult for _tick in range(5): # expected need is 2 iterations + leave some wiggle room @@ -2568,13 +2579,13 @@ def my_done_callback(run_outcome): # mode", where our core event loop gets unrolled into a series of callbacks on # the host loop. If you're doing a regular trio.run then this gets run # straight through. +@enable_ki_protection def unrolled_run( runner: Runner, async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]], args: tuple[Unpack[PosArgT]], host_uses_signal_set_wakeup_fd: bool = False, ) -> Generator[float, EventResult, None]: - sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True __tracebackhide__ = True try: @@ -2708,7 +2719,7 @@ def unrolled_run( next_send_fn = task._next_send_fn next_send = task._next_send task._next_send_fn = task._next_send = None - final_outcome: Outcome[Any] | None = None + final_outcome: Outcome[object] | None = None try: # We used to unwrap the Outcome object here and send/throw # its contents in directly, but it turns out that .throw() @@ -2817,15 +2828,15 @@ def unrolled_run( ################################################################ -class _TaskStatusIgnored(TaskStatus[Any]): +class _TaskStatusIgnored(TaskStatus[object]): def __repr__(self) -> str: return "TASK_STATUS_IGNORED" - def started(self, value: Any = None) -> None: + def started(self, value: object = None) -> None: pass -TASK_STATUS_IGNORED: Final[TaskStatus[Any]] = _TaskStatusIgnored() +TASK_STATUS_IGNORED: Final[TaskStatus[object]] = _TaskStatusIgnored() def current_task() -> Task: @@ -2942,6 +2953,13 @@ async def checkpoint_if_cancelled() -> None: _KqueueStatistics as IOStatistics, ) else: # pragma: no cover + _patchers = sorted({"eventlet", "gevent"}.intersection(sys.modules)) + if _patchers: + raise NotImplementedError( + "unsupported platform or primitives trio depends on are monkey-patched out by " + + ", ".join(_patchers), + ) + raise NotImplementedError("unsupported platform") from ._generated_instrumentation import * diff --git a/src/trio/_core/_run_context.py b/src/trio/_core/_run_context.py new file mode 100644 index 0000000000..085bff9a34 --- /dev/null +++ b/src/trio/_core/_run_context.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +import threading +from typing import TYPE_CHECKING, Final + +if TYPE_CHECKING: + from ._run import Runner, Task + + +class RunContext(threading.local): + runner: Runner + task: Task + + +GLOBAL_RUN_CONTEXT: Final = RunContext() diff --git a/src/trio/_core/_tests/test_guest_mode.py b/src/trio/_core/_tests/test_guest_mode.py index a5a056855a..ddb435a704 100644 --- a/src/trio/_core/_tests/test_guest_mode.py +++ b/src/trio/_core/_tests/test_guest_mode.py @@ -2,7 +2,6 @@ import asyncio import contextlib -import contextvars import queue import signal import socket @@ -11,24 +10,25 @@ import time import traceback import warnings -from collections.abc import AsyncGenerator, Awaitable, Callable +import weakref +from collections.abc import AsyncGenerator, Awaitable, Callable, Sequence from functools import partial from math import inf from typing import ( TYPE_CHECKING, - Any, NoReturn, TypeVar, + cast, ) import pytest +import sniffio from outcome import Outcome import trio import trio.testing -from trio.abc import Instrument +from trio.abc import Clock, Instrument -from ..._util import signal_raise from .tutil import gc_collect_harder, restore_unraisablehook if TYPE_CHECKING: @@ -37,7 +37,7 @@ from trio._channel import MemorySendChannel T = TypeVar("T") -InHost: TypeAlias = Callable[[object], None] +InHost: TypeAlias = Callable[[Callable[[], object]], None] # The simplest possible "host" loop. @@ -47,12 +47,16 @@ # - final result is returned # - any unhandled exceptions cause an immediate crash def trivial_guest_run( - trio_fn: Callable[..., Awaitable[T]], + trio_fn: Callable[[InHost], Awaitable[T]], *, in_host_after_start: Callable[[], None] | None = None, - **start_guest_run_kwargs: Any, + host_uses_signal_set_wakeup_fd: bool = False, + clock: Clock | None = None, + instruments: Sequence[Instrument] = (), + restrict_keyboard_interrupt_to_checkpoints: bool = False, + strict_exception_groups: bool = True, ) -> T: - todo: queue.Queue[tuple[str, Outcome[T] | Callable[..., object]]] = queue.Queue() + todo: queue.Queue[tuple[str, Outcome[T] | Callable[[], object]]] = queue.Queue() host_thread = threading.current_thread() @@ -86,7 +90,11 @@ def done_callback(outcome: Outcome[T]) -> None: run_sync_soon_threadsafe=run_sync_soon_threadsafe, run_sync_soon_not_threadsafe=run_sync_soon_not_threadsafe, done_callback=done_callback, - **start_guest_run_kwargs, + host_uses_signal_set_wakeup_fd=host_uses_signal_set_wakeup_fd, + clock=clock, + instruments=instruments, + restrict_keyboard_interrupt_to_checkpoints=restrict_keyboard_interrupt_to_checkpoints, + strict_exception_groups=strict_exception_groups, ) if in_host_after_start is not None: in_host_after_start() @@ -170,10 +178,16 @@ async def early_task() -> None: assert res == "ok" assert set(record) == {"system task ran", "main task ran", "run_sync_soon cb ran"} - class BadClock: + class BadClock(Clock): def start_clock(self) -> NoReturn: raise ValueError("whoops") + def current_time(self) -> float: + raise NotImplementedError() + + def deadline_to_sleep_time(self, deadline: float) -> float: + raise NotImplementedError() + def after_start_never_runs() -> None: # pragma: no cover pytest.fail("shouldn't get here") @@ -221,7 +235,8 @@ async def trio_main(in_host: InHost) -> str: def test_guest_mode_sniffio_integration() -> None: - from sniffio import current_async_library, thread_local as sniffio_library + current_async_library = sniffio.current_async_library + sniffio_library = sniffio.thread_local async def trio_main(in_host: InHost) -> str: async def synchronize() -> None: @@ -431,33 +446,46 @@ async def abandoned_main(in_host: InHost) -> None: def aiotrio_run( - trio_fn: Callable[..., Awaitable[T]], + trio_fn: Callable[[], Awaitable[T]], *, pass_not_threadsafe: bool = True, - **start_guest_run_kwargs: Any, + run_sync_soon_not_threadsafe: InHost | None = None, + host_uses_signal_set_wakeup_fd: bool = False, + clock: Clock | None = None, + instruments: Sequence[Instrument] = (), + restrict_keyboard_interrupt_to_checkpoints: bool = False, + strict_exception_groups: bool = True, ) -> T: loop = asyncio.new_event_loop() async def aio_main() -> T: - trio_done_fut = loop.create_future() + nonlocal run_sync_soon_not_threadsafe + trio_done_fut: asyncio.Future[Outcome[T]] = loop.create_future() - def trio_done_callback(main_outcome: Outcome[object]) -> None: + def trio_done_callback(main_outcome: Outcome[T]) -> None: print(f"trio_fn finished: {main_outcome!r}") trio_done_fut.set_result(main_outcome) if pass_not_threadsafe: - start_guest_run_kwargs["run_sync_soon_not_threadsafe"] = loop.call_soon + run_sync_soon_not_threadsafe = cast("InHost", loop.call_soon) trio.lowlevel.start_guest_run( trio_fn, run_sync_soon_threadsafe=loop.call_soon_threadsafe, done_callback=trio_done_callback, - **start_guest_run_kwargs, + run_sync_soon_not_threadsafe=run_sync_soon_not_threadsafe, + host_uses_signal_set_wakeup_fd=host_uses_signal_set_wakeup_fd, + clock=clock, + instruments=instruments, + restrict_keyboard_interrupt_to_checkpoints=restrict_keyboard_interrupt_to_checkpoints, + strict_exception_groups=strict_exception_groups, ) - return (await trio_done_fut).unwrap() # type: ignore[no-any-return] + return (await trio_done_fut).unwrap() try: + # can't use asyncio.run because that fails on Windows (3.8, x64, with + # Komodia LSP) and segfaults on Windows (3.9, x64, with Komodia LSP) return loop.run_until_complete(aio_main()) finally: loop.close() @@ -555,10 +583,13 @@ async def crash_in_worker_thread_io(in_host: InHost) -> None: t = threading.current_thread() old_get_events = trio._core._run.TheIOManager.get_events - def bad_get_events(*args: Any) -> object: + def bad_get_events( + self: trio._core._run.TheIOManager, + timeout: float, + ) -> trio._core._run.EventResult: if threading.current_thread() is not t: raise ValueError("oh no!") - return old_get_events(*args) + return old_get_events(self, timeout) m.setattr("trio._core._run.TheIOManager.get_events", bad_get_events) @@ -576,10 +607,10 @@ def test_guest_mode_ki() -> None: # Check SIGINT in Trio func and in host func async def trio_main(in_host: InHost) -> None: with pytest.raises(KeyboardInterrupt): - signal_raise(signal.SIGINT) + signal.raise_signal(signal.SIGINT) # Host SIGINT should get injected into Trio - in_host(partial(signal_raise, signal.SIGINT)) + in_host(partial(signal.raise_signal, signal.SIGINT)) await trio.sleep(10) with pytest.raises(KeyboardInterrupt) as excinfo: @@ -592,7 +623,7 @@ async def trio_main(in_host: InHost) -> None: final_exc = KeyError("whoa") async def trio_main_raising(in_host: InHost) -> NoReturn: - in_host(partial(signal_raise, signal.SIGINT)) + in_host(partial(signal.raise_signal, signal.SIGINT)) raise final_exc with pytest.raises(KeyboardInterrupt) as excinfo: @@ -627,8 +658,6 @@ async def trio_main(in_host: InHost) -> None: @restore_unraisablehook() def test_guest_mode_asyncgens() -> None: - import sniffio - record = set() async def agen(label: str) -> AsyncGenerator[int, None]: @@ -655,9 +684,49 @@ async def trio_main() -> None: gc_collect_harder() - # Ensure we don't pollute the thread-level context if run under - # an asyncio without contextvars support (3.6) - context = contextvars.copy_context() - context.run(aiotrio_run, trio_main, host_uses_signal_set_wakeup_fd=True) + aiotrio_run(trio_main, host_uses_signal_set_wakeup_fd=True) assert record == {("asyncio", "asyncio"), ("trio", "trio")} + + +@restore_unraisablehook() +def test_guest_mode_asyncgens_garbage_collection() -> None: + record: set[tuple[str, str, bool]] = set() + + async def agen(label: str) -> AsyncGenerator[int, None]: + class A: + pass + + a = A() + a_wr = weakref.ref(a) + assert sniffio.current_async_library() == label + try: + yield 1 + finally: + library = sniffio.current_async_library() + with contextlib.suppress(trio.Cancelled): + await sys.modules[library].sleep(0) + + del a + if sys.implementation.name == "pypy": + gc_collect_harder() + + record.add((label, library, a_wr() is None)) + + async def iterate_in_aio() -> None: + await agen("asyncio").asend(None) + + async def trio_main() -> None: + task = asyncio.ensure_future(iterate_in_aio()) + done_evt = trio.Event() + task.add_done_callback(lambda _: done_evt.set()) + with trio.fail_after(1): + await done_evt.wait() + + await agen("trio").asend(None) + + gc_collect_harder() + + aiotrio_run(trio_main, host_uses_signal_set_wakeup_fd=True) + + assert record == {("asyncio", "asyncio", True), ("trio", "trio", True)} diff --git a/src/trio/_core/_tests/test_ki.py b/src/trio/_core/_tests/test_ki.py index 8582cc0b21..67c83e8358 100644 --- a/src/trio/_core/_tests/test_ki.py +++ b/src/trio/_core/_tests/test_ki.py @@ -3,14 +3,19 @@ import contextlib import inspect import signal +import sys import threading -from typing import TYPE_CHECKING +import weakref +from collections.abc import AsyncIterator, Iterator +from typing import TYPE_CHECKING, Callable, TypeVar import outcome import pytest from trio.testing import RaisesGroup +from .tutil import gc_collect_harder + try: from async_generator import async_generator, yield_ except ImportError: # pragma: no cover @@ -18,18 +23,24 @@ from ... import _core from ..._abc import Instrument +from ..._core import _ki from ..._timeouts import sleep -from ..._util import signal_raise from ...testing import wait_all_tasks_blocked if TYPE_CHECKING: - from collections.abc import AsyncIterator, Callable, Iterator + from collections.abc import ( + AsyncGenerator, + AsyncIterator, + Callable, + Generator, + Iterator, + ) from ..._core import Abort, RaiseCancelT def ki_self() -> None: - signal_raise(signal.SIGINT) + signal.raise_signal(signal.SIGINT) def test_ki_self() -> None: @@ -517,3 +528,179 @@ async def inner() -> None: _core.run(inner) finally: threading._active[thread.ident] = original # type: ignore[attr-defined] + + +_T = TypeVar("_T") + + +def _identity(v: _T) -> _T: + return v + + +@pytest.mark.xfail( + strict=True, + raises=AssertionError, + reason=( + "it was decided not to protect against this case, see discussion in: " + "https://github.com/python-trio/trio/pull/3110#discussion_r1802123644" + ), +) +async def test_ki_does_not_leak_across_different_calls_to_inner_functions() -> None: + assert not _core.currently_ki_protected() + + def factory(enabled: bool) -> Callable[[], bool]: + @_core.enable_ki_protection if enabled else _identity + def decorated() -> bool: + return _core.currently_ki_protected() + + return decorated + + decorated_enabled = factory(True) + decorated_disabled = factory(False) + assert decorated_enabled() + assert not decorated_disabled() + + +async def test_ki_protection_check_does_not_freeze_locals() -> None: + class A: + pass + + a = A() + wr_a = weakref.ref(a) + assert not _core.currently_ki_protected() + del a + if sys.implementation.name == "pypy": + gc_collect_harder() + assert wr_a() is None + + +def test_identity_weakref_internals() -> None: + """To cover the parts WeakKeyIdentityDictionary won't ever reach.""" + + class A: + def __eq__(self, other: object) -> bool: + return False + + a = A() + assert a != a + wr = _ki._IdRef(a) + wr_other_is_self = wr + + # dict always checks identity before equality so we need to do it here + # to cover `if self is other` + assert wr == wr_other_is_self + + # we want to cover __ne__ and `return NotImplemented` + assert wr != object() + + +def test_weak_key_identity_dict_remove_callback_keyerror() -> None: + """We need to cover the KeyError in self._remove.""" + + class A: + def __eq__(self, other: object) -> bool: + return False + + a = A() + assert a != a + d: _ki.WeakKeyIdentityDictionary[A, bool] = _ki.WeakKeyIdentityDictionary() + + d[a] = True + + data_copy = d._data.copy() + d._data.clear() + del a + + gc_collect_harder() # would call sys.unraisablehook if there's a problem + assert data_copy + + +def test_weak_key_identity_dict_remove_callback_selfref_expired() -> None: + """We need to cover the KeyError in self._remove.""" + + class A: + def __eq__(self, other: object) -> bool: + return False + + a = A() + assert a != a + d: _ki.WeakKeyIdentityDictionary[A, bool] = _ki.WeakKeyIdentityDictionary() + + d[a] = True + + data_copy = d._data.copy() + wr_d = weakref.ref(d) + del d + gc_collect_harder() # would call sys.unraisablehook if there's a problem + assert wr_d() is None + del a + gc_collect_harder() + assert data_copy + + +@_core.enable_ki_protection +async def _protected_async_gen_fn() -> AsyncGenerator[None, None]: + yield + + +@_core.enable_ki_protection +async def _protected_async_fn() -> None: + pass + + +@_core.enable_ki_protection +def _protected_gen_fn() -> Generator[None, None, None]: + yield + + +@_core.disable_ki_protection +async def _unprotected_async_gen_fn() -> AsyncGenerator[None, None]: + yield + + +@_core.disable_ki_protection +async def _unprotected_async_fn() -> None: + pass + + +@_core.disable_ki_protection +def _unprotected_gen_fn() -> Generator[None, None, None]: + yield + + +async def _consume_async_generator(agen: AsyncGenerator[None, None]) -> None: + try: + with pytest.raises(StopAsyncIteration): + while True: + await agen.asend(None) + finally: + await agen.aclose() + + +# Explicit .../"Any" is not allowed +def _consume_function_for_coverage( # type: ignore[misc] + fn: Callable[..., object], +) -> None: + result = fn() + if inspect.isasyncgen(result): + result = _consume_async_generator(result) + + assert inspect.isgenerator(result) or inspect.iscoroutine(result) + with pytest.raises(StopIteration): + while True: + result.send(None) + + +def test_enable_disable_ki_protection_passes_on_inspect_flags() -> None: + assert inspect.isasyncgenfunction(_protected_async_gen_fn) + _consume_function_for_coverage(_protected_async_gen_fn) + assert inspect.iscoroutinefunction(_protected_async_fn) + _consume_function_for_coverage(_protected_async_fn) + assert inspect.isgeneratorfunction(_protected_gen_fn) + _consume_function_for_coverage(_protected_gen_fn) + assert inspect.isasyncgenfunction(_unprotected_async_gen_fn) + _consume_function_for_coverage(_unprotected_async_gen_fn) + assert inspect.iscoroutinefunction(_unprotected_async_fn) + _consume_function_for_coverage(_unprotected_async_fn) + assert inspect.isgeneratorfunction(_unprotected_gen_fn) + _consume_function_for_coverage(_unprotected_gen_fn) diff --git a/src/trio/_core/_tests/test_parking_lot.py b/src/trio/_core/_tests/test_parking_lot.py index d9afee83d4..809fb2824a 100644 --- a/src/trio/_core/_tests/test_parking_lot.py +++ b/src/trio/_core/_tests/test_parking_lot.py @@ -304,9 +304,10 @@ async def test_parking_lot_breaker_registration() -> None: # registering a task as breaker on an already broken lot is fine lot.break_lot() - child_task = None + child_task: _core.Task | None = None async with trio.open_nursery() as nursery: child_task = await nursery.start(dummy_task) + assert isinstance(child_task, _core.Task) add_parking_lot_breaker(child_task, lot) nursery.cancel_scope.cancel() assert lot.broken_by == [task, child_task] @@ -339,6 +340,9 @@ async def test_parking_lot_multiple_breakers_exit() -> None: child_task1 = await nursery.start(dummy_task) child_task2 = await nursery.start(dummy_task) child_task3 = await nursery.start(dummy_task) + assert isinstance(child_task1, _core.Task) + assert isinstance(child_task2, _core.Task) + assert isinstance(child_task3, _core.Task) add_parking_lot_breaker(child_task1, lot) add_parking_lot_breaker(child_task2, lot) add_parking_lot_breaker(child_task3, lot) @@ -350,9 +354,11 @@ async def test_parking_lot_multiple_breakers_exit() -> None: async def test_parking_lot_breaker_register_exited_task() -> None: lot = ParkingLot() - child_task = None + child_task: _core.Task | None = None async with trio.open_nursery() as nursery: - child_task = await nursery.start(dummy_task) + value = await nursery.start(dummy_task) + assert isinstance(value, _core.Task) + child_task = value nursery.cancel_scope.cancel() # trying to register an exited task as lot breaker errors with pytest.raises( diff --git a/src/trio/_core/_tests/test_run.py b/src/trio/_core/_tests/test_run.py index 23e38a1222..f2f2e7301b 100644 --- a/src/trio/_core/_tests/test_run.py +++ b/src/trio/_core/_tests/test_run.py @@ -10,7 +10,7 @@ import weakref from contextlib import ExitStack, contextmanager, suppress from math import inf, nan -from typing import TYPE_CHECKING, Any, NoReturn, TypeVar, cast +from typing import TYPE_CHECKING, NoReturn, TypeVar import outcome import pytest @@ -823,7 +823,9 @@ async def task3(task_status: _core.TaskStatus[_core.CancelScope]) -> None: await sleep_forever() async with _core.open_nursery() as nursery: - scope: _core.CancelScope = await nursery.start(task3) + value = await nursery.start(task3) + assert isinstance(value, _core.CancelScope) + scope: _core.CancelScope = value with pytest.raises(RuntimeError, match="from unrelated"): scope.__exit__(None, None, None) scope.cancel() @@ -1646,7 +1648,10 @@ async def func1(expected: str) -> None: async def func2() -> None: # pragma: no cover pass - async def check(spawn_fn: Callable[..., object]) -> None: + # Explicit .../"Any" is not allowed + async def check( # type: ignore[misc] + spawn_fn: Callable[..., object], + ) -> None: spawn_fn(func1, "func1") spawn_fn(func1, "func2", name=func2) spawn_fn(func1, "func3", name="func3") @@ -1681,13 +1686,14 @@ async def test_current_effective_deadline(mock_clock: _core.MockClock) -> None: def test_nice_error_on_bad_calls_to_run_or_spawn() -> None: - def bad_call_run( + # Explicit .../"Any" is not allowed + def bad_call_run( # type: ignore[misc] func: Callable[..., Awaitable[object]], *args: tuple[object, ...], ) -> None: _core.run(func, *args) - def bad_call_spawn( + def bad_call_spawn( # type: ignore[misc] func: Callable[..., Awaitable[object]], *args: tuple[object, ...], ) -> None: @@ -1959,7 +1965,9 @@ async def sleeping_children( # Cancelling the setup_nursery just *before* calling started() async with _core.open_nursery() as nursery: - target_nursery: _core.Nursery = await nursery.start(setup_nursery) + value = await nursery.start(setup_nursery) + assert isinstance(value, _core.Nursery) + target_nursery: _core.Nursery = value await target_nursery.start( sleeping_children, target_nursery.cancel_scope.cancel, @@ -1967,7 +1975,9 @@ async def sleeping_children( # Cancelling the setup_nursery just *after* calling started() async with _core.open_nursery() as nursery: - target_nursery = await nursery.start(setup_nursery) + value = await nursery.start(setup_nursery) + assert isinstance(value, _core.Nursery) + target_nursery = value await target_nursery.start(sleeping_children, lambda: None) target_nursery.cancel_scope.cancel() @@ -2286,7 +2296,8 @@ async def detachable_coroutine( await sleep(0) nonlocal task, pdco_outcome task = _core.current_task() - pdco_outcome = await outcome.acapture( + # `No overload variant of "acapture" matches argument types "Callable[[Outcome[object]], Coroutine[Any, Any, object]]", "Outcome[None]"` + pdco_outcome = await outcome.acapture( # type: ignore[call-overload] _core.permanently_detach_coroutine_object, task_outcome, ) @@ -2299,10 +2310,11 @@ async def detachable_coroutine( # is still iterable. At that point anything can be sent into the coroutine, so the .coro type # is wrong. assert pdco_outcome is None - assert not_none(task).coro.send(cast(Any, "be free!")) == "I'm free!" + # `Argument 1 to "send" of "Coroutine" has incompatible type "str"; expected "Outcome[object]"` + assert not_none(task).coro.send("be free!") == "I'm free!" # type: ignore[arg-type] assert pdco_outcome == outcome.Value("be free!") with pytest.raises(StopIteration): - not_none(task).coro.send(cast(Any, None)) + not_none(task).coro.send(None) # type: ignore[arg-type] # Check the exception paths too task = None @@ -2315,7 +2327,7 @@ async def detachable_coroutine( assert not_none(task).coro.throw(throw_in) == "uh oh" assert pdco_outcome == outcome.Error(throw_in) with pytest.raises(StopIteration): - task.coro.send(cast(Any, None)) + task.coro.send(None) async def bad_detach() -> None: async with _core.open_nursery(): @@ -2367,9 +2379,10 @@ def abort_fn(_: _core.RaiseCancelT) -> _core.Abort: # pragma: no cover await wait_all_tasks_blocked() # Okay, it's detached. Here's our coroutine runner: - assert not_none(task).coro.send(cast(Any, "not trio!")) == 1 - assert not_none(task).coro.send(cast(Any, None)) == 2 - assert not_none(task).coro.send(cast(Any, None)) == "byebye" + # `Argument 1 to "send" of "Coroutine" has incompatible type "str"; expected "Outcome[object]"` + assert not_none(task).coro.send("not trio!") == 1 # type: ignore[arg-type] + assert not_none(task).coro.send(None) == 2 # type: ignore[arg-type] + assert not_none(task).coro.send(None) == "byebye" # type: ignore[arg-type] # Now it's been reattached, and we can leave the nursery @@ -2399,7 +2412,8 @@ def abort_fn(_: _core.RaiseCancelT) -> _core.Abort: await wait_all_tasks_blocked() assert task is not None nursery.cancel_scope.cancel() - task.coro.send(cast(Any, None)) + # `Argument 1 to "send" of "Coroutine" has incompatible type "None"; expected "Outcome[object]"` + task.coro.send(None) # type: ignore[arg-type] assert abort_fn_called @@ -2779,3 +2793,47 @@ async def spawn_tasks_in_old_nursery(task_status: _core.TaskStatus[None]) -> Non with pytest.raises(_core.TrioInternalError) as excinfo: await nursery.start(spawn_tasks_in_old_nursery) assert RaisesGroup(ValueError, ValueError).matches(excinfo.value.__cause__) + + +if sys.version_info >= (3, 11): + + def no_other_refs() -> list[object]: + return [] + +else: + + def no_other_refs() -> list[object]: + return [sys._getframe(1)] + + +@pytest.mark.skipif( + sys.implementation.name != "cpython", + reason="Only makes sense with refcounting GC", +) +async def test_ki_protection_doesnt_leave_cyclic_garbage() -> None: + class MyException(Exception): + pass + + async def demo() -> None: + async def handle_error() -> None: + try: + raise MyException + except MyException as e: + exceptions.append(e) + + exceptions: list[MyException] = [] + try: + async with _core.open_nursery() as n: + n.start_soon(handle_error) + raise ExceptionGroup("errors", exceptions) + finally: + exceptions = [] + + exc: Exception | None = None + try: + await demo() + except ExceptionGroup as excs: + exc = excs.exceptions[0] + + assert isinstance(exc, MyException) + assert gc.get_referrers(exc) == no_other_refs() diff --git a/src/trio/_core/_tests/tutil.py b/src/trio/_core/_tests/tutil.py index 81370ed76e..063fa1dd80 100644 --- a/src/trio/_core/_tests/tutil.py +++ b/src/trio/_core/_tests/tutil.py @@ -12,7 +12,7 @@ import pytest -# See trio/_tests/conftest.py for the other half of this +# See trio/_tests/pytest_plugin.py for the other half of this from trio._tests.pytest_plugin import RUN_SLOW if TYPE_CHECKING: diff --git a/src/trio/_core/_thread_cache.py b/src/trio/_core/_thread_cache.py index c612222697..189d5a5836 100644 --- a/src/trio/_core/_thread_cache.py +++ b/src/trio/_core/_thread_cache.py @@ -7,10 +7,13 @@ from functools import partial from itertools import count from threading import Lock, Thread -from typing import Any, Callable, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar import outcome +if TYPE_CHECKING: + from collections.abc import Callable + RetT = TypeVar("RetT") @@ -126,6 +129,8 @@ def darwin_namefunc( class WorkerThread(Generic[RetT]): + __slots__ = ("_default_name", "_job", "_thread", "_thread_cache", "_worker_lock") + def __init__(self, thread_cache: ThreadCache) -> None: self._job: ( tuple[ @@ -207,8 +212,11 @@ def _work(self) -> None: class ThreadCache: + __slots__ = ("_idle_workers",) + def __init__(self) -> None: - self._idle_workers: dict[WorkerThread[Any], None] = {} + # Explicit "Any" not allowed + self._idle_workers: dict[WorkerThread[Any], None] = {} # type: ignore[misc] def start_thread_soon( self, diff --git a/src/trio/_core/_traps.py b/src/trio/_core/_traps.py index 27518406cb..1ddd5628ba 100644 --- a/src/trio/_core/_traps.py +++ b/src/trio/_core/_traps.py @@ -4,7 +4,9 @@ import enum import types -from typing import TYPE_CHECKING, Any, Callable, NoReturn + +# Jedi gets mad in test_static_tool_sees_class_members if we use collections Callable +from typing import TYPE_CHECKING, Any, Callable, NoReturn, Union, cast import attrs import outcome @@ -12,10 +14,40 @@ from . import _run if TYPE_CHECKING: + from collections.abc import Awaitable, Generator + from typing_extensions import TypeAlias from ._run import Task +RaiseCancelT: TypeAlias = Callable[[], NoReturn] + + +# This class object is used as a singleton. +# Not exported in the trio._core namespace, but imported directly by _run. +class CancelShieldedCheckpoint: + __slots__ = () + + +# Not exported in the trio._core namespace, but imported directly by _run. +@attrs.frozen(slots=False) +class WaitTaskRescheduled: + abort_func: Callable[[RaiseCancelT], Abort] + + +# Not exported in the trio._core namespace, but imported directly by _run. +@attrs.frozen(slots=False) +class PermanentlyDetachCoroutineObject: + final_outcome: outcome.Outcome[object] + + +MessageType: TypeAlias = Union[ + type[CancelShieldedCheckpoint], + WaitTaskRescheduled, + PermanentlyDetachCoroutineObject, + object, +] + # Helper for the bottommost 'yield'. You can't use 'yield' inside an async # function, but you can inside a generator, and if you decorate your generator @@ -25,14 +57,18 @@ # tracking machinery. Since our traps are public APIs, we make them real async # functions, and then this helper takes care of the actual yield: @types.coroutine -def _async_yield(obj: Any) -> Any: # type: ignore[misc] +def _real_async_yield( + obj: MessageType, +) -> Generator[MessageType, None, None]: return (yield obj) -# This class object is used as a singleton. -# Not exported in the trio._core namespace, but imported directly by _run. -class CancelShieldedCheckpoint: - pass +# Real yield value is from trio's main loop, but type checkers can't +# understand that, so we cast it to make type checkers understand. +_async_yield = cast( + "Callable[[MessageType], Awaitable[outcome.Outcome[object]]]", + _real_async_yield, +) async def cancel_shielded_checkpoint() -> None: @@ -66,18 +102,12 @@ class Abort(enum.Enum): FAILED = 2 -# Not exported in the trio._core namespace, but imported directly by _run. -@attrs.frozen(slots=False) -class WaitTaskRescheduled: - abort_func: Callable[[RaiseCancelT], Abort] - - -RaiseCancelT: TypeAlias = Callable[[], NoReturn] - - # Should always return the type a Task "expects", unless you willfully reschedule it # with a bad value. -async def wait_task_rescheduled(abort_func: Callable[[RaiseCancelT], Abort]) -> Any: +# Explicit "Any" is not allowed +async def wait_task_rescheduled( # type: ignore[misc] + abort_func: Callable[[RaiseCancelT], Abort], +) -> Any: """Put the current task to sleep, with cancellation support. This is the lowest-level API for blocking in Trio. Every time a @@ -179,15 +209,9 @@ def abort(inner_raise_cancel): return (await _async_yield(WaitTaskRescheduled(abort_func))).unwrap() -# Not exported in the trio._core namespace, but imported directly by _run. -@attrs.frozen(slots=False) -class PermanentlyDetachCoroutineObject: - final_outcome: outcome.Outcome[Any] - - async def permanently_detach_coroutine_object( - final_outcome: outcome.Outcome[Any], -) -> Any: + final_outcome: outcome.Outcome[object], +) -> object: """Permanently detach the current task from the Trio scheduler. Normally, a Trio task doesn't exit until its coroutine object exits. When @@ -220,7 +244,7 @@ async def permanently_detach_coroutine_object( async def temporarily_detach_coroutine_object( abort_func: Callable[[RaiseCancelT], Abort], -) -> Any: +) -> object: """Temporarily detach the current coroutine object from the Trio scheduler. diff --git a/src/trio/_core/_windows_cffi.py b/src/trio/_core/_windows_cffi.py index 453b4beda3..575fcb5601 100644 --- a/src/trio/_core/_windows_cffi.py +++ b/src/trio/_core/_windows_cffi.py @@ -395,9 +395,9 @@ class _Overlapped(Protocol): hEvent: Handle -kernel32 = cast(_Kernel32, ffi.dlopen("kernel32.dll")) -ntdll = cast(_Nt, ffi.dlopen("ntdll.dll")) -ws2_32 = cast(_Ws2, ffi.dlopen("ws2_32.dll")) +kernel32 = cast("_Kernel32", ffi.dlopen("kernel32.dll")) +ntdll = cast("_Nt", ffi.dlopen("ntdll.dll")) +ws2_32 = cast("_Ws2", ffi.dlopen("ws2_32.dll")) ################################################################ # Magic numbers diff --git a/src/trio/_dtls.py b/src/trio/_dtls.py index b3a6c75a9e..fdce94b191 100644 --- a/src/trio/_dtls.py +++ b/src/trio/_dtls.py @@ -19,7 +19,6 @@ from itertools import count from typing import ( TYPE_CHECKING, - Any, Generic, TypeVar, Union, @@ -37,9 +36,10 @@ from types import TracebackType # See DTLSEndpoint.__init__ for why this is imported here - from OpenSSL import SSL # noqa: TCH004 + from OpenSSL import SSL # noqa: TC004 from typing_extensions import Self, TypeAlias, TypeVarTuple, Unpack + from trio._socket import AddressFormat from trio.socket import SocketType PosArgsT = TypeVarTuple("PosArgsT") @@ -566,7 +566,7 @@ def _make_cookie( key: bytes, salt: bytes, tick: int, - address: Any, + address: AddressFormat, client_hello_bits: bytes, ) -> bytes: assert len(salt) == SALT_BYTES @@ -587,7 +587,7 @@ def _make_cookie( def valid_cookie( key: bytes, cookie: bytes, - address: Any, + address: AddressFormat, client_hello_bits: bytes, ) -> bool: if len(cookie) > SALT_BYTES: @@ -615,7 +615,7 @@ def valid_cookie( def challenge_for( key: bytes, - address: Any, + address: AddressFormat, epoch_seqno: int, client_hello_bits: bytes, ) -> bytes: @@ -661,7 +661,7 @@ def challenge_for( class _Queue(Generic[_T]): - def __init__(self, incoming_packets_buffer: int | float): # noqa: PYI041 + def __init__(self, incoming_packets_buffer: int | float) -> None: # noqa: PYI041 self.s, self.r = trio.open_memory_channel[_T](incoming_packets_buffer) @@ -678,7 +678,7 @@ def _read_loop(read_fn: Callable[[int], bytes]) -> bytes: async def handle_client_hello_untrusted( endpoint: DTLSEndpoint, - address: Any, + address: AddressFormat, packet: bytes, ) -> None: # it's trivial to write a simple function that directly calls this to @@ -730,23 +730,6 @@ async def handle_client_hello_untrusted( # after all. return - # Some old versions of OpenSSL have a bug with memory BIOs, where DTLSv1_listen - # consumes the ClientHello out of the BIO, but then do_handshake expects the - # ClientHello to still be in there (but not the one that ships with Ubuntu - # 20.04). In particular, this is known to affect the OpenSSL v1.1.1 that ships - # with Ubuntu 18.04. To work around this, we deliver a second copy of the - # ClientHello after DTLSv1_listen has completed. This is safe to do - # unconditionally, because on newer versions of OpenSSL, the second ClientHello - # is treated as a duplicate packet, which is a normal thing that can happen over - # UDP. For more details, see: - # - # https://github.com/pyca/pyopenssl/blob/e84e7b57d1838de70ab7a27089fbee78ce0d2106/tests/test_ssl.py#L4226-L4293 - # - # This was fixed in v1.1.1a, and all later versions. So maybe in 2024 or so we - # can delete this. The fix landed in OpenSSL master as 079ef6bd534d2, and then - # was backported to the 1.1.1 branch as d1bfd8076e28. - stream._ssl.bio_write(packet) - # Check if we have an existing association old_stream = endpoint._streams.get(address) if old_stream is not None: @@ -852,7 +835,7 @@ class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor): def __init__( self, endpoint: DTLSEndpoint, - peer_address: Any, + peer_address: AddressFormat, ctx: SSL.Context, ) -> None: self.endpoint = endpoint @@ -1227,7 +1210,9 @@ def __init__( # as a peer provides a valid cookie, we can immediately tear down the # old connection. # {remote address: DTLSChannel} - self._streams: WeakValueDictionary[Any, DTLSChannel] = WeakValueDictionary() + self._streams: WeakValueDictionary[AddressFormat, DTLSChannel] = ( + WeakValueDictionary() + ) self._listening_context: SSL.Context | None = None self._listening_key: bytes | None = None self._incoming_connections_q = _Queue[DTLSChannel](float("inf")) diff --git a/src/trio/_file_io.py b/src/trio/_file_io.py index 45f2475b84..1cd8696e49 100644 --- a/src/trio/_file_io.py +++ b/src/trio/_file_io.py @@ -31,6 +31,8 @@ ) from typing_extensions import Literal + from ._sync import CapacityLimiter + # This list is also in the docs, make sure to keep them in sync _FILE_SYNC_ATTRS: set[str] = { "closed", @@ -241,7 +243,10 @@ def __getattr__(self, name: str) -> object: meth = getattr(self._wrapped, name) @async_wraps(self.__class__, self._wrapped.__class__, name) - async def wrapper(*args, **kwargs): + async def wrapper( + *args: Callable[..., T], + **kwargs: object | str | bool | CapacityLimiter | None, + ) -> T: func = partial(meth, *args, **kwargs) return await trio.to_thread.run_sync(func) @@ -443,7 +448,7 @@ async def open_file( newline: str | None = None, closefd: bool = True, opener: _Opener | None = None, -) -> AsyncIOWrapper[Any]: +) -> AsyncIOWrapper[object]: """Asynchronous version of :func:`open`. Returns: diff --git a/src/trio/_highlevel_open_tcp_stream.py b/src/trio/_highlevel_open_tcp_stream.py index 6cb19266a0..43125dac89 100644 --- a/src/trio/_highlevel_open_tcp_stream.py +++ b/src/trio/_highlevel_open_tcp_stream.py @@ -11,6 +11,8 @@ from collections.abc import Generator from socket import AddressFamily, SocketKind + from trio._socket import AddressFormat + if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup, ExceptionGroup @@ -132,16 +134,9 @@ def close_all() -> Generator[set[SocketType], None, None]: raise BaseExceptionGroup("", errs) -def reorder_for_rfc_6555_section_5_4( - targets: list[ - tuple[ - AddressFamily, - SocketKind, - int, - str, - Any, - ] - ], +# Explicit "Any" is not allowed +def reorder_for_rfc_6555_section_5_4( # type: ignore[misc] + targets: list[tuple[AddressFamily, SocketKind, int, str, Any]], ) -> None: # RFC 6555 section 5.4 says that if getaddrinfo returns multiple address # families (e.g. IPv4 and IPv6), then you should make sure that your first @@ -301,7 +296,7 @@ async def open_tcp_stream( # face of crash or cancellation async def attempt_connect( socket_args: tuple[AddressFamily, SocketKind, int], - sockaddr: Any, + sockaddr: AddressFormat, attempt_failed: trio.Event, ) -> None: nonlocal winning_socket diff --git a/src/trio/_highlevel_serve_listeners.py b/src/trio/_highlevel_serve_listeners.py index 0a85c8ecb0..9b17f8d538 100644 --- a/src/trio/_highlevel_serve_listeners.py +++ b/src/trio/_highlevel_serve_listeners.py @@ -25,7 +25,8 @@ StreamT = TypeVar("StreamT", bound=trio.abc.AsyncResource) -ListenerT = TypeVar("ListenerT", bound=trio.abc.Listener[Any]) +# Explicit "Any" is not allowed +ListenerT = TypeVar("ListenerT", bound=trio.abc.Listener[Any]) # type: ignore[misc] Handler = Callable[[StreamT], Awaitable[object]] @@ -67,7 +68,8 @@ async def _serve_one_listener( # https://github.com/python/typing/issues/548 -async def serve_listeners( +# Explicit "Any" is not allowed +async def serve_listeners( # type: ignore[misc] handler: Handler[StreamT], listeners: list[ListenerT], *, diff --git a/src/trio/_highlevel_socket.py b/src/trio/_highlevel_socket.py index 4e1992c0e9..0b099f9049 100644 --- a/src/trio/_highlevel_socket.py +++ b/src/trio/_highlevel_socket.py @@ -68,7 +68,7 @@ class SocketStream(HalfCloseableStream): """ - def __init__(self, socket: SocketType): + def __init__(self, socket: SocketType) -> None: if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketStream requires a Trio socket object") if socket.type != tsocket.SOCK_STREAM: @@ -363,7 +363,7 @@ class SocketListener(Listener[SocketStream]): """ - def __init__(self, socket: SocketType): + def __init__(self, socket: SocketType) -> None: if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketListener requires a Trio socket object") if socket.type != tsocket.SOCK_STREAM: diff --git a/src/trio/_path.py b/src/trio/_path.py index 2c9dfff292..a58136b75b 100644 --- a/src/trio/_path.py +++ b/src/trio/_path.py @@ -30,8 +30,9 @@ T = TypeVar("T") -def _wraps_async( - wrapped: Callable[..., Any], +# Explicit .../"Any" is not allowed +def _wraps_async( # type: ignore[misc] + wrapped: Callable[..., object], ) -> Callable[[Callable[P, T]], Callable[P, Awaitable[T]]]: def decorator(fn: Callable[P, T]) -> Callable[P, Awaitable[T]]: async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: diff --git a/src/trio/_repl.py b/src/trio/_repl.py index 9da7e1ecde..754b0017bb 100644 --- a/src/trio/_repl.py +++ b/src/trio/_repl.py @@ -22,7 +22,7 @@ class TrioInteractiveConsole(InteractiveConsole): # we make the type more specific on our subclass locals: dict[str, object] - def __init__(self, repl_locals: dict[str, object] | None = None): + def __init__(self, repl_locals: dict[str, object] | None = None) -> None: super().__init__(locals=repl_locals) self.compile.compiler.flags |= ast.PyCF_ALLOW_TOP_LEVEL_AWAIT diff --git a/src/trio/_signals.py b/src/trio/_signals.py index 63ee176c1b..729c48ad4e 100644 --- a/src/trio/_signals.py +++ b/src/trio/_signals.py @@ -7,7 +7,7 @@ import trio -from ._util import ConflictDetector, is_main_thread, signal_raise +from ._util import ConflictDetector, is_main_thread if TYPE_CHECKING: from collections.abc import AsyncIterator, Callable, Generator, Iterable @@ -78,7 +78,7 @@ def __init__(self) -> None: def _add(self, signum: int) -> None: if self._closed: - signal_raise(signum) + signal.raise_signal(signum) else: self._pending[signum] = None self._lot.unpark() @@ -95,7 +95,7 @@ def deliver_next() -> None: if self._pending: signum, _ = self._pending.popitem(last=False) try: - signal_raise(signum) + signal.raise_signal(signum) finally: deliver_next() diff --git a/src/trio/_socket.py b/src/trio/_socket.py index f7a5cfbd66..59adf62c77 100644 --- a/src/trio/_socket.py +++ b/src/trio/_socket.py @@ -12,7 +12,6 @@ from typing import ( TYPE_CHECKING, Any, - Literal, SupportsIndex, TypeVar, Union, @@ -50,7 +49,8 @@ # most users, so currently we just specify it as `Any`. Otherwise we would write: # `AddressFormat = TypeVar("AddressFormat")` # but instead we simply do: -AddressFormat: TypeAlias = Any +# Explicit "Any" is not allowed +AddressFormat: TypeAlias = Any # type: ignore[misc] # Usage: @@ -65,7 +65,7 @@ class _try_sync: def __init__( self, blocking_exc_override: Callable[[BaseException], bool] | None = None, - ): + ) -> None: self._blocking_exc_override = blocking_exc_override def _is_blocking_io_error(self, exc: BaseException) -> bool: @@ -332,7 +332,7 @@ def fromshare(info: bytes) -> SocketType: TypeT: TypeAlias = int FamilyDefault = _stdlib_socket.AF_INET else: - FamilyDefault: Literal[None] = None + FamilyDefault: None = None FamilyT: TypeAlias = Union[int, AddressFamily, None] TypeT: TypeAlias = Union[_stdlib_socket.socket, int] @@ -473,7 +473,7 @@ async def _resolve_address_nocp( ipv6_v6only: bool | int, address: AddressFormat, local: bool, -) -> Any: +) -> AddressFormat: # Do some pre-checking (or exit early for non-IP sockets) if family == _stdlib_socket.AF_INET: if not isinstance(address, tuple) or not len(address) == 2: @@ -710,7 +710,7 @@ def recvmsg( __bufsize: int, __ancbufsize: int = 0, __flags: int = 0, - ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, Any]]: + ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, object]]: raise NotImplementedError if sys.platform != "win32" or ( @@ -722,7 +722,7 @@ def recvmsg_into( __buffers: Iterable[Buffer], __ancbufsize: int = 0, __flags: int = 0, - ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, Any]]: + ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, object]]: raise NotImplementedError def send(__self, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]: @@ -743,7 +743,7 @@ async def sendto( __address: tuple[object, ...] | str | Buffer, ) -> int: ... - async def sendto(self, *args: Any) -> int: + async def sendto(self, *args: object) -> int: raise NotImplementedError if sys.platform != "win32" or ( @@ -777,7 +777,7 @@ async def sendmsg( class _SocketType(SocketType): - def __init__(self, sock: _stdlib_socket.socket): + def __init__(self, sock: _stdlib_socket.socket) -> None: if type(sock) is not _stdlib_socket.socket: # For example, ssl.SSLSocket subclasses socket.socket, but we # certainly don't want to blindly wrap one of those. @@ -1191,7 +1191,7 @@ def recvmsg( __bufsize: int, __ancbufsize: int = 0, __flags: int = 0, - ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, Any]]: ... + ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, object]]: ... recvmsg = _make_simple_sock_method_wrapper( _stdlib_socket.socket.recvmsg, @@ -1213,7 +1213,7 @@ def recvmsg_into( __buffers: Iterable[Buffer], __ancbufsize: int = 0, __flags: int = 0, - ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, Any]]: ... + ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, object]]: ... recvmsg_into = _make_simple_sock_method_wrapper( _stdlib_socket.socket.recvmsg_into, @@ -1253,8 +1253,8 @@ async def sendto( __address: tuple[object, ...] | str | Buffer, ) -> int: ... - @_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=()) # type: ignore[misc] - async def sendto(self, *args: Any) -> int: + @_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=()) + async def sendto(self, *args: object) -> int: """Similar to :meth:`socket.socket.sendto`, but async.""" # args is: data[, flags], address # and kwargs are not accepted @@ -1279,10 +1279,10 @@ async def sendto(self, *args: Any) -> int: @_wraps(_stdlib_socket.socket.sendmsg, assigned=(), updated=()) async def sendmsg( self, - __buffers: Iterable[Buffer], - __ancdata: Iterable[tuple[int, int, Buffer]] = (), - __flags: int = 0, - __address: AddressFormat | None = None, + buffers: Iterable[Buffer], + ancdata: Iterable[tuple[int, int, Buffer]] = (), + flags: int = 0, + address: AddressFormat | None = None, ) -> int: """Similar to :meth:`socket.socket.sendmsg`, but async. @@ -1290,15 +1290,15 @@ async def sendmsg( available. """ - if __address is not None: - __address = await self._resolve_address_nocp(__address, local=False) + if address is not None: + address = await self._resolve_address_nocp(address, local=False) return await self._nonblocking_helper( _core.wait_writable, _stdlib_socket.socket.sendmsg, - __buffers, - __ancdata, - __flags, - __address, + buffers, + ancdata, + flags, + address, ) ################################################################ diff --git a/src/trio/_ssl.py b/src/trio/_ssl.py index 8dd9b2123d..9c259435b4 100644 --- a/src/trio/_ssl.py +++ b/src/trio/_ssl.py @@ -16,6 +16,10 @@ if TYPE_CHECKING: from collections.abc import Awaitable, Callable + from typing_extensions import TypeVarTuple, Unpack + + Ts = TypeVarTuple("Ts") + # General theory of operation: # # We implement an API that closely mirrors the stdlib ssl module's blocking @@ -219,7 +223,13 @@ class NeedHandshakeError(Exception): class _Once: - def __init__(self, afn: Callable[..., Awaitable[object]], *args: object) -> None: + __slots__ = ("_afn", "_args", "_done", "started") + + def __init__( + self, + afn: Callable[[*Ts], Awaitable[object]], + *args: Unpack[Ts], + ) -> None: self._afn = afn self._args = args self.started = False @@ -413,7 +423,11 @@ def __init__( "version", } - def __getattr__(self, name: str) -> Any: + # Explicit "Any" is not allowed + def __getattr__( # type: ignore[misc] + self, + name: str, + ) -> Any: if name in self._forwarded: if name in self._after_handshake and not self._handshook.done: raise NeedHandshakeError(f"call do_handshake() before calling {name!r}") @@ -445,8 +459,8 @@ def _check_status(self) -> None: # too. async def _retry( self, - fn: Callable[..., T], - *args: object, + fn: Callable[[*Ts], T], + *args: Unpack[Ts], ignore_want_read: bool = False, is_handshake: bool = False, ) -> T | None: diff --git a/src/trio/_subprocess_platform/__init__.py b/src/trio/_subprocess_platform/__init__.py index d74cd462a0..daa28d8cd2 100644 --- a/src/trio/_subprocess_platform/__init__.py +++ b/src/trio/_subprocess_platform/__init__.py @@ -8,7 +8,7 @@ import trio from .. import _core, _subprocess -from .._abc import ReceiveStream, SendStream # noqa: TCH001 +from .._abc import ReceiveStream, SendStream # noqa: TC001 _wait_child_exiting_error: ImportError | None = None _create_child_pipe_error: ImportError | None = None @@ -85,11 +85,11 @@ def create_pipe_from_child_output() -> tuple[ClosableReceiveStream, int]: elif os.name == "posix": - def create_pipe_to_child_stdin(): + def create_pipe_to_child_stdin() -> tuple[trio.lowlevel.FdStream, int]: rfd, wfd = os.pipe() return trio.lowlevel.FdStream(wfd), rfd - def create_pipe_from_child_output(): + def create_pipe_from_child_output() -> tuple[trio.lowlevel.FdStream, int]: rfd, wfd = os.pipe() return trio.lowlevel.FdStream(rfd), wfd @@ -106,12 +106,12 @@ def create_pipe_from_child_output(): from .._windows_pipes import PipeReceiveStream, PipeSendStream - def create_pipe_to_child_stdin(): + def create_pipe_to_child_stdin() -> tuple[PipeSendStream, int]: # for stdin, we want the write end (our end) to use overlapped I/O rh, wh = windows_pipe(overlapped=(False, True)) return PipeSendStream(wh), msvcrt.open_osfhandle(rh, os.O_RDONLY) - def create_pipe_from_child_output(): + def create_pipe_from_child_output() -> tuple[PipeReceiveStream, int]: # for stdout/err, it's the read end that's overlapped rh, wh = windows_pipe(overlapped=(True, False)) return PipeReceiveStream(rh), msvcrt.open_osfhandle(wh, 0) diff --git a/src/trio/_sync.py b/src/trio/_sync.py index b3f9e2b4d3..53808c261d 100644 --- a/src/trio/_sync.py +++ b/src/trio/_sync.py @@ -220,7 +220,7 @@ class CapacityLimiter(AsyncContextManagerMixin): """ # total_tokens would ideally be int|Literal[math.inf] - but that's not valid typing - def __init__(self, total_tokens: int | float): # noqa: PYI041 + def __init__(self, total_tokens: int | float) -> None: # noqa: PYI041 self._lot = ParkingLot() self._borrowers: set[Task | object] = set() # Maps tasks attempting to acquire -> borrower, to handle on-behalf-of @@ -433,7 +433,7 @@ class Semaphore(AsyncContextManagerMixin): """ - def __init__(self, initial_value: int, *, max_value: int | None = None): + def __init__(self, initial_value: int, *, max_value: int | None = None) -> None: if not isinstance(initial_value, int): raise TypeError("initial_value must be an int") if initial_value < 0: @@ -759,7 +759,7 @@ class Condition(AsyncContextManagerMixin): """ - def __init__(self, lock: Lock | None = None): + def __init__(self, lock: Lock | None = None) -> None: if lock is None: lock = Lock() if type(lock) is not Lock: diff --git a/src/trio/_tests/_check_type_completeness.json b/src/trio/_tests/_check_type_completeness.json index badb7cba17..72d981f89c 100644 --- a/src/trio/_tests/_check_type_completeness.json +++ b/src/trio/_tests/_check_type_completeness.json @@ -40,7 +40,6 @@ "No docstring found for class \"trio._core._local.RunVarToken\"", "No docstring found for class \"trio.lowlevel.RunVarToken\"", "No docstring found for class \"trio.lowlevel.Task\"", - "No docstring found for class \"trio._core._ki.KIProtectionSignature\"", "No docstring found for class \"trio.socket.SocketType\"", "No docstring found for class \"trio.socket.gaierror\"", "No docstring found for class \"trio.socket.herror\"", diff --git a/src/trio/_tests/test_deprecate_strict_exception_groups_false.py b/src/trio/_tests/test_deprecate_strict_exception_groups_false.py index 7e575aa92e..1b02c9ee73 100644 --- a/src/trio/_tests/test_deprecate_strict_exception_groups_false.py +++ b/src/trio/_tests/test_deprecate_strict_exception_groups_false.py @@ -32,7 +32,7 @@ async def foo_loose_nursery() -> None: async with trio.open_nursery(strict_exception_groups=False): ... - def helper(fun: Callable[..., Awaitable[None]], num: int) -> None: + def helper(fun: Callable[[], Awaitable[None]], num: int) -> None: with pytest.warns( trio.TrioDeprecationWarning, match="strict_exception_groups=False", diff --git a/src/trio/_tests/test_exports.py b/src/trio/_tests/test_exports.py index d824814022..7d8a7e3c0b 100644 --- a/src/trio/_tests/test_exports.py +++ b/src/trio/_tests/test_exports.py @@ -19,11 +19,10 @@ import trio import trio.testing -from trio._tests.pytest_plugin import skip_if_optional_else_raise +from trio._tests.pytest_plugin import RUN_SLOW, skip_if_optional_else_raise from .. import _core, _util from .._core._tests.tutil import slow -from .pytest_plugin import RUN_SLOW if TYPE_CHECKING: from collections.abc import Iterable, Iterator @@ -390,11 +389,13 @@ def lookup_symbol(symbol: str) -> dict[str, str]: assert "node" in cached_type_info node = cached_type_info["node"] - static_names = no_hidden(k for k in node["names"] if not k.startswith(".")) + static_names = no_hidden( + k for k in node.get("names", ()) if not k.startswith(".") + ) for symbol in node["mro"][1:]: node = lookup_symbol(symbol)["node"] static_names |= no_hidden( - k for k in node["names"] if not k.startswith(".") + k for k in node.get("names", ()) if not k.startswith(".") ) static_names -= ignore_names @@ -572,3 +573,37 @@ def test_classes_are_final() -> None: continue assert class_is_final(class_) + + +# Plugin might not be running, especially if running from an installed version. +@pytest.mark.skipif( + not hasattr(attrs.field, "trio_modded"), + reason="Pytest plugin not installed.", +) +def test_pyright_recognizes_init_attributes() -> None: + """Check whether we provide `alias` for all underscore prefixed attributes. + + Attrs always sets the `alias` attribute on fields, so a pytest plugin is used + to monkeypatch `field()` to record whether an alias was defined in the metadata. + See `_trio_check_attrs_aliases`. + """ + for module in PUBLIC_MODULES: + for class_ in module.__dict__.values(): + if not attrs.has(class_): + continue + if isinstance(class_, _util.NoPublicConstructor): + continue + + attributes = [ + attr + for attr in attrs.fields(class_) + if attr.init + if attr.alias + not in ( + attr.name, + # trio_original_args may not be present in autoattribs + attr.metadata.get("trio_original_args", {}).get("alias"), + ) + ] + + assert attributes == [], class_ diff --git a/src/trio/_tests/test_highlevel_open_tcp_listeners.py b/src/trio/_tests/test_highlevel_open_tcp_listeners.py index 30596aa4fb..e78e4414d2 100644 --- a/src/trio/_tests/test_highlevel_open_tcp_listeners.py +++ b/src/trio/_tests/test_highlevel_open_tcp_listeners.py @@ -4,7 +4,7 @@ import socket as stdlib_socket import sys from socket import AddressFamily, SocketKind -from typing import TYPE_CHECKING, Any, overload +from typing import TYPE_CHECKING, cast, overload import attrs import pytest @@ -30,6 +30,8 @@ from typing_extensions import Buffer + from trio._socket import AddressFormat + async def test_open_tcp_listeners_basic() -> None: listeners = await open_tcp_listeners(0) @@ -195,7 +197,7 @@ def setsockopt( ) -> None: pass - async def bind(self, address: Any) -> None: + async def bind(self, address: AddressFormat) -> None: pass def listen(self, /, backlog: int = min(stdlib_socket.SOMAXCONN, 128)) -> None: @@ -310,7 +312,9 @@ async def handler(stream: SendStream) -> None: async with trio.open_nursery() as nursery: # nursery.start is incorrectly typed, awaiting #2773 - listeners: list[SocketListener] = await nursery.start(serve_tcp, handler, 0) + value = await nursery.start(serve_tcp, handler, 0) + assert isinstance(value, list) + listeners = cast("list[SocketListener]", value) stream = await open_stream_to_socket_listener(listeners[0]) async with stream: assert await stream.receive_some(1) == b"x" diff --git a/src/trio/_tests/test_highlevel_open_tcp_stream.py b/src/trio/_tests/test_highlevel_open_tcp_stream.py index 0032a551dc..98adf7efea 100644 --- a/src/trio/_tests/test_highlevel_open_tcp_stream.py +++ b/src/trio/_tests/test_highlevel_open_tcp_stream.py @@ -3,7 +3,7 @@ import socket import sys from socket import AddressFamily, SocketKind -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import attrs import pytest @@ -360,7 +360,8 @@ async def run_scenario( # If this is True, we require there to be an exception, and return # (exception, scenario object) expect_error: tuple[type[BaseException], ...] | type[BaseException] = (), - **kwargs: Any, + happy_eyeballs_delay: float | None = 0.25, + local_address: str | None = None, ) -> tuple[SocketType, Scenario] | tuple[BaseException, Scenario]: supported_families = set() if ipv4_supported: @@ -372,7 +373,12 @@ async def run_scenario( trio.socket.set_custom_socket_factory(scenario) try: - stream = await open_tcp_stream("test.example.com", port, **kwargs) + stream = await open_tcp_stream( + "test.example.com", + port, + happy_eyeballs_delay=happy_eyeballs_delay, + local_address=local_address, + ) assert expect_error == () scenario.check(stream.socket) return (stream.socket, scenario) diff --git a/src/trio/_tests/test_highlevel_serve_listeners.py b/src/trio/_tests/test_highlevel_serve_listeners.py index 0ce82e7846..9268555b32 100644 --- a/src/trio/_tests/test_highlevel_serve_listeners.py +++ b/src/trio/_tests/test_highlevel_serve_listeners.py @@ -2,7 +2,7 @@ import errno from functools import partial -from typing import TYPE_CHECKING, NoReturn +from typing import TYPE_CHECKING, NoReturn, cast import attrs @@ -96,11 +96,13 @@ async def do_tests(parent_nursery: Nursery) -> None: parent_nursery.cancel_scope.cancel() async with trio.open_nursery() as nursery: - l2: list[MemoryListener] = await nursery.start( + value = await nursery.start( trio.serve_listeners, handler, listeners, ) + assert isinstance(value, list) + l2 = cast("list[MemoryListener]", value) assert l2 == listeners # This is just split into another function because gh-136 isn't # implemented yet @@ -172,7 +174,9 @@ async def connection_watcher( # the exception is wrapped twice because we open two nested nurseries with RaisesGroup(RaisesGroup(Done)): async with trio.open_nursery() as nursery: - handler_nursery: trio.Nursery = await nursery.start(connection_watcher) + value = await nursery.start(connection_watcher) + assert isinstance(value, trio.Nursery) + handler_nursery: trio.Nursery = value await nursery.start( partial( trio.serve_listeners, diff --git a/src/trio/_tests/test_highlevel_ssl_helpers.py b/src/trio/_tests/test_highlevel_ssl_helpers.py index ca23c333c7..e42f311981 100644 --- a/src/trio/_tests/test_highlevel_ssl_helpers.py +++ b/src/trio/_tests/test_highlevel_ssl_helpers.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import partial -from typing import TYPE_CHECKING, Any, NoReturn +from typing import TYPE_CHECKING, NoReturn, cast import attrs import pytest @@ -66,7 +66,11 @@ async def getaddrinfo( ]: return [(AF_INET, SOCK_STREAM, IPPROTO_TCP, "", self.sockaddr)] - async def getnameinfo(self, *args: Any) -> NoReturn: # pragma: no cover + async def getnameinfo( + self, + sockaddr: tuple[str, int] | tuple[str, int, int, int], + flags: int, + ) -> NoReturn: # pragma: no cover raise NotImplementedError @@ -79,17 +83,17 @@ async def test_open_ssl_over_tcp_stream_and_everything_else( # TODO: this function wraps an SSLListener around a SocketListener, this is illegal # according to current type hints, and probably for good reason. But there should # maybe be a different wrapper class/function that could be used instead? - res: list[SSLListener[SocketListener]] = ( # type: ignore[type-var] - await nursery.start( - partial( - serve_ssl_over_tcp, - echo_handler, - 0, - SERVER_CTX, - host="127.0.0.1", - ), - ) + value = await nursery.start( + partial( + serve_ssl_over_tcp, + echo_handler, + 0, + SERVER_CTX, + host="127.0.0.1", + ), ) + assert isinstance(value, list) + res = cast("list[SSLListener[SocketListener]]", value) # type: ignore[type-var] (listener,) = res async with listener: # listener.transport_listener is of type Listener[Stream] diff --git a/src/trio/_tests/test_repl.py b/src/trio/_tests/test_repl.py index eba166d7d5..be9338ce4c 100644 --- a/src/trio/_tests/test_repl.py +++ b/src/trio/_tests/test_repl.py @@ -103,12 +103,11 @@ async def test_KI_interrupts( console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals()) raw_input = build_raw_input( [ - "from trio._util import signal_raise", "import signal, trio, trio.lowlevel", "async def f():", " trio.lowlevel.spawn_system_task(" " trio.to_thread.run_sync," - " signal_raise,signal.SIGINT," + " signal.raise_signal, signal.SIGINT," " )", # just awaiting this kills the test runner?! " await trio.sleep_forever()", " print('should not see this')", diff --git a/src/trio/_tests/test_signals.py b/src/trio/_tests/test_signals.py index 453b1b68f2..d149b86575 100644 --- a/src/trio/_tests/test_signals.py +++ b/src/trio/_tests/test_signals.py @@ -10,7 +10,6 @@ from .. import _core from .._signals import _signal_handler, get_pending_signal_count, open_signal_receiver -from .._util import signal_raise if TYPE_CHECKING: from types import FrameType @@ -21,16 +20,16 @@ async def test_open_signal_receiver() -> None: with open_signal_receiver(signal.SIGILL) as receiver: # Raise it a few times, to exercise signal coalescing, both at the # call_soon level and at the SignalQueue level - signal_raise(signal.SIGILL) - signal_raise(signal.SIGILL) + signal.raise_signal(signal.SIGILL) + signal.raise_signal(signal.SIGILL) await _core.wait_all_tasks_blocked() - signal_raise(signal.SIGILL) + signal.raise_signal(signal.SIGILL) await _core.wait_all_tasks_blocked() async for signum in receiver: # pragma: no branch assert signum == signal.SIGILL break assert get_pending_signal_count(receiver) == 0 - signal_raise(signal.SIGILL) + signal.raise_signal(signal.SIGILL) async for signum in receiver: # pragma: no branch assert signum == signal.SIGILL break @@ -101,8 +100,8 @@ async def test_open_signal_receiver_no_starvation() -> None: print(signal.getsignal(signal.SIGILL)) previous = None for _ in range(10): - signal_raise(signal.SIGILL) - signal_raise(signal.SIGFPE) + signal.raise_signal(signal.SIGILL) + signal.raise_signal(signal.SIGFPE) await wait_run_sync_soon_idempotent_queue_barrier() if previous is None: previous = await receiver.__anext__() @@ -134,8 +133,8 @@ def direct_handler(signo: int, frame: FrameType | None) -> None: # before we exit the with block: with _signal_handler({signal.SIGILL, signal.SIGFPE}, direct_handler): with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver: - signal_raise(signal.SIGILL) - signal_raise(signal.SIGFPE) + signal.raise_signal(signal.SIGILL) + signal.raise_signal(signal.SIGFPE) await wait_run_sync_soon_idempotent_queue_barrier() assert delivered_directly == {signal.SIGILL, signal.SIGFPE} delivered_directly.clear() @@ -145,8 +144,8 @@ def direct_handler(signo: int, frame: FrameType | None) -> None: # we exit the with block: with _signal_handler({signal.SIGILL, signal.SIGFPE}, direct_handler): with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver: - signal_raise(signal.SIGILL) - signal_raise(signal.SIGFPE) + signal.raise_signal(signal.SIGILL) + signal.raise_signal(signal.SIGFPE) await wait_run_sync_soon_idempotent_queue_barrier() assert get_pending_signal_count(receiver) == 2 assert delivered_directly == {signal.SIGILL, signal.SIGFPE} @@ -157,14 +156,14 @@ def direct_handler(signo: int, frame: FrameType | None) -> None: print(3) with _signal_handler({signal.SIGILL}, signal.SIG_IGN): with open_signal_receiver(signal.SIGILL) as receiver: - signal_raise(signal.SIGILL) + signal.raise_signal(signal.SIGILL) await wait_run_sync_soon_idempotent_queue_barrier() # test passes if the process reaches this point without dying print(4) with _signal_handler({signal.SIGILL}, signal.SIG_IGN): with open_signal_receiver(signal.SIGILL) as receiver: - signal_raise(signal.SIGILL) + signal.raise_signal(signal.SIGILL) await wait_run_sync_soon_idempotent_queue_barrier() assert get_pending_signal_count(receiver) == 1 # test passes if the process reaches this point without dying @@ -177,8 +176,8 @@ def raise_handler(signum: int, frame: FrameType | None) -> NoReturn: with _signal_handler({signal.SIGILL, signal.SIGFPE}, raise_handler): with pytest.raises(RuntimeError) as excinfo: with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver: - signal_raise(signal.SIGILL) - signal_raise(signal.SIGFPE) + signal.raise_signal(signal.SIGILL) + signal.raise_signal(signal.SIGFPE) await wait_run_sync_soon_idempotent_queue_barrier() assert get_pending_signal_count(receiver) == 2 exc = excinfo.value diff --git a/src/trio/_tests/test_socket.py b/src/trio/_tests/test_socket.py index 8d70f232a6..9b56f92149 100644 --- a/src/trio/_tests/test_socket.py +++ b/src/trio/_tests/test_socket.py @@ -8,17 +8,19 @@ import tempfile from pathlib import Path from socket import AddressFamily, SocketKind -from typing import TYPE_CHECKING, Any, Callable, Union +from typing import TYPE_CHECKING, Union, cast import attrs import pytest from .. import _core, socket as tsocket -from .._core._tests.tutil import binds_ipv6, creates_ipv6 -from .._socket import _NUMERIC_ONLY, SocketType, _SocketType, _try_sync +from .._core._tests.tutil import binds_ipv6, can_create_ipv6, creates_ipv6 +from .._socket import _NUMERIC_ONLY, AddressFormat, SocketType, _SocketType, _try_sync from ..testing import assert_checkpoints, wait_all_tasks_blocked if TYPE_CHECKING: + from collections.abc import Callable + from typing_extensions import TypeAlias from .._highlevel_socket import SocketStream @@ -31,9 +33,18 @@ Union[tuple[str, int], tuple[str, int, int, int]], ] GetAddrInfoResponse: TypeAlias = list[GaiTuple] + GetAddrInfoArgs: TypeAlias = tuple[ + Union[str, bytes, None], + Union[str, bytes, int, None], + int, + int, + int, + int, + ] else: GaiTuple: object GetAddrInfoResponse = object + GetAddrInfoArgs = object ################################################################ # utils @@ -41,15 +52,34 @@ class MonkeypatchedGAI: - def __init__(self, orig_getaddrinfo: Callable[..., GetAddrInfoResponse]) -> None: + __slots__ = ("_orig_getaddrinfo", "_responses", "record") + + def __init__( + self, + orig_getaddrinfo: Callable[ + [str | bytes | None, str | bytes | int | None, int, int, int, int], + GetAddrInfoResponse, + ], + ) -> None: self._orig_getaddrinfo = orig_getaddrinfo - self._responses: dict[tuple[Any, ...], GetAddrInfoResponse | str] = {} - self.record: list[tuple[Any, ...]] = [] + self._responses: dict[ + GetAddrInfoArgs, + GetAddrInfoResponse | str, + ] = {} + self.record: list[GetAddrInfoArgs] = [] # get a normalized getaddrinfo argument tuple - def _frozenbind(self, *args: Any, **kwargs: Any) -> tuple[Any, ...]: + def _frozenbind( + self, + host: str | bytes | None, + port: str | bytes | int | None, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, + ) -> GetAddrInfoArgs: sig = inspect.signature(self._orig_getaddrinfo) - bound = sig.bind(*args, **kwargs) + bound = sig.bind(host, port, family=family, type=type, proto=proto, flags=flags) bound.apply_defaults() frozenbound = bound.args assert not bound.kwargs @@ -58,18 +88,39 @@ def _frozenbind(self, *args: Any, **kwargs: Any) -> tuple[Any, ...]: def set( self, response: GetAddrInfoResponse | str, - *args: Any, - **kwargs: Any, + host: str | bytes | None, + port: str | bytes | int | None, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, ) -> None: - self._responses[self._frozenbind(*args, **kwargs)] = response - - def getaddrinfo(self, *args: Any, **kwargs: Any) -> GetAddrInfoResponse | str: - bound = self._frozenbind(*args, **kwargs) + self._responses[ + self._frozenbind( + host, + port, + family=family, + type=type, + proto=proto, + flags=flags, + ) + ] = response + + def getaddrinfo( + self, + host: str | bytes | None, + port: str | bytes | int | None, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, + ) -> GetAddrInfoResponse | str: + bound = self._frozenbind(host, port, family, type, proto, flags) self.record.append(bound) if bound in self._responses: return self._responses[bound] - if bound[-1] & stdlib_socket.AI_NUMERICHOST: - return self._orig_getaddrinfo(*args, **kwargs) + if flags & stdlib_socket.AI_NUMERICHOST: + return self._orig_getaddrinfo(host, port, family, type, proto, flags) raise RuntimeError(f"gai called with unexpected arguments {bound}") @@ -324,9 +375,12 @@ async def test_sniff_sockopts() -> None: from socket import AF_INET, AF_INET6, SOCK_DGRAM, SOCK_STREAM # generate the combinations of families/types we're testing: + families = [AF_INET] + if can_create_ipv6: + families.append(AF_INET6) sockets = [ stdlib_socket.socket(family, type_) - for family in [AF_INET, AF_INET6] + for family in families for type_ in [SOCK_DGRAM, SOCK_STREAM] ] for socket in sockets: @@ -594,11 +648,13 @@ async def res( | tuple[str, str, int] | tuple[str, str, int, int] ), - ) -> Any: - return await sock._resolve_address_nocp( + ) -> tuple[str | int, ...]: + value = await sock._resolve_address_nocp( args, local=local, # noqa: B023 # local is not bound in function definition ) + assert isinstance(value, tuple) + return cast("tuple[Union[str, int], ...]", value) assert_eq(await res((addrs.arbitrary, "http")), (addrs.arbitrary, 80)) if v6: @@ -793,7 +849,10 @@ async def test_SocketType_connect_paths() -> None: # nose -- and then swap it back out again before we hit # wait_socket_writable, which insists on a real socket. class CancelSocket(stdlib_socket.socket): - def connect(self, *args: Any, **kwargs: Any) -> None: + def connect( + self, + address: AddressFormat, + ) -> None: # accessing private method only available in _SocketType assert isinstance(sock, _SocketType) @@ -803,7 +862,7 @@ def connect(self, *args: Any, **kwargs: Any) -> None: self.family, self.type, ) - sock._sock.connect(*args, **kwargs) + sock._sock.connect(address) # If connect *doesn't* raise, then pretend it did raise BlockingIOError # pragma: no cover @@ -850,15 +909,17 @@ async def test_resolve_address_exception_in_connect_closes_socket() -> None: with tsocket.socket() as sock: async def _resolve_address_nocp( - self: Any, - *args: Any, - **kwargs: Any, + address: AddressFormat, + *, + local: bool, ) -> None: + assert address == "" + assert not local cancel_scope.cancel() await _core.checkpoint() assert isinstance(sock, _SocketType) - sock._resolve_address_nocp = _resolve_address_nocp # type: ignore[method-assign, assignment] + sock._resolve_address_nocp = _resolve_address_nocp # type: ignore[method-assign] with assert_checkpoints(): with pytest.raises(_core.Cancelled): await sock.connect("") diff --git a/src/trio/_tests/test_ssl.py b/src/trio/_tests/test_ssl.py index 1d264b5168..9b872a2f49 100644 --- a/src/trio/_tests/test_ssl.py +++ b/src/trio/_tests/test_ssl.py @@ -167,8 +167,8 @@ def ssl_echo_serve_sync( # Fixture that gives a raw socket connected to a trio-test-1 echo server # (running in a thread). Useful for testing making connections with different # SSLContexts. -@asynccontextmanager # type: ignore[misc] # decorated contains Any -async def ssl_echo_server_raw(**kwargs: Any) -> AsyncIterator[SocketStream]: +@asynccontextmanager +async def ssl_echo_server_raw(expect_fail: bool = False) -> AsyncIterator[SocketStream]: a, b = stdlib_socket.socketpair() async with trio.open_nursery() as nursery: # Exiting the 'with a, b' context manager closes the sockets, which @@ -177,7 +177,7 @@ async def ssl_echo_server_raw(**kwargs: Any) -> AsyncIterator[SocketStream]: with a, b: nursery.start_soon( trio.to_thread.run_sync, - partial(ssl_echo_serve_sync, b, **kwargs), + partial(ssl_echo_serve_sync, b, expect_fail=expect_fail), ) yield SocketStream(tsocket.from_stdlib_socket(a)) @@ -185,12 +185,12 @@ async def ssl_echo_server_raw(**kwargs: Any) -> AsyncIterator[SocketStream]: # Fixture that gives a properly set up SSLStream connected to a trio-test-1 # echo server (running in a thread) -@asynccontextmanager # type: ignore[misc] # decorated contains Any +@asynccontextmanager async def ssl_echo_server( client_ctx: SSLContext, - **kwargs: Any, + expect_fail: bool = False, ) -> AsyncIterator[SSLStream[Stream]]: - async with ssl_echo_server_raw(**kwargs) as sock: + async with ssl_echo_server_raw(expect_fail=expect_fail) as sock: yield SSLStream(sock, client_ctx, server_hostname="trio-test-1.example.org") @@ -200,7 +200,10 @@ async def ssl_echo_server( # jakkdl: it seems to implement all the abstract methods (now), so I made it inherit # from Stream for the sake of typechecking. class PyOpenSSLEchoStream(Stream): - def __init__(self, sleeper: None = None) -> None: + def __init__( + self, + sleeper: Callable[[str], Awaitable[None]] | None = None, + ) -> None: ctx = SSL.Context(SSL.SSLv23_METHOD) # TLS 1.3 removes renegotiation support. Which is great for them, but # we still have to support versions before that, and that means we @@ -248,6 +251,7 @@ def __init__(self, sleeper: None = None) -> None: "simultaneous calls to PyOpenSSLEchoStream.receive_some", ) + self.sleeper: Callable[[str], Awaitable[None]] if sleeper is None: async def no_op_sleeper(_: object) -> None: @@ -383,16 +387,17 @@ async def do_test( await do_test("receive_some", (1,), "receive_some", (1,)) -@contextmanager # type: ignore[misc] # decorated contains Any +@contextmanager def virtual_ssl_echo_server( client_ctx: SSLContext, - **kwargs: Any, + sleeper: Callable[[str], Awaitable[None]] | None = None, ) -> Iterator[SSLStream[PyOpenSSLEchoStream]]: - fakesock = PyOpenSSLEchoStream(**kwargs) + fakesock = PyOpenSSLEchoStream(sleeper=sleeper) yield SSLStream(fakesock, client_ctx, server_hostname="trio-test-1.example.org") -def ssl_wrap_pair( +# Explicit "Any" is not allowed +def ssl_wrap_pair( # type: ignore[misc] client_ctx: SSLContext, client_transport: T_Stream, server_transport: T_Stream, @@ -422,23 +427,43 @@ def ssl_wrap_pair( MemoryStapledStream: TypeAlias = StapledStream[MemorySendStream, MemoryReceiveStream] -def ssl_memory_stream_pair(client_ctx: SSLContext, **kwargs: Any) -> tuple[ +def ssl_memory_stream_pair( + client_ctx: SSLContext, + client_kwargs: dict[str, str | bytes | bool | None] | None = None, + server_kwargs: dict[str, str | bytes | bool | None] | None = None, +) -> tuple[ SSLStream[MemoryStapledStream], SSLStream[MemoryStapledStream], ]: client_transport, server_transport = memory_stream_pair() - return ssl_wrap_pair(client_ctx, client_transport, server_transport, **kwargs) + return ssl_wrap_pair( + client_ctx, + client_transport, + server_transport, + client_kwargs=client_kwargs, + server_kwargs=server_kwargs, + ) MyStapledStream: TypeAlias = StapledStream[SendStream, ReceiveStream] -def ssl_lockstep_stream_pair(client_ctx: SSLContext, **kwargs: Any) -> tuple[ +def ssl_lockstep_stream_pair( + client_ctx: SSLContext, + client_kwargs: dict[str, str | bytes | bool | None] | None = None, + server_kwargs: dict[str, str | bytes | bool | None] | None = None, +) -> tuple[ SSLStream[MyStapledStream], SSLStream[MyStapledStream], ]: client_transport, server_transport = lockstep_stream_pair() - return ssl_wrap_pair(client_ctx, client_transport, server_transport, **kwargs) + return ssl_wrap_pair( + client_ctx, + client_transport, + server_transport, + client_kwargs=client_kwargs, + server_kwargs=server_kwargs, + ) # Simple smoke test for handshake/send/receive/shutdown talking to a @@ -1318,13 +1343,17 @@ async def test_getpeercert(client_ctx: SSLContext) -> None: async def test_SSLListener(client_ctx: SSLContext) -> None: async def setup( - **kwargs: Any, + https_compatible: bool = False, ) -> tuple[tsocket.SocketType, SSLListener[SocketStream], SSLStream[SocketStream]]: listen_sock = tsocket.socket() await listen_sock.bind(("127.0.0.1", 0)) listen_sock.listen(1) socket_listener = SocketListener(listen_sock) - ssl_listener = SSLListener(socket_listener, SERVER_CTX, **kwargs) + ssl_listener = SSLListener( + socket_listener, + SERVER_CTX, + https_compatible=https_compatible, + ) transport_client = await open_tcp_stream(*listen_sock.getsockname()) ssl_client = SSLStream( diff --git a/src/trio/_tests/test_subprocess.py b/src/trio/_tests/test_subprocess.py index 32c793f589..d7ed48e889 100644 --- a/src/trio/_tests/test_subprocess.py +++ b/src/trio/_tests/test_subprocess.py @@ -87,8 +87,11 @@ def got_signal(proc: Process, sig: SignalType) -> bool: return proc.returncode != 0 -@asynccontextmanager # type: ignore[misc] # Any in decorator -async def open_process_then_kill(*args: Any, **kwargs: Any) -> AsyncIterator[Process]: +@asynccontextmanager # type: ignore[misc] # Any in decorated +async def open_process_then_kill( + *args: Any, + **kwargs: Any, +) -> AsyncIterator[Process]: proc = await open_process(*args, **kwargs) try: yield proc @@ -97,11 +100,16 @@ async def open_process_then_kill(*args: Any, **kwargs: Any) -> AsyncIterator[Pro await proc.wait() -@asynccontextmanager # type: ignore[misc] # Any in decorator -async def run_process_in_nursery(*args: Any, **kwargs: Any) -> AsyncIterator[Process]: +@asynccontextmanager # type: ignore[misc] # Any in decorated +async def run_process_in_nursery( + *args: Any, + **kwargs: Any, +) -> AsyncIterator[Process]: async with _core.open_nursery() as nursery: kwargs.setdefault("check", False) - proc: Process = await nursery.start(partial(run_process, *args, **kwargs)) + value = await nursery.start(partial(run_process, *args, **kwargs)) + assert isinstance(value, Process) + proc: Process = value yield proc nursery.cancel_scope.cancel() @@ -112,7 +120,11 @@ async def run_process_in_nursery(*args: Any, **kwargs: Any) -> AsyncIterator[Pro ids=["open_process", "run_process in nursery"], ) -BackgroundProcessType: TypeAlias = Callable[..., AbstractAsyncContextManager[Process]] +# Explicit .../"Any" is not allowed +BackgroundProcessType: TypeAlias = Callable[ # type: ignore[misc] + ..., + AbstractAsyncContextManager[Process], +] @background_process_param @@ -629,7 +641,9 @@ async def test_warn_on_cancel_SIGKILL_escalation( async def test_run_process_background_fail() -> None: with RaisesGroup(subprocess.CalledProcessError): async with _core.open_nursery() as nursery: - proc: Process = await nursery.start(run_process, EXIT_FALSE) + value = await nursery.start(run_process, EXIT_FALSE) + assert isinstance(value, Process) + proc: Process = value assert proc.returncode == 1 diff --git a/src/trio/_tests/test_testing_raisesgroup.py b/src/trio/_tests/test_testing_raisesgroup.py index 17eb6afcc7..bb86d88646 100644 --- a/src/trio/_tests/test_testing_raisesgroup.py +++ b/src/trio/_tests/test_testing_raisesgroup.py @@ -3,7 +3,6 @@ import re import sys from types import TracebackType -from typing import Any import pytest @@ -25,7 +24,7 @@ def test_raises_group() -> None: f'Invalid argument "{TypeError()!r}" must be exception type, Matcher, or RaisesGroup.', ), ): - RaisesGroup(TypeError()) + RaisesGroup(TypeError()) # type: ignore[call-overload] with RaisesGroup(ValueError): raise ExceptionGroup("foo", (ValueError(),)) @@ -235,7 +234,10 @@ def test_RaisesGroup_matches() -> None: def test_message() -> None: - def check_message(message: str, body: RaisesGroup[Any]) -> None: + def check_message( + message: str, + body: RaisesGroup[BaseException], + ) -> None: with pytest.raises( AssertionError, match=f"^DID NOT RAISE any exception, expected {re.escape(message)}$", @@ -351,9 +353,9 @@ def check_errno_is_5(e: OSError) -> bool: def test_matcher_tostring() -> None: assert str(Matcher(ValueError)) == "Matcher(ValueError)" assert str(Matcher(match="[a-z]")) == "Matcher(match='[a-z]')" - pattern_no_flags = re.compile("noflag", 0) + pattern_no_flags = re.compile(r"noflag", 0) assert str(Matcher(match=pattern_no_flags)) == "Matcher(match='noflag')" - pattern_flags = re.compile("noflag", re.IGNORECASE) + pattern_flags = re.compile(r"noflag", re.IGNORECASE) assert str(Matcher(match=pattern_flags)) == f"Matcher(match={pattern_flags!r})" assert ( str(Matcher(ValueError, match="re", check=bool)) diff --git a/src/trio/_tests/test_threads.py b/src/trio/_tests/test_threads.py index f9d1c5e67d..ac770f451e 100644 --- a/src/trio/_tests/test_threads.py +++ b/src/trio/_tests/test_threads.py @@ -55,7 +55,8 @@ async def test_do_in_trio_thread() -> None: trio_thread = threading.current_thread() - async def check_case( + # Explicit "Any" is not allowed + async def check_case( # type: ignore[misc] do_in_trio_thread: Callable[..., threading.Thread], fn: Callable[..., T | Awaitable[T]], expected: tuple[str, T], @@ -219,7 +220,7 @@ def f(name: str) -> Callable[[None], threading.Thread]: # test that you can set a custom name, and that it's reset afterwards async def test_thread_name(name: str) -> None: thread = await to_thread_run_sync(f(name), thread_name=name) - assert re.match("Trio thread [0-9]*", thread.name) + assert re.match(r"Trio thread [0-9]*", thread.name) await test_thread_name("") await test_thread_name("fobiedoo") @@ -300,7 +301,7 @@ async def test_thread_name(name: str, expected: str | None = None) -> None: os_thread_name = _get_thread_name(thread.ident) assert os_thread_name is not None, "should skip earlier if this is the case" - assert re.match("Trio thread [0-9]*", os_thread_name) + assert re.match(r"Trio thread [0-9]*", os_thread_name) await test_thread_name("") await test_thread_name("fobiedoo") diff --git a/src/trio/_tests/test_unix_pipes.py b/src/trio/_tests/test_unix_pipes.py index 6f8fa6e02e..c850ebefea 100644 --- a/src/trio/_tests/test_unix_pipes.py +++ b/src/trio/_tests/test_unix_pipes.py @@ -12,6 +12,9 @@ from .._core._tests.tutil import gc_collect_harder, skip_if_fbsd_pipes_broken from ..testing import check_one_way_stream, wait_all_tasks_blocked +if TYPE_CHECKING: + from .._file_io import _HasFileNo + posix = os.name == "posix" pytestmark = pytest.mark.skipif(not posix, reason="posix only") @@ -30,7 +33,7 @@ async def make_pipe() -> tuple[FdStream, FdStream]: return FdStream(w), FdStream(r) -async def make_clogged_pipe(): +async def make_clogged_pipe() -> tuple[FdStream, FdStream]: s, r = await make_pipe() try: while True: @@ -197,8 +200,11 @@ async def expect_closedresourceerror() -> None: orig_wait_readable = _core._run.TheIOManager.wait_readable - async def patched_wait_readable(*args, **kwargs) -> None: - await orig_wait_readable(*args, **kwargs) + async def patched_wait_readable( + self: _core._run.TheIOManager, + fd: int | _HasFileNo, + ) -> None: + await orig_wait_readable(self, fd) await r.aclose() monkeypatch.setattr(_core._run.TheIOManager, "wait_readable", patched_wait_readable) @@ -225,8 +231,11 @@ async def expect_closedresourceerror() -> None: orig_wait_writable = _core._run.TheIOManager.wait_writable - async def patched_wait_writable(*args, **kwargs) -> None: - await orig_wait_writable(*args, **kwargs) + async def patched_wait_writable( + self: _core._run.TheIOManager, + fd: int | _HasFileNo, + ) -> None: + await orig_wait_writable(self, fd) await s.aclose() monkeypatch.setattr(_core._run.TheIOManager, "wait_writable", patched_wait_writable) diff --git a/src/trio/_tests/test_util.py b/src/trio/_tests/test_util.py index 41ce5f27cb..5036d76e52 100644 --- a/src/trio/_tests/test_util.py +++ b/src/trio/_tests/test_util.py @@ -1,8 +1,11 @@ -import signal +from __future__ import annotations + import sys import types -from typing import Any, TypeVar +from typing import TYPE_CHECKING, TypeVar +if TYPE_CHECKING: + from collections.abc import AsyncGenerator, Coroutine, Generator import pytest import trio @@ -21,25 +24,13 @@ fixup_module_metadata, generic_function, is_main_thread, - signal_raise, ) from ..testing import wait_all_tasks_blocked -T = TypeVar("T") - - -def test_signal_raise() -> None: - record = [] +if TYPE_CHECKING: + from collections.abc import AsyncGenerator - def handler(signum: int, _: object) -> None: - record.append(signum) - - old = signal.signal(signal.SIGFPE, handler) - try: - signal_raise(signal.SIGFPE) - finally: - signal.signal(signal.SIGFPE, old) - assert record == [signal.SIGFPE] +T = TypeVar("T") async def test_ConflictDetector() -> None: @@ -116,9 +107,11 @@ async def f() -> None: # pragma: no cover import asyncio if sys.version_info < (3, 11): - # not bothering to type this one - @asyncio.coroutine # type: ignore[misc] - def generator_based_coro() -> Any: # pragma: no cover + + @asyncio.coroutine + def generator_based_coro() -> ( + Generator[Coroutine[None, None, None], None, None] + ): # pragma: no cover yield from asyncio.sleep(1) with pytest.raises(TypeError) as excinfo: @@ -147,12 +140,13 @@ def generator_based_coro() -> Any: # pragma: no cover assert "appears to be synchronous" in str(excinfo.value) - async def async_gen(_: object) -> Any: # pragma: no cover + async def async_gen( + _: object, + ) -> AsyncGenerator[None, None]: # pragma: no cover yield - # does not give arg-type typing error with pytest.raises(TypeError) as excinfo: - coroutine_or_error(async_gen, [0]) # type: ignore[unused-coroutine] + coroutine_or_error(async_gen, [0]) # type: ignore[arg-type,unused-coroutine] msg = "expected an async function but got an async generator" assert msg in str(excinfo.value) diff --git a/src/trio/_tests/type_tests/path.py b/src/trio/_tests/type_tests/path.py index 6ea6717a6e..6749d06276 100644 --- a/src/trio/_tests/type_tests/path.py +++ b/src/trio/_tests/type_tests/path.py @@ -48,17 +48,15 @@ def sync_attrs(path: trio.Path) -> None: assert_type(path.as_posix(), str) assert_type(path.as_uri(), str) assert_type(path.is_absolute(), bool) - if sys.version_info > (3, 9): - assert_type(path.is_relative_to(path), bool) + assert_type(path.is_relative_to(path), bool) assert_type(path.is_reserved(), bool) assert_type(path.joinpath(path, "folder"), trio.Path) assert_type(path.match("*.py"), bool) assert_type(path.relative_to("/usr"), trio.Path) - if sys.version_info > (3, 12): - assert_type(path.relative_to("/", walk_up=True), bool) + if sys.version_info >= (3, 12): + assert_type(path.relative_to("/", walk_up=True), trio.Path) assert_type(path.with_name("filename.txt"), trio.Path) - if sys.version_info > (3, 9): - assert_type(path.with_stem("readme"), trio.Path) + assert_type(path.with_stem("readme"), trio.Path) assert_type(path.with_suffix(".log"), trio.Path) @@ -75,7 +73,7 @@ async def async_attrs(path: trio.Path) -> None: assert_type(await path.group(), str) assert_type(await path.is_dir(), bool) assert_type(await path.is_file(), bool) - if sys.version_info > (3, 12): + if sys.version_info >= (3, 12): assert_type(await path.is_junction(), bool) if sys.platform != "win32": assert_type(await path.is_mount(), bool) @@ -95,8 +93,7 @@ async def async_attrs(path: trio.Path) -> None: assert_type(await path.owner(), str) assert_type(await path.read_bytes(), bytes) assert_type(await path.read_text(encoding="utf16", errors="replace"), str) - if sys.version_info > (3, 9): - assert_type(await path.readlink(), trio.Path) + assert_type(await path.readlink(), trio.Path) assert_type(await path.rename("another"), trio.Path) assert_type(await path.replace(path), trio.Path) assert_type(await path.resolve(), trio.Path) @@ -107,7 +104,7 @@ async def async_attrs(path: trio.Path) -> None: assert_type(await path.rmdir(), None) assert_type(await path.samefile("something_else"), bool) assert_type(await path.symlink_to("somewhere"), None) - if sys.version_info > (3, 10): + if sys.version_info >= (3, 10): assert_type(await path.hardlink_to("elsewhere"), None) assert_type(await path.touch(), None) assert_type(await path.unlink(missing_ok=True), None) diff --git a/src/trio/_tests/type_tests/raisesgroup.py b/src/trio/_tests/type_tests/raisesgroup.py index e637ace076..4d5ed4882c 100644 --- a/src/trio/_tests/type_tests/raisesgroup.py +++ b/src/trio/_tests/type_tests/raisesgroup.py @@ -1,21 +1,7 @@ -"""The typing of RaisesGroup involves a lot of deception and lies, since AFAIK what we -actually want to achieve is ~impossible. This is because we specify what we expect with -instances of RaisesGroup and exception classes, but excinfo.value will be instances of -[Base]ExceptionGroup and instances of exceptions. So we need to "translate" from -RaisesGroup to ExceptionGroup. - -The way it currently works is that RaisesGroup[E] corresponds to -ExceptionInfo[BaseExceptionGroup[E]], so the top-level group will be correct. But -RaisesGroup[RaisesGroup[ValueError]] will become -ExceptionInfo[BaseExceptionGroup[RaisesGroup[ValueError]]]. To get around that we specify -RaisesGroup as a subclass of BaseExceptionGroup during type checking - which should mean -that most static type checking for end users should be mostly correct. -""" - from __future__ import annotations import sys -from typing import Union +from typing import Callable, Union from trio.testing import Matcher, RaisesGroup from typing_extensions import assert_type @@ -26,17 +12,6 @@ # split into functions to isolate the different scopes -def check_inheritance_and_assignments() -> None: - # Check inheritance - _: BaseExceptionGroup[ValueError] = RaisesGroup(ValueError) - _ = RaisesGroup(RaisesGroup(ValueError)) # type: ignore - - a: BaseExceptionGroup[BaseExceptionGroup[ValueError]] - a = RaisesGroup(RaisesGroup(ValueError)) - a = BaseExceptionGroup("", (BaseExceptionGroup("", (ValueError(),)),)) - assert a - - def check_matcher_typevar_default(e: Matcher) -> object: assert e.exception_type is not None exc: type[BaseException] = e.exception_type @@ -46,29 +21,32 @@ def check_matcher_typevar_default(e: Matcher) -> object: def check_basic_contextmanager() -> None: - # One level of Group is correctly translated - except it's a BaseExceptionGroup - # instead of an ExceptionGroup. with RaisesGroup(ValueError) as e: raise ExceptionGroup("foo", (ValueError(),)) - assert_type(e.value, BaseExceptionGroup[ValueError]) + assert_type(e.value, ExceptionGroup[ValueError]) def check_basic_matches() -> None: # check that matches gets rid of the naked ValueError in the union exc: ExceptionGroup[ValueError] | ValueError = ExceptionGroup("", (ValueError(),)) if RaisesGroup(ValueError).matches(exc): - assert_type(exc, BaseExceptionGroup[ValueError]) + assert_type(exc, ExceptionGroup[ValueError]) + + # also check that BaseExceptionGroup shows up for BaseExceptions + if RaisesGroup(KeyboardInterrupt).matches(exc): + assert_type(exc, BaseExceptionGroup[KeyboardInterrupt]) def check_matches_with_different_exception_type() -> None: - # This should probably raise some type error somewhere, since - # ValueError != KeyboardInterrupt e: BaseExceptionGroup[KeyboardInterrupt] = BaseExceptionGroup( "", (KeyboardInterrupt(),), ) + + # note: it might be tempting to have this warn. + # however, that isn't possible with current typing if RaisesGroup(ValueError).matches(e): - assert_type(e, BaseExceptionGroup[ValueError]) + assert_type(e, ExceptionGroup[ValueError]) def check_matcher_init() -> None: @@ -134,16 +112,7 @@ def handle_value(e: BaseExceptionGroup[ValueError]) -> bool: def raisesgroup_narrow_baseexceptiongroup() -> None: - """Check type narrowing specifically for the container exceptiongroup. - This is not currently working, and after playing around with it for a bit - I think the only way is to introduce a subclass `NonBaseRaisesGroup`, and overload - `__new__` in Raisesgroup to return the subclass when exceptions are non-base. - (or make current class BaseRaisesGroup and introduce RaisesGroup for non-base) - I encountered problems trying to type this though, see - https://github.com/python/mypy/issues/17251 - That is probably possible to work around by entirely using `__new__` instead of - `__init__`, but........ ugh. - """ + """Check type narrowing specifically for the container exceptiongroup.""" def handle_group(e: ExceptionGroup[Exception]) -> bool: return True @@ -151,42 +120,36 @@ def handle_group(e: ExceptionGroup[Exception]) -> bool: def handle_group_value(e: ExceptionGroup[ValueError]) -> bool: return True - # should work, but BaseExceptionGroup does not get narrowed to ExceptionGroup - RaisesGroup(ValueError, check=handle_group_value) # type: ignore + RaisesGroup(ValueError, check=handle_group_value) - # should work, but BaseExceptionGroup does not get narrowed to ExceptionGroup - RaisesGroup(Exception, check=handle_group) # type: ignore + RaisesGroup(Exception, check=handle_group) def check_matcher_transparent() -> None: with RaisesGroup(Matcher(ValueError)) as e: ... _: BaseExceptionGroup[ValueError] = e.value - assert_type(e.value, BaseExceptionGroup[ValueError]) + assert_type(e.value, ExceptionGroup[ValueError]) def check_nested_raisesgroups_contextmanager() -> None: with RaisesGroup(RaisesGroup(ValueError)) as excinfo: raise ExceptionGroup("foo", (ValueError(),)) - # thanks to inheritance this assignment works _: BaseExceptionGroup[BaseExceptionGroup[ValueError]] = excinfo.value - # and it can mostly be treated like an exceptiongroup - print(excinfo.value.exceptions[0].exceptions[0]) - # but assert_type reveals the lies - print(type(excinfo.value)) # would print "ExceptionGroup" - # typing says it's a BaseExceptionGroup assert_type( excinfo.value, - BaseExceptionGroup[RaisesGroup[ValueError]], + ExceptionGroup[ExceptionGroup[ValueError]], ) - print(type(excinfo.value.exceptions[0])) # would print "ExceptionGroup" - # but type checkers are utterly confused assert_type( excinfo.value.exceptions[0], - Union[RaisesGroup[ValueError], BaseExceptionGroup[RaisesGroup[ValueError]]], + # this union is because of how typeshed defines .exceptions + Union[ + ExceptionGroup[ValueError], + ExceptionGroup[ExceptionGroup[ValueError]], + ], ) @@ -196,9 +159,9 @@ def check_nested_raisesgroups_matches() -> None: "", (ExceptionGroup("", (ValueError(),)),), ) - # has the same problems as check_nested_raisesgroups_contextmanager + if RaisesGroup(RaisesGroup(ValueError)).matches(exc): - assert_type(exc, BaseExceptionGroup[RaisesGroup[ValueError]]) + assert_type(exc, ExceptionGroup[ExceptionGroup[ValueError]]) def check_multiple_exceptions_1() -> None: @@ -206,7 +169,7 @@ def check_multiple_exceptions_1() -> None: b = RaisesGroup(Matcher(ValueError), Matcher(ValueError)) c = RaisesGroup(ValueError, Matcher(ValueError)) - d: BaseExceptionGroup[ValueError] + d: RaisesGroup[ValueError] d = a d = b d = c @@ -219,7 +182,7 @@ def check_multiple_exceptions_2() -> None: b = RaisesGroup(Matcher(ValueError), TypeError) c = RaisesGroup(ValueError, TypeError) - d: BaseExceptionGroup[Exception] + d: RaisesGroup[Exception] d = a d = b d = c @@ -252,3 +215,25 @@ def check_raisesgroup_overloads() -> None: # if they're both false we can of course specify nested raisesgroup RaisesGroup(RaisesGroup(ValueError)) + + +def check_triple_nested_raisesgroup() -> None: + with RaisesGroup(RaisesGroup(RaisesGroup(ValueError))) as e: + assert_type(e.value, ExceptionGroup[ExceptionGroup[ExceptionGroup[ValueError]]]) + + +def check_check_typing() -> None: + # mypy issue is https://github.com/python/mypy/issues/18185 + + # fmt: off + # mypy raises an error on `assert_type` + # pyright raises an error on `RaisesGroup(ValueError).check` + # to satisfy both, need to disable formatting and put it on one line + assert_type(RaisesGroup(ValueError).check, # type: ignore + Union[ + Callable[[BaseExceptionGroup[ValueError]], None], + Callable[[ExceptionGroup[ValueError]], None], + None, + ], + ) + # fmt: on diff --git a/src/trio/_threads.py b/src/trio/_threads.py index dcab70babb..c0e32f2f3b 100644 --- a/src/trio/_threads.py +++ b/src/trio/_threads.py @@ -146,8 +146,9 @@ class ThreadPlaceholder: # Types for the to_thread_run_sync message loop @attrs.frozen(eq=False, slots=False) -class Run(Generic[RetT]): - afn: Callable[..., Awaitable[RetT]] +# Explicit .../"Any" is not allowed +class Run(Generic[RetT]): # type: ignore[misc] + afn: Callable[..., Awaitable[RetT]] # type: ignore[misc] args: tuple[object, ...] context: contextvars.Context = attrs.field( init=False, @@ -205,8 +206,9 @@ def in_trio_thread() -> None: @attrs.frozen(eq=False, slots=False) -class RunSync(Generic[RetT]): - fn: Callable[..., RetT] +# Explicit .../"Any" is not allowed +class RunSync(Generic[RetT]): # type: ignore[misc] + fn: Callable[..., RetT] # type: ignore[misc] args: tuple[object, ...] context: contextvars.Context = attrs.field( init=False, @@ -521,7 +523,8 @@ def _send_message_to_trio( return message_to_trio.queue.get().unwrap() -def from_thread_run( +# Explicit "Any" is not allowed +def from_thread_run( # type: ignore[misc] afn: Callable[..., Awaitable[RetT]], *args: object, trio_token: TrioToken | None = None, @@ -565,7 +568,8 @@ def from_thread_run( return _send_message_to_trio(trio_token, Run(afn, args)) -def from_thread_run_sync( +# Explicit "Any" is not allowed +def from_thread_run_sync( # type: ignore[misc] fn: Callable[..., RetT], *args: object, trio_token: TrioToken | None = None, diff --git a/src/trio/_tools/gen_exports.py b/src/trio/_tools/gen_exports.py index c762ee138a..b4db597b63 100755 --- a/src/trio/_tools/gen_exports.py +++ b/src/trio/_tools/gen_exports.py @@ -34,12 +34,11 @@ import sys -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._ki import enable_ki_protection from ._run import GLOBAL_RUN_CONTEXT """ -TEMPLATE = """sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True -try: +TEMPLATE = """try: return{}GLOBAL_RUN_CONTEXT.{}.{} except AttributeError: raise RuntimeError("must be called from async context") from None @@ -237,7 +236,7 @@ def gen_public_wrappers_source(file: File) -> str: is_cm = False # Remove decorators - method.decorator_list = [] + method.decorator_list = [ast.Name("enable_ki_protection")] # Create pass through arguments new_args = create_passthrough_args(method) diff --git a/src/trio/_util.py b/src/trio/_util.py index 0bb9b89ddc..9b5d1a9436 100644 --- a/src/trio/_util.py +++ b/src/trio/_util.py @@ -3,9 +3,7 @@ import collections.abc import inspect -import os import signal -import threading from abc import ABCMeta from collections.abc import Awaitable, Callable, Sequence from functools import update_wrapper @@ -22,7 +20,8 @@ import trio -CallT = TypeVar("CallT", bound=Callable[..., Any]) +# Explicit "Any" is not allowed +CallT = TypeVar("CallT", bound=Callable[..., Any]) # type: ignore[misc] T = TypeVar("T") RetT = TypeVar("RetT") @@ -35,60 +34,6 @@ PosArgsT = TypeVarTuple("PosArgsT") -if TYPE_CHECKING: - # Don't type check the implementation below, pthread_kill does not exist on Windows. - def signal_raise(signum: int) -> None: ... - - -# Equivalent to the C function raise(), which Python doesn't wrap -elif os.name == "nt": - # On Windows, os.kill exists but is really weird. - # - # If you give it CTRL_C_EVENT or CTRL_BREAK_EVENT, it tries to deliver - # those using GenerateConsoleCtrlEvent. But I found that when I tried - # to run my test normally, it would freeze waiting... unless I added - # print statements, in which case the test suddenly worked. So I guess - # these signals are only delivered if/when you access the console? I - # don't really know what was going on there. From reading the - # GenerateConsoleCtrlEvent docs I don't know how it worked at all. - # - # I later spent a bunch of time trying to make GenerateConsoleCtrlEvent - # work for creating synthetic control-C events, and... failed - # utterly. There are lots of details in the code and comments - # removed/added at this commit: - # https://github.com/python-trio/trio/commit/95843654173e3e826c34d70a90b369ba6edf2c23 - # - # OTOH, if you pass os.kill any *other* signal number... then CPython - # just calls TerminateProcess (wtf). - # - # So, anyway, os.kill is not so useful for testing purposes. Instead, - # we use raise(): - # - # https://msdn.microsoft.com/en-us/library/dwwzkt4c.aspx - # - # Have to import cffi inside the 'if os.name' block because we don't - # depend on cffi on non-Windows platforms. (It would be easy to switch - # this to ctypes though if we ever remove the cffi dependency.) - # - # Some more information: - # https://bugs.python.org/issue26350 - # - # Anyway, we use this for two things: - # - redelivering unhandled signals - # - generating synthetic signals for tests - # and for both of those purposes, 'raise' works fine. - import cffi - - _ffi = cffi.FFI() - _ffi.cdef("int raise(int);") - _lib = _ffi.dlopen("api-ms-win-crt-runtime-l1-1-0.dll") - signal_raise = getattr(_lib, "raise") -else: - - def signal_raise(signum: int) -> None: - signal.pthread_kill(threading.get_ident(), signum) - - # See: #461 as to why this is needed. # The gist is that threading.main_thread() has the capability to lie to us # if somebody else edits the threading ident cache to replace the main @@ -231,14 +176,16 @@ def __exit__( self._held = False -def async_wraps( +# Explicit "Any" is not allowed +def async_wraps( # type: ignore[misc] cls: type[object], wrapped_cls: type[object], attr_name: str, ) -> Callable[[CallT], CallT]: """Similar to wraps, but for async wrappers of non-async functions.""" - def decorator(func: CallT) -> CallT: + # Explicit "Any" is not allowed + def decorator(func: CallT) -> CallT: # type: ignore[misc] func.__name__ = attr_name func.__qualname__ = f"{cls.__qualname__}.{attr_name}" @@ -301,11 +248,15 @@ def open_memory_channel(max_buffer_size: int) -> Tuple[ but at least it becomes possible to write those. """ - def __init__(self, fn: Callable[..., RetT]) -> None: + # Explicit .../"Any" is not allowed + def __init__( # type: ignore[misc] + self, + fn: Callable[..., RetT], + ) -> None: update_wrapper(self, fn) self._fn = fn - def __call__(self, *args: Any, **kwargs: Any) -> RetT: + def __call__(self, *args: object, **kwargs: object) -> RetT: return self._fn(*args, **kwargs) def __getitem__(self, subscript: object) -> Self: @@ -394,9 +345,11 @@ def name_asyncgen(agen: AsyncGeneratorType[object, NoReturn]) -> str: # work around a pyright error if TYPE_CHECKING: - Fn = TypeVar("Fn", bound=Callable[..., object]) + # Explicit .../"Any" is not allowed + Fn = TypeVar("Fn", bound=Callable[..., object]) # type: ignore[misc] - def wraps( + # Explicit .../"Any" is not allowed + def wraps( # type: ignore[misc] wrapped: Callable[..., object], assigned: Sequence[str] = ..., updated: Sequence[str] = ..., diff --git a/src/trio/testing/_check_streams.py b/src/trio/testing/_check_streams.py index 3bf5442814..e58e2ddfed 100644 --- a/src/trio/testing/_check_streams.py +++ b/src/trio/testing/_check_streams.py @@ -311,7 +311,7 @@ async def expect_cancelled( # receive stream causes it to wake up. async with _ForceCloseBoth(await stream_maker()) as (s, r): - async def receive_expecting_closed(): + async def receive_expecting_closed() -> None: with _assert_raises(_core.ClosedResourceError): await r.receive_some(10) diff --git a/src/trio/testing/_fake_net.py b/src/trio/testing/_fake_net.py index 2cf319cb2b..1c3af8c5ef 100644 --- a/src/trio/testing/_fake_net.py +++ b/src/trio/testing/_fake_net.py @@ -36,6 +36,8 @@ from typing_extensions import Buffer, Self, TypeAlias + from trio._socket import AddressFormat + IPAddress: TypeAlias = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] @@ -210,7 +212,7 @@ def __init__( family: AddressFamily, type: SocketKind, proto: int, - ): + ) -> None: self._fake_net = fake_net if not family: # pragma: no cover @@ -313,7 +315,7 @@ async def _sendmsg( buffers: Iterable[Buffer], ancdata: Iterable[tuple[int, int, Buffer]] = (), flags: int = 0, - address: Any | None = None, + address: AddressFormat | None = None, ) -> int: self._check_closed() @@ -357,7 +359,12 @@ async def _recvmsg_into( buffers: Iterable[Buffer], ancbufsize: int = 0, flags: int = 0, - ) -> tuple[int, list[tuple[int, int, bytes]], int, Any]: + ) -> tuple[ + int, + list[tuple[int, int, bytes]], + int, + tuple[str, int] | tuple[str, int, int, int], + ]: if ancbufsize != 0: raise NotImplementedError("FakeNet doesn't support ancillary data") if flags != 0: @@ -496,10 +503,14 @@ async def sendto( self, __data: Buffer, # noqa: PYI063 __flags: int, - __address: tuple[object, ...] | str | None | Buffer, + __address: tuple[object, ...] | str | Buffer | None, ) -> int: ... - async def sendto(self, *args: Any) -> int: + # Explicit "Any" is not allowed + async def sendto( # type: ignore[misc] + self, + *args: Any, + ) -> int: data: Buffer flags: int address: tuple[object, ...] | str | Buffer @@ -520,7 +531,11 @@ async def recv_into(self, buf: Buffer, nbytes: int = 0, flags: int = 0) -> int: got_bytes, _address = await self.recvfrom_into(buf, nbytes, flags) return got_bytes - async def recvfrom(self, bufsize: int, flags: int = 0) -> tuple[bytes, Any]: + async def recvfrom( + self, + bufsize: int, + flags: int = 0, + ) -> tuple[bytes, AddressFormat]: data, _ancdata, _msg_flags, address = await self._recvmsg(bufsize, flags) return data, address @@ -529,7 +544,7 @@ async def recvfrom_into( buf: Buffer, nbytes: int = 0, flags: int = 0, - ) -> tuple[int, Any]: + ) -> tuple[int, AddressFormat]: if nbytes != 0 and nbytes != memoryview(buf).nbytes: raise NotImplementedError("partial recvfrom_into") got_nbytes, _ancdata, _msg_flags, address = await self._recvmsg_into( @@ -544,7 +559,7 @@ async def _recvmsg( bufsize: int, ancbufsize: int = 0, flags: int = 0, - ) -> tuple[bytes, list[tuple[int, int, bytes]], int, Any]: + ) -> tuple[bytes, list[tuple[int, int, bytes]], int, AddressFormat]: buf = bytearray(bufsize) got_nbytes, ancdata, msg_flags, address = await self._recvmsg_into( [buf], diff --git a/src/trio/testing/_memory_streams.py b/src/trio/testing/_memory_streams.py index b5a4c418dd..3564e9699a 100644 --- a/src/trio/testing/_memory_streams.py +++ b/src/trio/testing/_memory_streams.py @@ -113,7 +113,7 @@ def __init__( send_all_hook: AsyncHook | None = None, wait_send_all_might_not_block_hook: AsyncHook | None = None, close_hook: SyncHook | None = None, - ): + ) -> None: self._conflict_detector = _util.ConflictDetector( "another task is using this stream", ) @@ -223,7 +223,7 @@ def __init__( self, receive_some_hook: AsyncHook | None = None, close_hook: SyncHook | None = None, - ): + ) -> None: self._conflict_detector = _util.ConflictDetector( "another task is using this stream", ) @@ -548,7 +548,7 @@ async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: class _LockstepSendStream(SendStream): - def __init__(self, lbq: _LockstepByteQueue): + def __init__(self, lbq: _LockstepByteQueue) -> None: self._lbq = lbq def close(self) -> None: @@ -566,7 +566,7 @@ async def wait_send_all_might_not_block(self) -> None: class _LockstepReceiveStream(ReceiveStream): - def __init__(self, lbq: _LockstepByteQueue): + def __init__(self, lbq: _LockstepByteQueue) -> None: self._lbq = lbq def close(self) -> None: diff --git a/src/trio/testing/_raises_group.py b/src/trio/testing/_raises_group.py index e93e0bf3bd..700c16ca6a 100644 --- a/src/trio/testing/_raises_group.py +++ b/src/trio/testing/_raises_group.py @@ -2,7 +2,6 @@ import re import sys -from contextlib import AbstractContextManager from re import Pattern from typing import ( TYPE_CHECKING, @@ -25,6 +24,7 @@ from _pytest._code.code import ExceptionChainRepr, ReprExceptionInfo, Traceback from typing_extensions import TypeGuard, TypeVar + # this conditional definition is because we want to allow a TypeVar default MatchE = TypeVar( "MatchE", bound=BaseException, @@ -35,12 +35,16 @@ from typing import TypeVar MatchE = TypeVar("MatchE", bound=BaseException, covariant=True) + # RaisesGroup doesn't work with a default. -E = TypeVar("E", bound=BaseException, covariant=True) -# These two typevars are special cased in sphinx config to workaround lookup bugs. +BaseExcT_co = TypeVar("BaseExcT_co", bound=BaseException, covariant=True) +BaseExcT_1 = TypeVar("BaseExcT_1", bound=BaseException) +BaseExcT_2 = TypeVar("BaseExcT_2", bound=BaseException) +ExcT_1 = TypeVar("ExcT_1", bound=Exception) +ExcT_2 = TypeVar("ExcT_2", bound=Exception) if sys.version_info < (3, 11): - from exceptiongroup import BaseExceptionGroup + from exceptiongroup import BaseExceptionGroup, ExceptionGroup @final @@ -52,7 +56,7 @@ class _ExceptionInfo(Generic[MatchE]): def __init__( self, excinfo: tuple[type[MatchE], MatchE, types.TracebackType] | None, - ): + ) -> None: self._excinfo = excinfo def fill_unfilled( @@ -147,7 +151,7 @@ def _stringify_exception(exc: BaseException) -> str: # String patterns default to including the unicode flag. -_regex_no_flags = re.compile("").flags +_REGEX_NO_FLAGS = re.compile(r"").flags @final @@ -174,7 +178,7 @@ def __init__( exception_type: type[MatchE], match: str | Pattern[str] = ..., check: Callable[[MatchE], bool] = ..., - ): ... + ) -> None: ... @overload def __init__( @@ -183,10 +187,10 @@ def __init__( match: str | Pattern[str], # If exception_type is not provided, check() must do any typechecks itself. check: Callable[[BaseException], bool] = ..., - ): ... + ) -> None: ... @overload - def __init__(self, *, check: Callable[[BaseException], bool]): ... + def __init__(self, *, check: Callable[[BaseException], bool]) -> None: ... def __init__( self, @@ -238,7 +242,7 @@ def matches(self, exception: BaseException) -> TypeGuard[MatchE]: return False # If exception_type is None check() accepts BaseException. # If non-none, we have done an isinstance check above. - return self.check is None or self.check(cast(MatchE, exception)) + return self.check is None or self.check(cast("MatchE", exception)) def __str__(self) -> str: reqs = [] @@ -247,39 +251,15 @@ def __str__(self) -> str: if (match := self.match) is not None: # If no flags were specified, discard the redundant re.compile() here. reqs.append( - f"match={match.pattern if match.flags == _regex_no_flags else match!r}", + f"match={match.pattern if match.flags == _REGEX_NO_FLAGS else match!r}", ) if self.check is not None: reqs.append(f"check={self.check!r}") return f'Matcher({", ".join(reqs)})' -# typing this has been somewhat of a nightmare, with the primary difficulty making -# the return type of __enter__ correct. Ideally it would function like this -# with RaisesGroup(RaisesGroup(ValueError)) as excinfo: -# ... -# assert_type(excinfo.value, ExceptionGroup[ExceptionGroup[ValueError]]) -# in addition to all the simple cases, but getting all the way to the above seems maybe -# impossible. The type being RaisesGroup[RaisesGroup[ValueError]] is probably also fine, -# as long as I add fake properties corresponding to the properties of exceptiongroup. But -# I had trouble with it handling recursive cases properly. - -# Current solution settles on the above giving BaseExceptionGroup[RaisesGroup[ValueError]], and it not -# being a type error to do `with RaisesGroup(ValueError()): ...` - but that will error on runtime. - -# We lie to type checkers that we inherit, so excinfo.value and sub-exceptiongroups can be treated as ExceptionGroups -if TYPE_CHECKING: - SuperClass = BaseExceptionGroup -else: - # At runtime, use a redundant Generic base class which effectively gets ignored. - SuperClass = Generic - - @final -class RaisesGroup( - AbstractContextManager[ExceptionInfo[BaseExceptionGroup[E]]], - SuperClass[E], -): +class RaisesGroup(Generic[BaseExcT_co]): """Contextmanager for checking for an expected `ExceptionGroup`. This works similar to ``pytest.raises``, and a version of it will hopefully be added upstream, after which this can be deprecated and removed. See https://github.com/pytest-dev/pytest/issues/11538 @@ -332,61 +312,121 @@ class RaisesGroup( even though it generally does not care about the order of the exceptions in the group. To avoid the above you should specify the first ValueError with a Matcher as well. - - It is also not typechecked perfectly, and that's likely not possible with the current approach. Most common usage should work without issue though. """ - # needed for pyright, since BaseExceptionGroup.__new__ takes two arguments - if TYPE_CHECKING: - - def __new__(cls, *args: object, **kwargs: object) -> RaisesGroup[E]: ... - # allow_unwrapped=True requires: singular exception, exception not being # RaisesGroup instance, match is None, check is None @overload def __init__( self, - exception: type[E] | Matcher[E], + exception: type[BaseExcT_co] | Matcher[BaseExcT_co], *, allow_unwrapped: Literal[True], flatten_subgroups: bool = False, - match: None = None, - check: None = None, - ): ... + ) -> None: ... # flatten_subgroups = True also requires no nested RaisesGroup @overload def __init__( self, - exception: type[E] | Matcher[E], - *other_exceptions: type[E] | Matcher[E], - allow_unwrapped: Literal[False] = False, + exception: type[BaseExcT_co] | Matcher[BaseExcT_co], + *other_exceptions: type[BaseExcT_co] | Matcher[BaseExcT_co], flatten_subgroups: Literal[True], match: str | Pattern[str] | None = None, - check: Callable[[BaseExceptionGroup[E]], bool] | None = None, - ): ... + check: Callable[[BaseExceptionGroup[BaseExcT_co]], bool] | None = None, + ) -> None: ... + + # simplify the typevars if possible (the following 3 are equivalent but go simpler->complicated) + # ... the first handles RaisesGroup[ValueError], the second RaisesGroup[ExceptionGroup[ValueError]], + # the third RaisesGroup[ValueError | ExceptionGroup[ValueError]]. + # ... otherwise, we will get results like RaisesGroup[ValueError | ExceptionGroup[Never]] (I think) + # (technically correct but misleading) + @overload + def __init__( + self: RaisesGroup[ExcT_1], + exception: type[ExcT_1] | Matcher[ExcT_1], + *other_exceptions: type[ExcT_1] | Matcher[ExcT_1], + match: str | Pattern[str] | None = None, + check: Callable[[ExceptionGroup[ExcT_1]], bool] | None = None, + ) -> None: ... @overload def __init__( - self, - exception: type[E] | Matcher[E] | E, - *other_exceptions: type[E] | Matcher[E] | E, - allow_unwrapped: Literal[False] = False, - flatten_subgroups: Literal[False] = False, + self: RaisesGroup[ExceptionGroup[ExcT_2]], + exception: RaisesGroup[ExcT_2], + *other_exceptions: RaisesGroup[ExcT_2], match: str | Pattern[str] | None = None, - check: Callable[[BaseExceptionGroup[E]], bool] | None = None, - ): ... + check: Callable[[ExceptionGroup[ExceptionGroup[ExcT_2]]], bool] | None = None, + ) -> None: ... + @overload def __init__( - self, - exception: type[E] | Matcher[E] | E, - *other_exceptions: type[E] | Matcher[E] | E, + self: RaisesGroup[ExcT_1 | ExceptionGroup[ExcT_2]], + exception: type[ExcT_1] | Matcher[ExcT_1] | RaisesGroup[ExcT_2], + *other_exceptions: type[ExcT_1] | Matcher[ExcT_1] | RaisesGroup[ExcT_2], + match: str | Pattern[str] | None = None, + check: ( + Callable[[ExceptionGroup[ExcT_1 | ExceptionGroup[ExcT_2]]], bool] | None + ) = None, + ) -> None: ... + + # same as the above 3 but handling BaseException + @overload + def __init__( + self: RaisesGroup[BaseExcT_1], + exception: type[BaseExcT_1] | Matcher[BaseExcT_1], + *other_exceptions: type[BaseExcT_1] | Matcher[BaseExcT_1], + match: str | Pattern[str] | None = None, + check: Callable[[BaseExceptionGroup[BaseExcT_1]], bool] | None = None, + ) -> None: ... + + @overload + def __init__( + self: RaisesGroup[BaseExceptionGroup[BaseExcT_2]], + exception: RaisesGroup[BaseExcT_2], + *other_exceptions: RaisesGroup[BaseExcT_2], + match: str | Pattern[str] | None = None, + check: ( + Callable[[BaseExceptionGroup[BaseExceptionGroup[BaseExcT_2]]], bool] | None + ) = None, + ) -> None: ... + + @overload + def __init__( + self: RaisesGroup[BaseExcT_1 | BaseExceptionGroup[BaseExcT_2]], + exception: type[BaseExcT_1] | Matcher[BaseExcT_1] | RaisesGroup[BaseExcT_2], + *other_exceptions: type[BaseExcT_1] + | Matcher[BaseExcT_1] + | RaisesGroup[BaseExcT_2], + match: str | Pattern[str] | None = None, + check: ( + Callable[ + [BaseExceptionGroup[BaseExcT_1 | BaseExceptionGroup[BaseExcT_2]]], + bool, + ] + | None + ) = None, + ) -> None: ... + + def __init__( + self: RaisesGroup[ExcT_1 | BaseExcT_1 | BaseExceptionGroup[BaseExcT_2]], + exception: type[BaseExcT_1] | Matcher[BaseExcT_1] | RaisesGroup[BaseExcT_2], + *other_exceptions: type[BaseExcT_1] + | Matcher[BaseExcT_1] + | RaisesGroup[BaseExcT_2], allow_unwrapped: bool = False, flatten_subgroups: bool = False, match: str | Pattern[str] | None = None, - check: Callable[[BaseExceptionGroup[E]], bool] | None = None, + check: ( + Callable[[BaseExceptionGroup[BaseExcT_1]], bool] + | Callable[[ExceptionGroup[ExcT_1]], bool] + | None + ) = None, ): - self.expected_exceptions: tuple[type[E] | Matcher[E] | E, ...] = ( + self.expected_exceptions: tuple[ + type[BaseExcT_co] | Matcher[BaseExcT_co] | RaisesGroup[BaseException], + ..., + ] = ( exception, *other_exceptions, ) @@ -448,8 +488,19 @@ def __init__( " RaisesGroup.", ) - def __enter__(self) -> ExceptionInfo[BaseExceptionGroup[E]]: - self.excinfo: ExceptionInfo[BaseExceptionGroup[E]] = ExceptionInfo.for_later() + @overload + def __enter__( + self: RaisesGroup[ExcT_1], + ) -> ExceptionInfo[ExceptionGroup[ExcT_1]]: ... + @overload + def __enter__( + self: RaisesGroup[BaseExcT_1], + ) -> ExceptionInfo[BaseExceptionGroup[BaseExcT_1]]: ... + + def __enter__(self) -> ExceptionInfo[BaseExceptionGroup[BaseException]]: + self.excinfo: ExceptionInfo[BaseExceptionGroup[BaseExcT_co]] = ( + ExceptionInfo.for_later() + ) return self.excinfo def _unroll_exceptions( @@ -466,10 +517,21 @@ def _unroll_exceptions( res.append(exc) return res + @overload + def matches( + self: RaisesGroup[ExcT_1], + exc_val: BaseException | None, + ) -> TypeGuard[ExceptionGroup[ExcT_1]]: ... + @overload + def matches( + self: RaisesGroup[BaseExcT_1], + exc_val: BaseException | None, + ) -> TypeGuard[BaseExceptionGroup[BaseExcT_1]]: ... + def matches( self, exc_val: BaseException | None, - ) -> TypeGuard[BaseExceptionGroup[E]]: + ) -> TypeGuard[BaseExceptionGroup[BaseExcT_co]]: """Check if an exception matches the requirements of this RaisesGroup. Example:: @@ -502,8 +564,6 @@ def matches( _stringify_exception(exc_val), ): return False - if self.check is not None and not self.check(exc_val): - return False remaining_exceptions = list(self.expected_exceptions) actual_exceptions: Sequence[BaseException] = exc_val.exceptions @@ -514,9 +574,6 @@ def matches( if len(actual_exceptions) != len(self.expected_exceptions): return False - # it should be possible to get RaisesGroup.matches typed so as not to - # need type: ignore, but I'm not sure that's possible while also having it - # transparent for the end user. for e in actual_exceptions: for rem_e in remaining_exceptions: if ( @@ -524,11 +581,14 @@ def matches( or (isinstance(rem_e, RaisesGroup) and rem_e.matches(e)) or (isinstance(rem_e, Matcher) and rem_e.matches(e)) ): - remaining_exceptions.remove(rem_e) # type: ignore[arg-type] + remaining_exceptions.remove(rem_e) break else: return False - return True + + # only run `self.check` once we know `exc_val` is correct. (see the types) + # unfortunately mypy isn't smart enough to recognize the above `for`s as narrowing. + return self.check is None or self.check(exc_val) # type: ignore[arg-type] def __exit__( self, @@ -549,7 +609,7 @@ def __exit__( # Cast to narrow the exception type now that it's verified. exc_info = cast( - "tuple[type[BaseExceptionGroup[E]], BaseExceptionGroup[E], types.TracebackType]", + "tuple[type[BaseExceptionGroup[BaseExcT_co]], BaseExceptionGroup[BaseExcT_co], types.TracebackType]", (exc_type, exc_val, exc_tb), ) self.excinfo.fill_unfilled(exc_info) diff --git a/test-requirements.in b/test-requirements.in index add7798d05..809e171e3b 100644 --- a/test-requirements.in +++ b/test-requirements.in @@ -11,8 +11,9 @@ cryptography>=41.0.0 # cryptography<41 segfaults on pypy3.10 # Tools black; implementation_name == "cpython" -mypy -ruff >= 0.6.6 +mypy # Would use mypy[faster-cache], but orjson has build issues on pypy +orjson; implementation_name == "cpython" +ruff >= 0.8.0 astor # code generation uv >= 0.2.24 codespell diff --git a/test-requirements.txt b/test-requirements.txt index 314eaf4b94..87b3c581eb 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,10 +1,10 @@ # This file was autogenerated by uv via the following command: # uv pip compile --universal --python-version=3.9 test-requirements.in -o test-requirements.txt -alabaster==0.7.13 +alabaster==0.7.16 # via sphinx astor==0.8.1 # via -r test-requirements.in -astroid==3.2.4 +astroid==3.3.5 # via pylint async-generator==1.10 # via -r test-requirements.in @@ -14,7 +14,7 @@ attrs==24.2.0 # outcome babel==2.16.0 # via sphinx -black==24.8.0 ; implementation_name == 'cpython' +black==24.10.0 ; implementation_name == 'cpython' # via -r test-requirements.in certifi==2024.8.30 # via requests @@ -22,7 +22,7 @@ cffi==1.17.1 ; platform_python_implementation != 'PyPy' or os_name == 'nt' # via # -r test-requirements.in # cryptography -charset-normalizer==3.3.2 +charset-normalizer==3.4.0 # via requests click==8.1.7 ; implementation_name == 'cpython' # via black @@ -34,9 +34,9 @@ colorama==0.4.6 ; (implementation_name != 'cpython' and sys_platform == 'win32') # pylint # pytest # sphinx -coverage==7.6.1 +coverage==7.6.8 # via -r test-requirements.in -cryptography==43.0.1 +cryptography==43.0.3 # via # -r test-requirements.in # pyopenssl @@ -44,7 +44,7 @@ cryptography==43.0.1 # types-pyopenssl dill==0.3.9 # via pylint -docutils==0.20.1 +docutils==0.21.2 # via sphinx exceptiongroup==1.2.2 ; python_full_version < '3.11' # via @@ -63,15 +63,15 @@ iniconfig==2.0.0 # via pytest isort==5.13.2 # via pylint -jedi==0.19.1 ; implementation_name == 'cpython' +jedi==0.19.2 ; implementation_name == 'cpython' # via -r test-requirements.in jinja2==3.1.4 # via sphinx -markupsafe==2.1.5 +markupsafe==3.0.2 # via jinja2 mccabe==0.7.0 # via pylint -mypy==1.11.2 +mypy==1.13.0 # via -r test-requirements.in mypy-extensions==1.0.0 # via @@ -80,9 +80,11 @@ mypy-extensions==1.0.0 # mypy nodeenv==1.9.1 # via pyright +orjson==3.10.12 ; implementation_name == 'cpython' + # via -r test-requirements.in outcome==1.3.0.post0 # via -r test-requirements.in -packaging==24.1 +packaging==24.2 # via # black # pytest @@ -101,17 +103,17 @@ pycparser==2.22 ; platform_python_implementation != 'PyPy' or os_name == 'nt' # via cffi pygments==2.18.0 # via sphinx -pylint==3.2.7 +pylint==3.3.1 # via -r test-requirements.in pyopenssl==24.2.1 # via -r test-requirements.in -pyright==1.1.382.post1 +pyright==1.1.389 # via -r test-requirements.in pytest==8.3.3 # via -r test-requirements.in requests==2.32.3 # via sphinx -ruff==0.6.8 +ruff==0.8.2 # via -r test-requirements.in sniffio==1.3.1 # via -r test-requirements.in @@ -119,39 +121,40 @@ snowballstemmer==2.2.0 # via sphinx sortedcontainers==2.4.0 # via -r test-requirements.in -sphinx==7.1.2 +sphinx==7.4.7 # via -r test-requirements.in -sphinxcontrib-applehelp==1.0.4 +sphinxcontrib-applehelp==2.0.0 # via sphinx -sphinxcontrib-devhelp==1.0.2 +sphinxcontrib-devhelp==2.0.0 # via sphinx -sphinxcontrib-htmlhelp==2.0.1 +sphinxcontrib-htmlhelp==2.1.0 # via sphinx sphinxcontrib-jsmath==1.0.1 # via sphinx -sphinxcontrib-qthelp==1.0.3 +sphinxcontrib-qthelp==2.0.0 # via sphinx -sphinxcontrib-serializinghtml==1.1.5 +sphinxcontrib-serializinghtml==2.0.0 # via sphinx -tomli==2.0.1 ; python_full_version < '3.11' +tomli==2.2.1 ; python_full_version < '3.11' # via # black # mypy # pylint # pytest + # sphinx tomlkit==0.13.2 # via pylint -trustme==1.1.0 +trustme==1.2.0 # via -r test-requirements.in types-cffi==1.16.0.20240331 # via # -r test-requirements.in # types-pyopenssl -types-docutils==0.21.0.20240907 +types-docutils==0.21.0.20241128 # via -r test-requirements.in types-pyopenssl==24.1.0.20240722 # via -r test-requirements.in -types-setuptools==75.1.0.20240917 +types-setuptools==75.6.0.20241126 # via types-cffi typing-extensions==4.12.2 # via @@ -163,7 +166,7 @@ typing-extensions==4.12.2 # pyright urllib3==2.2.3 # via requests -uv==0.4.17 +uv==0.5.5 # via -r test-requirements.in -zipp==3.20.2 ; python_full_version < '3.10' +zipp==3.21.0 ; python_full_version < '3.10' # via importlib-metadata diff --git a/tests/_trio_check_attrs_aliases.py b/tests/_trio_check_attrs_aliases.py new file mode 100644 index 0000000000..b4a339dabc --- /dev/null +++ b/tests/_trio_check_attrs_aliases.py @@ -0,0 +1,22 @@ +"""Plugins are executed by Pytest before test modules. + +We use this to monkeypatch attrs.field(), so that we can detect if aliases are used for test_exports. +""" + +from typing import Any + +import attrs + +orig_field = attrs.field + + +def field(**kwargs: Any) -> Any: + original_args = kwargs.copy() + metadata = kwargs.setdefault("metadata", {}) + metadata["trio_original_args"] = original_args + return orig_field(**kwargs) + + +# Mark it as being ours, so the test knows it can actually run. +field.trio_modded = True # type: ignore +attrs.field = field