Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert _default_params to dataclass #237

Merged
merged 12 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ repos:
args: ["--number"]

- repo: https://github.com/pre-commit/mirrors-prettier
rev: v4.0.0-alpha.8
rev: v3.1.0
hooks:
- id: prettier
files: \.(json|yml|yaml|toml)
Expand Down
8 changes: 4 additions & 4 deletions src/cachier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
from .config import (
disable_caching,
enable_caching,
get_default_params,
set_default_params,
get_global_params,
set_global_params,
)
from .core import cachier

__all__ = [
"cachier",
"set_default_params",
"get_default_params",
"set_global_params",
"get_global_params",
"enable_caching",
"disable_caching",
]
101 changes: 62 additions & 39 deletions src/cachier/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import hashlib
import os
import pickle
from typing import Optional, TypedDict, Union
from collections.abc import Mapping
from dataclasses import dataclass, replace
from typing import Optional, Union

from ._types import Backend, HashFunc, Mongetter

Expand All @@ -16,35 +18,24 @@ def _default_hash_func(args, kwds):
return hashlib.sha256(serialized).hexdigest()


class Params(TypedDict):
"""Type definition for cachier parameters."""

caching_enabled: bool
hash_func: HashFunc
backend: Backend
mongetter: Optional[Mongetter]
stale_after: datetime.timedelta
next_time: bool
cache_dir: Union[str, os.PathLike]
pickle_reload: bool
separate_files: bool
wait_for_calc_timeout: int
allow_none: bool


_default_params: Params = {
"caching_enabled": True,
"hash_func": _default_hash_func,
"backend": "pickle",
"mongetter": None,
"stale_after": datetime.timedelta.max,
"next_time": False,
"cache_dir": "~/.cachier/",
"pickle_reload": True,
"separate_files": False,
"wait_for_calc_timeout": 0,
"allow_none": False,
}
@dataclass
class Params:
"""Default definition for cachier parameters."""

caching_enabled: bool = True
hash_func: HashFunc = _default_hash_func
backend: Backend = "pickle"
mongetter: Optional[Mongetter] = None
stale_after: datetime.timedelta = datetime.timedelta.max
next_time: bool = False
cache_dir: Union[str, os.PathLike] = "~/.cachier/"
pickle_reload: bool = True
separate_files: bool = False
wait_for_calc_timeout: int = 0
allow_none: bool = False


_global_params = Params()


def _update_with_defaults(
Expand All @@ -57,11 +48,25 @@ def _update_with_defaults(
if kw_name in func_kwargs:
return func_kwargs.pop(kw_name)
if param is None:
return cachier.config._default_params[name]
return getattr(cachier.config._global_params, name)
return param


def set_default_params(**params):
def set_default_params(**params: Mapping) -> None:
"""Configure default parameters applicable to all memoized functions."""
# It is kept for backwards compatibility with desperation warning
import warnings

warnings.warn(
"Called `set_default_params` is deprecated and will be removed."
" Please use `set_global_params` instead.",
DeprecationWarning,
stacklevel=2,
)
set_global_params(**params)


def set_global_params(**params: Mapping) -> None:
"""Configure global parameters applicable to all memoized functions.

This function takes the same keyword parameters as the ones defined in the
Expand All @@ -76,28 +81,46 @@ def set_default_params(**params):
"""
import cachier

valid_params = (
p for p in params.items() if p[0] in cachier.config._default_params
valid_params = {
k: v
for k, v in params.items()
if hasattr(cachier.config._global_params, k)
}
cachier.config._global_params = replace(
cachier.config._global_params, **valid_params
)


def get_default_params() -> Params:
"""Get current set of default parameters."""
# It is kept for backwards compatibility with desperation warning
import warnings

warnings.warn(
"Called `get_default_params` is deprecated and will be removed."
" Please use `get_global_params` instead.",
DeprecationWarning,
stacklevel=2,
)
_default_params.update(valid_params)
return get_global_params()


def get_default_params():
def get_global_params() -> Params:
"""Get current set of default parameters."""
import cachier

return cachier.config._default_params
return cachier.config._global_params


def enable_caching():
"""Enable caching globally."""
import cachier

cachier.config._default_params["caching_enabled"] = True
cachier.config._global_params.caching_enabled = True


def disable_caching():
"""Disable caching globally."""
import cachier

cachier.config._default_params["caching_enabled"] = False
cachier.config._global_params.caching_enabled = False
5 changes: 3 additions & 2 deletions src/cachier/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
Backend,
HashFunc,
Mongetter,
_default_params,
_update_with_defaults,
)
from .cores.base import RecalculationNeeded, _BaseCore
Expand Down Expand Up @@ -176,6 +175,8 @@ def cachier(
None will not be cached and are recalculated every call.

"""
from .config import _global_params

# Check for deprecated parameters
if hash_params is not None:
message = (
Expand Down Expand Up @@ -244,7 +245,7 @@ def func_wrapper(*args, **kwds):
_print = lambda x: None # noqa: E731
if verbose:
_print = print
if ignore_cache or not _default_params["caching_enabled"]:
if ignore_cache or not _global_params.caching_enabled:
return (
func(args[0], **kwargs)
if core.func_is_method
Expand Down
8 changes: 4 additions & 4 deletions tests/test_core_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import pytest

from cachier import cachier, get_default_params
from cachier import cachier, get_global_params
from cachier.cores.mongo import MissingMongetter


def test_get_default_params():
params = get_default_params()
assert tuple(sorted(params)) == (
params = get_global_params()
assert sorted(vars(params).keys()) == [
"allow_none",
"backend",
"cache_dir",
Expand All @@ -20,7 +20,7 @@ def test_get_default_params():
"separate_files",
"stale_after",
"wait_for_calc_timeout",
)
]


def test_bad_name(name="nope"):
Expand Down
25 changes: 13 additions & 12 deletions tests/test_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,23 @@
import random
import threading
import time
from dataclasses import replace

import pytest

import cachier
from tests.test_mongo_core import _test_mongetter

MONGO_DELTA = datetime.timedelta(seconds=3)
_default_params = cachier.get_default_params().copy()
_copied_defaults = replace(cachier.get_global_params())


def setup_function():
cachier.set_default_params(**_default_params)
cachier.set_global_params(**vars(_copied_defaults))


def teardown_function():
cachier.set_default_params(**_default_params)
cachier.set_global_params(**vars(_copied_defaults))


def test_hash_func_default_param():
Expand All @@ -30,7 +31,7 @@ def slow_hash_func(args, kwds):
def fast_hash_func(args, kwds):
return "hash"

cachier.set_default_params(hash_func=slow_hash_func)
cachier.set_global_params(hash_func=slow_hash_func)

@cachier.cachier()
def global_test_1():
Expand All @@ -51,7 +52,7 @@ def global_test_2():


def test_backend_default_param():
cachier.set_default_params(backend="memory")
cachier.set_global_params(backend="memory")

@cachier.cachier()
def global_test_1():
Expand All @@ -67,7 +68,7 @@ def global_test_2():

@pytest.mark.mongo
def test_mongetter_default_param():
cachier.set_default_params(mongetter=_test_mongetter)
cachier.set_global_params(mongetter=_test_mongetter)

@cachier.cachier()
def global_test_1():
Expand All @@ -82,7 +83,7 @@ def global_test_2():


def test_cache_dir_default_param(tmpdir):
cachier.set_default_params(cache_dir=tmpdir / "1")
cachier.set_global_params(cache_dir=tmpdir / "1")

@cachier.cachier()
def global_test_1():
Expand All @@ -97,7 +98,7 @@ def global_test_2():


def test_separate_files_default_param(tmpdir):
cachier.set_default_params(separate_files=True)
cachier.set_global_params(separate_files=True)

@cachier.cachier(cache_dir=tmpdir / "1")
def global_test_1(arg_1, arg_2):
Expand All @@ -117,7 +118,7 @@ def global_test_2(arg_1, arg_2):


def test_allow_none_default_param(tmpdir):
cachier.set_default_params(
cachier.set_global_params(
allow_none=True,
separate_files=True,
verbose_cache=True,
Expand Down Expand Up @@ -167,7 +168,7 @@ def _stale_after_test(arg_1, arg_2):
"""Some function."""
return random.random() + arg_1 + arg_2

cachier.set_default_params(stale_after=MONGO_DELTA)
cachier.set_global_params(stale_after=MONGO_DELTA)

_stale_after_test.clear_cache()
val1 = _stale_after_test(1, 2)
Expand All @@ -187,7 +188,7 @@ def _stale_after_next_time(arg_1, arg_2):
"""Some function."""
return random.random()

cachier.set_default_params(stale_after=NEXT_AFTER_DELTA, next_time=True)
cachier.set_global_params(stale_after=NEXT_AFTER_DELTA, next_time=True)

_stale_after_next_time.clear_cache()
val1 = _stale_after_next_time(1, 2)
Expand Down Expand Up @@ -217,7 +218,7 @@ def _calls_wait_for_calc_timeout_slow(res_queue):
res = _wait_for_calc_timeout_slow(1, 2)
res_queue.put(res)

cachier.set_default_params(wait_for_calc_timeout=2)
cachier.set_global_params(wait_for_calc_timeout=2)
_wait_for_calc_timeout_slow.clear_cache()
res_queue = queue.Queue()
thread1 = threading.Thread(
Expand Down
4 changes: 3 additions & 1 deletion tests/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,17 @@ def test_separate_processes():

def test_global_disable():
@cachier.cachier()
def get_random():
def get_random() -> float:
return random()

get_random.clear_cache()
result_1 = get_random()
result_2 = get_random()
cachier.disable_caching()
assert cachier.config._global_params.caching_enabled is False
result_3 = get_random()
cachier.enable_caching()
assert cachier.config._global_params.caching_enabled is True
result_4 = get_random()
assert result_1 == result_2 == result_4
assert result_1 != result_3
Expand Down
4 changes: 2 additions & 2 deletions tests/test_pickle_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import pandas as pd

from cachier import cachier
from cachier.core import _default_params
from cachier.config import _global_params


def _get_decorated_func(func, **kwargs):
Expand Down Expand Up @@ -329,7 +329,7 @@ def _bad_cache(arg_1, arg_2):
".tests.test_pickle_core._bad_cache_"
f"{hashlib.sha256(pickle.dumps((0.13, 0.02))).hexdigest()}"
)
EXPANDED_CACHIER_DIR = os.path.expanduser(_default_params["cache_dir"])
EXPANDED_CACHIER_DIR = os.path.expanduser(_global_params.cache_dir)
_BAD_CACHE_FPATH = os.path.join(EXPANDED_CACHIER_DIR, _BAD_CACHE_FNAME)
_BAD_CACHE_FPATH_SEPARATE_FILES = os.path.join(
EXPANDED_CACHIER_DIR, _BAD_CACHE_FNAME_SEPARATE_FILES
Expand Down
Loading