Skip to content

Commit

Permalink
Use SpaceInfo to determne types action/obs space
Browse files Browse the repository at this point in the history
  • Loading branch information
dantp-ai committed Mar 24, 2024
1 parent f5084ca commit 2e8f3c3
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions test/discrete/test_bdq.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import pprint
from typing import cast

import gymnasium as gym
import numpy as np
Expand All @@ -10,6 +11,7 @@
from tianshou.policy import BranchingDQNPolicy
from tianshou.trainer import OffpolicyTrainer
from tianshou.utils.net.common import BranchingNet
from tianshou.utils.space_info import SpaceInfo


def get_args() -> argparse.Namespace:
Expand Down Expand Up @@ -51,9 +53,10 @@ def get_args() -> argparse.Namespace:
def test_bdq(args: argparse.Namespace = get_args()) -> None:
env = gym.make(args.task)
env = ContinuousToDiscrete(env, args.action_per_branch)

args.state_shape = env.observation_space.shape or env.observation_space.n
args.num_branches = env.action_space.shape[0]
env.action_space = cast(gym.spaces.Discrete, env.action_space)
space_info = SpaceInfo.from_env(env)
args.state_shape = space_info.observation_info.obs_shape
args.num_branches = space_info.action_info.action_dim

if args.reward_threshold is None:
default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250}
Expand Down

0 comments on commit 2e8f3c3

Please sign in to comment.