From fdcb696617312ce1b5475da4e7714a6f866084ea Mon Sep 17 00:00:00 2001 From: Oliver Berger Date: Thu, 5 Dec 2024 13:30:53 +0100 Subject: [PATCH] fix: use context in factory --- devenv.nix | 2 +- src/buvar/context.py | 44 +++++++++++++++++++++++++++++++------------- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/devenv.nix b/devenv.nix index 4276510..26e11fd 100644 --- a/devenv.nix +++ b/devenv.nix @@ -39,7 +39,7 @@ in # https://devenv.sh/languages/ languages.python = { enable = true; - version = "3.12.3"; + version = "3.10"; uv.enable = true; venv = { enable = true; diff --git a/src/buvar/context.py b/src/buvar/context.py index 2cdfd39..d67553d 100644 --- a/src/buvar/context.py +++ b/src/buvar/context.py @@ -4,6 +4,7 @@ import contextlib import contextvars import functools +import sys from . import components @@ -19,19 +20,36 @@ class StackingTaskFactory: def __init__(self, *, parent_factory=None): self.parent_factory = parent_factory - def __call__(self, loop, coro, context=None): - context = current_context().push() - token = buvar_context.set(context) - # with child(): - task = ( - self.parent_factory - if self.parent_factory is not None - else asyncio.tasks.Task - )(loop=loop, coro=coro) - try: - return task - finally: - buvar_context.reset(token) + if sys.version_info < (3, 11): + + def __call__(self, loop, coro, context=None): + component_context = current_context().push() + token = buvar_context.set(component_context) + # with child(): + task = ( + self.parent_factory + if self.parent_factory is not None + else asyncio.tasks.Task + )(loop=loop, coro=coro) + try: + return task + finally: + buvar_context.reset(token) + else: + # INFO: Task() accepts context + def __call__(self, loop, coro, context=None): + component_context = current_context().push() + token = buvar_context.set(component_context) + # with child(): + task = ( + self.parent_factory + if self.parent_factory is not None + else asyncio.tasks.Task + )(loop=loop, coro=coro, context=context) + try: + return task + finally: + buvar_context.reset(token) @classmethod def set(cls, *, loop=None):