Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/batch eq for scalar #1186

Merged
merged 4 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
- Fix `output_dim` not being set if `features_only`=True and `output_dim_added_layer` is not None #1128
- `PPOPolicy`:
- Fix `max_batchsize` not being used in `logp_old` computation inside `process_fn` #1168
- Fix `Batch.__eq__` to allow comparing Batches with scalar array values #1185

### Internal Improvements
- `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063
Expand Down Expand Up @@ -108,6 +109,7 @@ continuous and discrete cases. #1032
- Fixed env seeding it `test_sac_with_il.py` so that the test doesn't fail randomly. #1081
- Improved CI triggers and added telemetry (if requested by user) #1177
- Improved environment used in tests.
- Improved tests bach equality to check with scalar values #1185

### Dependencies
- [DeepDiff](https://github.com/seperman/deepdiff) added to help with diffs of batches in tests. #1098
Expand Down
31 changes: 27 additions & 4 deletions test/base/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,28 @@ def test_nested_shapes_different() -> None:
batch2 = Batch(a=Batch(a=[1, 4]), b=[4, 5])
assert batch1 != batch2

@staticmethod
def test_array_scalars() -> None:
batch1 = Batch(a={"b": 1})
batch2 = Batch(a={"b": 1})
assert batch1 == batch2

batch3 = Batch(a={"c": 2})
assert batch1 != batch3

batch4 = Batch(b={"b": 1})
assert batch1 != batch4

batch5 = Batch(a={"b": 10})
assert batch1 != batch5

batch6 = Batch(a={"b": [1]})
assert batch1 == batch6

batch7 = Batch(a=1, b=5)
batch8 = Batch(a=1, b=5)
assert batch7 == batch8

@staticmethod
def test_slice_equal() -> None:
batch1 = Batch(a=[1, 2, 3])
Expand Down Expand Up @@ -837,10 +859,11 @@ def test_slice_distribution() -> None:
selected_idx = [1, 3]
sliced_batch = batch[selected_idx]
sliced_probs = cat_probs[selected_idx]
assert (sliced_batch.dist.probs == Categorical(probs=sliced_probs).probs).all()
assert (
Categorical(probs=sliced_probs).probs == get_sliced_dist(dist, selected_idx).probs
).all()
assert torch.allclose(sliced_batch.dist.probs, Categorical(probs=sliced_probs).probs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx! I guess it was failing randomly, right?

Copy link
Contributor Author

@dantp-ai dantp-ai Aug 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it was sporadically returning inequality. I will take a look at getitem soon as I have an overdue task there with maintaining slicing consistency. Then I can also check what causes slicing to make these little decimal errors.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, appreciate it! I don't think the getitem issue can actually be solved, but implementing eq already solves all practical problems of it. So no need to further address it, unless you see some simple way for doing that.

The order of slicing however it's still a valid issue

assert torch.allclose(
Categorical(probs=sliced_probs).probs,
get_sliced_dist(dist, selected_idx).probs,
)
# retrieving a single index
assert torch.allclose(batch[0].dist.probs, dist.probs[0])

Expand Down
10 changes: 10 additions & 0 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,16 @@ def __eq__(self, other: Any) -> bool:

this_batch_no_torch_tensor: Batch = Batch.to_numpy(self)
other_batch_no_torch_tensor: Batch = Batch.to_numpy(other)
# DeepDiff 7.0.1 cannot compare 0-dimensional arrays
# so, we ensure with this transform that all array values have at least 1 dim
this_batch_no_torch_tensor.apply_values_transform(
values_transform=np.atleast_1d,
inplace=True,
)
other_batch_no_torch_tensor.apply_values_transform(
values_transform=np.atleast_1d,
inplace=True,
)
this_dict = this_batch_no_torch_tensor.to_dict(recursive=True)
other_dict = other_batch_no_torch_tensor.to_dict(recursive=True)

Expand Down
Loading