From 164cf84122913044aa4f90d1ec97e5f1e29f2db4 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Tue, 16 Apr 2024 15:55:16 +0200 Subject: [PATCH] Implement non-inplace to_numpy for Batch * Breaking change: Previous in-place `Batch.to_numpy` is now `Batch.to_numpy_` (following naming convention of other in-place methods). * Update places where in-place was expected * Add tests for both to_numpy/to_numpy_ --- docs/01_tutorials/03_batch.rst | 4 ++-- docs/02_notebooks/L1_Batch.ipynb | 2 +- test/base/test_batch.py | 26 +++++++++++++++++++++++++- tianshou/data/batch.py | 31 ++++++++++++++++++++++++------- tianshou/data/utils/converter.py | 2 +- 5 files changed, 53 insertions(+), 12 deletions(-) diff --git a/docs/01_tutorials/03_batch.rst b/docs/01_tutorials/03_batch.rst index 71f82f84e..46fa86b3d 100644 --- a/docs/01_tutorials/03_batch.rst +++ b/docs/01_tutorials/03_batch.rst @@ -485,8 +485,8 @@ Miscellaneous Notes tensor([[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]]) - >>> # data.to_numpy is also available - >>> data.to_numpy() + >>> # data.to_numpy_ is also available + >>> data.to_numpy_() .. raw:: html diff --git a/docs/02_notebooks/L1_Batch.ipynb b/docs/02_notebooks/L1_Batch.ipynb index 9c80349da..54008ee64 100644 --- a/docs/02_notebooks/L1_Batch.ipynb +++ b/docs/02_notebooks/L1_Batch.ipynb @@ -331,7 +331,7 @@ }, "outputs": [], "source": [ - "batch_cat.to_numpy()\n", + "batch_cat.to_numpy_()\n", "print(batch_cat)\n", "batch_cat.to_torch()\n", "print(batch_cat)" diff --git a/test/base/test_batch.py b/test/base/test_batch.py index b694abcf5..0ce219e75 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -478,7 +478,7 @@ def test_batch_from_to_numpy_without_copy() -> None: a_mem_addr_orig = batch.a.__array_interface__["data"][0] c_mem_addr_orig = batch.b.c.__array_interface__["data"][0] batch.to_torch() - batch.to_numpy() + batch.to_numpy_() a_mem_addr_new = batch.a.__array_interface__["data"][0] c_mem_addr_new = batch.b.c.__array_interface__["data"][0] assert a_mem_addr_new == a_mem_addr_orig @@ -703,6 +703,30 @@ def test_to_dict_nested_batch_with_torch_tensor() -> None: assert not DeepDiff(batch.to_dict(recurse=True), expected) +class TestToNumpy: + """Tests for `Batch.to_numpy()` and its in-place counterpart `Batch.to_numpy_()` .""" + + @staticmethod + def test_to_numpy() -> None: + batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])}) + new_batch: Batch = Batch.to_numpy(batch) + assert id(batch) != id(new_batch) + assert isinstance(batch.b, torch.Tensor) + assert isinstance(batch.c.d, torch.Tensor) + + assert isinstance(new_batch.b, np.ndarray) + assert isinstance(new_batch.c.d, np.ndarray) + + @staticmethod + def test_to_numpy_() -> None: + batch = Batch(a=1, b=torch.arange(5), c={"d": torch.tensor([1, 2, 3])}) + id_batch = id(batch) + batch.to_numpy_() + assert id_batch == id(batch) + assert isinstance(batch.b, np.ndarray) + assert isinstance(batch.c.d, np.ndarray) + + if __name__ == "__main__": test_batch() test_batch_over_batch() diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 8e14dafd2..4eee0cd81 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -272,7 +272,12 @@ def __iter__(self) -> Iterator[Self]: def __eq__(self, other: Any) -> bool: ... - def to_numpy(self) -> None: + @staticmethod + def to_numpy(batch: TBatch) -> TBatch: + """Change all torch.Tensor to numpy.ndarray and return a new Batch.""" + ... + + def to_numpy_(self) -> None: """Change all torch.Tensor to numpy.ndarray in-place.""" ... @@ -508,10 +513,10 @@ def __eq__(self, other: Any) -> bool: if not isinstance(other, self.__class__): return False - self.to_numpy() - other.to_numpy() - this_dict = self.to_dict(recurse=True) - other_dict = other.to_dict(recurse=True) + this_batch_no_torch_tensor: Batch = Batch.to_numpy(self) + other_batch_no_torch_tensor: Batch = Batch.to_numpy(other) + this_dict = this_batch_no_torch_tensor.to_dict(recurse=True) + other_dict = other_batch_no_torch_tensor.to_dict(recurse=True) return not DeepDiff(this_dict, other_dict) @@ -614,12 +619,24 @@ def __repr__(self) -> str: self_str = self.__class__.__name__ + "()" return self_str - def to_numpy(self) -> None: + @staticmethod + def to_numpy(batch: TBatch) -> TBatch: + batch_dict = deepcopy(batch) + for batch_key, obj in batch_dict.items(): + if isinstance(obj, torch.Tensor): + batch_dict.__dict__[batch_key] = obj.detach().cpu().numpy() + elif isinstance(obj, Batch): + obj = Batch.to_numpy(obj) + batch_dict.__dict__[batch_key] = obj + + return batch_dict + + def to_numpy_(self) -> None: for batch_key, obj in self.items(): if isinstance(obj, torch.Tensor): self.__dict__[batch_key] = obj.detach().cpu().numpy() elif isinstance(obj, Batch): - obj.to_numpy() + obj.to_numpy_() def to_torch( self, diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py index 2df462da5..7edf3ff45 100644 --- a/tianshou/data/utils/converter.py +++ b/tianshou/data/utils/converter.py @@ -26,7 +26,7 @@ def to_numpy(x: Any) -> Batch | np.ndarray: return np.array(None, dtype=object) if isinstance(x, dict | Batch): x = Batch(x) if isinstance(x, dict) else deepcopy(x) - x.to_numpy() + x.to_numpy_() return x if isinstance(x, list | tuple): return to_numpy(_parse_value(x))