diff --git a/docs/02_notebooks/L0_overview.ipynb b/docs/02_notebooks/L0_overview.ipynb index de2e06cfd..41a7143e8 100644 --- a/docs/02_notebooks/L0_overview.ipynb +++ b/docs/02_notebooks/L0_overview.ipynb @@ -141,7 +141,7 @@ "# Let's watch its performance!\n", "policy.eval()\n", "result = test_collector.collect(n_episode=1, render=False)\n", - "print(\"Final reward: {}, length: {}\".format(result[\"rews\"].mean(), result[\"lens\"].mean()))" + "print(\"Final reward: {}, length: {}\".format(result.returns.mean(), result.lens.mean()))" ] }, { diff --git a/docs/02_notebooks/L4_Policy.ipynb b/docs/02_notebooks/L4_Policy.ipynb index 4a815d07a..e31214fb4 100644 --- a/docs/02_notebooks/L4_Policy.ipynb +++ b/docs/02_notebooks/L4_Policy.ipynb @@ -54,16 +54,19 @@ }, "outputs": [], "source": [ - "from typing import Dict, List\n", + "from typing import cast\n", "\n", "import numpy as np\n", "import torch\n", "import gymnasium as gym\n", "\n", - "from tianshou.data import Batch, ReplayBuffer, to_torch, to_torch_as\n", - "from tianshou.policy import BasePolicy\n", + "from dataclasses import dataclass\n", + "\n", + "from tianshou.data import Batch, ReplayBuffer, to_torch, to_torch_as, SequenceSummaryStats\n", + "from tianshou.policy import BasePolicy, TrainingStats\n", "from tianshou.utils.net.common import Net\n", - "from tianshou.utils.net.discrete import Actor" + "from tianshou.utils.net.discrete import Actor\n", + "from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol" ] }, { @@ -102,12 +105,14 @@ "\n", "\n", "\n", - "1. Since Tianshou is a **Deep** RL libraries, there should be a policy network in our Policy Module, also a Torch optimizer.\n", - "2. In Tianshou's BasePolicy, `Policy.update()` first calls `Policy.process_fn()` to preprocess training data and computes quantities like episodic returns (gradient free), then it will call `Policy.learn()` to perform the back-propagation.\n", + "1. Since Tianshou is a **Deep** RL libraries, there should be a policy network in our Policy Module, \n", + "also a Torch optimizer.\n", + "2. In Tianshou's BasePolicy, `Policy.update()` first calls `Policy.process_fn()` to \n", + "preprocess training data and computes quantities like episodic returns (gradient free), \n", + "then it will call `Policy.learn()` to perform the back-propagation.\n", + "3. Each Policy is accompanied by a dedicated implementation of `TrainingStats` to store details of training.\n", "\n", "Then we get the implementation below.\n", - "\n", - "\n", "\n" ] }, @@ -119,7 +124,14 @@ }, "outputs": [], "source": [ - "class REINFORCEPolicy(BasePolicy):\n", + "@dataclass(kw_only=True)\n", + "class REINFORCETrainingStats(TrainingStats):\n", + " \"\"\"A dedicated class for REINFORCE training statistics.\"\"\"\n", + "\n", + " loss: SequenceSummaryStats\n", + "\n", + "\n", + "class REINFORCEPolicy(BasePolicy[REINFORCETrainingStats]):\n", " \"\"\"Implementation of REINFORCE algorithm.\"\"\"\n", "\n", " def __init__(\n", @@ -138,7 +150,7 @@ " \"\"\"Compute the discounted returns for each transition.\"\"\"\n", " pass\n", "\n", - " def learn(self, batch: Batch, batch_size: int, repeat: int) -> Dict[str, List[float]]:\n", + " def learn(self, batch: Batch, batch_size: int, repeat: int) -> REINFORCETrainingStats:\n", " \"\"\"Perform the back-propagation.\"\"\"\n", " return" ] @@ -220,7 +232,9 @@ }, "source": [ "### Policy.learn()\n", - "Data batch returned by `Policy.process_fn()` will flow into `Policy.learn()`. Final we can construct our loss function and perform the back-propagation." + "Data batch returned by `Policy.process_fn()` will flow into `Policy.learn()`. Finally,\n", + "we can construct our loss function and perform the back-propagation. The method \n", + "should look something like this:" ] }, { @@ -231,22 +245,24 @@ }, "outputs": [], "source": [ - "def learn(self, batch: Batch, batch_size: int, repeat: int) -> Dict[str, List[float]]:\n", + "from tianshou.utils.optim import optim_step\n", + "\n", + "\n", + "def learn(self, batch: Batch, batch_size: int, repeat: int):\n", " \"\"\"Perform the back-propagation.\"\"\"\n", - " logging_losses = []\n", + " train_losses = []\n", " for _ in range(repeat):\n", " for minibatch in batch.split(batch_size, merge_last=True):\n", - " self.optim.zero_grad()\n", " result = self(minibatch)\n", " dist = result.dist\n", " act = to_torch_as(minibatch.act, result.act)\n", " ret = to_torch(minibatch.returns, torch.float, result.act.device)\n", " log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)\n", " loss = -(log_prob * ret).mean()\n", - " loss.backward()\n", - " self.optim.step()\n", - " logging_losses.append(loss.item())\n", - " return {\"loss\": logging_losses}" + " optim_step(loss, self.optim)\n", + " train_losses.append(loss.item())\n", + "\n", + " return REINFORCETrainingStats(loss=SequenceSummaryStats.from_sequence(train_losses))" ] }, { @@ -256,7 +272,12 @@ }, "source": [ "## Implementation\n", - "Finally we can assemble the implemented methods and form a REINFORCE Policy." + "Now we can assemble the methods and form a REINFORCE Policy. The outputs of\n", + "`learn` will be collected to a dedicated dataclass.\n", + "\n", + "We will also use protocols to specify what fields are expected and produced inside a `Batch` in\n", + "each processing step. By using protocols, we can get better type checking and IDE support \n", + "without having to implement a separate class for each combination of fields." ] }, { @@ -290,30 +311,33 @@ " act = dist.sample()\n", " return Batch(act=act, dist=dist)\n", "\n", - " def process_fn(self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray) -> Batch:\n", + " def process_fn(\n", + " self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray\n", + " ) -> BatchWithReturnsProtocol:\n", " \"\"\"Compute the discounted returns for each transition.\"\"\"\n", " returns, _ = self.compute_episodic_return(\n", " batch, buffer, indices, gamma=0.99, gae_lambda=1.0\n", " )\n", " batch.returns = returns\n", - " return batch\n", + " return cast(BatchWithReturnsProtocol, batch)\n", "\n", - " def learn(self, batch: Batch, batch_size: int, repeat: int) -> Dict[str, List[float]]:\n", + " def learn(\n", + " self, batch: BatchWithReturnsProtocol, batch_size: int, repeat: int\n", + " ) -> REINFORCETrainingStats:\n", " \"\"\"Perform the back-propagation.\"\"\"\n", - " logging_losses = []\n", + " train_losses = []\n", " for _ in range(repeat):\n", " for minibatch in batch.split(batch_size, merge_last=True):\n", - " self.optim.zero_grad()\n", " result = self(minibatch)\n", " dist = result.dist\n", " act = to_torch_as(minibatch.act, result.act)\n", " ret = to_torch(minibatch.returns, torch.float, result.act.device)\n", " log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)\n", " loss = -(log_prob * ret).mean()\n", - " loss.backward()\n", - " self.optim.step()\n", - " logging_losses.append(loss.item())\n", - " return {\"loss\": logging_losses}" + " optim_step(loss, self.optim)\n", + " train_losses.append(loss.item())\n", + "\n", + " return REINFORCETrainingStats(loss=SequenceSummaryStats.from_sequence(train_losses))" ] }, { @@ -370,8 +394,8 @@ "source": [ "print(policy)\n", "print(\"========================================\")\n", - "for para in policy.parameters():\n", - " print(para.shape)" + "for param in policy.parameters():\n", + " print(param.shape)" ] }, { @@ -831,6 +855,13 @@ "\n", "" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/docs/02_notebooks/L5_Collector.ipynb b/docs/02_notebooks/L5_Collector.ipynb index 0dbb77de6..fc5588401 100644 --- a/docs/02_notebooks/L5_Collector.ipynb +++ b/docs/02_notebooks/L5_Collector.ipynb @@ -116,9 +116,9 @@ "source": [ "collect_result = test_collector.collect(n_episode=9)\n", "print(collect_result)\n", - "print(\"Rewards of 9 episodes are {}\".format(collect_result[\"rews\"]))\n", - "print(\"Average episode reward is {}.\".format(collect_result[\"rew\"]))\n", - "print(\"Average episode length is {}.\".format(collect_result[\"len\"]))" + "print(f\"Returns of 9 episodes are {collect_result.returns}\")\n", + "print(f\"Average episode return is {collect_result.returns_stat.mean}.\")\n", + "print(f\"Average episode length is {collect_result.lens_stat.mean}.\")" ] }, { @@ -146,9 +146,9 @@ "test_collector.reset()\n", "collect_result = test_collector.collect(n_episode=9, random=True)\n", "print(collect_result)\n", - "print(\"Rewards of 9 episodes are {}\".format(collect_result[\"rews\"]))\n", - "print(\"Average episode reward is {}.\".format(collect_result[\"rew\"]))\n", - "print(\"Average episode length is {}.\".format(collect_result[\"len\"]))" + "print(f\"Returns of 9 episodes are {collect_result.returns}\")\n", + "print(f\"Average episode return is {collect_result.returns_stat.mean}.\")\n", + "print(f\"Average episode length is {collect_result.lens_stat.mean}.\")" ] }, { @@ -157,7 +157,7 @@ "id": "sKQRTiG10ljU" }, "source": [ - "Seems that an initialized policy performs even worse than a random policy without any training." + "It seems like an initialized policy performs even worse than a random policy without any training." ] }, { diff --git a/docs/02_notebooks/L6_Trainer.ipynb b/docs/02_notebooks/L6_Trainer.ipynb index db6b0fb86..da4010b70 100644 --- a/docs/02_notebooks/L6_Trainer.ipynb +++ b/docs/02_notebooks/L6_Trainer.ipynb @@ -146,7 +146,7 @@ "replaybuffer.reset()\n", "for i in range(10):\n", " evaluation_result = test_collector.collect(n_episode=10)\n", - " print(\"Evaluation reward is {}\".format(evaluation_result[\"rew\"]))\n", + " print(f\"Evaluation mean episodic reward is: {evaluation_result.returns.mean()}\")\n", " train_collector.collect(n_step=2000)\n", " # 0 means taking all data stored in train_collector.buffer\n", " policy.update(0, train_collector.buffer, batch_size=512, repeat=1)\n", diff --git a/docs/02_notebooks/L7_Experiment.ipynb b/docs/02_notebooks/L7_Experiment.ipynb index 55a3be144..ad450374e 100644 --- a/docs/02_notebooks/L7_Experiment.ipynb +++ b/docs/02_notebooks/L7_Experiment.ipynb @@ -303,7 +303,7 @@ "# Let's watch its performance!\n", "policy.eval()\n", "result = test_collector.collect(n_episode=1, render=False)\n", - "print(\"Final reward: {}, length: {}\".format(result[\"rews\"].mean(), result[\"lens\"].mean()))" + "print(f\"Final episode reward: {result.returns.mean()}, length: {result.lens.mean()}\")" ] } ], diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 9066e8694..1849f2efc 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -239,3 +239,10 @@ logp autogenerated subpackage subpackages +recurse +rollout +rollouts +prepend +prepends +dict +dicts diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index c0837d645..21b42aa7b 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -184,7 +184,7 @@ def watch(): print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result["rews"].mean() + rew = result.returns_stat.mean print(f"Mean reward (over {result['n/ep']} episodes): {rew}") if args.watch: diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 2d94b3356..520761e55 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -225,7 +225,7 @@ def watch(): print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result["rews"].mean() + rew = result.returns_stat.mean print(f"Mean reward (over {result['n/ep']} episodes): {rew}") if args.watch: diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 671a16d1b..97bd7ded4 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -197,7 +197,7 @@ def watch(): print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result["rews"].mean() + rew = result.returns_stat.mean print(f"Mean reward (over {result['n/ep']} episodes): {rew}") if args.watch: diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index af7bfe1d9..dc59da469 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -194,7 +194,7 @@ def watch(): print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result["rews"].mean() + rew = result.returns_stat.mean print(f"Mean reward (over {result['n/ep']} episodes): {rew}") if args.watch: diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index c46abf49d..1dd695615 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -252,7 +252,7 @@ def watch(): print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result["rews"].mean() + rew = result.returns_stat.mean print(f"Mean reward (over {result['n/ep']} episodes): {rew}") if args.watch: diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index c72f80d97..5231c0391 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -180,7 +180,7 @@ def watch(): print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result["rews"].mean() + rew = result.returns_stat.mean print(f"Mean reward (over {result['n/ep']} episodes): {rew}") if args.watch: diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 5fd3d4380..ded22f5d9 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -223,7 +223,7 @@ def watch(): print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result["rews"].mean() + rew = result.returns_stat.mean print(f"Mean reward (over {result['n/ep']} episodes): {rew}") if args.watch: diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index 30e2a21be..543d4b8fb 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -236,7 +236,7 @@ def watch(): print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result["rews"].mean() + rew = result.returns_stat.mean print(f"Mean reward (over {result['n/ep']} episodes): {rew}") if args.watch: diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index bfd248fa2..003caa3fd 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -130,7 +130,7 @@ def test_fn(epoch, env_step): logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! @@ -139,8 +139,7 @@ def test_fn(epoch, env_step): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index 1d9aabfd5..ff532830a 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -145,7 +145,7 @@ def test_fn(epoch, env_step): logger=logger, ).run() - # assert stop_fn(result["best_reward"]) + # assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! @@ -154,8 +154,7 @@ def test_fn(epoch, env_step): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index e3703b3c9..2dda6b610 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -187,8 +187,7 @@ def stop_fn(mean_rewards): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 8a3154195..007a310d8 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -126,7 +126,7 @@ def test_fn(epoch, env_step): logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! @@ -135,8 +135,7 @@ def test_fn(epoch, env_step): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 748c9d73c..49a34af14 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -140,7 +140,7 @@ def stop_fn(mean_rewards): logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! @@ -148,8 +148,7 @@ def stop_fn(mean_rewards): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 4e464156c..3dcc92fcd 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -257,7 +257,7 @@ def save_best_fn(policy): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index 05e31019f..6e37100db 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -225,7 +225,7 @@ def save_best_fn(policy): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index fbd2ec989..b58c802ba 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -222,7 +222,7 @@ def save_best_fn(policy): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index 7b7dba7b6..b3772b4d6 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -21,7 +21,7 @@ def main( experiment_config: ExperimentConfig, - task: str = "Ant-v3", + task: str = "Ant-v4", buffer_size: int = 4096, hidden_sizes: Sequence[int] = (64, 64), lr: float = 7e-4, diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index f20292831..75948679b 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -171,7 +171,7 @@ def save_best_fn(policy): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index bad4f0e28..6b1e8c13d 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -219,7 +219,7 @@ def save_best_fn(policy): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index 9b190c7e4..e7287f583 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -227,7 +227,7 @@ def save_best_fn(policy): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index 80b07b635..e39c7b092 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -199,7 +199,7 @@ def save_best_fn(policy): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_redq_hl.py b/examples/mujoco/mujoco_redq_hl.py index 66219bf8c..cfcdb792f 100644 --- a/examples/mujoco/mujoco_redq_hl.py +++ b/examples/mujoco/mujoco_redq_hl.py @@ -18,7 +18,7 @@ def main( experiment_config: ExperimentConfig, - task: str = "Ant-v3", + task: str = "Ant-v4", buffer_size: int = 1000000, hidden_sizes: Sequence[int] = (256, 256), ensemble_size: int = 10, diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index 50bf312a6..1b394189b 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -199,7 +199,7 @@ def save_best_fn(policy): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index c42775379..a53c04c3a 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -193,7 +193,7 @@ def save_best_fn(policy): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index 7a556e219..6af4c6192 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -17,7 +17,7 @@ def main( experiment_config: ExperimentConfig, - task: str = "Ant-v3", + task: str = "Ant-v4", buffer_size: int = 1000000, hidden_sizes: Sequence[int] = (256, 256), actor_lr: float = 1e-3, diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 667db0afb..8e9476143 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -191,7 +191,7 @@ def save_best_fn(policy): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index 1f783f106..67b9c1847 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -22,7 +22,7 @@ def main( experiment_config: ExperimentConfig, - task: str = "Ant-v3", + task: str = "Ant-v4", buffer_size: int = 1000000, hidden_sizes: Sequence[int] = (256, 256), actor_lr: float = 3e-4, diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index e37cf9196..97fc0f910 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -224,7 +224,7 @@ def save_best_fn(policy): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index 63e4256f0..e2015382d 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -188,8 +188,8 @@ def watch(): test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) - rew = result["rews"].mean() - print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') + rew = result.returns_stat.mean + print(f"Mean reward (over {result.n_collected_episodes} episodes): {rew}") if args.watch: watch() diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 14947b04f..82a0d6edf 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -164,8 +164,8 @@ def watch(): test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) - rew = result["rews"].mean() - print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') + rew = result.returns_stat.mean + print(f"Mean reward (over {result.n_collected_episodes} episodes): {rew}") if args.watch: watch() diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index 7432560d1..db768f2b0 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -186,8 +186,8 @@ def watch(): test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) - rew = result["rews"].mean() - print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') + rew = result.returns_stat.mean + print(f"Mean reward (over {result.n_collected_episodes} episodes): {rew}") if args.watch: watch() diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index c4576924f..3efb56f91 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -147,8 +147,8 @@ def watch(): test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) - rew = result["rews"].mean() - print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') + rew = result.returns_stat.mean + print(f"Mean reward (over {result.n_collected_episodes} episodes): {rew}") if args.watch: watch() diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index a710c44ac..5e3746e34 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -229,7 +229,7 @@ def watch(): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result['rews'].mean()}, length: {result['lens'].mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 1c29a1152..160137fce 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -366,7 +366,7 @@ def watch(): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result['rews'].mean()}, length: {result['lens'].mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index 617091d36..89384127c 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -166,7 +166,7 @@ def watch(): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result['rews'].mean()}, length: {result['lens'].mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index fef7d3f81..c4c4e9dd9 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -216,7 +216,7 @@ def watch(): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result['rews'].mean()}, length: {result['lens'].mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 03ff92e0d..46a33769f 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -190,10 +190,10 @@ def watch(): print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result["rews"].mean() - lens = result["lens"].mean() * args.skip_num - print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') - print(f'Mean length (over {result["n/ep"]} episodes): {lens}') + rew = result.returns_stat.mean + lens = result.lens_stat.mean * args.skip_num + print(f"Mean reward (over {result.n_collected_episodes} episodes): {rew}") + print(f"Mean length (over {result.n_collected_episodes} episodes): {lens}") if args.watch: watch() diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index d84a9a34a..0811a24dd 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -255,10 +255,10 @@ def watch(): print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result["rews"].mean() - lens = result["lens"].mean() * args.skip_num - print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') - print(f'Mean length (over {result["n/ep"]} episodes): {lens}') + rew = result.returns_stat.mean + lens = result.lens_stat.mean * args.skip_num + print(f"Mean reward (over {result.n_collected_episodes} episodes): {rew}") + print(f"Mean length (over {result.n_collected_episodes} episodes): {lens}") if args.watch: watch() diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 1f8d4f556..985e6ef50 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -188,7 +188,7 @@ def test_collector(gym_reset_kwargs): assert np.all(c2obs == obs1) or np.all(c2obs == obs2) c2.reset_env(gym_reset_kwargs=gym_reset_kwargs) c2.reset_buffer() - assert c2.collect(n_episode=8, gym_reset_kwargs=gym_reset_kwargs)["n/ep"] == 8 + assert c2.collect(n_episode=8, gym_reset_kwargs=gym_reset_kwargs).n_collected_episodes == 8 valid_indices = [4, 5, 28, 29, 30, 54, 55, 56, 57] obs[valid_indices] = [0, 1, 0, 1, 2, 0, 1, 2, 3] assert np.all(c2.buffer.obs[:, 0] == obs) @@ -237,9 +237,9 @@ def test_collector_with_async(gym_reset_kwargs): ptr = [0, 0, 0, 0] for n_episode in tqdm.trange(1, 30, desc="test async n_episode"): result = c1.collect(n_episode=n_episode, gym_reset_kwargs=gym_reset_kwargs) - assert result["n/ep"] >= n_episode + assert result.n_collected_episodes >= n_episode # check buffer data, obs and obs_next, env_id - for i, count in enumerate(np.bincount(result["lens"], minlength=6)[2:]): + for i, count in enumerate(np.bincount(result.lens, minlength=6)[2:]): env_len = i + 2 total = env_len * count indices = np.arange(ptr[i], ptr[i] + total) % bufsize @@ -252,7 +252,7 @@ def test_collector_with_async(gym_reset_kwargs): # test async n_step, for now the buffer should be full of data for n_step in tqdm.trange(1, 15, desc="test async n_step"): result = c1.collect(n_step=n_step, gym_reset_kwargs=gym_reset_kwargs) - assert result["n/st"] >= n_step + assert result.n_collected_steps >= n_step for i in range(4): env_len = i + 2 seq = np.arange(env_len) @@ -284,12 +284,12 @@ def test_collector_with_dict_state(): ) c1.collect(n_step=12) result = c1.collect(n_episode=8) - assert result["n/ep"] == 8 - lens = np.bincount(result["lens"]) + assert result.n_collected_episodes == 8 + lens = np.bincount(result.lens) assert ( - result["n/st"] == 21 + result.n_collected_steps == 21 and np.all(lens == [0, 0, 2, 2, 2, 2]) - or result["n/st"] == 20 + or result.n_collected_steps == 20 and np.all(lens == [0, 0, 3, 1, 2, 2]) ) batch, _ = c1.buffer.sample(10) @@ -407,9 +407,9 @@ def test_collector_with_ma(): policy = MyPolicy() c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn) # n_step=3 will collect a full episode - rew = c0.collect(n_step=3)["rews"] + rew = c0.collect(n_step=3).returns assert len(rew) == 0 - rew = c0.collect(n_episode=2)["rews"] + rew = c0.collect(n_episode=2).returns assert rew.shape == (2, 4) assert np.all(rew == 1) env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]] @@ -420,9 +420,9 @@ def test_collector_with_ma(): VectorReplayBuffer(total_size=100, buffer_num=4), Logger.single_preprocess_fn, ) - rew = c1.collect(n_step=12)["rews"] + rew = c1.collect(n_step=12).returns assert rew.shape == (2, 4) and np.all(rew == 1), rew - rew = c1.collect(n_episode=8)["rews"] + rew = c1.collect(n_episode=8).returns assert rew.shape == (8, 4) assert np.all(rew == 1) batch, _ = c1.buffer.sample(10) @@ -528,7 +528,7 @@ def test_collector_with_ma(): VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4), Logger.single_preprocess_fn, ) - rew = c2.collect(n_episode=10)["rews"] + rew = c2.collect(n_episode=10).returns assert rew.shape == (10, 4) assert np.all(rew == 1) batch, _ = c2.buffer.sample(10) @@ -580,8 +580,8 @@ def test_collector_with_atari_setting(): c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4)) c3.collect(n_step=12) result = c3.collect(n_episode=9) - assert result["n/ep"] == 9 - assert result["n/st"] == 23 + assert result.n_collected_episodes == 9 + assert result.n_collected_steps == 23 assert c3.buffer.obs.shape == (100, 4, 84, 84) obs = np.zeros_like(c3.buffer.obs) obs[np.arange(8)] = reference_obs[[0, 1, 0, 1, 0, 1, 0, 1]] @@ -608,8 +608,8 @@ def test_collector_with_atari_setting(): ) c4.collect(n_step=12) result = c4.collect(n_episode=9) - assert result["n/ep"] == 9 - assert result["n/st"] == 23 + assert result.n_collected_episodes == 9 + assert result.n_collected_steps == 23 assert c4.buffer.obs.shape == (100, 84, 84) obs = np.zeros_like(c4.buffer.obs) slice_obs = reference_obs[:, -1] @@ -676,8 +676,8 @@ def test_collector_with_atari_setting(): assert len(buf) == 5 assert len(c5.buffer) == 12 result = c5.collect(n_episode=9) - assert result["n/ep"] == 9 - assert result["n/st"] == 23 + assert result.n_collected_episodes == 9 + assert result.n_collected_steps == 23 assert len(buf) == 35 assert np.all( buf.obs[: len(buf)] @@ -768,11 +768,11 @@ def test_collector_with_atari_setting(): # test buffer=None c6 = Collector(policy, envs) result1 = c6.collect(n_step=12) - for key in ["n/ep", "n/st", "rews", "lens"]: - assert np.allclose(result1[key], result_[key]) + for key in ["n_collected_episodes", "n_collected_steps", "returns", "lens"]: + assert np.allclose(getattr(result1, key), getattr(result_, key)) result2 = c6.collect(n_episode=9) - for key in ["n/ep", "n/st", "rews", "lens"]: - assert np.allclose(result2[key], result[key]) + for key in ["n_collected_episodes", "n_collected_steps", "returns", "lens"]: + assert np.allclose(getattr(result2, key), getattr(result, key)) @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") diff --git a/test/base/test_logger.py b/test/base/test_logger.py new file mode 100644 index 000000000..1634f4d8f --- /dev/null +++ b/test/base/test_logger.py @@ -0,0 +1,59 @@ +import numpy as np +import pytest + +from tianshou.utils import BaseLogger + + +class TestBaseLogger: + @staticmethod + @pytest.mark.parametrize( + "input_dict, expected_output", + [ + ({"a": 1, "b": {"c": 2, "d": {"e": 3}}}, {"a": 1, "b/c": 2, "b/d/e": 3}), + ({"a": {"b": {"c": 1}}}, {"a/b/c": 1}), + ], + ) + def test_flatten_dict_basic(input_dict, expected_output): + result = BaseLogger.prepare_dict_for_logging(input_dict) + assert result == expected_output + + @staticmethod + @pytest.mark.parametrize( + "input_dict, delimiter, expected_output", + [ + ({"a": {"b": {"c": 1}}}, "|", {"a|b|c": 1}), + ({"a": {"b": {"c": 1}}}, ".", {"a.b.c": 1}), + ], + ) + def test_flatten_dict_custom_delimiter(input_dict, delimiter, expected_output): + result = BaseLogger.prepare_dict_for_logging(input_dict, delimiter=delimiter) + assert result == expected_output + + @staticmethod + @pytest.mark.parametrize( + "input_dict, exclude_arrays, expected_output", + [ + ( + {"a": np.array([1, 2, 3]), "b": {"c": np.array([4, 5, 6])}}, + False, + {"a": np.array([1, 2, 3]), "b/c": np.array([4, 5, 6])}, + ), + ({"a": np.array([1, 2, 3]), "b": {"c": np.array([4, 5, 6])}}, True, {}), + ], + ) + def test_flatten_dict_exclude_arrays(input_dict, exclude_arrays, expected_output): + result = BaseLogger.prepare_dict_for_logging(input_dict, exclude_arrays=exclude_arrays) + assert result.keys() == expected_output.keys() + for val1, val2 in zip(result.values(), expected_output.values(), strict=True): + assert np.all(val1 == val2) + + @staticmethod + @pytest.mark.parametrize( + "input_dict, expected_output", + [ + ({"a": (1,), "b": {"c": "2", "d": {"e": 3}}}, {"b/d/e": 3}), + ], + ) + def test_flatten_dict_invalid_values_filtered_out(input_dict, expected_output): + result = BaseLogger.prepare_dict_for_logging(input_dict) + assert result == expected_output diff --git a/test/base/test_stats.py b/test/base/test_stats.py new file mode 100644 index 000000000..b9ec67a12 --- /dev/null +++ b/test/base/test_stats.py @@ -0,0 +1,40 @@ +import pytest + +from tianshou.policy.base import TrainingStats, TrainingStatsWrapper + + +class DummyTrainingStatsWrapper(TrainingStatsWrapper): + def __init__(self, wrapped_stats: TrainingStats, *, dummy_field: int): + self.dummy_field = dummy_field + super().__init__(wrapped_stats) + + +class TestStats: + @staticmethod + def test_training_stats_wrapper(): + train_stats = TrainingStats(train_time=1.0) + train_stats.loss_field = 12 + + wrapped_train_stats = DummyTrainingStatsWrapper(train_stats, dummy_field=42) + + # basic readout + assert wrapped_train_stats.train_time == 1.0 + assert wrapped_train_stats.loss_field == 12 + + # mutation of TrainingStats fields + wrapped_train_stats.train_time = 2.0 + wrapped_train_stats.smoothed_loss["foo"] = 50 + assert wrapped_train_stats.train_time == 2.0 + assert wrapped_train_stats.smoothed_loss["foo"] == 50 + + # loss stats dict + assert wrapped_train_stats.get_loss_stats_dict() == {"loss_field": 12, "dummy_field": 42} + + # new fields can't be added + with pytest.raises(AttributeError): + wrapped_train_stats.new_loss_field = 90 + + # existing fields, wrapped and not-wrapped, can be mutated + wrapped_train_stats.loss_field = 13 + wrapped_train_stats.dummy_field = 43 + assert wrapped_train_stats.wrapped_stats.loss_field == wrapped_train_stats.loss_field == 13 diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 360a80393..30e0e33de 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -126,7 +126,7 @@ def stop_fn(mean_rewards): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -135,8 +135,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index ab02e4cde..edd51f274 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -144,7 +144,7 @@ def stop_fn(mean_rewards): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -153,8 +153,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 9fe6f1895..1327ede6e 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -174,22 +174,21 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): save_checkpoint_fn=save_checkpoint_fn, ) - for epoch, epoch_stat, info in trainer: - print(f"Epoch: {epoch}") + for epoch_stat in trainer: + print(f"Epoch: {epoch_stat.epoch}") print(epoch_stat) - print(info) + # print(info) - assert stop_fn(info["best_reward"]) + assert stop_fn(epoch_stat.info_stat.best_reward) if __name__ == "__main__": - pprint.pprint(info) + pprint.pprint(epoch_stat) # Let's watch its performance! env = gym.make(args.task) policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") def test_ppo_resume(args=get_args()): diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index 87588e5b4..2258d6c20 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -153,7 +153,7 @@ def stop_fn(mean_rewards): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -162,8 +162,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 13b44b5e9..9c18904d6 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -1,12 +1,14 @@ import argparse import os +import gymnasium as gym import numpy as np import pytest import torch from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.policy import ImitationPolicy, SACPolicy from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger @@ -57,8 +59,11 @@ def get_args(): @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") def test_sac_with_il(args=get_args()): # if you want to use python vector env, please refer to other test scripts - train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed) - test_envs = envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed) + # train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed) + # test_envs = envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed) + env = gym.make(args.task) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n args.max_action = env.action_space.high[0] @@ -146,7 +151,7 @@ def stop_fn(mean_rewards): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) # here we define an imitation collector with a trivial policy policy.eval() @@ -172,7 +177,8 @@ def stop_fn(mean_rewards): ) il_test_collector = Collector( il_policy, - envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed), + # envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed), + gym.make(args.task), ) train_collector.reset() result = OffpolicyTrainer( @@ -188,7 +194,7 @@ def stop_fn(mean_rewards): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 76659d0e0..ecad08bb0 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -143,22 +143,21 @@ def stop_fn(mean_rewards): save_best_fn=save_best_fn, logger=logger, ) - for epoch, epoch_stat, info in trainer: - print(f"Epoch: {epoch}") - print(epoch_stat) - print(info) + for epoch_stat in trainer: + print(f"Epoch: {epoch_stat.epoch}") + pprint.pprint(epoch_stat) + # print(info) - assert stop_fn(info["best_reward"]) + assert stop_fn(epoch_stat.info_stat.best_reward) if __name__ == "__main__": - pprint.pprint(info) + pprint.pprint(epoch_stat.info_stat) # Let's watch its performance! env = gym.make(args.task) policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index ee3f2bb5f..300c96db2 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -148,7 +148,7 @@ def stop_fn(mean_rewards): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -157,8 +157,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index e833cc36d..97205c5e7 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -135,7 +135,7 @@ def stop_fn(mean_rewards): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -144,8 +144,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") policy.eval() # here we define an imitation collector with a trivial policy @@ -173,7 +172,7 @@ def stop_fn(mean_rewards): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -182,8 +181,7 @@ def stop_fn(mean_rewards): il_policy.eval() collector = Collector(il_policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py index 469c04cf4..f63224331 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdq.py @@ -136,7 +136,7 @@ def stop_fn(mean_rewards): stop_fn=stop_fn, ).run() - # assert stop_fn(result["best_reward"]) + # assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! @@ -145,8 +145,9 @@ def stop_fn(mean_rewards): test_envs.seed(args.seed) test_collector.reset() collector_result = test_collector.collect(n_episode=args.test_num, render=args.render) - rews, lens = collector_result["rews"], collector_result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print( + f"Final reward: {collector_result.returns_stat.mean}, length: {collector_result.lens_stat.mean}", + ) if __name__ == "__main__": diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index edb903240..e4406df18 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -186,7 +186,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -196,8 +196,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") def test_c51_resume(args=get_args()): diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 6f23cfbad..751f849ff 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -139,7 +139,7 @@ def test_fn(epoch, env_step): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -149,8 +149,7 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") def test_pdqn(args=get_args()): diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 68b3d53d3..5ca79fd30 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -122,7 +122,7 @@ def test_fn(epoch, env_step): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -131,8 +131,7 @@ def test_fn(epoch, env_step): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 55a66e450..84dd207e8 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -156,7 +156,7 @@ def test_fn(epoch, env_step): logger=logger, update_per_step=args.update_per_step, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -166,8 +166,7 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") def test_pfqf(args=get_args()): diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 504ada3a4..23c35b9a8 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -152,7 +152,7 @@ def test_fn(epoch, env_step): logger=logger, update_per_step=args.update_per_step, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -162,8 +162,7 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") def test_piqn(args=get_args()): diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 310f0a135..6d9873ab3 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -117,7 +117,7 @@ def stop_fn(mean_rewards): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -126,8 +126,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 53d2b29f5..e66dc23a9 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -140,7 +140,7 @@ def stop_fn(mean_rewards): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -149,8 +149,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index e68211e96..6d54d59cf 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -143,7 +143,7 @@ def test_fn(epoch, env_step): logger=logger, update_per_step=args.update_per_step, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -153,8 +153,7 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") def test_pqrdqn(args=get_args()): diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index c36d7a136..c7e51a36e 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -205,7 +205,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -215,8 +215,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") def test_rainbow_resume(args=get_args()): diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 49f4b9648..4831f12de 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -129,7 +129,7 @@ def stop_fn(mean_rewards): update_per_step=args.update_per_step, test_in_train=False, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -138,8 +138,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index e8c4a80d2..ecfb4d2b6 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -183,7 +183,7 @@ def test_fn(epoch, env_step): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -193,8 +193,7 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index e0858b14f..27adc12d8 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -179,7 +179,7 @@ def stop_fn(mean_rewards): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -188,8 +188,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index 45ddf1647..d2fbf7358 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -123,10 +123,9 @@ def stop_fn(mean_rewards): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.rew_mean}, length: {result.len_mean}") elif env.spec.reward_threshold: - assert result["best_reward"] >= env.spec.reward_threshold + assert result.best_reward >= env.spec.reward_threshold if __name__ == "__main__": diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index c42524cb6..cf4a5bd24 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -147,7 +147,7 @@ def test_fn(epoch, env_step): logger=logger, update_per_step=args.update_per_step, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) # save buffer in pickle format, for imitation learning unittest buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs)) @@ -159,5 +159,5 @@ def test_fn(epoch, env_step): else: with open(args.save_buffer_name, "wb") as f: pickle.dump(buf, f) - print(result["rews"].mean()) + print(result.returns_stat.mean) return buf diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index f08510a54..064369cd2 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -150,8 +150,7 @@ def stop_fn(mean_rewards): ).run() train_collector.reset() result = train_collector.collect(n_step=args.buffer_size) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if args.save_buffer_name.endswith(".hdf5"): buffer.save_hdf5(args.save_buffer_name) else: diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index da6e211ed..c82eadcd7 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -198,7 +198,7 @@ def watch(): logger=logger, show_progress=args.show_progress, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) # Let's watch its performance! if __name__ == "__main__": @@ -207,8 +207,7 @@ def watch(): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 1a7c86ba8..693eecc49 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -188,22 +188,23 @@ def stop_fn(mean_rewards): logger=logger, ) - for epoch, epoch_stat, info in trainer: - print(f"Epoch: {epoch}") - print(epoch_stat) - print(info) + for epoch_stat in trainer: + print(f"Epoch: {epoch_stat.epoch}") + pprint.pprint(epoch_stat) + # print(info) - assert stop_fn(info["best_reward"]) + assert stop_fn(epoch_stat.info_stat.best_reward) # Let's watch its performance! if __name__ == "__main__": - pprint.pprint(info) + pprint.pprint(epoch_stat.info_stat) env = gym.make(args.task) policy.eval() collector = Collector(policy, env) collector_result = collector.collect(n_episode=1, render=args.render) - rews, lens = collector_result["rews"], collector_result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print( + f"Final reward: {collector_result.returns_stat.mean}, length: {collector_result.lens_stat.mean}", + ) if __name__ == "__main__": diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 93f342ab0..4f4208b5e 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -157,7 +157,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -167,8 +167,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") def test_discrete_bcq_resume(args=get_args()): diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index b795d4eac..b7e4cc567 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -119,7 +119,7 @@ def stop_fn(mean_rewards): logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -129,8 +129,7 @@ def stop_fn(mean_rewards): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index 83fd79c69..ea880a530 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -122,7 +122,7 @@ def stop_fn(mean_rewards): logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -131,8 +131,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index 4257933af..650aabb8a 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -214,7 +214,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -223,8 +223,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index 9cd1f6ea4..af915844d 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -181,22 +181,23 @@ def stop_fn(mean_rewards): logger=logger, ) - for epoch, epoch_stat, info in trainer: - print(f"Epoch: {epoch}") + for epoch_stat in trainer: + print(f"Epoch: {epoch_stat.epoch}") print(epoch_stat) - print(info) + # print(info) - assert stop_fn(info["best_reward"]) + assert stop_fn(epoch_stat.info_stat.best_reward) # Let's watch its performance! if __name__ == "__main__": - pprint.pprint(info) + pprint.pprint(epoch_stat.info_stat) env = gym.make(args.task) policy.eval() collector = Collector(policy, env) collector_result = collector.collect(n_episode=1, render=args.render) - rews, lens = collector_result["rews"], collector_result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print( + f"Final reward: {collector_result.returns_stat.mean}, length: {collector_result.lens_stat.mean}", + ) if __name__ == "__main__": diff --git a/test/pettingzoo/test_pistonball.py b/test/pettingzoo/test_pistonball.py index 4a6c59655..5043a96ca 100644 --- a/test/pettingzoo/test_pistonball.py +++ b/test/pettingzoo/test_pistonball.py @@ -9,7 +9,7 @@ def test_piston_ball(args=get_args()): return result, agent = train_agent(args) - # assert result["best_reward"] >= args.win_rate + # assert result.best_reward >= args.win_rate if __name__ == "__main__": pprint.pprint(result) diff --git a/test/pettingzoo/test_pistonball_continuous.py b/test/pettingzoo/test_pistonball_continuous.py index afb0d5448..f85884d53 100644 --- a/test/pettingzoo/test_pistonball_continuous.py +++ b/test/pettingzoo/test_pistonball_continuous.py @@ -11,7 +11,7 @@ def test_piston_ball_continuous(args=get_args()): return result, agent = train_agent(args) - # assert result["best_reward"] >= 30.0 + # assert result.best_reward >= 30.0 if __name__ == "__main__": pprint.pprint(result) diff --git a/test/pettingzoo/test_tic_tac_toe.py b/test/pettingzoo/test_tic_tac_toe.py index 524cdb92a..f689283f2 100644 --- a/test/pettingzoo/test_tic_tac_toe.py +++ b/test/pettingzoo/test_tic_tac_toe.py @@ -9,7 +9,7 @@ def test_tic_tac_toe(args=get_args()): return result, agent = train_agent(args) - assert result["best_reward"] >= args.win_rate + assert result.best_reward >= args.win_rate if __name__ == "__main__": pprint.pprint(result) diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index 7a86ce857..623079890 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -18,7 +18,13 @@ VectorReplayBuffer, ) from tianshou.data.buffer.cached import CachedReplayBuffer -from tianshou.data.collector import Collector, AsyncCollector +from tianshou.data.stats import ( + EpochStats, + InfoStats, + SequenceSummaryStats, + TimingStats, +) +from tianshou.data.collector import Collector, AsyncCollector, CollectStats, CollectStatsBase __all__ = [ "Batch", @@ -37,5 +43,11 @@ "HERVectorReplayBuffer", "CachedReplayBuffer", "Collector", + "CollectStats", + "CollectStatsBase", "AsyncCollector", + "EpochStats", + "InfoStats", + "SequenceSummaryStats", + "TimingStats", ] diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index da0442078..f188f3ca0 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -1,6 +1,7 @@ import time import warnings from collections.abc import Callable +from dataclasses import dataclass from typing import Any, cast import gymnasium as gym @@ -13,6 +14,7 @@ PrioritizedReplayBuffer, ReplayBuffer, ReplayBufferManager, + SequenceSummaryStats, VectorReplayBuffer, to_numpy, ) @@ -22,6 +24,34 @@ from tianshou.policy import BasePolicy +@dataclass(kw_only=True) +class CollectStatsBase: + """The most basic stats, often used for offline learning.""" + + n_collected_episodes: int = 0 + """The number of collected episodes.""" + n_collected_steps: int = 0 + """The number of collected steps.""" + + +@dataclass(kw_only=True) +class CollectStats(CollectStatsBase): + """A data structure for storing the statistics of rollouts.""" + + collect_time: float = 0.0 + """The time for collecting transitions.""" + collect_speed: float = 0.0 + """The speed of collecting (env_step per second).""" + returns: np.ndarray + """The collected episode returns.""" + returns_stat: SequenceSummaryStats | None # can be None if no episode ends during collect step + """Stats of the collected returns.""" + lens: np.ndarray + """The collected episode lengths.""" + lens_stat: SequenceSummaryStats | None # can be None if no episode ends during collect step + """Stats of the collected episode lengths.""" + + class Collector: """Collector enables the policy to interact with different types of envs with exact number of steps or episodes. @@ -191,7 +221,7 @@ def collect( render: float | None = None, no_grad: bool = True, gym_reset_kwargs: dict[str, Any] | None = None, - ) -> dict[str, Any]: + ) -> CollectStats: """Collect a specified number of step or episode. To ensure unbiased sampling result with n_episode option, this function will @@ -214,17 +244,7 @@ def collect( One and only one collection number specification is permitted, either ``n_step`` or ``n_episode``. - :return: A dict including the following keys - - * ``n/ep`` collected number of episodes. - * ``n/st`` collected number of steps. - * ``rews`` array of episode reward over collected episodes. - * ``lens`` array of episode length over collected episodes. - * ``idxs`` array of episode start index in buffer over collected episodes. - * ``rew`` mean of episodic rewards. - * ``len`` mean of episodic lengths. - * ``rew_std`` standard error of episodic rewards. - * ``len_std`` standard error of episodic lengths. + :return: A dataclass object """ assert not self.env.is_async, "Please use AsyncCollector if using async venv." if n_step is not None: @@ -253,9 +273,9 @@ def collect( step_count = 0 episode_count = 0 - episode_rews = [] - episode_lens = [] - episode_start_indices = [] + episode_returns: list[float] = [] + episode_lens: list[int] = [] + episode_start_indices: list[int] = [] while True: assert len(self.data) == len(ready_env_ids) @@ -334,9 +354,9 @@ def collect( env_ind_local = np.where(done)[0] env_ind_global = ready_env_ids[env_ind_local] episode_count += len(env_ind_local) - episode_lens.append(ep_len[env_ind_local]) - episode_rews.append(ep_rew[env_ind_local]) - episode_start_indices.append(ep_idx[env_ind_local]) + episode_lens.extend(ep_len[env_ind_local]) + episode_returns.extend(ep_rew[env_ind_local]) + episode_start_indices.extend(ep_idx[env_ind_local]) # now we copy obs_next to obs, but since there might be # finished episodes, we have to reset finished envs first. self._reset_env_with_ids(env_ind_local, env_ind_global, gym_reset_kwargs) @@ -361,7 +381,8 @@ def collect( # generate statistics self.collect_step += step_count self.collect_episode += episode_count - self.collect_time += max(time.time() - start_time, 1e-9) + collect_time = max(time.time() - start_time, 1e-9) + self.collect_time += collect_time if n_episode: data = Batch( @@ -378,27 +399,20 @@ def collect( self.data = cast(RolloutBatchProtocol, data) self.reset_env() - if episode_count > 0: - rews, lens, idxs = list( - map(np.concatenate, [episode_rews, episode_lens, episode_start_indices]), - ) - rew_mean, rew_std = rews.mean(), rews.std() - len_mean, len_std = lens.mean(), lens.std() - else: - rews, lens, idxs = np.array([]), np.array([], int), np.array([], int) - rew_mean = rew_std = len_mean = len_std = 0 - - return { - "n/ep": episode_count, - "n/st": step_count, - "rews": rews, - "lens": lens, - "idxs": idxs, - "rew": rew_mean, - "len": len_mean, - "rew_std": rew_std, - "len_std": len_std, - } + return CollectStats( + n_collected_episodes=episode_count, + n_collected_steps=step_count, + collect_time=collect_time, + collect_speed=step_count / collect_time, + returns=np.array(episode_returns), + returns_stat=SequenceSummaryStats.from_sequence(episode_returns) + if len(episode_returns) > 0 + else None, + lens=np.array(episode_lens, int), + lens_stat=SequenceSummaryStats.from_sequence(episode_lens) + if len(episode_lens) > 0 + else None, + ) class AsyncCollector(Collector): @@ -438,7 +452,7 @@ def collect( render: float | None = None, no_grad: bool = True, gym_reset_kwargs: dict[str, Any] | None = None, - ) -> dict[str, Any]: + ) -> CollectStats: """Collect a specified number of step or episode with async env setting. This function doesn't collect exactly n_step or n_episode number of @@ -461,17 +475,7 @@ def collect( One and only one collection number specification is permitted, either ``n_step`` or ``n_episode``. - :return: A dict including the following keys - - * ``n/ep`` collected number of episodes. - * ``n/st`` collected number of steps. - * ``rews`` array of episode reward over collected episodes. - * ``lens`` array of episode length over collected episodes. - * ``idxs`` array of episode start index in buffer over collected episodes. - * ``rew`` mean of episodic rewards. - * ``len`` mean of episodic lengths. - * ``rew_std`` standard error of episodic rewards. - * ``len_std`` standard error of episodic lengths. + :return: A dataclass object """ # collect at least n_step or n_episode if n_step is not None: @@ -494,9 +498,9 @@ def collect( step_count = 0 episode_count = 0 - episode_rews = [] - episode_lens = [] - episode_start_indices = [] + episode_returns: list[float] = [] + episode_lens: list[int] = [] + episode_start_indices: list[int] = [] while True: whole_data = self.data @@ -602,9 +606,9 @@ def collect( env_ind_local = np.where(done)[0] env_ind_global = ready_env_ids[env_ind_local] episode_count += len(env_ind_local) - episode_lens.append(ep_len[env_ind_local]) - episode_rews.append(ep_rew[env_ind_local]) - episode_start_indices.append(ep_idx[env_ind_local]) + episode_lens.extend(ep_len[env_ind_local]) + episode_returns.extend(ep_rew[env_ind_local]) + episode_start_indices.extend(ep_idx[env_ind_local]) # now we copy obs_next to obs, but since there might be # finished episodes, we have to reset finished envs first. self._reset_env_with_ids(env_ind_local, env_ind_global, gym_reset_kwargs) @@ -633,26 +637,20 @@ def collect( # generate statistics self.collect_step += step_count self.collect_episode += episode_count - self.collect_time += max(time.time() - start_time, 1e-9) - - if episode_count > 0: - rews, lens, idxs = list( - map(np.concatenate, [episode_rews, episode_lens, episode_start_indices]), - ) - rew_mean, rew_std = rews.mean(), rews.std() - len_mean, len_std = lens.mean(), lens.std() - else: - rews, lens, idxs = np.array([]), np.array([], int), np.array([], int) - rew_mean = rew_std = len_mean = len_std = 0 - - return { - "n/ep": episode_count, - "n/st": step_count, - "rews": rews, - "lens": lens, - "idxs": idxs, - "rew": rew_mean, - "len": len_mean, - "rew_std": rew_std, - "len_std": len_std, - } + collect_time = max(time.time() - start_time, 1e-9) + self.collect_time += collect_time + + return CollectStats( + n_collected_episodes=episode_count, + n_collected_steps=step_count, + collect_time=collect_time, + collect_speed=step_count / collect_time, + returns=np.array(episode_returns), + returns_stat=SequenceSummaryStats.from_sequence(episode_returns) + if len(episode_returns) > 0 + else None, + lens=np.array(episode_lens, int), + lens_stat=SequenceSummaryStats.from_sequence(episode_lens) + if len(episode_lens) > 0 + else None, + ) diff --git a/tianshou/data/stats.py b/tianshou/data/stats.py new file mode 100644 index 000000000..980b3f84d --- /dev/null +++ b/tianshou/data/stats.py @@ -0,0 +1,86 @@ +from collections.abc import Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +import numpy as np + +if TYPE_CHECKING: + from tianshou.data import CollectStats, CollectStatsBase + from tianshou.policy.base import TrainingStats + + +@dataclass(kw_only=True) +class SequenceSummaryStats: + """A data structure for storing the statistics of a sequence.""" + + mean: float + std: float + max: float + min: float + + @classmethod + def from_sequence(cls, sequence: Sequence[float | int] | np.ndarray) -> "SequenceSummaryStats": + return cls( + mean=float(np.mean(sequence)), + std=float(np.std(sequence)), + max=float(np.max(sequence)), + min=float(np.min(sequence)), + ) + + +@dataclass(kw_only=True) +class TimingStats: + """A data structure for storing timing statistics.""" + + total_time: float = 0.0 + """The total time elapsed.""" + train_time: float = 0.0 + """The total time elapsed for training (collecting samples plus model update).""" + train_time_collect: float = 0.0 + """The total time elapsed for collecting training transitions.""" + train_time_update: float = 0.0 + """The total time elapsed for updating models.""" + test_time: float = 0.0 + """The total time elapsed for testing models.""" + update_speed: float = 0.0 + """The speed of updating (env_step per second).""" + + +@dataclass(kw_only=True) +class InfoStats: + """A data structure for storing information about the learning process.""" + + gradient_step: int + """The total gradient step.""" + best_reward: float + """The best reward over the test results.""" + best_reward_std: float + """Standard deviation of the best reward over the test results.""" + train_step: int + """The total collected step of training collector.""" + train_episode: int + """The total collected episode of training collector.""" + test_step: int + """The total collected step of test collector.""" + test_episode: int + """The total collected episode of test collector.""" + + timing: TimingStats + """The timing statistics.""" + + +@dataclass(kw_only=True) +class EpochStats: + """A data structure for storing epoch statistics.""" + + epoch: int + """The current epoch.""" + + train_collect_stat: "CollectStatsBase" + """The statistics of the last call to the training collector.""" + test_collect_stat: Optional["CollectStats"] + """The statistics of the last call to the test collector.""" + training_stat: "TrainingStats" + """The statistics of the last model update step.""" + info_stat: InfoStats + """The information of the collector.""" diff --git a/tianshou/data/types.py b/tianshou/data/types.py index a63a9d1c0..eb65f75d3 100644 --- a/tianshou/data/types.py +++ b/tianshou/data/types.py @@ -1,3 +1,5 @@ +from typing import Protocol + import numpy as np import torch @@ -5,7 +7,7 @@ from tianshou.data.batch import BatchProtocol, arr_type -class ObsBatchProtocol(BatchProtocol): +class ObsBatchProtocol(BatchProtocol, Protocol): """Observations of an environment that a policy can turn into actions. Typically used inside a policy's forward @@ -15,7 +17,7 @@ class ObsBatchProtocol(BatchProtocol): info: arr_type -class RolloutBatchProtocol(ObsBatchProtocol): +class RolloutBatchProtocol(ObsBatchProtocol, Protocol): """Typically, the outcome of sampling from a replay buffer.""" obs_next: arr_type | BatchProtocol @@ -25,52 +27,52 @@ class RolloutBatchProtocol(ObsBatchProtocol): truncated: arr_type -class BatchWithReturnsProtocol(RolloutBatchProtocol): +class BatchWithReturnsProtocol(RolloutBatchProtocol, Protocol): """With added returns, usually computed with GAE.""" returns: arr_type -class PrioBatchProtocol(RolloutBatchProtocol): +class PrioBatchProtocol(RolloutBatchProtocol, Protocol): """Contains weights that can be used for prioritized replay.""" weight: np.ndarray | torch.Tensor -class RecurrentStateBatch(BatchProtocol): +class RecurrentStateBatch(BatchProtocol, Protocol): """Used by RNNs in policies, contains `hidden` and `cell` fields.""" hidden: torch.Tensor cell: torch.Tensor -class ActBatchProtocol(BatchProtocol): +class ActBatchProtocol(BatchProtocol, Protocol): """Simplest batch, just containing the action. Useful e.g., for random policy.""" act: arr_type -class ActStateBatchProtocol(ActBatchProtocol): +class ActStateBatchProtocol(ActBatchProtocol, Protocol): """Contains action and state (which can be None), useful for policies that can support RNNs.""" state: dict | BatchProtocol | np.ndarray | None -class ModelOutputBatchProtocol(ActStateBatchProtocol): +class ModelOutputBatchProtocol(ActStateBatchProtocol, Protocol): """In addition to state and action, contains model output: (logits).""" logits: torch.Tensor state: dict | BatchProtocol | np.ndarray | None -class FQFBatchProtocol(ModelOutputBatchProtocol): +class FQFBatchProtocol(ModelOutputBatchProtocol, Protocol): """Model outputs, fractions and quantiles_tau - specific to the FQF model.""" fractions: torch.Tensor quantiles_tau: torch.Tensor -class BatchWithAdvantagesProtocol(BatchWithReturnsProtocol): +class BatchWithAdvantagesProtocol(BatchWithReturnsProtocol, Protocol): """Contains estimated advantages and values. Returns are usually computed from GAE of advantages by adding the value. @@ -80,7 +82,7 @@ class BatchWithAdvantagesProtocol(BatchWithReturnsProtocol): v_s: torch.Tensor -class DistBatchProtocol(ModelOutputBatchProtocol): +class DistBatchProtocol(ModelOutputBatchProtocol, Protocol): """Contains dist instances for actions (created by dist_fn). Usually categorical or normal. @@ -89,13 +91,13 @@ class DistBatchProtocol(ModelOutputBatchProtocol): dist: torch.distributions.Distribution -class DistLogProbBatchProtocol(DistBatchProtocol): +class DistLogProbBatchProtocol(DistBatchProtocol, Protocol): """Contains dist objects that can be sampled from and log_prob of taken action.""" log_prob: torch.Tensor -class LogpOldProtocol(BatchWithAdvantagesProtocol): +class LogpOldProtocol(BatchWithAdvantagesProtocol, Protocol): """Contains logp_old, often needed for importance weights, in particular in PPO. Builds on batches that contain advantages and values. @@ -104,7 +106,7 @@ class LogpOldProtocol(BatchWithAdvantagesProtocol): logp_old: torch.Tensor -class QuantileRegressionBatchProtocol(ModelOutputBatchProtocol): +class QuantileRegressionBatchProtocol(ModelOutputBatchProtocol, Protocol): """Contains taus for algorithms using quantile regression. See e.g. https://arxiv.org/abs/1806.06923 @@ -113,7 +115,7 @@ class QuantileRegressionBatchProtocol(ModelOutputBatchProtocol): taus: torch.Tensor -class ImitationBatchProtocol(ActBatchProtocol): +class ImitationBatchProtocol(ActBatchProtocol, Protocol): """Similar to other batches, but contains imitation_logits and q_value fields.""" state: dict | Batch | np.ndarray | None diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 361c3a94c..3989d3583 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -4,12 +4,12 @@ from collections.abc import Sequence from dataclasses import dataclass from pprint import pformat -from typing import Any, Self +from typing import Self import numpy as np import torch -from tianshou.data import Collector +from tianshou.data import Collector, InfoStats from tianshou.highlevel.agent import ( A2CAgentFactory, AgentFactory, @@ -121,8 +121,8 @@ class ExperimentResult: world: World """contains all the essential instances of the experiment""" - trainer_result: dict[str, Any] | None - """dictionary of results as returned by the trainer (if any)""" + trainer_result: InfoStats | None + """dataclass of results as returned by the trainer (if any)""" class Experiment(ToStringMixin): @@ -280,7 +280,7 @@ def run( # train policy log.info("Starting training") - trainer_result: dict[str, Any] | None = None + trainer_result: InfoStats | None = None if self.config.train: trainer = self.agent_factory.create_trainer(world, policy_persistence) world.trainer = trainer @@ -309,7 +309,9 @@ def _watch_agent( policy.eval() test_collector.reset() result = test_collector.collect(n_episode=num_episodes, render=render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + assert result.returns_stat is not None # for mypy + assert result.lens_stat is not None # for mypy + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") class ExperimentBuilder: diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index c8fa45e8e..5e6967ad7 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -1,7 +1,7 @@ """Policy package.""" # isort:skip_file -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import BasePolicy, TrainingStats from tianshou.policy.random import RandomPolicy from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.modelfree.bdq import BranchingDQNPolicy @@ -63,4 +63,5 @@ "PSRLPolicy", "ICMPolicy", "MultiAgentPolicyManager", + "TrainingStats", ] diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 3de0d0fda..73dd45572 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -1,16 +1,19 @@ import logging +import time from abc import ABC, abstractmethod from collections.abc import Callable -from typing import Any, Literal, TypeAlias, cast +from dataclasses import dataclass, field +from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast import gymnasium as gym import numpy as np import torch from gymnasium.spaces import Box, Discrete, MultiBinary, MultiDiscrete from numba import njit +from overrides import override from torch import nn -from tianshou.data import ReplayBuffer, to_numpy, to_torch_as +from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_numpy, to_torch_as from tianshou.data.batch import Batch, BatchProtocol, arr_type from tianshou.data.buffer.base import TBuffer from tianshou.data.types import ( @@ -26,7 +29,106 @@ TLearningRateScheduler: TypeAlias = torch.optim.lr_scheduler.LRScheduler | MultipleLRSchedulers -class BasePolicy(ABC, nn.Module): +@dataclass(kw_only=True) +class TrainingStats: + _non_loss_fields = ("train_time", "smoothed_loss") + + train_time: float = 0.0 + """The time for learning models.""" + + # TODO: modified in the trainer but not used anywhere else. Should be refactored. + smoothed_loss: dict = field(default_factory=dict) + """The smoothed loss statistics of the policy learn step.""" + + # Mainly so that we can override this in the TrainingStatsWrapper + def _get_self_dict(self) -> dict[str, Any]: + return self.__dict__ + + def get_loss_stats_dict(self) -> dict[str, float]: + """Return loss statistics as a dict for logging. + + Returns a dict with all fields except train_time and smoothed_loss. Moreover, fields with value None excluded, + and instances of SequenceSummaryStats are replaced by their mean. + """ + result = {} + for k, v in self._get_self_dict().items(): + if k.startswith("_"): + logger.debug(f"Skipping {k=} as it starts with an underscore.") + continue + if k in self._non_loss_fields or v is None: + continue + if isinstance(v, SequenceSummaryStats): + result[k] = v.mean + else: + result[k] = v + + return result + + +class TrainingStatsWrapper(TrainingStats): + _setattr_frozen = False + _training_stats_public_fields = TrainingStats.__dataclass_fields__.keys() + + def __init__(self, wrapped_stats: TrainingStats) -> None: + """In this particular case, super().__init__() should be called LAST in the subclass init.""" + self._wrapped_stats = wrapped_stats + + # HACK: special sauce for the existing attributes of the base TrainingStats class + # for some reason, delattr doesn't work here, so we need to delegate their handling + # to the wrapped stats object by always keeping the value there and in self in sync + # see also __setattr__ + for k in self._training_stats_public_fields: + super().__setattr__(k, getattr(self._wrapped_stats, k)) + + self._setattr_frozen = True + + @override + def _get_self_dict(self) -> dict[str, Any]: + return {**self._wrapped_stats._get_self_dict(), **self.__dict__} + + @property + def wrapped_stats(self) -> TrainingStats: + return self._wrapped_stats + + def __getattr__(self, name: str) -> Any: + return getattr(self._wrapped_stats, name) + + def __setattr__(self, name: str, value: Any) -> None: + """Setattr logic for wrapper of a dataclass with default values. + + 1. If name exists directly in self, set it there. + 2. If it exists in self._wrapped_stats, set it there instead. + 3. Special case: if name is in the base TrainingStats class, keep it in sync between self and the _wrapped_stats. + 4. If name doesn't exist in either and attribute setting is frozen, raise an AttributeError. + """ + # HACK: special sauce for the existing attributes of the base TrainingStats class, see init + # Need to keep them in sync with the wrapped stats object + if name in self._training_stats_public_fields: + setattr(self._wrapped_stats, name, value) + super().__setattr__(name, value) + return + + if not self._setattr_frozen: + super().__setattr__(name, value) + return + + if not hasattr(self, name): + raise AttributeError( + f"Setting new attributes on StatsWrappers outside of init is not allowed. " + f"Tried to set {name=}, {value=} on {self.__class__.__name__}. \n" + f"NOTE: you may get this error if you call super().__init__() in your subclass init too early! " + f"The call to super().__init__() should be the last call in your subclass init.", + ) + if hasattr(self._wrapped_stats, name): + setattr(self._wrapped_stats, name, value) + else: + super().__setattr__(name, value) + + +TTrainingStats = TypeVar("TTrainingStats", bound=TrainingStats) + + +class BasePolicy(nn.Module, Generic[TTrainingStats], ABC): """The base class for any RL policy. Tianshou aims to modularize RL algorithms. It comes into several classes of @@ -321,10 +423,10 @@ def process_fn( return batch @abstractmethod - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, Any]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TTrainingStats: """Update policy with a given batch of data. - :return: A dict, including the data needed to be logged (e.g., loss). + :return: A dataclass object, including the data needed to be logged (e.g., loss). .. note:: @@ -372,13 +474,15 @@ def update( sample_size: int | None, buffer: ReplayBuffer | None, **kwargs: Any, - ) -> dict[str, Any]: + ) -> TTrainingStats: """Update the policy network and replay buffer. It includes 3 function steps: process_fn, learn, and post_process_fn. In addition, this function will change the value of ``self.updating``: it will be False before this function and will be True when executing :meth:`update`. - Please refer to :ref:`policy_state` for more detailed explanation. + Please refer to :ref:`policy_state` for more detailed explanation. The return + value of learn is augmented with the training time within update, while smoothed + loss values are computed in the trainer. :param sample_size: 0 means it will extract all the data from the buffer, otherwise it will sample a batch with given sample_size. None also @@ -386,20 +490,24 @@ def update( first. TODO: remove the option for 0? :param buffer: the corresponding replay buffer. - :return: A dict, including the data needed to be logged (e.g., loss) from + :return: A dataclass object containing the data needed to be logged (e.g., loss) from ``policy.learn()``. """ + # TODO: when does this happen? + # -> this happens never in practice as update is either called with a collector buffer or an assert before if buffer is None: - return {} + return TrainingStats() # type: ignore[return-value] + start_time = time.time() batch, indices = buffer.sample(sample_size) self.updating = True batch = self.process_fn(batch, buffer, indices) - result = self.learn(batch, **kwargs) + training_stat = self.learn(batch, **kwargs) self.post_process_fn(batch, buffer, indices) if self.lr_scheduler is not None: self.lr_scheduler.step() self.updating = False - return result + training_stat.train_time = time.time() - start_time + return training_stat @staticmethod def value_mask(buffer: ReplayBuffer, indices: np.ndarray) -> np.ndarray: diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index 5acacf1dd..1daa9ae71 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -1,4 +1,5 @@ -from typing import Any, Literal, cast +from dataclasses import dataclass +from typing import Any, Generic, Literal, TypeVar, cast import gymnasium as gym import numpy as np @@ -13,10 +14,18 @@ RolloutBatchProtocol, ) from tianshou.policy import BasePolicy -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import TLearningRateScheduler, TrainingStats -class ImitationPolicy(BasePolicy): +@dataclass(kw_only=True) +class ImitationTrainingStats(TrainingStats): + loss: float = 0.0 + + +TImitationTrainingStats = TypeVar("TImitationTrainingStats", bound=ImitationTrainingStats) + + +class ImitationPolicy(BasePolicy[TImitationTrainingStats], Generic[TImitationTrainingStats]): """Implementation of vanilla imitation learning. :param actor: a model following the rules in @@ -68,7 +77,12 @@ def forward( result = Batch(logits=logits, act=act, state=hidden) return cast(ModelOutputBatchProtocol, result) - def learn(self, batch: RolloutBatchProtocol, *ags: Any, **kwargs: Any) -> dict[str, float]: + def learn( + self, + batch: RolloutBatchProtocol, + *ags: Any, + **kwargs: Any, + ) -> TImitationTrainingStats: self.optim.zero_grad() if self.action_type == "continuous": # regression act = self(batch).act @@ -80,4 +94,5 @@ def learn(self, batch: RolloutBatchProtocol, *ags: Any, **kwargs: Any) -> dict[s loss = F.nll_loss(act, act_target) loss.backward() self.optim.step() - return {"loss": loss.item()} + + return ImitationTrainingStats(loss=loss.item()) # type: ignore diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/bcq.py index 14f9056d8..dee1a80a3 100644 --- a/tianshou/policy/imitation/bcq.py +++ b/tianshou/policy/imitation/bcq.py @@ -1,5 +1,6 @@ import copy -from typing import Any, Literal, Self, cast +from dataclasses import dataclass +from typing import Any, Generic, Literal, Self, TypeVar, cast import gymnasium as gym import numpy as np @@ -10,12 +11,23 @@ from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.policy import BasePolicy -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.utils.net.continuous import VAE from tianshou.utils.optim import clone_optimizer -class BCQPolicy(BasePolicy): +@dataclass(kw_only=True) +class BCQTrainingStats(TrainingStats): + actor_loss: float + critic1_loss: float + critic2_loss: float + vae_loss: float + + +TBCQTrainingStats = TypeVar("TBCQTrainingStats", bound=BCQTrainingStats) + + +class BCQPolicy(BasePolicy[TBCQTrainingStats], Generic[TBCQTrainingStats]): """Implementation of BCQ algorithm. arXiv:1812.02900. :param actor_perturbation: the actor perturbation. `(s, a -> perturbed a)` @@ -142,7 +154,7 @@ def sync_weight(self) -> None: self.soft_update(self.critic2_target, self.critic2, self.tau) self.soft_update(self.actor_perturbation_target, self.actor_perturbation, self.tau) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBCQTrainingStats: # batch: obs, act, rew, done, obs_next. (numpy array) # (batch_size, state_dim) batch: Batch = to_torch(batch, dtype=torch.float, device=self.device) @@ -213,9 +225,9 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ # update target network self.sync_weight() - return { - "loss/actor": actor_loss.item(), - "loss/critic1": critic1_loss.item(), - "loss/critic2": critic2_loss.item(), - "loss/vae": vae_loss.item(), - } + return BCQTrainingStats( # type: ignore + actor_loss=actor_loss.item(), + critic1_loss=critic1_loss.item(), + critic2_loss=critic2_loss.item(), + vae_loss=vae_loss.item(), + ) diff --git a/tianshou/policy/imitation/cql.py b/tianshou/policy/imitation/cql.py index c3bbf6b1e..1ce6d83d4 100644 --- a/tianshou/policy/imitation/cql.py +++ b/tianshou/policy/imitation/cql.py @@ -1,4 +1,5 @@ -from typing import Any, Literal, Self, cast +from dataclasses import dataclass +from typing import Any, Literal, Self, TypeVar, cast import gymnasium as gym import numpy as np @@ -13,10 +14,23 @@ from tianshou.exploration import BaseNoise from tianshou.policy import SACPolicy from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.sac import SACTrainingStats +from tianshou.utils.conversion import to_optional_float from tianshou.utils.net.continuous import ActorProb -class CQLPolicy(SACPolicy): +@dataclass(kw_only=True) +class CQLTrainingStats(SACTrainingStats): + """A data structure for storing loss statistics of the CQL learn step.""" + + cql_alpha: float | None = None + cql_alpha_loss: float | None = None + + +TCQLTrainingStats = TypeVar("TCQLTrainingStats", bound=CQLTrainingStats) + + +class CQLPolicy(SACPolicy[TCQLTrainingStats]): """Implementation of CQL algorithm. arXiv:2006.04779. :param actor: the actor network following the rules in @@ -233,7 +247,7 @@ def process_fn( # Should probably be fixed! return batch - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TCQLTrainingStats: # type: ignore batch: Batch = to_torch(batch, dtype=torch.float, device=self.device) obs, act, rew, obs_next = batch.obs, batch.act, batch.rew, batch.obs_next batch_size = obs.shape[0] @@ -244,6 +258,7 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ actor_loss.backward() self.actor_optim.step() + alpha_loss = None # compute alpha loss if self.is_auto_alpha: log_pi = log_pi + self.target_entropy @@ -341,6 +356,8 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ ) # shape: (1) + cql_alpha_loss = None + cql_alpha = None if self.with_lagrange: cql_alpha = torch.clamp( self.cql_log_alpha.exp(), @@ -373,16 +390,12 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ self.sync_weight() - result = { - "loss/actor": actor_loss.item(), - "loss/critic1": critic1_loss.item(), - "loss/critic2": critic2_loss.item(), - } - if self.is_auto_alpha: - self.alpha = cast(torch.Tensor, self.alpha) - result["loss/alpha"] = alpha_loss.item() - result["alpha"] = self.alpha.item() - if self.with_lagrange: - result["loss/cql_alpha"] = cql_alpha_loss.item() - result["cql_alpha"] = cql_alpha.item() - return result + return CQLTrainingStats( # type: ignore[return-value] + actor_loss=to_optional_float(actor_loss), + critic1_loss=to_optional_float(critic1_loss), + critic2_loss=to_optional_float(critic2_loss), + alpha=to_optional_float(self.alpha), + alpha_loss=to_optional_float(alpha_loss), + cql_alpha_loss=to_optional_float(cql_alpha_loss), + cql_alpha=to_optional_float(cql_alpha), + ) diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index d32d4ef6a..8412e0a60 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -1,5 +1,6 @@ import math -from typing import Any, Self, cast +from dataclasses import dataclass +from typing import Any, Self, TypeVar, cast import gymnasium as gym import numpy as np @@ -14,12 +15,23 @@ ) from tianshou.policy import DQNPolicy from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.dqn import DQNTrainingStats float_info = torch.finfo(torch.float32) INF = float_info.max -class DiscreteBCQPolicy(DQNPolicy): +@dataclass(kw_only=True) +class DiscreteBCQTrainingStats(DQNTrainingStats): + q_loss: float + i_loss: float + reg_loss: float + + +TDiscreteBCQTrainingStats = TypeVar("TDiscreteBCQTrainingStats", bound=DiscreteBCQTrainingStats) + + +class DiscreteBCQPolicy(DQNPolicy[TDiscreteBCQTrainingStats]): """Implementation of discrete BCQ algorithm. arXiv:1910.01708. :param model: a model following the rules in @@ -136,7 +148,12 @@ def forward( # type: ignore result = Batch(act=act, state=state, q_value=q_value, imitation_logits=imitation_logits) return cast(ImitationBatchProtocol, result) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn( + self, + batch: RolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> TDiscreteBCQTrainingStats: if self._iter % self.freq == 0: self.sync_weight() self._iter += 1 @@ -155,9 +172,9 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ loss.backward() self.optim.step() - return { - "loss": loss.item(), - "loss/q": q_loss.item(), - "loss/i": i_loss.item(), - "loss/reg": reg_loss.item(), - } + return DiscreteBCQTrainingStats( # type: ignore[return-value] + loss=loss.item(), + q_loss=q_loss.item(), + i_loss=i_loss.item(), + reg_loss=reg_loss.item(), + ) diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py index 924fd9947..dc23cb75a 100644 --- a/tianshou/policy/imitation/discrete_cql.py +++ b/tianshou/policy/imitation/discrete_cql.py @@ -1,4 +1,5 @@ -from typing import Any +from dataclasses import dataclass +from typing import Any, TypeVar import gymnasium as gym import numpy as np @@ -9,9 +10,19 @@ from tianshou.data.types import RolloutBatchProtocol from tianshou.policy import QRDQNPolicy from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats -class DiscreteCQLPolicy(QRDQNPolicy): +@dataclass(kw_only=True) +class DiscreteCQLTrainingStats(QRDQNTrainingStats): + cql_loss: float + qr_loss: float + + +TDiscreteCQLTrainingStats = TypeVar("TDiscreteCQLTrainingStats", bound=DiscreteCQLTrainingStats) + + +class DiscreteCQLPolicy(QRDQNPolicy[TDiscreteCQLTrainingStats]): """Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779. :param model: a model following the rules in @@ -72,7 +83,12 @@ def __init__( ) self.min_q_weight = min_q_weight - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn( + self, + batch: RolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> TDiscreteCQLTrainingStats: if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() @@ -101,8 +117,9 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ loss.backward() self.optim.step() self._iter += 1 - return { - "loss": loss.item(), - "loss/qr": qr_loss.item(), - "loss/cql": min_q_loss.item(), - } + + return DiscreteCQLTrainingStats( # type: ignore[return-value] + loss=loss.item(), + qr_loss=qr_loss.item(), + cql_loss=min_q_loss.item(), + ) diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index f22ffe101..9a3c2db9f 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -1,5 +1,6 @@ from copy import deepcopy -from typing import Any, Literal +from dataclasses import dataclass +from typing import Any, Literal, TypeVar import gymnasium as gym import torch @@ -9,10 +10,20 @@ from tianshou.data import to_torch, to_torch_as from tianshou.data.types import RolloutBatchProtocol from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.pg import PGPolicy +from tianshou.policy.modelfree.pg import PGPolicy, PGTrainingStats -class DiscreteCRRPolicy(PGPolicy): +@dataclass +class DiscreteCRRTrainingStats(PGTrainingStats): + actor_loss: float + critic_loss: float + cql_loss: float + + +TDiscreteCRRTrainingStats = TypeVar("TDiscreteCRRTrainingStats", bound=DiscreteCRRTrainingStats) + + +class DiscreteCRRPolicy(PGPolicy[TDiscreteCRRTrainingStats]): r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134. :param actor: the actor network following the rules in @@ -96,7 +107,7 @@ def learn( # type: ignore batch: RolloutBatchProtocol, *args: Any, **kwargs: Any, - ) -> dict[str, float]: + ) -> TDiscreteCRRTrainingStats: if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() @@ -131,9 +142,10 @@ def learn( # type: ignore loss.backward() self.optim.step() self._iter += 1 - return { - "loss": loss.item(), - "loss/actor": actor_loss.item(), - "loss/critic": critic_loss.item(), - "loss/cql": min_q_loss.item(), - } + + return DiscreteCRRTrainingStats( # type: ignore[return-value] + loss=loss.item(), + actor_loss=actor_loss.item(), + critic_loss=critic_loss.item(), + cql_loss=min_q_loss.item(), + ) diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index 5461f3657..c98f7afb8 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -1,18 +1,35 @@ -from typing import Any, Literal +from dataclasses import dataclass +from typing import Any, Literal, TypeVar import gymnasium as gym import numpy as np import torch import torch.nn.functional as F -from tianshou.data import ReplayBuffer, to_numpy, to_torch +from tianshou.data import ( + ReplayBuffer, + SequenceSummaryStats, + to_numpy, + to_torch, +) from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol from tianshou.policy import PPOPolicy from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.pg import TDistributionFunction +from tianshou.policy.modelfree.ppo import PPOTrainingStats -class GAILPolicy(PPOPolicy): +@dataclass(kw_only=True) +class GailTrainingStats(PPOTrainingStats): + disc_loss: SequenceSummaryStats + acc_pi: SequenceSummaryStats + acc_exp: SequenceSummaryStats + + +TGailTrainingStats = TypeVar("TGailTrainingStats", bound=GailTrainingStats) + + +class GAILPolicy(PPOPolicy[TGailTrainingStats]): r"""Implementation of Generative Adversarial Imitation Learning. arXiv:1606.03476. :param actor: the actor network following the rules in BasePolicy. (s -> logits) @@ -142,7 +159,7 @@ def learn( # type: ignore batch_size: int | None, repeat: int, **kwargs: Any, - ) -> dict[str, list[float]]: + ) -> TGailTrainingStats: # update discriminator losses = [] acc_pis = [] @@ -162,8 +179,15 @@ def learn( # type: ignore acc_pis.append((logits_pi < 0).float().mean().item()) acc_exps.append((logits_exp > 0).float().mean().item()) # update policy - res = super().learn(batch, batch_size, repeat, **kwargs) - res["loss/disc"] = losses - res["stats/acc_pi"] = acc_pis - res["stats/acc_exp"] = acc_exps - return res + ppo_loss_stat = super().learn(batch, batch_size, repeat, **kwargs) + + disc_losses_summary = SequenceSummaryStats.from_sequence(losses) + acc_pi_summary = SequenceSummaryStats.from_sequence(acc_pis) + acc_exps_summary = SequenceSummaryStats.from_sequence(acc_exps) + + return GailTrainingStats( # type: ignore[return-value] + **ppo_loss_stat.__dict__, + disc_loss=disc_losses_summary, + acc_pi=acc_pi_summary, + acc_exp=acc_exps_summary, + ) diff --git a/tianshou/policy/imitation/td3_bc.py b/tianshou/policy/imitation/td3_bc.py index 4cb60bf21..7ef700b0c 100644 --- a/tianshou/policy/imitation/td3_bc.py +++ b/tianshou/policy/imitation/td3_bc.py @@ -1,4 +1,5 @@ -from typing import Any, Literal +from dataclasses import dataclass +from typing import Any, Literal, TypeVar import gymnasium as gym import torch @@ -9,9 +10,18 @@ from tianshou.exploration import BaseNoise, GaussianNoise from tianshou.policy import TD3Policy from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.td3 import TD3TrainingStats -class TD3BCPolicy(TD3Policy): +@dataclass(kw_only=True) +class TD3BCTrainingStats(TD3TrainingStats): + pass + + +TTD3BCTrainingStats = TypeVar("TTD3BCTrainingStats", bound=TD3BCTrainingStats) + + +class TD3BCPolicy(TD3Policy[TTD3BCTrainingStats]): """Implementation of TD3+BC. arXiv:2106.06860. :param actor: the actor network following the rules in @@ -94,7 +104,7 @@ def __init__( ) self.alpha = alpha - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TTD3BCTrainingStats: # type: ignore # critic 1&2 td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim) @@ -112,8 +122,9 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ self.actor_optim.step() self.sync_weight() self._cnt += 1 - return { - "loss/actor": self._last, - "loss/critic1": critic1_loss.item(), - "loss/critic2": critic2_loss.item(), - } + + return TD3BCTrainingStats( # type: ignore[return-value] + actor_loss=self._last, + critic1_loss=critic1_loss.item(), + critic2_loss=critic2_loss.item(), + ) diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py index 8f05d6f69..54d560b9d 100644 --- a/tianshou/policy/modelbased/icm.py +++ b/tianshou/policy/modelbased/icm.py @@ -9,11 +9,31 @@ from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.policy import BasePolicy -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import ( + TLearningRateScheduler, + TrainingStats, + TrainingStatsWrapper, + TTrainingStats, +) from tianshou.utils.net.discrete import IntrinsicCuriosityModule -class ICMPolicy(BasePolicy): +class ICMTrainingStats(TrainingStatsWrapper): + def __init__( + self, + wrapped_stats: TrainingStats, + *, + icm_loss: float, + icm_forward_loss: float, + icm_inverse_loss: float, + ) -> None: + self.icm_loss = icm_loss + self.icm_forward_loss = icm_forward_loss + self.icm_inverse_loss = icm_inverse_loss + super().__init__(wrapped_stats) + + +class ICMPolicy(BasePolicy[ICMTrainingStats]): """Implementation of Intrinsic Curiosity Module. arXiv:1705.05363. :param policy: a base policy to add ICM to. @@ -37,7 +57,7 @@ class ICMPolicy(BasePolicy): def __init__( self, *, - policy: BasePolicy, + policy: BasePolicy[TTrainingStats], model: IntrinsicCuriosityModule, optim: torch.optim.Optimizer, lr_scale: float, @@ -128,8 +148,13 @@ def post_process_fn( self.policy.post_process_fn(batch, buffer, indices) batch.rew = batch.policy.orig_rew # restore original reward - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: - res = self.policy.learn(batch, **kwargs) + def learn( + self, + batch: RolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> ICMTrainingStats: + training_stat = self.policy.learn(batch, **kwargs) self.optim.zero_grad() act_hat = batch.policy.act_hat act = to_torch(batch.act, dtype=torch.long, device=act_hat.device) @@ -140,11 +165,10 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ ) * self.lr_scale loss.backward() self.optim.step() - res.update( - { - "loss/icm": loss.item(), - "loss/icm/forward": forward_loss.item(), - "loss/icm/inverse": inverse_loss.item(), - }, + + return ICMTrainingStats( + training_stat, + icm_loss=loss.item(), + icm_forward_loss=forward_loss.item(), + icm_inverse_loss=inverse_loss.item(), ) - return res diff --git a/tianshou/policy/modelbased/psrl.py b/tianshou/policy/modelbased/psrl.py index 52dadfedb..8c1374709 100644 --- a/tianshou/policy/modelbased/psrl.py +++ b/tianshou/policy/modelbased/psrl.py @@ -1,4 +1,5 @@ -from typing import Any, cast +from dataclasses import dataclass +from typing import Any, TypeVar, cast import gymnasium as gym import numpy as np @@ -8,7 +9,16 @@ from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.policy import BasePolicy -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import TLearningRateScheduler, TrainingStats + + +@dataclass(kw_only=True) +class PSRLTrainingStats(TrainingStats): + psrl_rew_mean: float = 0.0 + psrl_rew_std: float = 0.0 + + +TPSRLTrainingStats = TypeVar("TPSRLTrainingStats", bound=PSRLTrainingStats) class PSRLModel: @@ -140,7 +150,7 @@ def __call__( return self.policy[obs] -class PSRLPolicy(BasePolicy): +class PSRLPolicy(BasePolicy[TPSRLTrainingStats]): """Implementation of Posterior Sampling Reinforcement Learning. Reference: Strens M. A Bayesian framework for reinforcement learning [C] @@ -217,7 +227,7 @@ def forward( act = self.model(batch.obs, state=state, info=batch.info) return cast(ActBatchProtocol, Batch(act=act)) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TPSRLTrainingStats: n_s, n_a = self.model.n_state, self.model.n_action trans_count = np.zeros((n_s, n_a, n_s)) rew_sum = np.zeros((n_s, n_a)) @@ -236,7 +246,8 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ trans_count[obs_next, :, obs_next] += 1 rew_count[obs_next, :] += 1 self.model.observe(trans_count, rew_sum, rew_square_sum, rew_count) - return { - "psrl/rew_mean": float(self.model.rew_mean.mean()), - "psrl/rew_std": float(self.model.rew_std.mean()), - } + + return PSRLTrainingStats( # type: ignore[return-value] + psrl_rew_mean=float(self.model.rew_mean.mean()), + psrl_rew_std=float(self.model.rew_std.mean()), + ) diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 5d43ec1ee..2aad187dd 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -1,4 +1,5 @@ -from typing import Any, Literal, cast +from dataclasses import dataclass +from typing import Any, Generic, Literal, TypeVar, cast import gymnasium as gym import numpy as np @@ -6,15 +7,27 @@ import torch.nn.functional as F from torch import nn -from tianshou.data import ReplayBuffer, to_torch_as +from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol from tianshou.policy import PGPolicy -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.policy.modelfree.pg import TDistributionFunction from tianshou.utils.net.common import ActorCritic -class A2CPolicy(PGPolicy): +@dataclass(kw_only=True) +class A2CTrainingStats(TrainingStats): + loss: SequenceSummaryStats + actor_loss: SequenceSummaryStats + vf_loss: SequenceSummaryStats + ent_loss: SequenceSummaryStats + + +TA2CTrainingStats = TypeVar("TA2CTrainingStats", bound=A2CTrainingStats) + + +# TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. +class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # type: ignore[type-var] """Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783. :param actor: the actor network following the rules in BasePolicy. (s -> logits) @@ -146,7 +159,7 @@ def learn( # type: ignore repeat: int, *args: Any, **kwargs: Any, - ) -> dict[str, list[float]]: + ) -> TA2CTrainingStats: losses, actor_losses, vf_losses, ent_losses = [], [], [], [] split_batch_size = batch_size or -1 for _ in range(repeat): @@ -175,9 +188,14 @@ def learn( # type: ignore ent_losses.append(ent_loss.item()) losses.append(loss.item()) - return { - "loss": losses, - "loss/actor": actor_losses, - "loss/vf": vf_losses, - "loss/ent": ent_losses, - } + loss_summary_stat = SequenceSummaryStats.from_sequence(losses) + actor_loss_summary_stat = SequenceSummaryStats.from_sequence(actor_losses) + vf_loss_summary_stat = SequenceSummaryStats.from_sequence(vf_losses) + ent_loss_summary_stat = SequenceSummaryStats.from_sequence(ent_losses) + + return A2CTrainingStats( # type: ignore[return-value] + loss=loss_summary_stat, + actor_loss=actor_loss_summary_stat, + vf_loss=vf_loss_summary_stat, + ent_loss=ent_loss_summary_stat, + ) diff --git a/tianshou/policy/modelfree/bdq.py b/tianshou/policy/modelfree/bdq.py index b78aa15e7..a91ea0093 100644 --- a/tianshou/policy/modelfree/bdq.py +++ b/tianshou/policy/modelfree/bdq.py @@ -1,4 +1,5 @@ -from typing import Any, Literal, cast +from dataclasses import dataclass +from typing import Any, Literal, TypeVar, cast import gymnasium as gym import numpy as np @@ -14,10 +15,19 @@ ) from tianshou.policy import DQNPolicy from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.dqn import DQNTrainingStats from tianshou.utils.net.common import BranchingNet -class BranchingDQNPolicy(DQNPolicy): +@dataclass(kw_only=True) +class BDQNTrainingStats(DQNTrainingStats): + pass + + +TBDQNTrainingStats = TypeVar("TBDQNTrainingStats", bound=BDQNTrainingStats) + + +class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]): """Implementation of the Branching dual Q network arXiv:1711.08946. :param model: BranchingNet mapping (obs, state, info) -> logits. @@ -151,7 +161,7 @@ def forward( result = Batch(logits=logits, act=act, state=hidden) return cast(ModelOutputBatchProtocol, result) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBDQNTrainingStats: if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() @@ -169,7 +179,8 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ loss.backward() self.optim.step() self._iter += 1 - return {"loss": loss.item()} + + return BDQNTrainingStats(loss=loss.item()) # type: ignore[return-value] def exploration_noise( self, diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index 600694fa2..bd4491469 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -1,4 +1,5 @@ -from typing import Any +from dataclasses import dataclass +from typing import Any, Generic, TypeVar import gymnasium as gym import numpy as np @@ -8,9 +9,18 @@ from tianshou.data.types import RolloutBatchProtocol from tianshou.policy import DQNPolicy from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.dqn import DQNTrainingStats -class C51Policy(DQNPolicy): +@dataclass(kw_only=True) +class C51TrainingStats(DQNTrainingStats): + pass + + +TC51TrainingStats = TypeVar("TC51TrainingStats", bound=C51TrainingStats) + + +class C51Policy(DQNPolicy[TC51TrainingStats], Generic[TC51TrainingStats]): """Implementation of Categorical Deep Q-Network. arXiv:1707.06887. :param model: a model following the rules in @@ -107,7 +117,7 @@ def _target_dist(self, batch: RolloutBatchProtocol) -> torch.Tensor: ).clamp(0, 1) * next_dist.unsqueeze(1) return target_dist.sum(-1) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TC51TrainingStats: if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() @@ -124,4 +134,5 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ loss.backward() self.optim.step() self._iter += 1 - return {"loss": loss.item()} + + return C51TrainingStats(loss=loss.item()) # type: ignore[return-value] diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index d2f0987f5..1b371d4b3 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -1,6 +1,7 @@ import warnings from copy import deepcopy -from typing import Any, Literal, Self, cast +from dataclasses import dataclass +from typing import Any, Generic, Literal, Self, TypeVar, cast import gymnasium as gym import numpy as np @@ -16,10 +17,19 @@ ) from tianshou.exploration import BaseNoise, GaussianNoise from tianshou.policy import BasePolicy -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import TLearningRateScheduler, TrainingStats -class DDPGPolicy(BasePolicy): +@dataclass(kw_only=True) +class DDPGTrainingStats(TrainingStats): + actor_loss: float + critic_loss: float + + +TDDPGTrainingStats = TypeVar("TDDPGTrainingStats", bound=DDPGTrainingStats) + + +class DDPGPolicy(BasePolicy[TDDPGTrainingStats], Generic[TDDPGTrainingStats]): """Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971. :param actor: The actor network following the rules in @@ -185,7 +195,7 @@ def _mse_optimizer( optimizer.step() return td, critic_loss - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDDPGTrainingStats: # type: ignore # critic td, critic_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) batch.weight = td # prio-buffer @@ -195,7 +205,8 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ actor_loss.backward() self.actor_optim.step() self.sync_weight() - return {"loss/actor": actor_loss.item(), "loss/critic": critic_loss.item()} + + return DDPGTrainingStats(actor_loss=actor_loss.item(), critic_loss=critic_loss.item()) # type: ignore[return-value] def exploration_noise( self, diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 1b4a3427d..8a80f184c 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -1,4 +1,5 @@ -from typing import Any, cast +from dataclasses import dataclass +from typing import Any, TypeVar, cast import gymnasium as gym import numpy as np @@ -11,9 +12,18 @@ from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol from tianshou.policy import SACPolicy from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.sac import SACTrainingStats -class DiscreteSACPolicy(SACPolicy): +@dataclass +class DiscreteSACTrainingStats(SACTrainingStats): + pass + + +TDiscreteSACTrainingStats = TypeVar("TDiscreteSACTrainingStats", bound=DiscreteSACTrainingStats) + + +class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]): """Implementation of SAC for Discrete Action Settings. arXiv:1910.07207. :param actor: the actor network following the rules in @@ -117,7 +127,7 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: ) return target_q.sum(dim=-1) + self.alpha * dist.entropy() - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDiscreteSACTrainingStats: # type: ignore weight = batch.pop("weight", 1.0) target_q = batch.returns.flatten() act = to_torch(batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long) @@ -163,17 +173,16 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ self.sync_weight() - result = { - "loss/actor": actor_loss.item(), - "loss/critic1": critic1_loss.item(), - "loss/critic2": critic2_loss.item(), - } if self.is_auto_alpha: self.alpha = cast(torch.Tensor, self.alpha) - result["loss/alpha"] = alpha_loss.item() - result["alpha"] = self.alpha.item() - return result + return DiscreteSACTrainingStats( # type: ignore[return-value] + actor_loss=actor_loss.item(), + critic1_loss=critic1_loss.item(), + critic2_loss=critic2_loss.item(), + alpha=self.alpha.item() if isinstance(self.alpha, torch.Tensor) else self.alpha, + alpha_loss=None if not self.is_auto_alpha else alpha_loss.item(), + ) def exploration_noise( self, diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index f8c458255..5b90510b4 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -1,5 +1,6 @@ from copy import deepcopy -from typing import Any, Literal, Self, cast +from dataclasses import dataclass +from typing import Any, Generic, Literal, Self, TypeVar, cast import gymnasium as gym import numpy as np @@ -14,10 +15,18 @@ RolloutBatchProtocol, ) from tianshou.policy import BasePolicy -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import TLearningRateScheduler, TrainingStats -class DQNPolicy(BasePolicy): +@dataclass(kw_only=True) +class DQNTrainingStats(TrainingStats): + loss: float + + +TDQNTrainingStats = TypeVar("TDQNTrainingStats", bound=DQNTrainingStats) + + +class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]): """Implementation of Deep Q Network. arXiv:1312.5602. Implementation of Double Q-Learning. arXiv:1509.06461. @@ -199,7 +208,7 @@ def forward( result = Batch(logits=logits, act=act, state=hidden) return cast(ModelOutputBatchProtocol, result) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats: if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() @@ -220,7 +229,8 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ loss.backward() self.optim.step() self._iter += 1 - return {"loss": loss.item()} + + return DQNTrainingStats(loss=loss.item()) # type: ignore[return-value] def exploration_noise( self, diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index c68043614..9f1b083ee 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -1,4 +1,5 @@ -from typing import Any, Literal, cast +from dataclasses import dataclass +from typing import Any, Literal, TypeVar, cast import gymnasium as gym import numpy as np @@ -9,10 +10,21 @@ from tianshou.data.types import FQFBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.policy import DQNPolicy, QRDQNPolicy from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction -class FQFPolicy(QRDQNPolicy): +@dataclass(kw_only=True) +class FQFTrainingStats(QRDQNTrainingStats): + quantile_loss: float + fraction_loss: float + entropy_loss: float + + +TFQFTrainingStats = TypeVar("TFQFTrainingStats", bound=FQFTrainingStats) + + +class FQFPolicy(QRDQNPolicy[TFQFTrainingStats]): """Implementation of Fully-parameterized Quantile Function. arXiv:1911.02140. :param model: a model following the rules in @@ -141,7 +153,7 @@ def forward( # type: ignore ) return cast(FQFBatchProtocol, result) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TFQFTrainingStats: if self._target and self._iter % self.freq == 0: self.sync_weight() weight = batch.pop("weight", 1.0) @@ -199,9 +211,10 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ quantile_loss.backward() self.optim.step() self._iter += 1 - return { - "loss": quantile_loss.item() + fraction_entropy_loss.item(), - "loss/quantile": quantile_loss.item(), - "loss/fraction": fraction_loss.item(), - "loss/entropy": entropy_loss.item(), - } + + return FQFTrainingStats( # type: ignore[return-value] + loss=quantile_loss.item() + fraction_entropy_loss.item(), + quantile_loss=quantile_loss.item(), + fraction_loss=fraction_loss.item(), + entropy_loss=entropy_loss.item(), + ) diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index c87242fcd..f242c146d 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -1,4 +1,5 @@ -from typing import Any, Literal, cast +from dataclasses import dataclass +from typing import Any, Literal, TypeVar, cast import gymnasium as gym import numpy as np @@ -14,9 +15,18 @@ ) from tianshou.policy import QRDQNPolicy from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats -class IQNPolicy(QRDQNPolicy): +@dataclass(kw_only=True) +class IQNTrainingStats(QRDQNTrainingStats): + pass + + +TIQNTrainingStats = TypeVar("TIQNTrainingStats", bound=IQNTrainingStats) + + +class IQNPolicy(QRDQNPolicy[TIQNTrainingStats]): """Implementation of Implicit Quantile Network. arXiv:1806.06923. :param model: a model following the rules in @@ -121,7 +131,7 @@ def forward( result = Batch(logits=logits, act=act, state=hidden, taus=taus) return cast(QuantileRegressionBatchProtocol, result) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TIQNTrainingStats: if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() @@ -148,4 +158,5 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ loss.backward() self.optim.step() self._iter += 1 - return {"loss": loss.item()} + + return IQNTrainingStats(loss=loss.item()) # type: ignore[return-value] diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index bb1a1cccd..f2939450c 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -1,4 +1,5 @@ -from typing import Any, Literal +from dataclasses import dataclass +from typing import Any, Generic, Literal, TypeVar import gymnasium as gym import numpy as np @@ -7,14 +8,25 @@ from torch import nn from torch.distributions import kl_divergence -from tianshou.data import Batch, ReplayBuffer +from tianshou.data import Batch, ReplayBuffer, SequenceSummaryStats from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol from tianshou.policy import A2CPolicy -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.policy.modelfree.pg import TDistributionFunction -class NPGPolicy(A2CPolicy): +@dataclass(kw_only=True) +class NPGTrainingStats(TrainingStats): + actor_loss: SequenceSummaryStats + vf_loss: SequenceSummaryStats + kl: SequenceSummaryStats + + +TNPGTrainingStats = TypeVar("TNPGTrainingStats", bound=NPGTrainingStats) + + +# TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. +class NPGPolicy(A2CPolicy[TNPGTrainingStats], Generic[TNPGTrainingStats]): # type: ignore[type-var] """Implementation of Natural Policy Gradient. https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf @@ -112,7 +124,7 @@ def learn( # type: ignore batch_size: int | None, repeat: int, **kwargs: Any, - ) -> dict[str, list[float]]: + ) -> TNPGTrainingStats: actor_losses, vf_losses, kls = [], [], [] split_batch_size = batch_size or -1 for _ in range(repeat): @@ -144,7 +156,7 @@ def learn( # type: ignore new_dist = self(minibatch).dist kl = kl_divergence(old_dist, new_dist).mean() - # optimize citirc + # optimize critic for _ in range(self.optim_critic_iters): value = self.critic(minibatch.obs).flatten() vf_loss = F.mse_loss(minibatch.returns, value) @@ -156,7 +168,15 @@ def learn( # type: ignore vf_losses.append(vf_loss.item()) kls.append(kl.item()) - return {"loss/actor": actor_losses, "loss/vf": vf_losses, "kl": kls} + actor_loss_summary_stat = SequenceSummaryStats.from_sequence(actor_losses) + vf_loss_summary_stat = SequenceSummaryStats.from_sequence(vf_losses) + kl_summary_stat = SequenceSummaryStats.from_sequence(kls) + + return NPGTrainingStats( # type: ignore[return-value] + actor_loss=actor_loss_summary_stat, + vf_loss=vf_loss_summary_stat, + kl=kl_summary_stat, + ) def _MVP(self, v: torch.Tensor, flat_kl_grad: torch.Tensor) -> torch.Tensor: """Matrix vector product.""" diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 09be46a8d..7db588be6 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -1,12 +1,19 @@ import warnings from collections.abc import Callable -from typing import Any, Literal, TypeAlias, cast +from dataclasses import dataclass +from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast import gymnasium as gym import numpy as np import torch -from tianshou.data import Batch, ReplayBuffer, to_torch, to_torch_as +from tianshou.data import ( + Batch, + ReplayBuffer, + SequenceSummaryStats, + to_torch, + to_torch_as, +) from tianshou.data.batch import BatchProtocol from tianshou.data.types import ( BatchWithReturnsProtocol, @@ -15,14 +22,22 @@ RolloutBatchProtocol, ) from tianshou.policy import BasePolicy -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.utils import RunningMeanStd # TODO: Is there a better way to define this type? mypy doesn't like Callable[[torch.Tensor, ...], torch.distributions.Distribution] TDistributionFunction: TypeAlias = Callable[..., torch.distributions.Distribution] -class PGPolicy(BasePolicy): +@dataclass(kw_only=True) +class PGTrainingStats(TrainingStats): + loss: SequenceSummaryStats + + +TPGTrainingStats = TypeVar("TPGTrainingStats", bound=PGTrainingStats) + + +class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]): """Implementation of REINFORCE algorithm. :param actor: mapping (s->model_output), should follow the rules in @@ -192,12 +207,12 @@ def forward( # TODO: why does mypy complain? def learn( # type: ignore self, - batch: RolloutBatchProtocol, + batch: BatchWithReturnsProtocol, batch_size: int | None, repeat: int, *args: Any, **kwargs: Any, - ) -> dict[str, list[float]]: + ) -> TPGTrainingStats: losses = [] split_batch_size = batch_size or -1 for _ in range(repeat): @@ -213,4 +228,6 @@ def learn( # type: ignore self.optim.step() losses.append(loss.item()) - return {"loss": losses} + loss_summary_stat = SequenceSummaryStats.from_sequence(losses) + + return PGTrainingStats(loss=loss_summary_stat) # type: ignore[return-value] diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 3f8c47fb9..fde9e7c79 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -1,19 +1,32 @@ -from typing import Any, Literal +from dataclasses import dataclass +from typing import Any, Generic, Literal, TypeVar import gymnasium as gym import numpy as np import torch from torch import nn -from tianshou.data import ReplayBuffer, to_torch_as +from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol from tianshou.policy import A2CPolicy -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.policy.modelfree.pg import TDistributionFunction from tianshou.utils.net.common import ActorCritic -class PPOPolicy(A2CPolicy): +@dataclass(kw_only=True) +class PPOTrainingStats(TrainingStats): + loss: SequenceSummaryStats + clip_loss: SequenceSummaryStats + vf_loss: SequenceSummaryStats + ent_loss: SequenceSummaryStats + + +TPPOTrainingStats = TypeVar("TPPOTrainingStats", bound=PPOTrainingStats) + + +# TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. +class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # type: ignore[type-var] r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347. :param actor: the actor network following the rules in BasePolicy. (s -> logits) @@ -132,7 +145,7 @@ def learn( # type: ignore repeat: int, *args: Any, **kwargs: Any, - ) -> dict[str, list[float]]: + ) -> TPPOTrainingStats: losses, clip_losses, vf_losses, ent_losses = [], [], [], [] split_batch_size = batch_size or -1 for step in range(repeat): @@ -182,9 +195,14 @@ def learn( # type: ignore ent_losses.append(ent_loss.item()) losses.append(loss.item()) - return { - "loss": losses, - "loss/clip": clip_losses, - "loss/vf": vf_losses, - "loss/ent": ent_losses, - } + losses_summary = SequenceSummaryStats.from_sequence(losses) + clip_losses_summary = SequenceSummaryStats.from_sequence(clip_losses) + vf_losses_summary = SequenceSummaryStats.from_sequence(vf_losses) + ent_losses_summary = SequenceSummaryStats.from_sequence(ent_losses) + + return PPOTrainingStats( # type: ignore[return-value] + loss=losses_summary, + clip_loss=clip_losses_summary, + vf_loss=vf_losses_summary, + ent_loss=ent_losses_summary, + ) diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index f4b96151c..b2f5d1e8c 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -1,5 +1,6 @@ import warnings -from typing import Any +from dataclasses import dataclass +from typing import Any, Generic, TypeVar import gymnasium as gym import numpy as np @@ -10,9 +11,18 @@ from tianshou.data.types import RolloutBatchProtocol from tianshou.policy import DQNPolicy from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.dqn import DQNTrainingStats -class QRDQNPolicy(DQNPolicy): +@dataclass(kw_only=True) +class QRDQNTrainingStats(DQNTrainingStats): + pass + + +TQRDQNTrainingStats = TypeVar("TQRDQNTrainingStats", bound=QRDQNTrainingStats) + + +class QRDQNPolicy(DQNPolicy[TQRDQNTrainingStats], Generic[TQRDQNTrainingStats]): """Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044. :param model: a model following the rules in @@ -95,7 +105,7 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torch.Tensor: return super().compute_q_value(logits.mean(2), mask) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TQRDQNTrainingStats: if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() @@ -118,4 +128,5 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ loss.backward() self.optim.step() self._iter += 1 - return {"loss": loss.item()} + + return QRDQNTrainingStats(loss=loss.item()) # type: ignore[return-value] diff --git a/tianshou/policy/modelfree/rainbow.py b/tianshou/policy/modelfree/rainbow.py index 2b476bcc7..fad567cd2 100644 --- a/tianshou/policy/modelfree/rainbow.py +++ b/tianshou/policy/modelfree/rainbow.py @@ -1,9 +1,11 @@ -from typing import Any +from dataclasses import dataclass +from typing import Any, TypeVar from torch import nn from tianshou.data.types import RolloutBatchProtocol from tianshou.policy import C51Policy +from tianshou.policy.modelfree.c51 import C51TrainingStats from tianshou.utils.net.discrete import NoisyLinear @@ -25,8 +27,16 @@ def _sample_noise(model: nn.Module) -> bool: return sampled_any_noise +@dataclass(kw_only=True) +class RainbowTrainingStats(C51TrainingStats): + loss: float + + +TRainbowTrainingStats = TypeVar("TRainbowTrainingStats", bound=RainbowTrainingStats) + + # TODO: is this class worth keeping? It barely does anything -class RainbowPolicy(C51Policy): +class RainbowPolicy(C51Policy[TRainbowTrainingStats]): """Implementation of Rainbow DQN. arXiv:1710.02298. Same parameters as :class:`~tianshou.policy.C51Policy`. @@ -37,7 +47,12 @@ class RainbowPolicy(C51Policy): explanation. """ - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn( + self, + batch: RolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> TRainbowTrainingStats: _sample_noise(self.model) if self._target and _sample_noise(self.model_old): self.model_old.train() # so that NoisyLinear takes effect diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index cdc5ad93d..d1b00714d 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -1,4 +1,5 @@ -from typing import Any, Literal, cast +from dataclasses import dataclass +from typing import Any, Literal, TypeVar, cast import gymnasium as gym import numpy as np @@ -10,9 +11,21 @@ from tianshou.exploration import BaseNoise from tianshou.policy import DDPGPolicy from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.ddpg import DDPGTrainingStats -class REDQPolicy(DDPGPolicy): +@dataclass +class REDQTrainingStats(DDPGTrainingStats): + """A data structure for storing loss statistics of the REDQ learn step.""" + + alpha: float | None = None + alpha_loss: float | None = None + + +TREDQTrainingStats = TypeVar("TREDQTrainingStats", bound=REDQTrainingStats) + + +class REDQPolicy(DDPGPolicy[TREDQTrainingStats]): """Implementation of REDQ. arXiv:2101.05982. :param actor: The actor network following the rules in @@ -99,6 +112,8 @@ def __init__( self.deterministic_eval = deterministic_eval self.__eps = np.finfo(np.float32).eps.item() + self._last_actor_loss = 0.0 # only for logging purposes + # TODO: reduce duplication with SACPolicy self.alpha: float | torch.Tensor self._is_auto_alpha = not isinstance(alpha, float) @@ -168,7 +183,7 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: return target_q - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TREDQTrainingStats: # type: ignore # critic ensemble weight = getattr(batch, "weight", 1.0) current_qs = self.critic(batch.obs, batch.act).flatten(1) @@ -181,6 +196,7 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ batch.weight = torch.mean(td, dim=0) # prio-buffer self.critic_gradient_step += 1 + alpha_loss = None # actor if self.critic_gradient_step % self.actor_delay == 0: obs_result = self(batch) @@ -201,12 +217,14 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ self.sync_weight() - result = {"loss/critics": critic_loss.item()} if self.critic_gradient_step % self.actor_delay == 0: - result["loss/actor"] = (actor_loss.item(),) - if self.is_auto_alpha: - self.alpha = cast(torch.Tensor, self.alpha) - result["loss/alpha"] = alpha_loss.item() - result["alpha"] = self.alpha.item() - - return result + self._last_actor_loss = actor_loss.item() + if self.is_auto_alpha: + self.alpha = cast(torch.Tensor, self.alpha) + + return REDQTrainingStats( # type: ignore[return-value] + actor_loss=self._last_actor_loss, + critic_loss=critic_loss.item(), + alpha=self.alpha.item() if isinstance(self.alpha, torch.Tensor) else self.alpha, + alpha_loss=alpha_loss, + ) diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 2c494d4c6..f487e33e9 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -1,5 +1,6 @@ from copy import deepcopy -from typing import Any, Literal, Self, cast +from dataclasses import dataclass +from typing import Any, Generic, Literal, Self, TypeVar, cast import gymnasium as gym import numpy as np @@ -14,11 +15,25 @@ ) from tianshou.exploration import BaseNoise from tianshou.policy import DDPGPolicy -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.utils.conversion import to_optional_float from tianshou.utils.optim import clone_optimizer -class SACPolicy(DDPGPolicy): +@dataclass(kw_only=True) +class SACTrainingStats(TrainingStats): + actor_loss: float + critic1_loss: float + critic2_loss: float + alpha: float | None = None + alpha_loss: float | None = None + + +TSACTrainingStats = TypeVar("TSACTrainingStats", bound=SACTrainingStats) + + +# TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. +class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # type: ignore[type-var] """Implementation of Soft Actor-Critic. arXiv:1812.05905. :param actor: the actor network following the rules in @@ -120,7 +135,10 @@ def __init__( self.target_entropy, self.log_alpha, self.alpha_optim = alpha self.alpha = self.log_alpha.detach().exp() else: - alpha = cast(float, alpha) + alpha = cast( + float, + alpha, + ) # can we convert alpha to a constant tensor here? then mypy wouldn't complain self.alpha = alpha # TODO or not TODO: add to BasePolicy? @@ -195,7 +213,7 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - self.alpha * obs_next_result.log_prob ) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TSACTrainingStats: # type: ignore # critic 1&2 td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim) @@ -212,6 +230,7 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() + alpha_loss = None if self.is_auto_alpha: log_prob = obs_result.log_prob.detach() + self.target_entropy @@ -224,14 +243,10 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ self.sync_weight() - result = { - "loss/actor": actor_loss.item(), - "loss/critic1": critic1_loss.item(), - "loss/critic2": critic2_loss.item(), - } - if self.is_auto_alpha: - self.alpha = cast(torch.Tensor, self.alpha) - result["loss/alpha"] = alpha_loss.item() - result["alpha"] = self.alpha.item() - - return result + return SACTrainingStats( # type: ignore[return-value] + actor_loss=actor_loss.item(), + critic1_loss=critic1_loss.item(), + critic2_loss=critic2_loss.item(), + alpha=to_optional_float(self.alpha), + alpha_loss=to_optional_float(alpha_loss), + ) diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index c364fc8e6..dbf7b6589 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -1,5 +1,6 @@ from copy import deepcopy -from typing import Any, Literal, Self +from dataclasses import dataclass +from typing import Any, Generic, Literal, Self, TypeVar import gymnasium as gym import numpy as np @@ -9,11 +10,22 @@ from tianshou.data.types import RolloutBatchProtocol from tianshou.exploration import BaseNoise from tianshou.policy import DDPGPolicy -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.utils.optim import clone_optimizer -class TD3Policy(DDPGPolicy): +@dataclass(kw_only=True) +class TD3TrainingStats(TrainingStats): + actor_loss: float + critic1_loss: float + critic2_loss: float + + +TTD3TrainingStats = TypeVar("TTD3TrainingStats", bound=TD3TrainingStats) + + +# TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. +class TD3Policy(DDPGPolicy[TTD3TrainingStats], Generic[TTD3TrainingStats]): # type: ignore[type-var] """Implementation of TD3, arXiv:1802.09477. :param actor: the actor network following the rules in @@ -128,7 +140,7 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: self.critic2_old(obs_next_batch.obs, act_), ) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TTD3TrainingStats: # type: ignore # critic 1&2 td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim) @@ -143,8 +155,9 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ self.actor_optim.step() self.sync_weight() self._cnt += 1 - return { - "loss/actor": self._last, - "loss/critic1": critic1_loss.item(), - "loss/critic2": critic2_loss.item(), - } + + return TD3TrainingStats( # type: ignore[return-value] + actor_loss=self._last, + critic1_loss=critic1_loss.item(), + critic2_loss=critic2_loss.item(), + ) diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index 7546a25d8..babc23bfa 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -1,18 +1,28 @@ import warnings -from typing import Any, Literal +from dataclasses import dataclass +from typing import Any, Literal, TypeVar import gymnasium as gym import torch import torch.nn.functional as F from torch.distributions import kl_divergence -from tianshou.data import Batch +from tianshou.data import Batch, SequenceSummaryStats from tianshou.policy import NPGPolicy from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.npg import NPGTrainingStats from tianshou.policy.modelfree.pg import TDistributionFunction -class TRPOPolicy(NPGPolicy): +@dataclass(kw_only=True) +class TRPOTrainingStats(NPGTrainingStats): + step_size: SequenceSummaryStats + + +TTRPOTrainingStats = TypeVar("TTRPOTrainingStats", bound=TRPOTrainingStats) + + +class TRPOPolicy(NPGPolicy[TTRPOTrainingStats]): """Implementation of Trust Region Policy Optimization. arXiv:1502.05477. :param actor: the actor network following the rules in BasePolicy. (s -> logits) @@ -94,7 +104,7 @@ def learn( # type: ignore batch_size: int | None, repeat: int, **kwargs: Any, - ) -> dict[str, list[float]]: + ) -> TTRPOTrainingStats: actor_losses, vf_losses, step_sizes, kls = [], [], [], [] split_batch_size = batch_size or -1 for _ in range(repeat): @@ -171,9 +181,14 @@ def learn( # type: ignore step_sizes.append(step_size.item()) kls.append(kl.item()) - return { - "loss/actor": actor_losses, - "loss/vf": vf_losses, - "step_size": step_sizes, - "kl": kls, - } + actor_loss_summary_stat = SequenceSummaryStats.from_sequence(actor_losses) + vf_loss_summary_stat = SequenceSummaryStats.from_sequence(vf_losses) + kl_summary_stat = SequenceSummaryStats.from_sequence(kls) + step_size_stat = SequenceSummaryStats.from_sequence(step_sizes) + + return TRPOTrainingStats( # type: ignore[return-value] + actor_loss=actor_loss_summary_stat, + vf_loss=vf_loss_summary_stat, + kl=kl_summary_stat, + step_size=step_size_stat, + ) diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index 4cccac773..e41e069d1 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -1,12 +1,13 @@ -from typing import Any, Literal, Self +from typing import Any, Literal, Protocol, Self, cast, overload import numpy as np +from overrides import override from tianshou.data import Batch, ReplayBuffer -from tianshou.data.batch import BatchProtocol +from tianshou.data.batch import BatchProtocol, IndexType from tianshou.data.types import RolloutBatchProtocol from tianshou.policy import BasePolicy -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import TLearningRateScheduler, TrainingStats try: from tianshou.env.pettingzoo_env import PettingZooEnv @@ -14,6 +15,54 @@ PettingZooEnv = None # type: ignore +class MapTrainingStats(TrainingStats): + def __init__( + self, + agent_id_to_stats: dict[str | int, TrainingStats], + train_time_aggregator: Literal["min", "max", "mean"] = "max", + ) -> None: + self._agent_id_to_stats = agent_id_to_stats + train_times = [agent_stats.train_time for agent_stats in agent_id_to_stats.values()] + match train_time_aggregator: + case "max": + aggr_function = max + case "min": + aggr_function = min + case "mean": + aggr_function = np.mean # type: ignore + case _: + raise ValueError( + f"Unknown {train_time_aggregator=}", + ) + self.train_time = aggr_function(train_times) + self.smoothed_loss = {} + + @override + def get_loss_stats_dict(self) -> dict[str, float]: + """Collects loss_stats_dicts from all agents, prepends agent_id to all keys, and joins results.""" + result_dict = {} + for agent_id, stats in self._agent_id_to_stats.items(): + agent_loss_stats_dict = stats.get_loss_stats_dict() + for k, v in agent_loss_stats_dict.items(): + result_dict[f"{agent_id}/" + k] = v + return result_dict + + +class MAPRolloutBatchProtocol(RolloutBatchProtocol, Protocol): + # TODO: this might not be entirely correct. + # The whole MAP data processing pipeline needs more documentation and possibly some refactoring + @overload + def __getitem__(self, index: str) -> RolloutBatchProtocol: + ... + + @overload + def __getitem__(self, index: IndexType) -> Self: + ... + + def __getitem__(self, index: str | IndexType) -> Any: + ... + + class MultiAgentPolicyManager(BasePolicy): """Multi-agent policy manager for MARL. @@ -58,8 +107,10 @@ def __init__( # (this MultiAgentPolicyManager) policy.set_agent_id(env.agents[i]) - self.policies = dict(zip(env.agents, policies, strict=True)) + self.policies: dict[str | int, BasePolicy] = dict(zip(env.agents, policies, strict=True)) + """Maps agent_id to policy.""" + # TODO: unused - remove it? def replace_policy(self, policy: BasePolicy, agent_id: int) -> None: """Replace the "agent_id"th policy in this manager.""" policy.set_agent_id(agent_id) @@ -68,17 +119,18 @@ def replace_policy(self, policy: BasePolicy, agent_id: int) -> None: # TODO: violates Liskov substitution principle def process_fn( # type: ignore self, - batch: RolloutBatchProtocol, + batch: MAPRolloutBatchProtocol, buffer: ReplayBuffer, indice: np.ndarray, - ) -> BatchProtocol: - """Dispatch batch data from obs.agent_id to every policy's process_fn. + ) -> MAPRolloutBatchProtocol: + """Dispatch batch data from `obs.agent_id` to every policy's process_fn. Save original multi-dimensional rew in "save_rew", set rew to the reward of each agent during their "process_fn", and restore the original reward afterwards. """ - results = {} + # TODO: maybe only str is actually allowed as agent_id? See MAPRolloutBatchProtocol + results: dict[str | int, RolloutBatchProtocol] = {} assert isinstance( batch.obs, BatchProtocol, @@ -92,7 +144,7 @@ def process_fn( # type: ignore for agent, policy in self.policies.items(): agent_index = np.nonzero(batch.obs.agent_id == agent)[0] if len(agent_index) == 0: - results[agent] = Batch() + results[agent] = cast(RolloutBatchProtocol, Batch()) continue tmp_batch, tmp_indice = batch[agent_index], indice[agent_index] if has_rew: @@ -133,10 +185,12 @@ def forward( # type: ignore ) -> Batch: """Dispatch batch data from obs.agent_id to every policy's forward. + :param batch: TODO: document what is expected at input and make a BatchProtocol for it :param state: if None, it means all agents have no state. If not None, it should contain keys of "agent_1", "agent_2", ... :return: a Batch with the following contents: + TODO: establish a BatcProtocol for this :: @@ -202,34 +256,24 @@ def forward( # type: ignore holder["state"] = state_dict return holder - def learn( + # Violates Liskov substitution principle + def learn( # type: ignore self, - batch: RolloutBatchProtocol, + batch: MAPRolloutBatchProtocol, *args: Any, **kwargs: Any, - ) -> dict[str, float | list[float]]: + ) -> MapTrainingStats: """Dispatch the data to all policies for learning. - :return: a dict with the following contents: - - :: - - { - "agent_1/item1": item 1 of agent_1's policy.learn output - "agent_1/item2": item 2 of agent_1's policy.learn output - "agent_2/xxx": xxx - ... - "agent_n/xxx": xxx - } + :param batch: must map agent_ids to rollout batches """ - results = {} + agent_id_to_stats = {} for agent_id, policy in self.policies.items(): data = batch[agent_id] if not data.is_empty(): - out = policy.learn(batch=data, **kwargs) - for k, v in out.items(): - results[agent_id + "/" + k] = v - return results + train_stats = policy.learn(batch=data, **kwargs) + agent_id_to_stats[agent_id] = train_stats + return MapTrainingStats(agent_id_to_stats) # Need a train method that set all sub-policies to train mode. # No need for a similar eval function, as eval internally uses the train function. diff --git a/tianshou/policy/random.py b/tianshou/policy/random.py index 38ff7d064..943ae99f2 100644 --- a/tianshou/policy/random.py +++ b/tianshou/policy/random.py @@ -1,4 +1,4 @@ -from typing import Any, cast +from typing import Any, TypeVar, cast import numpy as np @@ -6,9 +6,17 @@ from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.policy import BasePolicy +from tianshou.policy.base import TrainingStats -class RandomPolicy(BasePolicy): +class RandomTrainingStats(TrainingStats): + pass + + +TRandomTrainingStats = TypeVar("TRandomTrainingStats", bound=RandomTrainingStats) + + +class RandomPolicy(BasePolicy[TRandomTrainingStats]): """A random agent used in multi-agent learning. It randomly chooses an action from the legal action. @@ -41,6 +49,6 @@ def forward( result = Batch(act=logits.argmax(axis=-1)) return cast(ActBatchProtocol, result) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TRandomTrainingStats: # type: ignore """Since a random agent learns nothing, it returns an empty dict.""" - return {} + return RandomTrainingStats() # type: ignore[return-value] diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index fbd3e9197..84faf7999 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -3,13 +3,24 @@ from abc import ABC, abstractmethod from collections import defaultdict, deque from collections.abc import Callable +from dataclasses import asdict from typing import Any import numpy as np import tqdm -from tianshou.data import AsyncCollector, Collector, ReplayBuffer +from tianshou.data import ( + AsyncCollector, + Collector, + CollectStats, + EpochStats, + InfoStats, + ReplayBuffer, + SequenceSummaryStats, +) +from tianshou.data.collector import CollectStatsBase from tianshou.policy import BasePolicy +from tianshou.policy.base import TrainingStats from tianshou.trainer.utils import gather_info, test_episode from tianshou.utils import ( BaseLogger, @@ -19,6 +30,7 @@ deprecation, tqdm_config, ) +from tianshou.utils.logging import set_numerical_fields_to_precision log = logging.getLogger(__name__) @@ -189,8 +201,9 @@ def __init__( self.start_epoch = 0 # This is only used for logging but creeps into the implementations # of the trainers. I believe it would be better to remove - self.gradient_step = 0 + self._gradient_step = 0 self.env_step = 0 + self.policy_update_time = 0.0 self.max_epoch = max_epoch self.step_per_epoch = step_per_epoch @@ -218,7 +231,7 @@ def __init__( self.resume_from_log = resume_from_log self.is_run = False - self.last_rew, self.last_len = 0.0, 0 + self.last_rew, self.last_len = 0.0, 0.0 self.epoch = self.start_epoch self.best_epoch = self.start_epoch @@ -233,10 +246,10 @@ def reset(self) -> None: ( self.start_epoch, self.env_step, - self.gradient_step, + self._gradient_step, ) = self.logger.restore_data() - self.last_rew, self.last_len = 0.0, 0 + self.last_rew, self.last_len = 0.0, 0.0 self.start_time = time.time() if self.train_collector is not None: self.train_collector.reset_stat() @@ -258,10 +271,11 @@ def reset(self) -> None: self.env_step, self.reward_metric, ) + assert test_result.returns_stat is not None # for mypy self.best_epoch = self.start_epoch self.best_reward, self.best_reward_std = ( - test_result["rew"], - test_result["rew_std"], + test_result.returns_stat.mean, + test_result.returns_stat.std, ) if self.save_best_fn: self.save_best_fn(self.policy) @@ -274,7 +288,7 @@ def __iter__(self): # type: ignore self.reset() return self - def __next__(self) -> None | tuple[int, dict[str, Any], dict[str, Any]]: + def __next__(self) -> EpochStats: """Perform one epoch (both train and eval).""" self.epoch += 1 self.iter_num += 1 @@ -291,29 +305,32 @@ def __next__(self) -> None | tuple[int, dict[str, Any], dict[str, Any]]: # set policy in train mode self.policy.train() - epoch_stat: dict[str, Any] = {} - progress = tqdm.tqdm if self.show_progress else DummyTqdm # perform n step_per_epoch with progress(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config) as t: while t.n < t.total and not self.stop_fn_flag: - data: dict[str, Any] = {} - result: dict[str, Any] = {} + train_stat: CollectStatsBase if self.train_collector is not None: - data, result, self.stop_fn_flag = self.train_step() - t.update(result["n/st"]) + pbar_data_dict, train_stat, self.stop_fn_flag = self.train_step() + t.update(train_stat.n_collected_steps) if self.stop_fn_flag: - t.set_postfix(**data) + t.set_postfix(**pbar_data_dict) break else: + pbar_data_dict = {} assert self.buffer, "No train_collector or buffer specified" - result["n/ep"] = len(self.buffer) - result["n/st"] = int(self.gradient_step) + train_stat = CollectStatsBase( + n_collected_episodes=len(self.buffer), + n_collected_steps=int(self._gradient_step), + ) t.update() - self.policy_update_fn(data, result) - t.set_postfix(**data) + update_stat = self.policy_update_fn(train_stat) + pbar_data_dict = set_numerical_fields_to_precision(pbar_data_dict) + pbar_data_dict["gradient_step"] = self._gradient_step + + t.set_postfix(**pbar_data_dict) if t.n <= t.total and not self.stop_fn_flag: t.update() @@ -322,49 +339,49 @@ def __next__(self) -> None | tuple[int, dict[str, Any], dict[str, Any]]: if self.train_collector is None: assert self.buffer is not None batch_size = self.batch_size or len(self.buffer) - self.env_step = self.gradient_step * batch_size + self.env_step = self._gradient_step * batch_size + test_stat = None if not self.stop_fn_flag: self.logger.save_data( self.epoch, self.env_step, - self.gradient_step, + self._gradient_step, self.save_checkpoint_fn, ) # test if self.test_collector is not None: test_stat, self.stop_fn_flag = self.test_step() - if not self.is_run: - epoch_stat.update(test_stat) - - if not self.is_run: - epoch_stat.update({k: v.get() for k, v in self.stat.items()}) - epoch_stat["gradient_step"] = self.gradient_step - epoch_stat.update( - { - "env_step": self.env_step, - "rew": self.last_rew, - "len": int(self.last_len), - "n/ep": int(result["n/ep"]), - "n/st": int(result["n/st"]), - }, - ) - info = gather_info( - self.start_time, - self.train_collector, - self.test_collector, - self.best_reward, - self.best_reward_std, - ) - return self.epoch, epoch_stat, info - return None - def test_step(self) -> tuple[dict[str, Any], bool]: + info_stat = gather_info( + start_time=self.start_time, + policy_update_time=self.policy_update_time, + gradient_step=self._gradient_step, + best_reward=self.best_reward, + best_reward_std=self.best_reward_std, + train_collector=self.train_collector, + test_collector=self.test_collector, + ) + + self.logger.log_info_data(asdict(info_stat), self.epoch) + + # in case trainer is used with run(), epoch_stat will not be returned + epoch_stat: EpochStats = EpochStats( + epoch=self.epoch, + train_collect_stat=train_stat, + test_collect_stat=test_stat, + training_stat=update_stat, + info_stat=info_stat, + ) + + return epoch_stat + + def test_step(self) -> tuple[CollectStats, bool]: """Perform one testing step.""" assert self.episode_per_test is not None assert self.test_collector is not None stop_fn_flag = False - test_result = test_episode( + test_stat = test_episode( self.policy, self.test_collector, self.test_fn, @@ -374,7 +391,8 @@ def test_step(self) -> tuple[dict[str, Any], bool]: self.env_step, self.reward_metric, ) - rew, rew_std = test_result["rew"], test_result["rew_std"] + assert test_stat.returns_stat is not None # for mypy + rew, rew_std = test_stat.returns_stat.mean, test_stat.returns_stat.std if self.best_epoch < 0 or self.best_reward < rew: self.best_epoch = self.epoch self.best_reward = float(rew) @@ -389,22 +407,13 @@ def test_step(self) -> tuple[dict[str, Any], bool]: log.info(log_msg) if self.verbose: print(log_msg, flush=True) - if not self.is_run: - test_stat = { - "test_reward": rew, - "test_reward_std": rew_std, - "best_reward": self.best_reward, - "best_reward_std": self.best_reward_std, - "best_epoch": self.best_epoch, - } - else: - test_stat = {} + if self.stop_fn and self.stop_fn(self.best_reward): stop_fn_flag = True return test_stat, stop_fn_flag - def train_step(self) -> tuple[dict[str, Any], dict[str, Any], bool]: + def train_step(self) -> tuple[dict[str, Any], CollectStats, bool]: """Perform one training step.""" assert self.episode_per_test is not None assert self.train_collector is not None @@ -415,25 +424,33 @@ def train_step(self) -> tuple[dict[str, Any], dict[str, Any], bool]: n_step=self.step_per_collect, n_episode=self.episode_per_collect, ) - if result["n/ep"] > 0 and self.reward_metric: - rew = self.reward_metric(result["rews"]) - result.update(rews=rew, rew=rew.mean(), rew_std=rew.std()) - self.env_step += int(result["n/st"]) - self.logger.log_train_data(result, self.env_step) - self.last_rew = result["rew"] if result["n/ep"] > 0 else self.last_rew - self.last_len = result["len"] if result["n/ep"] > 0 else self.last_len + + self.env_step += result.n_collected_steps + + if result.n_collected_episodes > 0: + assert result.returns_stat is not None # for mypy + assert result.lens_stat is not None # for mypy + self.last_rew = result.returns_stat.mean + self.last_len = result.lens_stat.mean + if self.reward_metric: # TODO: move inside collector + rew = self.reward_metric(result.returns) + result.returns = rew + result.returns_stat = SequenceSummaryStats.from_sequence(rew) + + self.logger.log_train_data(asdict(result), self.env_step) + data = { "env_step": str(self.env_step), "rew": f"{self.last_rew:.2f}", "len": str(int(self.last_len)), - "n/ep": str(int(result["n/ep"])), - "n/st": str(int(result["n/st"])), + "n/ep": str(result.n_collected_episodes), + "n/st": str(result.n_collected_steps), } if ( - result["n/ep"] > 0 + result.n_collected_episodes > 0 and self.test_in_train and self.stop_fn - and self.stop_fn(result["rew"]) + and self.stop_fn(result.returns_stat.mean) # type: ignore ): assert self.test_collector is not None test_result = test_episode( @@ -445,31 +462,54 @@ def train_step(self) -> tuple[dict[str, Any], dict[str, Any], bool]: self.logger, self.env_step, ) - if self.stop_fn(test_result["rew"]): + assert test_result.returns_stat is not None # for mypy + if self.stop_fn(test_result.returns_stat.mean): stop_fn_flag = True - self.best_reward = test_result["rew"] - self.best_reward_std = test_result["rew_std"] + self.best_reward = test_result.returns_stat.mean + self.best_reward_std = test_result.returns_stat.std else: self.policy.train() return data, result, stop_fn_flag - def log_update_data(self, data: dict[str, Any], losses: dict[str, Any]) -> None: - """Log losses to current logger.""" - for k in losses: - self.stat[k].add(losses[k]) - losses[k] = self.stat[k].get() - data[k] = f"{losses[k]:.3f}" - self.logger.log_update_data(losses, self.gradient_step) + # TODO: move moving average computation and logging into its own logger + # TODO: maybe think about a command line logger instead of always printing data dict + def _update_moving_avg_stats_and_log_update_data(self, update_stat: TrainingStats) -> None: + """Log losses, update moving average stats, and also modify the smoothed_loss in update_stat.""" + cur_losses_dict = update_stat.get_loss_stats_dict() + update_stat.smoothed_loss = self._update_moving_avg_stats_and_get_averaged_data( + cur_losses_dict, + ) + self.logger.log_update_data(asdict(update_stat), self._gradient_step) + + # TODO: seems convoluted, there should be a better way of dealing with the moving average stats + def _update_moving_avg_stats_and_get_averaged_data( + self, + data: dict[str, float], + ) -> dict[str, float]: + """Add entries to the moving average object in the trainer and retrieve the averaged results. + + :param data: any entries to be tracked in the moving average object. + :return: A dictionary containing the averaged values of the tracked entries. + + """ + smoothed_data = {} + for key, loss_item in data.items(): + self.stat[key].add(loss_item) + smoothed_data[key] = self.stat[key].get() + return smoothed_data @abstractmethod - def policy_update_fn(self, data: dict[str, Any], result: dict[str, Any]) -> None: + def policy_update_fn( + self, + collect_stats: CollectStatsBase, + ) -> TrainingStats: """Policy update function for different trainer implementation. - :param data: information in progress bar. - :param result: collector's return value. + :param collect_stats: provides info about the most recent collection. In the offline case, this will contain + stats of the whole dataset """ - def run(self) -> dict[str, float | str]: + def run(self) -> InfoStats: """Consume iterator. See itertools - recipes. Use functions that consume iterators at C speed @@ -479,25 +519,28 @@ def run(self) -> dict[str, float | str]: self.is_run = True deque(self, maxlen=0) # feed the entire iterator into a zero-length deque info = gather_info( - self.start_time, - self.train_collector, - self.test_collector, - self.best_reward, - self.best_reward_std, + start_time=self.start_time, + policy_update_time=self.policy_update_time, + gradient_step=self._gradient_step, + best_reward=self.best_reward, + best_reward_std=self.best_reward_std, + train_collector=self.train_collector, + test_collector=self.test_collector, ) finally: self.is_run = False return info - def _sample_and_update(self, buffer: ReplayBuffer, data: dict[str, Any]) -> None: - self.gradient_step += 1 + def _sample_and_update(self, buffer: ReplayBuffer) -> TrainingStats: + """Sample a mini-batch, perform one gradient step, and update the _gradient_step counter.""" + self._gradient_step += 1 # Note: since sample_size=batch_size, this will perform # exactly one gradient step. This is why we don't need to calculate the # number of gradient steps, like in the on-policy case. - losses = self.policy.update(sample_size=self.batch_size, buffer=buffer) - data.update({"gradient_step": str(self.gradient_step)}) - self.log_update_data(data, losses) + update_stat = self.policy.update(sample_size=self.batch_size, buffer=buffer) + self._update_moving_avg_stats_and_log_update_data(update_stat) + return update_stat class OfflineTrainer(BaseTrainer): @@ -512,12 +555,14 @@ class OfflineTrainer(BaseTrainer): def policy_update_fn( self, - data: dict[str, Any], - result: dict[str, Any] | None = None, - ) -> None: + collect_stats: CollectStatsBase | None = None, + ) -> TrainingStats: """Perform one off-line policy update.""" assert self.buffer - self._sample_and_update(self.buffer, data) + update_stat = self._sample_and_update(self.buffer) + # logging + self.policy_update_time += update_stat.train_time + return update_stat class OffpolicyTrainer(BaseTrainer): @@ -532,20 +577,31 @@ class OffpolicyTrainer(BaseTrainer): assert isinstance(BaseTrainer.__doc__, str) __doc__ += BaseTrainer.gen_doc("offpolicy") + "\n".join(BaseTrainer.__doc__.split("\n")[1:]) - def policy_update_fn(self, data: dict[str, Any], result: dict[str, Any]) -> None: - """Perform off-policy updates. + def policy_update_fn( + self, + # TODO: this is the only implementation where collect_stats is actually needed. Maybe change interface? + collect_stats: CollectStatsBase, + ) -> TrainingStats: + """Perform `update_per_step * n_collected_steps` gradient steps by sampling mini-batches from the buffer. - :param data: - :param result: must contain `n/st` key, see documentation of - `:meth:~tianshou.data.collector.Collector.collect` for the kind of - data returned there. `n/st` stands for `step_count` + :param collect_stats: the :class:`~TrainingStats` instance returned by the last gradient step. Some values + in it will be replaced by their moving averages. """ assert self.train_collector is not None - n_collected_steps = result["n/st"] - # Same as training intensity, right? - num_updates = round(self.update_per_step * n_collected_steps) - for _ in range(num_updates): - self._sample_and_update(self.train_collector.buffer, data) + n_collected_steps = collect_stats.n_collected_steps + n_gradient_steps = round(self.update_per_step * n_collected_steps) + if n_gradient_steps == 0: + raise ValueError( + f"n_gradient_steps is 0, n_collected_steps={n_collected_steps}, " + f"update_per_step={self.update_per_step}", + ) + for _ in range(n_gradient_steps): + update_stat = self._sample_and_update(self.train_collector.buffer) + + # logging + self.policy_update_time += update_stat.train_time + # TODO: only the last update_stat is returned, should be improved + return update_stat class OnpolicyTrainer(BaseTrainer): @@ -561,12 +617,11 @@ class OnpolicyTrainer(BaseTrainer): def policy_update_fn( self, - data: dict[str, Any], - result: dict[str, Any] | None = None, - ) -> None: - """Perform one on-policy update.""" + result: CollectStatsBase | None = None, + ) -> TrainingStats: + """Perform one on-policy update by passing the entire buffer to the policy's update method.""" assert self.train_collector is not None - losses = self.policy.update( + training_stat = self.policy.update( sample_size=0, buffer=self.train_collector.buffer, # Note: sample_size is None, so the whole buffer is used for the update. @@ -578,13 +633,14 @@ def policy_update_fn( ) # just for logging, no functional role + self.policy_update_time += training_stat.train_time # TODO: remove the gradient step counting in trainers? Doesn't seem like # it's important and it adds complexity - self.gradient_step += 1 + self._gradient_step += 1 if self.batch_size is None: - self.gradient_step += 1 + self._gradient_step += 1 elif self.batch_size > 0: - self.gradient_step += int((len(self.train_collector.buffer) - 0.1) // self.batch_size) + self._gradient_step += int((len(self.train_collector.buffer) - 0.1) // self.batch_size) # Note: this is the main difference to the off-policy trainer! # The second difference is that batches of data are sampled without replacement @@ -593,4 +649,6 @@ def policy_update_fn( self.train_collector.reset_buffer(keep_statistics=True) # The step is the number of mini-batches used for the update, so essentially - self.log_update_data(data, losses) + self._update_moving_avg_stats_and_log_update_data(training_stat) + + return training_stat diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index e8193d5a8..300c7c470 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -1,10 +1,16 @@ import time from collections.abc import Callable -from typing import Any +from dataclasses import asdict import numpy as np -from tianshou.data import Collector +from tianshou.data import ( + Collector, + CollectStats, + InfoStats, + SequenceSummaryStats, + TimingStats, +) from tianshou.policy import BasePolicy from tianshou.utils import BaseLogger @@ -18,7 +24,7 @@ def test_episode( logger: BaseLogger | None = None, global_step: int | None = None, reward_metric: Callable[[np.ndarray], np.ndarray] | None = None, -) -> dict[str, Any]: +) -> CollectStats: """A simple wrapper of testing policy in collector.""" collector.reset_env() collector.reset_buffer() @@ -26,72 +32,72 @@ def test_episode( if test_fn: test_fn(epoch, global_step) result = collector.collect(n_episode=n_episode) - if reward_metric: - rew = reward_metric(result["rews"]) - result.update(rews=rew, rew=rew.mean(), rew_std=rew.std()) + if reward_metric: # TODO: move into collector + rew = reward_metric(result.returns) + result.returns = rew + result.returns_stat = SequenceSummaryStats.from_sequence(rew) if logger and global_step is not None: - logger.log_test_data(result, global_step) + assert result.n_collected_episodes > 0 + logger.log_test_data(asdict(result), global_step) return result def gather_info( start_time: float, - train_collector: Collector | None, - test_collector: Collector | None, + policy_update_time: float, + gradient_step: int, best_reward: float, best_reward_std: float, -) -> dict[str, float | str]: + train_collector: Collector | None = None, + test_collector: Collector | None = None, +) -> InfoStats: """A simple wrapper of gathering information from collectors. - :return: A dictionary with the following keys: + :return: A dataclass object with the following members (depending on available collectors): + * ``gradient_step`` the total number of gradient steps; + * ``best_reward`` the best reward over the test results; + * ``best_reward_std`` the standard deviation of best reward over the test results; * ``train_step`` the total collected step of training collector; * ``train_episode`` the total collected episode of training collector; - * ``train_time/collector`` the time for collecting transitions in the \ - training collector; - * ``train_time/model`` the time for training models; - * ``train_speed`` the speed of training (env_step per second); * ``test_step`` the total collected step of test collector; * ``test_episode`` the total collected episode of test collector; + * ``timing`` the timing statistics, with the following members: + * ``total_time`` the total time elapsed; + * ``train_time`` the total time elapsed for learning training (collecting samples plus model update); + * ``train_time_collect`` the time for collecting transitions in the \ + training collector; + * ``train_time_update`` the time for training models; * ``test_time`` the time for testing; - * ``test_speed`` the speed of testing (env_step per second); - * ``best_reward`` the best reward over the test results; - * ``duration`` the total elapsed time. + * ``update_speed`` the speed of updating (env_step per second). """ - duration = max(0, time.time() - start_time) - model_time = duration - result: dict[str, float | str] = { - "duration": f"{duration:.2f}s", - "train_time/model": f"{model_time:.2f}s", - } + duration = max(0.0, time.time() - start_time) + test_time = 0.0 + update_speed = 0.0 + train_time_collect = 0.0 if test_collector is not None: - model_time = max(0, duration - test_collector.collect_time) - test_speed = test_collector.collect_step / test_collector.collect_time - result.update( - { - "test_step": test_collector.collect_step, - "test_episode": test_collector.collect_episode, - "test_time": f"{test_collector.collect_time:.2f}s", - "test_speed": f"{test_speed:.2f} step/s", - "best_reward": best_reward, - "best_result": f"{best_reward:.2f} ± {best_reward_std:.2f}", - "duration": f"{duration:.2f}s", - "train_time/model": f"{model_time:.2f}s", - }, - ) + test_time = test_collector.collect_time + if train_collector is not None: - model_time = max(0, model_time - train_collector.collect_time) - if test_collector is not None: - train_speed = train_collector.collect_step / (duration - test_collector.collect_time) - else: - train_speed = train_collector.collect_step / duration - result.update( - { - "train_step": train_collector.collect_step, - "train_episode": train_collector.collect_episode, - "train_time/collector": f"{train_collector.collect_time:.2f}s", - "train_time/model": f"{model_time:.2f}s", - "train_speed": f"{train_speed:.2f} step/s", - }, - ) - return result + train_time_collect = train_collector.collect_time + update_speed = train_collector.collect_step / (duration - test_time) + + timing_stat = TimingStats( + total_time=duration, + train_time=duration - test_time, + train_time_collect=train_time_collect, + train_time_update=policy_update_time, + test_time=test_time, + update_speed=update_speed, + ) + + return InfoStats( + gradient_step=gradient_step, + best_reward=best_reward, + best_reward_std=best_reward_std, + train_step=train_collector.collect_step if train_collector is not None else 0, + train_episode=train_collector.collect_episode if train_collector is not None else 0, + test_step=test_collector.collect_step if test_collector is not None else 0, + test_episode=test_collector.collect_episode if test_collector is not None else 0, + timing=timing_stat, + ) diff --git a/tianshou/utils/conversion.py b/tianshou/utils/conversion.py new file mode 100644 index 000000000..bae2db331 --- /dev/null +++ b/tianshou/utils/conversion.py @@ -0,0 +1,25 @@ +from typing import overload + +import torch + + +@overload +def to_optional_float(x: torch.Tensor) -> float: + ... + + +@overload +def to_optional_float(x: float) -> float: + ... + + +@overload +def to_optional_float(x: None) -> None: + ... + + +def to_optional_float(x: torch.Tensor | float | None) -> float | None: + """For the common case where one needs to extract a float from a scalar Tensor, which may be None.""" + if isinstance(x, torch.Tensor): + return x.item() + return x diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/base.py index 92ba41d35..74e0fa4cc 100644 --- a/tianshou/utils/logger/base.py +++ b/tianshou/utils/logger/base.py @@ -1,10 +1,23 @@ +import typing from abc import ABC, abstractmethod from collections.abc import Callable +from enum import Enum from numbers import Number +from typing import Any import numpy as np -LOG_DATA_TYPE = dict[str, int | Number | np.number | np.ndarray] +VALID_LOG_VALS_TYPE = int | Number | np.number | np.ndarray +VALID_LOG_VALS = typing.get_args( + VALID_LOG_VALS_TYPE, +) # I know it's stupid, but we can't use Union type in isinstance + + +class DataScope(Enum): + TRAIN = "train" + TEST = "test" + UPDATE = "update" + INFO = "info" class BaseLogger(ABC): @@ -15,6 +28,7 @@ class BaseLogger(ABC): :param train_interval: the log interval in log_train_data(). Default to 1000. :param test_interval: the log interval in log_test_data(). Default to 1. :param update_interval: the log interval in log_update_data(). Default to 1000. + :param info_interval: the log interval in log_info_data(). Default to 1. """ def __init__( @@ -22,17 +36,20 @@ def __init__( train_interval: int = 1000, test_interval: int = 1, update_interval: int = 1000, + info_interval: int = 1, ) -> None: super().__init__() self.train_interval = train_interval self.test_interval = test_interval self.update_interval = update_interval + self.info_interval = info_interval self.last_log_train_step = -1 self.last_log_test_step = -1 self.last_log_update_step = -1 + self.last_log_info_step = -1 @abstractmethod - def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: + def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None: """Specify how the writer is used to log data. :param str step_type: namespace which the data dict belongs to. @@ -40,53 +57,96 @@ def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: :param data: the data to write with format ``{key: value}``. """ - def log_train_data(self, collect_result: dict, step: int) -> None: + @staticmethod + def prepare_dict_for_logging( + input_dict: dict[str, Any], + parent_key: str = "", + delimiter: str = "/", + exclude_arrays: bool = True, + ) -> dict[str, VALID_LOG_VALS_TYPE]: + """Flattens and filters a nested dictionary by recursively traversing all levels and compressing the keys. + + Filtering is performed with respect to valid logging data types. + + :param input_dict: The nested dictionary to be flattened and filtered. + :param parent_key: The parent key used as a prefix before the input_dict keys. + :param delimiter: The delimiter used to separate the keys. + :param exclude_arrays: Whether to exclude numpy arrays from the output. + :return: A flattened dictionary where the keys are compressed and values are filtered. + """ + result = {} + + def add_to_result( + cur_dict: dict, + prefix: str = "", + ) -> None: + for key, value in cur_dict.items(): + if exclude_arrays and isinstance(value, np.ndarray): + continue + + new_key = prefix + delimiter + key + new_key = new_key.lstrip(delimiter) + + if isinstance(value, dict): + add_to_result( + value, + new_key, + ) + elif isinstance(value, VALID_LOG_VALS): + result[new_key] = value + + add_to_result(input_dict, prefix=parent_key) + return result + + def log_train_data(self, log_data: dict, step: int) -> None: """Use writer to log statistics generated during training. - :param collect_result: a dict containing information of data collected in - training stage, i.e., returns of collector.collect(). - :param step: stands for the timestep the collect_result being logged. + :param log_data: a dict containing the information returned by the collector during the train step. + :param step: stands for the timestep the collector result is logged. """ - if collect_result["n/ep"] > 0 and step - self.last_log_train_step >= self.train_interval: - log_data = { - "train/episode": collect_result["n/ep"], - "train/reward": collect_result["rew"], - "train/length": collect_result["len"], - } + # TODO: move interval check to calling method + if step - self.last_log_train_step >= self.train_interval: + log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.TRAIN.value) self.write("train/env_step", step, log_data) self.last_log_train_step = step - def log_test_data(self, collect_result: dict, step: int) -> None: + def log_test_data(self, log_data: dict, step: int) -> None: """Use writer to log statistics generated during evaluating. - :param collect_result: a dict containing information of data collected in - evaluating stage, i.e., returns of collector.collect(). - :param step: stands for the timestep the collect_result being logged. + :param log_data:a dict containing the information returned by the collector during the evaluation step. + :param step: stands for the timestep the collector result is logged. """ - assert collect_result["n/ep"] > 0 + # TODO: move interval check to calling method (stupid because log_test_data is only called from function in utils.py, not from BaseTrainer) if step - self.last_log_test_step >= self.test_interval: - log_data = { - "test/env_step": step, - "test/reward": collect_result["rew"], - "test/length": collect_result["len"], - "test/reward_std": collect_result["rew_std"], - "test/length_std": collect_result["len_std"], - } - self.write("test/env_step", step, log_data) + log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.TEST.value) + self.write(DataScope.TEST.value + "/env_step", step, log_data) self.last_log_test_step = step - def log_update_data(self, update_result: dict, step: int) -> None: + def log_update_data(self, log_data: dict, step: int) -> None: """Use writer to log statistics generated during updating. - :param update_result: a dict containing information of data collected in - updating stage, i.e., returns of policy.update(). - :param step: stands for the timestep the collect_result being logged. + :param log_data:a dict containing the information returned during the policy update step. + :param step: stands for the timestep the policy training data is logged. """ + # TODO: move interval check to calling method if step - self.last_log_update_step >= self.update_interval: - log_data = {f"update/{k}": v for k, v in update_result.items()} - self.write("update/gradient_step", step, log_data) + log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.UPDATE.value) + self.write(DataScope.UPDATE.value + "/gradient_step", step, log_data) self.last_log_update_step = step + def log_info_data(self, log_data: dict, step: int) -> None: + """Use writer to log global statistics. + + :param log_data: a dict containing information of data collected at the end of an epoch. + :param step: stands for the timestep the training info is logged. + """ + if ( + step - self.last_log_info_step >= self.info_interval + ): # TODO: move interval check to calling method + log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.INFO.value) + self.write(DataScope.INFO.value + "/epoch", step, log_data) + self.last_log_info_step = step + @abstractmethod def save_data( self, @@ -121,7 +181,7 @@ class LazyLogger(BaseLogger): def __init__(self) -> None: super().__init__() - def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: + def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None: """The LazyLogger writes nothing.""" def save_data( diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py index e4e6ea9df..2a26963a5 100644 --- a/tianshou/utils/logger/tensorboard.py +++ b/tianshou/utils/logger/tensorboard.py @@ -4,7 +4,7 @@ from tensorboard.backend.event_processing import event_accumulator from torch.utils.tensorboard import SummaryWriter -from tianshou.utils.logger.base import LOG_DATA_TYPE, BaseLogger +from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE, BaseLogger from tianshou.utils.warning import deprecation @@ -15,6 +15,7 @@ class TensorboardLogger(BaseLogger): :param train_interval: the log interval in log_train_data(). Default to 1000. :param test_interval: the log interval in log_test_data(). Default to 1. :param update_interval: the log interval in log_update_data(). Default to 1000. + :param info_interval: the log interval in log_info_data(). Default to 1. :param save_interval: the save interval in save_data(). Default to 1 (save at the end of each epoch). :param write_flush: whether to flush tensorboard result after each @@ -27,16 +28,17 @@ def __init__( train_interval: int = 1000, test_interval: int = 1, update_interval: int = 1000, + info_interval: int = 1, save_interval: int = 1, write_flush: bool = True, ) -> None: - super().__init__(train_interval, test_interval, update_interval) + super().__init__(train_interval, test_interval, update_interval, info_interval) self.save_interval = save_interval self.write_flush = write_flush self.last_save_step = -1 self.writer = writer - def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: + def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None: for k, v in data.items(): self.writer.add_scalar(k, v, global_step=step) if self.write_flush: # issue 580 diff --git a/tianshou/utils/logger/wandb.py b/tianshou/utils/logger/wandb.py index e984d464d..53dbf107a 100644 --- a/tianshou/utils/logger/wandb.py +++ b/tianshou/utils/logger/wandb.py @@ -6,7 +6,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.utils import BaseLogger, TensorboardLogger -from tianshou.utils.logger.base import LOG_DATA_TYPE +from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE with contextlib.suppress(ImportError): import wandb @@ -30,6 +30,7 @@ class WandbLogger(BaseLogger): :param test_interval: the log interval in log_test_data(). Default to 1. :param update_interval: the log interval in log_update_data(). Default to 1000. + :param info_interval: the log interval in log_info_data(). Default to 1. :param save_interval: the save interval in save_data(). Default to 1 (save at the end of each epoch). :param write_flush: whether to flush tensorboard result after each @@ -46,6 +47,7 @@ def __init__( train_interval: int = 1000, test_interval: int = 1, update_interval: int = 1000, + info_interval: int = 1, save_interval: int = 1000, write_flush: bool = True, project: str | None = None, @@ -55,7 +57,7 @@ def __init__( config: argparse.Namespace | dict | None = None, monitor_gym: bool = True, ) -> None: - super().__init__(train_interval, test_interval, update_interval) + super().__init__(train_interval, test_interval, update_interval, info_interval) self.last_save_step = -1 self.save_interval = save_interval self.write_flush = write_flush @@ -91,7 +93,7 @@ def load(self, writer: SummaryWriter) -> None: self.write_flush, ) - def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: + def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None: if self.tensorboard_logger is None: raise Exception( "`logger` needs to load the Tensorboard Writer before " diff --git a/tianshou/utils/logging.py b/tianshou/utils/logging.py index 607f09c21..dcda429e6 100644 --- a/tianshou/utils/logging.py +++ b/tianshou/utils/logging.py @@ -20,6 +20,22 @@ _logFormat = LOG_DEFAULT_FORMAT +def set_numerical_fields_to_precision(data: dict[str, Any], precision: int = 3) -> dict[str, Any]: + """Returns a copy of the given dictionary with all numerical values rounded to the given precision. + + Note: does not recurse into nested dictionaries. + + :param data: a dictionary + :param precision: the precision to be used + """ + result = {} + for k, v in data.items(): + if isinstance(v, float): + v = round(v, precision) + result[k] = v + return result + + def remove_log_handlers() -> None: """Removes all current log handlers.""" logger = getLogger() diff --git a/tianshou/utils/optim.py b/tianshou/utils/optim.py index 0c1093cc9..c69ef71db 100644 --- a/tianshou/utils/optim.py +++ b/tianshou/utils/optim.py @@ -8,19 +8,24 @@ def optim_step( loss: torch.Tensor, optim: torch.optim.Optimizer, - module: nn.Module, + module: nn.Module | None = None, max_grad_norm: float | None = None, ) -> None: - """Perform a single optimization step. + """Perform a single optimization step: zero_grad -> backward (-> clip_grad_norm) -> step. :param loss: :param optim: - :param module: + :param module: the module to optimize, required if max_grad_norm is passed :param max_grad_norm: if passed, will clip gradients using this """ optim.zero_grad() loss.backward() if max_grad_norm: + if not module: + raise ValueError( + "module must be passed if max_grad_norm is passed. " + "Note: often the module will be the policy, i.e.`self`", + ) nn.utils.clip_grad_norm_(module.parameters(), max_norm=max_grad_norm) optim.step() diff --git a/tianshou/utils/statistics.py b/tianshou/utils/statistics.py index 1f336c59c..779b0babe 100644 --- a/tianshou/utils/statistics.py +++ b/tianshou/utils/statistics.py @@ -29,7 +29,10 @@ def __init__(self, size: int = 100) -> None: self.cache: list[np.number] = [] self.banned = [np.inf, np.nan, -np.inf] - def add(self, data_array: Number | np.number | list | np.ndarray | torch.Tensor) -> float: + def add( + self, + data_array: Number | float | np.number | list | np.ndarray | torch.Tensor, + ) -> float: """Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with only one element, a python scalar, or