Skip to content

Commit

Permalink
Improve interface of BasePolicy.compute_action (#1170)
Browse files Browse the repository at this point in the history
  • Loading branch information
opcode81 authored Jul 18, 2024
1 parent 1ce9023 commit 7a4e5f1
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
- policy:
- introduced attribute `in_training_step` that is controlled by the trainer. #1123
- policy automatically set to `eval` mode when collecting and to `train` mode when updating. #1123
- Extended interface of `compute_action` to also support array-like inputs #1169
- `highlevel`:
- `SamplingConfig`:
- Add support for `batch_size=None`. #1077
Expand Down
7 changes: 4 additions & 3 deletions tianshou/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from gymnasium.spaces import Box, Discrete, MultiBinary, MultiDiscrete
from numba import njit
from numpy.typing import ArrayLike
from overrides import override
from torch import nn

Expand Down Expand Up @@ -289,7 +290,7 @@ def soft_update(self, tgt: nn.Module, src: nn.Module, tau: float) -> None:

def compute_action(
self,
obs: arr_type,
obs: ArrayLike,
info: dict[str, Any] | None = None,
state: dict | BatchProtocol | np.ndarray | None = None,
) -> np.ndarray | int:
Expand All @@ -300,8 +301,8 @@ def compute_action(
:param state: the hidden state of RNN policy, used for recurrent policy.
:return: action as int (for discrete env's) or array (for continuous ones).
"""
# need to add empty batch dimension
obs = obs[None, :]
obs = np.array(obs) # convert array-like to array (e.g. LazyFrames)
obs = obs[None, :] # add batch dimension
obs_batch = cast(ObsBatchProtocol, Batch(obs=obs, info=info))
act = self.forward(obs_batch, state=state).act.squeeze()
if isinstance(act, torch.Tensor):
Expand Down

0 comments on commit 7a4e5f1

Please sign in to comment.