-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bugfix/batch eq for scalar #1186
Bugfix/batch eq for scalar #1186
Conversation
* Seems that batch slicing leads to slightly different floats some of the time (See: thu-ml#1181)
…-ai/tianshou into bugfix/batch_eq_for_scalar
assert ( | ||
Categorical(probs=sliced_probs).probs == get_sliced_dist(dist, selected_idx).probs | ||
).all() | ||
assert torch.allclose(sliced_batch.dist.probs, Categorical(probs=sliced_probs).probs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx! I guess it was failing randomly, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it was sporadically returning inequality. I will take a look at getitem soon as I have an overdue task there with maintaining slicing consistency. Then I can also check what causes slicing to make these little decimal errors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, appreciate it! I don't think the getitem issue can actually be solved, but implementing eq already solves all practical problems of it. So no need to further address it, unless you see some simple way for doing that.
The order of slicing however it's still a valid issue
Thx for the PR! All looks good, merging now |
Fixes: #1182
Note: Updated
test_batch.test_slice_distribution()
to use allclose (See: #1181).