Skip to content

Commit

Permalink
Fix no-untyped-def by adding annotation to funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
dantp-ai committed Jan 4, 2024
1 parent a6a13ce commit 7c28ff7
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions test/offline/test_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import BCQPolicy
from tianshou.policy import BasePolicy, BCQPolicy
from tianshou.trainer import OfflineTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import MLP, Net
Expand All @@ -23,7 +23,7 @@
from test.offline.gather_pendulum_data import expert_file_name, gather_data


def get_args():
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="Pendulum-v1")
parser.add_argument("--reward-threshold", type=float, default=None)
Expand Down Expand Up @@ -64,7 +64,7 @@ def get_args():
return parser.parse_known_args()[0]


def test_bcq(args=get_args()):
def test_bcq(args: argparse.Namespace = get_args()) -> None:
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
if args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
Expand Down Expand Up @@ -170,13 +170,13 @@ def test_bcq(args=get_args()):
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)

def save_best_fn(policy):
def save_best_fn(policy: BasePolicy) -> None:
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards):
def stop_fn(mean_rewards: float) -> bool:
return mean_rewards >= args.reward_threshold

def watch():
def watch() -> None:
policy.load_state_dict(
torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")),
)
Expand Down

0 comments on commit 7c28ff7

Please sign in to comment.