Skip to content

Commit

Permalink
Merge pull request #4 from simonsobs/dev
Browse files Browse the repository at this point in the history
Update the function signature of `st_ranges()`
  • Loading branch information
TaiSakuma authored Oct 1, 2024
2 parents a8743fd + e8e2be7 commit 5211d3e
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 13 deletions.
6 changes: 4 additions & 2 deletions src/nextline_test_utils/strategies/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class StMinMaxValuesFactory(Protocol[T]): # pragma: no cover
def __call__(
self, *, min_value: Optional[T] = None, max_value: Optional[T] = None
self, min_value: Optional[T] = None, max_value: Optional[T] = None
) -> st.SearchStrategy[T]: ...


Expand Down Expand Up @@ -68,7 +68,9 @@ def st_graphql_ints(


def st_ranges(
st_: StMinMaxValuesFactory[T],
st_: StMinMaxValuesFactory[T] = st.integers, # type: ignore
/,
*,
min_start: Optional[T] = None,
max_start: Optional[T] = None,
min_end: Optional[T] = None,
Expand Down
2 changes: 1 addition & 1 deletion tests/strategies/test_st_datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

@given(st.data())
def test_st_datetimes(data: st.DataObject) -> None:
min_, max_ = data.draw(st_ranges(st_=st_datetimes))
min_, max_ = data.draw(st_ranges(st_datetimes))
dt_ = data.draw(st_datetimes(min_value=min_, max_value=max_))
assert dt_.tzinfo is None
assert dt_.fold == 0
Expand Down
4 changes: 2 additions & 2 deletions tests/strategies/test_st_ints.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
@given(st.data())
def test_st_sqlite_ints(data: st.DataObject) -> None:
min_, max_ = data.draw(
st_ranges(st_=st.integers, max_start=SQLITE_INT_MAX, min_end=SQLITE_INT_MIN)
st_ranges(st.integers, max_start=SQLITE_INT_MAX, min_end=SQLITE_INT_MIN)
)

i = data.draw(st_sqlite_ints(min_value=min_, max_value=max_))
Expand All @@ -31,7 +31,7 @@ def test_st_sqlite_ints(data: st.DataObject) -> None:
@given(st.data())
def test_st_graphql_ints(data: st.DataObject) -> None:
min_, max_ = data.draw(
st_ranges(st_=st.integers, max_start=GRAPHQL_MAX_INT, min_end=GRAPHQL_MIN_INT)
st_ranges(st.integers, max_start=GRAPHQL_MAX_INT, min_end=GRAPHQL_MIN_INT)
)

i = data.draw(st_graphql_ints(min_value=min_, max_value=max_))
Expand Down
17 changes: 10 additions & 7 deletions tests/strategies/test_st_ranges.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from typing import Optional, TypeVar
from typing import Any, Optional, TypeVar, overload

from hypothesis import given, settings
from hypothesis import strategies as st
Expand Down Expand Up @@ -50,7 +50,7 @@ def st_max(min_value: Optional[T]) -> st.SearchStrategy[Optional[T]]:


class StRangesKwargs(TypedDict, Generic[T], total=False):
st_: StMinMaxValuesFactory[T]
# st_: StMinMaxValuesFactory[T]
min_start: Optional[T]
max_start: Optional[T]
min_end: Optional[T]
Expand All @@ -63,9 +63,10 @@ class StRangesKwargs(TypedDict, Generic[T], total=False):

@st.composite
def st_st_ranges_kwargs(
draw: st.DrawFn, st_: StMinMaxValuesFactory[T]
draw: st.DrawFn, st_: StMinMaxValuesFactory[T] | None = None
) -> StRangesKwargs[T]:
kwargs = StRangesKwargs(st_=st_)
st_ = st_ or st.integers # type: ignore
kwargs = StRangesKwargs[T]()

min_start, max_start = draw(st_min_max_start(st_=st_)) # type: ignore
if min_start is not None:
Expand Down Expand Up @@ -93,7 +94,7 @@ def st_st_ranges_kwargs(

@given(st.data())
def test_st_st_ranges_kwargs(data: st.DataObject) -> None:
st_ = data.draw(st.sampled_from([st_graphql_ints, st_datetimes]))
st_ = data.draw(st.sampled_from([None, st_graphql_ints, st_datetimes]))
kwargs = data.draw(st_st_ranges_kwargs(st_=st_)) # type: ignore

min_start = kwargs.get('min_start')
Expand All @@ -108,10 +109,12 @@ def test_st_st_ranges_kwargs(data: st.DataObject) -> None:
@given(st.data())
@settings(max_examples=1000)
def test_st_ranges(data: st.DataObject) -> None:
st_ = data.draw(st.sampled_from([st_graphql_ints, st_datetimes]))
st_ = data.draw(st.sampled_from([None, st_graphql_ints, st_datetimes]))
kwargs = data.draw(st_st_ranges_kwargs(st_=st_)) # type: ignore

args = (st_,) if st_ is not None else ()

start, end = data.draw(st_ranges(**kwargs))
start, end = data.draw(st_ranges(*args, **kwargs)) # type: ignore

allow_start_none = kwargs.get('allow_start_none', True)
if not allow_start_none:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_safe_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_repr() -> None:
def test_safe_compare(data: st.DataObject) -> None:
allow_equal = data.draw(st.booleans())
none_or_small, none_or_large = data.draw(
st_ranges(st_=st.integers, allow_equal=allow_equal)
st_ranges(st.integers, allow_equal=allow_equal)
)
if allow_equal:
assert safe_compare(none_or_small) <= safe_compare(none_or_large)
Expand Down

0 comments on commit 5211d3e

Please sign in to comment.