Skip to content

Commit

Permalink
Improve interface of BasePolicy.compute_action #1169
Browse files Browse the repository at this point in the history
  • Loading branch information
opcode81 committed Jul 8, 2024
1 parent 9362744 commit c7b83f8
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions tianshou/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from gymnasium.spaces import Box, Discrete, MultiBinary, MultiDiscrete
from numba import njit
from numpy._typing import ArrayLike
from overrides import override
from torch import nn

Expand Down Expand Up @@ -289,7 +290,7 @@ def soft_update(self, tgt: nn.Module, src: nn.Module, tau: float) -> None:

def compute_action(
self,
obs: arr_type,
obs: ArrayLike,
info: dict[str, Any] | None = None,
state: dict | BatchProtocol | np.ndarray | None = None,
) -> np.ndarray | int:
Expand All @@ -300,8 +301,8 @@ def compute_action(
:param state: the hidden state of RNN policy, used for recurrent policy.
:return: action as int (for discrete env's) or array (for continuous ones).
"""
# need to add empty batch dimension
obs = obs[None, :]
obs = np.array(obs) # convert array-like to array (e.g. LazyFrames)
obs = obs[None, :] # need to add empty batch dimension
obs_batch = cast(ObsBatchProtocol, Batch(obs=obs, info=info))
act = self.forward(obs_batch, state=state).act.squeeze()
if isinstance(act, torch.Tensor):
Expand Down

0 comments on commit c7b83f8

Please sign in to comment.