diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 5f84dfa..661fbe9 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -6,6 +6,7 @@ from typing import ( Any, Protocol, + Self, TypeVar, Union, cast, @@ -232,7 +233,7 @@ def __getitem__(self, index: str) -> Any: ... @overload - def __getitem__(self: TBatch, index: IndexType) -> TBatch: + def __getitem__(self, index: IndexType) -> Self: ... def __getitem__(self, index: str | IndexType) -> Any: @@ -241,22 +242,22 @@ def __getitem__(self, index: str | IndexType) -> Any: def __setitem__(self, index: str | IndexType, value: Any) -> None: ... - def __iadd__(self: TBatch, other: TBatch | Number | np.number) -> TBatch: + def __iadd__(self, other: Self | Number | np.number) -> Self: ... - def __add__(self: TBatch, other: TBatch | Number | np.number) -> TBatch: + def __add__(self, other: Self | Number | np.number) -> Self: ... - def __imul__(self: TBatch, value: Number | np.number) -> TBatch: + def __imul__(self, value: Number | np.number) -> Self: ... - def __mul__(self: TBatch, value: Number | np.number) -> TBatch: + def __mul__(self, value: Number | np.number) -> Self: ... - def __itruediv__(self: TBatch, value: Number | np.number) -> TBatch: + def __itruediv__(self, value: Number | np.number) -> Self: ... - def __truediv__(self: TBatch, value: Number | np.number) -> TBatch: + def __truediv__(self, value: Number | np.number) -> Self: ... def __repr__(self) -> str: @@ -274,7 +275,7 @@ def to_torch( """Change all numpy.ndarray to torch.Tensor in-place.""" ... - def cat_(self, batches: TBatch | Sequence[dict | TBatch]) -> None: + def cat_(self, batches: Self | Sequence[dict | Self]) -> None: """Concatenate a list of (or one) Batch objects into current batch.""" ... @@ -298,7 +299,7 @@ def cat(batches: Sequence[dict | TBatch]) -> TBatch: """ ... - def stack_(self, batches: Sequence[dict | TBatch], axis: int = 0) -> None: + def stack_(self, batches: Sequence[dict | Self], axis: int = 0) -> None: """Stack a list of Batch object into current batch.""" ... @@ -327,7 +328,7 @@ def stack(batches: Sequence[dict | TBatch], axis: int = 0) -> TBatch: """ ... - def empty_(self: TBatch, index: slice | IndexType | None = None) -> TBatch: + def empty_(self, index: slice | IndexType | None = None) -> Self: """Return an empty Batch object with 0 or None filled. If "index" is specified, it will only reset the specific indexed-data. @@ -362,7 +363,7 @@ def empty(batch: TBatch, index: IndexType | None = None) -> TBatch: """ ... - def update(self, batch: dict | TBatch | None = None, **kwargs: Any) -> None: + def update(self, batch: dict | Self | None = None, **kwargs: Any) -> None: """Update this batch from another dict/Batch.""" ... @@ -373,11 +374,11 @@ def is_empty(self, recurse: bool = False) -> bool: ... def split( - self: TBatch, + self, size: int, shuffle: bool = True, merge_last: bool = False, - ) -> Iterator[TBatch]: + ) -> Iterator[Self]: """Split whole data into multiple small batches. :param int size: divide the data batch with the given size, but one @@ -457,7 +458,7 @@ def __getitem__(self, index: str) -> Any: ... @overload - def __getitem__(self: TBatch, index: IndexType) -> TBatch: + def __getitem__(self, index: IndexType) -> Self: ... def __getitem__(self, index: str | IndexType) -> Any: @@ -501,7 +502,7 @@ def __setitem__(self, index: str | IndexType, value: Any) -> None: else: self.__dict__[key][index] = None - def __iadd__(self: TBatch, other: TBatch | Number | np.number) -> TBatch: + def __iadd__(self, other: Self | Number | np.number) -> Self: """Algebraic addition with another Batch instance in-place.""" if isinstance(other, Batch): for (batch_key, obj), value in zip( @@ -521,11 +522,11 @@ def __iadd__(self: TBatch, other: TBatch | Number | np.number) -> TBatch: return self raise TypeError("Only addition of Batch or number is supported.") - def __add__(self: TBatch, other: TBatch | Number | np.number) -> TBatch: + def __add__(self, other: Self | Number | np.number) -> Self: """Algebraic addition with another Batch instance out-of-place.""" return deepcopy(self).__iadd__(other) - def __imul__(self: TBatch, value: Number | np.number) -> TBatch: + def __imul__(self, value: Number | np.number) -> Self: """Algebraic multiplication with a scalar value in-place.""" assert _is_number(value), "Only multiplication by a number is supported." for batch_key, obj in self.__dict__.items(): @@ -534,11 +535,11 @@ def __imul__(self: TBatch, value: Number | np.number) -> TBatch: self.__dict__[batch_key] *= value return self - def __mul__(self: TBatch, value: Number | np.number) -> TBatch: + def __mul__(self, value: Number | np.number) -> Self: """Algebraic multiplication with a scalar value out-of-place.""" return deepcopy(self).__imul__(value) - def __itruediv__(self: TBatch, value: Number | np.number) -> TBatch: + def __itruediv__(self, value: Number | np.number) -> Self: """Algebraic division with a scalar value in-place.""" assert _is_number(value), "Only division by a number is supported." for batch_key, obj in self.__dict__.items(): @@ -547,7 +548,7 @@ def __itruediv__(self: TBatch, value: Number | np.number) -> TBatch: self.__dict__[batch_key] /= value return self - def __truediv__(self: TBatch, value: Number | np.number) -> TBatch: + def __truediv__(self, value: Number | np.number) -> Self: """Algebraic division with a scalar value out-of-place.""" return deepcopy(self).__itruediv__(value) @@ -604,7 +605,7 @@ def to_torch( obj = obj.type(dtype) # noqa: PLW2901 self.__dict__[batch_key] = obj - def __cat(self: TBatch, batches: Sequence[dict | TBatch], lens: list[int]) -> None: + def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None: """Private method for Batch.cat_. :: @@ -798,7 +799,7 @@ def stack(batches: Sequence[dict | TBatch], axis: int = 0) -> TBatch: # can't cast to a generic type, so we have to ignore the type here return batch # type: ignore - def empty_(self: TBatch, index: slice | IndexType | None = None) -> TBatch: + def empty_(self, index: slice | IndexType | None = None) -> Self: for batch_key, obj in self.items(): if isinstance(obj, torch.Tensor): # most often case self.__dict__[batch_key][index] = 0 @@ -826,7 +827,7 @@ def empty_(self: TBatch, index: slice | IndexType | None = None) -> TBatch: def empty(batch: TBatch, index: IndexType | None = None) -> TBatch: return deepcopy(batch).empty_(index) - def update(self, batch: dict | TBatch | None = None, **kwargs: Any) -> None: + def update(self, batch: dict | Self | None = None, **kwargs: Any) -> None: if batch is None: self.update(kwargs) return @@ -902,11 +903,11 @@ def shape(self) -> list[int]: ) def split( - self: TBatch, + self, size: int, shuffle: bool = True, merge_last: bool = False, - ) -> Iterator[TBatch]: + ) -> Iterator[Self]: length = len(self) if size == -1: size = length diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index 0a55300..3609b79 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -1,4 +1,4 @@ -from typing import Any, cast +from typing import Any, Self, cast import h5py import numpy as np @@ -111,7 +111,7 @@ def save_hdf5(self, path: str, compression: str | None = None) -> None: to_hdf5(self.__dict__, f, compression=compression) @classmethod - def load_hdf5(cls, path: str, device: str | None = None) -> "ReplayBuffer": + def load_hdf5(cls, path: str, device: str | None = None) -> Self: """Load replay buffer from HDF5 file.""" with h5py.File(path, "r") as f: buf = cls.__new__(cls) @@ -128,7 +128,7 @@ def from_data( truncated: h5py.Dataset, done: h5py.Dataset, obs_next: h5py.Dataset, - ) -> "ReplayBuffer": + ) -> Self: size = len(obs) assert all( len(dset) == size for dset in [obs, act, rew, terminated, truncated, done, obs_next] diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index fcf8653..7bca1db 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -62,13 +62,13 @@ def __init__( policy: BasePolicy, env: gym.Env | BaseVectorEnv, buffer: ReplayBuffer | None = None, - preprocess_fn: Callable[..., Batch] | None = None, + preprocess_fn: Callable[..., RolloutBatchProtocol] | None = None, exploration_noise: bool = False, ) -> None: super().__init__() if isinstance(env, gym.Env) and not hasattr(env, "__len__"): warnings.warn("Single environment detected, wrap to DummyVectorEnv.") - self.env = DummyVectorEnv([lambda: env]) + self.env = DummyVectorEnv([lambda: env]) # type: ignore else: self.env = env # type: ignore self.env_num = len(self.env) @@ -413,7 +413,7 @@ def __init__( policy: BasePolicy, env: BaseVectorEnv, buffer: ReplayBuffer | None = None, - preprocess_fn: Callable[..., Batch] | None = None, + preprocess_fn: Callable[..., RolloutBatchProtocol] | None = None, exploration_noise: bool = False, ) -> None: # assert env.is_async diff --git a/tianshou/env/utils.py b/tianshou/env/utils.py index bc66321..b5be12f 100644 --- a/tianshou/env/utils.py +++ b/tianshou/env/utils.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Union +from typing import Any import cloudpickle import gymnasium @@ -6,12 +6,7 @@ from tianshou.env.pettingzoo_env import PettingZooEnv -if TYPE_CHECKING: - import gym - -# TODO: remove gym entirely? Currently mypy complains in several places -# if gym.Env is removed from the Union -ENV_TYPE = Union[gymnasium.Env, "gym.Env", PettingZooEnv] +ENV_TYPE = gymnasium.Env | PettingZooEnv gym_new_venv_step_type = tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray] diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 44cef97..a779a76 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -11,7 +11,7 @@ from torch import nn from tianshou.data import ReplayBuffer, to_numpy, to_torch_as -from tianshou.data.batch import BatchProtocol, TBatch +from tianshou.data.batch import BatchProtocol from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol from tianshou.utils import MultipleLRSchedulers @@ -185,7 +185,7 @@ def forward( """ @overload - def map_action(self, act: TBatch) -> TBatch: + def map_action(self, act: BatchProtocol) -> BatchProtocol: ... @overload