diff --git a/src/nextline_test_utils/strategies/misc.py b/src/nextline_test_utils/strategies/misc.py index c9684dd..84c6f1d 100644 --- a/src/nextline_test_utils/strategies/misc.py +++ b/src/nextline_test_utils/strategies/misc.py @@ -13,7 +13,7 @@ class StMinMaxValuesFactory(Protocol[T]): # pragma: no cover def __call__( - self, *, min_value: Optional[T] = None, max_value: Optional[T] = None + self, min_value: Optional[T] = None, max_value: Optional[T] = None ) -> st.SearchStrategy[T]: ... @@ -68,7 +68,9 @@ def st_graphql_ints( def st_ranges( - st_: StMinMaxValuesFactory[T], + st_: StMinMaxValuesFactory[T] = st.integers, # type: ignore + /, + *, min_start: Optional[T] = None, max_start: Optional[T] = None, min_end: Optional[T] = None, diff --git a/tests/strategies/test_st_datetimes.py b/tests/strategies/test_st_datetimes.py index 0d8e3de..344f0ee 100644 --- a/tests/strategies/test_st_datetimes.py +++ b/tests/strategies/test_st_datetimes.py @@ -7,7 +7,7 @@ @given(st.data()) def test_st_datetimes(data: st.DataObject) -> None: - min_, max_ = data.draw(st_ranges(st_=st_datetimes)) + min_, max_ = data.draw(st_ranges(st_datetimes)) dt_ = data.draw(st_datetimes(min_value=min_, max_value=max_)) assert dt_.tzinfo is None assert dt_.fold == 0 diff --git a/tests/strategies/test_st_ints.py b/tests/strategies/test_st_ints.py index 14cca35..23d9591 100644 --- a/tests/strategies/test_st_ints.py +++ b/tests/strategies/test_st_ints.py @@ -19,7 +19,7 @@ @given(st.data()) def test_st_sqlite_ints(data: st.DataObject) -> None: min_, max_ = data.draw( - st_ranges(st_=st.integers, max_start=SQLITE_INT_MAX, min_end=SQLITE_INT_MIN) + st_ranges(st.integers, max_start=SQLITE_INT_MAX, min_end=SQLITE_INT_MIN) ) i = data.draw(st_sqlite_ints(min_value=min_, max_value=max_)) @@ -31,7 +31,7 @@ def test_st_sqlite_ints(data: st.DataObject) -> None: @given(st.data()) def test_st_graphql_ints(data: st.DataObject) -> None: min_, max_ = data.draw( - st_ranges(st_=st.integers, max_start=GRAPHQL_MAX_INT, min_end=GRAPHQL_MIN_INT) + st_ranges(st.integers, max_start=GRAPHQL_MAX_INT, min_end=GRAPHQL_MIN_INT) ) i = data.draw(st_graphql_ints(min_value=min_, max_value=max_)) diff --git a/tests/strategies/test_st_ranges.py b/tests/strategies/test_st_ranges.py index 25a4e47..a7fa9d1 100644 --- a/tests/strategies/test_st_ranges.py +++ b/tests/strategies/test_st_ranges.py @@ -1,5 +1,5 @@ import sys -from typing import Optional, TypeVar +from typing import Any, Optional, TypeVar, overload from hypothesis import given, settings from hypothesis import strategies as st @@ -50,7 +50,7 @@ def st_max(min_value: Optional[T]) -> st.SearchStrategy[Optional[T]]: class StRangesKwargs(TypedDict, Generic[T], total=False): - st_: StMinMaxValuesFactory[T] + # st_: StMinMaxValuesFactory[T] min_start: Optional[T] max_start: Optional[T] min_end: Optional[T] @@ -63,9 +63,10 @@ class StRangesKwargs(TypedDict, Generic[T], total=False): @st.composite def st_st_ranges_kwargs( - draw: st.DrawFn, st_: StMinMaxValuesFactory[T] + draw: st.DrawFn, st_: StMinMaxValuesFactory[T] | None = None ) -> StRangesKwargs[T]: - kwargs = StRangesKwargs(st_=st_) + st_ = st_ or st.integers # type: ignore + kwargs = StRangesKwargs[T]() min_start, max_start = draw(st_min_max_start(st_=st_)) # type: ignore if min_start is not None: @@ -93,7 +94,7 @@ def st_st_ranges_kwargs( @given(st.data()) def test_st_st_ranges_kwargs(data: st.DataObject) -> None: - st_ = data.draw(st.sampled_from([st_graphql_ints, st_datetimes])) + st_ = data.draw(st.sampled_from([None, st_graphql_ints, st_datetimes])) kwargs = data.draw(st_st_ranges_kwargs(st_=st_)) # type: ignore min_start = kwargs.get('min_start') @@ -108,10 +109,12 @@ def test_st_st_ranges_kwargs(data: st.DataObject) -> None: @given(st.data()) @settings(max_examples=1000) def test_st_ranges(data: st.DataObject) -> None: - st_ = data.draw(st.sampled_from([st_graphql_ints, st_datetimes])) + st_ = data.draw(st.sampled_from([None, st_graphql_ints, st_datetimes])) kwargs = data.draw(st_st_ranges_kwargs(st_=st_)) # type: ignore + + args = (st_,) if st_ is not None else () - start, end = data.draw(st_ranges(**kwargs)) + start, end = data.draw(st_ranges(*args, **kwargs)) # type: ignore allow_start_none = kwargs.get('allow_start_none', True) if not allow_start_none: diff --git a/tests/test_safe_compare.py b/tests/test_safe_compare.py index c963d5f..c104c57 100644 --- a/tests/test_safe_compare.py +++ b/tests/test_safe_compare.py @@ -14,7 +14,7 @@ def test_repr() -> None: def test_safe_compare(data: st.DataObject) -> None: allow_equal = data.draw(st.booleans()) none_or_small, none_or_large = data.draw( - st_ranges(st_=st.integers, allow_equal=allow_equal) + st_ranges(st.integers, allow_equal=allow_equal) ) if allow_equal: assert safe_compare(none_or_small) <= safe_compare(none_or_large)