diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 21b42aa7b..1946cc790 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -8,15 +8,15 @@ import torch from atari_network import C51 from atari_wrapper import make_atari_env -from torch.utils.tensorboard import SummaryWriter +from examples.common import logger_factory from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import C51Policy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer -from tianshou.utils import TensorboardLogger, WandbLogger -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=0) @@ -66,7 +66,7 @@ def get_args(): return parser.parse_args() -def test_c51(args=get_args()): +def test_c51(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, @@ -87,7 +87,7 @@ def test_c51(args=get_args()): net = C51(*args.state_shape, args.action_shape, args.num_atoms, args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy - policy = C51Policy( + policy: C51Policy = C51Policy( model=net, optim=optim, discount_factor=args.gamma, @@ -123,21 +123,19 @@ def test_c51(args=get_args()): # logger if args.logger == "wandb": - logger = WandbLogger( - save_interval=1, - name=log_name.replace(os.path.sep, "__"), - run_id=args.resume_id, - config=args, - project=args.wandb_project, - ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb - logger.load(writer) - - def save_best_fn(policy): + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -147,7 +145,7 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= 20 return False - def train_fn(epoch, env_step): + def train_fn(epoch: int, env_step: int) -> None: # nature DQN setting, linear decay in the first 1M steps if env_step <= 1e6: eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) @@ -157,11 +155,11 @@ def train_fn(epoch, env_step): if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) - def test_fn(epoch, env_step): + def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) # watch agent's performance - def watch(): + def watch() -> None: print("Setup test envs ...") policy.eval() policy.set_eps(args.eps_test) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 520761e55..c669fa714 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -8,17 +8,17 @@ import torch from atari_network import DQN from atari_wrapper import make_atari_env -from torch.utils.tensorboard import SummaryWriter +from examples.common import logger_factory from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import DQNPolicy +from tianshou.policy.base import BasePolicy from tianshou.policy.modelbased.icm import ICMPolicy from tianshou.trainer import OffpolicyTrainer -from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.discrete import IntrinsicCuriosityModule -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=0) @@ -83,7 +83,7 @@ def get_args(): return parser.parse_args() -def test_dqn(args=get_args()): +def test_dqn(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, @@ -104,7 +104,7 @@ def test_dqn(args=get_args()): net = DQN(*args.state_shape, args.action_shape, args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy - policy = DQNPolicy( + policy: DQNPolicy = DQNPolicy( model=net, optim=optim, action_space=env.action_space, @@ -158,21 +158,19 @@ def test_dqn(args=get_args()): # logger if args.logger == "wandb": - logger = WandbLogger( - save_interval=1, - name=log_name.replace(os.path.sep, "__"), - run_id=args.resume_id, - config=args, - project=args.wandb_project, - ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb - logger.load(writer) + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -182,7 +180,7 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= 20 return False - def train_fn(epoch, env_step): + def train_fn(epoch: int, env_step: int) -> None: # nature DQN setting, linear decay in the first 1M steps if env_step <= 1e6: eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) @@ -192,17 +190,17 @@ def train_fn(epoch, env_step): if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) - def test_fn(epoch, env_step): + def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) - def save_checkpoint_fn(epoch, env_step, gradient_step): + def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save({"model": policy.state_dict()}, ckpt_path) return ckpt_path # watch agent's performance - def watch(): + def watch() -> None: print("Setup test envs ...") policy.eval() policy.set_eps(args.eps_test) diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index 830ade184..e7a100368 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import functools import os from examples.atari.atari_network import ( @@ -48,7 +49,7 @@ def main( icm_lr_scale: float = 0.0, icm_reward_scale: float = 0.01, icm_forward_loss_weight: float = 0.2, -): +) -> None: log_name = os.path.join(task, "dqn", str(experiment_config.seed), datetime_tag()) sampling_config = SamplingConfig( @@ -102,4 +103,5 @@ def main( if __name__ == "__main__": - logging.run_cli(main) + run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig()) + logging.run_cli(run_with_default_config) diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 97bd7ded4..31adf9efd 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -8,16 +8,16 @@ import torch from atari_network import DQN from atari_wrapper import make_atari_env -from torch.utils.tensorboard import SummaryWriter +from examples.common import logger_factory from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import FQFPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer -from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=3128) @@ -69,7 +69,7 @@ def get_args(): return parser.parse_args() -def test_fqf(args=get_args()): +def test_fqf(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, @@ -99,7 +99,7 @@ def test_fqf(args=get_args()): fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim) fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), lr=args.fraction_lr) # define policy - policy = FQFPolicy( + policy: FQFPolicy = FQFPolicy( model=net, optim=optim, fraction_model=fraction_net, @@ -136,21 +136,19 @@ def test_fqf(args=get_args()): # logger if args.logger == "wandb": - logger = WandbLogger( - save_interval=1, - name=log_name.replace(os.path.sep, "__"), - run_id=args.resume_id, - config=args, - project=args.wandb_project, - ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb - logger.load(writer) - - def save_best_fn(policy): + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -160,7 +158,7 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= 20 return False - def train_fn(epoch, env_step): + def train_fn(epoch: int, env_step: int) -> None: # nature DQN setting, linear decay in the first 1M steps if env_step <= 1e6: eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) @@ -170,11 +168,11 @@ def train_fn(epoch, env_step): if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) - def test_fn(epoch, env_step): + def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) # watch agent's performance - def watch(): + def watch() -> None: print("Setup test envs ...") policy.eval() policy.set_eps(args.eps_test) diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index dc59da469..d9832b9ea 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -8,16 +8,16 @@ import torch from atari_network import DQN from atari_wrapper import make_atari_env -from torch.utils.tensorboard import SummaryWriter +from examples.common import logger_factory from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import IQNPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer -from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.discrete import ImplicitQuantileNetwork -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=1234) @@ -69,7 +69,7 @@ def get_args(): return parser.parse_args() -def test_iqn(args=get_args()): +def test_iqn(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, @@ -97,7 +97,7 @@ def test_iqn(args=get_args()): ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy - policy = IQNPolicy( + policy: IQNPolicy = IQNPolicy( model=net, optim=optim, action_space=env.action_space, @@ -133,21 +133,19 @@ def test_iqn(args=get_args()): # logger if args.logger == "wandb": - logger = WandbLogger( - save_interval=1, - name=log_name.replace(os.path.sep, "__"), - run_id=args.resume_id, - config=args, - project=args.wandb_project, - ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb - logger.load(writer) - - def save_best_fn(policy): + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -157,7 +155,7 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= 20 return False - def train_fn(epoch, env_step): + def train_fn(epoch: int, env_step: int) -> None: # nature DQN setting, linear decay in the first 1M steps if env_step <= 1e6: eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) @@ -167,11 +165,11 @@ def train_fn(epoch, env_step): if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) - def test_fn(epoch, env_step): + def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) # watch agent's performance - def watch(): + def watch() -> None: print("Setup test envs ...") policy.eval() policy.set_eps(args.eps_test) diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index 412ef1db3..32492a30b 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import functools import os from collections.abc import Sequence @@ -47,7 +48,7 @@ def main( test_num: int = 10, frames_stack: int = 4, save_buffer_name: str | None = None, # TODO support? -): +) -> None: log_name = os.path.join(task, "iqn", str(experiment_config.seed), datetime_tag()) sampling_config = SamplingConfig( @@ -94,4 +95,5 @@ def main( if __name__ == "__main__": - logging.run_cli(main) + run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig()) + logging.run_cli(run_with_default_config) diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index 7266eadae..2b8288a7c 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -24,7 +24,7 @@ def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0. class ScaledObsInputModule(torch.nn.Module): - def __init__(self, module: torch.nn.Module, denom: float = 255.0): + def __init__(self, module: torch.nn.Module, denom: float = 255.0) -> None: super().__init__() self.module = module self.denom = denom @@ -240,7 +240,12 @@ def forward( class ActorFactoryAtariDQN(ActorFactory): - def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool, features_only: bool): + def __init__( + self, + hidden_size: int | Sequence[int], + scale_obs: bool, + features_only: bool, + ) -> None: self.hidden_size = hidden_size self.scale_obs = scale_obs self.features_only = features_only @@ -260,7 +265,7 @@ def create_module(self, envs: Environments, device: TDevice) -> Actor: class IntermediateModuleFactoryAtariDQN(IntermediateModuleFactory): - def __init__(self, features_only: bool = False, net_only: bool = False): + def __init__(self, features_only: bool = False, net_only: bool = False) -> None: self.features_only = features_only self.net_only = net_only @@ -276,5 +281,5 @@ def create_intermediate_module(self, envs: Environments, device: TDevice) -> Int class IntermediateModuleFactoryAtariDQNFeatures(IntermediateModuleFactoryAtariDQN): - def __init__(self): + def __init__(self) -> None: super().__init__(features_only=True, net_only=True) diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 63da7a26d..86f54d4d7 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -8,18 +8,19 @@ import torch from atari_network import DQN, layer_init, scale_obs from atari_wrapper import make_atari_env +from torch.distributions import Categorical, Distribution from torch.optim.lr_scheduler import LambdaLR -from torch.utils.tensorboard import SummaryWriter +from examples.common import logger_factory from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import ICMPolicy, PPOPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OnpolicyTrainer -from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=4213) @@ -91,7 +92,7 @@ def get_args(): return parser.parse_args() -def test_ppo(args=get_args()): +def test_ppo(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, @@ -131,10 +132,10 @@ def test_ppo(args=get_args()): lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) # define policy - def dist(p): - return torch.distributions.Categorical(logits=p) + def dist(logits: torch.Tensor) -> Distribution: + return Categorical(logits=logits) - policy = PPOPolicy( + policy: PPOPolicy = PPOPolicy( actor=actor, critic=critic, optim=optim, @@ -200,21 +201,19 @@ def dist(p): # logger if args.logger == "wandb": - logger = WandbLogger( - save_interval=1, - name=log_name.replace(os.path.sep, "__"), - run_id=args.resume_id, - config=args, - project=args.wandb_project, - ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb - logger.load(writer) + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -224,14 +223,14 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= 20 return False - def save_checkpoint_fn(epoch, env_step, gradient_step): + def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save({"model": policy.state_dict()}, ckpt_path) return ckpt_path # watch agent's performance - def watch(): + def watch() -> None: print("Setup test envs ...") policy.eval() test_envs.seed(args.seed) diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index 2dafb599d..736fb1dd1 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import functools import os from collections.abc import Sequence @@ -53,7 +54,7 @@ def main( icm_lr_scale: float = 0.0, icm_reward_scale: float = 0.01, icm_forward_loss_weight: float = 0.2, -): +) -> None: log_name = os.path.join(task, "ppo", str(experiment_config.seed), datetime_tag()) sampling_config = SamplingConfig( @@ -113,4 +114,5 @@ def main( if __name__ == "__main__": - logging.run_cli(main) + run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig()) + logging.run_cli(run_with_default_config) diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 5231c0391..cef1c4247 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -8,15 +8,15 @@ import torch from atari_network import QRDQN from atari_wrapper import make_atari_env -from torch.utils.tensorboard import SummaryWriter +from examples.common import logger_factory from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import QRDQNPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer -from tianshou.utils import TensorboardLogger, WandbLogger -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=0) @@ -64,7 +64,7 @@ def get_args(): return parser.parse_args() -def test_qrdqn(args=get_args()): +def test_qrdqn(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, @@ -85,7 +85,7 @@ def test_qrdqn(args=get_args()): net = QRDQN(*args.state_shape, args.action_shape, args.num_quantiles, args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy - policy = QRDQNPolicy( + policy: QRDQNPolicy = QRDQNPolicy( model=net, optim=optim, action_space=env.action_space, @@ -119,21 +119,19 @@ def test_qrdqn(args=get_args()): # logger if args.logger == "wandb": - logger = WandbLogger( - save_interval=1, - name=log_name.replace(os.path.sep, "__"), - run_id=args.resume_id, - config=args, - project=args.wandb_project, - ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb - logger.load(writer) - - def save_best_fn(policy): + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -143,7 +141,7 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= 20 return False - def train_fn(epoch, env_step): + def train_fn(epoch: int, env_step: int) -> None: # nature DQN setting, linear decay in the first 1M steps if env_step <= 1e6: eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) @@ -153,11 +151,11 @@ def train_fn(epoch, env_step): if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) - def test_fn(epoch, env_step): + def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) # watch agent's performance - def watch(): + def watch() -> None: print("Setup test envs ...") policy.eval() policy.set_eps(args.eps_test) diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index ded22f5d9..dbea00688 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -8,15 +8,15 @@ import torch from atari_network import Rainbow from atari_wrapper import make_atari_env -from torch.utils.tensorboard import SummaryWriter +from examples.common import logger_factory from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer -from tianshou.policy import RainbowPolicy +from tianshou.policy import C51Policy, RainbowPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer -from tianshou.utils import TensorboardLogger, WandbLogger -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=0) @@ -75,7 +75,7 @@ def get_args(): return parser.parse_args() -def test_rainbow(args=get_args()): +def test_rainbow(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, @@ -104,7 +104,7 @@ def test_rainbow(args=get_args()): ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy - policy = RainbowPolicy( + policy: C51Policy = RainbowPolicy( model=net, optim=optim, discount_factor=args.gamma, @@ -121,6 +121,7 @@ def test_rainbow(args=get_args()): print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM + buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer if args.no_priority: buffer = VectorReplayBuffer( args.buffer_size, @@ -152,21 +153,19 @@ def test_rainbow(args=get_args()): # logger if args.logger == "wandb": - logger = WandbLogger( - save_interval=1, - name=log_name.replace(os.path.sep, "__"), - run_id=args.resume_id, - config=args, - project=args.wandb_project, - ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb - logger.load(writer) + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -176,7 +175,7 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= 20 return False - def train_fn(epoch, env_step): + def train_fn(epoch: int, env_step: int) -> None: # nature DQN setting, linear decay in the first 1M steps if env_step <= 1e6: eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) @@ -194,11 +193,11 @@ def train_fn(epoch, env_step): if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/beta": beta}) - def test_fn(epoch, env_step): + def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) # watch agent's performance - def watch(): + def watch() -> None: print("Setup test envs ...") policy.eval() policy.set_eps(args.eps_test) diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index 543d4b8fb..7dc60c0e8 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -8,16 +8,16 @@ import torch from atari_network import DQN from atari_wrapper import make_atari_env -from torch.utils.tensorboard import SummaryWriter +from examples.common import logger_factory from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import DiscreteSACPolicy, ICMPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer -from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=4213) @@ -85,7 +85,7 @@ def get_args(): return parser.parse_args() -def test_discrete_sac(args=get_args()): +def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_atari_env( args.task, args.seed, @@ -124,7 +124,7 @@ def test_discrete_sac(args=get_args()): alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy = DiscreteSACPolicy( + policy: DiscreteSACPolicy = DiscreteSACPolicy( actor=actor, actor_optim=actor_optim, critic=critic1, @@ -183,21 +183,19 @@ def test_discrete_sac(args=get_args()): # logger if args.logger == "wandb": - logger = WandbLogger( - save_interval=1, - name=log_name.replace(os.path.sep, "__"), - run_id=args.resume_id, - config=args, - project=args.wandb_project, - ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb - logger.load(writer) + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -207,14 +205,14 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= 20 return False - def save_checkpoint_fn(epoch, env_step, gradient_step): + def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, "checkpoint.pth") torch.save({"model": policy.state_dict()}, ckpt_path) return ckpt_path # watch agent's performance - def watch(): + def watch() -> None: print("Setup test envs ...") policy.eval() test_envs.seed(args.seed) diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index a8a5bd4ed..f2e02773d 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import functools import os from examples.atari.atari_network import ( @@ -47,7 +48,7 @@ def main( icm_lr_scale: float = 0.0, icm_reward_scale: float = 0.01, icm_forward_loss_weight: float = 0.2, -): +) -> None: log_name = os.path.join(task, "sac", str(experiment_config.seed), datetime_tag()) sampling_config = SamplingConfig( @@ -100,4 +101,5 @@ def main( if __name__ == "__main__": - logging.run_cli(main) + run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig()) + logging.run_cli(run_with_default_config) diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index 2f098b380..a2fdcca1f 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -46,7 +46,7 @@ class NoopResetEnv(gym.Wrapper): :param int noop_max: the maximum value of no-ops to run. """ - def __init__(self, env, noop_max=30): + def __init__(self, env, noop_max=30) -> None: super().__init__(env) self.noop_max = noop_max self.noop_action = 0 @@ -79,7 +79,7 @@ class MaxAndSkipEnv(gym.Wrapper): :param int skip: number of `skip`-th frame. """ - def __init__(self, env, skip=4): + def __init__(self, env, skip=4) -> None: super().__init__(env) self._skip = skip @@ -117,7 +117,7 @@ class EpisodicLifeEnv(gym.Wrapper): :param gym.Env env: the environment to wrap. """ - def __init__(self, env): + def __init__(self, env) -> None: super().__init__(env) self.lives = 0 self.was_real_done = True @@ -174,7 +174,7 @@ class FireResetEnv(gym.Wrapper): :param gym.Env env: the environment to wrap. """ - def __init__(self, env): + def __init__(self, env) -> None: super().__init__(env) assert env.unwrapped.get_action_meanings()[1] == "FIRE" assert len(env.unwrapped.get_action_meanings()) >= 3 @@ -191,7 +191,7 @@ class WarpFrame(gym.ObservationWrapper): :param gym.Env env: the environment to wrap. """ - def __init__(self, env): + def __init__(self, env) -> None: super().__init__(env) self.size = 84 self.observation_space = gym.spaces.Box( @@ -213,7 +213,7 @@ class ScaledFloatFrame(gym.ObservationWrapper): :param gym.Env env: the environment to wrap. """ - def __init__(self, env): + def __init__(self, env) -> None: super().__init__(env) low = np.min(env.observation_space.low) high = np.max(env.observation_space.high) @@ -236,7 +236,7 @@ class ClipRewardEnv(gym.RewardWrapper): :param gym.Env env: the environment to wrap. """ - def __init__(self, env): + def __init__(self, env) -> None: super().__init__(env) self.reward_range = (-1, 1) @@ -252,7 +252,7 @@ class FrameStack(gym.Wrapper): :param int n_frames: the number of frames to stack. """ - def __init__(self, env, n_frames): + def __init__(self, env, n_frames) -> None: super().__init__(env) self.n_frames = n_frames self.frames = deque([], maxlen=n_frames) @@ -353,7 +353,7 @@ def __init__( frame_stack: int, scale: bool = False, use_envpool_if_available: bool = True, - ): + ) -> None: assert "NoFrameskip" in task self.frame_stack = frame_stack self.scale = scale @@ -388,7 +388,7 @@ class EnvPoolFactory(EnvPoolFactory): it sets the creation keyword arguments accordingly. """ - def __init__(self, parent: "AtariEnvFactory"): + def __init__(self, parent: "AtariEnvFactory") -> None: self.parent = parent if self.parent.scale: warnings.warn( @@ -411,7 +411,7 @@ def _transform_kwargs(self, kwargs: dict, mode: EnvMode) -> dict: class AtariEpochStopCallback(EpochStopCallback): - def __init__(self, task: str): + def __init__(self, task: str) -> None: self.task = task def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool: diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 003caa3fd..2f6114dc5 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -1,6 +1,7 @@ import argparse import os import pprint +from typing import cast import gymnasium as gym import numpy as np @@ -10,12 +11,14 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import DQNPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net +from tianshou.utils.space_info import SpaceInfo -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Acrobot-v1") parser.add_argument("--seed", type=int, default=0) @@ -46,10 +49,12 @@ def get_args(): return parser.parse_args() -def test_dqn(args=get_args()): +def test_dqn(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + env.action_space = cast(gym.spaces.Discrete, env.action_space) + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) @@ -71,7 +76,7 @@ def test_dqn(args=get_args()): dueling_param=(Q_param, V_param), ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy = DQNPolicy( + policy: DQNPolicy = DQNPolicy( model=net, optim=optim, action_space=env.action_space, @@ -94,13 +99,18 @@ def test_dqn(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): - return mean_rewards >= env.spec.reward_threshold + def stop_fn(mean_rewards: float) -> bool: + if env.spec: + if not env.spec.reward_threshold: + return False + else: + return mean_rewards >= env.spec.reward_threshold + return False - def train_fn(epoch, env_step): + def train_fn(epoch: int, env_step: int) -> None: if env_step <= 100000: policy.set_eps(args.eps_train) elif env_step <= 500000: @@ -109,7 +119,7 @@ def train_fn(epoch, env_step): else: policy.set_eps(0.5 * args.eps_train) - def test_fn(epoch, env_step): + def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) # trainer @@ -138,8 +148,8 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index ff532830a..628116150 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -11,12 +11,13 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import ContinuousToDiscrete, SubprocVectorEnv from tianshou.policy import BranchingDQNPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import BranchingNet -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() # task parser.add_argument("--task", type=str, default="BipedalWalker-v3") @@ -52,7 +53,7 @@ def get_args(): return parser.parse_args() -def test_bdq(args=get_args()): +def test_bdq(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) env = ContinuousToDiscrete(env, args.action_per_branch) @@ -97,7 +98,12 @@ def test_bdq(args=get_args()): device=args.device, ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy = BranchingDQNPolicy(net, optim, args.gamma, target_update_freq=args.target_update_freq) + policy: BranchingDQNPolicy = BranchingDQNPolicy( + net, + optim, + args.gamma, + target_update_freq=args.target_update_freq, + ) # collector train_collector = Collector( policy, @@ -114,17 +120,17 @@ def test_bdq(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= getattr(env.spec.reward_threshold) - def train_fn(epoch, env_step): # exp decay + def train_fn(epoch: int, env_step: int) -> None: # exp decay eps = max(args.eps_train * (1 - args.eps_decay) ** env_step, args.eps_test) policy.set_eps(eps) - def test_fn(epoch, env_step): + def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) # trainer @@ -153,8 +159,8 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 2dda6b610..0d999679a 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -1,22 +1,26 @@ import argparse import os import pprint +from typing import Any import gymnasium as gym import numpy as np import torch +from gymnasium.core import WrapperActType, WrapperObsType from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import SubprocVectorEnv from tianshou.policy import SACPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.space_info import SpaceInfo -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="BipedalWalkerHardcore-v3") parser.add_argument("--seed", type=int, default=0) @@ -51,32 +55,41 @@ def get_args(): class Wrapper(gym.Wrapper): """Env wrapper for reward scale, action repeat and removing done penalty.""" - def __init__(self, env, action_repeat=3, reward_scale=5, rm_done=True): + def __init__( + self, + env: gym.Env, + action_repeat: int = 3, + reward_scale: int = 5, + rm_done: bool = True, + ) -> None: super().__init__(env) self.action_repeat = action_repeat self.reward_scale = reward_scale self.rm_done = rm_done - def step(self, action): + def step( + self, + action: WrapperActType, + ) -> tuple[WrapperObsType, float, bool, bool, dict[str, Any]]: rew_sum = 0.0 for _ in range(self.action_repeat): - obs, rew, done, info = self.env.step(action) + obs, rew, terminated, truncated, info = self.env.step(action) + done = terminated | truncated # remove done reward penalty if not done or not self.rm_done: - rew_sum = rew_sum + rew + rew_sum = rew_sum + float(rew) if done: break # scale reward - return obs, self.reward_scale * rew_sum, done, info + return obs, self.reward_scale * rew_sum, terminated, truncated, info -def test_sac_bipedal(args=get_args()): +def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: env = Wrapper(gym.make(args.task)) - - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] - + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + args.max_action = space_info.action_info.max_action train_envs = SubprocVectorEnv( [lambda: Wrapper(gym.make(args.task)) for _ in range(args.training_num)], ) @@ -119,13 +132,14 @@ def test_sac_bipedal(args=get_args()): critic2 = Critic(net_c2, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + action_dim = space_info.action_info.action_dim if args.auto_alpha: - target_entropy = -np.prod(env.action_space.shape) + target_entropy = -action_dim log_alpha = torch.zeros(1, requires_grad=True, device=args.device) alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy = SACPolicy( + policy: SACPolicy = SACPolicy( actor=actor, actor_optim=actor_optim, critic=critic1, @@ -157,11 +171,16 @@ def test_sac_bipedal(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): - return mean_rewards >= env.spec.reward_threshold + def stop_fn(mean_rewards: float) -> bool: + if env.spec: + if not env.spec.reward_threshold: + return False + else: + return mean_rewards >= env.spec.reward_threshold + return False # trainer result = OffpolicyTrainer( @@ -186,8 +205,8 @@ def stop_fn(mean_rewards): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 007a310d8..22f192b6d 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -10,12 +10,13 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import DQNPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() # the parameters are found by Optuna parser.add_argument("--task", type=str, default="LunarLander-v2") @@ -47,10 +48,11 @@ def get_args(): return parser.parse_args() -def test_dqn(args=get_args()): +def test_dqn(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) @@ -72,7 +74,7 @@ def test_dqn(args=get_args()): dueling_param=(Q_param, V_param), ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy = DQNPolicy( + policy: DQNPolicy = DQNPolicy( model=net, optim=optim, action_space=env.action_space, @@ -95,17 +97,22 @@ def test_dqn(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): - return mean_rewards >= env.spec.reward_threshold + def stop_fn(mean_rewards: float) -> bool: + if env.spec: + if not env.spec.reward_threshold: + return False + else: + return mean_rewards >= env.spec.reward_threshold + return False - def train_fn(epoch, env_step): # exp decay + def train_fn(epoch: int, env_step: int) -> None: # exp decay eps = max(args.eps_train * (1 - 5e-6) ** env_step, args.eps_test) policy.set_eps(eps) - def test_fn(epoch, env_step): + def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) # trainer @@ -134,8 +141,8 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 49a34af14..8a19d1806 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -11,13 +11,15 @@ from tianshou.env import DummyVectorEnv from tianshou.exploration import OUNoise from tianshou.policy import SACPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.space_info import SpaceInfo -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="MountainCarContinuous-v0") parser.add_argument("--seed", type=int, default=1626) @@ -48,11 +50,12 @@ def get_args(): return parser.parse_args() -def test_sac(args=get_args()): +def test_sac(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + args.max_action = space_info.action_info.max_action # train_envs = gym.make(args.task) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) # test_envs = gym.make(args.task) @@ -85,13 +88,14 @@ def test_sac(args=get_args()): critic2 = Critic(net_c2, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + action_dim = space_info.action_info.action_dim if args.auto_alpha: - target_entropy = -np.prod(env.action_space.shape) + target_entropy = -action_dim log_alpha = torch.zeros(1, requires_grad=True, device=args.device) alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy = SACPolicy( + policy: SACPolicy = SACPolicy( actor=actor, actor_optim=actor_optim, critic=critic1, @@ -118,11 +122,16 @@ def test_sac(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): - return mean_rewards >= env.spec.reward_threshold + def stop_fn(mean_rewards: float) -> bool: + if env.spec: + if not env.spec.reward_threshold: + return False + else: + return mean_rewards >= env.spec.reward_threshold + return False # trainer result = OffpolicyTrainer( @@ -147,8 +156,8 @@ def stop_fn(mean_rewards): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/examples/common.py b/examples/common.py new file mode 100644 index 000000000..a86115c2a --- /dev/null +++ b/examples/common.py @@ -0,0 +1,3 @@ +from tianshou.highlevel.logger import LoggerFactoryDefault + +logger_factory = LoggerFactoryDefault() diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index 55ca5efa7..bf553252c 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -1,11 +1,14 @@ +from typing import cast + import gymnasium as gym import torch from torch.utils.tensorboard import SummaryWriter import tianshou as ts +from tianshou.utils.space_info import SpaceInfo -def main(): +def main() -> None: task = "CartPole-v1" lr, epoch, batch_size = 1e-3, 10, 64 train_num, test_num = 10, 100 @@ -25,12 +28,14 @@ def main(): # you can define other net by following the API: # https://tianshou.readthedocs.io/en/master/tutorials/dqn.html#build-the-network env = gym.make(task, render_mode="human") - state_shape = env.observation_space.shape or env.observation_space.n - action_shape = env.action_space.shape or env.action_space.n + env.action_space = cast(gym.spaces.Discrete, env.action_space) + space_info = SpaceInfo.from_env(env) + state_shape = space_info.observation_info.obs_shape + action_shape = space_info.action_info.action_shape net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128]) optim = torch.optim.Adam(net.parameters(), lr=lr) - policy = ts.policy.DQNPolicy( + policy: ts.policy.DQNPolicy = ts.policy.DQNPolicy( model=net, optim=optim, discount_factor=gamma, @@ -50,6 +55,14 @@ def main(): exploration_noise=True, ) # because DQN uses epsilon-greedy method + def stop_fn(mean_rewards: float) -> bool: + if env.spec: + if not env.spec.reward_threshold: + return False + else: + return mean_rewards >= env.spec.reward_threshold + return False + result = ts.trainer.OffpolicyTrainer( policy=policy, train_collector=train_collector, @@ -62,7 +75,7 @@ def main(): update_per_step=1 / step_per_collect, train_fn=lambda epoch, env_step: policy.set_eps(eps_train), test_fn=lambda epoch, env_step: policy.set_eps(eps_test), - stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, + stop_fn=stop_fn, logger=logger, ).run() print(f"Finished training in {result.timing.total_time} seconds") diff --git a/examples/discrete/discrete_dqn_hl.py b/examples/discrete/discrete_dqn_hl.py index e0f4ca58b..097692c2b 100644 --- a/examples/discrete/discrete_dqn_hl.py +++ b/examples/discrete/discrete_dqn_hl.py @@ -13,7 +13,7 @@ from tianshou.utils.logging import run_main -def main(): +def main() -> None: experiment = ( DQNExperimentBuilder( EnvFactoryRegistered(task="CartPole-v1", seed=0, venv_type=VectorEnvType.DUMMY), diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 3dcc92fcd..705acaa00 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -4,23 +4,26 @@ import datetime import os import pprint +from typing import SupportsFloat import d4rl import gymnasium as gym import numpy as np import torch from torch import nn -from torch.distributions import Independent, Normal +from torch.distributions import Distribution, Independent, Normal from torch.optim.lr_scheduler import LambdaLR from torch.utils.tensorboard import SummaryWriter from tianshou.data import Batch, Collector, ReplayBuffer, VectorReplayBuffer from tianshou.env import SubprocVectorEnv, VectorEnvNormObs from tianshou.policy import GAILPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OnpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.space_info import SpaceInfo class NoRewardEnv(gym.RewardWrapper): @@ -29,15 +32,15 @@ class NoRewardEnv(gym.RewardWrapper): :param gym.Env env: the environment to wrap. """ - def __init__(self, env): + def __init__(self, env: gym.Env) -> None: super().__init__(env) - def reward(self, reward): + def reward(self, reward: SupportsFloat) -> np.ndarray: """Set reward to 0.""" return np.zeros_like(reward) -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="HalfCheetah-v2") parser.add_argument("--seed", type=int, default=0) @@ -86,14 +89,15 @@ def get_args(): return parser.parse_args() -def test_gail(args=get_args()): +def test_gail(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + args.max_action = space_info.action_info.max_action print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) + print("Action range:", args.min_action, args.max_action) # train_envs = gym.make(args.task) train_envs = SubprocVectorEnv( [lambda: NoRewardEnv(gym.make(args.task)) for _ in range(args.training_num)], @@ -163,7 +167,7 @@ def test_gail(args=get_args()): lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) - def dist(*logits): + def dist(*logits: torch.Tensor) -> Distribution: return Independent(Normal(*logits), 1) # expert replay buffer @@ -185,7 +189,7 @@ def dist(*logits): ) print("dataset loaded") - policy = GAILPolicy( + policy: GAILPolicy = GAILPolicy( actor=actor, critic=critic, optim=optim, @@ -217,6 +221,7 @@ def dist(*logits): print("Loaded agent from: ", args.resume_path) # collector + buffer: ReplayBuffer if args.training_num > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: @@ -231,7 +236,7 @@ def dist(*logits): writer.add_text("args", str(args)) logger = TensorboardLogger(writer, update_interval=100, train_interval=100) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not args.watch: @@ -256,8 +261,8 @@ def save_best_fn(policy): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index 6e37100db..dc471c200 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -12,6 +12,7 @@ import wandb from torch.utils.tensorboard import SummaryWriter + from tianshou.data import ( Collector, HERReplayBuffer, @@ -22,13 +23,14 @@ from tianshou.env import ShmemVectorEnv, TruncatedAsTerminated from tianshou.exploration import GaussianNoise from tianshou.policy import DDPGPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net, get_dict_state_decorator from tianshou.utils.net.continuous import Actor, Critic -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="FetchReach-v3") parser.add_argument("--seed", type=int, default=0) @@ -87,7 +89,7 @@ def make_fetch_env(task, training_num, test_num): return env, train_envs, test_envs -def test_ddpg(args=get_args()): +def test_ddpg(args: argparse.Namespace = get_args()) -> None: # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") args.algo_name = "ddpg" @@ -153,7 +155,7 @@ def test_ddpg(args=get_args()): ) critic = dict_state_dec(Critic)(net_c, device=args.device).to(args.device) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) - policy = DDPGPolicy( + policy: DDPGPolicy = DDPGPolicy( actor=actor, actor_optim=actor_optim, critic=critic, @@ -174,6 +176,7 @@ def test_ddpg(args=get_args()): def compute_reward_fn(ag: np.ndarray, g: np.ndarray): return env.compute_reward(ag, g, {}) + buffer: VectorReplayBuffer | ReplayBuffer | HERReplayBuffer | HERVectorReplayBuffer if args.replay_buffer == "normal": if args.training_num > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) @@ -199,7 +202,7 @@ def compute_reward_fn(ag: np.ndarray, g: np.ndarray): test_collector = Collector(policy, test_envs) train_collector.collect(n_step=args.start_timesteps, random=True) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not args.watch: @@ -224,8 +227,8 @@ def save_best_fn(policy): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index 7abd6201f..95b645dc3 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -9,19 +9,19 @@ import torch from mujoco_env import make_mujoco_env from torch import nn -from torch.distributions import Independent, Normal +from torch.distributions import Distribution, Independent, Normal from torch.optim.lr_scheduler import LambdaLR -from torch.utils.tensorboard import SummaryWriter +from examples.common import logger_factory from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.policy import A2CPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OnpolicyTrainer -from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Ant-v4") parser.add_argument("--seed", type=int, default=0) @@ -70,7 +70,7 @@ def get_args(): return parser.parse_args() -def test_a2c(args=get_args()): +def test_a2c(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, @@ -137,10 +137,10 @@ def test_a2c(args=get_args()): lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) - def dist(*logits): + def dist(*logits: torch.Tensor) -> Distribution: return Independent(Normal(*logits), 1) - policy = A2CPolicy( + policy: A2CPolicy = A2CPolicy( actor=actor, critic=critic, optim=optim, @@ -166,6 +166,7 @@ def dist(*logits): print("Loaded agent from: ", args.resume_path) # collector + buffer: VectorReplayBuffer | ReplayBuffer if args.training_num > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: @@ -181,21 +182,19 @@ def dist(*logits): # logger if args.logger == "wandb": - logger = WandbLogger( - save_interval=1, - name=log_name.replace(os.path.sep, "__"), - run_id=args.resume_id, - config=args, - project=args.wandb_project, - ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb - logger.load(writer) - - def save_best_fn(policy): + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: state = {"model": policy.state_dict(), "obs_rms": train_envs.get_obs_rms()} torch.save(state, os.path.join(log_path, "policy.pth")) @@ -221,8 +220,8 @@ def save_best_fn(policy): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index b3772b4d6..5b29a7565 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import functools import os from collections.abc import Sequence from typing import Literal @@ -40,7 +41,7 @@ def main( bound_action_method: Literal["clip", "tanh"] = "clip", lr_decay: bool = True, max_grad_norm: float = 0.5, -): +) -> None: log_name = os.path.join(task, "a2c", str(experiment_config.seed), datetime_tag()) sampling_config = SamplingConfig( @@ -82,4 +83,5 @@ def main( if __name__ == "__main__": - logging.run_cli(main) + run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig()) + logging.run_cli(run_with_default_config) diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index 7cb91322a..04fc0109b 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -8,18 +8,18 @@ import numpy as np import torch from mujoco_env import make_mujoco_env -from torch.utils.tensorboard import SummaryWriter +from examples.common import logger_factory from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.exploration import GaussianNoise from tianshou.policy import DDPGPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer -from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Ant-v4") parser.add_argument("--seed", type=int, default=0) @@ -64,7 +64,7 @@ def get_args(): return parser.parse_args() -def test_ddpg(args=get_args()): +def test_ddpg(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, @@ -97,7 +97,7 @@ def test_ddpg(args=get_args()): ) critic = Critic(net_c, device=args.device).to(args.device) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) - policy = DDPGPolicy( + policy: DDPGPolicy = DDPGPolicy( actor=actor, actor_optim=actor_optim, critic=critic, @@ -131,21 +131,19 @@ def test_ddpg(args=get_args()): # logger if args.logger == "wandb": - logger = WandbLogger( - save_interval=1, - name=log_name.replace(os.path.sep, "__"), - run_id=args.resume_id, - config=args, - project=args.wandb_project, - ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb - logger.load(writer) - - def save_best_fn(policy): + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not args.watch: @@ -170,8 +168,8 @@ def save_best_fn(policy): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py index 40a31366b..0026acfa4 100644 --- a/examples/mujoco/mujoco_ddpg_hl.py +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import functools import os from collections.abc import Sequence @@ -34,7 +35,7 @@ def main( batch_size: int = 256, training_num: int = 1, test_num: int = 10, -): +) -> None: log_name = os.path.join(task, "ddpg", str(experiment_config.seed), datetime_tag()) sampling_config = SamplingConfig( @@ -73,4 +74,5 @@ def main( if __name__ == "__main__": - logging.run_cli(main) + run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig()) + logging.run_cli(run_with_default_config) diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index 081e41e9e..210d46558 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -59,7 +59,7 @@ def restore(self, event: RestoreEvent, world: World): class MujocoEnvFactory(EnvFactoryRegistered): - def __init__(self, task: str, seed: int, obs_norm=True): + def __init__(self, task: str, seed: int, obs_norm=True) -> None: super().__init__( task=task, seed=seed, diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index f6b951bef..454565a46 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -9,19 +9,19 @@ import torch from mujoco_env import make_mujoco_env from torch import nn -from torch.distributions import Independent, Normal +from torch.distributions import Distribution, Independent, Normal from torch.optim.lr_scheduler import LambdaLR -from torch.utils.tensorboard import SummaryWriter +from examples.common import logger_factory from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.policy import NPGPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OnpolicyTrainer -from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Ant-v4") parser.add_argument("--seed", type=int, default=0) @@ -75,7 +75,7 @@ def get_args(): return parser.parse_args() -def test_npg(args=get_args()): +def test_npg(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, @@ -134,10 +134,10 @@ def test_npg(args=get_args()): lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) - def dist(*logits): + def dist(*logits: torch.Tensor) -> Distribution: return Independent(Normal(*logits), 1) - policy = NPGPolicy( + policy: NPGPolicy = NPGPolicy( actor=actor, critic=critic, optim=optim, @@ -163,6 +163,7 @@ def dist(*logits): print("Loaded agent from: ", args.resume_path) # collector + buffer: VectorReplayBuffer | ReplayBuffer if args.training_num > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: @@ -178,21 +179,19 @@ def dist(*logits): # logger if args.logger == "wandb": - logger = WandbLogger( - save_interval=1, - name=log_name.replace(os.path.sep, "__"), - run_id=args.resume_id, - config=args, - project=args.wandb_project, - ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb - logger.load(writer) - - def save_best_fn(policy): + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: state = {"model": policy.state_dict(), "obs_rms": train_envs.get_obs_rms()} torch.save(state, os.path.join(log_path, "policy.pth")) @@ -218,8 +217,8 @@ def save_best_fn(policy): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index 6784108d9..85451be50 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import functools import os from collections.abc import Sequence from typing import Literal @@ -42,7 +43,7 @@ def main( norm_adv: bool = True, optim_critic_iters: int = 20, actor_step_size: float = 0.1, -): +) -> None: log_name = os.path.join(task, "npg", str(experiment_config.seed), datetime_tag()) sampling_config = SamplingConfig( @@ -84,4 +85,5 @@ def main( if __name__ == "__main__": - logging.run_cli(main) + run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig()) + logging.run_cli(run_with_default_config) diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index 40b2f3946..61fb8cf84 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -9,19 +9,19 @@ import torch from mujoco_env import make_mujoco_env from torch import nn -from torch.distributions import Independent, Normal +from torch.distributions import Distribution, Independent, Normal from torch.optim.lr_scheduler import LambdaLR -from torch.utils.tensorboard import SummaryWriter +from examples.common import logger_factory from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.policy import PPOPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OnpolicyTrainer -from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Ant-v4") parser.add_argument("--seed", type=int, default=0) @@ -75,7 +75,7 @@ def get_args(): return parser.parse_args() -def test_ppo(args=get_args()): +def test_ppo(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, @@ -137,10 +137,10 @@ def test_ppo(args=get_args()): lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) - def dist(*logits): + def dist(*logits: torch.Tensor) -> Distribution: return Independent(Normal(*logits), 1) - policy = PPOPolicy( + policy: PPOPolicy = PPOPolicy( actor=actor, critic=critic, optim=optim, @@ -171,6 +171,7 @@ def dist(*logits): print("Loaded agent from: ", args.resume_path) # collector + buffer: VectorReplayBuffer | ReplayBuffer if args.training_num > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: @@ -186,21 +187,19 @@ def dist(*logits): # logger if args.logger == "wandb": - logger = WandbLogger( - save_interval=1, - name=log_name.replace(os.path.sep, "__"), - run_id=args.resume_id, - config=args, - project=args.wandb_project, - ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb - logger.load(writer) - - def save_best_fn(policy): + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: state = {"model": policy.state_dict(), "obs_rms": train_envs.get_obs_rms()} torch.save(state, os.path.join(log_path, "policy.pth")) @@ -226,8 +225,8 @@ def save_best_fn(policy): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index f4110ae1b..955d77ef2 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import functools import os from collections.abc import Sequence from typing import Literal @@ -47,7 +48,7 @@ def main( value_clip: bool = False, norm_adv: bool = False, recompute_adv: bool = True, -): +) -> None: log_name = os.path.join(task, "ppo", str(experiment_config.seed), datetime_tag()) sampling_config = SamplingConfig( @@ -94,4 +95,5 @@ def main( if __name__ == "__main__": - logging.run_cli(main) + run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig()) + logging.run_cli(run_with_default_config) diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index b50304fbb..66c9f7db6 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -8,17 +8,17 @@ import numpy as np import torch from mujoco_env import make_mujoco_env -from torch.utils.tensorboard import SummaryWriter +from examples.common import logger_factory from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.policy import REDQPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer -from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import EnsembleLinear, Net from tianshou.utils.net.continuous import ActorProb, Critic -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Ant-v4") parser.add_argument("--seed", type=int, default=0) @@ -68,7 +68,7 @@ def get_args(): return parser.parse_args() -def test_redq(args=get_args()): +def test_redq(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, @@ -121,7 +121,7 @@ def linear(x, y): alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy = REDQPolicy( + policy: REDQPolicy = REDQPolicy( actor=actor, actor_optim=actor_optim, critic=critics, @@ -143,6 +143,7 @@ def linear(x, y): print("Loaded agent from: ", args.resume_path) # collector + buffer: VectorReplayBuffer | ReplayBuffer if args.training_num > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: @@ -159,21 +160,19 @@ def linear(x, y): # logger if args.logger == "wandb": - logger = WandbLogger( - save_interval=1, - name=log_name.replace(os.path.sep, "__"), - run_id=args.resume_id, - config=args, - project=args.wandb_project, - ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb - logger.load(writer) - - def save_best_fn(policy): + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not args.watch: @@ -198,8 +197,8 @@ def save_best_fn(policy): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_redq_hl.py b/examples/mujoco/mujoco_redq_hl.py index cfcdb792f..8b8234b12 100644 --- a/examples/mujoco/mujoco_redq_hl.py +++ b/examples/mujoco/mujoco_redq_hl.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import functools import os from collections.abc import Sequence from typing import Literal @@ -40,7 +41,7 @@ def main( target_mode: Literal["mean", "min"] = "min", training_num: int = 1, test_num: int = 10, -): +) -> None: log_name = os.path.join(task, "redq", str(experiment_config.seed), datetime_tag()) sampling_config = SamplingConfig( @@ -82,4 +83,5 @@ def main( if __name__ == "__main__": - logging.run_cli(main) + run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig()) + logging.run_cli(run_with_default_config) diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index be0eb048a..6175c37cd 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -9,19 +9,19 @@ import torch from mujoco_env import make_mujoco_env from torch import nn -from torch.distributions import Independent, Normal +from torch.distributions import Distribution, Independent, Normal from torch.optim.lr_scheduler import LambdaLR -from torch.utils.tensorboard import SummaryWriter +from examples.common import logger_factory from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.policy import PGPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OnpolicyTrainer -from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Ant-v4") parser.add_argument("--seed", type=int, default=0) @@ -67,7 +67,7 @@ def get_args(): return parser.parse_args() -def test_reinforce(args=get_args()): +def test_reinforce(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, @@ -119,13 +119,13 @@ def test_reinforce(args=get_args()): lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) - def dist_fn(*logits): + def dist(*logits: torch.Tensor) -> Distribution: return Independent(Normal(*logits), 1) - policy = PGPolicy( + policy: PGPolicy = PGPolicy( actor=actor, optim=optim, - dist_fn=dist_fn, + dist_fn=dist, action_space=env.action_space, discount_factor=args.gamma, reward_normalization=args.rew_norm, @@ -143,6 +143,7 @@ def dist_fn(*logits): print("Loaded agent from: ", args.resume_path) # collector + buffer: VectorReplayBuffer | ReplayBuffer if args.training_num > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: @@ -158,21 +159,19 @@ def dist_fn(*logits): # logger if args.logger == "wandb": - logger = WandbLogger( - save_interval=1, - name=log_name.replace(os.path.sep, "__"), - run_id=args.resume_id, - config=args, - project=args.wandb_project, - ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb - logger.load(writer) - - def save_best_fn(policy): + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: state = {"model": policy.state_dict(), "obs_rms": train_envs.get_obs_rms()} torch.save(state, os.path.join(log_path, "policy.pth")) @@ -198,8 +197,8 @@ def save_best_fn(policy): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index a524b0e82..f793ab82b 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import functools import os from collections.abc import Sequence from typing import Literal @@ -35,7 +36,7 @@ def main( rew_norm: bool = True, action_bound_method: Literal["clip", "tanh"] = "tanh", lr_decay: bool = True, -): +) -> None: log_name = os.path.join(task, "reinforce", str(experiment_config.seed), datetime_tag()) sampling_config = SamplingConfig( @@ -71,4 +72,5 @@ def main( if __name__ == "__main__": - logging.run_cli(main) + run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig()) + logging.run_cli(run_with_default_config) diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 286f2cc93..a09118979 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -8,17 +8,17 @@ import numpy as np import torch from mujoco_env import make_mujoco_env -from torch.utils.tensorboard import SummaryWriter +from examples.common import logger_factory from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.policy import SACPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer -from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Ant-v4") parser.add_argument("--seed", type=int, default=0) @@ -65,7 +65,7 @@ def get_args(): return parser.parse_args() -def test_sac(args=get_args()): +def test_sac(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, @@ -117,7 +117,7 @@ def test_sac(args=get_args()): alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy = SACPolicy( + policy: SACPolicy = SACPolicy( actor=actor, actor_optim=actor_optim, critic=critic1, @@ -137,6 +137,7 @@ def test_sac(args=get_args()): print("Loaded agent from: ", args.resume_path) # collector + buffer: VectorReplayBuffer | ReplayBuffer if args.training_num > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: @@ -153,21 +154,19 @@ def test_sac(args=get_args()): # logger if args.logger == "wandb": - logger = WandbLogger( - save_interval=1, - name=log_name.replace(os.path.sep, "__"), - run_id=args.resume_id, - config=args, - project=args.wandb_project, - ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb - logger.load(writer) - - def save_best_fn(policy): + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not args.watch: @@ -192,8 +191,8 @@ def save_best_fn(policy): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index 6af4c6192..cc22a306e 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import functools import os from collections.abc import Sequence @@ -36,7 +37,7 @@ def main( batch_size: int = 256, training_num: int = 1, test_num: int = 10, -): +) -> None: log_name = os.path.join(task, "sac", str(experiment_config.seed), datetime_tag()) sampling_config = SamplingConfig( @@ -79,4 +80,5 @@ def main( if __name__ == "__main__": - logging.run_cli(main) + run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig()) + logging.run_cli(run_with_default_config) diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index ef986163e..057905c64 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -8,18 +8,18 @@ import numpy as np import torch from mujoco_env import make_mujoco_env -from torch.utils.tensorboard import SummaryWriter +from examples.common import logger_factory from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.exploration import GaussianNoise from tianshou.policy import TD3Policy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer -from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Ant-v4") parser.add_argument("--seed", type=int, default=0) @@ -67,7 +67,7 @@ def get_args(): return parser.parse_args() -def test_td3(args=get_args()): +def test_td3(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, @@ -112,7 +112,7 @@ def test_td3(args=get_args()): critic2 = Critic(net_c2, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) - policy = TD3Policy( + policy: TD3Policy = TD3Policy( actor=actor, actor_optim=actor_optim, critic=critic1, @@ -135,6 +135,7 @@ def test_td3(args=get_args()): print("Loaded agent from: ", args.resume_path) # collector + buffer: VectorReplayBuffer | ReplayBuffer if args.training_num > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: @@ -151,21 +152,19 @@ def test_td3(args=get_args()): # logger if args.logger == "wandb": - logger = WandbLogger( - save_interval=1, - name=log_name.replace(os.path.sep, "__"), - run_id=args.resume_id, - config=args, - project=args.wandb_project, - ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb - logger.load(writer) - - def save_best_fn(policy): + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) if not args.watch: @@ -190,8 +189,8 @@ def save_best_fn(policy): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index 67b9c1847..e6ab40d47 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import functools import os from collections.abc import Sequence @@ -42,7 +43,7 @@ def main( batch_size: int = 256, training_num: int = 1, test_num: int = 10, -): +) -> None: log_name = os.path.join(task, "td3", str(experiment_config.seed), datetime_tag()) sampling_config = SamplingConfig( @@ -84,4 +85,5 @@ def main( if __name__ == "__main__": - logging.run_cli(main) + run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig()) + logging.run_cli(run_with_default_config) diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index d38afdbe5..b001fd04c 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -9,19 +9,19 @@ import torch from mujoco_env import make_mujoco_env from torch import nn -from torch.distributions import Independent, Normal +from torch.distributions import Distribution, Independent, Normal from torch.optim.lr_scheduler import LambdaLR -from torch.utils.tensorboard import SummaryWriter +from examples.common import logger_factory from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.policy import TRPOPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OnpolicyTrainer -from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Ant-v4") parser.add_argument("--seed", type=int, default=0) @@ -78,7 +78,7 @@ def get_args(): return parser.parse_args() -def test_trpo(args=get_args()): +def test_trpo(args: argparse.Namespace = get_args()) -> None: env, train_envs, test_envs = make_mujoco_env( args.task, args.seed, @@ -137,10 +137,10 @@ def test_trpo(args=get_args()): lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) - def dist(*logits): + def dist(*logits: torch.Tensor) -> Distribution: return Independent(Normal(*logits), 1) - policy = TRPOPolicy( + policy: TRPOPolicy = TRPOPolicy( actor=actor, critic=critic, optim=optim, @@ -168,6 +168,7 @@ def dist(*logits): print("Loaded agent from: ", args.resume_path) # collector + buffer: VectorReplayBuffer | ReplayBuffer if args.training_num > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: @@ -183,21 +184,19 @@ def dist(*logits): # logger if args.logger == "wandb": - logger = WandbLogger( - save_interval=1, - name=log_name.replace(os.path.sep, "__"), - run_id=args.resume_id, - config=args, - project=args.wandb_project, - ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb - logger.load(writer) - - def save_best_fn(policy): + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: state = {"model": policy.state_dict(), "obs_rms": train_envs.get_obs_rms()} torch.save(state, os.path.join(log_path, "policy.pth")) @@ -223,8 +222,8 @@ def save_best_fn(policy): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index 855caecb6..528e974af 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import functools import os from collections.abc import Sequence from typing import Literal @@ -44,7 +45,7 @@ def main( max_kl: float = 0.01, backtrack_coeff: float = 0.8, max_backtracks: int = 10, -): +) -> None: log_name = os.path.join(task, "trpo", str(experiment_config.seed), datetime_tag()) sampling_config = SamplingConfig( @@ -88,4 +89,5 @@ def main( if __name__ == "__main__": - logging.run_cli(main) + run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig()) + logging.run_cli(run_with_default_config) diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index e2015382d..766662a07 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -9,20 +9,20 @@ import numpy as np import torch -from torch.utils.tensorboard import SummaryWriter from examples.atari.atari_network import DQN from examples.atari.atari_wrapper import make_atari_env +from examples.common import logger_factory from examples.offline.utils import load_buffer from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import DiscreteBCQPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OfflineTrainer -from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.discrete import Actor -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=1626) @@ -72,7 +72,7 @@ def get_args(): return parser.parse_known_args()[0] -def test_discrete_bcq(args=get_args()): +def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: # envs env, _, test_envs = make_atari_env( args.task, @@ -114,7 +114,7 @@ def test_discrete_bcq(args=get_args()): actor_critic = ActorCritic(policy_net, imitation_net) optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) # define policy - policy = DiscreteBCQPolicy( + policy: DiscreteBCQPolicy = DiscreteBCQPolicy( model=policy_net, imitator=imitation_net, optim=optim, @@ -158,28 +158,26 @@ def test_discrete_bcq(args=get_args()): # logger if args.logger == "wandb": - logger = WandbLogger( - save_interval=1, - name=log_name.replace(os.path.sep, "__"), - run_id=args.resume_id, - config=args, - project=args.wandb_project, - ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb - logger.load(writer) - - def save_best_fn(policy): + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return False # watch agent's performance - def watch(): + def watch() -> None: print("Setup test envs ...") policy.eval() policy.set_eps(args.eps_test) diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 82a0d6edf..5f1afcdcd 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -9,18 +9,18 @@ import numpy as np import torch -from torch.utils.tensorboard import SummaryWriter from examples.atari.atari_network import QRDQN from examples.atari.atari_wrapper import make_atari_env +from examples.common import logger_factory from examples.offline.utils import load_buffer from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import DiscreteCQLPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OfflineTrainer -from tianshou.utils import TensorboardLogger, WandbLogger -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=1626) @@ -70,7 +70,7 @@ def get_args(): return parser.parse_known_args()[0] -def test_discrete_cql(args=get_args()): +def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: # envs env, _, test_envs = make_atari_env( args.task, @@ -92,7 +92,7 @@ def test_discrete_cql(args=get_args()): net = QRDQN(*args.state_shape, args.action_shape, args.num_quantiles, args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy - policy = DiscreteCQLPolicy( + policy: DiscreteCQLPolicy = DiscreteCQLPolicy( model=net, optim=optim, action_space=env.action_space, @@ -134,28 +134,26 @@ def test_discrete_cql(args=get_args()): # logger if args.logger == "wandb": - logger = WandbLogger( - save_interval=1, - name=log_name.replace(os.path.sep, "__"), - run_id=args.resume_id, - config=args, - project=args.wandb_project, - ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb - logger.load(writer) - - def save_best_fn(policy): + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return False # watch agent's performance - def watch(): + def watch() -> None: print("Setup test envs ...") policy.eval() policy.set_eps(args.eps_test) diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index db768f2b0..97622b6d5 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -9,20 +9,20 @@ import numpy as np import torch -from torch.utils.tensorboard import SummaryWriter from examples.atari.atari_network import DQN from examples.atari.atari_wrapper import make_atari_env +from examples.common import logger_factory from examples.offline.utils import load_buffer from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import DiscreteCRRPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OfflineTrainer -from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.discrete import Actor, Critic -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=1626) @@ -72,7 +72,7 @@ def get_args(): return parser.parse_known_args()[0] -def test_discrete_crr(args=get_args()): +def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: # envs env, _, test_envs = make_atari_env( args.task, @@ -113,7 +113,7 @@ def test_discrete_crr(args=get_args()): actor_critic = ActorCritic(actor, critic) optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) # define policy - policy = DiscreteCRRPolicy( + policy: DiscreteCRRPolicy = DiscreteCRRPolicy( actor=actor, critic=critic, optim=optim, @@ -157,28 +157,26 @@ def test_discrete_crr(args=get_args()): # logger if args.logger == "wandb": - logger = WandbLogger( - save_interval=1, - name=log_name.replace(os.path.sep, "__"), - run_id=args.resume_id, - config=args, - project=args.wandb_project, - ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb - logger.load(writer) - - def save_best_fn(policy): + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return False # watch agent's performance - def watch(): + def watch() -> None: print("Setup test envs ...") policy.eval() test_envs.seed(args.seed) diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index 3efb56f91..615d38ec0 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -9,18 +9,18 @@ import numpy as np import torch -from torch.utils.tensorboard import SummaryWriter from examples.atari.atari_network import DQN from examples.atari.atari_wrapper import make_atari_env +from examples.common import logger_factory from examples.offline.utils import load_buffer from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import ImitationPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OfflineTrainer -from tianshou.utils import TensorboardLogger, WandbLogger -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") parser.add_argument("--seed", type=int, default=1626) @@ -63,7 +63,7 @@ def get_args(): return parser.parse_known_args()[0] -def test_il(args=get_args()): +def test_il(args: argparse.Namespace = get_args()) -> None: # envs env, _, test_envs = make_atari_env( args.task, @@ -85,7 +85,7 @@ def test_il(args=get_args()): net = DQN(*args.state_shape, args.action_shape, device=args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy - policy = ImitationPolicy(actor=net, optim=optim, action_space=env.action_space) + policy: ImitationPolicy = ImitationPolicy(actor=net, optim=optim, action_space=env.action_space) # load a previous policy if args.resume_path: policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) @@ -118,28 +118,26 @@ def test_il(args=get_args()): # logger if args.logger == "wandb": - logger = WandbLogger( - save_interval=1, - name=log_name.replace(os.path.sep, "__"), - run_id=args.resume_id, - config=args, - project=args.wandb_project, - ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb - logger.load(writer) - - def save_best_fn(policy): + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return False # watch agent's performance - def watch(): + def watch() -> None: print("Setup test envs ...") policy.eval() test_envs.seed(args.seed) diff --git a/examples/offline/convert_rl_unplugged_atari.py b/examples/offline/convert_rl_unplugged_atari.py index 02202dc75..a28a35e5f 100755 --- a/examples/offline/convert_rl_unplugged_atari.py +++ b/examples/offline/convert_rl_unplugged_atari.py @@ -238,7 +238,7 @@ def process_dataset( process_shard(url, filepath, ofname) -def main(args): +def main(args) -> None: if args.task not in ALL_GAMES: raise KeyError(f"`{args.task}` is not in the list of games.") fn = _filename(args.run_id, args.shard_id, total_num_shards=args.total_num_shards) diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index 5e3746e34..7c275b555 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -14,13 +14,15 @@ from tianshou.data import Collector from tianshou.env import SubprocVectorEnv from tianshou.policy import BCQPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import VAE, Critic, Perturbation +from tianshou.utils.space_info import SpaceInfo -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="HalfCheetah-v2") parser.add_argument("--seed", type=int, default=0) @@ -70,19 +72,20 @@ def get_args(): return parser.parse_args() -def test_bcq(): +def test_bcq() -> None: args = get_args() env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] # float + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + args.max_action = space_info.action_info.max_action print("device:", args.device) print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) + print("Action range:", args.min_action, args.max_action) - args.state_dim = args.state_shape[0] - args.action_dim = args.action_shape[0] + args.state_dim = space_info.observation_info.obs_dim + args.action_dim = space_info.action_info.action_dim print("Max_action", args.max_action) # test_envs = gym.make(args.task) @@ -149,11 +152,12 @@ def test_bcq(): ).to(args.device) vae_optim = torch.optim.Adam(vae.parameters()) - policy = BCQPolicy( + policy: BCQPolicy = BCQPolicy( actor_perturbation=actor, actor_perturbation_optim=actor_optim, critic=critic1, critic_optim=critic1_optim, + action_space=env.action_space, critic2=critic2, critic2_optim=critic2_optim, vae=vae, @@ -179,7 +183,12 @@ def test_bcq(): log_path = os.path.join(args.logdir, log_name) # logger - if args.logger == "wandb": + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger: WandbLogger | TensorboardLogger + if args.logger == "tensorboard": + logger = TensorboardLogger(writer) + else: logger = WandbLogger( save_interval=1, name=log_name.replace(os.path.sep, "__"), @@ -187,17 +196,12 @@ def test_bcq(): config=args, project=args.wandb_project, ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb logger.load(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def watch(): + def watch() -> None: if args.resume_path is None: args.resume_path = os.path.join(log_path, "policy.pth") @@ -228,8 +232,8 @@ def watch(): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 160137fce..0e9fe62cd 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -4,6 +4,7 @@ import datetime import os import pprint +from typing import cast import gymnasium as gym import numpy as np @@ -14,13 +15,15 @@ from tianshou.data import Collector from tianshou.env import SubprocVectorEnv from tianshou.policy import CQLPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.space_info import SpaceInfo -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument( "--task", @@ -214,19 +217,22 @@ def get_args(): return parser.parse_args() -def test_cql(): +def test_cql() -> None: args = get_args() env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] # float + env.action_space = cast(gym.spaces.Box, env.action_space) + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + args.max_action = space_info.action_info.max_action + args.min_action = space_info.action_info.min_action print("device:", args.device) print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) + print("Action range:", args.min_action, args.max_action) - args.state_dim = args.state_shape[0] - args.action_dim = args.action_shape[0] + args.state_dim = space_info.observation_info.obs_dim + args.action_dim = space_info.action_info.action_dim print("Max_action", args.max_action) # test_envs = gym.make(args.task) @@ -274,12 +280,12 @@ def test_cql(): critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) if args.auto_alpha: - target_entropy = -np.prod(env.action_space.shape) + target_entropy = -args.action_dim log_alpha = torch.zeros(1, requires_grad=True, device=args.device) alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy = CQLPolicy( + policy: CQLPolicy = CQLPolicy( actor=actor, actor_optim=actor_optim, critic=critic, @@ -296,8 +302,8 @@ def test_cql(): temperature=args.temperature, with_lagrange=args.with_lagrange, lagrange_threshold=args.lagrange_threshold, - min_action=np.min(env.action_space.low), - max_action=np.max(env.action_space.high), + min_action=args.min_action, + max_action=args.max_action, device=args.device, ) @@ -316,7 +322,12 @@ def test_cql(): log_path = os.path.join(args.logdir, log_name) # logger - if args.logger == "wandb": + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger: WandbLogger | TensorboardLogger + if args.logger == "tensorboard": + logger = TensorboardLogger(writer) + else: logger = WandbLogger( save_interval=1, name=log_name.replace(os.path.sep, "__"), @@ -324,17 +335,12 @@ def test_cql(): config=args, project=args.wandb_project, ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb logger.load(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def watch(): + def watch() -> None: if args.resume_path is None: args.resume_path = os.path.join(log_path, "policy.pth") @@ -365,8 +371,8 @@ def watch(): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index 89384127c..b7153ed11 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -14,13 +14,15 @@ from tianshou.data import Collector from tianshou.env import SubprocVectorEnv from tianshou.policy import ImitationPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor +from tianshou.utils.space_info import SpaceInfo -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="HalfCheetah-v2") parser.add_argument("--seed", type=int, default=0) @@ -57,19 +59,20 @@ def get_args(): return parser.parse_args() -def test_il(): +def test_il() -> None: args = get_args() env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] # float + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + args.max_action = space_info.action_info.max_action print("device:", args.device) print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) + print("Action range:", args.min_action, args.max_action) - args.state_dim = args.state_shape[0] - args.action_dim = args.action_shape[0] + args.state_dim = space_info.observation_info.obs_dim + args.action_dim = space_info.action_info.action_dim print("Max_action", args.max_action) test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) @@ -93,9 +96,9 @@ def test_il(): ).to(args.device) optim = torch.optim.Adam(actor.parameters(), lr=args.lr) - policy = ImitationPolicy( - actor, - optim, + policy: ImitationPolicy = ImitationPolicy( + actor=actor, + optim=optim, action_space=env.action_space, action_scaling=True, action_bound_method="clip", @@ -116,7 +119,12 @@ def test_il(): log_path = os.path.join(args.logdir, log_name) # logger - if args.logger == "wandb": + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger: WandbLogger | TensorboardLogger + if args.logger == "tensorboard": + logger = TensorboardLogger(writer) + else: logger = WandbLogger( save_interval=1, name=log_name.replace(os.path.sep, "__"), @@ -124,17 +132,12 @@ def test_il(): config=args, project=args.wandb_project, ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb logger.load(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def watch(): + def watch() -> None: if args.resume_path is None: args.resume_path = os.path.join(log_path, "policy.pth") @@ -165,8 +168,8 @@ def watch(): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index c4c4e9dd9..b2a0b24c8 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -12,16 +12,18 @@ from examples.offline.utils import load_buffer_d4rl, normalize_all_obs_in_replay_buffer from tianshou.data import Collector -from tianshou.env import SubprocVectorEnv, VectorEnvNormObs +from tianshou.env import BaseVectorEnv, SubprocVectorEnv, VectorEnvNormObs from tianshou.exploration import GaussianNoise from tianshou.policy import TD3BCPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic +from tianshou.utils.space_info import SpaceInfo -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="HalfCheetah-v2") parser.add_argument("--seed", type=int, default=0) @@ -71,21 +73,24 @@ def get_args(): return parser.parse_args() -def test_td3_bc(): +def test_td3_bc() -> None: args = get_args() env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] # float + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + args.max_action = space_info.action_info.max_action + args.min_action = space_info.action_info.min_action print("device:", args.device) print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) + print("Action range:", args.min_action, args.max_action) - args.state_dim = args.state_shape[0] - args.action_dim = args.action_shape[0] + args.state_dim = space_info.observation_info.obs_dim + args.action_dim = space_info.action_info.action_dim print("Max_action", args.max_action) + test_envs: BaseVectorEnv test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) if args.norm_obs: test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False) @@ -130,7 +135,7 @@ def test_td3_bc(): critic2 = Critic(net_c2, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) - policy = TD3BCPolicy( + policy: TD3BCPolicy = TD3BCPolicy( actor=actor, actor_optim=actor_optim, critic=critic1, @@ -163,7 +168,12 @@ def test_td3_bc(): log_path = os.path.join(args.logdir, log_name) # logger - if args.logger == "wandb": + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger: WandbLogger | TensorboardLogger + if args.logger == "tensorboard": + logger = TensorboardLogger(writer) + else: logger = WandbLogger( save_interval=1, name=log_name.replace(os.path.sep, "__"), @@ -171,17 +181,12 @@ def test_td3_bc(): config=args, project=args.wandb_project, ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb logger.load(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def watch(): + def watch() -> None: if args.resume_path is None: args.resume_path = os.path.join(log_path, "policy.pth") @@ -215,8 +220,8 @@ def watch(): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/examples/vizdoom/env.py b/examples/vizdoom/env.py index 9128eaa92..f5c974fa0 100644 --- a/examples/vizdoom/env.py +++ b/examples/vizdoom/env.py @@ -41,7 +41,7 @@ def battle_button_comb(): class Env(gym.Env): - def __init__(self, cfg_path, frameskip=4, res=(4, 40, 60), save_lmp=False): + def __init__(self, cfg_path, frameskip=4, res=(4, 40, 60), save_lmp=False) -> None: super().__init__() self.save_lmp = save_lmp self.health_setting = "battle" in cfg_path diff --git a/examples/vizdoom/replay.py b/examples/vizdoom/replay.py index bd894b1c1..4437a08ba 100755 --- a/examples/vizdoom/replay.py +++ b/examples/vizdoom/replay.py @@ -6,7 +6,7 @@ import vizdoom as vzd -def main(cfg_path="maps/D3_battle.cfg", lmp_path="test.lmp"): +def main(cfg_path="maps/D3_battle.cfg", lmp_path="test.lmp") -> None: game = vzd.DoomGame() game.load_config(cfg_path) game.set_screen_format(vzd.ScreenFormat.CRCGCB) diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 46a33769f..0c6b0a23a 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -8,15 +8,15 @@ import torch from env import make_vizdoom_env from network import C51 -from torch.utils.tensorboard import SummaryWriter +from examples.common import logger_factory from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import C51Policy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer -from tianshou.utils import TensorboardLogger, WandbLogger -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="D1_basic") parser.add_argument("--seed", type=int, default=0) @@ -72,7 +72,7 @@ def get_args(): return parser.parse_args() -def test_c51(args=get_args()): +def test_c51(args: argparse.Namespace = get_args()) -> None: # make environments env, train_envs, test_envs = make_vizdoom_env( args.task, @@ -95,7 +95,7 @@ def test_c51(args=get_args()): net = C51(*args.state_shape, args.action_shape, args.num_atoms, args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy - policy = C51Policy( + policy: C51Policy = C51Policy( model=net, optim=optim, discount_factor=args.gamma, @@ -131,21 +131,19 @@ def test_c51(args=get_args()): # logger if args.logger == "wandb": - logger = WandbLogger( - save_interval=1, - name=log_name.replace(os.path.sep, "__"), - run_id=args.resume_id, - config=args, - project=args.wandb_project, - ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb - logger.load(writer) - - def save_best_fn(policy): + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) + + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -153,7 +151,7 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= env.spec.reward_threshold return False - def train_fn(epoch, env_step): + def train_fn(epoch: int, env_step: int) -> None: # nature DQN setting, linear decay in the first 1M steps if env_step <= 1e6: eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final) @@ -163,11 +161,11 @@ def train_fn(epoch, env_step): if env_step % 1000 == 0: logger.write("train/env_step", env_step, {"train/eps": eps}) - def test_fn(epoch, env_step): + def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) # watch agent's performance - def watch(): + def watch() -> None: print("Setup test envs ...") policy.eval() policy.set_eps(args.eps_test) diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 0811a24dd..b474b7b83 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -8,18 +8,19 @@ import torch from env import make_vizdoom_env from network import DQN +from torch.distributions import Categorical, Distribution from torch.optim.lr_scheduler import LambdaLR -from torch.utils.tensorboard import SummaryWriter +from examples.common import logger_factory from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import ICMPolicy, PPOPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OnpolicyTrainer -from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="D1_basic") parser.add_argument("--seed", type=int, default=0) @@ -97,7 +98,7 @@ def get_args(): return parser.parse_args() -def test_ppo(args=get_args()): +def test_ppo(args: argparse.Namespace = get_args()) -> None: # make environments env, train_envs, test_envs = make_vizdoom_env( args.task, @@ -136,10 +137,10 @@ def test_ppo(args=get_args()): lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) # define policy - def dist(p): - return torch.distributions.Categorical(logits=p) + def dist(logits: torch.Tensor) -> Distribution: + return Categorical(logits=logits) - policy = PPOPolicy( + policy: PPOPolicy = PPOPolicy( actor=actor, critic=critic, optim=optim, @@ -210,21 +211,19 @@ def dist(p): # logger if args.logger == "wandb": - logger = WandbLogger( - save_interval=1, - name=log_name.replace(os.path.sep, "__"), - run_id=args.resume_id, - config=args, - project=args.wandb_project, - ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb - logger.load(writer) + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: @@ -233,7 +232,7 @@ def stop_fn(mean_rewards: float) -> bool: return False # watch agent's performance - def watch(): + def watch() -> None: print("Setup test envs ...") policy.eval() test_envs.seed(args.seed) diff --git a/test/base/env.py b/test/base/env.py index b2faf7a5d..8a2de26cc 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -1,11 +1,12 @@ import random import time from copy import deepcopy +from typing import Any, Literal import gymnasium as gym import networkx as nx import numpy as np -from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple +from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, Space, Tuple class MyTestEnv(gym.Env): @@ -13,15 +14,15 @@ class MyTestEnv(gym.Env): def __init__( self, - size, - sleep=0, - dict_state=False, - recurse_state=False, - ma_rew=0, - multidiscrete_action=False, - random_sleep=False, - array_state=False, - ): + size: int, + sleep: int = 0, + dict_state: bool = False, + recurse_state: bool = False, + ma_rew: int = 0, + multidiscrete_action: bool = False, + random_sleep: bool = False, + array_state: bool = False, + ) -> None: assert ( dict_state + recurse_state + array_state <= 1 ), "dict_state / recurse_state / array_state can be only one true" @@ -70,7 +71,11 @@ def __init__( self.terminated = False self.index = 0 - def reset(self, seed=None, options=None): + def reset( + self, + seed: int | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[dict[str, Any] | np.ndarray, dict]: if options is None: options = {"state": 0} super().reset(seed=seed) @@ -79,14 +84,14 @@ def reset(self, seed=None, options=None): self.index = options["state"] return self._get_state(), {"key": 1, "env": self} - def _get_reward(self): + def _get_reward(self) -> list[int] | int: """Generate a non-scalar reward if ma_rew is True.""" end_flag = int(self.terminated) if self.ma_rew > 0: return [end_flag] * self.ma_rew return end_flag - def _get_state(self): + def _get_state(self) -> dict[str, Any] | np.ndarray: """Generate state(observation) of MyTestEnv.""" if self.dict_state: return { @@ -110,15 +115,15 @@ def _get_state(self): return img return np.array([self.index], dtype=np.float32) - def do_sleep(self): + def do_sleep(self) -> None: if self.sleep > 0: sleep_time = random.random() if self.random_sleep else 1 sleep_time *= self.sleep time.sleep(sleep_time) - def step(self, action): + def step(self, action: np.ndarray | int): self.steps += 1 - if self._md_action: + if self._md_action and isinstance(action, np.ndarray): action = action[0] if self.terminated: raise ValueError("step after done !!!") @@ -149,7 +154,7 @@ def step(self, action): class NXEnv(gym.Env): - def __init__(self, size, obs_type, feat_dim=32): + def __init__(self, size: int, obs_type: str, feat_dim: int = 32) -> None: self.size = size self.feat_dim = feat_dim self.graph = nx.Graph() @@ -157,26 +162,34 @@ def __init__(self, size, obs_type, feat_dim=32): assert obs_type in ["array", "object"] self.obs_type = obs_type - def _encode_obs(self): + def _encode_obs(self) -> np.ndarray | nx.Graph: if self.obs_type == "array": return np.stack([v["data"] for v in self.graph._node.values()]) return deepcopy(self.graph) - def reset(self): + def reset( + self, + seed: int | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[np.ndarray | nx.Graph, dict]: + super().reset(seed=seed) graph_state = np.random.rand(self.size, self.feat_dim) for i in range(self.size): self.graph.nodes[i]["data"] = graph_state[i] return self._encode_obs(), {} - def step(self, action): + def step( + self, + action: Space, + ) -> tuple[np.ndarray | nx.Graph, float, Literal[False], Literal[False], dict]: next_graph_state = np.random.rand(self.size, self.feat_dim) for i in range(self.size): self.graph.nodes[i]["data"] = next_graph_state[i] - return self._encode_obs(), 1.0, 0, 0, {} + return self._encode_obs(), 1.0, False, False, {} class MyGoalEnv(MyTestEnv): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: assert ( kwargs.get("dict_state", 0) + kwargs.get("recurse_state", 0) == 0 ), "dict_state / recurse_state not supported" @@ -193,12 +206,12 @@ def __init__(self, *args, **kwargs): }, ) - def reset(self, *args, **kwargs): + def reset(self, *args: Any, **kwargs: Any) -> tuple[dict[str, Any], dict]: obs, info = super().reset(*args, **kwargs) new_obs = {"observation": obs, "achieved_goal": obs, "desired_goal": self._goal} return new_obs, info - def step(self, *args, **kwargs): + def step(self, *args: Any, **kwargs: Any) -> tuple[dict[str, Any], float, bool, bool, dict]: obs_next, rew, terminated, truncated, info = super().step(*args, **kwargs) new_obs_next = { "observation": obs_next, @@ -213,7 +226,5 @@ def compute_reward_fn( desired_goal: np.ndarray, info: dict, ) -> np.ndarray: - axis = -1 - if self.array_state: - axis = (-3, -2, -1) + axis: tuple[int, ...] = (-3, -2, -1) if self.array_state else (-1,) return (achieved_goal == desired_goal).all(axis=axis) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index ba15282fe..aaacffdd4 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -11,7 +11,7 @@ from tianshou.data import Batch, to_numpy, to_torch -def test_batch(): +def test_batch() -> None: assert list(Batch()) == [] assert Batch().is_empty() assert not Batch(b={"c": {}}).is_empty() @@ -180,7 +180,7 @@ def test_batch(): assert Batch(a=np.array([g1, g2], dtype=object)).a.dtype == object -def test_batch_over_batch(): +def test_batch_over_batch() -> None: batch = Batch(a=[3, 4, 5], b=[4, 5, 6]) batch2 = Batch({"c": [6, 7, 8], "b": batch}) batch2.b.b[-1] = 0 @@ -225,7 +225,7 @@ def test_batch_over_batch(): assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3) -def test_batch_cat_and_stack(): +def test_batch_cat_and_stack() -> None: # test cat with compatible keys b1 = Batch(a=[{"b": np.float64(1.0), "d": Batch(e=np.array(3.0))}]) b2 = Batch(a=[{"b": np.float64(4.0), "d": {"e": np.array(6.0)}}]) @@ -365,7 +365,7 @@ def test_batch_cat_and_stack(): Batch.stack([b1, b2], axis=1) -def test_batch_over_batch_to_torch(): +def test_batch_over_batch_to_torch() -> None: batch = Batch( a=np.float64(1.0), b=Batch(c=np.ones((1,), dtype=np.float32), d=torch.ones((1,), dtype=torch.float64)), @@ -390,7 +390,7 @@ def test_batch_over_batch_to_torch(): assert batch.b.e.dtype == torch.float32 -def test_utils_to_torch_numpy(): +def test_utils_to_torch_numpy() -> None: batch = Batch( a=np.float64(1.0), b=Batch(c=np.ones((1,), dtype=np.float32), d=torch.ones((1,), dtype=torch.float64)), @@ -457,7 +457,7 @@ def test_utils_to_torch_numpy(): to_torch(np.array([{}, "2"])) -def test_batch_pickle(): +def test_batch_pickle() -> None: batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), np=np.zeros([3, 4])) batch_pk = pickle.loads(pickle.dumps(batch)) assert batch.obs.a == batch_pk.obs.a @@ -465,7 +465,7 @@ def test_batch_pickle(): assert np.all(batch.np == batch_pk.np) -def test_batch_from_to_numpy_without_copy(): +def test_batch_from_to_numpy_without_copy() -> None: batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,)))) a_mem_addr_orig = batch.a.__array_interface__["data"][0] c_mem_addr_orig = batch.b.c.__array_interface__["data"][0] @@ -477,7 +477,7 @@ def test_batch_from_to_numpy_without_copy(): assert c_mem_addr_new == c_mem_addr_orig -def test_batch_copy(): +def test_batch_copy() -> None: batch = Batch(a=np.array([3, 4, 5]), b=np.array([4, 5, 6])) batch2 = Batch({"c": np.array([6, 7, 8]), "b": batch}) orig_c_addr = batch2.c.__array_interface__["data"][0] @@ -509,7 +509,7 @@ def test_batch_copy(): assert orig_b_b_addr != curr_b_b_addr -def test_batch_empty(): +def test_batch_empty() -> None: b5_dict = np.array([{"a": False, "b": {"c": 2.0, "d": 1.0}}, {"a": True, "b": {"c": 3.0}}]) b5 = Batch(b5_dict) b5[1] = Batch.empty(b5[0]) @@ -545,7 +545,7 @@ def test_batch_empty(): assert b0.shape == [] -def test_batch_standard_compatibility(): +def test_batch_standard_compatibility() -> None: batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]), b=Batch(), c=np.array([5.0, 6.0])) batch_mean = np.mean(batch) assert isinstance(batch_mean, Batch) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index c4dcd6405..99154bbdf 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -27,7 +27,7 @@ from test.base.env import MyGoalEnv, MyTestEnv -def test_replaybuffer(size=10, bufsize=20): +def test_replaybuffer(size=10, bufsize=20) -> None: env = MyTestEnv(size) buf = ReplayBuffer(bufsize) buf.update(buf) @@ -139,7 +139,7 @@ def test_replaybuffer(size=10, bufsize=20): assert np.all(b.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3]) -def test_ignore_obs_next(size=10): +def test_ignore_obs_next(size=10) -> None: # Issue 82 buf = ReplayBuffer(size, ignore_obs_next=True) for i in range(size): @@ -208,7 +208,7 @@ def test_ignore_obs_next(size=10): assert data.obs_next -def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3): +def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3) -> None: env = MyTestEnv(size) buf = ReplayBuffer(bufsize, stack_num=stack_num) buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) @@ -279,7 +279,7 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3): buf[bufsize * 2] -def test_priortized_replaybuffer(size=32, bufsize=15): +def test_priortized_replaybuffer(size=32, bufsize=15) -> None: env = MyTestEnv(size) buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5) buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5) @@ -329,7 +329,7 @@ def test_priortized_replaybuffer(size=32, bufsize=15): assert weight[mask][0] <= 1 -def test_herreplaybuffer(size=10, bufsize=100, sample_sz=4): +def test_herreplaybuffer(size=10, bufsize=100, sample_sz=4) -> None: env_size = size env = MyGoalEnv(env_size, array_state=True) @@ -491,7 +491,7 @@ def compute_reward_fn(ag, g): assert int(buf.obs.desired_goal[10][0]) in [11, 12, 13, 14, 15, 16, 17, 18, 19, 20] -def test_update(): +def test_update() -> None: buf1 = ReplayBuffer(4, stack_num=2) buf2 = ReplayBuffer(4, stack_num=2) for i in range(5): @@ -515,7 +515,7 @@ def test_update(): b.update(b) -def test_segtree(): +def test_segtree() -> None: realop = np.sum # small test actual_len = 8 @@ -616,7 +616,7 @@ def sample_tree(): print("tree", timeit(sample_tree, setup=sample_tree, number=1000)) -def test_pickle(): +def test_pickle() -> None: size = 100 vbuf = ReplayBuffer(size, stack_num=2) pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4) @@ -654,7 +654,7 @@ def test_pickle(): assert np.allclose(_pbuf.weight[np.arange(len(_pbuf))], pbuf.weight[np.arange(len(pbuf))]) -def test_hdf5(): +def test_hdf5() -> None: size = 100 buffers = { "array": ReplayBuffer(size, stack_num=2), @@ -714,7 +714,7 @@ def test_hdf5(): to_hdf5(data, grp) -def test_replaybuffermanager(): +def test_replaybuffermanager() -> None: buf = VectorReplayBuffer(20, 4) batch = Batch( obs=[1, 2, 3], @@ -923,7 +923,7 @@ def test_replaybuffermanager(): assert np.array([ReplayBuffer(0, ignore_obs_next=True)]).dtype == object -def test_cachedbuffer(): +def test_cachedbuffer() -> None: buf = CachedReplayBuffer(ReplayBuffer(10), 4, 5) assert buf.sample_indices(0).tolist() == [] # check the normal function/usage/storage in CachedReplayBuffer @@ -1023,7 +1023,7 @@ def test_cachedbuffer(): assert np.allclose(buf.next(indices), [1, 1, 11, 11]) -def test_multibuf_stack(): +def test_multibuf_stack() -> None: size = 5 bufsize = 9 stack_num = 4 @@ -1208,7 +1208,7 @@ def test_multibuf_stack(): assert buf6[0].obs.shape == (4, 84, 84) -def test_multibuf_hdf5(): +def test_multibuf_hdf5() -> None: size = 100 buffers = { "vector": VectorReplayBuffer(size * 4, 4), @@ -1284,7 +1284,7 @@ def test_multibuf_hdf5(): os.remove(path) -def test_from_data(): +def test_from_data() -> None: obs_data = np.ndarray((10, 3, 3), dtype="uint8") for i in range(10): obs_data[i] = i * np.ones((3, 3), dtype="uint8") @@ -1311,7 +1311,7 @@ def test_from_data(): os.remove(path) -def test_custom_key(): +def test_custom_key() -> None: batch = Batch( obs_next=np.array( [ diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 985e6ef50..f7a24a86e 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -34,7 +34,7 @@ def __init__( dict_state=False, need_state=True, action_shape=None, - ): + ) -> None: """Mock policy for testing. :param action_space: the action space of the environment. If None, a dummy Box space will be used. @@ -64,7 +64,7 @@ def learn(self): class Logger: - def __init__(self, writer): + def __init__(self, writer) -> None: self.cnt = 0 self.writer = writer @@ -92,7 +92,7 @@ def single_preprocess_fn(**kwargs): @pytest.mark.parametrize("gym_reset_kwargs", [None, {}]) -def test_collector(gym_reset_kwargs): +def test_collector(gym_reset_kwargs) -> None: writer = SummaryWriter("log/collector") logger = Logger(writer) env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]] @@ -219,7 +219,7 @@ def test_collector(gym_reset_kwargs): @pytest.mark.parametrize("gym_reset_kwargs", [None, {}]) -def test_collector_with_async(gym_reset_kwargs): +def test_collector_with_async(gym_reset_kwargs) -> None: env_lens = [2, 3, 4, 5] writer = SummaryWriter("log/async_collector") logger = Logger(writer) @@ -264,7 +264,7 @@ def test_collector_with_async(gym_reset_kwargs): c1.collect() -def test_collector_with_dict_state(): +def test_collector_with_dict_state() -> None: env = MyTestEnv(size=5, sleep=0, dict_state=True) policy = MyPolicy(dict_state=True) c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn) @@ -402,7 +402,7 @@ def test_collector_with_dict_state(): batch, _ = c2.buffer.sample(10) -def test_collector_with_ma(): +def test_collector_with_ma() -> None: env = MyTestEnv(size=5, sleep=0, ma_rew=4) policy = MyPolicy() c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn) @@ -534,7 +534,7 @@ def test_collector_with_ma(): batch, _ = c2.buffer.sample(10) -def test_collector_with_atari_setting(): +def test_collector_with_atari_setting() -> None: reference_obs = np.zeros([6, 4, 84, 84]) for i in range(6): reference_obs[i, 3, np.arange(84), np.arange(84)] = i @@ -776,7 +776,7 @@ def test_collector_with_atari_setting(): @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") -def test_collector_envpool_gym_reset_return_info(): +def test_collector_envpool_gym_reset_return_info() -> None: envs = envpool.make_gymnasium("Pendulum-v1", num_envs=4, gym_reset_return_info=True) policy = MyPolicy(action_shape=(len(envs), 1)) diff --git a/test/base/test_env.py b/test/base/test_env.py index 05a12474b..edeb3f361 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -53,7 +53,7 @@ def recurse_comp(a, b): return False -def test_async_env(size=10000, num=8, sleep=0.1): +def test_async_env(size=10000, num=8, sleep=0.1) -> None: # simplify the test case, just keep stepping env_fns = [ lambda i=i: MyTestEnv(size=i, sleep=sleep, random_sleep=True) @@ -106,7 +106,7 @@ def test_async_env(size=10000, num=8, sleep=0.1): assert spent_time < 6.0 * sleep * num / (num + 1) -def test_async_check_id(size=100, num=4, sleep=0.2, timeout=0.7): +def test_async_check_id(size=100, num=4, sleep=0.2, timeout=0.7) -> None: env_fns = [ lambda: MyTestEnv(size=size, sleep=sleep * 2), lambda: MyTestEnv(size=size, sleep=sleep * 3), @@ -154,7 +154,7 @@ def test_async_check_id(size=100, num=4, sleep=0.2, timeout=0.7): assert total_pass >= 2 -def test_vecenv(size=10, num=8, sleep=0.001): +def test_vecenv(size=10, num=8, sleep=0.001) -> None: env_fns = [ lambda i=i: MyTestEnv(size=i, sleep=sleep, recurse_state=True) for i in range(size, size + num) @@ -219,7 +219,7 @@ def assert_get(v, expected): v.close() -def test_attr_unwrapped(): +def test_attr_unwrapped() -> None: train_envs = DummyVectorEnv([lambda: gym.make("CartPole-v1")]) train_envs.set_env_attr("test_attribute", 1337) assert train_envs.get_env_attr("test_attribute") == [1337] @@ -227,7 +227,7 @@ def test_attr_unwrapped(): assert hasattr(train_envs.workers[0].env.unwrapped, "test_attribute") -def test_env_obs_dtype(): +def test_env_obs_dtype() -> None: for obs_type in ["array", "object"]: envs = SubprocVectorEnv([lambda i=x, t=obs_type: NXEnv(i, t) for x in [5, 10, 15, 20]]) obs, info = envs.reset() @@ -236,7 +236,7 @@ def test_env_obs_dtype(): assert obs.dtype == object -def test_env_reset_optional_kwargs(size=10000, num=8): +def test_env_reset_optional_kwargs(size=10000, num=8) -> None: env_fns = [lambda i=i: MyTestEnv(size=i) for i in range(size, size + num)] test_cls = [DummyVectorEnv, SubprocVectorEnv, ShmemVectorEnv] if has_ray(): @@ -248,7 +248,7 @@ def test_env_reset_optional_kwargs(size=10000, num=8): assert isinstance(info[0], dict) -def test_venv_wrapper_gym(num_envs: int = 4): +def test_venv_wrapper_gym(num_envs: int = 4) -> None: # Issue 697 envs = DummyVectorEnv([lambda: gym.make("CartPole-v1") for _ in range(num_envs)]) envs = VectorEnvNormObs(envs) @@ -329,7 +329,7 @@ def reset_result_to_obs(reset_result): assert np.allclose(no, to) -def test_venv_norm_obs(): +def test_venv_norm_obs() -> None: sizes = np.array([5, 10, 15, 20]) action = np.array([1, 1, 1, 1]) total_step = 30 @@ -343,9 +343,9 @@ def test_venv_norm_obs(): run_align_norm_obs(raw, train_env, test_env, action_list) -def test_gym_wrappers(): +def test_gym_wrappers() -> None: class DummyEnv(gym.Env): - def __init__(self): + def __init__(self) -> None: self.action_space = gym.spaces.Box(low=-1.0, high=2.0, shape=(4,), dtype=np.float32) self.observation_space = gym.spaces.Discrete(2) @@ -387,7 +387,7 @@ def step(self, act): @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") -def test_venv_wrapper_envpool(): +def test_venv_wrapper_envpool() -> None: raw = envpool.make_gymnasium("Ant-v3", num_envs=4) train = VectorEnvNormObs(envpool.make_gymnasium("Ant-v3", num_envs=4)) test = VectorEnvNormObs(envpool.make_gymnasium("Ant-v3", num_envs=4), update_obs_rms=False) @@ -397,7 +397,7 @@ def test_venv_wrapper_envpool(): @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") -def test_venv_wrapper_envpool_gym_reset_return_info(): +def test_venv_wrapper_envpool_gym_reset_return_info() -> None: num_envs = 4 env = VectorEnvNormObs( envpool.make_gymnasium("Ant-v3", num_envs=num_envs, gym_reset_return_info=True), diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index 2617dc798..d1e780251 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -14,7 +14,7 @@ class DummyDataset(Dataset): - def __init__(self, length): + def __init__(self, length) -> None: self.length = length self.episodes = [3 * i % 5 + 1 for i in range(self.length)] @@ -27,7 +27,7 @@ def __len__(self): class FiniteEnv(gym.Env): - def __init__(self, dataset, num_replicas, rank): + def __init__(self, dataset, num_replicas, rank) -> None: self.dataset = dataset self.num_replicas = num_replicas self.rank = rank @@ -62,7 +62,7 @@ def step(self, action): class FiniteVectorEnv(BaseVectorEnv): - def __init__(self, env_fns, **kwargs): + def __init__(self, env_fns, **kwargs) -> None: super().__init__(env_fns, **kwargs) self._alive_env_ids = set() self._reset_alive_envs() @@ -165,7 +165,7 @@ class FiniteSubprocVectorEnv(FiniteVectorEnv, SubprocVectorEnv): class AnyPolicy(BasePolicy): - def __init__(self): + def __init__(self) -> None: super().__init__(action_space=Box(-1, 1, (1,))) def forward(self, batch, state=None): @@ -180,7 +180,7 @@ def _finite_env_factory(dataset, num_replicas, rank): class MetricTracker: - def __init__(self): + def __init__(self) -> None: self.counter = Counter() self.finished = set() @@ -199,7 +199,7 @@ def validate(self): assert v == k * 3 % 5 + 1 -def test_finite_dummy_vector_env(): +def test_finite_dummy_vector_env() -> None: dataset = DummyDataset(100) envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)]) policy = AnyPolicy() @@ -213,7 +213,7 @@ def test_finite_dummy_vector_env(): envs.tracker.validate() -def test_finite_subproc_vector_env(): +def test_finite_subproc_vector_env() -> None: dataset = DummyDataset(100) envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)]) policy = AnyPolicy() diff --git a/test/base/test_logger.py b/test/base/test_logger.py index 1634f4d8f..69c02a285 100644 --- a/test/base/test_logger.py +++ b/test/base/test_logger.py @@ -1,3 +1,5 @@ +from typing import Literal + import numpy as np import pytest @@ -13,7 +15,11 @@ class TestBaseLogger: ({"a": {"b": {"c": 1}}}, {"a/b/c": 1}), ], ) - def test_flatten_dict_basic(input_dict, expected_output): + def test_flatten_dict_basic( + input_dict: dict[str, int | dict[str, int | dict[str, int]]] + | dict[str, dict[str, dict[str, int]]], + expected_output: dict[str, int], + ) -> None: result = BaseLogger.prepare_dict_for_logging(input_dict) assert result == expected_output @@ -25,7 +31,11 @@ def test_flatten_dict_basic(input_dict, expected_output): ({"a": {"b": {"c": 1}}}, ".", {"a.b.c": 1}), ], ) - def test_flatten_dict_custom_delimiter(input_dict, delimiter, expected_output): + def test_flatten_dict_custom_delimiter( + input_dict: dict[str, dict[str, dict[str, int]]], + delimiter: Literal["|", "."], + expected_output: dict[str, int], + ) -> None: result = BaseLogger.prepare_dict_for_logging(input_dict, delimiter=delimiter) assert result == expected_output @@ -41,7 +51,11 @@ def test_flatten_dict_custom_delimiter(input_dict, delimiter, expected_output): ({"a": np.array([1, 2, 3]), "b": {"c": np.array([4, 5, 6])}}, True, {}), ], ) - def test_flatten_dict_exclude_arrays(input_dict, exclude_arrays, expected_output): + def test_flatten_dict_exclude_arrays( + input_dict: dict[str, np.ndarray | dict[str, np.ndarray]], + exclude_arrays: bool, + expected_output: dict[str, np.ndarray], + ) -> None: result = BaseLogger.prepare_dict_for_logging(input_dict, exclude_arrays=exclude_arrays) assert result.keys() == expected_output.keys() for val1, val2 in zip(result.values(), expected_output.values(), strict=True): @@ -54,6 +68,9 @@ def test_flatten_dict_exclude_arrays(input_dict, exclude_arrays, expected_output ({"a": (1,), "b": {"c": "2", "d": {"e": 3}}}, {"b/d/e": 3}), ], ) - def test_flatten_dict_invalid_values_filtered_out(input_dict, expected_output): + def test_flatten_dict_invalid_values_filtered_out( + input_dict: dict[str, tuple[Literal[1]] | dict[str, str | dict[str, int]]], + expected_output: dict[str, int], + ) -> None: result = BaseLogger.prepare_dict_for_logging(input_dict) assert result == expected_output diff --git a/test/base/test_policy.py b/test/base/test_policy.py index 672f194bf..9fe6f8c3a 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -43,7 +43,7 @@ def policy(request): actor_critic = ActorCritic(actor, critic) optim = torch.optim.Adam(actor_critic.parameters(), lr=1e-3) - policy = PPOPolicy( + policy: PPOPolicy = PPOPolicy( actor=actor, critic=critic, dist_fn=dist_fn, @@ -56,7 +56,7 @@ def policy(request): class TestPolicyBasics: - def test_get_action(self, policy): + def test_get_action(self, policy) -> None: sample_obs = torch.randn(obs_shape) policy.deterministic_eval = False actions = [policy.compute_action(sample_obs) for _ in range(10)] diff --git a/test/base/test_returns.py b/test/base/test_returns.py index 48ec89f8a..2dbf47c29 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -19,7 +19,7 @@ def compute_episodic_return_base(batch, gamma): return batch -def test_episodic_returns(size=2560): +def test_episodic_returns(size=2560) -> None: fn = BasePolicy.compute_episodic_return buf = ReplayBuffer(20) batch = Batch( @@ -195,7 +195,7 @@ def compute_nstep_return_base(nstep, gamma, buffer, indices): return returns -def test_nstep_returns(size=10000): +def test_nstep_returns(size=10000) -> None: buf = ReplayBuffer(10) for i in range(12): buf.add( @@ -273,7 +273,7 @@ def test_nstep_returns(size=10000): assert np.allclose(returns_multidim, returns[:, np.newaxis]) -def test_nstep_returns_with_timelimit(size=10000): +def test_nstep_returns_with_timelimit(size=10000) -> None: buf = ReplayBuffer(10) for i in range(12): buf.add( diff --git a/test/base/test_stats.py b/test/base/test_stats.py index b9ec67a12..537519287 100644 --- a/test/base/test_stats.py +++ b/test/base/test_stats.py @@ -4,14 +4,14 @@ class DummyTrainingStatsWrapper(TrainingStatsWrapper): - def __init__(self, wrapped_stats: TrainingStats, *, dummy_field: int): + def __init__(self, wrapped_stats: TrainingStats, *, dummy_field: int) -> None: self.dummy_field = dummy_field super().__init__(wrapped_stats) class TestStats: @staticmethod - def test_training_stats_wrapper(): + def test_training_stats_wrapper() -> None: train_stats = TrainingStats(train_time=1.0) train_stats.loss_field = 12 diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 98e5a2027..bd14ffe2a 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -7,7 +7,7 @@ from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic -def test_noise(): +def test_noise() -> None: noise = GaussianNoise() size = (3, 4, 5) assert np.allclose(noise(size).shape, size) @@ -16,7 +16,7 @@ def test_noise(): assert np.allclose(noise(size).shape, size) -def test_moving_average(): +def test_moving_average() -> None: stat = MovAvg(10) assert np.allclose(stat.get(), 0) assert np.allclose(stat.mean(), 0) @@ -30,7 +30,7 @@ def test_moving_average(): assert np.allclose(stat.std() ** 2, 2) -def test_rms(): +def test_rms() -> None: rms = RunningMeanStd() assert np.allclose(rms.mean, 0) assert np.allclose(rms.var, 1) @@ -40,7 +40,7 @@ def test_rms(): assert np.allclose(rms.var, np.array([[0, 0], [2, 14 / 3.0]]), atol=1e-3) -def test_net(): +def test_net() -> None: # here test the networks that does not appear in the other script bsz = 64 # MLP @@ -75,7 +75,7 @@ def test_net(): assert list(net(data)[0].shape) == expect_output_shape # concat net = Net(state_shape, action_shape, hidden_sizes=[128], concat=True) - data = torch.rand([bsz, np.prod(state_shape) + np.prod(action_shape)]) + data = torch.rand([bsz, int(np.prod(state_shape)) + int(np.prod(action_shape))]) expect_output_shape = [bsz, 128] assert list(net(data)[0].shape) == expect_output_shape net = Net( @@ -94,12 +94,12 @@ def test_net(): assert mu.shape == sigma.shape assert list(mu.shape) == [bsz, 5] net = RecurrentCritic(3, state_shape, action_shape) - data = torch.rand([bsz, 8, np.prod(state_shape)]) + data = torch.rand([bsz, 8, int(np.prod(state_shape))]) act = torch.rand(expect_output_shape) assert list(net(data, act).shape) == [bsz, 1] -def test_lr_schedulers(): +def test_lr_schedulers() -> None: initial_lr_1 = 10.0 step_size_1 = 1 gamma_1 = 0.5 diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 30e0e33de..e2de17e85 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -11,13 +11,15 @@ from tianshou.env import DummyVectorEnv from tianshou.exploration import GaussianNoise from tianshou.policy import DDPGPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic +from tianshou.utils.space_info import SpaceInfo -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward-threshold", type=float, default=None) @@ -48,14 +50,18 @@ def get_args(): return parser.parse_known_args()[0] -def test_ddpg(args=get_args()): +def test_ddpg(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + args.max_action = space_info.action_info.max_action if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) @@ -81,7 +87,7 @@ def test_ddpg(args=get_args()): ) critic = Critic(net, device=args.device).to(args.device) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) - policy = DDPGPolicy( + policy: DDPGPolicy = DDPGPolicy( actor=actor, actor_optim=actor_optim, critic=critic, @@ -105,10 +111,10 @@ def test_ddpg(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer @@ -134,8 +140,8 @@ def stop_fn(mean_rewards): env = gym.make(args.task) policy.eval() collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = collector.collect(n_episode=1, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index edd51f274..bcfe6b07b 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -6,19 +6,22 @@ import numpy as np import torch from torch import nn -from torch.distributions import Independent, Normal +from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import NPGPolicy +from tianshou.policy.base import BasePolicy +from tianshou.policy.modelfree.npg import NPGTrainingStats from tianshou.trainer import OnpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.space_info import SpaceInfo -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward-threshold", type=float, default=None) @@ -50,14 +53,20 @@ def get_args(): return parser.parse_known_args()[0] -def test_npg(args=get_args()): +def test_npg(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] + + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + args.max_action = space_info.action_info.max_action + if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) @@ -94,10 +103,10 @@ def test_npg(args=get_args()): # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward - def dist(*logits): + def dist(*logits: torch.Tensor) -> Distribution: return Independent(Normal(*logits), 1) - policy = NPGPolicy( + policy: NPGPolicy[NPGTrainingStats] = NPGPolicy( actor=actor, critic=critic, optim=optim, @@ -123,10 +132,10 @@ def dist(*logits): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer @@ -152,8 +161,8 @@ def stop_fn(mean_rewards): env = gym.make(args.task) policy.eval() collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = collector.collect(n_episode=1, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 1327ede6e..d092bc67c 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -5,19 +5,22 @@ import gymnasium as gym import numpy as np import torch -from torch.distributions import Independent, Normal +from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import PPOPolicy +from tianshou.policy.base import BasePolicy +from tianshou.policy.modelfree.ppo import PPOTrainingStats from tianshou.trainer import OnpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.space_info import SpaceInfo -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward-threshold", type=float, default=None) @@ -56,14 +59,20 @@ def get_args(): return parser.parse_known_args()[0] -def test_ppo(args=get_args()): +def test_ppo(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] + + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + args.max_action = space_info.action_info.max_action + if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) @@ -91,10 +100,10 @@ def test_ppo(args=get_args()): # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward - def dist(*logits): + def dist(*logits: torch.Tensor) -> Distribution: return Independent(Normal(*logits), 1) - policy = PPOPolicy( + policy: PPOPolicy[PPOTrainingStats] = PPOPolicy( actor=actor, critic=critic, optim=optim, @@ -124,13 +133,13 @@ def dist(*logits): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - def save_checkpoint_fn(epoch, env_step, gradient_step): + def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, "checkpoint.pth") # Example: saving by epoch num @@ -187,11 +196,11 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): env = gym.make(args.task) policy.eval() collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = collector.collect(n_episode=1, render=args.render) + print(collector_stats) -def test_ppo_resume(args=get_args()): +def test_ppo_resume(args: argparse.Namespace = get_args()) -> None: args.resume = True test_ppo(args) diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index 2258d6c20..20177f429 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -1,22 +1,26 @@ import argparse import os import pprint +from typing import cast import gymnasium as gym import numpy as np import torch +import torch.nn as nn from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import REDQPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import EnsembleLinear, Net from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.space_info import SpaceInfo -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward-threshold", type=float, default=None) @@ -52,14 +56,19 @@ def get_args(): return parser.parse_known_args()[0] -def test_redq(args=get_args()): +def test_redq(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] + env.action_space = cast(gym.spaces.Box, env.action_space) + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + args.max_action = space_info.action_info.max_action if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) @@ -81,7 +90,7 @@ def test_redq(args=get_args()): ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - def linear(x, y): + def linear(x: int, y: int) -> nn.Module: return EnsembleLinear(args.ensemble_size, x, y) net_c = Net( @@ -97,13 +106,14 @@ def linear(x, y): ) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) + action_dim = space_info.action_info.action_dim if args.auto_alpha: - target_entropy = -np.prod(env.action_space.shape) + target_entropy = -action_dim log_alpha = torch.zeros(1, requires_grad=True, device=args.device) alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy = REDQPolicy( + policy: REDQPolicy = REDQPolicy( actor=actor, actor_optim=actor_optim, critic=critic, @@ -132,10 +142,10 @@ def linear(x, y): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer @@ -161,8 +171,8 @@ def stop_fn(mean_rewards): env = gym.make(args.task) policy.eval() collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = collector.collect(n_episode=1, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 9c18904d6..c0e17c62c 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -10,10 +10,12 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import ImitationPolicy, SACPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, ActorProb, Critic +from tianshou.utils.space_info import SpaceInfo try: import envpool @@ -21,7 +23,7 @@ envpool = None -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward-threshold", type=float, default=None) @@ -57,19 +59,23 @@ def get_args(): @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") -def test_sac_with_il(args=get_args()): +def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: # if you want to use python vector env, please refer to other test scripts # train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed) # test_envs = envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed) env = gym.make(args.task) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + args.max_action = space_info.action_info.max_action if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) # you can also use tianshou.env.SubprocVectorEnv # seed np.random.seed(args.seed) @@ -97,13 +103,14 @@ def test_sac_with_il(args=get_args()): critic2 = Critic(net_c2, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + action_dim = space_info.action_info.action_dim if args.auto_alpha: - target_entropy = -np.prod(env.action_space.shape) + target_entropy = -action_dim log_alpha = torch.zeros(1, requires_grad=True, device=args.device) alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy = SACPolicy( + policy: BasePolicy = SACPolicy( actor=actor, actor_optim=actor_optim, critic=critic1, @@ -130,10 +137,10 @@ def test_sac_with_il(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer @@ -157,19 +164,20 @@ def stop_fn(mean_rewards): policy.eval() if args.task.startswith("Pendulum"): args.reward_threshold -= 50 # lower the goal - net = Actor( - Net( - args.state_shape, - hidden_sizes=args.imitation_hidden_sizes, - device=args.device, - ), + il_net = Net( + args.state_shape, + hidden_sizes=args.imitation_hidden_sizes, + device=args.device, + ) + il_actor = Actor( + il_net, args.action_shape, max_action=args.max_action, device=args.device, ).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) - il_policy = ImitationPolicy( - actor=net, + optim = torch.optim.Adam(il_actor.parameters(), lr=args.il_lr) + il_policy: BasePolicy = ImitationPolicy( + actor=il_actor, optim=optim, action_space=env.action_space, action_scaling=True, diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index ecad08bb0..7b8690f94 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -11,13 +11,15 @@ from tianshou.env import DummyVectorEnv from tianshou.exploration import GaussianNoise from tianshou.policy import TD3Policy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic +from tianshou.utils.space_info import SpaceInfo -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward-threshold", type=float, default=None) @@ -50,14 +52,18 @@ def get_args(): return parser.parse_known_args()[0] -def test_td3(args=get_args()): +def test_td3(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + args.max_action = space_info.action_info.max_action if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) @@ -92,7 +98,7 @@ def test_td3(args=get_args()): ) critic2 = Critic(net_c2, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) - policy = TD3Policy( + policy: TD3Policy = TD3Policy( actor=actor, actor_optim=actor_optim, critic=critic1, @@ -122,10 +128,10 @@ def test_td3(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # Iterator trainer @@ -156,8 +162,8 @@ def stop_fn(mean_rewards): env = gym.make(args.task) policy.eval() collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = collector.collect(n_episode=1, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 300c96db2..807061231 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -6,19 +6,21 @@ import numpy as np import torch from torch import nn -from torch.distributions import Independent, Normal +from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import TRPOPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OnpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.space_info import SpaceInfo -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward-threshold", type=float, default=None) @@ -53,14 +55,17 @@ def get_args(): return parser.parse_known_args()[0] -def test_trpo(args=get_args()): +def test_trpo(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) @@ -97,10 +102,10 @@ def test_trpo(args=get_args()): # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward - def dist(*logits): + def dist(*logits: torch.Tensor) -> Distribution: return Independent(Normal(*logits), 1) - policy = TRPOPolicy( + policy: BasePolicy = TRPOPolicy( actor=actor, critic=critic, optim=optim, @@ -127,10 +132,10 @@ def dist(*logits): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer @@ -156,8 +161,8 @@ def stop_fn(mean_rewards): env = gym.make(args.task) policy.eval() collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = collector.collect(n_episode=1, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 97205c5e7..3ca7ce6cf 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -11,6 +11,7 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import A2CPolicy, ImitationPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer, OnpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net @@ -22,7 +23,7 @@ envpool = None -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--reward-threshold", type=float, default=None) @@ -60,7 +61,7 @@ def get_args(): @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") -def test_a2c_with_il(args=get_args()): +def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None: # if you want to use python vector env, please refer to other test scripts train_envs = env = envpool.make( args.task, @@ -88,7 +89,7 @@ def test_a2c_with_il(args=get_args()): critic = Critic(net, device=args.device).to(args.device) optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr) dist = torch.distributions.Categorical - policy = A2CPolicy( + policy: A2CPolicy = A2CPolicy( actor=actor, critic=critic, optim=optim, @@ -114,10 +115,10 @@ def test_a2c_with_il(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer @@ -143,8 +144,8 @@ def stop_fn(mean_rewards): env = gym.make(args.task) policy.eval() collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = collector.collect(n_episode=1, render=args.render) + print(collector_stats) policy.eval() # here we define an imitation collector with a trivial policy @@ -153,7 +154,11 @@ def stop_fn(mean_rewards): net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) net = Actor(net, args.action_shape, device=args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) - il_policy = ImitationPolicy(actor=net, optim=optim, action_space=env.action_space) + il_policy: ImitationPolicy = ImitationPolicy( + actor=net, + optim=optim, + action_space=env.action_space, + ) il_test_collector = Collector( il_policy, envpool.make(args.task, env_type="gymnasium", num_envs=args.test_num, seed=args.seed), @@ -180,8 +185,8 @@ def stop_fn(mean_rewards): env = gym.make(args.task) il_policy.eval() collector = Collector(il_policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = collector.collect(n_episode=1, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py index f63224331..a45b02114 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdq.py @@ -12,7 +12,7 @@ from tianshou.utils.net.common import BranchingNet -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() # task parser.add_argument("--task", type=str, default="Pendulum-v1") @@ -48,7 +48,7 @@ def get_args(): return parser.parse_known_args()[0] -def test_bdq(args=get_args()): +def test_bdq(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) env = ContinuousToDiscrete(env, args.action_per_branch) @@ -92,7 +92,7 @@ def test_bdq(args=get_args()): device=args.device, ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy = BranchingDQNPolicy( + policy: BranchingDQNPolicy = BranchingDQNPolicy( model=net, optim=optim, discount_factor=args.gamma, @@ -110,14 +110,14 @@ def test_bdq(args=get_args()): # policy.set_eps(1) train_collector.collect(n_step=args.batch_size * args.training_num) - def train_fn(epoch, env_step): # exp decay + def train_fn(epoch: int, env_step: int) -> None: # exp decay eps = max(args.eps_train * (1 - args.eps_decay) ** env_step, args.eps_test) policy.set_eps(eps) - def test_fn(epoch, env_step): + def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer @@ -144,10 +144,8 @@ def stop_fn(mean_rewards): policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - collector_result = test_collector.collect(n_episode=args.test_num, render=args.render) - print( - f"Final reward: {collector_result.returns_stat.mean}, length: {collector_result.lens_stat.mean}", - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index e4406df18..013e2c414 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -2,21 +2,29 @@ import os import pickle import pprint +from typing import cast import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.data import ( + Collector, + PrioritizedVectorReplayBuffer, + ReplayBuffer, + VectorReplayBuffer, +) from tianshou.env import DummyVectorEnv from tianshou.policy import C51Policy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net +from tianshou.utils.space_info import SpaceInfo -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--reward-threshold", type=float, default=None) @@ -54,13 +62,18 @@ def get_args(): return parser.parse_known_args()[0] -def test_c51(args=get_args()): +def test_c51(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + env.action_space = cast(gym.spaces.Discrete, env.action_space) + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"CartPole-v0": 195} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) @@ -81,7 +94,7 @@ def test_c51(args=get_args()): num_atoms=args.num_atoms, ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy = C51Policy( + policy: C51Policy = C51Policy( model=net, optim=optim, action_space=env.action_space, @@ -93,6 +106,7 @@ def test_c51(args=get_args()): target_update_freq=args.target_update_freq, ).to(args.device) # buffer + buf: ReplayBuffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( args.buffer_size, @@ -112,13 +126,13 @@ def test_c51(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - def train_fn(epoch, env_step): + def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: policy.set_eps(args.eps_train) @@ -128,10 +142,10 @@ def train_fn(epoch, env_step): else: policy.set_eps(0.1 * args.eps_train) - def test_fn(epoch, env_step): + def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) - def save_checkpoint_fn(epoch, env_step, gradient_step): + def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, "checkpoint.pth") # Example: saving by epoch num @@ -195,16 +209,16 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = collector.collect(n_episode=1, render=args.render) + print(collector_stats) -def test_c51_resume(args=get_args()): +def test_c51_resume(args: argparse.Namespace = get_args()) -> None: args.resume = True test_c51(args) -def test_pc51(args=get_args()): +def test_pc51(args: argparse.Namespace = get_args()) -> None: args.prioritized_replay = True args.gamma = 0.95 args.seed = 1 diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 751f849ff..d0aba10c4 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -1,21 +1,29 @@ import argparse import os import pprint +from typing import cast import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.data import ( + Collector, + PrioritizedVectorReplayBuffer, + ReplayBuffer, + VectorReplayBuffer, +) from tianshou.env import DummyVectorEnv from tianshou.policy import DQNPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net +from tianshou.utils.space_info import SpaceInfo -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--reward-threshold", type=float, default=None) @@ -48,13 +56,18 @@ def get_args(): return parser.parse_known_args()[0] -def test_dqn(args=get_args()): +def test_dqn(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + env.action_space = cast(gym.spaces.Discrete, env.action_space) + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"CartPole-v0": 195} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) @@ -75,7 +88,7 @@ def test_dqn(args=get_args()): # dueling=(Q_param, V_param), ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy = DQNPolicy( + policy: BasePolicy = DQNPolicy( model=net, optim=optim, discount_factor=args.gamma, @@ -84,6 +97,7 @@ def test_dqn(args=get_args()): action_space=env.action_space, ) # buffer + buf: ReplayBuffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( args.buffer_size, @@ -103,13 +117,13 @@ def test_dqn(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - def train_fn(epoch, env_step): + def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: policy.set_eps(args.eps_train) @@ -119,7 +133,7 @@ def train_fn(epoch, env_step): else: policy.set_eps(0.1 * args.eps_train) - def test_fn(epoch, env_step): + def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) # trainer @@ -148,11 +162,11 @@ def test_fn(epoch, env_step): policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = collector.collect(n_episode=1, render=args.render) + print(collector_stats) -def test_pdqn(args=get_args()): +def test_pdqn(args: argparse.Namespace = get_args()) -> None: args.prioritized_replay = True args.gamma = 0.95 args.seed = 1 diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 5ca79fd30..3ed1f4fbe 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -1,6 +1,7 @@ import argparse import os import pprint +from typing import cast import gymnasium as gym import numpy as np @@ -10,12 +11,14 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import DQNPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Recurrent +from tianshou.utils.space_info import SpaceInfo -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--reward-threshold", type=float, default=None) @@ -46,13 +49,18 @@ def get_args(): return parser.parse_known_args()[0] -def test_drqn(args=get_args()): +def test_drqn(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + env.action_space = cast(gym.spaces.Discrete, env.action_space) + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"CartPole-v0": 195} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) @@ -68,7 +76,7 @@ def test_drqn(args=get_args()): args.device, ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy = DQNPolicy( + policy: DQNPolicy = DQNPolicy( model=net, optim=optim, discount_factor=args.gamma, @@ -93,16 +101,16 @@ def test_drqn(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - def train_fn(epoch, env_step): + def train_fn(epoch: int, env_step: int) -> None: policy.set_eps(args.eps_train) - def test_fn(epoch, env_step): + def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) # trainer @@ -130,8 +138,8 @@ def test_fn(epoch, env_step): env = gym.make(args.task) policy.eval() collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = collector.collect(n_episode=1, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 84dd207e8..7de090119 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -1,22 +1,30 @@ import argparse import os import pprint +from typing import cast import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.data import ( + Collector, + PrioritizedVectorReplayBuffer, + ReplayBuffer, + VectorReplayBuffer, +) from tianshou.env import DummyVectorEnv from tianshou.policy import FQFPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction +from tianshou.utils.space_info import SpaceInfo -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--reward-threshold", type=float, default=None) @@ -53,13 +61,18 @@ def get_args(): return parser.parse_known_args()[0] -def test_fqf(args=get_args()): +def test_fqf(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + space_info = SpaceInfo.from_env(env) + env.action_space = cast(gym.spaces.Discrete, env.action_space) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"CartPole-v0": 195} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) @@ -88,7 +101,7 @@ def test_fqf(args=get_args()): optim = torch.optim.Adam(net.parameters(), lr=args.lr) fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim) fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), lr=args.fraction_lr) - policy = FQFPolicy( + policy: BasePolicy = FQFPolicy( model=net, optim=optim, fraction_model=fraction_net, @@ -101,6 +114,7 @@ def test_fqf(args=get_args()): target_update_freq=args.target_update_freq, ).to(args.device) # buffer + buf: ReplayBuffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( args.buffer_size, @@ -120,13 +134,13 @@ def test_fqf(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - def train_fn(epoch, env_step): + def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: policy.set_eps(args.eps_train) @@ -136,7 +150,7 @@ def train_fn(epoch, env_step): else: policy.set_eps(0.1 * args.eps_train) - def test_fn(epoch, env_step): + def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) # trainer @@ -165,11 +179,11 @@ def test_fn(epoch, env_step): policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = collector.collect(n_episode=1, render=args.render) + print(collector_stats) -def test_pfqf(args=get_args()): +def test_pfqf(args: argparse.Namespace = get_args()) -> None: args.prioritized_replay = True args.gamma = 0.95 test_fqf(args) diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 23c35b9a8..5ff71f515 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -1,22 +1,30 @@ import argparse import os import pprint +from typing import cast import gymnasium as gym import numpy as np import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.data import ( + Collector, + PrioritizedVectorReplayBuffer, + ReplayBuffer, + VectorReplayBuffer, +) from tianshou.env import DummyVectorEnv from tianshou.policy import IQNPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import ImplicitQuantileNetwork +from tianshou.utils.space_info import SpaceInfo -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--reward-threshold", type=float, default=None) @@ -53,13 +61,18 @@ def get_args(): return parser.parse_known_args()[0] -def test_iqn(args=get_args()): +def test_iqn(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + space_info = SpaceInfo.from_env(env) + env.action_space = cast(gym.spaces.Discrete, env.action_space) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"CartPole-v0": 195} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) @@ -85,7 +98,7 @@ def test_iqn(args=get_args()): device=args.device, ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy = IQNPolicy( + policy: IQNPolicy = IQNPolicy( model=net, optim=optim, action_space=env.action_space, @@ -97,6 +110,7 @@ def test_iqn(args=get_args()): target_update_freq=args.target_update_freq, ).to(args.device) # buffer + buf: ReplayBuffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( args.buffer_size, @@ -116,13 +130,13 @@ def test_iqn(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - def train_fn(epoch, env_step): + def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: policy.set_eps(args.eps_train) @@ -132,7 +146,7 @@ def train_fn(epoch, env_step): else: policy.set_eps(0.1 * args.eps_train) - def test_fn(epoch, env_step): + def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) # trainer @@ -161,11 +175,11 @@ def test_fn(epoch, env_step): policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = collector.collect(n_episode=1, render=args.render) + print(collector_stats) -def test_piqn(args=get_args()): +def test_piqn(args: argparse.Namespace = get_args()) -> None: args.prioritized_replay = True args.gamma = 0.95 test_iqn(args) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 6d9873ab3..795086d1b 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -11,12 +11,14 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import PGPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OnpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net +from tianshou.utils.space_info import SpaceInfo -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--reward-threshold", type=float, default=None) @@ -43,13 +45,17 @@ def get_args(): return parser.parse_known_args()[0] -def test_pg(args=get_args()): +def test_pg(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"CartPole-v0": 195} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) @@ -70,7 +76,7 @@ def test_pg(args=get_args()): ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) dist_fn = torch.distributions.Categorical - policy = PGPolicy( + policy: BasePolicy = PGPolicy( actor=net, optim=optim, dist_fn=dist_fn, @@ -96,10 +102,10 @@ def test_pg(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer @@ -125,8 +131,8 @@ def stop_fn(mean_rewards): env = gym.make(args.task) policy.eval() collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = collector.collect(n_episode=1, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index e66dc23a9..63ef55122 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -5,19 +5,23 @@ import gymnasium as gym import numpy as np import torch +import torch.nn as nn from gymnasium.spaces import Box from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import PPOPolicy +from tianshou.policy.base import BasePolicy +from tianshou.policy.modelfree.ppo import PPOTrainingStats from tianshou.trainer import OnpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, DataParallelNet, Net from tianshou.utils.net.discrete import Actor, Critic +from tianshou.utils.space_info import SpaceInfo -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--reward-threshold", type=float, default=None) @@ -54,13 +58,17 @@ def get_args(): return parser.parse_known_args()[0] -def test_ppo(args=get_args()): +def test_ppo(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"CartPole-v0": 195} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) @@ -73,9 +81,11 @@ def test_ppo(args=get_args()): test_envs.seed(args.seed) # model net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + actor: nn.Module + critic: nn.Module if torch.cuda.is_available(): - actor = DataParallelNet(Actor(net, args.action_shape, device=None).to(args.device)) - critic = DataParallelNet(Critic(net, device=None).to(args.device)) + actor = DataParallelNet(Actor(net, args.action_shape, device=args.device).to(args.device)) + critic = DataParallelNet(Critic(net, device=args.device).to(args.device)) else: actor = Actor(net, args.action_shape, device=args.device).to(args.device) critic = Critic(net, device=args.device).to(args.device) @@ -87,7 +97,7 @@ def test_ppo(args=get_args()): torch.nn.init.zeros_(m.bias) optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) dist = torch.distributions.Categorical - policy = PPOPolicy( + policy: PPOPolicy[PPOTrainingStats] = PPOPolicy( actor=actor, critic=critic, optim=optim, @@ -119,10 +129,10 @@ def test_ppo(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer @@ -148,8 +158,8 @@ def stop_fn(mean_rewards): env = gym.make(args.task) policy.eval() collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = collector.collect(n_episode=1, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 6d54d59cf..c1bbcc3fa 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -1,6 +1,7 @@ import argparse import os import pprint +from typing import cast import gymnasium as gym import numpy as np @@ -10,12 +11,15 @@ from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import QRDQNPolicy +from tianshou.policy.base import BasePolicy +from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net +from tianshou.utils.space_info import SpaceInfo -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--reward-threshold", type=float, default=None) @@ -49,15 +53,22 @@ def get_args(): return parser.parse_known_args()[0] -def test_qrdqn(args=get_args()): +def test_qrdqn(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - if args.task == "CartPole-v0": + env.action_space = cast(gym.spaces.Discrete, env.action_space) + + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + + if args.task == "CartPole-v0" and env.spec: env.spec.reward_threshold = 190 # lower the goal - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n if args.reward_threshold is None: default_reward_threshold = {"CartPole-v0": 195} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) @@ -78,7 +89,7 @@ def test_qrdqn(args=get_args()): num_atoms=args.num_quantiles, ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy = QRDQNPolicy( + policy: QRDQNPolicy[QRDQNTrainingStats] = QRDQNPolicy( model=net, optim=optim, action_space=env.action_space, @@ -88,6 +99,7 @@ def test_qrdqn(args=get_args()): target_update_freq=args.target_update_freq, ).to(args.device) # buffer + buf: PrioritizedVectorReplayBuffer | VectorReplayBuffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( args.buffer_size, @@ -107,13 +119,13 @@ def test_qrdqn(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - def train_fn(epoch, env_step): + def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: policy.set_eps(args.eps_train) @@ -123,7 +135,7 @@ def train_fn(epoch, env_step): else: policy.set_eps(0.1 * args.eps_train) - def test_fn(epoch, env_step): + def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) # trainer @@ -152,11 +164,11 @@ def test_fn(epoch, env_step): policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = collector.collect(n_episode=1, render=args.render) + print(collector_stats) -def test_pqrdqn(args=get_args()): +def test_pqrdqn(args: argparse.Namespace = get_args()) -> None: args.prioritized_replay = True args.gamma = 0.95 test_qrdqn(args) diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index c7e51a36e..fafa1e03b 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -2,6 +2,7 @@ import os import pickle import pprint +from typing import cast import gymnasium as gym import numpy as np @@ -11,13 +12,16 @@ from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import RainbowPolicy +from tianshou.policy.base import BasePolicy +from tianshou.policy.modelfree.rainbow import RainbowTrainingStats from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import NoisyLinear +from tianshou.utils.space_info import SpaceInfo -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--reward-threshold", type=float, default=None) @@ -57,13 +61,20 @@ def get_args(): return parser.parse_known_args()[0] -def test_rainbow(args=get_args()): +def test_rainbow(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + env.action_space = cast(gym.spaces.Discrete, env.action_space) + + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + if args.reward_threshold is None: default_reward_threshold = {"CartPole-v0": 195} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) @@ -77,7 +88,7 @@ def test_rainbow(args=get_args()): # model - def noisy_linear(x, y): + def noisy_linear(x: int, y: int) -> NoisyLinear: return NoisyLinear(x, y, args.noisy_std) net = Net( @@ -90,7 +101,7 @@ def noisy_linear(x, y): dueling_param=({"linear_layer": noisy_linear}, {"linear_layer": noisy_linear}), ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy = RainbowPolicy( + policy: RainbowPolicy[RainbowTrainingStats] = RainbowPolicy( model=net, optim=optim, discount_factor=args.gamma, @@ -102,6 +113,7 @@ def noisy_linear(x, y): target_update_freq=args.target_update_freq, ).to(args.device) # buffer + buf: PrioritizedVectorReplayBuffer | VectorReplayBuffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( args.buffer_size, @@ -122,13 +134,13 @@ def noisy_linear(x, y): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - def train_fn(epoch, env_step): + def train_fn(epoch: int, env_step: int) -> None: # eps annealing, just a demo if env_step <= 10000: policy.set_eps(args.eps_train) @@ -147,10 +159,10 @@ def train_fn(epoch, env_step): beta = args.beta_final buf.set_beta(beta) - def test_fn(epoch, env_step): + def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) - def save_checkpoint_fn(epoch, env_step, gradient_step): + def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, "checkpoint.pth") # Example: saving by epoch num @@ -214,16 +226,16 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = collector.collect(n_episode=1, render=args.render) + print(collector_stats) -def test_rainbow_resume(args=get_args()): +def test_rainbow_resume(args: argparse.Namespace = get_args()) -> None: args.resume = True test_rainbow(args) -def test_prainbow(args=get_args()): +def test_prainbow(args: argparse.Namespace = get_args()) -> None: args.prioritized_replay = True args.gamma = 0.95 args.seed = 1 diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 4831f12de..9d5e27be6 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -1,6 +1,7 @@ import argparse import os import pprint +from typing import cast import gymnasium as gym import numpy as np @@ -10,13 +11,16 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import DiscreteSACPolicy +from tianshou.policy.base import BasePolicy +from tianshou.policy.modelfree.discrete_sac import DiscreteSACTrainingStats from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import Actor, Critic +from tianshou.utils.space_info import SpaceInfo -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--reward-threshold", type=float, default=None) @@ -48,13 +52,20 @@ def get_args(): return parser.parse_known_args()[0] -def test_discrete_sac(args=get_args()): +def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + env.action_space = cast(gym.spaces.Discrete, env.action_space) + + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + if args.reward_threshold is None: default_reward_threshold = {"CartPole-v0": 170} # lower the goal - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) @@ -64,24 +75,26 @@ def test_discrete_sac(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model + obs_dim = space_info.observation_info.obs_dim + action_dim = space_info.action_info.action_dim net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, softmax_output=False, device=args.device).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c1 = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - critic1 = Critic(net_c1, last_size=args.action_shape, device=args.device).to(args.device) + critic1 = Critic(net_c1, last_size=action_dim, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - net_c2 = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - critic2 = Critic(net_c2, last_size=args.action_shape, device=args.device).to(args.device) + net_c2 = Net(obs_dim, hidden_sizes=args.hidden_sizes, device=args.device) + critic2 = Critic(net_c2, last_size=action_dim, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) # better not to use auto alpha in CartPole if args.auto_alpha: - target_entropy = 0.98 * np.log(np.prod(args.action_shape)) + target_entropy = 0.98 * np.log(action_dim) log_alpha = torch.zeros(1, requires_grad=True, device=args.device) alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy = DiscreteSACPolicy( + policy: DiscreteSACPolicy[DiscreteSACTrainingStats] = DiscreteSACPolicy( actor=actor, actor_optim=actor_optim, critic=critic1, @@ -107,10 +120,10 @@ def test_discrete_sac(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer @@ -137,8 +150,8 @@ def stop_fn(mean_rewards): env = gym.make(args.task) policy.eval() collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = collector.collect(n_episode=1, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/test/highlevel/__init__.py b/test/highlevel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/highlevel/env_factory.py b/test/highlevel/env_factory.py index 4d09ce86d..1dd1273d4 100644 --- a/test/highlevel/env_factory.py +++ b/test/highlevel/env_factory.py @@ -5,10 +5,10 @@ class DiscreteTestEnvFactory(EnvFactoryRegistered): - def __init__(self): + def __init__(self) -> None: super().__init__(task="CartPole-v0", seed=42, venv_type=VectorEnvType.DUMMY) class ContinuousTestEnvFactory(EnvFactoryRegistered): - def __init__(self): + def __init__(self) -> None: super().__init__(task="Pendulum-v1", seed=42, venv_type=VectorEnvType.DUMMY) diff --git a/test/highlevel/test_experiment_builder.py b/test/highlevel/test_experiment_builder.py index e53c0f744..028c46cf4 100644 --- a/test/highlevel/test_experiment_builder.py +++ b/test/highlevel/test_experiment_builder.py @@ -8,6 +8,7 @@ DDPGExperimentBuilder, DiscreteSACExperimentBuilder, DQNExperimentBuilder, + ExperimentBuilder, ExperimentConfig, IQNExperimentBuilder, PGExperimentBuilder, @@ -33,7 +34,7 @@ PGExperimentBuilder, ], ) -def test_experiment_builder_continuous_default_params(builder_cls): +def test_experiment_builder_continuous_default_params(builder_cls: type[ExperimentBuilder]) -> None: env_factory = ContinuousTestEnvFactory() sampling_config = SamplingConfig( num_epochs=1, @@ -62,7 +63,7 @@ def test_experiment_builder_continuous_default_params(builder_cls): IQNExperimentBuilder, ], ) -def test_experiment_builder_discrete_default_params(builder_cls): +def test_experiment_builder_discrete_default_params(builder_cls: type[ExperimentBuilder]) -> None: env_factory = DiscreteTestEnvFactory() sampling_config = SamplingConfig( num_epochs=1, diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index ecfb4d2b6..0f957c75d 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -1,6 +1,7 @@ import argparse import os import pprint +from typing import cast import gymnasium as gym import numpy as np @@ -10,13 +11,16 @@ from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import DQNPolicy, ICMPolicy +from tianshou.policy.base import BasePolicy +from tianshou.policy.modelfree.dqn import DQNTrainingStats from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.discrete import IntrinsicCuriosityModule +from tianshou.utils.space_info import SpaceInfo -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--reward-threshold", type=float, default=None) @@ -67,13 +71,20 @@ def get_args(): return parser.parse_known_args()[0] -def test_dqn_icm(args=get_args()): +def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + env.action_space = cast(gym.spaces.Discrete, env.action_space) + + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + if args.reward_threshold is None: default_reward_threshold = {"CartPole-v0": 195} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) @@ -94,7 +105,7 @@ def test_dqn_icm(args=get_args()): # dueling=(Q_param, V_param), ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy = DQNPolicy( + policy: DQNPolicy[DQNTrainingStats] = DQNPolicy( model=net, optim=optim, action_space=env.action_space, @@ -103,13 +114,14 @@ def test_dqn_icm(args=get_args()): target_update_freq=args.target_update_freq, ) feature_dim = args.hidden_sizes[-1] + obs_dim = space_info.observation_info.obs_dim feature_net = MLP( - np.prod(args.state_shape), + obs_dim, output_dim=feature_dim, hidden_sizes=args.hidden_sizes[:-1], device=args.device, ) - action_dim = np.prod(args.action_shape) + action_dim = space_info.action_info.action_dim icm_net = IntrinsicCuriosityModule( feature_net, feature_dim, @@ -118,7 +130,7 @@ def test_dqn_icm(args=get_args()): device=args.device, ).to(args.device) icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) - policy = ICMPolicy( + policy: ICMPolicy = ICMPolicy( policy=policy, model=icm_net, optim=icm_optim, @@ -128,6 +140,7 @@ def test_dqn_icm(args=get_args()): forward_loss_weight=args.forward_loss_weight, ) # buffer + buf: PrioritizedVectorReplayBuffer | VectorReplayBuffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( args.buffer_size, @@ -147,13 +160,13 @@ def test_dqn_icm(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - def train_fn(epoch, env_step): + def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: policy.set_eps(args.eps_train) @@ -163,7 +176,7 @@ def train_fn(epoch, env_step): else: policy.set_eps(0.1 * args.eps_train) - def test_fn(epoch, env_step): + def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) # trainer @@ -192,8 +205,8 @@ def test_fn(epoch, env_step): policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = collector.collect(n_episode=1, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index 27adc12d8..3095f7cc9 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -11,13 +11,16 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import ICMPolicy, PPOPolicy +from tianshou.policy.base import BasePolicy +from tianshou.policy.modelfree.ppo import PPOTrainingStats from tianshou.trainer import OnpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, ActorCritic, Net from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule +from tianshou.utils.space_info import SpaceInfo -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--reward-threshold", type=float, default=None) @@ -72,13 +75,19 @@ def get_args(): return parser.parse_known_args()[0] -def test_ppo(args=get_args()): +def test_ppo(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + if args.reward_threshold is None: default_reward_threshold = {"CartPole-v0": 195} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) @@ -101,7 +110,7 @@ def test_ppo(args=get_args()): torch.nn.init.zeros_(m.bias) optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) dist = torch.distributions.Categorical - policy = PPOPolicy( + policy: PPOPolicy[PPOTrainingStats] = PPOPolicy( actor=actor, critic=critic, optim=optim, @@ -123,12 +132,12 @@ def test_ppo(args=get_args()): ) feature_dim = args.hidden_sizes[-1] feature_net = MLP( - np.prod(args.state_shape), + space_info.observation_info.obs_dim, output_dim=feature_dim, hidden_sizes=args.hidden_sizes[:-1], device=args.device, ) - action_dim = np.prod(args.action_shape) + action_dim = space_info.action_info.action_dim icm_net = IntrinsicCuriosityModule( feature_net, feature_dim, @@ -158,10 +167,10 @@ def test_ppo(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer @@ -187,8 +196,8 @@ def stop_fn(mean_rewards): env = gym.make(args.task) policy.eval() collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = collector.collect(n_episode=1, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index d2fbf7358..55719f47e 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -18,7 +18,7 @@ envpool = None -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="NChain-v0") parser.add_argument("--reward-threshold", type=float, default=None) @@ -46,7 +46,7 @@ def get_args(): @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") -def test_psrl(args=get_args()): +def test_psrl(args: argparse.Namespace = get_args()) -> None: # if you want to use python vector env, please refer to other test scripts train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed) test_envs = envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed) @@ -65,7 +65,7 @@ def test_psrl(args=get_args()): trans_count_prior = np.ones((n_state, n_action, n_state)) rew_mean_prior = np.full((n_state, n_action), args.rew_mean_prior) rew_std_prior = np.full((n_state, n_action), args.rew_std_prior) - policy = PSRLPolicy( + policy: PSRLPolicy = PSRLPolicy( trans_count_prior=trans_count_prior, rew_mean_prior=rew_mean_prior, rew_std_prior=rew_std_prior, @@ -83,20 +83,19 @@ def test_psrl(args=get_args()): ) test_collector = Collector(policy, test_envs) # Logger + log_path = os.path.join(args.logdir, args.task, "psrl") + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger: WandbLogger | TensorboardLogger | LazyLogger if args.logger == "wandb": logger = WandbLogger(save_interval=1, project="psrl", name="wandb_test", config=args) - if args.logger != "none": - log_path = os.path.join(args.logdir, args.task, "psrl") - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: - logger.load(writer) + logger.load(writer) + elif args.logger == "tensorboard": + logger = TensorboardLogger(writer) else: logger = LazyLogger() - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold train_collector.collect(n_step=args.buffer_size, random=True) diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index cf4a5bd24..7f1a3128b 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -1,6 +1,7 @@ import argparse import os import pickle +from typing import cast import gymnasium as gym import numpy as np @@ -10,16 +11,19 @@ from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import QRDQNPolicy +from tianshou.policy.base import BasePolicy +from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net +from tianshou.utils.space_info import SpaceInfo -def expert_file_name(): +def expert_file_name() -> str: return os.path.join(os.path.dirname(__file__), "expert_QRDQN_CartPole-v0.pkl") -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--reward-threshold", type=float, default=None) @@ -54,14 +58,21 @@ def get_args(): return parser.parse_known_args()[0] -def gather_data(): +def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: args = get_args() env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + env.action_space = cast(gym.spaces.Discrete, env.action_space) + + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + if args.reward_threshold is None: default_reward_threshold = {"CartPole-v0": 190} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) @@ -82,7 +93,7 @@ def gather_data(): num_atoms=args.num_quantiles, ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy = QRDQNPolicy( + policy: QRDQNPolicy[QRDQNTrainingStats] = QRDQNPolicy( model=net, optim=optim, action_space=env.action_space, @@ -92,6 +103,7 @@ def gather_data(): target_update_freq=args.target_update_freq, ).to(args.device) # buffer + buf: VectorReplayBuffer | PrioritizedVectorReplayBuffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( args.buffer_size, @@ -111,13 +123,13 @@ def gather_data(): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - def train_fn(epoch, env_step): + def train_fn(epoch: int, env_step: int) -> None: # eps annnealing, just a demo if env_step <= 10000: policy.set_eps(args.eps_train) @@ -127,7 +139,7 @@ def train_fn(epoch, env_step): else: policy.set_eps(0.1 * args.eps_train) - def test_fn(epoch, env_step): + def test_fn(epoch: int, env_step: int | None) -> None: policy.set_eps(args.eps_test) # trainer @@ -153,11 +165,11 @@ def test_fn(epoch, env_step): buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs)) policy.set_eps(0.2) collector = Collector(policy, test_envs, buf, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + collector_stats = collector.collect(n_step=args.buffer_size) if args.save_buffer_name.endswith(".hdf5"): buf.save_hdf5(args.save_buffer_name) else: with open(args.save_buffer_name, "wb") as f: pickle.dump(buf, f) - print(result.returns_stat.mean) + print(collector_stats) return buf diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index 064369cd2..e7a76221a 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -10,17 +10,20 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import SACPolicy +from tianshou.policy.base import BasePolicy +from tianshou.policy.modelfree.sac import SACTrainingStats from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.space_info import SpaceInfo -def expert_file_name(): +def expert_file_name() -> str: return os.path.join(os.path.dirname(__file__), "expert_SAC_Pendulum-v1.pkl") -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward-threshold", type=float, default=None) @@ -61,16 +64,22 @@ def get_args(): return parser.parse_known_args()[0] -def gather_data(): +def gather_data() -> VectorReplayBuffer: """Return expert buffer data.""" args = get_args() env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] + + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + args.max_action = space_info.action_info.max_action + if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) @@ -100,13 +109,14 @@ def gather_data(): critic = Critic(net_c, device=args.device).to(args.device) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) + action_dim = space_info.action_info.action_dim if args.auto_alpha: - target_entropy = -np.prod(env.action_space.shape) + target_entropy = -action_dim log_alpha = torch.zeros(1, requires_grad=True, device=args.device) alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy = SACPolicy( + policy: SACPolicy[SACTrainingStats] = SACPolicy( actor=actor, actor_optim=actor_optim, critic=critic, @@ -127,10 +137,10 @@ def gather_data(): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer @@ -149,8 +159,8 @@ def stop_fn(mean_rewards): logger=logger, ).run() train_collector.reset() - result = train_collector.collect(n_step=args.buffer_size) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = train_collector.collect(n_step=args.buffer_size) + print(collector_stats) if args.save_buffer_name.endswith(".hdf5"): buffer.save_hdf5(args.save_buffer_name) else: diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index c82eadcd7..425f70b25 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -11,11 +11,13 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import BCQPolicy +from tianshou.policy import BasePolicy, BCQPolicy +from tianshou.policy.imitation.bcq import BCQTrainingStats from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import VAE, Critic, Perturbation +from tianshou.utils.space_info import SpaceInfo if __name__ == "__main__": from gather_pendulum_data import expert_file_name, gather_data @@ -23,7 +25,7 @@ from test.offline.gather_pendulum_data import expert_file_name, gather_data -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward-threshold", type=float, default=None) @@ -64,7 +66,7 @@ def get_args(): return parser.parse_known_args()[0] -def test_bcq(args=get_args()): +def test_bcq(args: argparse.Namespace = get_args()) -> None: if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) @@ -74,16 +76,23 @@ def test_bcq(args=get_args()): else: buffer = gather_data() env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] # float + + space_info = SpaceInfo.from_env(env) + + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + args.max_action = space_info.action_info.max_action + args.state_dim = space_info.observation_info.obs_dim + args.action_dim = space_info.action_info.action_dim + if args.reward_threshold is None: # too low? default_reward_threshold = {"Pendulum-v0": -1100, "Pendulum-v1": -1100} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) - args.state_dim = args.state_shape[0] - args.action_dim = args.action_shape[0] # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed @@ -139,7 +148,7 @@ def test_bcq(args=get_args()): ).to(args.device) vae_optim = torch.optim.Adam(vae.parameters()) - policy = BCQPolicy( + policy: BCQPolicy[BCQTrainingStats] = BCQPolicy( actor_perturbation=actor, actor_perturbation_optim=actor_optim, critic=critic, @@ -170,13 +179,13 @@ def test_bcq(args=get_args()): writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - def watch(): + def watch() -> None: policy.load_state_dict( torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")), ) @@ -206,8 +215,8 @@ def watch(): env = gym.make(args.task) policy.eval() collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = collector.collect(n_episode=1, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 693eecc49..5ce5b406d 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -3,6 +3,7 @@ import os import pickle import pprint +from typing import cast import gymnasium as gym import numpy as np @@ -11,11 +12,13 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import CQLPolicy +from tianshou.policy import BasePolicy, CQLPolicy +from tianshou.policy.imitation.cql import CQLTrainingStats from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.space_info import SpaceInfo if __name__ == "__main__": from gather_pendulum_data import expert_file_name, gather_data @@ -23,7 +26,7 @@ from test.offline.gather_pendulum_data import expert_file_name, gather_data -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward-threshold", type=float, default=None) @@ -68,7 +71,7 @@ def get_args(): return parser.parse_known_args()[0] -def test_cql(args=get_args()): +def test_cql(args: argparse.Namespace = get_args()) -> None: if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) @@ -78,16 +81,25 @@ def test_cql(args=get_args()): else: buffer = gather_data() env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] # float + env.action_space = cast(gym.spaces.Box, env.action_space) + + space_info = SpaceInfo.from_env(env) + + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + args.min_action = space_info.action_info.min_action + args.max_action = space_info.action_info.max_action + args.state_dim = space_info.observation_info.obs_dim + args.action_dim = space_info.action_info.action_dim + if args.reward_threshold is None: # too low? default_reward_threshold = {"Pendulum-v0": -1200, "Pendulum-v1": -1200} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) - args.state_dim = args.state_shape[0] - args.action_dim = args.action_shape[0] # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed @@ -124,12 +136,12 @@ def test_cql(args=get_args()): critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) if args.auto_alpha: - target_entropy = -np.prod(env.action_space.shape) + target_entropy = -np.prod(args.action_shape) log_alpha = torch.zeros(1, requires_grad=True, device=args.device) alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy = CQLPolicy( + policy: CQLPolicy[CQLTrainingStats] = CQLPolicy( actor=actor, actor_optim=actor_optim, critic=critic, @@ -146,8 +158,8 @@ def test_cql(args=get_args()): temperature=args.temperature, with_lagrange=args.with_lagrange, lagrange_threshold=args.lagrange_threshold, - min_action=np.min(env.action_space.low), - max_action=np.max(env.action_space.high), + min_action=args.min_action, + max_action=args.max_action, device=args.device, ) @@ -168,10 +180,10 @@ def test_cql(args=get_args()): writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer @@ -202,9 +214,10 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) collector_result = collector.collect(n_episode=1, render=args.render) - print( - f"Final reward: {collector_result.returns_stat.mean}, length: {collector_result.lens_stat.mean}", - ) + if collector_result.returns_stat and collector_result.lens_stat: + print( + f"Final reward: {collector_result.returns_stat.mean}, length: {collector_result.lens_stat.mean}", + ) if __name__ == "__main__": diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 4f4208b5e..32e6d5696 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -2,6 +2,7 @@ import os import pickle import pprint +from typing import cast import gymnasium as gym import numpy as np @@ -10,11 +11,12 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import DiscreteBCQPolicy +from tianshou.policy import BasePolicy, DiscreteBCQPolicy from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.discrete import Actor +from tianshou.utils.space_info import SpaceInfo if __name__ == "__main__": from gather_cartpole_data import expert_file_name, gather_data @@ -22,7 +24,7 @@ from test.offline.gather_cartpole_data import expert_file_name, gather_data -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--reward-threshold", type=float, default=None) @@ -52,14 +54,19 @@ def get_args(): return parser.parse_known_args()[0] -def test_discrete_bcq(args=get_args()): +def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: # envs env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + env.action_space = cast(gym.spaces.Discrete, env.action_space) + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"CartPole-v0": 185} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed np.random.seed(args.seed) @@ -82,7 +89,7 @@ def test_discrete_bcq(args=get_args()): actor_critic = ActorCritic(policy_net, imitation_net) optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) - policy = DiscreteBCQPolicy( + policy: DiscreteBCQPolicy = DiscreteBCQPolicy( model=policy_net, imitator=imitation_net, optim=optim, @@ -111,13 +118,13 @@ def test_discrete_bcq(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - def save_checkpoint_fn(epoch, env_step, gradient_step): + def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, "checkpoint.pth") # Example: saving by epoch num @@ -166,11 +173,11 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = collector.collect(n_episode=1, render=args.render) + print(collector_stats) -def test_discrete_bcq_resume(args=get_args()): +def test_discrete_bcq_resume(args: argparse.Namespace = get_args()) -> None: args.resume = True test_discrete_bcq(args) diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index b7e4cc567..3d8bb4c39 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -2,6 +2,7 @@ import os import pickle import pprint +from typing import cast import gymnasium as gym import numpy as np @@ -10,10 +11,11 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import DiscreteCQLPolicy +from tianshou.policy import BasePolicy, DiscreteCQLPolicy from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net +from tianshou.utils.space_info import SpaceInfo if __name__ == "__main__": from gather_cartpole_data import expert_file_name, gather_data @@ -21,7 +23,7 @@ from test.offline.gather_cartpole_data import expert_file_name, gather_data -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--reward-threshold", type=float, default=None) @@ -49,14 +51,19 @@ def get_args(): return parser.parse_known_args()[0] -def test_discrete_cql(args=get_args()): +def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: # envs env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + env.action_space = cast(gym.spaces.Discrete, env.action_space) + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"CartPole-v0": 170} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed np.random.seed(args.seed) @@ -73,7 +80,7 @@ def test_discrete_cql(args=get_args()): ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy = DiscreteCQLPolicy( + policy: BasePolicy = DiscreteCQLPolicy( model=net, optim=optim, action_space=env.action_space, @@ -100,10 +107,10 @@ def test_discrete_cql(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold result = OfflineTrainer( @@ -128,8 +135,8 @@ def stop_fn(mean_rewards): policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = collector.collect(n_episode=1, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index ea880a530..beca5467f 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -2,6 +2,7 @@ import os import pickle import pprint +from typing import cast import gymnasium as gym import numpy as np @@ -10,11 +11,12 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import DiscreteCRRPolicy +from tianshou.policy import BasePolicy, DiscreteCRRPolicy from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.discrete import Actor, Critic +from tianshou.utils.space_info import SpaceInfo if __name__ == "__main__": from gather_cartpole_data import expert_file_name, gather_data @@ -22,7 +24,7 @@ from test.offline.gather_cartpole_data import expert_file_name, gather_data -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--reward-threshold", type=float, default=None) @@ -47,14 +49,19 @@ def get_args(): return parser.parse_known_args()[0] -def test_discrete_crr(args=get_args()): +def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: # envs env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + env.action_space = cast(gym.spaces.Discrete, env.action_space) + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: default_reward_threshold = {"CartPole-v0": 180} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed np.random.seed(args.seed) @@ -69,16 +76,17 @@ def test_discrete_crr(args=get_args()): device=args.device, softmax_output=False, ) + action_dim = space_info.action_info.action_dim critic = Critic( net, hidden_sizes=args.hidden_sizes, - last_size=np.prod(args.action_shape), + last_size=action_dim, device=args.device, ) actor_critic = ActorCritic(actor, critic) optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) - policy = DiscreteCRRPolicy( + policy: DiscreteCRRPolicy = DiscreteCRRPolicy( actor=actor, critic=critic, optim=optim, @@ -103,10 +111,10 @@ def test_discrete_crr(args=get_args()): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold result = OfflineTrainer( @@ -130,8 +138,8 @@ def stop_fn(mean_rewards): env = gym.make(args.task) policy.eval() collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = collector.collect(n_episode=1, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index 650aabb8a..37eb3352f 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -6,16 +6,17 @@ import gymnasium as gym import numpy as np import torch -from torch.distributions import Independent, Normal +from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.policy import GAILPolicy +from tianshou.policy import BasePolicy, GAILPolicy from tianshou.trainer import OnpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.space_info import SpaceInfo if __name__ == "__main__": from gather_pendulum_data import expert_file_name, gather_data @@ -23,7 +24,7 @@ from test.offline.gather_pendulum_data import expert_file_name, gather_data -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward-threshold", type=float, default=None) @@ -65,7 +66,7 @@ def get_args(): return parser.parse_known_args()[0] -def test_gail(args=get_args()): +def test_gail(args: argparse.Namespace = get_args()) -> None: if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) @@ -77,10 +78,14 @@ def test_gail(args=get_args()): env = gym.make(args.task) if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -1100, "Pendulum-v1": -1100} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + args.max_action = space_info.action_info.max_action # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) @@ -128,10 +133,10 @@ def test_gail(args=get_args()): # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward - def dist(*logits): + def dist(*logits: torch.Tensor) -> Distribution: return Independent(Normal(*logits), 1) - policy = GAILPolicy( + policy: BasePolicy = GAILPolicy( actor=actor, critic=critic, optim=optim, @@ -165,13 +170,13 @@ def dist(*logits): writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold - def save_checkpoint_fn(epoch, env_step, gradient_step): + def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html ckpt_path = os.path.join(log_path, "checkpoint.pth") # Example: saving by epoch num @@ -222,8 +227,8 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): env = gym.make(args.task) policy.eval() collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") + collector_stats = collector.collect(n_episode=1, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index af915844d..961af2ab3 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -13,10 +13,12 @@ from tianshou.env import DummyVectorEnv from tianshou.exploration import GaussianNoise from tianshou.policy import TD3BCPolicy +from tianshou.policy.base import BasePolicy from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic +from tianshou.utils.space_info import SpaceInfo if __name__ == "__main__": from gather_pendulum_data import expert_file_name, gather_data @@ -24,7 +26,7 @@ from test.offline.gather_pendulum_data import expert_file_name, gather_data -def get_args(): +def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="Pendulum-v1") parser.add_argument("--reward-threshold", type=float, default=None) @@ -64,7 +66,7 @@ def get_args(): return parser.parse_known_args()[0] -def test_td3_bc(args=get_args()): +def test_td3_bc(args: argparse.Namespace = get_args()) -> None: if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) @@ -74,16 +76,20 @@ def test_td3_bc(args=get_args()): else: buffer = gather_data() env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] # float + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + args.max_action = space_info.action_info.max_action if args.reward_threshold is None: # too low? default_reward_threshold = {"Pendulum-v0": -1200, "Pendulum-v1": -1200} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) - args.state_dim = args.state_shape[0] - args.action_dim = args.action_shape[0] + args.state_dim = space_info.action_info.action_dim + args.action_dim = space_info.observation_info.obs_dim # test_envs = gym.make(args.task) test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) # seed @@ -126,7 +132,7 @@ def test_td3_bc(args=get_args()): critic2 = Critic(net_c2, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) - policy = TD3BCPolicy( + policy: TD3BCPolicy = TD3BCPolicy( actor=actor, actor_optim=actor_optim, critic=critic1, @@ -161,10 +167,10 @@ def test_td3_bc(args=get_args()): writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # trainer @@ -194,10 +200,8 @@ def stop_fn(mean_rewards): env = gym.make(args.task) policy.eval() collector = Collector(policy, env) - collector_result = collector.collect(n_episode=1, render=args.render) - print( - f"Final reward: {collector_result.returns_stat.mean}, length: {collector_result.lens_stat.mean}", - ) + collector_stats = collector.collect(n_episode=1, render=args.render) + print(collector_stats) if __name__ == "__main__": diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 722902049..0dd750b4c 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -97,7 +97,7 @@ def get_agents( device=args.device, ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - agent = DQNPolicy( + agent: DQNPolicy = DQNPolicy( model=net, optim=optim, action_space=env.action_space, @@ -142,16 +142,16 @@ def train_agent( writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: pass - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return False - def train_fn(epoch, env_step): + def train_fn(epoch: int, env_step: int) -> None: [agent.set_eps(args.eps_train) for agent in policy.policies.values()] - def test_fn(epoch, env_step): + def test_fn(epoch: int, env_step: int | None) -> None: [agent.set_eps(args.eps_test) for agent in policy.policies.values()] def reward_metric(rews): diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 0cdf54fec..8bbb20cfd 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -8,7 +8,7 @@ import torch from pettingzoo.butterfly import pistonball_v6 from torch import nn -from torch.distributions import Independent, Normal +from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer @@ -181,10 +181,10 @@ def get_agents( torch.nn.init.zeros_(m.bias) optim = torch.optim.Adam(set(actor.parameters()).union(critic.parameters()), lr=args.lr) - def dist(*logits): + def dist(*logits: torch.Tensor) -> Distribution: return Independent(Normal(*logits), 1) - agent = PPOPolicy( + agent: PPOPolicy = PPOPolicy( actor, critic, optim, @@ -241,10 +241,10 @@ def train_agent( writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: pass - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return False def reward_metric(rews): diff --git a/test/pettingzoo/test_pistonball.py b/test/pettingzoo/test_pistonball.py index 5043a96ca..2432ca531 100644 --- a/test/pettingzoo/test_pistonball.py +++ b/test/pettingzoo/test_pistonball.py @@ -1,9 +1,10 @@ +import argparse import pprint from pistonball import get_args, train_agent, watch -def test_piston_ball(args=get_args()): +def test_piston_ball(args: argparse.Namespace = get_args()) -> None: if args.watch: watch(args) return diff --git a/test/pettingzoo/test_pistonball_continuous.py b/test/pettingzoo/test_pistonball_continuous.py index f85884d53..bb2979bc6 100644 --- a/test/pettingzoo/test_pistonball_continuous.py +++ b/test/pettingzoo/test_pistonball_continuous.py @@ -1,3 +1,4 @@ +import argparse import pprint import pytest @@ -5,7 +6,7 @@ @pytest.mark.skip(reason="runtime too long and unstable result") -def test_piston_ball_continuous(args=get_args()): +def test_piston_ball_continuous(args: argparse.Namespace = get_args()) -> None: if args.watch: watch(args) return diff --git a/test/pettingzoo/test_tic_tac_toe.py b/test/pettingzoo/test_tic_tac_toe.py index f689283f2..44aa86b9f 100644 --- a/test/pettingzoo/test_tic_tac_toe.py +++ b/test/pettingzoo/test_tic_tac_toe.py @@ -1,9 +1,10 @@ +import argparse import pprint from tic_tac_toe import get_args, train_agent, watch -def test_tic_tac_toe(args=get_args()): +def test_tic_tac_toe(args: argparse.Namespace = get_args()) -> None: if args.watch: watch(args) return diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index 7835c5d30..62b66dfa4 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -177,20 +177,20 @@ def train_agent( writer.add_text("args", str(args)) logger = TensorboardLogger(writer) - def save_best_fn(policy): + def save_best_fn(policy: BasePolicy) -> None: if hasattr(args, "model_save_path"): model_save_path = args.model_save_path else: model_save_path = os.path.join(args.logdir, "tic_tac_toe", "dqn", "policy.pth") torch.save(policy.policies[agents[args.agent_id - 1]].state_dict(), model_save_path) - def stop_fn(mean_rewards): + def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.win_rate - def train_fn(epoch, env_step): + def train_fn(epoch: int, env_step: int) -> None: policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_train) - def test_fn(epoch, env_step): + def test_fn(epoch: int, env_step: int | None) -> None: policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) def reward_metric(rews): diff --git a/tianshou/utils/space_info.py b/tianshou/utils/space_info.py new file mode 100644 index 000000000..f8b99053f --- /dev/null +++ b/tianshou/utils/space_info.py @@ -0,0 +1,113 @@ +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, Self + +import gymnasium as gym +import numpy as np +from gymnasium import spaces + +from tianshou.utils.string import ToStringMixin + + +@dataclass(kw_only=True) +class ActionSpaceInfo(ToStringMixin): + """A data structure for storing the different attributes of the action space.""" + + action_shape: int | Sequence[int] + """The shape of the action space.""" + min_action: float + """The smallest allowable action or in the continuous case the lower bound for allowable action value.""" + max_action: float + """The largest allowable action or in the continuous case the upper bound for allowable action value.""" + + @property + def action_dim(self) -> int: + """Return the number of distinct actions (must be greater than zero) an agent can take it its action space.""" + if isinstance(self.action_shape, int): + return self.action_shape + else: + return int(np.prod(self.action_shape)) + + @classmethod + def from_space(cls, space: spaces.Space) -> Self: + """Instantiate the `ActionSpaceInfo` object from a `Space`, supported spaces are Box and Discrete.""" + if isinstance(space, spaces.Box): + return cls( + action_shape=space.shape, + min_action=float(np.min(space.low)), + max_action=float(np.max(space.high)), + ) + elif isinstance(space, spaces.Discrete): + return cls( + action_shape=int(space.n), + min_action=float(space.start), + max_action=float(space.start + space.n - 1), + ) + else: + raise ValueError( + f"Unsupported space type: {space.__class__}. Currently supported types are Discrete and Box.", + ) + + def _tostring_additional_entries(self) -> dict[str, Any]: + return {"action_dim": self.action_dim} + + +@dataclass(kw_only=True) +class ObservationSpaceInfo(ToStringMixin): + """A data structure for storing the different attributes of the observation space.""" + + obs_shape: int | Sequence[int] + """The shape of the observation space.""" + + @property + def obs_dim(self) -> int: + """Return the number of distinct features (must be greater than zero) or dimensions in the observation space.""" + if isinstance(self.obs_shape, int): + return self.obs_shape + else: + return int(np.prod(self.obs_shape)) + + @classmethod + def from_space(cls, space: spaces.Space) -> Self: + """Instantiate the `ObservationSpaceInfo` object from a `Space`, supported spaces are Box and Discrete.""" + if isinstance(space, spaces.Box): + return cls( + obs_shape=space.shape, + ) + elif isinstance(space, spaces.Discrete): + return cls( + obs_shape=int(space.n), + ) + else: + raise ValueError( + f"Unsupported space type: {space.__class__}. Currently supported types are Discrete and Box.", + ) + + def _tostring_additional_entries(self) -> dict[str, Any]: + return {"obs_dim": self.obs_dim} + + +@dataclass(kw_only=True) +class SpaceInfo(ToStringMixin): + """A data structure for storing the attributes of both the action and observation space.""" + + action_info: ActionSpaceInfo + """Stores the attributes of the action space.""" + observation_info: ObservationSpaceInfo + """Stores the attributes of the observation space.""" + + @classmethod + def from_env(cls, env: gym.Env) -> Self: + """Instantiate the `SpaceInfo` object from `gym.Env.action_space` and `gym.Env.observation_space`.""" + return cls.from_spaces(env.action_space, env.observation_space) + + @classmethod + def from_spaces(cls, action_space: spaces.Space, observation_space: spaces.Space) -> Self: + """Instantiate the `SpaceInfo` object from `ActionSpaceInfo` and `ObservationSpaceInfo`.""" + action_info = ActionSpaceInfo.from_space(action_space) + observation_info = ObservationSpaceInfo.from_space(observation_space) + + return cls( + action_info=action_info, + observation_info=observation_info, + )