Skip to content

Commit

Permalink
ignoring cached exceptions (#8)
Browse files Browse the repository at this point in the history
* refactor: ignore cached exceptions

* test

* pre commit

* commented out exception tests
  • Loading branch information
yskakstacker authored Sep 6, 2024
1 parent 6bf69e8 commit 566a38a
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 83 deletions.
71 changes: 23 additions & 48 deletions deche/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,7 @@ def inner():

return inner

def __call__(self, func): # noqa: C901
# TODO - very lazy async support. Refactor
# TODO - fsspec also has async support - could make exists/load calls async

def __call__(self, func):
if inspect.iscoroutinefunction(func):

@functools.wraps(func)
Expand All @@ -266,25 +263,17 @@ async def wrapper(*args, **kwargs):
key, _ = tokenize(obj=inputs)
if self.valid(path=f"{path}/{key}"):
return self._load(func=func)(key=key)
elif self._exists(func=func, ext=Extensions.exception)(key=key):
raise self._load(func=func, ext=Extensions.exception)(key=key)
try:
self.write_input(path=f"{path}/{key}", inputs=inputs)
logger.debug(f"Calling {func}")
output = await func(*args, **kwargs)
if self.result_validator is not None:
logger.debug(f"Validating result with {self.result_validator}")
try:
self.result_validator(output)
except Exception as e:
raise ValidationError(e)
logger.debug(f"Function {func} ran successfully")
self.write_output(path=f"{path}/{key}", output=output)
except Exception as e:
logger.debug(f"Function {func} raised {e}")
self.write_output(path=f"{path}/{key}{Extensions.exception}", output=e)
raise e

self.write_input(path=f"{path}/{key}", inputs=inputs)
logger.debug(f"Calling {func}")
output = await func(*args, **kwargs)
if self.result_validator is not None:
logger.debug(f"Validating result with {self.result_validator}")
try:
self.result_validator(output)
except Exception as e:
raise ValidationError(e)
logger.debug(f"Function {func} ran successfully")
self.write_output(path=f"{path}/{key}", output=output)
return output

else:
Expand All @@ -299,25 +288,17 @@ def wrapper(*args, **kwargs):
key, _ = tokenize(obj=inputs)
if self.valid(path=f"{path}/{key}"):
return self._load(func=func)(key=key)
elif self._exists(func=func, ext=Extensions.exception)(key=key):
raise self._load(func=func, ext=Extensions.exception)(key=key)
try:
self.write_input(path=f"{path}/{key}", inputs=inputs)
logger.debug(f"Calling {func}")
output = func(*args, **kwargs)
if self.result_validator is not None:
logger.debug(f"Validating result with {self.result_validator}")
try:
assert self.result_validator(output) is not False
except Exception as e:
raise ValidationError(e)
logger.debug(f"Function {func} ran successfully")
self.write_output(path=f"{path}/{key}", output=output)
except Exception as e:
logger.debug(f"Function {func} raised {e}")
self.write_output(path=f"{path}/{key}{Extensions.exception}", output=e)
raise e

self.write_input(path=f"{path}/{key}", inputs=inputs)
logger.debug(f"Calling {func}")
output = func(*args, **kwargs)
if self.result_validator is not None:
logger.debug(f"Validating result with {self.result_validator}")
try:
assert self.result_validator(output) is not False
except Exception as e:
raise ValidationError(e)
logger.debug(f"Function {func} ran successfully")
self.write_output(path=f"{path}/{key}", output=output)
return output

wrapper.tokenize = tokenize_func(func=func, ignore=self.non_hashable_kwargs, cls_attrs=self.cls_attrs)
Expand All @@ -326,20 +307,14 @@ def wrapper(*args, **kwargs):
wrapper.is_valid = self.is_valid(func=wrapper)
wrapper.has_inputs = self._exists(func=wrapper, ext=Extensions.inputs)
wrapper.has_data = self._exists(func=wrapper)
wrapper.has_exception = self._exists(func=wrapper, ext=Extensions.exception)
wrapper.list_cached_inputs = self._list(func=wrapper, ext=Extensions.inputs)
wrapper.list_cached_data = self._list(func=wrapper, filter_=data_filter)
wrapper.list_cached_exceptions = self._list(func=wrapper, ext=Extensions.exception)
wrapper.iter_cached_inputs = self._iter(func=wrapper, ext=Extensions.inputs)
wrapper.iter_cached_data = self._iter(func=wrapper, filter_=data_filter)
wrapper.iter_cached_exception = self._iter(func=wrapper, ext=Extensions.exception)
wrapper.load_cached_inputs = self._load(func=wrapper, ext=Extensions.inputs)
wrapper.load_cached_data = self._load(func=wrapper)
wrapper.load_cached_exception = self._load(func=wrapper, ext=Extensions.exception)
wrapper.remove_cached_inputs = self._remove(func=wrapper, ext=Extensions.inputs)
wrapper.remove_cached_data = self._remove(func=wrapper)
wrapper.remove_cached_exception = self._remove(func=wrapper, ext=Extensions.exception)
wrapper.remove_all_cached_exceptions = self._remove_all(func=wrapper, ext=Extensions.exception)
wrapper.path = functools.partial(self._path, func=func)
wrapper.deche = self
return wrapper
Expand Down
70 changes: 35 additions & 35 deletions tests/unit/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,14 @@ def test_list_cached_data():
assert result == ["/deche.test_utils.func/f4f46c47d91eea40eba825cf941ff22bdc87ce849400ed3fd85be092e43031d4"]


def test_list_cached_exceptions():
with pytest.raises(ZeroDivisionError):
exc_func()
result = exc_func.list_cached_exceptions()
assert result == ["6c8d328939ceaaf60d6cbe813bf07a48656647184baa590fe9b6632bfc3d7936"]
# def test_list_cached_exceptions():
# with pytest.raises(ZeroDivisionError):
# exc_func()
# result = exc_func.list_cached_exceptions()
# assert result == ["6c8d328939ceaaf60d6cbe813bf07a48656647184baa590fe9b6632bfc3d7936"]

result = exc_func.list_cached_exceptions(key_only=False)
assert result == ["/deche.test_utils.exc_func/6c8d328939ceaaf60d6cbe813bf07a48656647184baa590fe9b6632bfc3d7936.exc"]
# result = exc_func.list_cached_exceptions(key_only=False)
# assert result == ["/deche.test_utils.exc_func/6c8d328939ceaaf60d6cbe813bf07a48656647184baa590fe9b6632bfc3d7936.exc"]


def test_iter():
Expand Down Expand Up @@ -194,31 +194,31 @@ def test_load_cached_data():
assert result == expected


def test_load_cached_exception():
try:
exc_func()
except ZeroDivisionError as expected:
result = exc_func.load_cached_exception(kwargs={})
assert isinstance(result, type(expected))
assert type(expected) == type(result)
# def test_load_cached_exception():
# try:
# exc_func()
# except ZeroDivisionError as expected:
# result = exc_func.load_cached_exception(kwargs={})
# assert isinstance(result, type(expected))
# assert type(expected) == type(result)

key = exc_func.tokenize()
result = exc_func.load_cached_exception(key=key)
assert isinstance(result, type(expected))
# key = exc_func.tokenize()
# result = exc_func.load_cached_exception(key=key)
# assert isinstance(result, type(expected))


def test_remove_all_exceptions():
try:
exc_func(1)
except ZeroDivisionError:
pass
try:
exc_func(2)
except ZeroDivisionError:
pass
assert len(exc_func.list_cached_exceptions()) == 2
exc_func.remove_all_cached_exceptions()
assert len(exc_func.list_cached_exceptions()) == 0
# def test_remove_all_exceptions():
# try:
# exc_func(1)
# except ZeroDivisionError:
# pass
# try:
# exc_func(2)
# except ZeroDivisionError:
# pass
# assert len(exc_func.list_cached_exceptions()) == 2
# exc_func.remove_all_cached_exceptions()
# assert len(exc_func.list_cached_exceptions()) == 0


def test_exists():
Expand Down Expand Up @@ -272,12 +272,12 @@ def test_cache_path(c: Cache):
assert func.path() == "/deche.test_utils.func"


def test_cache_exception(c: Cache):
try:
exc_func()
except ZeroDivisionError as e:
exc = exc_func.load_cached_exception(kwargs={})
assert type(exc) == type(e)
# def test_cache_exception(c: Cache):
# try:
# exc_func()
# except ZeroDivisionError as e:
# exc = exc_func.load_cached_exception(kwargs={})
# assert type(exc) == type(e)


def test_cached_exception_raises(cached_exception):
Expand Down
108 changes: 108 additions & 0 deletions tests/unit/test_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import time

import pytest

from deche import Cache


@pytest.fixture
def memory_cache():
return Cache(fs_protocol="memory", prefix="/")


def test_exception_not_cached(memory_cache):
@memory_cache
def failing_func():
time.sleep(0.1)
raise ValueError("This function always fails")

# First call should raise the exception
start_time = time.time()
with pytest.raises(ValueError):
failing_func()
first_call_time = time.time() - start_time

# Second call should also raise the exception, not return a cached exception
start_time = time.time()
with pytest.raises(ValueError):
failing_func()
second_call_time = time.time() - start_time

assert failing_func.list_cached_data() == []
assert abs(first_call_time - second_call_time) < 0.05 # Both calls should take similar time


def test_successful_execution_after_exception(memory_cache):
call_count = 0

@memory_cache
def sometimes_failing_func(fail=True):
nonlocal call_count
call_count += 1
time.sleep(0.1)
if fail:
raise ValueError("This function fails when fail=True")
return "Success"

# First call should raise the exception
start_time = time.time()
with pytest.raises(ValueError):
sometimes_failing_func(fail=True)
exception_time = time.time() - start_time

# Second call with fail=False should execute the function and cache the result
start_time = time.time()
result = sometimes_failing_func(fail=False)
success_time = time.time() - start_time

assert result == "Success"
assert call_count == 2
assert exception_time > 0.1
assert success_time > 0.1

# Third call with fail=False should return the cached result
start_time = time.time()
result = sometimes_failing_func(fail=False)
cached_time = time.time() - start_time

assert result == "Success"
assert call_count == 2 # Call count shouldn't increase
assert cached_time < 0.01 # Cached call should be very fast


def test_cache_behavior_unchanged_for_successful_calls(memory_cache):
call_count = 0

@memory_cache
def cached_func(x):
nonlocal call_count
call_count += 1
time.sleep(0.1)
return x * 2

# First call should execute the function
start_time = time.time()
result = cached_func(5)
first_call_time = time.time() - start_time

assert result == 10
assert call_count == 1
assert first_call_time > 0.1

# Second call should return cached result
start_time = time.time()
result = cached_func(5)
second_call_time = time.time() - start_time

assert result == 10
assert call_count == 1 # Call count shouldn't increase
assert second_call_time < 0.01 # Cached call should be very fast

# Call with different argument should execute the function again
start_time = time.time()
result = cached_func(7)
third_call_time = time.time() - start_time

assert result == 14
assert call_count == 2
assert third_call_time > 0.1

0 comments on commit 566a38a

Please sign in to comment.