Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring/mypy issues test #1017

Merged
merged 74 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
799e755
check beforehand if iterable None
dantp-ai Jan 1, 2024
2d3c637
Revert "check beforehand if iterable None"
dantp-ai Jan 1, 2024
0805062
Add type annotation to func args and return type
dantp-ai Jan 2, 2024
7404a76
Make test/highlevel a package by adding __init__
dantp-ai Jan 3, 2024
1468cdf
Fix type mismatch due to same variable naming
dantp-ai Jan 3, 2024
a6a13ce
Print episode stats iff ep ended during collect
dantp-ai Jan 3, 2024
7c28ff7
Fix no-untyped-def by adding annotation to funcs
dantp-ai Jan 4, 2024
27fbd16
Fix AttributeError in case spec is None
dantp-ai Jan 4, 2024
f5aaea8
Fix: var-annotated. Add type to policy.
dantp-ai Jan 8, 2024
fe3a0e0
Fix no-untyped-def by adding annotation to funcs
dantp-ai Jan 11, 2024
fce2482
Fix var-annotated by providing type for policy var
dantp-ai Jan 11, 2024
0b018a1
Merge branch 'thu-ml:master' into refactoring/mypy-issues-test
dantp-ai Jan 11, 2024
3f2994b
Fix var-annotated by providing type for policy var
dantp-ai Jan 11, 2024
5dcfe0a
Minor simplification in train_step (#1019)
MischaPanch Jan 9, 2024
2e370bb
Bump gitpython from 3.1.40 to 3.1.41 (#1020)
dependabot[bot] Jan 10, 2024
bed23bd
Resolved merge conflict in test/offline/test_cql.py
dantp-ai Jan 12, 2024
af63276
Add type annotations to funcs
dantp-ai Jan 14, 2024
5cdf549
Merge branch 'master' into refactoring/mypy-issues-test
dantp-ai Jan 17, 2024
f5ff4f2
Check if env.spec is not None before accessing attrs
dantp-ai Jan 18, 2024
cadb813
Add utils for typing space infos from gym
dantp-ai Jan 19, 2024
9a5220e
Fix typing action/obs space gym
dantp-ai Jan 19, 2024
05a4de4
add -> None to __init__
dantp-ai Jan 19, 2024
a1d6fa7
Refactor get_spaces_info()
dantp-ai Jan 19, 2024
2a03e56
Add type hints to env space attrs
dantp-ai Jan 20, 2024
8cd8b5b
Print rollout stats only if available. Abstract print in util function.
dantp-ai Jan 20, 2024
47d86af
Print collector stats only if available.
dantp-ai Jan 20, 2024
dea8960
Fix return type hint for __init__()
dantp-ai Jan 20, 2024
d4be1b9
Add return type hint to get_args()
dantp-ai Jan 20, 2024
c40f565
Cast env.action_space to Box
dantp-ai Jan 21, 2024
fb06c85
Add type hint to def test_*(args=get_args())
dantp-ai Jan 21, 2024
fde35cc
Add type hint to def stop_fn(mean_rewards)
dantp-ai Jan 21, 2024
53c4104
Add type hint to def save_best_fn(policy)
dantp-ai Jan 21, 2024
b501eff
Add type hint to func save_checkpoint_fn(epoch, env_step, gradient_step)
dantp-ai Jan 21, 2024
e5b3bbd
Add type hint to func test_fn(epoch, env_step)
dantp-ai Jan 21, 2024
0904458
Add type hint to func train_fn(epoch, env_step)
dantp-ai Jan 21, 2024
5f5710f
Refactor get_spaces_info() to use dataclass
dantp-ai Jan 21, 2024
f9134d8
Minor change: fix type hint
dantp-ai Jan 21, 2024
74fe676
Add return type hint to funcs
dantp-ai Jan 22, 2024
6357e1f
Add type hints
dantp-ai Jan 23, 2024
c9d7174
Add type hints
dantp-ai Jan 23, 2024
7828303
Fix issue incompatible list item
dantp-ai Jan 23, 2024
7f81ac7
Refactor get_spaces_info
dantp-ai Jan 26, 2024
f5b18a5
use Self for type hint
dantp-ai Jan 26, 2024
e1f53dc
Make action/obs_dim properties
dantp-ai Jan 26, 2024
07eee1d
Add docstrings
dantp-ai Jan 26, 2024
5df51ec
Update docstrings
dantp-ai Jan 29, 2024
b1ef1bb
Assume action/obs_space is non-empty
dantp-ai Jan 29, 2024
283f345
Merge branch 'master' into refactoring/mypy-issues-test
dantp-ai Jan 29, 2024
f274310
Fix mypy
dantp-ai Jan 30, 2024
9bcb080
Fix mypy issues
dantp-ai Jan 30, 2024
e6fd697
Fix mypy issues
dantp-ai Jan 30, 2024
f90754d
Use action_dim from SpaceInfo to compute target entropy
dantp-ai Jan 30, 2024
7741728
Fix mypy issues
dantp-ai Jan 31, 2024
73b4f77
Fix mypy issues
dantp-ai Jan 31, 2024
0148f01
Refactor SpaceInfo constructor method from_env()
dantp-ai Jan 31, 2024
a6135c6
Fix mypy issues
dantp-ai Jan 31, 2024
89bb655
Fix mypy issues
dantp-ai Jan 31, 2024
0d185e6
Fix mypy issues
dantp-ai Feb 1, 2024
7a9ae7f
Fix mypy issues
dantp-ai Feb 1, 2024
372ee9d
Fix mypy issues: continuous/test_npg
dantp-ai Feb 1, 2024
63e542d
Fix mypy issues: continuous/test_ddpg
dantp-ai Feb 1, 2024
8816771
Fix similar mypy issues in a bunch of test files
dantp-ai Feb 1, 2024
6c41eb7
Fix similar mypy issues in a bunch of test files
dantp-ai Feb 2, 2024
79ee83f
Fix similar mypy issues in a bunch of test files
dantp-ai Feb 3, 2024
0cdf4c5
Use more specific type annotations for policy
dantp-ai Feb 3, 2024
5589540
Remove print_final_stats and use plain print
dantp-ai Feb 3, 2024
686f0c9
Fix mypy issue for logger in if-else
dantp-ai Feb 3, 2024
5282501
Set return type of dist to torch.distributions.Distribution
dantp-ai Feb 4, 2024
bd490c6
Fix mypy issues: no type annotation for policy, agent
dantp-ai Feb 4, 2024
f2ab587
Fix type annotation for `dist` func
dantp-ai Feb 4, 2024
db4e6a9
Fix logging type annotations in examples
dantp-ai Feb 4, 2024
c4c8c7d
Fix mypy issue incompatible types in assignment buffer
dantp-ai Feb 4, 2024
2d91e7b
Add type annotation to main() and watch()
dantp-ai Feb 5, 2024
4e72567
Fix incompatible type to `run_cli()`
dantp-ai Feb 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 21 additions & 23 deletions examples/atari/atari_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
import torch
from atari_network import C51
from atari_wrapper import make_atari_env
from torch.utils.tensorboard import SummaryWriter

from examples.common import logger_factory
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import C51Policy
from tianshou.policy.base import BasePolicy
from tianshou.trainer import OffpolicyTrainer
from tianshou.utils import TensorboardLogger, WandbLogger


def get_args():
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=0)
Expand Down Expand Up @@ -66,7 +66,7 @@ def get_args():
return parser.parse_args()


def test_c51(args=get_args()):
def test_c51(args: argparse.Namespace = get_args()) -> None:
env, train_envs, test_envs = make_atari_env(
args.task,
args.seed,
Expand All @@ -87,7 +87,7 @@ def test_c51(args=get_args()):
net = C51(*args.state_shape, args.action_shape, args.num_atoms, args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy = C51Policy(
policy: C51Policy = C51Policy(
model=net,
optim=optim,
discount_factor=args.gamma,
Expand Down Expand Up @@ -123,21 +123,19 @@ def test_c51(args=get_args()):

# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)

def save_best_fn(policy):
logger_factory.logger_type = "wandb"
logger_factory.wandb_project = args.wandb_project
else:
logger_factory.logger_type = "tensorboard"

logger = logger_factory.create_logger(
log_dir=log_path,
experiment_name=log_name,
run_id=args.resume_id,
config_dict=vars(args),
)

def save_best_fn(policy: BasePolicy) -> None:
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards: float) -> bool:
Expand All @@ -147,7 +145,7 @@ def stop_fn(mean_rewards: float) -> bool:
return mean_rewards >= 20
return False

def train_fn(epoch, env_step):
def train_fn(epoch: int, env_step: int) -> None:
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final)
Expand All @@ -157,11 +155,11 @@ def train_fn(epoch, env_step):
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})

def test_fn(epoch, env_step):
def test_fn(epoch: int, env_step: int | None) -> None:
policy.set_eps(args.eps_test)

# watch agent's performance
def watch():
def watch() -> None:
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
Expand Down
44 changes: 21 additions & 23 deletions examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@
import torch
from atari_network import DQN
from atari_wrapper import make_atari_env
from torch.utils.tensorboard import SummaryWriter

from examples.common import logger_factory
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import DQNPolicy
from tianshou.policy.base import BasePolicy
from tianshou.policy.modelbased.icm import ICMPolicy
from tianshou.trainer import OffpolicyTrainer
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.net.discrete import IntrinsicCuriosityModule


def get_args():
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=0)
Expand Down Expand Up @@ -83,7 +83,7 @@ def get_args():
return parser.parse_args()


def test_dqn(args=get_args()):
def test_dqn(args: argparse.Namespace = get_args()) -> None:
env, train_envs, test_envs = make_atari_env(
args.task,
args.seed,
Expand All @@ -104,7 +104,7 @@ def test_dqn(args=get_args()):
net = DQN(*args.state_shape, args.action_shape, args.device).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy = DQNPolicy(
policy: DQNPolicy = DQNPolicy(
model=net,
optim=optim,
action_space=env.action_space,
Expand Down Expand Up @@ -158,21 +158,19 @@ def test_dqn(args=get_args()):

# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)
logger_factory.logger_type = "wandb"
logger_factory.wandb_project = args.wandb_project
else:
logger_factory.logger_type = "tensorboard"

logger = logger_factory.create_logger(
log_dir=log_path,
experiment_name=log_name,
run_id=args.resume_id,
config_dict=vars(args),
)

def save_best_fn(policy):
def save_best_fn(policy: BasePolicy) -> None:
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards: float) -> bool:
Expand All @@ -182,7 +180,7 @@ def stop_fn(mean_rewards: float) -> bool:
return mean_rewards >= 20
return False

def train_fn(epoch, env_step):
def train_fn(epoch: int, env_step: int) -> None:
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final)
Expand All @@ -192,17 +190,17 @@ def train_fn(epoch, env_step):
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})

def test_fn(epoch, env_step):
def test_fn(epoch: int, env_step: int | None) -> None:
policy.set_eps(args.eps_test)

def save_checkpoint_fn(epoch, env_step, gradient_step):
def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str:
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
torch.save({"model": policy.state_dict()}, ckpt_path)
return ckpt_path

# watch agent's performance
def watch():
def watch() -> None:
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
Expand Down
6 changes: 4 additions & 2 deletions examples/atari/atari_dqn_hl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3

import functools
import os

from examples.atari.atari_network import (
Expand Down Expand Up @@ -48,7 +49,7 @@ def main(
icm_lr_scale: float = 0.0,
icm_reward_scale: float = 0.01,
icm_forward_loss_weight: float = 0.2,
):
) -> None:
log_name = os.path.join(task, "dqn", str(experiment_config.seed), datetime_tag())

sampling_config = SamplingConfig(
Expand Down Expand Up @@ -102,4 +103,5 @@ def main(


if __name__ == "__main__":
logging.run_cli(main)
run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig())
logging.run_cli(run_with_default_config)
44 changes: 21 additions & 23 deletions examples/atari/atari_fqf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@
import torch
from atari_network import DQN
from atari_wrapper import make_atari_env
from torch.utils.tensorboard import SummaryWriter

from examples.common import logger_factory
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import FQFPolicy
from tianshou.policy.base import BasePolicy
from tianshou.trainer import OffpolicyTrainer
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction


def get_args():
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=3128)
Expand Down Expand Up @@ -69,7 +69,7 @@ def get_args():
return parser.parse_args()


def test_fqf(args=get_args()):
def test_fqf(args: argparse.Namespace = get_args()) -> None:
env, train_envs, test_envs = make_atari_env(
args.task,
args.seed,
Expand Down Expand Up @@ -99,7 +99,7 @@ def test_fqf(args=get_args()):
fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim)
fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), lr=args.fraction_lr)
# define policy
policy = FQFPolicy(
policy: FQFPolicy = FQFPolicy(
model=net,
optim=optim,
fraction_model=fraction_net,
Expand Down Expand Up @@ -136,21 +136,19 @@ def test_fqf(args=get_args()):

# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)

def save_best_fn(policy):
logger_factory.logger_type = "wandb"
logger_factory.wandb_project = args.wandb_project
else:
logger_factory.logger_type = "tensorboard"

logger = logger_factory.create_logger(
log_dir=log_path,
experiment_name=log_name,
run_id=args.resume_id,
config_dict=vars(args),
)

def save_best_fn(policy: BasePolicy) -> None:
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards: float) -> bool:
Expand All @@ -160,7 +158,7 @@ def stop_fn(mean_rewards: float) -> bool:
return mean_rewards >= 20
return False

def train_fn(epoch, env_step):
def train_fn(epoch: int, env_step: int) -> None:
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final)
Expand All @@ -170,11 +168,11 @@ def train_fn(epoch, env_step):
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})

def test_fn(epoch, env_step):
def test_fn(epoch: int, env_step: int | None) -> None:
policy.set_eps(args.eps_test)

# watch agent's performance
def watch():
def watch() -> None:
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
Expand Down
Loading
Loading