From 522f7fbf98853b1088fab97ce296106f689634f7 Mon Sep 17 00:00:00 2001 From: maxhuettenrauch Date: Sat, 30 Dec 2023 11:09:03 +0100 Subject: [PATCH] Feature/dataclasses (#996) This PR adds strict typing to the output of `update` and `learn` in all policies. This will likely be the last large refactoring PR before the next release (0.6.0, not 1.0.0), so it requires some attention. Several difficulties were encountered on the path to that goal: 1. The policy hierarchy is actually "broken" in the sense that the keys of dicts that were output by `learn` did not follow the same enhancement (inheritance) pattern as the policies. This is a real problem and should be addressed in the near future. Generally, several aspects of the policy design and hierarchy might deserve a dedicated discussion. 2. Each policy needs to be generic in the stats return type, because one might want to extend it at some point and then also extend the stats. Even within the source code base this pattern is necessary in many places. 3. The interaction between learn and update is a bit quirky, we currently handle it by having update modify special field inside TrainingStats, whereas all other fields are handled by learn. 4. The IQM module is a policy wrapper and required a TrainingStatsWrapper. The latter relies on a bunch of black magic. They were addressed by: 1. Live with the broken hierarchy, which is now made visible by bounds in generics. We use type: ignore where appropriate. 2. Make all policies generic with bounds following the policy inheritance hierarchy (which is incorrect, see above). We experimented a bit with nested TrainingStats classes, but that seemed to add more complexity and be harder to understand. Unfortunately, mypy thinks that the code below is wrong, wherefore we have to add `type: ignore` to the return of each `learn` ```python T = TypeVar("T", bound=int) def f() -> T: return 3 ``` 3. See above 4. Write representative tests for the `TrainingStatsWrapper`. Still, the black magic might cause nasty surprises down the line (I am not proud of it)... Closes #933 --------- Co-authored-by: Maximilian Huettenrauch Co-authored-by: Michael Panchenko --- docs/02_notebooks/L0_overview.ipynb | 2 +- docs/02_notebooks/L4_Policy.ipynb | 91 ++++-- docs/02_notebooks/L5_Collector.ipynb | 14 +- docs/02_notebooks/L6_Trainer.ipynb | 2 +- docs/02_notebooks/L7_Experiment.ipynb | 2 +- docs/spelling_wordlist.txt | 7 + examples/atari/atari_c51.py | 2 +- examples/atari/atari_dqn.py | 2 +- examples/atari/atari_fqf.py | 2 +- examples/atari/atari_iqn.py | 2 +- examples/atari/atari_ppo.py | 2 +- examples/atari/atari_qrdqn.py | 2 +- examples/atari/atari_rainbow.py | 2 +- examples/atari/atari_sac.py | 2 +- examples/box2d/acrobot_dualdqn.py | 5 +- examples/box2d/bipedal_bdq.py | 5 +- examples/box2d/bipedal_hardcore_sac.py | 3 +- examples/box2d/lunarlander_dqn.py | 5 +- examples/box2d/mcc_sac.py | 5 +- examples/inverse/irl_gail.py | 2 +- examples/mujoco/fetch_her_ddpg.py | 2 +- examples/mujoco/mujoco_a2c.py | 2 +- examples/mujoco/mujoco_a2c_hl.py | 2 +- examples/mujoco/mujoco_ddpg.py | 2 +- examples/mujoco/mujoco_npg.py | 2 +- examples/mujoco/mujoco_ppo.py | 2 +- examples/mujoco/mujoco_redq.py | 2 +- examples/mujoco/mujoco_redq_hl.py | 2 +- examples/mujoco/mujoco_reinforce.py | 2 +- examples/mujoco/mujoco_sac.py | 2 +- examples/mujoco/mujoco_sac_hl.py | 2 +- examples/mujoco/mujoco_td3.py | 2 +- examples/mujoco/mujoco_td3_hl.py | 2 +- examples/mujoco/mujoco_trpo.py | 2 +- examples/offline/atari_bcq.py | 4 +- examples/offline/atari_cql.py | 4 +- examples/offline/atari_crr.py | 4 +- examples/offline/atari_il.py | 4 +- examples/offline/d4rl_bcq.py | 2 +- examples/offline/d4rl_cql.py | 2 +- examples/offline/d4rl_il.py | 2 +- examples/offline/d4rl_td3_bc.py | 2 +- examples/vizdoom/vizdoom_c51.py | 8 +- examples/vizdoom/vizdoom_ppo.py | 8 +- test/base/test_collector.py | 46 +-- test/base/test_logger.py | 59 ++++ test/base/test_stats.py | 40 +++ test/continuous/test_ddpg.py | 5 +- test/continuous/test_npg.py | 5 +- test/continuous/test_ppo.py | 13 +- test/continuous/test_redq.py | 5 +- test/continuous/test_sac_with_il.py | 16 +- test/continuous/test_td3.py | 15 +- test/continuous/test_trpo.py | 5 +- test/discrete/test_a2c_with_il.py | 10 +- test/discrete/test_bdq.py | 7 +- test/discrete/test_c51.py | 5 +- test/discrete/test_dqn.py | 5 +- test/discrete/test_drqn.py | 5 +- test/discrete/test_fqf.py | 5 +- test/discrete/test_iqn.py | 5 +- test/discrete/test_pg.py | 5 +- test/discrete/test_ppo.py | 5 +- test/discrete/test_qrdqn.py | 5 +- test/discrete/test_rainbow.py | 5 +- test/discrete/test_sac.py | 5 +- test/modelbased/test_dqn_icm.py | 5 +- test/modelbased/test_ppo_icm.py | 5 +- test/modelbased/test_psrl.py | 5 +- test/offline/gather_cartpole_data.py | 4 +- test/offline/gather_pendulum_data.py | 3 +- test/offline/test_bcq.py | 5 +- test/offline/test_cql.py | 17 +- test/offline/test_discrete_bcq.py | 5 +- test/offline/test_discrete_cql.py | 5 +- test/offline/test_discrete_crr.py | 5 +- test/offline/test_gail.py | 5 +- test/offline/test_td3_bc.py | 15 +- test/pettingzoo/test_pistonball.py | 2 +- test/pettingzoo/test_pistonball_continuous.py | 2 +- test/pettingzoo/test_tic_tac_toe.py | 2 +- tianshou/data/__init__.py | 14 +- tianshou/data/collector.py | 160 +++++----- tianshou/data/stats.py | 86 +++++ tianshou/data/types.py | 32 +- tianshou/highlevel/experiment.py | 14 +- tianshou/policy/__init__.py | 3 +- tianshou/policy/base.py | 130 +++++++- tianshou/policy/imitation/base.py | 25 +- tianshou/policy/imitation/bcq.py | 32 +- tianshou/policy/imitation/cql.py | 45 ++- tianshou/policy/imitation/discrete_bcq.py | 35 ++- tianshou/policy/imitation/discrete_cql.py | 33 +- tianshou/policy/imitation/discrete_crr.py | 32 +- tianshou/policy/imitation/gail.py | 42 ++- tianshou/policy/imitation/td3_bc.py | 27 +- tianshou/policy/modelbased/icm.py | 48 ++- tianshou/policy/modelbased/psrl.py | 27 +- tianshou/policy/modelfree/a2c.py | 40 ++- tianshou/policy/modelfree/bdq.py | 19 +- tianshou/policy/modelfree/c51.py | 19 +- tianshou/policy/modelfree/ddpg.py | 21 +- tianshou/policy/modelfree/discrete_sac.py | 31 +- tianshou/policy/modelfree/dqn.py | 20 +- tianshou/policy/modelfree/fqf.py | 31 +- tianshou/policy/modelfree/iqn.py | 19 +- tianshou/policy/modelfree/npg.py | 34 +- tianshou/policy/modelfree/pg.py | 31 +- tianshou/policy/modelfree/ppo.py | 40 ++- tianshou/policy/modelfree/qrdqn.py | 19 +- tianshou/policy/modelfree/rainbow.py | 21 +- tianshou/policy/modelfree/redq.py | 40 ++- tianshou/policy/modelfree/sac.py | 47 ++- tianshou/policy/modelfree/td3.py | 31 +- tianshou/policy/modelfree/trpo.py | 35 ++- tianshou/policy/multiagent/mapolicy.py | 100 ++++-- tianshou/policy/random.py | 16 +- tianshou/trainer/base.py | 294 +++++++++++------- tianshou/trainer/utils.py | 112 +++---- tianshou/utils/conversion.py | 25 ++ tianshou/utils/logger/base.py | 124 ++++++-- tianshou/utils/logger/tensorboard.py | 8 +- tianshou/utils/logger/wandb.py | 8 +- tianshou/utils/logging.py | 16 + tianshou/utils/optim.py | 11 +- tianshou/utils/statistics.py | 5 +- 126 files changed, 1767 insertions(+), 823 deletions(-) create mode 100644 test/base/test_logger.py create mode 100644 test/base/test_stats.py create mode 100644 tianshou/data/stats.py create mode 100644 tianshou/utils/conversion.py diff --git a/docs/02_notebooks/L0_overview.ipynb b/docs/02_notebooks/L0_overview.ipynb index de2e06cfd..41a7143e8 100644 --- a/docs/02_notebooks/L0_overview.ipynb +++ b/docs/02_notebooks/L0_overview.ipynb @@ -141,7 +141,7 @@ "# Let's watch its performance!\n", "policy.eval()\n", "result = test_collector.collect(n_episode=1, render=False)\n", - "print(\"Final reward: {}, length: {}\".format(result[\"rews\"].mean(), result[\"lens\"].mean()))" + "print(\"Final reward: {}, length: {}\".format(result.returns.mean(), result.lens.mean()))" ] }, { diff --git a/docs/02_notebooks/L4_Policy.ipynb b/docs/02_notebooks/L4_Policy.ipynb index 4a815d07a..e31214fb4 100644 --- a/docs/02_notebooks/L4_Policy.ipynb +++ b/docs/02_notebooks/L4_Policy.ipynb @@ -54,16 +54,19 @@ }, "outputs": [], "source": [ - "from typing import Dict, List\n", + "from typing import cast\n", "\n", "import numpy as np\n", "import torch\n", "import gymnasium as gym\n", "\n", - "from tianshou.data import Batch, ReplayBuffer, to_torch, to_torch_as\n", - "from tianshou.policy import BasePolicy\n", + "from dataclasses import dataclass\n", + "\n", + "from tianshou.data import Batch, ReplayBuffer, to_torch, to_torch_as, SequenceSummaryStats\n", + "from tianshou.policy import BasePolicy, TrainingStats\n", "from tianshou.utils.net.common import Net\n", - "from tianshou.utils.net.discrete import Actor" + "from tianshou.utils.net.discrete import Actor\n", + "from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol" ] }, { @@ -102,12 +105,14 @@ "\n", "\n", "\n", - "1. Since Tianshou is a **Deep** RL libraries, there should be a policy network in our Policy Module, also a Torch optimizer.\n", - "2. In Tianshou's BasePolicy, `Policy.update()` first calls `Policy.process_fn()` to preprocess training data and computes quantities like episodic returns (gradient free), then it will call `Policy.learn()` to perform the back-propagation.\n", + "1. Since Tianshou is a **Deep** RL libraries, there should be a policy network in our Policy Module, \n", + "also a Torch optimizer.\n", + "2. In Tianshou's BasePolicy, `Policy.update()` first calls `Policy.process_fn()` to \n", + "preprocess training data and computes quantities like episodic returns (gradient free), \n", + "then it will call `Policy.learn()` to perform the back-propagation.\n", + "3. Each Policy is accompanied by a dedicated implementation of `TrainingStats` to store details of training.\n", "\n", "Then we get the implementation below.\n", - "\n", - "\n", "\n" ] }, @@ -119,7 +124,14 @@ }, "outputs": [], "source": [ - "class REINFORCEPolicy(BasePolicy):\n", + "@dataclass(kw_only=True)\n", + "class REINFORCETrainingStats(TrainingStats):\n", + " \"\"\"A dedicated class for REINFORCE training statistics.\"\"\"\n", + "\n", + " loss: SequenceSummaryStats\n", + "\n", + "\n", + "class REINFORCEPolicy(BasePolicy[REINFORCETrainingStats]):\n", " \"\"\"Implementation of REINFORCE algorithm.\"\"\"\n", "\n", " def __init__(\n", @@ -138,7 +150,7 @@ " \"\"\"Compute the discounted returns for each transition.\"\"\"\n", " pass\n", "\n", - " def learn(self, batch: Batch, batch_size: int, repeat: int) -> Dict[str, List[float]]:\n", + " def learn(self, batch: Batch, batch_size: int, repeat: int) -> REINFORCETrainingStats:\n", " \"\"\"Perform the back-propagation.\"\"\"\n", " return" ] @@ -220,7 +232,9 @@ }, "source": [ "### Policy.learn()\n", - "Data batch returned by `Policy.process_fn()` will flow into `Policy.learn()`. Final we can construct our loss function and perform the back-propagation." + "Data batch returned by `Policy.process_fn()` will flow into `Policy.learn()`. Finally,\n", + "we can construct our loss function and perform the back-propagation. The method \n", + "should look something like this:" ] }, { @@ -231,22 +245,24 @@ }, "outputs": [], "source": [ - "def learn(self, batch: Batch, batch_size: int, repeat: int) -> Dict[str, List[float]]:\n", + "from tianshou.utils.optim import optim_step\n", + "\n", + "\n", + "def learn(self, batch: Batch, batch_size: int, repeat: int):\n", " \"\"\"Perform the back-propagation.\"\"\"\n", - " logging_losses = []\n", + " train_losses = []\n", " for _ in range(repeat):\n", " for minibatch in batch.split(batch_size, merge_last=True):\n", - " self.optim.zero_grad()\n", " result = self(minibatch)\n", " dist = result.dist\n", " act = to_torch_as(minibatch.act, result.act)\n", " ret = to_torch(minibatch.returns, torch.float, result.act.device)\n", " log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)\n", " loss = -(log_prob * ret).mean()\n", - " loss.backward()\n", - " self.optim.step()\n", - " logging_losses.append(loss.item())\n", - " return {\"loss\": logging_losses}" + " optim_step(loss, self.optim)\n", + " train_losses.append(loss.item())\n", + "\n", + " return REINFORCETrainingStats(loss=SequenceSummaryStats.from_sequence(train_losses))" ] }, { @@ -256,7 +272,12 @@ }, "source": [ "## Implementation\n", - "Finally we can assemble the implemented methods and form a REINFORCE Policy." + "Now we can assemble the methods and form a REINFORCE Policy. The outputs of\n", + "`learn` will be collected to a dedicated dataclass.\n", + "\n", + "We will also use protocols to specify what fields are expected and produced inside a `Batch` in\n", + "each processing step. By using protocols, we can get better type checking and IDE support \n", + "without having to implement a separate class for each combination of fields." ] }, { @@ -290,30 +311,33 @@ " act = dist.sample()\n", " return Batch(act=act, dist=dist)\n", "\n", - " def process_fn(self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray) -> Batch:\n", + " def process_fn(\n", + " self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray\n", + " ) -> BatchWithReturnsProtocol:\n", " \"\"\"Compute the discounted returns for each transition.\"\"\"\n", " returns, _ = self.compute_episodic_return(\n", " batch, buffer, indices, gamma=0.99, gae_lambda=1.0\n", " )\n", " batch.returns = returns\n", - " return batch\n", + " return cast(BatchWithReturnsProtocol, batch)\n", "\n", - " def learn(self, batch: Batch, batch_size: int, repeat: int) -> Dict[str, List[float]]:\n", + " def learn(\n", + " self, batch: BatchWithReturnsProtocol, batch_size: int, repeat: int\n", + " ) -> REINFORCETrainingStats:\n", " \"\"\"Perform the back-propagation.\"\"\"\n", - " logging_losses = []\n", + " train_losses = []\n", " for _ in range(repeat):\n", " for minibatch in batch.split(batch_size, merge_last=True):\n", - " self.optim.zero_grad()\n", " result = self(minibatch)\n", " dist = result.dist\n", " act = to_torch_as(minibatch.act, result.act)\n", " ret = to_torch(minibatch.returns, torch.float, result.act.device)\n", " log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)\n", " loss = -(log_prob * ret).mean()\n", - " loss.backward()\n", - " self.optim.step()\n", - " logging_losses.append(loss.item())\n", - " return {\"loss\": logging_losses}" + " optim_step(loss, self.optim)\n", + " train_losses.append(loss.item())\n", + "\n", + " return REINFORCETrainingStats(loss=SequenceSummaryStats.from_sequence(train_losses))" ] }, { @@ -370,8 +394,8 @@ "source": [ "print(policy)\n", "print(\"========================================\")\n", - "for para in policy.parameters():\n", - " print(para.shape)" + "for param in policy.parameters():\n", + " print(param.shape)" ] }, { @@ -831,6 +855,13 @@ "\n", "" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/docs/02_notebooks/L5_Collector.ipynb b/docs/02_notebooks/L5_Collector.ipynb index 0dbb77de6..fc5588401 100644 --- a/docs/02_notebooks/L5_Collector.ipynb +++ b/docs/02_notebooks/L5_Collector.ipynb @@ -116,9 +116,9 @@ "source": [ "collect_result = test_collector.collect(n_episode=9)\n", "print(collect_result)\n", - "print(\"Rewards of 9 episodes are {}\".format(collect_result[\"rews\"]))\n", - "print(\"Average episode reward is {}.\".format(collect_result[\"rew\"]))\n", - "print(\"Average episode length is {}.\".format(collect_result[\"len\"]))" + "print(f\"Returns of 9 episodes are {collect_result.returns}\")\n", + "print(f\"Average episode return is {collect_result.returns_stat.mean}.\")\n", + "print(f\"Average episode length is {collect_result.lens_stat.mean}.\")" ] }, { @@ -146,9 +146,9 @@ "test_collector.reset()\n", "collect_result = test_collector.collect(n_episode=9, random=True)\n", "print(collect_result)\n", - "print(\"Rewards of 9 episodes are {}\".format(collect_result[\"rews\"]))\n", - "print(\"Average episode reward is {}.\".format(collect_result[\"rew\"]))\n", - "print(\"Average episode length is {}.\".format(collect_result[\"len\"]))" + "print(f\"Returns of 9 episodes are {collect_result.returns}\")\n", + "print(f\"Average episode return is {collect_result.returns_stat.mean}.\")\n", + "print(f\"Average episode length is {collect_result.lens_stat.mean}.\")" ] }, { @@ -157,7 +157,7 @@ "id": "sKQRTiG10ljU" }, "source": [ - "Seems that an initialized policy performs even worse than a random policy without any training." + "It seems like an initialized policy performs even worse than a random policy without any training." ] }, { diff --git a/docs/02_notebooks/L6_Trainer.ipynb b/docs/02_notebooks/L6_Trainer.ipynb index db6b0fb86..da4010b70 100644 --- a/docs/02_notebooks/L6_Trainer.ipynb +++ b/docs/02_notebooks/L6_Trainer.ipynb @@ -146,7 +146,7 @@ "replaybuffer.reset()\n", "for i in range(10):\n", " evaluation_result = test_collector.collect(n_episode=10)\n", - " print(\"Evaluation reward is {}\".format(evaluation_result[\"rew\"]))\n", + " print(f\"Evaluation mean episodic reward is: {evaluation_result.returns.mean()}\")\n", " train_collector.collect(n_step=2000)\n", " # 0 means taking all data stored in train_collector.buffer\n", " policy.update(0, train_collector.buffer, batch_size=512, repeat=1)\n", diff --git a/docs/02_notebooks/L7_Experiment.ipynb b/docs/02_notebooks/L7_Experiment.ipynb index 55a3be144..ad450374e 100644 --- a/docs/02_notebooks/L7_Experiment.ipynb +++ b/docs/02_notebooks/L7_Experiment.ipynb @@ -303,7 +303,7 @@ "# Let's watch its performance!\n", "policy.eval()\n", "result = test_collector.collect(n_episode=1, render=False)\n", - "print(\"Final reward: {}, length: {}\".format(result[\"rews\"].mean(), result[\"lens\"].mean()))" + "print(f\"Final episode reward: {result.returns.mean()}, length: {result.lens.mean()}\")" ] } ], diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 9066e8694..1849f2efc 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -239,3 +239,10 @@ logp autogenerated subpackage subpackages +recurse +rollout +rollouts +prepend +prepends +dict +dicts diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index c0837d645..21b42aa7b 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -184,7 +184,7 @@ def watch(): print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result["rews"].mean() + rew = result.returns_stat.mean print(f"Mean reward (over {result['n/ep']} episodes): {rew}") if args.watch: diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 2d94b3356..520761e55 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -225,7 +225,7 @@ def watch(): print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result["rews"].mean() + rew = result.returns_stat.mean print(f"Mean reward (over {result['n/ep']} episodes): {rew}") if args.watch: diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 671a16d1b..97bd7ded4 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -197,7 +197,7 @@ def watch(): print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result["rews"].mean() + rew = result.returns_stat.mean print(f"Mean reward (over {result['n/ep']} episodes): {rew}") if args.watch: diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index af7bfe1d9..dc59da469 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -194,7 +194,7 @@ def watch(): print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result["rews"].mean() + rew = result.returns_stat.mean print(f"Mean reward (over {result['n/ep']} episodes): {rew}") if args.watch: diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index c46abf49d..1dd695615 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -252,7 +252,7 @@ def watch(): print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result["rews"].mean() + rew = result.returns_stat.mean print(f"Mean reward (over {result['n/ep']} episodes): {rew}") if args.watch: diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index c72f80d97..5231c0391 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -180,7 +180,7 @@ def watch(): print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result["rews"].mean() + rew = result.returns_stat.mean print(f"Mean reward (over {result['n/ep']} episodes): {rew}") if args.watch: diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 5fd3d4380..ded22f5d9 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -223,7 +223,7 @@ def watch(): print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result["rews"].mean() + rew = result.returns_stat.mean print(f"Mean reward (over {result['n/ep']} episodes): {rew}") if args.watch: diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index 30e2a21be..543d4b8fb 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -236,7 +236,7 @@ def watch(): print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result["rews"].mean() + rew = result.returns_stat.mean print(f"Mean reward (over {result['n/ep']} episodes): {rew}") if args.watch: diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index bfd248fa2..003caa3fd 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -130,7 +130,7 @@ def test_fn(epoch, env_step): logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! @@ -139,8 +139,7 @@ def test_fn(epoch, env_step): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index 1d9aabfd5..ff532830a 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -145,7 +145,7 @@ def test_fn(epoch, env_step): logger=logger, ).run() - # assert stop_fn(result["best_reward"]) + # assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! @@ -154,8 +154,7 @@ def test_fn(epoch, env_step): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index e3703b3c9..2dda6b610 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -187,8 +187,7 @@ def stop_fn(mean_rewards): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 8a3154195..007a310d8 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -126,7 +126,7 @@ def test_fn(epoch, env_step): logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! @@ -135,8 +135,7 @@ def test_fn(epoch, env_step): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 748c9d73c..49a34af14 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -140,7 +140,7 @@ def stop_fn(mean_rewards): logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! @@ -148,8 +148,7 @@ def stop_fn(mean_rewards): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 4e464156c..3dcc92fcd 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -257,7 +257,7 @@ def save_best_fn(policy): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index 05e31019f..6e37100db 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -225,7 +225,7 @@ def save_best_fn(policy): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index fbd2ec989..b58c802ba 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -222,7 +222,7 @@ def save_best_fn(policy): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index 7b7dba7b6..b3772b4d6 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -21,7 +21,7 @@ def main( experiment_config: ExperimentConfig, - task: str = "Ant-v3", + task: str = "Ant-v4", buffer_size: int = 4096, hidden_sizes: Sequence[int] = (64, 64), lr: float = 7e-4, diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index f20292831..75948679b 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -171,7 +171,7 @@ def save_best_fn(policy): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index bad4f0e28..6b1e8c13d 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -219,7 +219,7 @@ def save_best_fn(policy): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index 9b190c7e4..e7287f583 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -227,7 +227,7 @@ def save_best_fn(policy): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index 80b07b635..e39c7b092 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -199,7 +199,7 @@ def save_best_fn(policy): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_redq_hl.py b/examples/mujoco/mujoco_redq_hl.py index 66219bf8c..cfcdb792f 100644 --- a/examples/mujoco/mujoco_redq_hl.py +++ b/examples/mujoco/mujoco_redq_hl.py @@ -18,7 +18,7 @@ def main( experiment_config: ExperimentConfig, - task: str = "Ant-v3", + task: str = "Ant-v4", buffer_size: int = 1000000, hidden_sizes: Sequence[int] = (256, 256), ensemble_size: int = 10, diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index 50bf312a6..1b394189b 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -199,7 +199,7 @@ def save_best_fn(policy): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index c42775379..a53c04c3a 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -193,7 +193,7 @@ def save_best_fn(policy): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index 7a556e219..6af4c6192 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -17,7 +17,7 @@ def main( experiment_config: ExperimentConfig, - task: str = "Ant-v3", + task: str = "Ant-v4", buffer_size: int = 1000000, hidden_sizes: Sequence[int] = (256, 256), actor_lr: float = 1e-3, diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 667db0afb..8e9476143 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -191,7 +191,7 @@ def save_best_fn(policy): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index 1f783f106..67b9c1847 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -22,7 +22,7 @@ def main( experiment_config: ExperimentConfig, - task: str = "Ant-v3", + task: str = "Ant-v4", buffer_size: int = 1000000, hidden_sizes: Sequence[int] = (256, 256), actor_lr: float = 3e-4, diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index e37cf9196..97fc0f910 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -224,7 +224,7 @@ def save_best_fn(policy): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index 63e4256f0..e2015382d 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -188,8 +188,8 @@ def watch(): test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) - rew = result["rews"].mean() - print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') + rew = result.returns_stat.mean + print(f"Mean reward (over {result.n_collected_episodes} episodes): {rew}") if args.watch: watch() diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 14947b04f..82a0d6edf 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -164,8 +164,8 @@ def watch(): test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) - rew = result["rews"].mean() - print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') + rew = result.returns_stat.mean + print(f"Mean reward (over {result.n_collected_episodes} episodes): {rew}") if args.watch: watch() diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index 7432560d1..db768f2b0 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -186,8 +186,8 @@ def watch(): test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) - rew = result["rews"].mean() - print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') + rew = result.returns_stat.mean + print(f"Mean reward (over {result.n_collected_episodes} episodes): {rew}") if args.watch: watch() diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index c4576924f..3efb56f91 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -147,8 +147,8 @@ def watch(): test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) - rew = result["rews"].mean() - print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') + rew = result.returns_stat.mean + print(f"Mean reward (over {result.n_collected_episodes} episodes): {rew}") if args.watch: watch() diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index a710c44ac..5e3746e34 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -229,7 +229,7 @@ def watch(): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result['rews'].mean()}, length: {result['lens'].mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 1c29a1152..160137fce 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -366,7 +366,7 @@ def watch(): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result['rews'].mean()}, length: {result['lens'].mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index 617091d36..89384127c 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -166,7 +166,7 @@ def watch(): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result['rews'].mean()}, length: {result['lens'].mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index fef7d3f81..c4c4e9dd9 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -216,7 +216,7 @@ def watch(): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result['rews'].mean()}, length: {result['lens'].mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 03ff92e0d..46a33769f 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -190,10 +190,10 @@ def watch(): print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result["rews"].mean() - lens = result["lens"].mean() * args.skip_num - print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') - print(f'Mean length (over {result["n/ep"]} episodes): {lens}') + rew = result.returns_stat.mean + lens = result.lens_stat.mean * args.skip_num + print(f"Mean reward (over {result.n_collected_episodes} episodes): {rew}") + print(f"Mean length (over {result.n_collected_episodes} episodes): {lens}") if args.watch: watch() diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index d84a9a34a..0811a24dd 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -255,10 +255,10 @@ def watch(): print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result["rews"].mean() - lens = result["lens"].mean() * args.skip_num - print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') - print(f'Mean length (over {result["n/ep"]} episodes): {lens}') + rew = result.returns_stat.mean + lens = result.lens_stat.mean * args.skip_num + print(f"Mean reward (over {result.n_collected_episodes} episodes): {rew}") + print(f"Mean length (over {result.n_collected_episodes} episodes): {lens}") if args.watch: watch() diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 1f8d4f556..985e6ef50 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -188,7 +188,7 @@ def test_collector(gym_reset_kwargs): assert np.all(c2obs == obs1) or np.all(c2obs == obs2) c2.reset_env(gym_reset_kwargs=gym_reset_kwargs) c2.reset_buffer() - assert c2.collect(n_episode=8, gym_reset_kwargs=gym_reset_kwargs)["n/ep"] == 8 + assert c2.collect(n_episode=8, gym_reset_kwargs=gym_reset_kwargs).n_collected_episodes == 8 valid_indices = [4, 5, 28, 29, 30, 54, 55, 56, 57] obs[valid_indices] = [0, 1, 0, 1, 2, 0, 1, 2, 3] assert np.all(c2.buffer.obs[:, 0] == obs) @@ -237,9 +237,9 @@ def test_collector_with_async(gym_reset_kwargs): ptr = [0, 0, 0, 0] for n_episode in tqdm.trange(1, 30, desc="test async n_episode"): result = c1.collect(n_episode=n_episode, gym_reset_kwargs=gym_reset_kwargs) - assert result["n/ep"] >= n_episode + assert result.n_collected_episodes >= n_episode # check buffer data, obs and obs_next, env_id - for i, count in enumerate(np.bincount(result["lens"], minlength=6)[2:]): + for i, count in enumerate(np.bincount(result.lens, minlength=6)[2:]): env_len = i + 2 total = env_len * count indices = np.arange(ptr[i], ptr[i] + total) % bufsize @@ -252,7 +252,7 @@ def test_collector_with_async(gym_reset_kwargs): # test async n_step, for now the buffer should be full of data for n_step in tqdm.trange(1, 15, desc="test async n_step"): result = c1.collect(n_step=n_step, gym_reset_kwargs=gym_reset_kwargs) - assert result["n/st"] >= n_step + assert result.n_collected_steps >= n_step for i in range(4): env_len = i + 2 seq = np.arange(env_len) @@ -284,12 +284,12 @@ def test_collector_with_dict_state(): ) c1.collect(n_step=12) result = c1.collect(n_episode=8) - assert result["n/ep"] == 8 - lens = np.bincount(result["lens"]) + assert result.n_collected_episodes == 8 + lens = np.bincount(result.lens) assert ( - result["n/st"] == 21 + result.n_collected_steps == 21 and np.all(lens == [0, 0, 2, 2, 2, 2]) - or result["n/st"] == 20 + or result.n_collected_steps == 20 and np.all(lens == [0, 0, 3, 1, 2, 2]) ) batch, _ = c1.buffer.sample(10) @@ -407,9 +407,9 @@ def test_collector_with_ma(): policy = MyPolicy() c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn) # n_step=3 will collect a full episode - rew = c0.collect(n_step=3)["rews"] + rew = c0.collect(n_step=3).returns assert len(rew) == 0 - rew = c0.collect(n_episode=2)["rews"] + rew = c0.collect(n_episode=2).returns assert rew.shape == (2, 4) assert np.all(rew == 1) env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]] @@ -420,9 +420,9 @@ def test_collector_with_ma(): VectorReplayBuffer(total_size=100, buffer_num=4), Logger.single_preprocess_fn, ) - rew = c1.collect(n_step=12)["rews"] + rew = c1.collect(n_step=12).returns assert rew.shape == (2, 4) and np.all(rew == 1), rew - rew = c1.collect(n_episode=8)["rews"] + rew = c1.collect(n_episode=8).returns assert rew.shape == (8, 4) assert np.all(rew == 1) batch, _ = c1.buffer.sample(10) @@ -528,7 +528,7 @@ def test_collector_with_ma(): VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4), Logger.single_preprocess_fn, ) - rew = c2.collect(n_episode=10)["rews"] + rew = c2.collect(n_episode=10).returns assert rew.shape == (10, 4) assert np.all(rew == 1) batch, _ = c2.buffer.sample(10) @@ -580,8 +580,8 @@ def test_collector_with_atari_setting(): c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4)) c3.collect(n_step=12) result = c3.collect(n_episode=9) - assert result["n/ep"] == 9 - assert result["n/st"] == 23 + assert result.n_collected_episodes == 9 + assert result.n_collected_steps == 23 assert c3.buffer.obs.shape == (100, 4, 84, 84) obs = np.zeros_like(c3.buffer.obs) obs[np.arange(8)] = reference_obs[[0, 1, 0, 1, 0, 1, 0, 1]] @@ -608,8 +608,8 @@ def test_collector_with_atari_setting(): ) c4.collect(n_step=12) result = c4.collect(n_episode=9) - assert result["n/ep"] == 9 - assert result["n/st"] == 23 + assert result.n_collected_episodes == 9 + assert result.n_collected_steps == 23 assert c4.buffer.obs.shape == (100, 84, 84) obs = np.zeros_like(c4.buffer.obs) slice_obs = reference_obs[:, -1] @@ -676,8 +676,8 @@ def test_collector_with_atari_setting(): assert len(buf) == 5 assert len(c5.buffer) == 12 result = c5.collect(n_episode=9) - assert result["n/ep"] == 9 - assert result["n/st"] == 23 + assert result.n_collected_episodes == 9 + assert result.n_collected_steps == 23 assert len(buf) == 35 assert np.all( buf.obs[: len(buf)] @@ -768,11 +768,11 @@ def test_collector_with_atari_setting(): # test buffer=None c6 = Collector(policy, envs) result1 = c6.collect(n_step=12) - for key in ["n/ep", "n/st", "rews", "lens"]: - assert np.allclose(result1[key], result_[key]) + for key in ["n_collected_episodes", "n_collected_steps", "returns", "lens"]: + assert np.allclose(getattr(result1, key), getattr(result_, key)) result2 = c6.collect(n_episode=9) - for key in ["n/ep", "n/st", "rews", "lens"]: - assert np.allclose(result2[key], result[key]) + for key in ["n_collected_episodes", "n_collected_steps", "returns", "lens"]: + assert np.allclose(getattr(result2, key), getattr(result, key)) @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") diff --git a/test/base/test_logger.py b/test/base/test_logger.py new file mode 100644 index 000000000..1634f4d8f --- /dev/null +++ b/test/base/test_logger.py @@ -0,0 +1,59 @@ +import numpy as np +import pytest + +from tianshou.utils import BaseLogger + + +class TestBaseLogger: + @staticmethod + @pytest.mark.parametrize( + "input_dict, expected_output", + [ + ({"a": 1, "b": {"c": 2, "d": {"e": 3}}}, {"a": 1, "b/c": 2, "b/d/e": 3}), + ({"a": {"b": {"c": 1}}}, {"a/b/c": 1}), + ], + ) + def test_flatten_dict_basic(input_dict, expected_output): + result = BaseLogger.prepare_dict_for_logging(input_dict) + assert result == expected_output + + @staticmethod + @pytest.mark.parametrize( + "input_dict, delimiter, expected_output", + [ + ({"a": {"b": {"c": 1}}}, "|", {"a|b|c": 1}), + ({"a": {"b": {"c": 1}}}, ".", {"a.b.c": 1}), + ], + ) + def test_flatten_dict_custom_delimiter(input_dict, delimiter, expected_output): + result = BaseLogger.prepare_dict_for_logging(input_dict, delimiter=delimiter) + assert result == expected_output + + @staticmethod + @pytest.mark.parametrize( + "input_dict, exclude_arrays, expected_output", + [ + ( + {"a": np.array([1, 2, 3]), "b": {"c": np.array([4, 5, 6])}}, + False, + {"a": np.array([1, 2, 3]), "b/c": np.array([4, 5, 6])}, + ), + ({"a": np.array([1, 2, 3]), "b": {"c": np.array([4, 5, 6])}}, True, {}), + ], + ) + def test_flatten_dict_exclude_arrays(input_dict, exclude_arrays, expected_output): + result = BaseLogger.prepare_dict_for_logging(input_dict, exclude_arrays=exclude_arrays) + assert result.keys() == expected_output.keys() + for val1, val2 in zip(result.values(), expected_output.values(), strict=True): + assert np.all(val1 == val2) + + @staticmethod + @pytest.mark.parametrize( + "input_dict, expected_output", + [ + ({"a": (1,), "b": {"c": "2", "d": {"e": 3}}}, {"b/d/e": 3}), + ], + ) + def test_flatten_dict_invalid_values_filtered_out(input_dict, expected_output): + result = BaseLogger.prepare_dict_for_logging(input_dict) + assert result == expected_output diff --git a/test/base/test_stats.py b/test/base/test_stats.py new file mode 100644 index 000000000..b9ec67a12 --- /dev/null +++ b/test/base/test_stats.py @@ -0,0 +1,40 @@ +import pytest + +from tianshou.policy.base import TrainingStats, TrainingStatsWrapper + + +class DummyTrainingStatsWrapper(TrainingStatsWrapper): + def __init__(self, wrapped_stats: TrainingStats, *, dummy_field: int): + self.dummy_field = dummy_field + super().__init__(wrapped_stats) + + +class TestStats: + @staticmethod + def test_training_stats_wrapper(): + train_stats = TrainingStats(train_time=1.0) + train_stats.loss_field = 12 + + wrapped_train_stats = DummyTrainingStatsWrapper(train_stats, dummy_field=42) + + # basic readout + assert wrapped_train_stats.train_time == 1.0 + assert wrapped_train_stats.loss_field == 12 + + # mutation of TrainingStats fields + wrapped_train_stats.train_time = 2.0 + wrapped_train_stats.smoothed_loss["foo"] = 50 + assert wrapped_train_stats.train_time == 2.0 + assert wrapped_train_stats.smoothed_loss["foo"] == 50 + + # loss stats dict + assert wrapped_train_stats.get_loss_stats_dict() == {"loss_field": 12, "dummy_field": 42} + + # new fields can't be added + with pytest.raises(AttributeError): + wrapped_train_stats.new_loss_field = 90 + + # existing fields, wrapped and not-wrapped, can be mutated + wrapped_train_stats.loss_field = 13 + wrapped_train_stats.dummy_field = 43 + assert wrapped_train_stats.wrapped_stats.loss_field == wrapped_train_stats.loss_field == 13 diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 360a80393..30e0e33de 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -126,7 +126,7 @@ def stop_fn(mean_rewards): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -135,8 +135,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index ab02e4cde..edd51f274 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -144,7 +144,7 @@ def stop_fn(mean_rewards): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -153,8 +153,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 9fe6f1895..1327ede6e 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -174,22 +174,21 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): save_checkpoint_fn=save_checkpoint_fn, ) - for epoch, epoch_stat, info in trainer: - print(f"Epoch: {epoch}") + for epoch_stat in trainer: + print(f"Epoch: {epoch_stat.epoch}") print(epoch_stat) - print(info) + # print(info) - assert stop_fn(info["best_reward"]) + assert stop_fn(epoch_stat.info_stat.best_reward) if __name__ == "__main__": - pprint.pprint(info) + pprint.pprint(epoch_stat) # Let's watch its performance! env = gym.make(args.task) policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") def test_ppo_resume(args=get_args()): diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index 87588e5b4..2258d6c20 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -153,7 +153,7 @@ def stop_fn(mean_rewards): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -162,8 +162,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 13b44b5e9..9c18904d6 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -1,12 +1,14 @@ import argparse import os +import gymnasium as gym import numpy as np import pytest import torch from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.policy import ImitationPolicy, SACPolicy from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger @@ -57,8 +59,11 @@ def get_args(): @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") def test_sac_with_il(args=get_args()): # if you want to use python vector env, please refer to other test scripts - train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed) - test_envs = envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed) + # train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed) + # test_envs = envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed) + env = gym.make(args.task) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n args.max_action = env.action_space.high[0] @@ -146,7 +151,7 @@ def stop_fn(mean_rewards): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) # here we define an imitation collector with a trivial policy policy.eval() @@ -172,7 +177,8 @@ def stop_fn(mean_rewards): ) il_test_collector = Collector( il_policy, - envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed), + # envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed), + gym.make(args.task), ) train_collector.reset() result = OffpolicyTrainer( @@ -188,7 +194,7 @@ def stop_fn(mean_rewards): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 76659d0e0..ecad08bb0 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -143,22 +143,21 @@ def stop_fn(mean_rewards): save_best_fn=save_best_fn, logger=logger, ) - for epoch, epoch_stat, info in trainer: - print(f"Epoch: {epoch}") - print(epoch_stat) - print(info) + for epoch_stat in trainer: + print(f"Epoch: {epoch_stat.epoch}") + pprint.pprint(epoch_stat) + # print(info) - assert stop_fn(info["best_reward"]) + assert stop_fn(epoch_stat.info_stat.best_reward) if __name__ == "__main__": - pprint.pprint(info) + pprint.pprint(epoch_stat.info_stat) # Let's watch its performance! env = gym.make(args.task) policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index ee3f2bb5f..300c96db2 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -148,7 +148,7 @@ def stop_fn(mean_rewards): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -157,8 +157,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index e833cc36d..97205c5e7 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -135,7 +135,7 @@ def stop_fn(mean_rewards): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -144,8 +144,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") policy.eval() # here we define an imitation collector with a trivial policy @@ -173,7 +172,7 @@ def stop_fn(mean_rewards): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -182,8 +181,7 @@ def stop_fn(mean_rewards): il_policy.eval() collector = Collector(il_policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py index 469c04cf4..f63224331 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdq.py @@ -136,7 +136,7 @@ def stop_fn(mean_rewards): stop_fn=stop_fn, ).run() - # assert stop_fn(result["best_reward"]) + # assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! @@ -145,8 +145,9 @@ def stop_fn(mean_rewards): test_envs.seed(args.seed) test_collector.reset() collector_result = test_collector.collect(n_episode=args.test_num, render=args.render) - rews, lens = collector_result["rews"], collector_result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print( + f"Final reward: {collector_result.returns_stat.mean}, length: {collector_result.lens_stat.mean}", + ) if __name__ == "__main__": diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index edb903240..e4406df18 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -186,7 +186,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -196,8 +196,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") def test_c51_resume(args=get_args()): diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 6f23cfbad..751f849ff 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -139,7 +139,7 @@ def test_fn(epoch, env_step): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -149,8 +149,7 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") def test_pdqn(args=get_args()): diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 68b3d53d3..5ca79fd30 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -122,7 +122,7 @@ def test_fn(epoch, env_step): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -131,8 +131,7 @@ def test_fn(epoch, env_step): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 55a66e450..84dd207e8 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -156,7 +156,7 @@ def test_fn(epoch, env_step): logger=logger, update_per_step=args.update_per_step, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -166,8 +166,7 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") def test_pfqf(args=get_args()): diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 504ada3a4..23c35b9a8 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -152,7 +152,7 @@ def test_fn(epoch, env_step): logger=logger, update_per_step=args.update_per_step, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -162,8 +162,7 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") def test_piqn(args=get_args()): diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 310f0a135..6d9873ab3 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -117,7 +117,7 @@ def stop_fn(mean_rewards): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -126,8 +126,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 53d2b29f5..e66dc23a9 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -140,7 +140,7 @@ def stop_fn(mean_rewards): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -149,8 +149,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index e68211e96..6d54d59cf 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -143,7 +143,7 @@ def test_fn(epoch, env_step): logger=logger, update_per_step=args.update_per_step, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -153,8 +153,7 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") def test_pqrdqn(args=get_args()): diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index c36d7a136..c7e51a36e 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -205,7 +205,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -215,8 +215,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") def test_rainbow_resume(args=get_args()): diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 49f4b9648..4831f12de 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -129,7 +129,7 @@ def stop_fn(mean_rewards): update_per_step=args.update_per_step, test_in_train=False, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -138,8 +138,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index e8c4a80d2..ecfb4d2b6 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -183,7 +183,7 @@ def test_fn(epoch, env_step): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -193,8 +193,7 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index e0858b14f..27adc12d8 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -179,7 +179,7 @@ def stop_fn(mean_rewards): save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -188,8 +188,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index 45ddf1647..d2fbf7358 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -123,10 +123,9 @@ def stop_fn(mean_rewards): test_envs.seed(args.seed) test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.rew_mean}, length: {result.len_mean}") elif env.spec.reward_threshold: - assert result["best_reward"] >= env.spec.reward_threshold + assert result.best_reward >= env.spec.reward_threshold if __name__ == "__main__": diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index c42524cb6..cf4a5bd24 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -147,7 +147,7 @@ def test_fn(epoch, env_step): logger=logger, update_per_step=args.update_per_step, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) # save buffer in pickle format, for imitation learning unittest buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs)) @@ -159,5 +159,5 @@ def test_fn(epoch, env_step): else: with open(args.save_buffer_name, "wb") as f: pickle.dump(buf, f) - print(result["rews"].mean()) + print(result.returns_stat.mean) return buf diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index f08510a54..064369cd2 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -150,8 +150,7 @@ def stop_fn(mean_rewards): ).run() train_collector.reset() result = train_collector.collect(n_step=args.buffer_size) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if args.save_buffer_name.endswith(".hdf5"): buffer.save_hdf5(args.save_buffer_name) else: diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index da6e211ed..c82eadcd7 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -198,7 +198,7 @@ def watch(): logger=logger, show_progress=args.show_progress, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) # Let's watch its performance! if __name__ == "__main__": @@ -207,8 +207,7 @@ def watch(): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 1a7c86ba8..693eecc49 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -188,22 +188,23 @@ def stop_fn(mean_rewards): logger=logger, ) - for epoch, epoch_stat, info in trainer: - print(f"Epoch: {epoch}") - print(epoch_stat) - print(info) + for epoch_stat in trainer: + print(f"Epoch: {epoch_stat.epoch}") + pprint.pprint(epoch_stat) + # print(info) - assert stop_fn(info["best_reward"]) + assert stop_fn(epoch_stat.info_stat.best_reward) # Let's watch its performance! if __name__ == "__main__": - pprint.pprint(info) + pprint.pprint(epoch_stat.info_stat) env = gym.make(args.task) policy.eval() collector = Collector(policy, env) collector_result = collector.collect(n_episode=1, render=args.render) - rews, lens = collector_result["rews"], collector_result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print( + f"Final reward: {collector_result.returns_stat.mean}, length: {collector_result.lens_stat.mean}", + ) if __name__ == "__main__": diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 93f342ab0..4f4208b5e 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -157,7 +157,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -167,8 +167,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") def test_discrete_bcq_resume(args=get_args()): diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index b795d4eac..b7e4cc567 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -119,7 +119,7 @@ def stop_fn(mean_rewards): logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -129,8 +129,7 @@ def stop_fn(mean_rewards): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index 83fd79c69..ea880a530 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -122,7 +122,7 @@ def stop_fn(mean_rewards): logger=logger, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -131,8 +131,7 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index 4257933af..650aabb8a 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -214,7 +214,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, ).run() - assert stop_fn(result["best_reward"]) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) @@ -223,8 +223,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") if __name__ == "__main__": diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index 9cd1f6ea4..af915844d 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -181,22 +181,23 @@ def stop_fn(mean_rewards): logger=logger, ) - for epoch, epoch_stat, info in trainer: - print(f"Epoch: {epoch}") + for epoch_stat in trainer: + print(f"Epoch: {epoch_stat.epoch}") print(epoch_stat) - print(info) + # print(info) - assert stop_fn(info["best_reward"]) + assert stop_fn(epoch_stat.info_stat.best_reward) # Let's watch its performance! if __name__ == "__main__": - pprint.pprint(info) + pprint.pprint(epoch_stat.info_stat) env = gym.make(args.task) policy.eval() collector = Collector(policy, env) collector_result = collector.collect(n_episode=1, render=args.render) - rews, lens = collector_result["rews"], collector_result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print( + f"Final reward: {collector_result.returns_stat.mean}, length: {collector_result.lens_stat.mean}", + ) if __name__ == "__main__": diff --git a/test/pettingzoo/test_pistonball.py b/test/pettingzoo/test_pistonball.py index 4a6c59655..5043a96ca 100644 --- a/test/pettingzoo/test_pistonball.py +++ b/test/pettingzoo/test_pistonball.py @@ -9,7 +9,7 @@ def test_piston_ball(args=get_args()): return result, agent = train_agent(args) - # assert result["best_reward"] >= args.win_rate + # assert result.best_reward >= args.win_rate if __name__ == "__main__": pprint.pprint(result) diff --git a/test/pettingzoo/test_pistonball_continuous.py b/test/pettingzoo/test_pistonball_continuous.py index afb0d5448..f85884d53 100644 --- a/test/pettingzoo/test_pistonball_continuous.py +++ b/test/pettingzoo/test_pistonball_continuous.py @@ -11,7 +11,7 @@ def test_piston_ball_continuous(args=get_args()): return result, agent = train_agent(args) - # assert result["best_reward"] >= 30.0 + # assert result.best_reward >= 30.0 if __name__ == "__main__": pprint.pprint(result) diff --git a/test/pettingzoo/test_tic_tac_toe.py b/test/pettingzoo/test_tic_tac_toe.py index 524cdb92a..f689283f2 100644 --- a/test/pettingzoo/test_tic_tac_toe.py +++ b/test/pettingzoo/test_tic_tac_toe.py @@ -9,7 +9,7 @@ def test_tic_tac_toe(args=get_args()): return result, agent = train_agent(args) - assert result["best_reward"] >= args.win_rate + assert result.best_reward >= args.win_rate if __name__ == "__main__": pprint.pprint(result) diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index 7a86ce857..623079890 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -18,7 +18,13 @@ VectorReplayBuffer, ) from tianshou.data.buffer.cached import CachedReplayBuffer -from tianshou.data.collector import Collector, AsyncCollector +from tianshou.data.stats import ( + EpochStats, + InfoStats, + SequenceSummaryStats, + TimingStats, +) +from tianshou.data.collector import Collector, AsyncCollector, CollectStats, CollectStatsBase __all__ = [ "Batch", @@ -37,5 +43,11 @@ "HERVectorReplayBuffer", "CachedReplayBuffer", "Collector", + "CollectStats", + "CollectStatsBase", "AsyncCollector", + "EpochStats", + "InfoStats", + "SequenceSummaryStats", + "TimingStats", ] diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index da0442078..f188f3ca0 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -1,6 +1,7 @@ import time import warnings from collections.abc import Callable +from dataclasses import dataclass from typing import Any, cast import gymnasium as gym @@ -13,6 +14,7 @@ PrioritizedReplayBuffer, ReplayBuffer, ReplayBufferManager, + SequenceSummaryStats, VectorReplayBuffer, to_numpy, ) @@ -22,6 +24,34 @@ from tianshou.policy import BasePolicy +@dataclass(kw_only=True) +class CollectStatsBase: + """The most basic stats, often used for offline learning.""" + + n_collected_episodes: int = 0 + """The number of collected episodes.""" + n_collected_steps: int = 0 + """The number of collected steps.""" + + +@dataclass(kw_only=True) +class CollectStats(CollectStatsBase): + """A data structure for storing the statistics of rollouts.""" + + collect_time: float = 0.0 + """The time for collecting transitions.""" + collect_speed: float = 0.0 + """The speed of collecting (env_step per second).""" + returns: np.ndarray + """The collected episode returns.""" + returns_stat: SequenceSummaryStats | None # can be None if no episode ends during collect step + """Stats of the collected returns.""" + lens: np.ndarray + """The collected episode lengths.""" + lens_stat: SequenceSummaryStats | None # can be None if no episode ends during collect step + """Stats of the collected episode lengths.""" + + class Collector: """Collector enables the policy to interact with different types of envs with exact number of steps or episodes. @@ -191,7 +221,7 @@ def collect( render: float | None = None, no_grad: bool = True, gym_reset_kwargs: dict[str, Any] | None = None, - ) -> dict[str, Any]: + ) -> CollectStats: """Collect a specified number of step or episode. To ensure unbiased sampling result with n_episode option, this function will @@ -214,17 +244,7 @@ def collect( One and only one collection number specification is permitted, either ``n_step`` or ``n_episode``. - :return: A dict including the following keys - - * ``n/ep`` collected number of episodes. - * ``n/st`` collected number of steps. - * ``rews`` array of episode reward over collected episodes. - * ``lens`` array of episode length over collected episodes. - * ``idxs`` array of episode start index in buffer over collected episodes. - * ``rew`` mean of episodic rewards. - * ``len`` mean of episodic lengths. - * ``rew_std`` standard error of episodic rewards. - * ``len_std`` standard error of episodic lengths. + :return: A dataclass object """ assert not self.env.is_async, "Please use AsyncCollector if using async venv." if n_step is not None: @@ -253,9 +273,9 @@ def collect( step_count = 0 episode_count = 0 - episode_rews = [] - episode_lens = [] - episode_start_indices = [] + episode_returns: list[float] = [] + episode_lens: list[int] = [] + episode_start_indices: list[int] = [] while True: assert len(self.data) == len(ready_env_ids) @@ -334,9 +354,9 @@ def collect( env_ind_local = np.where(done)[0] env_ind_global = ready_env_ids[env_ind_local] episode_count += len(env_ind_local) - episode_lens.append(ep_len[env_ind_local]) - episode_rews.append(ep_rew[env_ind_local]) - episode_start_indices.append(ep_idx[env_ind_local]) + episode_lens.extend(ep_len[env_ind_local]) + episode_returns.extend(ep_rew[env_ind_local]) + episode_start_indices.extend(ep_idx[env_ind_local]) # now we copy obs_next to obs, but since there might be # finished episodes, we have to reset finished envs first. self._reset_env_with_ids(env_ind_local, env_ind_global, gym_reset_kwargs) @@ -361,7 +381,8 @@ def collect( # generate statistics self.collect_step += step_count self.collect_episode += episode_count - self.collect_time += max(time.time() - start_time, 1e-9) + collect_time = max(time.time() - start_time, 1e-9) + self.collect_time += collect_time if n_episode: data = Batch( @@ -378,27 +399,20 @@ def collect( self.data = cast(RolloutBatchProtocol, data) self.reset_env() - if episode_count > 0: - rews, lens, idxs = list( - map(np.concatenate, [episode_rews, episode_lens, episode_start_indices]), - ) - rew_mean, rew_std = rews.mean(), rews.std() - len_mean, len_std = lens.mean(), lens.std() - else: - rews, lens, idxs = np.array([]), np.array([], int), np.array([], int) - rew_mean = rew_std = len_mean = len_std = 0 - - return { - "n/ep": episode_count, - "n/st": step_count, - "rews": rews, - "lens": lens, - "idxs": idxs, - "rew": rew_mean, - "len": len_mean, - "rew_std": rew_std, - "len_std": len_std, - } + return CollectStats( + n_collected_episodes=episode_count, + n_collected_steps=step_count, + collect_time=collect_time, + collect_speed=step_count / collect_time, + returns=np.array(episode_returns), + returns_stat=SequenceSummaryStats.from_sequence(episode_returns) + if len(episode_returns) > 0 + else None, + lens=np.array(episode_lens, int), + lens_stat=SequenceSummaryStats.from_sequence(episode_lens) + if len(episode_lens) > 0 + else None, + ) class AsyncCollector(Collector): @@ -438,7 +452,7 @@ def collect( render: float | None = None, no_grad: bool = True, gym_reset_kwargs: dict[str, Any] | None = None, - ) -> dict[str, Any]: + ) -> CollectStats: """Collect a specified number of step or episode with async env setting. This function doesn't collect exactly n_step or n_episode number of @@ -461,17 +475,7 @@ def collect( One and only one collection number specification is permitted, either ``n_step`` or ``n_episode``. - :return: A dict including the following keys - - * ``n/ep`` collected number of episodes. - * ``n/st`` collected number of steps. - * ``rews`` array of episode reward over collected episodes. - * ``lens`` array of episode length over collected episodes. - * ``idxs`` array of episode start index in buffer over collected episodes. - * ``rew`` mean of episodic rewards. - * ``len`` mean of episodic lengths. - * ``rew_std`` standard error of episodic rewards. - * ``len_std`` standard error of episodic lengths. + :return: A dataclass object """ # collect at least n_step or n_episode if n_step is not None: @@ -494,9 +498,9 @@ def collect( step_count = 0 episode_count = 0 - episode_rews = [] - episode_lens = [] - episode_start_indices = [] + episode_returns: list[float] = [] + episode_lens: list[int] = [] + episode_start_indices: list[int] = [] while True: whole_data = self.data @@ -602,9 +606,9 @@ def collect( env_ind_local = np.where(done)[0] env_ind_global = ready_env_ids[env_ind_local] episode_count += len(env_ind_local) - episode_lens.append(ep_len[env_ind_local]) - episode_rews.append(ep_rew[env_ind_local]) - episode_start_indices.append(ep_idx[env_ind_local]) + episode_lens.extend(ep_len[env_ind_local]) + episode_returns.extend(ep_rew[env_ind_local]) + episode_start_indices.extend(ep_idx[env_ind_local]) # now we copy obs_next to obs, but since there might be # finished episodes, we have to reset finished envs first. self._reset_env_with_ids(env_ind_local, env_ind_global, gym_reset_kwargs) @@ -633,26 +637,20 @@ def collect( # generate statistics self.collect_step += step_count self.collect_episode += episode_count - self.collect_time += max(time.time() - start_time, 1e-9) - - if episode_count > 0: - rews, lens, idxs = list( - map(np.concatenate, [episode_rews, episode_lens, episode_start_indices]), - ) - rew_mean, rew_std = rews.mean(), rews.std() - len_mean, len_std = lens.mean(), lens.std() - else: - rews, lens, idxs = np.array([]), np.array([], int), np.array([], int) - rew_mean = rew_std = len_mean = len_std = 0 - - return { - "n/ep": episode_count, - "n/st": step_count, - "rews": rews, - "lens": lens, - "idxs": idxs, - "rew": rew_mean, - "len": len_mean, - "rew_std": rew_std, - "len_std": len_std, - } + collect_time = max(time.time() - start_time, 1e-9) + self.collect_time += collect_time + + return CollectStats( + n_collected_episodes=episode_count, + n_collected_steps=step_count, + collect_time=collect_time, + collect_speed=step_count / collect_time, + returns=np.array(episode_returns), + returns_stat=SequenceSummaryStats.from_sequence(episode_returns) + if len(episode_returns) > 0 + else None, + lens=np.array(episode_lens, int), + lens_stat=SequenceSummaryStats.from_sequence(episode_lens) + if len(episode_lens) > 0 + else None, + ) diff --git a/tianshou/data/stats.py b/tianshou/data/stats.py new file mode 100644 index 000000000..980b3f84d --- /dev/null +++ b/tianshou/data/stats.py @@ -0,0 +1,86 @@ +from collections.abc import Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +import numpy as np + +if TYPE_CHECKING: + from tianshou.data import CollectStats, CollectStatsBase + from tianshou.policy.base import TrainingStats + + +@dataclass(kw_only=True) +class SequenceSummaryStats: + """A data structure for storing the statistics of a sequence.""" + + mean: float + std: float + max: float + min: float + + @classmethod + def from_sequence(cls, sequence: Sequence[float | int] | np.ndarray) -> "SequenceSummaryStats": + return cls( + mean=float(np.mean(sequence)), + std=float(np.std(sequence)), + max=float(np.max(sequence)), + min=float(np.min(sequence)), + ) + + +@dataclass(kw_only=True) +class TimingStats: + """A data structure for storing timing statistics.""" + + total_time: float = 0.0 + """The total time elapsed.""" + train_time: float = 0.0 + """The total time elapsed for training (collecting samples plus model update).""" + train_time_collect: float = 0.0 + """The total time elapsed for collecting training transitions.""" + train_time_update: float = 0.0 + """The total time elapsed for updating models.""" + test_time: float = 0.0 + """The total time elapsed for testing models.""" + update_speed: float = 0.0 + """The speed of updating (env_step per second).""" + + +@dataclass(kw_only=True) +class InfoStats: + """A data structure for storing information about the learning process.""" + + gradient_step: int + """The total gradient step.""" + best_reward: float + """The best reward over the test results.""" + best_reward_std: float + """Standard deviation of the best reward over the test results.""" + train_step: int + """The total collected step of training collector.""" + train_episode: int + """The total collected episode of training collector.""" + test_step: int + """The total collected step of test collector.""" + test_episode: int + """The total collected episode of test collector.""" + + timing: TimingStats + """The timing statistics.""" + + +@dataclass(kw_only=True) +class EpochStats: + """A data structure for storing epoch statistics.""" + + epoch: int + """The current epoch.""" + + train_collect_stat: "CollectStatsBase" + """The statistics of the last call to the training collector.""" + test_collect_stat: Optional["CollectStats"] + """The statistics of the last call to the test collector.""" + training_stat: "TrainingStats" + """The statistics of the last model update step.""" + info_stat: InfoStats + """The information of the collector.""" diff --git a/tianshou/data/types.py b/tianshou/data/types.py index a63a9d1c0..eb65f75d3 100644 --- a/tianshou/data/types.py +++ b/tianshou/data/types.py @@ -1,3 +1,5 @@ +from typing import Protocol + import numpy as np import torch @@ -5,7 +7,7 @@ from tianshou.data.batch import BatchProtocol, arr_type -class ObsBatchProtocol(BatchProtocol): +class ObsBatchProtocol(BatchProtocol, Protocol): """Observations of an environment that a policy can turn into actions. Typically used inside a policy's forward @@ -15,7 +17,7 @@ class ObsBatchProtocol(BatchProtocol): info: arr_type -class RolloutBatchProtocol(ObsBatchProtocol): +class RolloutBatchProtocol(ObsBatchProtocol, Protocol): """Typically, the outcome of sampling from a replay buffer.""" obs_next: arr_type | BatchProtocol @@ -25,52 +27,52 @@ class RolloutBatchProtocol(ObsBatchProtocol): truncated: arr_type -class BatchWithReturnsProtocol(RolloutBatchProtocol): +class BatchWithReturnsProtocol(RolloutBatchProtocol, Protocol): """With added returns, usually computed with GAE.""" returns: arr_type -class PrioBatchProtocol(RolloutBatchProtocol): +class PrioBatchProtocol(RolloutBatchProtocol, Protocol): """Contains weights that can be used for prioritized replay.""" weight: np.ndarray | torch.Tensor -class RecurrentStateBatch(BatchProtocol): +class RecurrentStateBatch(BatchProtocol, Protocol): """Used by RNNs in policies, contains `hidden` and `cell` fields.""" hidden: torch.Tensor cell: torch.Tensor -class ActBatchProtocol(BatchProtocol): +class ActBatchProtocol(BatchProtocol, Protocol): """Simplest batch, just containing the action. Useful e.g., for random policy.""" act: arr_type -class ActStateBatchProtocol(ActBatchProtocol): +class ActStateBatchProtocol(ActBatchProtocol, Protocol): """Contains action and state (which can be None), useful for policies that can support RNNs.""" state: dict | BatchProtocol | np.ndarray | None -class ModelOutputBatchProtocol(ActStateBatchProtocol): +class ModelOutputBatchProtocol(ActStateBatchProtocol, Protocol): """In addition to state and action, contains model output: (logits).""" logits: torch.Tensor state: dict | BatchProtocol | np.ndarray | None -class FQFBatchProtocol(ModelOutputBatchProtocol): +class FQFBatchProtocol(ModelOutputBatchProtocol, Protocol): """Model outputs, fractions and quantiles_tau - specific to the FQF model.""" fractions: torch.Tensor quantiles_tau: torch.Tensor -class BatchWithAdvantagesProtocol(BatchWithReturnsProtocol): +class BatchWithAdvantagesProtocol(BatchWithReturnsProtocol, Protocol): """Contains estimated advantages and values. Returns are usually computed from GAE of advantages by adding the value. @@ -80,7 +82,7 @@ class BatchWithAdvantagesProtocol(BatchWithReturnsProtocol): v_s: torch.Tensor -class DistBatchProtocol(ModelOutputBatchProtocol): +class DistBatchProtocol(ModelOutputBatchProtocol, Protocol): """Contains dist instances for actions (created by dist_fn). Usually categorical or normal. @@ -89,13 +91,13 @@ class DistBatchProtocol(ModelOutputBatchProtocol): dist: torch.distributions.Distribution -class DistLogProbBatchProtocol(DistBatchProtocol): +class DistLogProbBatchProtocol(DistBatchProtocol, Protocol): """Contains dist objects that can be sampled from and log_prob of taken action.""" log_prob: torch.Tensor -class LogpOldProtocol(BatchWithAdvantagesProtocol): +class LogpOldProtocol(BatchWithAdvantagesProtocol, Protocol): """Contains logp_old, often needed for importance weights, in particular in PPO. Builds on batches that contain advantages and values. @@ -104,7 +106,7 @@ class LogpOldProtocol(BatchWithAdvantagesProtocol): logp_old: torch.Tensor -class QuantileRegressionBatchProtocol(ModelOutputBatchProtocol): +class QuantileRegressionBatchProtocol(ModelOutputBatchProtocol, Protocol): """Contains taus for algorithms using quantile regression. See e.g. https://arxiv.org/abs/1806.06923 @@ -113,7 +115,7 @@ class QuantileRegressionBatchProtocol(ModelOutputBatchProtocol): taus: torch.Tensor -class ImitationBatchProtocol(ActBatchProtocol): +class ImitationBatchProtocol(ActBatchProtocol, Protocol): """Similar to other batches, but contains imitation_logits and q_value fields.""" state: dict | Batch | np.ndarray | None diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 361c3a94c..3989d3583 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -4,12 +4,12 @@ from collections.abc import Sequence from dataclasses import dataclass from pprint import pformat -from typing import Any, Self +from typing import Self import numpy as np import torch -from tianshou.data import Collector +from tianshou.data import Collector, InfoStats from tianshou.highlevel.agent import ( A2CAgentFactory, AgentFactory, @@ -121,8 +121,8 @@ class ExperimentResult: world: World """contains all the essential instances of the experiment""" - trainer_result: dict[str, Any] | None - """dictionary of results as returned by the trainer (if any)""" + trainer_result: InfoStats | None + """dataclass of results as returned by the trainer (if any)""" class Experiment(ToStringMixin): @@ -280,7 +280,7 @@ def run( # train policy log.info("Starting training") - trainer_result: dict[str, Any] | None = None + trainer_result: InfoStats | None = None if self.config.train: trainer = self.agent_factory.create_trainer(world, policy_persistence) world.trainer = trainer @@ -309,7 +309,9 @@ def _watch_agent( policy.eval() test_collector.reset() result = test_collector.collect(n_episode=num_episodes, render=render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + assert result.returns_stat is not None # for mypy + assert result.lens_stat is not None # for mypy + print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}") class ExperimentBuilder: diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index c8fa45e8e..5e6967ad7 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -1,7 +1,7 @@ """Policy package.""" # isort:skip_file -from tianshou.policy.base import BasePolicy +from tianshou.policy.base import BasePolicy, TrainingStats from tianshou.policy.random import RandomPolicy from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.modelfree.bdq import BranchingDQNPolicy @@ -63,4 +63,5 @@ "PSRLPolicy", "ICMPolicy", "MultiAgentPolicyManager", + "TrainingStats", ] diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 3de0d0fda..73dd45572 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -1,16 +1,19 @@ import logging +import time from abc import ABC, abstractmethod from collections.abc import Callable -from typing import Any, Literal, TypeAlias, cast +from dataclasses import dataclass, field +from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast import gymnasium as gym import numpy as np import torch from gymnasium.spaces import Box, Discrete, MultiBinary, MultiDiscrete from numba import njit +from overrides import override from torch import nn -from tianshou.data import ReplayBuffer, to_numpy, to_torch_as +from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_numpy, to_torch_as from tianshou.data.batch import Batch, BatchProtocol, arr_type from tianshou.data.buffer.base import TBuffer from tianshou.data.types import ( @@ -26,7 +29,106 @@ TLearningRateScheduler: TypeAlias = torch.optim.lr_scheduler.LRScheduler | MultipleLRSchedulers -class BasePolicy(ABC, nn.Module): +@dataclass(kw_only=True) +class TrainingStats: + _non_loss_fields = ("train_time", "smoothed_loss") + + train_time: float = 0.0 + """The time for learning models.""" + + # TODO: modified in the trainer but not used anywhere else. Should be refactored. + smoothed_loss: dict = field(default_factory=dict) + """The smoothed loss statistics of the policy learn step.""" + + # Mainly so that we can override this in the TrainingStatsWrapper + def _get_self_dict(self) -> dict[str, Any]: + return self.__dict__ + + def get_loss_stats_dict(self) -> dict[str, float]: + """Return loss statistics as a dict for logging. + + Returns a dict with all fields except train_time and smoothed_loss. Moreover, fields with value None excluded, + and instances of SequenceSummaryStats are replaced by their mean. + """ + result = {} + for k, v in self._get_self_dict().items(): + if k.startswith("_"): + logger.debug(f"Skipping {k=} as it starts with an underscore.") + continue + if k in self._non_loss_fields or v is None: + continue + if isinstance(v, SequenceSummaryStats): + result[k] = v.mean + else: + result[k] = v + + return result + + +class TrainingStatsWrapper(TrainingStats): + _setattr_frozen = False + _training_stats_public_fields = TrainingStats.__dataclass_fields__.keys() + + def __init__(self, wrapped_stats: TrainingStats) -> None: + """In this particular case, super().__init__() should be called LAST in the subclass init.""" + self._wrapped_stats = wrapped_stats + + # HACK: special sauce for the existing attributes of the base TrainingStats class + # for some reason, delattr doesn't work here, so we need to delegate their handling + # to the wrapped stats object by always keeping the value there and in self in sync + # see also __setattr__ + for k in self._training_stats_public_fields: + super().__setattr__(k, getattr(self._wrapped_stats, k)) + + self._setattr_frozen = True + + @override + def _get_self_dict(self) -> dict[str, Any]: + return {**self._wrapped_stats._get_self_dict(), **self.__dict__} + + @property + def wrapped_stats(self) -> TrainingStats: + return self._wrapped_stats + + def __getattr__(self, name: str) -> Any: + return getattr(self._wrapped_stats, name) + + def __setattr__(self, name: str, value: Any) -> None: + """Setattr logic for wrapper of a dataclass with default values. + + 1. If name exists directly in self, set it there. + 2. If it exists in self._wrapped_stats, set it there instead. + 3. Special case: if name is in the base TrainingStats class, keep it in sync between self and the _wrapped_stats. + 4. If name doesn't exist in either and attribute setting is frozen, raise an AttributeError. + """ + # HACK: special sauce for the existing attributes of the base TrainingStats class, see init + # Need to keep them in sync with the wrapped stats object + if name in self._training_stats_public_fields: + setattr(self._wrapped_stats, name, value) + super().__setattr__(name, value) + return + + if not self._setattr_frozen: + super().__setattr__(name, value) + return + + if not hasattr(self, name): + raise AttributeError( + f"Setting new attributes on StatsWrappers outside of init is not allowed. " + f"Tried to set {name=}, {value=} on {self.__class__.__name__}. \n" + f"NOTE: you may get this error if you call super().__init__() in your subclass init too early! " + f"The call to super().__init__() should be the last call in your subclass init.", + ) + if hasattr(self._wrapped_stats, name): + setattr(self._wrapped_stats, name, value) + else: + super().__setattr__(name, value) + + +TTrainingStats = TypeVar("TTrainingStats", bound=TrainingStats) + + +class BasePolicy(nn.Module, Generic[TTrainingStats], ABC): """The base class for any RL policy. Tianshou aims to modularize RL algorithms. It comes into several classes of @@ -321,10 +423,10 @@ def process_fn( return batch @abstractmethod - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, Any]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TTrainingStats: """Update policy with a given batch of data. - :return: A dict, including the data needed to be logged (e.g., loss). + :return: A dataclass object, including the data needed to be logged (e.g., loss). .. note:: @@ -372,13 +474,15 @@ def update( sample_size: int | None, buffer: ReplayBuffer | None, **kwargs: Any, - ) -> dict[str, Any]: + ) -> TTrainingStats: """Update the policy network and replay buffer. It includes 3 function steps: process_fn, learn, and post_process_fn. In addition, this function will change the value of ``self.updating``: it will be False before this function and will be True when executing :meth:`update`. - Please refer to :ref:`policy_state` for more detailed explanation. + Please refer to :ref:`policy_state` for more detailed explanation. The return + value of learn is augmented with the training time within update, while smoothed + loss values are computed in the trainer. :param sample_size: 0 means it will extract all the data from the buffer, otherwise it will sample a batch with given sample_size. None also @@ -386,20 +490,24 @@ def update( first. TODO: remove the option for 0? :param buffer: the corresponding replay buffer. - :return: A dict, including the data needed to be logged (e.g., loss) from + :return: A dataclass object containing the data needed to be logged (e.g., loss) from ``policy.learn()``. """ + # TODO: when does this happen? + # -> this happens never in practice as update is either called with a collector buffer or an assert before if buffer is None: - return {} + return TrainingStats() # type: ignore[return-value] + start_time = time.time() batch, indices = buffer.sample(sample_size) self.updating = True batch = self.process_fn(batch, buffer, indices) - result = self.learn(batch, **kwargs) + training_stat = self.learn(batch, **kwargs) self.post_process_fn(batch, buffer, indices) if self.lr_scheduler is not None: self.lr_scheduler.step() self.updating = False - return result + training_stat.train_time = time.time() - start_time + return training_stat @staticmethod def value_mask(buffer: ReplayBuffer, indices: np.ndarray) -> np.ndarray: diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index 5acacf1dd..1daa9ae71 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -1,4 +1,5 @@ -from typing import Any, Literal, cast +from dataclasses import dataclass +from typing import Any, Generic, Literal, TypeVar, cast import gymnasium as gym import numpy as np @@ -13,10 +14,18 @@ RolloutBatchProtocol, ) from tianshou.policy import BasePolicy -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import TLearningRateScheduler, TrainingStats -class ImitationPolicy(BasePolicy): +@dataclass(kw_only=True) +class ImitationTrainingStats(TrainingStats): + loss: float = 0.0 + + +TImitationTrainingStats = TypeVar("TImitationTrainingStats", bound=ImitationTrainingStats) + + +class ImitationPolicy(BasePolicy[TImitationTrainingStats], Generic[TImitationTrainingStats]): """Implementation of vanilla imitation learning. :param actor: a model following the rules in @@ -68,7 +77,12 @@ def forward( result = Batch(logits=logits, act=act, state=hidden) return cast(ModelOutputBatchProtocol, result) - def learn(self, batch: RolloutBatchProtocol, *ags: Any, **kwargs: Any) -> dict[str, float]: + def learn( + self, + batch: RolloutBatchProtocol, + *ags: Any, + **kwargs: Any, + ) -> TImitationTrainingStats: self.optim.zero_grad() if self.action_type == "continuous": # regression act = self(batch).act @@ -80,4 +94,5 @@ def learn(self, batch: RolloutBatchProtocol, *ags: Any, **kwargs: Any) -> dict[s loss = F.nll_loss(act, act_target) loss.backward() self.optim.step() - return {"loss": loss.item()} + + return ImitationTrainingStats(loss=loss.item()) # type: ignore diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/bcq.py index 14f9056d8..dee1a80a3 100644 --- a/tianshou/policy/imitation/bcq.py +++ b/tianshou/policy/imitation/bcq.py @@ -1,5 +1,6 @@ import copy -from typing import Any, Literal, Self, cast +from dataclasses import dataclass +from typing import Any, Generic, Literal, Self, TypeVar, cast import gymnasium as gym import numpy as np @@ -10,12 +11,23 @@ from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.policy import BasePolicy -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.utils.net.continuous import VAE from tianshou.utils.optim import clone_optimizer -class BCQPolicy(BasePolicy): +@dataclass(kw_only=True) +class BCQTrainingStats(TrainingStats): + actor_loss: float + critic1_loss: float + critic2_loss: float + vae_loss: float + + +TBCQTrainingStats = TypeVar("TBCQTrainingStats", bound=BCQTrainingStats) + + +class BCQPolicy(BasePolicy[TBCQTrainingStats], Generic[TBCQTrainingStats]): """Implementation of BCQ algorithm. arXiv:1812.02900. :param actor_perturbation: the actor perturbation. `(s, a -> perturbed a)` @@ -142,7 +154,7 @@ def sync_weight(self) -> None: self.soft_update(self.critic2_target, self.critic2, self.tau) self.soft_update(self.actor_perturbation_target, self.actor_perturbation, self.tau) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBCQTrainingStats: # batch: obs, act, rew, done, obs_next. (numpy array) # (batch_size, state_dim) batch: Batch = to_torch(batch, dtype=torch.float, device=self.device) @@ -213,9 +225,9 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ # update target network self.sync_weight() - return { - "loss/actor": actor_loss.item(), - "loss/critic1": critic1_loss.item(), - "loss/critic2": critic2_loss.item(), - "loss/vae": vae_loss.item(), - } + return BCQTrainingStats( # type: ignore + actor_loss=actor_loss.item(), + critic1_loss=critic1_loss.item(), + critic2_loss=critic2_loss.item(), + vae_loss=vae_loss.item(), + ) diff --git a/tianshou/policy/imitation/cql.py b/tianshou/policy/imitation/cql.py index c3bbf6b1e..1ce6d83d4 100644 --- a/tianshou/policy/imitation/cql.py +++ b/tianshou/policy/imitation/cql.py @@ -1,4 +1,5 @@ -from typing import Any, Literal, Self, cast +from dataclasses import dataclass +from typing import Any, Literal, Self, TypeVar, cast import gymnasium as gym import numpy as np @@ -13,10 +14,23 @@ from tianshou.exploration import BaseNoise from tianshou.policy import SACPolicy from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.sac import SACTrainingStats +from tianshou.utils.conversion import to_optional_float from tianshou.utils.net.continuous import ActorProb -class CQLPolicy(SACPolicy): +@dataclass(kw_only=True) +class CQLTrainingStats(SACTrainingStats): + """A data structure for storing loss statistics of the CQL learn step.""" + + cql_alpha: float | None = None + cql_alpha_loss: float | None = None + + +TCQLTrainingStats = TypeVar("TCQLTrainingStats", bound=CQLTrainingStats) + + +class CQLPolicy(SACPolicy[TCQLTrainingStats]): """Implementation of CQL algorithm. arXiv:2006.04779. :param actor: the actor network following the rules in @@ -233,7 +247,7 @@ def process_fn( # Should probably be fixed! return batch - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TCQLTrainingStats: # type: ignore batch: Batch = to_torch(batch, dtype=torch.float, device=self.device) obs, act, rew, obs_next = batch.obs, batch.act, batch.rew, batch.obs_next batch_size = obs.shape[0] @@ -244,6 +258,7 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ actor_loss.backward() self.actor_optim.step() + alpha_loss = None # compute alpha loss if self.is_auto_alpha: log_pi = log_pi + self.target_entropy @@ -341,6 +356,8 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ ) # shape: (1) + cql_alpha_loss = None + cql_alpha = None if self.with_lagrange: cql_alpha = torch.clamp( self.cql_log_alpha.exp(), @@ -373,16 +390,12 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ self.sync_weight() - result = { - "loss/actor": actor_loss.item(), - "loss/critic1": critic1_loss.item(), - "loss/critic2": critic2_loss.item(), - } - if self.is_auto_alpha: - self.alpha = cast(torch.Tensor, self.alpha) - result["loss/alpha"] = alpha_loss.item() - result["alpha"] = self.alpha.item() - if self.with_lagrange: - result["loss/cql_alpha"] = cql_alpha_loss.item() - result["cql_alpha"] = cql_alpha.item() - return result + return CQLTrainingStats( # type: ignore[return-value] + actor_loss=to_optional_float(actor_loss), + critic1_loss=to_optional_float(critic1_loss), + critic2_loss=to_optional_float(critic2_loss), + alpha=to_optional_float(self.alpha), + alpha_loss=to_optional_float(alpha_loss), + cql_alpha_loss=to_optional_float(cql_alpha_loss), + cql_alpha=to_optional_float(cql_alpha), + ) diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index d32d4ef6a..8412e0a60 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -1,5 +1,6 @@ import math -from typing import Any, Self, cast +from dataclasses import dataclass +from typing import Any, Self, TypeVar, cast import gymnasium as gym import numpy as np @@ -14,12 +15,23 @@ ) from tianshou.policy import DQNPolicy from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.dqn import DQNTrainingStats float_info = torch.finfo(torch.float32) INF = float_info.max -class DiscreteBCQPolicy(DQNPolicy): +@dataclass(kw_only=True) +class DiscreteBCQTrainingStats(DQNTrainingStats): + q_loss: float + i_loss: float + reg_loss: float + + +TDiscreteBCQTrainingStats = TypeVar("TDiscreteBCQTrainingStats", bound=DiscreteBCQTrainingStats) + + +class DiscreteBCQPolicy(DQNPolicy[TDiscreteBCQTrainingStats]): """Implementation of discrete BCQ algorithm. arXiv:1910.01708. :param model: a model following the rules in @@ -136,7 +148,12 @@ def forward( # type: ignore result = Batch(act=act, state=state, q_value=q_value, imitation_logits=imitation_logits) return cast(ImitationBatchProtocol, result) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn( + self, + batch: RolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> TDiscreteBCQTrainingStats: if self._iter % self.freq == 0: self.sync_weight() self._iter += 1 @@ -155,9 +172,9 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ loss.backward() self.optim.step() - return { - "loss": loss.item(), - "loss/q": q_loss.item(), - "loss/i": i_loss.item(), - "loss/reg": reg_loss.item(), - } + return DiscreteBCQTrainingStats( # type: ignore[return-value] + loss=loss.item(), + q_loss=q_loss.item(), + i_loss=i_loss.item(), + reg_loss=reg_loss.item(), + ) diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py index 924fd9947..dc23cb75a 100644 --- a/tianshou/policy/imitation/discrete_cql.py +++ b/tianshou/policy/imitation/discrete_cql.py @@ -1,4 +1,5 @@ -from typing import Any +from dataclasses import dataclass +from typing import Any, TypeVar import gymnasium as gym import numpy as np @@ -9,9 +10,19 @@ from tianshou.data.types import RolloutBatchProtocol from tianshou.policy import QRDQNPolicy from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats -class DiscreteCQLPolicy(QRDQNPolicy): +@dataclass(kw_only=True) +class DiscreteCQLTrainingStats(QRDQNTrainingStats): + cql_loss: float + qr_loss: float + + +TDiscreteCQLTrainingStats = TypeVar("TDiscreteCQLTrainingStats", bound=DiscreteCQLTrainingStats) + + +class DiscreteCQLPolicy(QRDQNPolicy[TDiscreteCQLTrainingStats]): """Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779. :param model: a model following the rules in @@ -72,7 +83,12 @@ def __init__( ) self.min_q_weight = min_q_weight - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn( + self, + batch: RolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> TDiscreteCQLTrainingStats: if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() @@ -101,8 +117,9 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ loss.backward() self.optim.step() self._iter += 1 - return { - "loss": loss.item(), - "loss/qr": qr_loss.item(), - "loss/cql": min_q_loss.item(), - } + + return DiscreteCQLTrainingStats( # type: ignore[return-value] + loss=loss.item(), + qr_loss=qr_loss.item(), + cql_loss=min_q_loss.item(), + ) diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index f22ffe101..9a3c2db9f 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -1,5 +1,6 @@ from copy import deepcopy -from typing import Any, Literal +from dataclasses import dataclass +from typing import Any, Literal, TypeVar import gymnasium as gym import torch @@ -9,10 +10,20 @@ from tianshou.data import to_torch, to_torch_as from tianshou.data.types import RolloutBatchProtocol from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.pg import PGPolicy +from tianshou.policy.modelfree.pg import PGPolicy, PGTrainingStats -class DiscreteCRRPolicy(PGPolicy): +@dataclass +class DiscreteCRRTrainingStats(PGTrainingStats): + actor_loss: float + critic_loss: float + cql_loss: float + + +TDiscreteCRRTrainingStats = TypeVar("TDiscreteCRRTrainingStats", bound=DiscreteCRRTrainingStats) + + +class DiscreteCRRPolicy(PGPolicy[TDiscreteCRRTrainingStats]): r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134. :param actor: the actor network following the rules in @@ -96,7 +107,7 @@ def learn( # type: ignore batch: RolloutBatchProtocol, *args: Any, **kwargs: Any, - ) -> dict[str, float]: + ) -> TDiscreteCRRTrainingStats: if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() @@ -131,9 +142,10 @@ def learn( # type: ignore loss.backward() self.optim.step() self._iter += 1 - return { - "loss": loss.item(), - "loss/actor": actor_loss.item(), - "loss/critic": critic_loss.item(), - "loss/cql": min_q_loss.item(), - } + + return DiscreteCRRTrainingStats( # type: ignore[return-value] + loss=loss.item(), + actor_loss=actor_loss.item(), + critic_loss=critic_loss.item(), + cql_loss=min_q_loss.item(), + ) diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index 5461f3657..c98f7afb8 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -1,18 +1,35 @@ -from typing import Any, Literal +from dataclasses import dataclass +from typing import Any, Literal, TypeVar import gymnasium as gym import numpy as np import torch import torch.nn.functional as F -from tianshou.data import ReplayBuffer, to_numpy, to_torch +from tianshou.data import ( + ReplayBuffer, + SequenceSummaryStats, + to_numpy, + to_torch, +) from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol from tianshou.policy import PPOPolicy from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.pg import TDistributionFunction +from tianshou.policy.modelfree.ppo import PPOTrainingStats -class GAILPolicy(PPOPolicy): +@dataclass(kw_only=True) +class GailTrainingStats(PPOTrainingStats): + disc_loss: SequenceSummaryStats + acc_pi: SequenceSummaryStats + acc_exp: SequenceSummaryStats + + +TGailTrainingStats = TypeVar("TGailTrainingStats", bound=GailTrainingStats) + + +class GAILPolicy(PPOPolicy[TGailTrainingStats]): r"""Implementation of Generative Adversarial Imitation Learning. arXiv:1606.03476. :param actor: the actor network following the rules in BasePolicy. (s -> logits) @@ -142,7 +159,7 @@ def learn( # type: ignore batch_size: int | None, repeat: int, **kwargs: Any, - ) -> dict[str, list[float]]: + ) -> TGailTrainingStats: # update discriminator losses = [] acc_pis = [] @@ -162,8 +179,15 @@ def learn( # type: ignore acc_pis.append((logits_pi < 0).float().mean().item()) acc_exps.append((logits_exp > 0).float().mean().item()) # update policy - res = super().learn(batch, batch_size, repeat, **kwargs) - res["loss/disc"] = losses - res["stats/acc_pi"] = acc_pis - res["stats/acc_exp"] = acc_exps - return res + ppo_loss_stat = super().learn(batch, batch_size, repeat, **kwargs) + + disc_losses_summary = SequenceSummaryStats.from_sequence(losses) + acc_pi_summary = SequenceSummaryStats.from_sequence(acc_pis) + acc_exps_summary = SequenceSummaryStats.from_sequence(acc_exps) + + return GailTrainingStats( # type: ignore[return-value] + **ppo_loss_stat.__dict__, + disc_loss=disc_losses_summary, + acc_pi=acc_pi_summary, + acc_exp=acc_exps_summary, + ) diff --git a/tianshou/policy/imitation/td3_bc.py b/tianshou/policy/imitation/td3_bc.py index 4cb60bf21..7ef700b0c 100644 --- a/tianshou/policy/imitation/td3_bc.py +++ b/tianshou/policy/imitation/td3_bc.py @@ -1,4 +1,5 @@ -from typing import Any, Literal +from dataclasses import dataclass +from typing import Any, Literal, TypeVar import gymnasium as gym import torch @@ -9,9 +10,18 @@ from tianshou.exploration import BaseNoise, GaussianNoise from tianshou.policy import TD3Policy from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.td3 import TD3TrainingStats -class TD3BCPolicy(TD3Policy): +@dataclass(kw_only=True) +class TD3BCTrainingStats(TD3TrainingStats): + pass + + +TTD3BCTrainingStats = TypeVar("TTD3BCTrainingStats", bound=TD3BCTrainingStats) + + +class TD3BCPolicy(TD3Policy[TTD3BCTrainingStats]): """Implementation of TD3+BC. arXiv:2106.06860. :param actor: the actor network following the rules in @@ -94,7 +104,7 @@ def __init__( ) self.alpha = alpha - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TTD3BCTrainingStats: # type: ignore # critic 1&2 td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim) @@ -112,8 +122,9 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ self.actor_optim.step() self.sync_weight() self._cnt += 1 - return { - "loss/actor": self._last, - "loss/critic1": critic1_loss.item(), - "loss/critic2": critic2_loss.item(), - } + + return TD3BCTrainingStats( # type: ignore[return-value] + actor_loss=self._last, + critic1_loss=critic1_loss.item(), + critic2_loss=critic2_loss.item(), + ) diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py index 8f05d6f69..54d560b9d 100644 --- a/tianshou/policy/modelbased/icm.py +++ b/tianshou/policy/modelbased/icm.py @@ -9,11 +9,31 @@ from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.policy import BasePolicy -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import ( + TLearningRateScheduler, + TrainingStats, + TrainingStatsWrapper, + TTrainingStats, +) from tianshou.utils.net.discrete import IntrinsicCuriosityModule -class ICMPolicy(BasePolicy): +class ICMTrainingStats(TrainingStatsWrapper): + def __init__( + self, + wrapped_stats: TrainingStats, + *, + icm_loss: float, + icm_forward_loss: float, + icm_inverse_loss: float, + ) -> None: + self.icm_loss = icm_loss + self.icm_forward_loss = icm_forward_loss + self.icm_inverse_loss = icm_inverse_loss + super().__init__(wrapped_stats) + + +class ICMPolicy(BasePolicy[ICMTrainingStats]): """Implementation of Intrinsic Curiosity Module. arXiv:1705.05363. :param policy: a base policy to add ICM to. @@ -37,7 +57,7 @@ class ICMPolicy(BasePolicy): def __init__( self, *, - policy: BasePolicy, + policy: BasePolicy[TTrainingStats], model: IntrinsicCuriosityModule, optim: torch.optim.Optimizer, lr_scale: float, @@ -128,8 +148,13 @@ def post_process_fn( self.policy.post_process_fn(batch, buffer, indices) batch.rew = batch.policy.orig_rew # restore original reward - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: - res = self.policy.learn(batch, **kwargs) + def learn( + self, + batch: RolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> ICMTrainingStats: + training_stat = self.policy.learn(batch, **kwargs) self.optim.zero_grad() act_hat = batch.policy.act_hat act = to_torch(batch.act, dtype=torch.long, device=act_hat.device) @@ -140,11 +165,10 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ ) * self.lr_scale loss.backward() self.optim.step() - res.update( - { - "loss/icm": loss.item(), - "loss/icm/forward": forward_loss.item(), - "loss/icm/inverse": inverse_loss.item(), - }, + + return ICMTrainingStats( + training_stat, + icm_loss=loss.item(), + icm_forward_loss=forward_loss.item(), + icm_inverse_loss=inverse_loss.item(), ) - return res diff --git a/tianshou/policy/modelbased/psrl.py b/tianshou/policy/modelbased/psrl.py index 52dadfedb..8c1374709 100644 --- a/tianshou/policy/modelbased/psrl.py +++ b/tianshou/policy/modelbased/psrl.py @@ -1,4 +1,5 @@ -from typing import Any, cast +from dataclasses import dataclass +from typing import Any, TypeVar, cast import gymnasium as gym import numpy as np @@ -8,7 +9,16 @@ from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.policy import BasePolicy -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import TLearningRateScheduler, TrainingStats + + +@dataclass(kw_only=True) +class PSRLTrainingStats(TrainingStats): + psrl_rew_mean: float = 0.0 + psrl_rew_std: float = 0.0 + + +TPSRLTrainingStats = TypeVar("TPSRLTrainingStats", bound=PSRLTrainingStats) class PSRLModel: @@ -140,7 +150,7 @@ def __call__( return self.policy[obs] -class PSRLPolicy(BasePolicy): +class PSRLPolicy(BasePolicy[TPSRLTrainingStats]): """Implementation of Posterior Sampling Reinforcement Learning. Reference: Strens M. A Bayesian framework for reinforcement learning [C] @@ -217,7 +227,7 @@ def forward( act = self.model(batch.obs, state=state, info=batch.info) return cast(ActBatchProtocol, Batch(act=act)) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TPSRLTrainingStats: n_s, n_a = self.model.n_state, self.model.n_action trans_count = np.zeros((n_s, n_a, n_s)) rew_sum = np.zeros((n_s, n_a)) @@ -236,7 +246,8 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ trans_count[obs_next, :, obs_next] += 1 rew_count[obs_next, :] += 1 self.model.observe(trans_count, rew_sum, rew_square_sum, rew_count) - return { - "psrl/rew_mean": float(self.model.rew_mean.mean()), - "psrl/rew_std": float(self.model.rew_std.mean()), - } + + return PSRLTrainingStats( # type: ignore[return-value] + psrl_rew_mean=float(self.model.rew_mean.mean()), + psrl_rew_std=float(self.model.rew_std.mean()), + ) diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 5d43ec1ee..2aad187dd 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -1,4 +1,5 @@ -from typing import Any, Literal, cast +from dataclasses import dataclass +from typing import Any, Generic, Literal, TypeVar, cast import gymnasium as gym import numpy as np @@ -6,15 +7,27 @@ import torch.nn.functional as F from torch import nn -from tianshou.data import ReplayBuffer, to_torch_as +from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol from tianshou.policy import PGPolicy -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.policy.modelfree.pg import TDistributionFunction from tianshou.utils.net.common import ActorCritic -class A2CPolicy(PGPolicy): +@dataclass(kw_only=True) +class A2CTrainingStats(TrainingStats): + loss: SequenceSummaryStats + actor_loss: SequenceSummaryStats + vf_loss: SequenceSummaryStats + ent_loss: SequenceSummaryStats + + +TA2CTrainingStats = TypeVar("TA2CTrainingStats", bound=A2CTrainingStats) + + +# TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. +class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # type: ignore[type-var] """Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783. :param actor: the actor network following the rules in BasePolicy. (s -> logits) @@ -146,7 +159,7 @@ def learn( # type: ignore repeat: int, *args: Any, **kwargs: Any, - ) -> dict[str, list[float]]: + ) -> TA2CTrainingStats: losses, actor_losses, vf_losses, ent_losses = [], [], [], [] split_batch_size = batch_size or -1 for _ in range(repeat): @@ -175,9 +188,14 @@ def learn( # type: ignore ent_losses.append(ent_loss.item()) losses.append(loss.item()) - return { - "loss": losses, - "loss/actor": actor_losses, - "loss/vf": vf_losses, - "loss/ent": ent_losses, - } + loss_summary_stat = SequenceSummaryStats.from_sequence(losses) + actor_loss_summary_stat = SequenceSummaryStats.from_sequence(actor_losses) + vf_loss_summary_stat = SequenceSummaryStats.from_sequence(vf_losses) + ent_loss_summary_stat = SequenceSummaryStats.from_sequence(ent_losses) + + return A2CTrainingStats( # type: ignore[return-value] + loss=loss_summary_stat, + actor_loss=actor_loss_summary_stat, + vf_loss=vf_loss_summary_stat, + ent_loss=ent_loss_summary_stat, + ) diff --git a/tianshou/policy/modelfree/bdq.py b/tianshou/policy/modelfree/bdq.py index b78aa15e7..a91ea0093 100644 --- a/tianshou/policy/modelfree/bdq.py +++ b/tianshou/policy/modelfree/bdq.py @@ -1,4 +1,5 @@ -from typing import Any, Literal, cast +from dataclasses import dataclass +from typing import Any, Literal, TypeVar, cast import gymnasium as gym import numpy as np @@ -14,10 +15,19 @@ ) from tianshou.policy import DQNPolicy from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.dqn import DQNTrainingStats from tianshou.utils.net.common import BranchingNet -class BranchingDQNPolicy(DQNPolicy): +@dataclass(kw_only=True) +class BDQNTrainingStats(DQNTrainingStats): + pass + + +TBDQNTrainingStats = TypeVar("TBDQNTrainingStats", bound=BDQNTrainingStats) + + +class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]): """Implementation of the Branching dual Q network arXiv:1711.08946. :param model: BranchingNet mapping (obs, state, info) -> logits. @@ -151,7 +161,7 @@ def forward( result = Batch(logits=logits, act=act, state=hidden) return cast(ModelOutputBatchProtocol, result) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBDQNTrainingStats: if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() @@ -169,7 +179,8 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ loss.backward() self.optim.step() self._iter += 1 - return {"loss": loss.item()} + + return BDQNTrainingStats(loss=loss.item()) # type: ignore[return-value] def exploration_noise( self, diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index 600694fa2..bd4491469 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -1,4 +1,5 @@ -from typing import Any +from dataclasses import dataclass +from typing import Any, Generic, TypeVar import gymnasium as gym import numpy as np @@ -8,9 +9,18 @@ from tianshou.data.types import RolloutBatchProtocol from tianshou.policy import DQNPolicy from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.dqn import DQNTrainingStats -class C51Policy(DQNPolicy): +@dataclass(kw_only=True) +class C51TrainingStats(DQNTrainingStats): + pass + + +TC51TrainingStats = TypeVar("TC51TrainingStats", bound=C51TrainingStats) + + +class C51Policy(DQNPolicy[TC51TrainingStats], Generic[TC51TrainingStats]): """Implementation of Categorical Deep Q-Network. arXiv:1707.06887. :param model: a model following the rules in @@ -107,7 +117,7 @@ def _target_dist(self, batch: RolloutBatchProtocol) -> torch.Tensor: ).clamp(0, 1) * next_dist.unsqueeze(1) return target_dist.sum(-1) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TC51TrainingStats: if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() @@ -124,4 +134,5 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ loss.backward() self.optim.step() self._iter += 1 - return {"loss": loss.item()} + + return C51TrainingStats(loss=loss.item()) # type: ignore[return-value] diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index d2f0987f5..1b371d4b3 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -1,6 +1,7 @@ import warnings from copy import deepcopy -from typing import Any, Literal, Self, cast +from dataclasses import dataclass +from typing import Any, Generic, Literal, Self, TypeVar, cast import gymnasium as gym import numpy as np @@ -16,10 +17,19 @@ ) from tianshou.exploration import BaseNoise, GaussianNoise from tianshou.policy import BasePolicy -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import TLearningRateScheduler, TrainingStats -class DDPGPolicy(BasePolicy): +@dataclass(kw_only=True) +class DDPGTrainingStats(TrainingStats): + actor_loss: float + critic_loss: float + + +TDDPGTrainingStats = TypeVar("TDDPGTrainingStats", bound=DDPGTrainingStats) + + +class DDPGPolicy(BasePolicy[TDDPGTrainingStats], Generic[TDDPGTrainingStats]): """Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971. :param actor: The actor network following the rules in @@ -185,7 +195,7 @@ def _mse_optimizer( optimizer.step() return td, critic_loss - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDDPGTrainingStats: # type: ignore # critic td, critic_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) batch.weight = td # prio-buffer @@ -195,7 +205,8 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ actor_loss.backward() self.actor_optim.step() self.sync_weight() - return {"loss/actor": actor_loss.item(), "loss/critic": critic_loss.item()} + + return DDPGTrainingStats(actor_loss=actor_loss.item(), critic_loss=critic_loss.item()) # type: ignore[return-value] def exploration_noise( self, diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 1b4a3427d..8a80f184c 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -1,4 +1,5 @@ -from typing import Any, cast +from dataclasses import dataclass +from typing import Any, TypeVar, cast import gymnasium as gym import numpy as np @@ -11,9 +12,18 @@ from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol from tianshou.policy import SACPolicy from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.sac import SACTrainingStats -class DiscreteSACPolicy(SACPolicy): +@dataclass +class DiscreteSACTrainingStats(SACTrainingStats): + pass + + +TDiscreteSACTrainingStats = TypeVar("TDiscreteSACTrainingStats", bound=DiscreteSACTrainingStats) + + +class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]): """Implementation of SAC for Discrete Action Settings. arXiv:1910.07207. :param actor: the actor network following the rules in @@ -117,7 +127,7 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: ) return target_q.sum(dim=-1) + self.alpha * dist.entropy() - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDiscreteSACTrainingStats: # type: ignore weight = batch.pop("weight", 1.0) target_q = batch.returns.flatten() act = to_torch(batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long) @@ -163,17 +173,16 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ self.sync_weight() - result = { - "loss/actor": actor_loss.item(), - "loss/critic1": critic1_loss.item(), - "loss/critic2": critic2_loss.item(), - } if self.is_auto_alpha: self.alpha = cast(torch.Tensor, self.alpha) - result["loss/alpha"] = alpha_loss.item() - result["alpha"] = self.alpha.item() - return result + return DiscreteSACTrainingStats( # type: ignore[return-value] + actor_loss=actor_loss.item(), + critic1_loss=critic1_loss.item(), + critic2_loss=critic2_loss.item(), + alpha=self.alpha.item() if isinstance(self.alpha, torch.Tensor) else self.alpha, + alpha_loss=None if not self.is_auto_alpha else alpha_loss.item(), + ) def exploration_noise( self, diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index f8c458255..5b90510b4 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -1,5 +1,6 @@ from copy import deepcopy -from typing import Any, Literal, Self, cast +from dataclasses import dataclass +from typing import Any, Generic, Literal, Self, TypeVar, cast import gymnasium as gym import numpy as np @@ -14,10 +15,18 @@ RolloutBatchProtocol, ) from tianshou.policy import BasePolicy -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import TLearningRateScheduler, TrainingStats -class DQNPolicy(BasePolicy): +@dataclass(kw_only=True) +class DQNTrainingStats(TrainingStats): + loss: float + + +TDQNTrainingStats = TypeVar("TDQNTrainingStats", bound=DQNTrainingStats) + + +class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]): """Implementation of Deep Q Network. arXiv:1312.5602. Implementation of Double Q-Learning. arXiv:1509.06461. @@ -199,7 +208,7 @@ def forward( result = Batch(logits=logits, act=act, state=hidden) return cast(ModelOutputBatchProtocol, result) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats: if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() @@ -220,7 +229,8 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ loss.backward() self.optim.step() self._iter += 1 - return {"loss": loss.item()} + + return DQNTrainingStats(loss=loss.item()) # type: ignore[return-value] def exploration_noise( self, diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index c68043614..9f1b083ee 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -1,4 +1,5 @@ -from typing import Any, Literal, cast +from dataclasses import dataclass +from typing import Any, Literal, TypeVar, cast import gymnasium as gym import numpy as np @@ -9,10 +10,21 @@ from tianshou.data.types import FQFBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.policy import DQNPolicy, QRDQNPolicy from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction -class FQFPolicy(QRDQNPolicy): +@dataclass(kw_only=True) +class FQFTrainingStats(QRDQNTrainingStats): + quantile_loss: float + fraction_loss: float + entropy_loss: float + + +TFQFTrainingStats = TypeVar("TFQFTrainingStats", bound=FQFTrainingStats) + + +class FQFPolicy(QRDQNPolicy[TFQFTrainingStats]): """Implementation of Fully-parameterized Quantile Function. arXiv:1911.02140. :param model: a model following the rules in @@ -141,7 +153,7 @@ def forward( # type: ignore ) return cast(FQFBatchProtocol, result) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TFQFTrainingStats: if self._target and self._iter % self.freq == 0: self.sync_weight() weight = batch.pop("weight", 1.0) @@ -199,9 +211,10 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ quantile_loss.backward() self.optim.step() self._iter += 1 - return { - "loss": quantile_loss.item() + fraction_entropy_loss.item(), - "loss/quantile": quantile_loss.item(), - "loss/fraction": fraction_loss.item(), - "loss/entropy": entropy_loss.item(), - } + + return FQFTrainingStats( # type: ignore[return-value] + loss=quantile_loss.item() + fraction_entropy_loss.item(), + quantile_loss=quantile_loss.item(), + fraction_loss=fraction_loss.item(), + entropy_loss=entropy_loss.item(), + ) diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index c87242fcd..f242c146d 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -1,4 +1,5 @@ -from typing import Any, Literal, cast +from dataclasses import dataclass +from typing import Any, Literal, TypeVar, cast import gymnasium as gym import numpy as np @@ -14,9 +15,18 @@ ) from tianshou.policy import QRDQNPolicy from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.qrdqn import QRDQNTrainingStats -class IQNPolicy(QRDQNPolicy): +@dataclass(kw_only=True) +class IQNTrainingStats(QRDQNTrainingStats): + pass + + +TIQNTrainingStats = TypeVar("TIQNTrainingStats", bound=IQNTrainingStats) + + +class IQNPolicy(QRDQNPolicy[TIQNTrainingStats]): """Implementation of Implicit Quantile Network. arXiv:1806.06923. :param model: a model following the rules in @@ -121,7 +131,7 @@ def forward( result = Batch(logits=logits, act=act, state=hidden, taus=taus) return cast(QuantileRegressionBatchProtocol, result) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TIQNTrainingStats: if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() @@ -148,4 +158,5 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ loss.backward() self.optim.step() self._iter += 1 - return {"loss": loss.item()} + + return IQNTrainingStats(loss=loss.item()) # type: ignore[return-value] diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index bb1a1cccd..f2939450c 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -1,4 +1,5 @@ -from typing import Any, Literal +from dataclasses import dataclass +from typing import Any, Generic, Literal, TypeVar import gymnasium as gym import numpy as np @@ -7,14 +8,25 @@ from torch import nn from torch.distributions import kl_divergence -from tianshou.data import Batch, ReplayBuffer +from tianshou.data import Batch, ReplayBuffer, SequenceSummaryStats from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol from tianshou.policy import A2CPolicy -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.policy.modelfree.pg import TDistributionFunction -class NPGPolicy(A2CPolicy): +@dataclass(kw_only=True) +class NPGTrainingStats(TrainingStats): + actor_loss: SequenceSummaryStats + vf_loss: SequenceSummaryStats + kl: SequenceSummaryStats + + +TNPGTrainingStats = TypeVar("TNPGTrainingStats", bound=NPGTrainingStats) + + +# TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. +class NPGPolicy(A2CPolicy[TNPGTrainingStats], Generic[TNPGTrainingStats]): # type: ignore[type-var] """Implementation of Natural Policy Gradient. https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf @@ -112,7 +124,7 @@ def learn( # type: ignore batch_size: int | None, repeat: int, **kwargs: Any, - ) -> dict[str, list[float]]: + ) -> TNPGTrainingStats: actor_losses, vf_losses, kls = [], [], [] split_batch_size = batch_size or -1 for _ in range(repeat): @@ -144,7 +156,7 @@ def learn( # type: ignore new_dist = self(minibatch).dist kl = kl_divergence(old_dist, new_dist).mean() - # optimize citirc + # optimize critic for _ in range(self.optim_critic_iters): value = self.critic(minibatch.obs).flatten() vf_loss = F.mse_loss(minibatch.returns, value) @@ -156,7 +168,15 @@ def learn( # type: ignore vf_losses.append(vf_loss.item()) kls.append(kl.item()) - return {"loss/actor": actor_losses, "loss/vf": vf_losses, "kl": kls} + actor_loss_summary_stat = SequenceSummaryStats.from_sequence(actor_losses) + vf_loss_summary_stat = SequenceSummaryStats.from_sequence(vf_losses) + kl_summary_stat = SequenceSummaryStats.from_sequence(kls) + + return NPGTrainingStats( # type: ignore[return-value] + actor_loss=actor_loss_summary_stat, + vf_loss=vf_loss_summary_stat, + kl=kl_summary_stat, + ) def _MVP(self, v: torch.Tensor, flat_kl_grad: torch.Tensor) -> torch.Tensor: """Matrix vector product.""" diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 09be46a8d..7db588be6 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -1,12 +1,19 @@ import warnings from collections.abc import Callable -from typing import Any, Literal, TypeAlias, cast +from dataclasses import dataclass +from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast import gymnasium as gym import numpy as np import torch -from tianshou.data import Batch, ReplayBuffer, to_torch, to_torch_as +from tianshou.data import ( + Batch, + ReplayBuffer, + SequenceSummaryStats, + to_torch, + to_torch_as, +) from tianshou.data.batch import BatchProtocol from tianshou.data.types import ( BatchWithReturnsProtocol, @@ -15,14 +22,22 @@ RolloutBatchProtocol, ) from tianshou.policy import BasePolicy -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.utils import RunningMeanStd # TODO: Is there a better way to define this type? mypy doesn't like Callable[[torch.Tensor, ...], torch.distributions.Distribution] TDistributionFunction: TypeAlias = Callable[..., torch.distributions.Distribution] -class PGPolicy(BasePolicy): +@dataclass(kw_only=True) +class PGTrainingStats(TrainingStats): + loss: SequenceSummaryStats + + +TPGTrainingStats = TypeVar("TPGTrainingStats", bound=PGTrainingStats) + + +class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]): """Implementation of REINFORCE algorithm. :param actor: mapping (s->model_output), should follow the rules in @@ -192,12 +207,12 @@ def forward( # TODO: why does mypy complain? def learn( # type: ignore self, - batch: RolloutBatchProtocol, + batch: BatchWithReturnsProtocol, batch_size: int | None, repeat: int, *args: Any, **kwargs: Any, - ) -> dict[str, list[float]]: + ) -> TPGTrainingStats: losses = [] split_batch_size = batch_size or -1 for _ in range(repeat): @@ -213,4 +228,6 @@ def learn( # type: ignore self.optim.step() losses.append(loss.item()) - return {"loss": losses} + loss_summary_stat = SequenceSummaryStats.from_sequence(losses) + + return PGTrainingStats(loss=loss_summary_stat) # type: ignore[return-value] diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 3f8c47fb9..fde9e7c79 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -1,19 +1,32 @@ -from typing import Any, Literal +from dataclasses import dataclass +from typing import Any, Generic, Literal, TypeVar import gymnasium as gym import numpy as np import torch from torch import nn -from tianshou.data import ReplayBuffer, to_torch_as +from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol from tianshou.policy import A2CPolicy -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.policy.modelfree.pg import TDistributionFunction from tianshou.utils.net.common import ActorCritic -class PPOPolicy(A2CPolicy): +@dataclass(kw_only=True) +class PPOTrainingStats(TrainingStats): + loss: SequenceSummaryStats + clip_loss: SequenceSummaryStats + vf_loss: SequenceSummaryStats + ent_loss: SequenceSummaryStats + + +TPPOTrainingStats = TypeVar("TPPOTrainingStats", bound=PPOTrainingStats) + + +# TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. +class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # type: ignore[type-var] r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347. :param actor: the actor network following the rules in BasePolicy. (s -> logits) @@ -132,7 +145,7 @@ def learn( # type: ignore repeat: int, *args: Any, **kwargs: Any, - ) -> dict[str, list[float]]: + ) -> TPPOTrainingStats: losses, clip_losses, vf_losses, ent_losses = [], [], [], [] split_batch_size = batch_size or -1 for step in range(repeat): @@ -182,9 +195,14 @@ def learn( # type: ignore ent_losses.append(ent_loss.item()) losses.append(loss.item()) - return { - "loss": losses, - "loss/clip": clip_losses, - "loss/vf": vf_losses, - "loss/ent": ent_losses, - } + losses_summary = SequenceSummaryStats.from_sequence(losses) + clip_losses_summary = SequenceSummaryStats.from_sequence(clip_losses) + vf_losses_summary = SequenceSummaryStats.from_sequence(vf_losses) + ent_losses_summary = SequenceSummaryStats.from_sequence(ent_losses) + + return PPOTrainingStats( # type: ignore[return-value] + loss=losses_summary, + clip_loss=clip_losses_summary, + vf_loss=vf_losses_summary, + ent_loss=ent_losses_summary, + ) diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index f4b96151c..b2f5d1e8c 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -1,5 +1,6 @@ import warnings -from typing import Any +from dataclasses import dataclass +from typing import Any, Generic, TypeVar import gymnasium as gym import numpy as np @@ -10,9 +11,18 @@ from tianshou.data.types import RolloutBatchProtocol from tianshou.policy import DQNPolicy from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.dqn import DQNTrainingStats -class QRDQNPolicy(DQNPolicy): +@dataclass(kw_only=True) +class QRDQNTrainingStats(DQNTrainingStats): + pass + + +TQRDQNTrainingStats = TypeVar("TQRDQNTrainingStats", bound=QRDQNTrainingStats) + + +class QRDQNPolicy(DQNPolicy[TQRDQNTrainingStats], Generic[TQRDQNTrainingStats]): """Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044. :param model: a model following the rules in @@ -95,7 +105,7 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: def compute_q_value(self, logits: torch.Tensor, mask: np.ndarray | None) -> torch.Tensor: return super().compute_q_value(logits.mean(2), mask) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TQRDQNTrainingStats: if self._target and self._iter % self.freq == 0: self.sync_weight() self.optim.zero_grad() @@ -118,4 +128,5 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ loss.backward() self.optim.step() self._iter += 1 - return {"loss": loss.item()} + + return QRDQNTrainingStats(loss=loss.item()) # type: ignore[return-value] diff --git a/tianshou/policy/modelfree/rainbow.py b/tianshou/policy/modelfree/rainbow.py index 2b476bcc7..fad567cd2 100644 --- a/tianshou/policy/modelfree/rainbow.py +++ b/tianshou/policy/modelfree/rainbow.py @@ -1,9 +1,11 @@ -from typing import Any +from dataclasses import dataclass +from typing import Any, TypeVar from torch import nn from tianshou.data.types import RolloutBatchProtocol from tianshou.policy import C51Policy +from tianshou.policy.modelfree.c51 import C51TrainingStats from tianshou.utils.net.discrete import NoisyLinear @@ -25,8 +27,16 @@ def _sample_noise(model: nn.Module) -> bool: return sampled_any_noise +@dataclass(kw_only=True) +class RainbowTrainingStats(C51TrainingStats): + loss: float + + +TRainbowTrainingStats = TypeVar("TRainbowTrainingStats", bound=RainbowTrainingStats) + + # TODO: is this class worth keeping? It barely does anything -class RainbowPolicy(C51Policy): +class RainbowPolicy(C51Policy[TRainbowTrainingStats]): """Implementation of Rainbow DQN. arXiv:1710.02298. Same parameters as :class:`~tianshou.policy.C51Policy`. @@ -37,7 +47,12 @@ class RainbowPolicy(C51Policy): explanation. """ - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn( + self, + batch: RolloutBatchProtocol, + *args: Any, + **kwargs: Any, + ) -> TRainbowTrainingStats: _sample_noise(self.model) if self._target and _sample_noise(self.model_old): self.model_old.train() # so that NoisyLinear takes effect diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index cdc5ad93d..d1b00714d 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -1,4 +1,5 @@ -from typing import Any, Literal, cast +from dataclasses import dataclass +from typing import Any, Literal, TypeVar, cast import gymnasium as gym import numpy as np @@ -10,9 +11,21 @@ from tianshou.exploration import BaseNoise from tianshou.policy import DDPGPolicy from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.ddpg import DDPGTrainingStats -class REDQPolicy(DDPGPolicy): +@dataclass +class REDQTrainingStats(DDPGTrainingStats): + """A data structure for storing loss statistics of the REDQ learn step.""" + + alpha: float | None = None + alpha_loss: float | None = None + + +TREDQTrainingStats = TypeVar("TREDQTrainingStats", bound=REDQTrainingStats) + + +class REDQPolicy(DDPGPolicy[TREDQTrainingStats]): """Implementation of REDQ. arXiv:2101.05982. :param actor: The actor network following the rules in @@ -99,6 +112,8 @@ def __init__( self.deterministic_eval = deterministic_eval self.__eps = np.finfo(np.float32).eps.item() + self._last_actor_loss = 0.0 # only for logging purposes + # TODO: reduce duplication with SACPolicy self.alpha: float | torch.Tensor self._is_auto_alpha = not isinstance(alpha, float) @@ -168,7 +183,7 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: return target_q - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TREDQTrainingStats: # type: ignore # critic ensemble weight = getattr(batch, "weight", 1.0) current_qs = self.critic(batch.obs, batch.act).flatten(1) @@ -181,6 +196,7 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ batch.weight = torch.mean(td, dim=0) # prio-buffer self.critic_gradient_step += 1 + alpha_loss = None # actor if self.critic_gradient_step % self.actor_delay == 0: obs_result = self(batch) @@ -201,12 +217,14 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ self.sync_weight() - result = {"loss/critics": critic_loss.item()} if self.critic_gradient_step % self.actor_delay == 0: - result["loss/actor"] = (actor_loss.item(),) - if self.is_auto_alpha: - self.alpha = cast(torch.Tensor, self.alpha) - result["loss/alpha"] = alpha_loss.item() - result["alpha"] = self.alpha.item() - - return result + self._last_actor_loss = actor_loss.item() + if self.is_auto_alpha: + self.alpha = cast(torch.Tensor, self.alpha) + + return REDQTrainingStats( # type: ignore[return-value] + actor_loss=self._last_actor_loss, + critic_loss=critic_loss.item(), + alpha=self.alpha.item() if isinstance(self.alpha, torch.Tensor) else self.alpha, + alpha_loss=alpha_loss, + ) diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 2c494d4c6..f487e33e9 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -1,5 +1,6 @@ from copy import deepcopy -from typing import Any, Literal, Self, cast +from dataclasses import dataclass +from typing import Any, Generic, Literal, Self, TypeVar, cast import gymnasium as gym import numpy as np @@ -14,11 +15,25 @@ ) from tianshou.exploration import BaseNoise from tianshou.policy import DDPGPolicy -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.utils.conversion import to_optional_float from tianshou.utils.optim import clone_optimizer -class SACPolicy(DDPGPolicy): +@dataclass(kw_only=True) +class SACTrainingStats(TrainingStats): + actor_loss: float + critic1_loss: float + critic2_loss: float + alpha: float | None = None + alpha_loss: float | None = None + + +TSACTrainingStats = TypeVar("TSACTrainingStats", bound=SACTrainingStats) + + +# TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. +class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # type: ignore[type-var] """Implementation of Soft Actor-Critic. arXiv:1812.05905. :param actor: the actor network following the rules in @@ -120,7 +135,10 @@ def __init__( self.target_entropy, self.log_alpha, self.alpha_optim = alpha self.alpha = self.log_alpha.detach().exp() else: - alpha = cast(float, alpha) + alpha = cast( + float, + alpha, + ) # can we convert alpha to a constant tensor here? then mypy wouldn't complain self.alpha = alpha # TODO or not TODO: add to BasePolicy? @@ -195,7 +213,7 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: - self.alpha * obs_next_result.log_prob ) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TSACTrainingStats: # type: ignore # critic 1&2 td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim) @@ -212,6 +230,7 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() + alpha_loss = None if self.is_auto_alpha: log_prob = obs_result.log_prob.detach() + self.target_entropy @@ -224,14 +243,10 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ self.sync_weight() - result = { - "loss/actor": actor_loss.item(), - "loss/critic1": critic1_loss.item(), - "loss/critic2": critic2_loss.item(), - } - if self.is_auto_alpha: - self.alpha = cast(torch.Tensor, self.alpha) - result["loss/alpha"] = alpha_loss.item() - result["alpha"] = self.alpha.item() - - return result + return SACTrainingStats( # type: ignore[return-value] + actor_loss=actor_loss.item(), + critic1_loss=critic1_loss.item(), + critic2_loss=critic2_loss.item(), + alpha=to_optional_float(self.alpha), + alpha_loss=to_optional_float(alpha_loss), + ) diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index c364fc8e6..dbf7b6589 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -1,5 +1,6 @@ from copy import deepcopy -from typing import Any, Literal, Self +from dataclasses import dataclass +from typing import Any, Generic, Literal, Self, TypeVar import gymnasium as gym import numpy as np @@ -9,11 +10,22 @@ from tianshou.data.types import RolloutBatchProtocol from tianshou.exploration import BaseNoise from tianshou.policy import DDPGPolicy -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.utils.optim import clone_optimizer -class TD3Policy(DDPGPolicy): +@dataclass(kw_only=True) +class TD3TrainingStats(TrainingStats): + actor_loss: float + critic1_loss: float + critic2_loss: float + + +TTD3TrainingStats = TypeVar("TTD3TrainingStats", bound=TD3TrainingStats) + + +# TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure. +class TD3Policy(DDPGPolicy[TTD3TrainingStats], Generic[TTD3TrainingStats]): # type: ignore[type-var] """Implementation of TD3, arXiv:1802.09477. :param actor: the actor network following the rules in @@ -128,7 +140,7 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: self.critic2_old(obs_next_batch.obs, act_), ) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TTD3TrainingStats: # type: ignore # critic 1&2 td1, critic1_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) td2, critic2_loss = self._mse_optimizer(batch, self.critic2, self.critic2_optim) @@ -143,8 +155,9 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[ self.actor_optim.step() self.sync_weight() self._cnt += 1 - return { - "loss/actor": self._last, - "loss/critic1": critic1_loss.item(), - "loss/critic2": critic2_loss.item(), - } + + return TD3TrainingStats( # type: ignore[return-value] + actor_loss=self._last, + critic1_loss=critic1_loss.item(), + critic2_loss=critic2_loss.item(), + ) diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index 7546a25d8..babc23bfa 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -1,18 +1,28 @@ import warnings -from typing import Any, Literal +from dataclasses import dataclass +from typing import Any, Literal, TypeVar import gymnasium as gym import torch import torch.nn.functional as F from torch.distributions import kl_divergence -from tianshou.data import Batch +from tianshou.data import Batch, SequenceSummaryStats from tianshou.policy import NPGPolicy from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.modelfree.npg import NPGTrainingStats from tianshou.policy.modelfree.pg import TDistributionFunction -class TRPOPolicy(NPGPolicy): +@dataclass(kw_only=True) +class TRPOTrainingStats(NPGTrainingStats): + step_size: SequenceSummaryStats + + +TTRPOTrainingStats = TypeVar("TTRPOTrainingStats", bound=TRPOTrainingStats) + + +class TRPOPolicy(NPGPolicy[TTRPOTrainingStats]): """Implementation of Trust Region Policy Optimization. arXiv:1502.05477. :param actor: the actor network following the rules in BasePolicy. (s -> logits) @@ -94,7 +104,7 @@ def learn( # type: ignore batch_size: int | None, repeat: int, **kwargs: Any, - ) -> dict[str, list[float]]: + ) -> TTRPOTrainingStats: actor_losses, vf_losses, step_sizes, kls = [], [], [], [] split_batch_size = batch_size or -1 for _ in range(repeat): @@ -171,9 +181,14 @@ def learn( # type: ignore step_sizes.append(step_size.item()) kls.append(kl.item()) - return { - "loss/actor": actor_losses, - "loss/vf": vf_losses, - "step_size": step_sizes, - "kl": kls, - } + actor_loss_summary_stat = SequenceSummaryStats.from_sequence(actor_losses) + vf_loss_summary_stat = SequenceSummaryStats.from_sequence(vf_losses) + kl_summary_stat = SequenceSummaryStats.from_sequence(kls) + step_size_stat = SequenceSummaryStats.from_sequence(step_sizes) + + return TRPOTrainingStats( # type: ignore[return-value] + actor_loss=actor_loss_summary_stat, + vf_loss=vf_loss_summary_stat, + kl=kl_summary_stat, + step_size=step_size_stat, + ) diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index 4cccac773..e41e069d1 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -1,12 +1,13 @@ -from typing import Any, Literal, Self +from typing import Any, Literal, Protocol, Self, cast, overload import numpy as np +from overrides import override from tianshou.data import Batch, ReplayBuffer -from tianshou.data.batch import BatchProtocol +from tianshou.data.batch import BatchProtocol, IndexType from tianshou.data.types import RolloutBatchProtocol from tianshou.policy import BasePolicy -from tianshou.policy.base import TLearningRateScheduler +from tianshou.policy.base import TLearningRateScheduler, TrainingStats try: from tianshou.env.pettingzoo_env import PettingZooEnv @@ -14,6 +15,54 @@ PettingZooEnv = None # type: ignore +class MapTrainingStats(TrainingStats): + def __init__( + self, + agent_id_to_stats: dict[str | int, TrainingStats], + train_time_aggregator: Literal["min", "max", "mean"] = "max", + ) -> None: + self._agent_id_to_stats = agent_id_to_stats + train_times = [agent_stats.train_time for agent_stats in agent_id_to_stats.values()] + match train_time_aggregator: + case "max": + aggr_function = max + case "min": + aggr_function = min + case "mean": + aggr_function = np.mean # type: ignore + case _: + raise ValueError( + f"Unknown {train_time_aggregator=}", + ) + self.train_time = aggr_function(train_times) + self.smoothed_loss = {} + + @override + def get_loss_stats_dict(self) -> dict[str, float]: + """Collects loss_stats_dicts from all agents, prepends agent_id to all keys, and joins results.""" + result_dict = {} + for agent_id, stats in self._agent_id_to_stats.items(): + agent_loss_stats_dict = stats.get_loss_stats_dict() + for k, v in agent_loss_stats_dict.items(): + result_dict[f"{agent_id}/" + k] = v + return result_dict + + +class MAPRolloutBatchProtocol(RolloutBatchProtocol, Protocol): + # TODO: this might not be entirely correct. + # The whole MAP data processing pipeline needs more documentation and possibly some refactoring + @overload + def __getitem__(self, index: str) -> RolloutBatchProtocol: + ... + + @overload + def __getitem__(self, index: IndexType) -> Self: + ... + + def __getitem__(self, index: str | IndexType) -> Any: + ... + + class MultiAgentPolicyManager(BasePolicy): """Multi-agent policy manager for MARL. @@ -58,8 +107,10 @@ def __init__( # (this MultiAgentPolicyManager) policy.set_agent_id(env.agents[i]) - self.policies = dict(zip(env.agents, policies, strict=True)) + self.policies: dict[str | int, BasePolicy] = dict(zip(env.agents, policies, strict=True)) + """Maps agent_id to policy.""" + # TODO: unused - remove it? def replace_policy(self, policy: BasePolicy, agent_id: int) -> None: """Replace the "agent_id"th policy in this manager.""" policy.set_agent_id(agent_id) @@ -68,17 +119,18 @@ def replace_policy(self, policy: BasePolicy, agent_id: int) -> None: # TODO: violates Liskov substitution principle def process_fn( # type: ignore self, - batch: RolloutBatchProtocol, + batch: MAPRolloutBatchProtocol, buffer: ReplayBuffer, indice: np.ndarray, - ) -> BatchProtocol: - """Dispatch batch data from obs.agent_id to every policy's process_fn. + ) -> MAPRolloutBatchProtocol: + """Dispatch batch data from `obs.agent_id` to every policy's process_fn. Save original multi-dimensional rew in "save_rew", set rew to the reward of each agent during their "process_fn", and restore the original reward afterwards. """ - results = {} + # TODO: maybe only str is actually allowed as agent_id? See MAPRolloutBatchProtocol + results: dict[str | int, RolloutBatchProtocol] = {} assert isinstance( batch.obs, BatchProtocol, @@ -92,7 +144,7 @@ def process_fn( # type: ignore for agent, policy in self.policies.items(): agent_index = np.nonzero(batch.obs.agent_id == agent)[0] if len(agent_index) == 0: - results[agent] = Batch() + results[agent] = cast(RolloutBatchProtocol, Batch()) continue tmp_batch, tmp_indice = batch[agent_index], indice[agent_index] if has_rew: @@ -133,10 +185,12 @@ def forward( # type: ignore ) -> Batch: """Dispatch batch data from obs.agent_id to every policy's forward. + :param batch: TODO: document what is expected at input and make a BatchProtocol for it :param state: if None, it means all agents have no state. If not None, it should contain keys of "agent_1", "agent_2", ... :return: a Batch with the following contents: + TODO: establish a BatcProtocol for this :: @@ -202,34 +256,24 @@ def forward( # type: ignore holder["state"] = state_dict return holder - def learn( + # Violates Liskov substitution principle + def learn( # type: ignore self, - batch: RolloutBatchProtocol, + batch: MAPRolloutBatchProtocol, *args: Any, **kwargs: Any, - ) -> dict[str, float | list[float]]: + ) -> MapTrainingStats: """Dispatch the data to all policies for learning. - :return: a dict with the following contents: - - :: - - { - "agent_1/item1": item 1 of agent_1's policy.learn output - "agent_1/item2": item 2 of agent_1's policy.learn output - "agent_2/xxx": xxx - ... - "agent_n/xxx": xxx - } + :param batch: must map agent_ids to rollout batches """ - results = {} + agent_id_to_stats = {} for agent_id, policy in self.policies.items(): data = batch[agent_id] if not data.is_empty(): - out = policy.learn(batch=data, **kwargs) - for k, v in out.items(): - results[agent_id + "/" + k] = v - return results + train_stats = policy.learn(batch=data, **kwargs) + agent_id_to_stats[agent_id] = train_stats + return MapTrainingStats(agent_id_to_stats) # Need a train method that set all sub-policies to train mode. # No need for a similar eval function, as eval internally uses the train function. diff --git a/tianshou/policy/random.py b/tianshou/policy/random.py index 38ff7d064..943ae99f2 100644 --- a/tianshou/policy/random.py +++ b/tianshou/policy/random.py @@ -1,4 +1,4 @@ -from typing import Any, cast +from typing import Any, TypeVar, cast import numpy as np @@ -6,9 +6,17 @@ from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.policy import BasePolicy +from tianshou.policy.base import TrainingStats -class RandomPolicy(BasePolicy): +class RandomTrainingStats(TrainingStats): + pass + + +TRandomTrainingStats = TypeVar("TRandomTrainingStats", bound=RandomTrainingStats) + + +class RandomPolicy(BasePolicy[TRandomTrainingStats]): """A random agent used in multi-agent learning. It randomly chooses an action from the legal action. @@ -41,6 +49,6 @@ def forward( result = Batch(act=logits.argmax(axis=-1)) return cast(ActBatchProtocol, result) - def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]: + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TRandomTrainingStats: # type: ignore """Since a random agent learns nothing, it returns an empty dict.""" - return {} + return RandomTrainingStats() # type: ignore[return-value] diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index fbd3e9197..84faf7999 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -3,13 +3,24 @@ from abc import ABC, abstractmethod from collections import defaultdict, deque from collections.abc import Callable +from dataclasses import asdict from typing import Any import numpy as np import tqdm -from tianshou.data import AsyncCollector, Collector, ReplayBuffer +from tianshou.data import ( + AsyncCollector, + Collector, + CollectStats, + EpochStats, + InfoStats, + ReplayBuffer, + SequenceSummaryStats, +) +from tianshou.data.collector import CollectStatsBase from tianshou.policy import BasePolicy +from tianshou.policy.base import TrainingStats from tianshou.trainer.utils import gather_info, test_episode from tianshou.utils import ( BaseLogger, @@ -19,6 +30,7 @@ deprecation, tqdm_config, ) +from tianshou.utils.logging import set_numerical_fields_to_precision log = logging.getLogger(__name__) @@ -189,8 +201,9 @@ def __init__( self.start_epoch = 0 # This is only used for logging but creeps into the implementations # of the trainers. I believe it would be better to remove - self.gradient_step = 0 + self._gradient_step = 0 self.env_step = 0 + self.policy_update_time = 0.0 self.max_epoch = max_epoch self.step_per_epoch = step_per_epoch @@ -218,7 +231,7 @@ def __init__( self.resume_from_log = resume_from_log self.is_run = False - self.last_rew, self.last_len = 0.0, 0 + self.last_rew, self.last_len = 0.0, 0.0 self.epoch = self.start_epoch self.best_epoch = self.start_epoch @@ -233,10 +246,10 @@ def reset(self) -> None: ( self.start_epoch, self.env_step, - self.gradient_step, + self._gradient_step, ) = self.logger.restore_data() - self.last_rew, self.last_len = 0.0, 0 + self.last_rew, self.last_len = 0.0, 0.0 self.start_time = time.time() if self.train_collector is not None: self.train_collector.reset_stat() @@ -258,10 +271,11 @@ def reset(self) -> None: self.env_step, self.reward_metric, ) + assert test_result.returns_stat is not None # for mypy self.best_epoch = self.start_epoch self.best_reward, self.best_reward_std = ( - test_result["rew"], - test_result["rew_std"], + test_result.returns_stat.mean, + test_result.returns_stat.std, ) if self.save_best_fn: self.save_best_fn(self.policy) @@ -274,7 +288,7 @@ def __iter__(self): # type: ignore self.reset() return self - def __next__(self) -> None | tuple[int, dict[str, Any], dict[str, Any]]: + def __next__(self) -> EpochStats: """Perform one epoch (both train and eval).""" self.epoch += 1 self.iter_num += 1 @@ -291,29 +305,32 @@ def __next__(self) -> None | tuple[int, dict[str, Any], dict[str, Any]]: # set policy in train mode self.policy.train() - epoch_stat: dict[str, Any] = {} - progress = tqdm.tqdm if self.show_progress else DummyTqdm # perform n step_per_epoch with progress(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config) as t: while t.n < t.total and not self.stop_fn_flag: - data: dict[str, Any] = {} - result: dict[str, Any] = {} + train_stat: CollectStatsBase if self.train_collector is not None: - data, result, self.stop_fn_flag = self.train_step() - t.update(result["n/st"]) + pbar_data_dict, train_stat, self.stop_fn_flag = self.train_step() + t.update(train_stat.n_collected_steps) if self.stop_fn_flag: - t.set_postfix(**data) + t.set_postfix(**pbar_data_dict) break else: + pbar_data_dict = {} assert self.buffer, "No train_collector or buffer specified" - result["n/ep"] = len(self.buffer) - result["n/st"] = int(self.gradient_step) + train_stat = CollectStatsBase( + n_collected_episodes=len(self.buffer), + n_collected_steps=int(self._gradient_step), + ) t.update() - self.policy_update_fn(data, result) - t.set_postfix(**data) + update_stat = self.policy_update_fn(train_stat) + pbar_data_dict = set_numerical_fields_to_precision(pbar_data_dict) + pbar_data_dict["gradient_step"] = self._gradient_step + + t.set_postfix(**pbar_data_dict) if t.n <= t.total and not self.stop_fn_flag: t.update() @@ -322,49 +339,49 @@ def __next__(self) -> None | tuple[int, dict[str, Any], dict[str, Any]]: if self.train_collector is None: assert self.buffer is not None batch_size = self.batch_size or len(self.buffer) - self.env_step = self.gradient_step * batch_size + self.env_step = self._gradient_step * batch_size + test_stat = None if not self.stop_fn_flag: self.logger.save_data( self.epoch, self.env_step, - self.gradient_step, + self._gradient_step, self.save_checkpoint_fn, ) # test if self.test_collector is not None: test_stat, self.stop_fn_flag = self.test_step() - if not self.is_run: - epoch_stat.update(test_stat) - - if not self.is_run: - epoch_stat.update({k: v.get() for k, v in self.stat.items()}) - epoch_stat["gradient_step"] = self.gradient_step - epoch_stat.update( - { - "env_step": self.env_step, - "rew": self.last_rew, - "len": int(self.last_len), - "n/ep": int(result["n/ep"]), - "n/st": int(result["n/st"]), - }, - ) - info = gather_info( - self.start_time, - self.train_collector, - self.test_collector, - self.best_reward, - self.best_reward_std, - ) - return self.epoch, epoch_stat, info - return None - def test_step(self) -> tuple[dict[str, Any], bool]: + info_stat = gather_info( + start_time=self.start_time, + policy_update_time=self.policy_update_time, + gradient_step=self._gradient_step, + best_reward=self.best_reward, + best_reward_std=self.best_reward_std, + train_collector=self.train_collector, + test_collector=self.test_collector, + ) + + self.logger.log_info_data(asdict(info_stat), self.epoch) + + # in case trainer is used with run(), epoch_stat will not be returned + epoch_stat: EpochStats = EpochStats( + epoch=self.epoch, + train_collect_stat=train_stat, + test_collect_stat=test_stat, + training_stat=update_stat, + info_stat=info_stat, + ) + + return epoch_stat + + def test_step(self) -> tuple[CollectStats, bool]: """Perform one testing step.""" assert self.episode_per_test is not None assert self.test_collector is not None stop_fn_flag = False - test_result = test_episode( + test_stat = test_episode( self.policy, self.test_collector, self.test_fn, @@ -374,7 +391,8 @@ def test_step(self) -> tuple[dict[str, Any], bool]: self.env_step, self.reward_metric, ) - rew, rew_std = test_result["rew"], test_result["rew_std"] + assert test_stat.returns_stat is not None # for mypy + rew, rew_std = test_stat.returns_stat.mean, test_stat.returns_stat.std if self.best_epoch < 0 or self.best_reward < rew: self.best_epoch = self.epoch self.best_reward = float(rew) @@ -389,22 +407,13 @@ def test_step(self) -> tuple[dict[str, Any], bool]: log.info(log_msg) if self.verbose: print(log_msg, flush=True) - if not self.is_run: - test_stat = { - "test_reward": rew, - "test_reward_std": rew_std, - "best_reward": self.best_reward, - "best_reward_std": self.best_reward_std, - "best_epoch": self.best_epoch, - } - else: - test_stat = {} + if self.stop_fn and self.stop_fn(self.best_reward): stop_fn_flag = True return test_stat, stop_fn_flag - def train_step(self) -> tuple[dict[str, Any], dict[str, Any], bool]: + def train_step(self) -> tuple[dict[str, Any], CollectStats, bool]: """Perform one training step.""" assert self.episode_per_test is not None assert self.train_collector is not None @@ -415,25 +424,33 @@ def train_step(self) -> tuple[dict[str, Any], dict[str, Any], bool]: n_step=self.step_per_collect, n_episode=self.episode_per_collect, ) - if result["n/ep"] > 0 and self.reward_metric: - rew = self.reward_metric(result["rews"]) - result.update(rews=rew, rew=rew.mean(), rew_std=rew.std()) - self.env_step += int(result["n/st"]) - self.logger.log_train_data(result, self.env_step) - self.last_rew = result["rew"] if result["n/ep"] > 0 else self.last_rew - self.last_len = result["len"] if result["n/ep"] > 0 else self.last_len + + self.env_step += result.n_collected_steps + + if result.n_collected_episodes > 0: + assert result.returns_stat is not None # for mypy + assert result.lens_stat is not None # for mypy + self.last_rew = result.returns_stat.mean + self.last_len = result.lens_stat.mean + if self.reward_metric: # TODO: move inside collector + rew = self.reward_metric(result.returns) + result.returns = rew + result.returns_stat = SequenceSummaryStats.from_sequence(rew) + + self.logger.log_train_data(asdict(result), self.env_step) + data = { "env_step": str(self.env_step), "rew": f"{self.last_rew:.2f}", "len": str(int(self.last_len)), - "n/ep": str(int(result["n/ep"])), - "n/st": str(int(result["n/st"])), + "n/ep": str(result.n_collected_episodes), + "n/st": str(result.n_collected_steps), } if ( - result["n/ep"] > 0 + result.n_collected_episodes > 0 and self.test_in_train and self.stop_fn - and self.stop_fn(result["rew"]) + and self.stop_fn(result.returns_stat.mean) # type: ignore ): assert self.test_collector is not None test_result = test_episode( @@ -445,31 +462,54 @@ def train_step(self) -> tuple[dict[str, Any], dict[str, Any], bool]: self.logger, self.env_step, ) - if self.stop_fn(test_result["rew"]): + assert test_result.returns_stat is not None # for mypy + if self.stop_fn(test_result.returns_stat.mean): stop_fn_flag = True - self.best_reward = test_result["rew"] - self.best_reward_std = test_result["rew_std"] + self.best_reward = test_result.returns_stat.mean + self.best_reward_std = test_result.returns_stat.std else: self.policy.train() return data, result, stop_fn_flag - def log_update_data(self, data: dict[str, Any], losses: dict[str, Any]) -> None: - """Log losses to current logger.""" - for k in losses: - self.stat[k].add(losses[k]) - losses[k] = self.stat[k].get() - data[k] = f"{losses[k]:.3f}" - self.logger.log_update_data(losses, self.gradient_step) + # TODO: move moving average computation and logging into its own logger + # TODO: maybe think about a command line logger instead of always printing data dict + def _update_moving_avg_stats_and_log_update_data(self, update_stat: TrainingStats) -> None: + """Log losses, update moving average stats, and also modify the smoothed_loss in update_stat.""" + cur_losses_dict = update_stat.get_loss_stats_dict() + update_stat.smoothed_loss = self._update_moving_avg_stats_and_get_averaged_data( + cur_losses_dict, + ) + self.logger.log_update_data(asdict(update_stat), self._gradient_step) + + # TODO: seems convoluted, there should be a better way of dealing with the moving average stats + def _update_moving_avg_stats_and_get_averaged_data( + self, + data: dict[str, float], + ) -> dict[str, float]: + """Add entries to the moving average object in the trainer and retrieve the averaged results. + + :param data: any entries to be tracked in the moving average object. + :return: A dictionary containing the averaged values of the tracked entries. + + """ + smoothed_data = {} + for key, loss_item in data.items(): + self.stat[key].add(loss_item) + smoothed_data[key] = self.stat[key].get() + return smoothed_data @abstractmethod - def policy_update_fn(self, data: dict[str, Any], result: dict[str, Any]) -> None: + def policy_update_fn( + self, + collect_stats: CollectStatsBase, + ) -> TrainingStats: """Policy update function for different trainer implementation. - :param data: information in progress bar. - :param result: collector's return value. + :param collect_stats: provides info about the most recent collection. In the offline case, this will contain + stats of the whole dataset """ - def run(self) -> dict[str, float | str]: + def run(self) -> InfoStats: """Consume iterator. See itertools - recipes. Use functions that consume iterators at C speed @@ -479,25 +519,28 @@ def run(self) -> dict[str, float | str]: self.is_run = True deque(self, maxlen=0) # feed the entire iterator into a zero-length deque info = gather_info( - self.start_time, - self.train_collector, - self.test_collector, - self.best_reward, - self.best_reward_std, + start_time=self.start_time, + policy_update_time=self.policy_update_time, + gradient_step=self._gradient_step, + best_reward=self.best_reward, + best_reward_std=self.best_reward_std, + train_collector=self.train_collector, + test_collector=self.test_collector, ) finally: self.is_run = False return info - def _sample_and_update(self, buffer: ReplayBuffer, data: dict[str, Any]) -> None: - self.gradient_step += 1 + def _sample_and_update(self, buffer: ReplayBuffer) -> TrainingStats: + """Sample a mini-batch, perform one gradient step, and update the _gradient_step counter.""" + self._gradient_step += 1 # Note: since sample_size=batch_size, this will perform # exactly one gradient step. This is why we don't need to calculate the # number of gradient steps, like in the on-policy case. - losses = self.policy.update(sample_size=self.batch_size, buffer=buffer) - data.update({"gradient_step": str(self.gradient_step)}) - self.log_update_data(data, losses) + update_stat = self.policy.update(sample_size=self.batch_size, buffer=buffer) + self._update_moving_avg_stats_and_log_update_data(update_stat) + return update_stat class OfflineTrainer(BaseTrainer): @@ -512,12 +555,14 @@ class OfflineTrainer(BaseTrainer): def policy_update_fn( self, - data: dict[str, Any], - result: dict[str, Any] | None = None, - ) -> None: + collect_stats: CollectStatsBase | None = None, + ) -> TrainingStats: """Perform one off-line policy update.""" assert self.buffer - self._sample_and_update(self.buffer, data) + update_stat = self._sample_and_update(self.buffer) + # logging + self.policy_update_time += update_stat.train_time + return update_stat class OffpolicyTrainer(BaseTrainer): @@ -532,20 +577,31 @@ class OffpolicyTrainer(BaseTrainer): assert isinstance(BaseTrainer.__doc__, str) __doc__ += BaseTrainer.gen_doc("offpolicy") + "\n".join(BaseTrainer.__doc__.split("\n")[1:]) - def policy_update_fn(self, data: dict[str, Any], result: dict[str, Any]) -> None: - """Perform off-policy updates. + def policy_update_fn( + self, + # TODO: this is the only implementation where collect_stats is actually needed. Maybe change interface? + collect_stats: CollectStatsBase, + ) -> TrainingStats: + """Perform `update_per_step * n_collected_steps` gradient steps by sampling mini-batches from the buffer. - :param data: - :param result: must contain `n/st` key, see documentation of - `:meth:~tianshou.data.collector.Collector.collect` for the kind of - data returned there. `n/st` stands for `step_count` + :param collect_stats: the :class:`~TrainingStats` instance returned by the last gradient step. Some values + in it will be replaced by their moving averages. """ assert self.train_collector is not None - n_collected_steps = result["n/st"] - # Same as training intensity, right? - num_updates = round(self.update_per_step * n_collected_steps) - for _ in range(num_updates): - self._sample_and_update(self.train_collector.buffer, data) + n_collected_steps = collect_stats.n_collected_steps + n_gradient_steps = round(self.update_per_step * n_collected_steps) + if n_gradient_steps == 0: + raise ValueError( + f"n_gradient_steps is 0, n_collected_steps={n_collected_steps}, " + f"update_per_step={self.update_per_step}", + ) + for _ in range(n_gradient_steps): + update_stat = self._sample_and_update(self.train_collector.buffer) + + # logging + self.policy_update_time += update_stat.train_time + # TODO: only the last update_stat is returned, should be improved + return update_stat class OnpolicyTrainer(BaseTrainer): @@ -561,12 +617,11 @@ class OnpolicyTrainer(BaseTrainer): def policy_update_fn( self, - data: dict[str, Any], - result: dict[str, Any] | None = None, - ) -> None: - """Perform one on-policy update.""" + result: CollectStatsBase | None = None, + ) -> TrainingStats: + """Perform one on-policy update by passing the entire buffer to the policy's update method.""" assert self.train_collector is not None - losses = self.policy.update( + training_stat = self.policy.update( sample_size=0, buffer=self.train_collector.buffer, # Note: sample_size is None, so the whole buffer is used for the update. @@ -578,13 +633,14 @@ def policy_update_fn( ) # just for logging, no functional role + self.policy_update_time += training_stat.train_time # TODO: remove the gradient step counting in trainers? Doesn't seem like # it's important and it adds complexity - self.gradient_step += 1 + self._gradient_step += 1 if self.batch_size is None: - self.gradient_step += 1 + self._gradient_step += 1 elif self.batch_size > 0: - self.gradient_step += int((len(self.train_collector.buffer) - 0.1) // self.batch_size) + self._gradient_step += int((len(self.train_collector.buffer) - 0.1) // self.batch_size) # Note: this is the main difference to the off-policy trainer! # The second difference is that batches of data are sampled without replacement @@ -593,4 +649,6 @@ def policy_update_fn( self.train_collector.reset_buffer(keep_statistics=True) # The step is the number of mini-batches used for the update, so essentially - self.log_update_data(data, losses) + self._update_moving_avg_stats_and_log_update_data(training_stat) + + return training_stat diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index e8193d5a8..300c7c470 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -1,10 +1,16 @@ import time from collections.abc import Callable -from typing import Any +from dataclasses import asdict import numpy as np -from tianshou.data import Collector +from tianshou.data import ( + Collector, + CollectStats, + InfoStats, + SequenceSummaryStats, + TimingStats, +) from tianshou.policy import BasePolicy from tianshou.utils import BaseLogger @@ -18,7 +24,7 @@ def test_episode( logger: BaseLogger | None = None, global_step: int | None = None, reward_metric: Callable[[np.ndarray], np.ndarray] | None = None, -) -> dict[str, Any]: +) -> CollectStats: """A simple wrapper of testing policy in collector.""" collector.reset_env() collector.reset_buffer() @@ -26,72 +32,72 @@ def test_episode( if test_fn: test_fn(epoch, global_step) result = collector.collect(n_episode=n_episode) - if reward_metric: - rew = reward_metric(result["rews"]) - result.update(rews=rew, rew=rew.mean(), rew_std=rew.std()) + if reward_metric: # TODO: move into collector + rew = reward_metric(result.returns) + result.returns = rew + result.returns_stat = SequenceSummaryStats.from_sequence(rew) if logger and global_step is not None: - logger.log_test_data(result, global_step) + assert result.n_collected_episodes > 0 + logger.log_test_data(asdict(result), global_step) return result def gather_info( start_time: float, - train_collector: Collector | None, - test_collector: Collector | None, + policy_update_time: float, + gradient_step: int, best_reward: float, best_reward_std: float, -) -> dict[str, float | str]: + train_collector: Collector | None = None, + test_collector: Collector | None = None, +) -> InfoStats: """A simple wrapper of gathering information from collectors. - :return: A dictionary with the following keys: + :return: A dataclass object with the following members (depending on available collectors): + * ``gradient_step`` the total number of gradient steps; + * ``best_reward`` the best reward over the test results; + * ``best_reward_std`` the standard deviation of best reward over the test results; * ``train_step`` the total collected step of training collector; * ``train_episode`` the total collected episode of training collector; - * ``train_time/collector`` the time for collecting transitions in the \ - training collector; - * ``train_time/model`` the time for training models; - * ``train_speed`` the speed of training (env_step per second); * ``test_step`` the total collected step of test collector; * ``test_episode`` the total collected episode of test collector; + * ``timing`` the timing statistics, with the following members: + * ``total_time`` the total time elapsed; + * ``train_time`` the total time elapsed for learning training (collecting samples plus model update); + * ``train_time_collect`` the time for collecting transitions in the \ + training collector; + * ``train_time_update`` the time for training models; * ``test_time`` the time for testing; - * ``test_speed`` the speed of testing (env_step per second); - * ``best_reward`` the best reward over the test results; - * ``duration`` the total elapsed time. + * ``update_speed`` the speed of updating (env_step per second). """ - duration = max(0, time.time() - start_time) - model_time = duration - result: dict[str, float | str] = { - "duration": f"{duration:.2f}s", - "train_time/model": f"{model_time:.2f}s", - } + duration = max(0.0, time.time() - start_time) + test_time = 0.0 + update_speed = 0.0 + train_time_collect = 0.0 if test_collector is not None: - model_time = max(0, duration - test_collector.collect_time) - test_speed = test_collector.collect_step / test_collector.collect_time - result.update( - { - "test_step": test_collector.collect_step, - "test_episode": test_collector.collect_episode, - "test_time": f"{test_collector.collect_time:.2f}s", - "test_speed": f"{test_speed:.2f} step/s", - "best_reward": best_reward, - "best_result": f"{best_reward:.2f} ± {best_reward_std:.2f}", - "duration": f"{duration:.2f}s", - "train_time/model": f"{model_time:.2f}s", - }, - ) + test_time = test_collector.collect_time + if train_collector is not None: - model_time = max(0, model_time - train_collector.collect_time) - if test_collector is not None: - train_speed = train_collector.collect_step / (duration - test_collector.collect_time) - else: - train_speed = train_collector.collect_step / duration - result.update( - { - "train_step": train_collector.collect_step, - "train_episode": train_collector.collect_episode, - "train_time/collector": f"{train_collector.collect_time:.2f}s", - "train_time/model": f"{model_time:.2f}s", - "train_speed": f"{train_speed:.2f} step/s", - }, - ) - return result + train_time_collect = train_collector.collect_time + update_speed = train_collector.collect_step / (duration - test_time) + + timing_stat = TimingStats( + total_time=duration, + train_time=duration - test_time, + train_time_collect=train_time_collect, + train_time_update=policy_update_time, + test_time=test_time, + update_speed=update_speed, + ) + + return InfoStats( + gradient_step=gradient_step, + best_reward=best_reward, + best_reward_std=best_reward_std, + train_step=train_collector.collect_step if train_collector is not None else 0, + train_episode=train_collector.collect_episode if train_collector is not None else 0, + test_step=test_collector.collect_step if test_collector is not None else 0, + test_episode=test_collector.collect_episode if test_collector is not None else 0, + timing=timing_stat, + ) diff --git a/tianshou/utils/conversion.py b/tianshou/utils/conversion.py new file mode 100644 index 000000000..bae2db331 --- /dev/null +++ b/tianshou/utils/conversion.py @@ -0,0 +1,25 @@ +from typing import overload + +import torch + + +@overload +def to_optional_float(x: torch.Tensor) -> float: + ... + + +@overload +def to_optional_float(x: float) -> float: + ... + + +@overload +def to_optional_float(x: None) -> None: + ... + + +def to_optional_float(x: torch.Tensor | float | None) -> float | None: + """For the common case where one needs to extract a float from a scalar Tensor, which may be None.""" + if isinstance(x, torch.Tensor): + return x.item() + return x diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/base.py index 92ba41d35..74e0fa4cc 100644 --- a/tianshou/utils/logger/base.py +++ b/tianshou/utils/logger/base.py @@ -1,10 +1,23 @@ +import typing from abc import ABC, abstractmethod from collections.abc import Callable +from enum import Enum from numbers import Number +from typing import Any import numpy as np -LOG_DATA_TYPE = dict[str, int | Number | np.number | np.ndarray] +VALID_LOG_VALS_TYPE = int | Number | np.number | np.ndarray +VALID_LOG_VALS = typing.get_args( + VALID_LOG_VALS_TYPE, +) # I know it's stupid, but we can't use Union type in isinstance + + +class DataScope(Enum): + TRAIN = "train" + TEST = "test" + UPDATE = "update" + INFO = "info" class BaseLogger(ABC): @@ -15,6 +28,7 @@ class BaseLogger(ABC): :param train_interval: the log interval in log_train_data(). Default to 1000. :param test_interval: the log interval in log_test_data(). Default to 1. :param update_interval: the log interval in log_update_data(). Default to 1000. + :param info_interval: the log interval in log_info_data(). Default to 1. """ def __init__( @@ -22,17 +36,20 @@ def __init__( train_interval: int = 1000, test_interval: int = 1, update_interval: int = 1000, + info_interval: int = 1, ) -> None: super().__init__() self.train_interval = train_interval self.test_interval = test_interval self.update_interval = update_interval + self.info_interval = info_interval self.last_log_train_step = -1 self.last_log_test_step = -1 self.last_log_update_step = -1 + self.last_log_info_step = -1 @abstractmethod - def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: + def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None: """Specify how the writer is used to log data. :param str step_type: namespace which the data dict belongs to. @@ -40,53 +57,96 @@ def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: :param data: the data to write with format ``{key: value}``. """ - def log_train_data(self, collect_result: dict, step: int) -> None: + @staticmethod + def prepare_dict_for_logging( + input_dict: dict[str, Any], + parent_key: str = "", + delimiter: str = "/", + exclude_arrays: bool = True, + ) -> dict[str, VALID_LOG_VALS_TYPE]: + """Flattens and filters a nested dictionary by recursively traversing all levels and compressing the keys. + + Filtering is performed with respect to valid logging data types. + + :param input_dict: The nested dictionary to be flattened and filtered. + :param parent_key: The parent key used as a prefix before the input_dict keys. + :param delimiter: The delimiter used to separate the keys. + :param exclude_arrays: Whether to exclude numpy arrays from the output. + :return: A flattened dictionary where the keys are compressed and values are filtered. + """ + result = {} + + def add_to_result( + cur_dict: dict, + prefix: str = "", + ) -> None: + for key, value in cur_dict.items(): + if exclude_arrays and isinstance(value, np.ndarray): + continue + + new_key = prefix + delimiter + key + new_key = new_key.lstrip(delimiter) + + if isinstance(value, dict): + add_to_result( + value, + new_key, + ) + elif isinstance(value, VALID_LOG_VALS): + result[new_key] = value + + add_to_result(input_dict, prefix=parent_key) + return result + + def log_train_data(self, log_data: dict, step: int) -> None: """Use writer to log statistics generated during training. - :param collect_result: a dict containing information of data collected in - training stage, i.e., returns of collector.collect(). - :param step: stands for the timestep the collect_result being logged. + :param log_data: a dict containing the information returned by the collector during the train step. + :param step: stands for the timestep the collector result is logged. """ - if collect_result["n/ep"] > 0 and step - self.last_log_train_step >= self.train_interval: - log_data = { - "train/episode": collect_result["n/ep"], - "train/reward": collect_result["rew"], - "train/length": collect_result["len"], - } + # TODO: move interval check to calling method + if step - self.last_log_train_step >= self.train_interval: + log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.TRAIN.value) self.write("train/env_step", step, log_data) self.last_log_train_step = step - def log_test_data(self, collect_result: dict, step: int) -> None: + def log_test_data(self, log_data: dict, step: int) -> None: """Use writer to log statistics generated during evaluating. - :param collect_result: a dict containing information of data collected in - evaluating stage, i.e., returns of collector.collect(). - :param step: stands for the timestep the collect_result being logged. + :param log_data:a dict containing the information returned by the collector during the evaluation step. + :param step: stands for the timestep the collector result is logged. """ - assert collect_result["n/ep"] > 0 + # TODO: move interval check to calling method (stupid because log_test_data is only called from function in utils.py, not from BaseTrainer) if step - self.last_log_test_step >= self.test_interval: - log_data = { - "test/env_step": step, - "test/reward": collect_result["rew"], - "test/length": collect_result["len"], - "test/reward_std": collect_result["rew_std"], - "test/length_std": collect_result["len_std"], - } - self.write("test/env_step", step, log_data) + log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.TEST.value) + self.write(DataScope.TEST.value + "/env_step", step, log_data) self.last_log_test_step = step - def log_update_data(self, update_result: dict, step: int) -> None: + def log_update_data(self, log_data: dict, step: int) -> None: """Use writer to log statistics generated during updating. - :param update_result: a dict containing information of data collected in - updating stage, i.e., returns of policy.update(). - :param step: stands for the timestep the collect_result being logged. + :param log_data:a dict containing the information returned during the policy update step. + :param step: stands for the timestep the policy training data is logged. """ + # TODO: move interval check to calling method if step - self.last_log_update_step >= self.update_interval: - log_data = {f"update/{k}": v for k, v in update_result.items()} - self.write("update/gradient_step", step, log_data) + log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.UPDATE.value) + self.write(DataScope.UPDATE.value + "/gradient_step", step, log_data) self.last_log_update_step = step + def log_info_data(self, log_data: dict, step: int) -> None: + """Use writer to log global statistics. + + :param log_data: a dict containing information of data collected at the end of an epoch. + :param step: stands for the timestep the training info is logged. + """ + if ( + step - self.last_log_info_step >= self.info_interval + ): # TODO: move interval check to calling method + log_data = self.prepare_dict_for_logging(log_data, parent_key=DataScope.INFO.value) + self.write(DataScope.INFO.value + "/epoch", step, log_data) + self.last_log_info_step = step + @abstractmethod def save_data( self, @@ -121,7 +181,7 @@ class LazyLogger(BaseLogger): def __init__(self) -> None: super().__init__() - def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: + def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None: """The LazyLogger writes nothing.""" def save_data( diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py index e4e6ea9df..2a26963a5 100644 --- a/tianshou/utils/logger/tensorboard.py +++ b/tianshou/utils/logger/tensorboard.py @@ -4,7 +4,7 @@ from tensorboard.backend.event_processing import event_accumulator from torch.utils.tensorboard import SummaryWriter -from tianshou.utils.logger.base import LOG_DATA_TYPE, BaseLogger +from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE, BaseLogger from tianshou.utils.warning import deprecation @@ -15,6 +15,7 @@ class TensorboardLogger(BaseLogger): :param train_interval: the log interval in log_train_data(). Default to 1000. :param test_interval: the log interval in log_test_data(). Default to 1. :param update_interval: the log interval in log_update_data(). Default to 1000. + :param info_interval: the log interval in log_info_data(). Default to 1. :param save_interval: the save interval in save_data(). Default to 1 (save at the end of each epoch). :param write_flush: whether to flush tensorboard result after each @@ -27,16 +28,17 @@ def __init__( train_interval: int = 1000, test_interval: int = 1, update_interval: int = 1000, + info_interval: int = 1, save_interval: int = 1, write_flush: bool = True, ) -> None: - super().__init__(train_interval, test_interval, update_interval) + super().__init__(train_interval, test_interval, update_interval, info_interval) self.save_interval = save_interval self.write_flush = write_flush self.last_save_step = -1 self.writer = writer - def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: + def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None: for k, v in data.items(): self.writer.add_scalar(k, v, global_step=step) if self.write_flush: # issue 580 diff --git a/tianshou/utils/logger/wandb.py b/tianshou/utils/logger/wandb.py index e984d464d..53dbf107a 100644 --- a/tianshou/utils/logger/wandb.py +++ b/tianshou/utils/logger/wandb.py @@ -6,7 +6,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.utils import BaseLogger, TensorboardLogger -from tianshou.utils.logger.base import LOG_DATA_TYPE +from tianshou.utils.logger.base import VALID_LOG_VALS_TYPE with contextlib.suppress(ImportError): import wandb @@ -30,6 +30,7 @@ class WandbLogger(BaseLogger): :param test_interval: the log interval in log_test_data(). Default to 1. :param update_interval: the log interval in log_update_data(). Default to 1000. + :param info_interval: the log interval in log_info_data(). Default to 1. :param save_interval: the save interval in save_data(). Default to 1 (save at the end of each epoch). :param write_flush: whether to flush tensorboard result after each @@ -46,6 +47,7 @@ def __init__( train_interval: int = 1000, test_interval: int = 1, update_interval: int = 1000, + info_interval: int = 1, save_interval: int = 1000, write_flush: bool = True, project: str | None = None, @@ -55,7 +57,7 @@ def __init__( config: argparse.Namespace | dict | None = None, monitor_gym: bool = True, ) -> None: - super().__init__(train_interval, test_interval, update_interval) + super().__init__(train_interval, test_interval, update_interval, info_interval) self.last_save_step = -1 self.save_interval = save_interval self.write_flush = write_flush @@ -91,7 +93,7 @@ def load(self, writer: SummaryWriter) -> None: self.write_flush, ) - def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: + def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None: if self.tensorboard_logger is None: raise Exception( "`logger` needs to load the Tensorboard Writer before " diff --git a/tianshou/utils/logging.py b/tianshou/utils/logging.py index 607f09c21..dcda429e6 100644 --- a/tianshou/utils/logging.py +++ b/tianshou/utils/logging.py @@ -20,6 +20,22 @@ _logFormat = LOG_DEFAULT_FORMAT +def set_numerical_fields_to_precision(data: dict[str, Any], precision: int = 3) -> dict[str, Any]: + """Returns a copy of the given dictionary with all numerical values rounded to the given precision. + + Note: does not recurse into nested dictionaries. + + :param data: a dictionary + :param precision: the precision to be used + """ + result = {} + for k, v in data.items(): + if isinstance(v, float): + v = round(v, precision) + result[k] = v + return result + + def remove_log_handlers() -> None: """Removes all current log handlers.""" logger = getLogger() diff --git a/tianshou/utils/optim.py b/tianshou/utils/optim.py index 0c1093cc9..c69ef71db 100644 --- a/tianshou/utils/optim.py +++ b/tianshou/utils/optim.py @@ -8,19 +8,24 @@ def optim_step( loss: torch.Tensor, optim: torch.optim.Optimizer, - module: nn.Module, + module: nn.Module | None = None, max_grad_norm: float | None = None, ) -> None: - """Perform a single optimization step. + """Perform a single optimization step: zero_grad -> backward (-> clip_grad_norm) -> step. :param loss: :param optim: - :param module: + :param module: the module to optimize, required if max_grad_norm is passed :param max_grad_norm: if passed, will clip gradients using this """ optim.zero_grad() loss.backward() if max_grad_norm: + if not module: + raise ValueError( + "module must be passed if max_grad_norm is passed. " + "Note: often the module will be the policy, i.e.`self`", + ) nn.utils.clip_grad_norm_(module.parameters(), max_norm=max_grad_norm) optim.step() diff --git a/tianshou/utils/statistics.py b/tianshou/utils/statistics.py index 1f336c59c..779b0babe 100644 --- a/tianshou/utils/statistics.py +++ b/tianshou/utils/statistics.py @@ -29,7 +29,10 @@ def __init__(self, size: int = 100) -> None: self.cache: list[np.number] = [] self.banned = [np.inf, np.nan, -np.inf] - def add(self, data_array: Number | np.number | list | np.ndarray | torch.Tensor) -> float: + def add( + self, + data_array: Number | float | np.number | list | np.ndarray | torch.Tensor, + ) -> float: """Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with only one element, a python scalar, or