Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add optional prefix to redis keys #74

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 42 additions & 18 deletions taskiq_redis/redis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
result_px_time: Optional[int] = None,
max_connection_pool_size: Optional[int] = None,
serializer: Optional[TaskiqSerializer] = None,
prefix_str: Optional[str] = None,
**connection_kwargs: Any,
) -> None:
"""
Expand All @@ -82,6 +83,7 @@ def __init__(
self.keep_results = keep_results
self.result_ex_time = result_ex_time
self.result_px_time = result_px_time
self.prefix_str = prefix_str

unavailable_conditions = any(
(
Expand All @@ -99,6 +101,11 @@ def __init__(
"Choose either result_ex_time or result_px_time.",
)

def _task_name(self, task_id: str) -> str:
if self.prefix_str is None:
return task_id
return f"{self.prefix_str}:{task_id}"

async def shutdown(self) -> None:
"""Closes redis connection."""
await self.redis_pool.disconnect()
Expand All @@ -119,7 +126,7 @@ async def set_result(
:param result: TaskiqResult instance.
"""
redis_set_params: Dict[str, Union[str, int, bytes]] = {
"name": task_id,
"name": self._task_name(task_id),
"value": self.serializer.dumpb(model_dump(result)),
}
if self.result_ex_time:
Expand All @@ -139,7 +146,7 @@ async def is_result_ready(self, task_id: str) -> bool:
:returns: True if the result is ready else False.
"""
async with Redis(connection_pool=self.redis_pool) as redis:
return bool(await redis.exists(task_id))
return bool(await redis.exists(self._task_name(task_id)))

async def get_result(
self,
Expand All @@ -154,14 +161,15 @@ async def get_result(
:raises ResultIsMissingError: if there is no result when trying to get it.
:return: task's return value.
"""
task_name = self._task_name(task_id)
async with Redis(connection_pool=self.redis_pool) as redis:
if self.keep_results:
result_value = await redis.get(
name=task_id,
name=task_name,
)
else:
result_value = await redis.getdel(
name=task_id,
name=task_name,
)

if result_value is None:
Expand Down Expand Up @@ -192,7 +200,7 @@ async def set_progress(
:param result: task's TaskProgress instance.
"""
redis_set_params: Dict[str, Union[str, int, bytes]] = {
"name": task_id + PROGRESS_KEY_SUFFIX,
"name": self._task_name(task_id) + PROGRESS_KEY_SUFFIX,
"value": self.serializer.dumpb(model_dump(progress)),
}
if self.result_ex_time:
Expand All @@ -215,7 +223,7 @@ async def get_progress(
"""
async with Redis(connection_pool=self.redis_pool) as redis:
result_value = await redis.get(
name=task_id + PROGRESS_KEY_SUFFIX,
name=self._task_name(task_id) + PROGRESS_KEY_SUFFIX,
)

if result_value is None:
Expand All @@ -237,6 +245,7 @@ def __init__(
result_ex_time: Optional[int] = None,
result_px_time: Optional[int] = None,
serializer: Optional[TaskiqSerializer] = None,
prefix_str: Optional[str] = None,
**connection_kwargs: Any,
) -> None:
"""
Expand All @@ -261,6 +270,7 @@ def __init__(
self.keep_results = keep_results
self.result_ex_time = result_ex_time
self.result_px_time = result_px_time
self.prefix_str = prefix_str

unavailable_conditions = any(
(
Expand All @@ -278,6 +288,11 @@ def __init__(
"Choose either result_ex_time or result_px_time.",
)

def _task_name(self, task_id: str) -> str:
if self.prefix_str is None:
return task_id
return f"{self.prefix_str}:{task_id}"

async def shutdown(self) -> None:
"""Closes redis connection."""
await self.redis.aclose() # type: ignore[attr-defined]
Expand All @@ -298,7 +313,7 @@ async def set_result(
:param result: TaskiqResult instance.
"""
redis_set_params: Dict[str, Union[str, bytes, int]] = {
"name": task_id,
"name": self._task_name(task_id),
"value": self.serializer.dumpb(model_dump(result)),
}
if self.result_ex_time:
Expand All @@ -316,7 +331,7 @@ async def is_result_ready(self, task_id: str) -> bool:

:returns: True if the result is ready else False.
"""
return bool(await self.redis.exists(task_id)) # type: ignore[attr-defined]
return bool(await self.redis.exists(self._task_name(task_id))) # type: ignore[attr-defined]

async def get_result(
self,
Expand All @@ -331,13 +346,14 @@ async def get_result(
:raises ResultIsMissingError: if there is no result when trying to get it.
:return: task's return value.
"""
task_name = self._task_name(task_id)
if self.keep_results:
result_value = await self.redis.get( # type: ignore[attr-defined]
name=task_id,
name=task_name,
)
else:
result_value = await self.redis.getdel( # type: ignore[attr-defined]
name=task_id,
name=task_name,
)

if result_value is None:
Expand Down Expand Up @@ -368,7 +384,7 @@ async def set_progress(
:param result: task's TaskProgress instance.
"""
redis_set_params: Dict[str, Union[str, int, bytes]] = {
"name": task_id + PROGRESS_KEY_SUFFIX,
"name": self._task_name(task_id) + PROGRESS_KEY_SUFFIX,
"value": self.serializer.dumpb(model_dump(progress)),
}
if self.result_ex_time:
Expand All @@ -389,7 +405,7 @@ async def get_progress(
:return: task's TaskProgress instance.
"""
result_value = await self.redis.get( # type: ignore[attr-defined]
name=task_id + PROGRESS_KEY_SUFFIX,
name=self._task_name(task_id) + PROGRESS_KEY_SUFFIX,
)

if result_value is None:
Expand All @@ -414,6 +430,7 @@ def __init__(
min_other_sentinels: int = 0,
sentinel_kwargs: Optional[Any] = None,
serializer: Optional[TaskiqSerializer] = None,
prefix_str: Optional[str] = None,
**connection_kwargs: Any,
) -> None:
"""
Expand Down Expand Up @@ -443,6 +460,7 @@ def __init__(
self.keep_results = keep_results
self.result_ex_time = result_ex_time
self.result_px_time = result_px_time
self.prefix_str = prefix_str

unavailable_conditions = any(
(
Expand All @@ -460,6 +478,11 @@ def __init__(
"Choose either result_ex_time or result_px_time.",
)

def _task_name(self, task_id: str) -> str:
if self.prefix_str is None:
return task_id
return f"{self.prefix_str}:{task_id}"

@asynccontextmanager
async def _acquire_master_conn(self) -> AsyncIterator[_Redis]:
async with self.sentinel.master_for(self.master_name) as redis_conn:
Expand All @@ -480,7 +503,7 @@ async def set_result(
:param result: TaskiqResult instance.
"""
redis_set_params: Dict[str, Union[str, bytes, int]] = {
"name": task_id,
"name": self._task_name(task_id),
"value": self.serializer.dumpb(model_dump(result)),
}
if self.result_ex_time:
Expand All @@ -500,7 +523,7 @@ async def is_result_ready(self, task_id: str) -> bool:
:returns: True if the result is ready else False.
"""
async with self._acquire_master_conn() as redis:
return bool(await redis.exists(task_id))
return bool(await redis.exists(self._task_name(task_id)))

async def get_result(
self,
Expand All @@ -515,14 +538,15 @@ async def get_result(
:raises ResultIsMissingError: if there is no result when trying to get it.
:return: task's return value.
"""
task_name = self._task_name(task_id)
async with self._acquire_master_conn() as redis:
if self.keep_results:
result_value = await redis.get(
name=task_id,
name=task_name,
)
else:
result_value = await redis.getdel(
name=task_id,
name=task_name,
)

if result_value is None:
Expand Down Expand Up @@ -553,7 +577,7 @@ async def set_progress(
:param result: task's TaskProgress instance.
"""
redis_set_params: Dict[str, Union[str, int, bytes]] = {
"name": task_id + PROGRESS_KEY_SUFFIX,
"name": self._task_name(task_id) + PROGRESS_KEY_SUFFIX,
"value": self.serializer.dumpb(model_dump(progress)),
}
if self.result_ex_time:
Expand All @@ -576,7 +600,7 @@ async def get_progress(
"""
async with self._acquire_master_conn() as redis:
result_value = await redis.get(
name=task_id + PROGRESS_KEY_SUFFIX,
name=self._task_name(task_id) + PROGRESS_KEY_SUFFIX,
)

if result_value is None:
Expand Down