Skip to content

Commit

Permalink
Refactoring/remove is empty batch (#1144)
Browse files Browse the repository at this point in the history
Closes: #1108

### API Update

- Method `is_empty` has been dropped for Batch. #1144

### Breaking Changes

- The method `Batch.is_empty` has been removed. Instead, the user can
simply check for emptiness of Batch by using `len` on dicts. #1144
  • Loading branch information
dantp-ai authored Jul 18, 2024
1 parent 7a4e5f1 commit 3dad4af
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 94 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ instead of just `nn.Module`. #1032
Can be considered a bugfix. #1063
- The methods `to_numpy` and `to_torch` in are not in-place anymore
(use `to_numpy_` or `to_torch_` instead). #1098, #1117
- The method `Batch.is_empty` has been removed. Instead, the user can simply check for emptiness of Batch by using `len` on dicts. #1144
- Logging:
- `BaseLogger.prepare_dict_for_logging` is now abstract. #1074
- Removed deprecated and unused `BasicLogger` (only affects users who subclassed it). #1074
Expand Down
18 changes: 9 additions & 9 deletions docs/01_tutorials/03_batch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -324,35 +324,35 @@ Still, we can use a tree (in the right) to show the structure of ``Batch`` objec

Reserved keys mean that in the future there will eventually be values attached to them. The values can be scalars, tensors, or even **Batch** objects. Understanding this is critical to understand the behavior of ``Batch`` when dealing with heterogeneous Batches.

The introduction of reserved keys gives rise to the need to check if a key is reserved. Tianshou provides ``Batch.is_empty`` to achieve this.
The introduction of reserved keys gives rise to the need to check if a key is reserved.

.. raw:: html

<details>
<summary>Examples of Batch.is_empty</summary>
<summary>Examples of checking whether Batch is empty</summary>

.. code-block:: python
>>> Batch().is_empty()
>>> len(Batch().get_keys()) == 0
True
>>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
>>> len(Batch(a=Batch(), b=Batch(c=Batch())).get_keys()) == 0
False
>>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
>>> len(Batch(a=Batch(), b=Batch(c=Batch()))) == 0
True
>>> Batch(d=1).is_empty()
>>> len(Batch(d=1).get_keys()) == 0
False
>>> Batch(a=np.float64(1.0)).is_empty()
>>> len(Batch(a=np.float64(1.0)).get_keys()) == 0
False
.. raw:: html

</details><br>

The ``Batch.is_empty`` function has an option to decide whether to identify direct emptiness (just a ``Batch()``) or to identify recursive emptiness (a ``Batch`` object without any scalar/tensor leaf nodes).
To check whether a Batch is empty, simply use ``len(Batch.get_keys()) == 0`` to decide whether to identify direct emptiness (just a ``Batch()``) or ``len(Batch) == 0`` to identify recursive emptiness (a ``Batch`` object without any scalar/tensor leaf nodes).

.. note::

Do not get confused with ``Batch.is_empty`` and ``Batch.empty``. ``Batch.empty`` and its in-place variant ``Batch.empty_`` are used to set some values to zeros or None. Check the API documentation for further details.
Do not get confused with ``Batch.empty``. ``Batch.empty`` and its in-place variant ``Batch.empty_`` are used to set some values to zeros or None. Check the API documentation for further details.


Length and Shape
Expand Down
36 changes: 18 additions & 18 deletions test/base/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@

def test_batch() -> None:
assert list(Batch()) == []
assert Batch().is_empty()
assert not Batch(b={"c": {}}).is_empty()
assert Batch(b={"c": {}}).is_empty(recurse=True)
assert not Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
assert Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
assert not Batch(d=1).is_empty()
assert not Batch(a=np.float64(1.0)).is_empty()
assert len(Batch().get_keys()) == 0
assert len(Batch(b={"c": {}}).get_keys()) != 0
assert len(Batch(b={"c": {}})) == 0
assert len(Batch(a=Batch(), b=Batch(c=Batch())).get_keys()) != 0
assert len(Batch(a=Batch(), b=Batch(c=Batch()))) == 0
assert len(Batch(d=1).get_keys()) != 0
assert len(Batch(a=np.float64(1.0)).get_keys()) != 0
assert len(Batch(a=[1, 2, 3], b={"c": {}})) == 3
assert not Batch(a=[1, 2, 3]).is_empty()
assert len(Batch(a=[1, 2, 3]).get_keys()) != 0
b = Batch({"a": [4, 4], "b": [5, 5]}, c=[None, None])
assert b.c.dtype == object
b = Batch(d=[None], e=[starmap], f=Batch)
assert b.d.dtype == b.e.dtype == object
assert b.f == Batch
b = Batch()
b.update()
assert b.is_empty()
assert len(b.get_keys()) == 0
b.update(c=[3, 5])
assert np.allclose(b.c, [3, 5])
# mimic the behavior of dict.update, where kwargs can overwrite keys
Expand Down Expand Up @@ -141,7 +141,7 @@ def test_batch() -> None:
assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2
assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2
assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2
assert batch2_sum.a.d.f.is_empty()
assert len(batch2_sum.a.d.f.get_keys()) == 0
with pytest.raises(TypeError):
batch2 += [1] # type: ignore # error is raised explicitly
batch3 = Batch(a={"c": np.zeros(1), "d": Batch(e=np.array([0.0]), f=np.array([3.0]))})
Expand Down Expand Up @@ -255,7 +255,7 @@ def test_batch_cat_and_stack() -> None:
ans = Batch.cat([a, b, a])
assert np.allclose(ans.a.a, np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a]))
assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b]))
assert ans.a.t.is_empty()
assert len(ans.a.t.get_keys()) == 0

b1.stack_([b2])
assert isinstance(b1.a.d.e, np.ndarray)
Expand Down Expand Up @@ -296,7 +296,7 @@ def test_batch_cat_and_stack() -> None:
b=torch.cat([torch.zeros(3, 3), b2.b]),
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])),
)
assert ans.a.is_empty()
assert len(ans.a.get_keys()) == 0
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)

Expand Down Expand Up @@ -325,7 +325,7 @@ def test_batch_cat_and_stack() -> None:
assert np.allclose(d.d, [0, 6, 9])

# test stack with empty Batch()
assert Batch.stack([Batch(), Batch(), Batch()]).is_empty()
assert len(Batch.stack([Batch(), Batch(), Batch()]).get_keys()) == 0
a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch())
b = Batch(a=4, b=5, d=6, e=Batch())
c = Batch(c=7, b=6, d=9, e=Batch())
Expand All @@ -334,12 +334,12 @@ def test_batch_cat_and_stack() -> None:
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
assert d.e.is_empty()
assert len(d.e.get_keys()) == 0
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5)))
b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5)))
test = Batch.stack([b1, b2], axis=-1)
assert test.a.is_empty()
assert test.b.is_empty()
assert len(test.a.get_keys()) == 0
assert len(test.b.get_keys()) == 0
assert np.allclose(test.common.c, np.stack([b1.common.c, b2.common.c], axis=-1))

b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5)))
Expand All @@ -362,9 +362,9 @@ def test_batch_cat_and_stack() -> None:

# exceptions
batch_cat: Batch = Batch.cat([])
assert batch_cat.is_empty()
assert len(batch_cat.get_keys()) == 0
batch_stack: Batch = Batch.stack([])
assert batch_stack.is_empty()
assert len(batch_stack.get_keys()) == 0
b1 = Batch(e=[4, 5], d=6)
b2 = Batch(e=[4, 6])
with pytest.raises(ValueError):
Expand Down
4 changes: 2 additions & 2 deletions test/base/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,5 +1378,5 @@ def test_custom_key() -> None:
sampled_batch.__dict__[key],
Batch,
):
assert batch.__dict__[key].is_empty()
assert sampled_batch.__dict__[key].is_empty()
assert len(batch.__dict__[key].get_keys()) == 0
assert len(sampled_batch.__dict__[key].get_keys()) == 0
78 changes: 19 additions & 59 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def alloc_by_keys_diff(
if key in meta.get_keys():
if isinstance(meta[key], Batch) and isinstance(batch[key], Batch):
alloc_by_keys_diff(meta[key], batch[key], size, stack)
elif isinstance(meta[key], Batch) and meta[key].is_empty():
elif isinstance(meta[key], Batch) and len(meta[key].get_keys()) == 0:
meta[key] = create_value(batch[key], size, stack)
else:
meta[key] = create_value(batch[key], size, stack)
Expand Down Expand Up @@ -393,9 +393,6 @@ def update(self, batch: dict | Self | None = None, **kwargs: Any) -> None:
def __len__(self) -> int:
...

def is_empty(self, recurse: bool = False) -> bool:
...

def split(
self,
size: int,
Expand Down Expand Up @@ -514,7 +511,7 @@ def __getitem__(self, index: str | IndexType) -> Any:
if len(batch_items) > 0:
new_batch = Batch()
for batch_key, obj in batch_items:
if isinstance(obj, Batch) and obj.is_empty():
if isinstance(obj, Batch) and len(obj.get_keys()) == 0:
new_batch.__dict__[batch_key] = Batch()
else:
new_batch.__dict__[batch_key] = obj[index]
Expand Down Expand Up @@ -574,13 +571,13 @@ def __iadd__(self, other: Self | Number | np.number) -> Self:
other.__dict__.values(),
strict=True,
): # TODO are keys consistent?
if isinstance(obj, Batch) and obj.is_empty():
if isinstance(obj, Batch) and len(obj.get_keys()) == 0:
continue
self.__dict__[batch_key] += value
return self
if _is_number(other):
for batch_key, obj in self.items():
if isinstance(obj, Batch) and obj.is_empty():
if isinstance(obj, Batch) and len(obj.get_keys()) == 0:
continue
self.__dict__[batch_key] += other
return self
Expand All @@ -594,7 +591,7 @@ def __imul__(self, value: Number | np.number) -> Self:
"""Algebraic multiplication with a scalar value in-place."""
assert _is_number(value), "Only multiplication by a number is supported."
for batch_key, obj in self.__dict__.items():
if isinstance(obj, Batch) and obj.is_empty():
if isinstance(obj, Batch) and len(obj.get_keys()) == 0:
continue
self.__dict__[batch_key] *= value
return self
Expand All @@ -607,7 +604,7 @@ def __itruediv__(self, value: Number | np.number) -> Self:
"""Algebraic division with a scalar value in-place."""
assert _is_number(value), "Only division by a number is supported."
for batch_key, obj in self.__dict__.items():
if isinstance(obj, Batch) and obj.is_empty():
if isinstance(obj, Batch) and len(obj.get_keys()) == 0:
continue
self.__dict__[batch_key] /= value
return self
Expand Down Expand Up @@ -722,7 +719,7 @@ def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None:
{
batch_key
for batch_key, obj in batch.items()
if not (isinstance(obj, Batch) and obj.is_empty())
if not (isinstance(obj, Batch) and len(obj.get_keys()) == 0)
}
for batch in batches
]
Expand Down Expand Up @@ -753,7 +750,7 @@ def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None:
if key not in batch.__dict__:
continue
value = batch.get(key)
if isinstance(value, Batch) and value.is_empty():
if isinstance(value, Batch) and len(value.get_keys()) == 0:
continue
try:
self.__dict__[key][sum_lens[i] : sum_lens[i + 1]] = value
Expand All @@ -771,28 +768,27 @@ def cat_(self, batches: BatchProtocol | Sequence[dict | BatchProtocol]) -> None:
if len(batch) > 0:
batch_list.append(Batch(batch))
elif isinstance(batch, Batch):
# x.is_empty() means that x is Batch() and should be ignored
if not batch.is_empty():
if len(batch.get_keys()) != 0:
batch_list.append(batch)
else:
raise ValueError(f"Cannot concatenate {type(batch)} in Batch.cat_")
if len(batch_list) == 0:
return
batches = batch_list
try:
# x.is_empty(recurse=True) here means x is a nested empty batch
# len(batch) here means batch is a nested empty batch
# like Batch(a=Batch), and we have to treat it as length zero and
# keep it.
lens = [0 if batch.is_empty(recurse=True) else len(batch) for batch in batches]
lens = [0 if len(batch) == 0 else len(batch) for batch in batches]
except TypeError as exception:
raise ValueError(
"Batch.cat_ meets an exception. Maybe because there is any "
f"scalar in {batches} but Batch.cat_ does not support the "
"concatenation of scalar.",
) from exception
if not self.is_empty():
if len(self.get_keys()) != 0:
batches = [self, *list(batches)]
lens = [0 if self.is_empty(recurse=True) else len(self), *lens]
lens = [0 if len(self) == 0 else len(self), *lens]
self.__cat(batches, lens)

@staticmethod
Expand All @@ -809,22 +805,21 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None
if len(batch) > 0:
batch_list.append(Batch(batch))
elif isinstance(batch, Batch):
# x.is_empty() means that x is Batch() and should be ignored
if not batch.is_empty():
if len(batch.get_keys()) != 0:
batch_list.append(batch)
else:
raise ValueError(f"Cannot concatenate {type(batch)} in Batch.stack_")
if len(batch_list) == 0:
return
batches = batch_list
if not self.is_empty():
if len(self.get_keys()) != 0:
batches = [self, *batches]
# collect non-empty keys
keys_map = [
{
batch_key
for batch_key, obj in batch.items()
if not (isinstance(obj, BatchProtocol) and obj.is_empty())
if not (isinstance(obj, BatchProtocol) and len(obj.get_keys()) == 0)
}
for batch in batches
]
Expand Down Expand Up @@ -870,7 +865,7 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None
# TODO: fix code/annotations s.t. the ignores can be removed
if (
isinstance(value, BatchProtocol) # type: ignore
and value.is_empty() # type: ignore
and len(value.get_keys()) == 0 # type: ignore
):
continue # type: ignore
try:
Expand Down Expand Up @@ -930,7 +925,7 @@ def __len__(self) -> int:
# TODO: causes inconsistent behavior to batch with empty batches
# and batch with empty sequences of other type. Remove, but only after
# Buffer and Collectors have been improved to no longer rely on this
if isinstance(obj, Batch) and obj.is_empty(recurse=True):
if isinstance(obj, Batch) and len(obj) == 0:
continue
if hasattr(obj, "__len__") and (isinstance(obj, Batch) or obj.ndim > 0):
lens.append(len(obj))
Expand All @@ -940,45 +935,10 @@ def __len__(self) -> int:
return 0
return min(lens)

def is_empty(self, recurse: bool = False) -> bool:
"""Test if a Batch is empty.
If ``recurse=True``, it further tests the values of the object; else
it only tests the existence of any key.
``b.is_empty(recurse=True)`` is mainly used to distinguish
``Batch(a=Batch(a=Batch()))`` and ``Batch(a=1)``. They both raise
exceptions when applied to ``len()``, but the former can be used in
``cat``, while the latter is a scalar and cannot be used in ``cat``.
Another usage is in ``__len__``, where we have to skip checking the
length of recursively empty Batch.
::
>>> Batch().is_empty()
True
>>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty()
False
>>> Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True)
True
>>> Batch(d=1).is_empty()
False
>>> Batch(a=np.float64(1.0)).is_empty()
False
"""
if len(self.__dict__) == 0:
return True
if not recurse:
return False
return all(
False if not isinstance(obj, Batch) else obj.is_empty(recurse=True)
for obj in self.values()
)

@property
def shape(self) -> list[int]:
"""Return self.shape."""
if self.is_empty():
if len(self.get_keys()) == 0:
return []
data_shape = []
for obj in self.__dict__.values():
Expand Down
6 changes: 3 additions & 3 deletions tianshou/data/buffer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def update(self, buffer: "ReplayBuffer") -> np.ndarray:
self._index = (self._index + 1) % self.maxsize
self._size = min(self._size + 1, self.maxsize)
to_indices = np.array(to_indices)
if self._meta.is_empty():
if len(self._meta.get_keys()) == 0:
self._meta = create_value(buffer._meta, self.maxsize, stack=False) # type: ignore
self._meta[to_indices] = buffer._meta[from_indices]
return to_indices
Expand Down Expand Up @@ -284,7 +284,7 @@ def add(
batch.done = batch.done.astype(bool)
batch.terminated = batch.terminated.astype(bool)
batch.truncated = batch.truncated.astype(bool)
if self._meta.is_empty():
if len(self._meta.get_keys()) == 0:
self._meta = create_value(batch, self.maxsize, stack) # type: ignore
else: # dynamic key pops up in batch
alloc_by_keys_diff(self._meta, batch, self.maxsize, stack)
Expand Down Expand Up @@ -377,7 +377,7 @@ def get(
return np.stack(stack, axis=indices.ndim)

except IndexError as exception:
if not (isinstance(val, Batch) and val.is_empty()):
if not (isinstance(val, Batch) and len(val.get_keys()) == 0):
raise exception # val != Batch()
return Batch()

Expand Down
Loading

0 comments on commit 3dad4af

Please sign in to comment.