From 01fd8c979264ccf3b96c6532d4e7a71e6dfe82e8 Mon Sep 17 00:00:00 2001 From: MylesBartlett Date: Thu, 1 Aug 2024 13:02:56 +0100 Subject: [PATCH] Fix variance issue with Par generic. --- pyrightconfig.json | 4 +- requirements-dev.lock | 11 ++--- requirements.lock | 1 + serox/collections/hash_map.py | 7 +-- serox/collections/hash_set.py | 29 +++++++----- serox/common.py | 10 +++++ serox/iter.py | 84 +++++++++++++++++++++++++---------- serox/option.py | 3 +- serox/range.py | 27 ++++++----- serox/vec.py | 58 +++++++++++++----------- 10 files changed, 151 insertions(+), 83 deletions(-) create mode 100644 serox/common.py diff --git a/pyrightconfig.json b/pyrightconfig.json index c6d3db1..a1fb535 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -4,9 +4,9 @@ "serox/**" ], "exclude": [ - "**/.", "**/__pycache__", "**/node_modules", + "**/.undodir", ".venv", "**/.cache" ], @@ -18,7 +18,7 @@ "reportUnusedCallResult": "error", "reportUnnecessaryTypeIgnoreComment": "warning", "reportMissingSuperCall": "warning", - "reportImportCycles": "error", + "reportImportCycles": "none", "reportShadowedImports": "warning", "reportUninitializedInstanceVariable": "error", "reportPropertyTypeMismatch": "error", diff --git a/requirements-dev.lock b/requirements-dev.lock index b58a043..0911c25 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -7,6 +7,7 @@ # all-features: false # with-sources: false # generate-hashes: false +# universal: false -e file:. black==24.4.2 @@ -44,21 +45,21 @@ packaging==24.1 # via pytest pathspec==0.12.1 # via black -pip==24.1.2 +pip==24.2 platformdirs==4.2.2 # via black # via virtualenv pluggy==1.5.0 # via pytest -pre-commit==3.7.1 +pre-commit==3.8.0 pydoclint==0.5.6 -pyright==1.1.372 -pytest==8.3.1 +pyright==1.1.374 +pytest==8.3.2 # via pytest-cov pytest-cov==5.0.0 pyyaml==6.0.1 # via pre-commit -ruff==0.5.4 +ruff==0.5.5 typing-extensions==4.12.2 # via serox virtualenv==20.26.3 diff --git a/requirements.lock b/requirements.lock index 07482e6..bd086ff 100644 --- a/requirements.lock +++ b/requirements.lock @@ -7,6 +7,7 @@ # all-features: false # with-sources: false # generate-hashes: false +# universal: false -e file:. joblib==1.4.2 diff --git a/serox/collections/hash_map.py b/serox/collections/hash_map.py index 305f368..42b2744 100644 --- a/serox/collections/hash_map.py +++ b/serox/collections/hash_map.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from typing import Any, Generator, Hashable, Iterable, Literal, Sized, override +from serox.common import False_, True_ from serox.conftest import TESTING from serox.convert import Into from serox.default import Default @@ -123,7 +124,7 @@ def clone(self) -> HashMap[K, V]: @dataclass(repr=True, init=False) -class Keys[K, P: bool](Iterator[K, P]): +class Keys[K, P: (True_, False_)](Iterator[K, P]): def __init__(self, inner: HashMap[K, Any], /, par: P) -> None: super().__init__() self.iter = iter(inner.inner.keys()) @@ -138,7 +139,7 @@ def next(self) -> Option[K]: @dataclass(repr=True, init=False) -class Values[V, P: bool](Iterator[V, P]): +class Values[V, P: (True_, False_)](Iterator[V, P]): def __init__(self, inner: HashMap[Any, V], /, par: P) -> None: super().__init__() self.iter = iter(inner.inner.values()) @@ -153,7 +154,7 @@ def next(self) -> Option[V]: @dataclass(repr=True, init=False) -class Entries[K, V, P: bool](Iterator[Entry[K, V], P]): +class Entries[K, V, P: (True_, False_)](Iterator[Entry[K, V], P]): def __init__(self, inner: HashMap[K, V], /, par: P) -> None: super().__init__() self.iter = iter(inner.inner.items()) diff --git a/serox/collections/hash_set.py b/serox/collections/hash_set.py index 154f4d9..4db1492 100644 --- a/serox/collections/hash_set.py +++ b/serox/collections/hash_set.py @@ -1,7 +1,9 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Generator, Hashable, Iterable, Literal, Self, Sized, override +from typing import Any, Generator, Hashable, Iterable, Self, Sized, override +from typing import Iterator as NativeIterator +from serox.common import False_, True_ from serox.convert import Into from serox.default import Default from serox.iter import Extend, FromIterator, IntoIterator, IntoParIterator, Iterator @@ -125,27 +127,32 @@ def from_iter(cls, iter: Iterable[T], /) -> HashSet[T]: return HashSet(*iter) @override - def iter(self) -> Iter[T, Literal[False]]: - return Iter(self, par=False) + def iter(self) -> Iter[T, False_]: + return Iter.new(self, par=False) def __iter__(self) -> Generator[T, None, None]: yield from self.iter() @override - def par_iter(self) -> Iter[T, Literal[True]]: - return Iter(self, par=True) + def par_iter(self) -> Iter[T, True_]: + return Iter.new(self, par=True) @override def clone(self) -> HashSet[T]: return HashSet(*self.inner.copy()) -@dataclass(repr=True, init=False) -class Iter[Item, P: bool](Iterator[Item, P]): - def __init__(self, inner: HashSet[Item], /, par: P) -> None: - super().__init__() - self.iter = iter(inner.inner) - self.par = par +@dataclass(repr=True, frozen=True, kw_only=True) +class Iter[Item, Par: (True_, False_)](Iterator[Item, Par]): + iter: NativeIterator[Item] + par: Par + + @classmethod + def new[Item2, Par2: (True_, False_)]( + cls, data: HashSet[Item2], par: Par2 = True + ) -> Iter[Item2, Par2]: + iter_ = iter(data.inner) + return Iter(iter=iter_, par=par) @override def next(self) -> Option[Item]: diff --git a/serox/common.py b/serox/common.py new file mode 100644 index 0000000..bba8a94 --- /dev/null +++ b/serox/common.py @@ -0,0 +1,10 @@ +from __future__ import annotations +from typing import Literal + +__all__ = [ + "False_", + "True_", +] + +type True_ = Literal[True] +type False_ = Literal[False] diff --git a/serox/iter.py b/serox/iter.py index 11ea1ad..57657c4 100644 --- a/serox/iter.py +++ b/serox/iter.py @@ -9,7 +9,6 @@ Callable, Generator, Iterable, - Literal, Protocol, Self, cast, @@ -25,6 +24,8 @@ ) from serox.cmp import Ord +from serox.common import False_, True_ +from serox.conftest import TESTING from serox.misc import SelfAddable, SelfMultiplicable if TYPE_CHECKING: @@ -50,12 +51,13 @@ "Zip", ] + type Fn1[T, U] = Callable[[T], U] class FromIterator[A](Protocol): @classmethod - def from_iter[P: bool](cls, iter: Iterator[A, P], /) -> Self: ... + def from_iter[P: (True_, False_)](cls, iter: Iterator[A, P], /) -> Self: ... def _identity[T](x: T) -> T: @@ -63,7 +65,7 @@ def _identity[T](x: T) -> T: @runtime_checkable -class Iterator[Item, Par: bool](Protocol): +class Iterator[Item, Par: (True_, False_)](Protocol): par: Par def next(self) -> Option[Item]: ... @@ -195,9 +197,9 @@ def max[U: Ord](self: Iterator[U, Par]) -> U: def min[U: Ord](self: Iterator[U, Par]) -> U: return min(self) - def par_bridge(self) -> Iterator[Item, Literal[True]]: + def par_bridge(self) -> Iterator[Item, True_]: object.__setattr__(self, "par", True) - return cast(Iterator[Item, Literal[True]], self) + return cast(Iterator[Item, True_], self) class Chunk[Item](list[Item], FromIterator[Item]): @@ -214,7 +216,7 @@ def is_empty(self) -> bool: @dataclass -class ArrayChunk[Item, P: bool](Iterator[Chunk[Item], P]): +class ArrayChunk[Item, P: (True_, False_)](Iterator[Chunk[Item], P]): iter: Iterator[Item, P] n: int par: P @@ -231,14 +233,14 @@ def next(self) -> Option[Chunk[Item]]: class IntoIterator[T](Protocol): - def iter(self) -> Iterator[T, Literal[False]]: ... + def iter(self) -> Iterator[T, False_]: ... class IntoParIterator[T](Protocol): - def par_iter(self) -> Iterator[T, Literal[True]]: ... + def par_iter(self) -> Iterator[T, True_]: ... -class DoubleEndedIterator[Item, P: bool](Iterator[Item, P], Protocol): +class DoubleEndedIterator[Item, P: (True_, False_)](Iterator[Item, P], Protocol): def next_back(self) -> Option[Item]: ... def rev(self) -> Rev[Item, P]: @@ -246,7 +248,7 @@ def rev(self) -> Rev[Item, P]: @dataclass(repr=True) -class Filter[Item, P: bool](Iterator[Item, P]): +class Filter[Item, P: (True_, False_)](Iterator[Item, P]): iter: Iterator[Item, P] f: Fn1[Item, bool] par: P @@ -265,7 +267,7 @@ def next(self) -> Option[Item]: @dataclass(repr=True) -class FilterMap[Item, B, P: bool](Iterator[B, P]): +class FilterMap[Item, B, P: (True_, False_)](Iterator[B, P]): iter: Iterator[Item, P] f: Fn1[Item, Option[B]] par: P @@ -286,7 +288,7 @@ def next(self) -> Option[B]: @dataclass(repr=True) -class Map[Item, B, P: bool](Iterator[B, P]): +class Map[Item, B, P: (True_, False_)](Iterator[B, P]): iter: Iterator[Item, P] f: Fn1[Item, B] par: P @@ -297,7 +299,7 @@ def next(self) -> Option[B]: @dataclass(repr=True) -class Take[Item, P: bool](Iterator[Item, P]): +class Take[Item, P: (True_, False_)](Iterator[Item, P]): iter: Iterator[Item, P] _n: int par: P @@ -325,7 +327,7 @@ def nth(self, n: int) -> Option[Item]: @dataclass(repr=True) -class TakeWhile[Item, P: bool](Iterator[Item, P]): +class TakeWhile[Item, P: (True_, False_)](Iterator[Item, P]): iter: Iterator[Item, P] predicate: Fn1[Item, bool] par: P @@ -349,7 +351,7 @@ def next(self) -> Option[Item]: @dataclass(repr=True) -class Zip[A, B, P: bool](Iterator[tuple[A, B], P]): +class Zip[A, B, P: (True_, False_)](Iterator[tuple[A, B], P]): a: Iterator[A, P] b: Iterator[B, P] par: P @@ -371,7 +373,7 @@ def next(self) -> Option[tuple[A, B]]: # Parametrising the first generic of `Iterator` as `Any` to avoid a circular import. @dataclass(repr=True) -class ZipLongest[A, B, P: bool](Iterator[Any, P]): +class ZipLongest[A, B, P: (True_, False_)](Iterator[Any, P]): a: Iterator[A, P] b: Iterator[B, P] par: P @@ -393,7 +395,7 @@ def next(self) -> Option[tuple[A, B] | tuple[Null[A], B] | tuple[A, Null[B]]]: @dataclass(repr=True) -class Chain[A, P: bool](Iterator[A, P]): +class Chain[A, P: (True_, False_)](Iterator[A, P]): a: Iterator[A, P] b: Iterator[A, P] par: P @@ -410,7 +412,7 @@ def next(self) -> Option[A]: @dataclass(repr=True) -class Rev[Item, P: bool](Iterator[Item, P]): +class Rev[Item, P: (True_, False_)](Iterator[Item, P]): iter: DoubleEndedIterator[Item, P] par: P @@ -426,12 +428,27 @@ def extend_one(self, item: Item) -> None: self.extend(Some(item)) -@dataclass(repr=True, init=False) -class Bridge[Item, P: bool](Iterator[Item, P]): - def __init__(self, iter: NativeIterator[Item], par: P = False) -> None: - super().__init__() - self.iter = iter - self.par = par +@dataclass(repr=True, frozen=True, kw_only=True) +class Bridge[Item, Par: (True_, False_)](Iterator[Item, Par]): + """ + A bridge between native Python iterators and `serox` ones. + Can be parallel (`par = True`) or non-parallel (`par = False`). + """ + + iter: NativeIterator[Item] + """The native Python iterator being bridged.""" + par: Par + """Whether to parallelise the iterator.""" + + @classmethod + def new[Item2, Par2: (True_, False_)]( + cls, iter: NativeIterator[Item2], /, par: Par2 = True + ) -> Bridge[Item2, Par2]: + return Bridge(iter=iter, par=par) + + @classmethod + def par_new[Item2](cls, iter: NativeIterator[Item2], /) -> Bridge[Item2, True_]: + return Bridge(iter=iter, par=True) @override def next(self) -> Option[Item]: @@ -441,3 +458,22 @@ def next(self) -> Option[Item]: return Some(self.iter.__next__()) except StopIteration: return Null() + + +if TESTING: + + def test_par_invariance(): + from .collections import HashMap + from .vec import Vec + + values = Vec(*range(4)) + keys = ["foo", "bar", "baz"] + bridge = Bridge.new(iter(keys), par=False) + bridge = Bridge(iter=iter(keys), par=False) + mapped = values.iter().map(lambda x: x**2) + _ = bridge.zip(mapped).collect(HashMap[str, int]) + + bridge = Bridge.new(iter(keys), par=True) + # shouldn't be able to combine parallel iterators with non-parallel ones + # for consistent typing + _ = bridge.zip(values.iter()) # pyright: ignore[reportArgumentType] diff --git a/serox/option.py b/serox/option.py index 43b9a0d..a48b2d7 100644 --- a/serox/option.py +++ b/serox/option.py @@ -15,6 +15,7 @@ override, ) +from serox.common import False_, True_ from serox.convert import From, Into from serox.default import Default from serox.iter import DoubleEndedIterator, IntoIterator @@ -301,7 +302,7 @@ def is_null[T](x: Option[T], /) -> TypeGuard[Null[T]]: repr=True, slots=True, ) -class Iter[Item, P: bool](DoubleEndedIterator[Item, P]): +class Iter[Item, P: (True_, False_)](DoubleEndedIterator[Item, P]): item: Option[Item] par: P diff --git a/serox/range.py b/serox/range.py index 17f0e9d..bc41920 100644 --- a/serox/range.py +++ b/serox/range.py @@ -2,7 +2,7 @@ import copy from dataclasses import dataclass, field from types import EllipsisType -from typing import Literal, Sized, TypeVar, override +from typing import Sized, override from serox.conftest import TESTING from serox.fmt import Debug @@ -10,13 +10,12 @@ from serox.misc import Clone from serox.option import Null, Option, Some +from .common import False_, True_ + __all__ = ["Range"] type Idx = int -P = TypeVar("P", Literal[True], Literal[False]) -P2 = TypeVar("P2", Literal[True], Literal[False]) - @dataclass( eq=True, @@ -24,8 +23,8 @@ init=False, repr=False, ) -class Range( - DoubleEndedIterator[Idx, P], +class Range[Par: (True_, False_)]( + DoubleEndedIterator[Idx, Par], Clone, Debug, Sized, @@ -34,10 +33,10 @@ class Range( start: Idx = field(init=False) end: Idx = field(init=False) - par: P = field(init=False) + par: Par = field(init=False) _ptr: Idx = field(init=False) - def __init__(self, start: Idx | EllipsisType, end: Idx, *, par: P = False) -> None: + def __init__(self, start: Idx | EllipsisType, end: Idx, *, par: Par = False) -> None: start = 0 if start is ... else start object.__setattr__(self, "start", start) object.__setattr__(self, "end", end) @@ -46,7 +45,9 @@ def __init__(self, start: Idx | EllipsisType, end: Idx, *, par: P = False) -> No super().__init__() @classmethod - def new(cls, start: Idx | EllipsisType, end: Idx, *, par: P2 = False) -> Range[P2]: + def new[Par2: (True_, False_)]( + cls, start: Idx | EllipsisType, end: Idx, *, par: Par2 = False + ) -> Range[Par2]: return Range(start, end, par=par) @override @@ -66,8 +67,10 @@ def next_back(self) -> Option[Idx]: return Null() @classmethod - def from_sized(cls, sized: Sized, /, start: int = 0, *, par: P2 = False) -> Range[P2]: - return Range[P2](start, len(sized), par=par) + def from_sized[Par2: (True_, False_)]( + cls, sized: Sized, /, start: int = 0, *, par: Par2 = False + ) -> Range[Par2]: + return Range[Par2](start, len(sized), par=par) def contains(self, item: Idx, /) -> bool: return self.start <= item < self.end @@ -76,7 +79,7 @@ def is_empty(self) -> bool: return not (self.start < self.end) @override - def clone(self) -> Range[P]: + def clone(self) -> Range[Par]: return Range(start=self.start, end=self.end, par=self.par) def len(self) -> int: diff --git a/serox/vec.py b/serox/vec.py index a9f4344..011c7c8 100644 --- a/serox/vec.py +++ b/serox/vec.py @@ -1,12 +1,12 @@ from __future__ import annotations # noqa: I001 -from dataclasses import dataclass, field +from dataclasses import dataclass from random import Random as Rng +from .common import True_, False_ from typing import ( Any, Callable, Generator, Iterable, - Literal, Self, override, ) @@ -40,21 +40,29 @@ type Fn1[T, U] = Callable[[T], U] -@dataclass(repr=True) -class Iter[Item, P: bool](DoubleEndedIterator[Item, P]): +@dataclass(repr=True, frozen=True, kw_only=True) +class Iter[Item, P: (True_, False_)](DoubleEndedIterator[Item, P]): data: SizedIndexable[Item] par: P - _ptr: int = 0 - _end_or_len: int = field(init=False) + end_or_len: int + ptr: int - def __post_init__(self) -> None: - self._end_or_len = len(self.data) + @classmethod + def new[Item2, Par2: (True_, False_)]( + cls, data: SizedIndexable[Item2], /, par: Par2 = False + ) -> Iter[Item2, Par2]: + return Iter( + data=data, + par=par, + ptr=0, + end_or_len=len(data), + ) def _next(self, back: bool) -> Option[Item]: - if self._ptr < self._end_or_len: - ptr = -self._ptr if back else self._ptr + if self.ptr < self.end_or_len: + ptr = -self.ptr if back else self.ptr item = Some(self.data[ptr]) - self._ptr += 1 + object.__setattr__(self, "ptr", self.ptr + 1) return item return Null() @@ -100,7 +108,7 @@ def into(self) -> list[T]: return self.inner @override - def iter(self) -> DoubleEndedIterator[T, Literal[False]]: + def iter(self) -> DoubleEndedIterator[T, False_]: """ Returns an iterator over the underlying list. @@ -108,10 +116,10 @@ def iter(self) -> DoubleEndedIterator[T, Literal[False]]: :returns: A double-ended iterator over the underlying list. """ - return Iter(self, par=False) + return Iter.new(self, par=False) @override - def par_iter(self) -> DoubleEndedIterator[T, Literal[True]]: + def par_iter(self) -> DoubleEndedIterator[T, True_]: """ Returns a parallel iterator over the underlying list. @@ -119,7 +127,7 @@ def par_iter(self) -> DoubleEndedIterator[T, Literal[True]]: :returns: A parallel double-ended iterator over the underlying list. """ - return Iter(self, par=True) + return Iter.new(self, par=True) def __iter__(self) -> Generator[T, None, None]: yield from self.iter() @@ -236,7 +244,7 @@ def insert(self, index: int, element: T) -> None: @override @classmethod - def from_iter[P: bool](cls, s: Iterator[T, P], /) -> Vec[T]: + def from_iter[P: (True_, False_)](cls, s: Iterator[T, P], /) -> Vec[T]: return Vec(*s) @override @@ -261,7 +269,7 @@ def choose(self, rng: Rng) -> Option[T]: return Null() return Some(rng.choice(self.inner)) - def choose_multiple(self, rng: Rng, amount: int) -> Iter[T, Literal[False]]: + def choose_multiple(self, rng: Rng, amount: int) -> Iter[T, False_]: """ Emulates `SliceRandom::choose_multiple` from the `rand` crate. @@ -277,7 +285,7 @@ def choose_multiple(self, rng: Rng, amount: int) -> Iter[T, Literal[False]]: amount = min(amount, self.len()) # TODO: Sample indices instead, passing them to a dedicated iterator to lazily sample # elements from `inner`. - return Iter(rng.sample(self.inner, k=amount), par=False) + return Iter.new(rng.sample(self.inner, k=amount), par=False) @qmark def choose_multiple_weighted( @@ -285,7 +293,7 @@ def choose_multiple_weighted( rng: Rng, amount: int, weight: Fn1[T, float], - ) -> Result[Iter[T, Literal[False]], ValueError]: + ) -> Result[Iter[T, False_], ValueError]: """ Similar to :meth:`~Vec.choose_multiple`, but where the likelihood of each element’s inclusion in the output may be specified. The elements are returned in an arbitrary, @@ -313,7 +321,7 @@ def call(x: T) -> float: weights = self.iter().map(call).collect(Vec[float]).into() # cap the sample size at the population size amount = min(amount, self.len()) - return Ok(Iter(rng.choices(self.inner, weights=weights, k=amount), par=False)) + return Ok(Iter.new(rng.choices(self.inner, weights=weights, k=amount), par=False)) def shuffle(self, rng: Rng) -> None: rng.shuffle(self.inner) @@ -380,7 +388,7 @@ def test_vec(): del item def test_take(): - vec = Vec[float].full(3.14, 3) + vec = Vec[float].full(math.pi, 3) taken = vec.iter().take(2).collect(Vec[float]) assert len(taken) == 2 taken = vec.iter().take(5).collect(Vec[float]) @@ -388,10 +396,10 @@ def test_take(): def test_take_while(): vec = Vec[float](*range(5)) - taken = vec.iter().take_while(lambda x: x < 5).collect(Vec[float]) - assert len(taken) == vec.len() - taken = vec.iter().take_while(lambda x: x < 3).collect(Vec[float]) - assert len(taken) == 3 + lt5 = vec.iter().take_while(lambda x: x < 5).collect(Vec[float]) + assert len(lt5) == vec.len() + lt3 = vec.iter().take_while(lambda x: x < 3).collect(Vec[float]) + assert len(lt3) == 3 vec2 = Vec[Vec[int]](Vec(1, 2, 3)) ls = [2, 3, 4] vec.extend(ls)