Skip to content

Commit

Permalink
Merge branch 'master' into feature/eq-for-slices-batch
Browse files Browse the repository at this point in the history
  • Loading branch information
dantp-ai authored Apr 16, 2024
2 parents 164cf84 + 049907d commit 91abd6b
Show file tree
Hide file tree
Showing 41 changed files with 89 additions and 55 deletions.
13 changes: 4 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,6 @@
[![PyPI](https://img.shields.io/pypi/v/tianshou)](https://pypi.org/project/tianshou/) [![Conda](https://img.shields.io/conda/vn/conda-forge/tianshou)](https://github.com/conda-forge/tianshou-feedstock) [![Read the Docs](https://img.shields.io/readthedocs/tianshou)](https://tianshou.readthedocs.io/en/master) [![Read the Docs](https://img.shields.io/readthedocs/tianshou-docs-zh-cn?label=%E4%B8%AD%E6%96%87%E6%96%87%E6%A1%A3)](https://tianshou.readthedocs.io/zh/master/) [![Unittest](https://github.com/thu-ml/tianshou/actions/workflows/pytest.yml/badge.svg)](https://github.com/thu-ml/tianshou/actions) [![codecov](https://img.shields.io/codecov/c/gh/thu-ml/tianshou)](https://codecov.io/gh/thu-ml/tianshou) [![GitHub issues](https://img.shields.io/github/issues/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/issues) [![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers) [![GitHub forks](https://img.shields.io/github/forks/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/network) [![GitHub license](https://img.shields.io/github/license/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/blob/master/LICENSE)


> ⚠️️ **Current Status**: the Tianshou master branch is currently under heavy development,
> moving towards more features, improved interfaces, more documentation.
You can view the relevant issues in the corresponding
> [milestone](https://github.com/thu-ml/tianshou/milestone/1)
> Stay tuned! (and expect breaking changes until the next major release)
**Tianshou** ([天授](https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88)) is a reinforcement learning platform based on pure PyTorch and [Gymnasium](http://github.com/Farama-Foundation/Gymnasium). Unlike other reinforcement learning libraries, which may have complex codebases,
unfriendly high-level APIs, or are not optimized for speed, Tianshou provides a high-performance, modularized framework
and user-friendly interfaces for building deep reinforcement learning agents. One more aspect that sets Tianshou apart is its
Expand Down Expand Up @@ -41,7 +35,7 @@ Supported algorithms include:
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
- [Randomized Ensembled Double Q-Learning (REDQ)](https://arxiv.org/pdf/2101.05982.pdf)
- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf)
- Vanilla Imitation Learning
- [Vanilla Imitation Learning](https://en.wikipedia.org/wiki/Apprenticeship_learning)
- [Batch-Constrained deep Q-Learning (BCQ)](https://arxiv.org/pdf/1812.02900.pdf)
- [Conservative Q-Learning (CQL)](https://arxiv.org/pdf/2006.04779.pdf)
- [Twin Delayed DDPG with Behavior Cloning (TD3+BC)](https://arxiv.org/pdf/2106.06860.pdf)
Expand Down Expand Up @@ -241,8 +235,9 @@ from tianshou.highlevel.env import (
from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig
from tianshou.highlevel.params.policy_params import DQNParams
from tianshou.highlevel.trainer import (
TrainerEpochCallbackTestDQNSetEps,
TrainerEpochCallbackTrainDQNSetEps,
EpochTestCallbackDQNSetEps,
EpochTrainCallbackDQNSetEps,
EpochStopCallbackRewardThreshold
)
```

Expand Down
1 change: 1 addition & 0 deletions examples/atari/atari_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def watch() -> None:
sys.exit(0)

# test train_collector and start filling replay buffer
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = OffpolicyTrainer(
Expand Down
5 changes: 3 additions & 2 deletions examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_args() -> argparse.Namespace:
return parser.parse_args()


def test_dqn(args: argparse.Namespace = get_args()) -> None:
def main(args: argparse.Namespace = get_args()) -> None:
env, train_envs, test_envs = make_atari_env(
args.task,
args.seed,
Expand Down Expand Up @@ -232,6 +232,7 @@ def watch() -> None:
sys.exit(0)

# test train_collector and start filling replay buffer
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = OffpolicyTrainer(
Expand Down Expand Up @@ -259,4 +260,4 @@ def watch() -> None:


if __name__ == "__main__":
test_dqn(get_args())
main(get_args())
1 change: 1 addition & 0 deletions examples/atari/atari_fqf.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def watch() -> None:
sys.exit(0)

# test train_collector and start filling replay buffer
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = OffpolicyTrainer(
Expand Down
1 change: 1 addition & 0 deletions examples/atari/atari_iqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def watch() -> None:
sys.exit(0)

# test train_collector and start filling replay buffer
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer

Expand Down
1 change: 1 addition & 0 deletions examples/atari/atari_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def watch() -> None:
sys.exit(0)

# test train_collector and start filling replay buffer
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = OnpolicyTrainer(
Expand Down
1 change: 1 addition & 0 deletions examples/atari/atari_qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def watch() -> None:
sys.exit(0)

# test train_collector and start filling replay buffer
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = OffpolicyTrainer(
Expand Down
1 change: 1 addition & 0 deletions examples/atari/atari_rainbow.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def watch() -> None:
sys.exit(0)

# test train_collector and start filling replay buffer
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = OffpolicyTrainer(
Expand Down
1 change: 1 addition & 0 deletions examples/atari/atari_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def watch() -> None:
sys.exit(0)

# test train_collector and start filling replay buffer
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = OffpolicyTrainer(
Expand Down
32 changes: 19 additions & 13 deletions examples/atari/atari_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,20 @@ def _parse_reset_result(reset_result: tuple) -> tuple[tuple, dict, bool]:
return reset_result, {}, contains_info


def get_space_dtype(obs_space: gym.spaces.Box) -> type[np.floating] | type[np.integer]:
obs_space_dtype: type[np.integer] | type[np.floating]
if np.issubdtype(obs_space.dtype, np.integer):
obs_space_dtype = np.integer
elif np.issubdtype(obs_space.dtype, np.floating):
obs_space_dtype = np.floating
else:
raise TypeError(
f"Unsupported observation space dtype: {obs_space.dtype}. "
f"This might be a bug in tianshou or gymnasium, please report it!",
)
return obs_space_dtype


class NoopResetEnv(gym.Wrapper):
"""Sample initial states by taking random number of no-ops on reset.
Expand Down Expand Up @@ -199,12 +213,8 @@ def __init__(self, env: gym.Env) -> None:
super().__init__(env)
self.size = 84
obs_space = env.observation_space
obs_space_dtype: type[np.floating[Any]] | type[np.integer[Any]]
if np.issubdtype(type(obs_space.dtype), np.integer):
obs_space_dtype = np.integer
elif np.issubdtype(type(obs_space.dtype), np.floating):
obs_space_dtype = np.floating
assert isinstance(obs_space, gym.spaces.Box)
obs_space_dtype = get_space_dtype(obs_space)
self.observation_space = gym.spaces.Box(
low=np.min(obs_space.low),
high=np.max(obs_space.high),
Expand Down Expand Up @@ -273,15 +283,11 @@ def __init__(self, env: gym.Env, n_frames: int) -> None:
obs_space_shape = env.observation_space.shape
assert obs_space_shape is not None
shape = (n_frames, *obs_space_shape)
assert isinstance(env.observation_space, gym.spaces.Box)
obs_space_dtype: type[np.floating[Any]] | type[np.integer[Any]]
if np.issubdtype(type(obs_space.dtype), np.integer):
obs_space_dtype = np.integer
elif np.issubdtype(type(obs_space.dtype), np.floating):
obs_space_dtype = np.floating
assert isinstance(obs_space, gym.spaces.Box)
obs_space_dtype = get_space_dtype(obs_space)
self.observation_space = gym.spaces.Box(
low=np.min(env.observation_space.low),
high=np.max(env.observation_space.high),
low=np.min(obs_space.low),
high=np.max(obs_space.high),
shape=shape,
dtype=obs_space_dtype,
)
Expand Down
1 change: 1 addition & 0 deletions examples/box2d/acrobot_dualdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, "dqn")
Expand Down
1 change: 1 addition & 0 deletions examples/box2d/bipedal_bdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
)
test_collector = Collector(policy, test_envs, exploration_noise=False)
# policy.set_eps(1)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
Expand Down
1 change: 1 addition & 0 deletions examples/box2d/lunarlander_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, "dqn")
Expand Down
1 change: 1 addition & 0 deletions examples/mujoco/fetch_her_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray:
)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
train_collector.reset()
train_collector.collect(n_step=args.start_timesteps, random=True)

def save_best_fn(policy: BasePolicy) -> None:
Expand Down
1 change: 1 addition & 0 deletions examples/mujoco/mujoco_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None:
buffer = ReplayBuffer(args.buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
train_collector.reset()
train_collector.collect(n_step=args.start_timesteps, random=True)

# log
Expand Down
1 change: 1 addition & 0 deletions examples/mujoco/mujoco_redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def linear(x: int, y: int) -> EnsembleLinear:
buffer = ReplayBuffer(args.buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
train_collector.reset()
train_collector.collect(n_step=args.start_timesteps, random=True)

# log
Expand Down
1 change: 1 addition & 0 deletions examples/mujoco/mujoco_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None:
buffer = ReplayBuffer(args.buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
train_collector.reset()
train_collector.collect(n_step=args.start_timesteps, random=True)

# log
Expand Down
1 change: 1 addition & 0 deletions examples/mujoco/mujoco_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None:
buffer = ReplayBuffer(args.buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
train_collector.reset()
train_collector.collect(n_step=args.start_timesteps, random=True)

# log
Expand Down
1 change: 1 addition & 0 deletions examples/vizdoom/vizdoom_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def watch() -> None:
sys.exit(0)

# test train_collector and start filling replay buffer
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = OffpolicyTrainer(
Expand Down
1 change: 1 addition & 0 deletions examples/vizdoom/vizdoom_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def watch() -> None:
sys.exit(0)

# test train_collector and start filling replay buffer
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = OnpolicyTrainer(
Expand Down
2 changes: 1 addition & 1 deletion test/base/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def test_batch_standard_compatibility() -> None:
batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]), b=Batch(), c=np.array([5.0, 6.0]))
batch_mean = np.mean(batch)
assert isinstance(batch_mean, Batch) # type: ignore # mypy doesn't know but it works, cf. `batch.rst`
assert sorted(batch_mean.keys()) == ["a", "b", "c"] # type: ignore
assert sorted(batch_mean.get_keys()) == ["a", "b", "c"] # type: ignore
with pytest.raises(TypeError):
len(batch_mean)
assert np.all(batch_mean.a == np.mean(batch.a, axis=0))
Expand Down
11 changes: 7 additions & 4 deletions test/base/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,11 +1379,14 @@ def test_custom_key() -> None:
buffer.add(batch)
sampled_batch, _ = buffer.sample(1)
# Check if they have the same keys
assert set(batch.keys()) == set(
sampled_batch.keys(),
), "Batches have different keys: {} and {}".format(set(batch.keys()), set(sampled_batch.keys()))
assert set(batch.get_keys()) == set(
sampled_batch.get_keys(),
), "Batches have different keys: {} and {}".format(
set(batch.get_keys()),
set(sampled_batch.get_keys()),
)
# Compare the values for each key
for key in batch.keys():
for key in batch.get_keys():
if isinstance(batch.__dict__[key], np.ndarray) and isinstance(
sampled_batch.__dict__[key],
np.ndarray,
Expand Down
2 changes: 1 addition & 1 deletion test/continuous/test_sac_with_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_args() -> argparse.Namespace:
parser.add_argument("--alpha", type=float, default=0.2)
parser.add_argument("--auto-alpha", type=int, default=1)
parser.add_argument("--alpha-lr", type=float, default=3e-4)
parser.add_argument("--epoch", type=int, default=5)
parser.add_argument("--epoch", type=int, default=10)
parser.add_argument("--step-per-epoch", type=int, default=24000)
parser.add_argument("--il-step-per-epoch", type=int, default=500)
parser.add_argument("--step-per-collect", type=int, default=10)
Expand Down
3 changes: 2 additions & 1 deletion test/discrete/test_bdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
)
test_collector = Collector(policy, test_envs, exploration_noise=False)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)

def train_fn(epoch: int, env_step: int) -> None: # exp decay
eps = max(args.eps_train * (1 - args.eps_decay) ** env_step, args.eps_test)
Expand Down
3 changes: 2 additions & 1 deletion test/discrete/test_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, "c51")
writer = SummaryWriter(log_path)
Expand Down
3 changes: 2 additions & 1 deletion test/discrete/test_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, "dqn")
writer = SummaryWriter(log_path)
Expand Down
3 changes: 2 additions & 1 deletion test/discrete/test_drqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None:
# the stack_num is for RNN training: sample framestack obs
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, "drqn")
writer = SummaryWriter(log_path)
Expand Down
3 changes: 2 additions & 1 deletion test/discrete/test_fqf.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, "fqf")
writer = SummaryWriter(log_path)
Expand Down
3 changes: 2 additions & 1 deletion test/discrete/test_iqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, "iqn")
writer = SummaryWriter(log_path)
Expand Down
3 changes: 2 additions & 1 deletion test/discrete/test_qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, "qrdqn")
writer = SummaryWriter(log_path)
Expand Down
3 changes: 2 additions & 1 deletion test/discrete/test_rainbow.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def noisy_linear(x: int, y: int) -> NoisyLinear:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, "rainbow")
writer = SummaryWriter(log_path)
Expand Down
3 changes: 2 additions & 1 deletion test/modelbased/test_dqn_icm.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, "dqn_icm")
writer = SummaryWriter(log_path)
Expand Down
2 changes: 1 addition & 1 deletion test/offline/gather_cartpole_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer:
test_collector = Collector(policy, test_envs, exploration_noise=True)
test_collector.reset()
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, "qrdqn")
writer = SummaryWriter(log_path)
Expand Down
3 changes: 2 additions & 1 deletion test/pettingzoo/pistonball.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ def train_agent(
exploration_noise=True,
)
test_collector = Collector(policy, test_envs, exploration_noise=True)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, "pistonball", "dqn")
writer = SummaryWriter(log_path)
Expand Down
Loading

0 comments on commit 91abd6b

Please sign in to comment.