Skip to content

Commit

Permalink
Feature/dataclasses (#996)
Browse files Browse the repository at this point in the history
This PR adds strict typing to the output of `update` and `learn` in all
policies. This will likely be the last large refactoring PR before the
next release (0.6.0, not 1.0.0), so it requires some attention. Several
difficulties were encountered on the path to that goal:

1. The policy hierarchy is actually "broken" in the sense that the keys
of dicts that were output by `learn` did not follow the same enhancement
(inheritance) pattern as the policies. This is a real problem and should
be addressed in the near future. Generally, several aspects of the
policy design and hierarchy might deserve a dedicated discussion.
2. Each policy needs to be generic in the stats return type, because one
might want to extend it at some point and then also extend the stats.
Even within the source code base this pattern is necessary in many
places.
3. The interaction between learn and update is a bit quirky, we
currently handle it by having update modify special field inside
TrainingStats, whereas all other fields are handled by learn.
4. The IQM module is a policy wrapper and required a
TrainingStatsWrapper. The latter relies on a bunch of black magic.

They were addressed by:
1. Live with the broken hierarchy, which is now made visible by bounds
in generics. We use type: ignore where appropriate.
2. Make all policies generic with bounds following the policy
inheritance hierarchy (which is incorrect, see above). We experimented a
bit with nested TrainingStats classes, but that seemed to add more
complexity and be harder to understand. Unfortunately, mypy thinks that
the code below is wrong, wherefore we have to add `type: ignore` to the
return of each `learn`

```python

T = TypeVar("T", bound=int)


def f() -> T:
  return 3
```

3. See above
4. Write representative tests for the `TrainingStatsWrapper`. Still, the
black magic might cause nasty surprises down the line (I am not proud of
it)...

Closes #933

---------

Co-authored-by: Maximilian Huettenrauch <[email protected]>
Co-authored-by: Michael Panchenko <[email protected]>
  • Loading branch information
3 people authored Dec 30, 2023
1 parent 5d09645 commit 522f7fb
Show file tree
Hide file tree
Showing 126 changed files with 1,767 additions and 823 deletions.
2 changes: 1 addition & 1 deletion docs/02_notebooks/L0_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@
"# Let's watch its performance!\n",
"policy.eval()\n",
"result = test_collector.collect(n_episode=1, render=False)\n",
"print(\"Final reward: {}, length: {}\".format(result[\"rews\"].mean(), result[\"lens\"].mean()))"
"print(\"Final reward: {}, length: {}\".format(result.returns.mean(), result.lens.mean()))"
]
},
{
Expand Down
91 changes: 61 additions & 30 deletions docs/02_notebooks/L4_Policy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,19 @@
},
"outputs": [],
"source": [
"from typing import Dict, List\n",
"from typing import cast\n",
"\n",
"import numpy as np\n",
"import torch\n",
"import gymnasium as gym\n",
"\n",
"from tianshou.data import Batch, ReplayBuffer, to_torch, to_torch_as\n",
"from tianshou.policy import BasePolicy\n",
"from dataclasses import dataclass\n",
"\n",
"from tianshou.data import Batch, ReplayBuffer, to_torch, to_torch_as, SequenceSummaryStats\n",
"from tianshou.policy import BasePolicy, TrainingStats\n",
"from tianshou.utils.net.common import Net\n",
"from tianshou.utils.net.discrete import Actor"
"from tianshou.utils.net.discrete import Actor\n",
"from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol"
]
},
{
Expand Down Expand Up @@ -102,12 +105,14 @@
"\n",
"\n",
"\n",
"1. Since Tianshou is a **Deep** RL libraries, there should be a policy network in our Policy Module, also a Torch optimizer.\n",
"2. In Tianshou's BasePolicy, `Policy.update()` first calls `Policy.process_fn()` to preprocess training data and computes quantities like episodic returns (gradient free), then it will call `Policy.learn()` to perform the back-propagation.\n",
"1. Since Tianshou is a **Deep** RL libraries, there should be a policy network in our Policy Module, \n",
"also a Torch optimizer.\n",
"2. In Tianshou's BasePolicy, `Policy.update()` first calls `Policy.process_fn()` to \n",
"preprocess training data and computes quantities like episodic returns (gradient free), \n",
"then it will call `Policy.learn()` to perform the back-propagation.\n",
"3. Each Policy is accompanied by a dedicated implementation of `TrainingStats` to store details of training.\n",
"\n",
"Then we get the implementation below.\n",
"\n",
"\n",
"\n"
]
},
Expand All @@ -119,7 +124,14 @@
},
"outputs": [],
"source": [
"class REINFORCEPolicy(BasePolicy):\n",
"@dataclass(kw_only=True)\n",
"class REINFORCETrainingStats(TrainingStats):\n",
" \"\"\"A dedicated class for REINFORCE training statistics.\"\"\"\n",
"\n",
" loss: SequenceSummaryStats\n",
"\n",
"\n",
"class REINFORCEPolicy(BasePolicy[REINFORCETrainingStats]):\n",
" \"\"\"Implementation of REINFORCE algorithm.\"\"\"\n",
"\n",
" def __init__(\n",
Expand All @@ -138,7 +150,7 @@
" \"\"\"Compute the discounted returns for each transition.\"\"\"\n",
" pass\n",
"\n",
" def learn(self, batch: Batch, batch_size: int, repeat: int) -> Dict[str, List[float]]:\n",
" def learn(self, batch: Batch, batch_size: int, repeat: int) -> REINFORCETrainingStats:\n",
" \"\"\"Perform the back-propagation.\"\"\"\n",
" return"
]
Expand Down Expand Up @@ -220,7 +232,9 @@
},
"source": [
"### Policy.learn()\n",
"Data batch returned by `Policy.process_fn()` will flow into `Policy.learn()`. Final we can construct our loss function and perform the back-propagation."
"Data batch returned by `Policy.process_fn()` will flow into `Policy.learn()`. Finally,\n",
"we can construct our loss function and perform the back-propagation. The method \n",
"should look something like this:"
]
},
{
Expand All @@ -231,22 +245,24 @@
},
"outputs": [],
"source": [
"def learn(self, batch: Batch, batch_size: int, repeat: int) -> Dict[str, List[float]]:\n",
"from tianshou.utils.optim import optim_step\n",
"\n",
"\n",
"def learn(self, batch: Batch, batch_size: int, repeat: int):\n",
" \"\"\"Perform the back-propagation.\"\"\"\n",
" logging_losses = []\n",
" train_losses = []\n",
" for _ in range(repeat):\n",
" for minibatch in batch.split(batch_size, merge_last=True):\n",
" self.optim.zero_grad()\n",
" result = self(minibatch)\n",
" dist = result.dist\n",
" act = to_torch_as(minibatch.act, result.act)\n",
" ret = to_torch(minibatch.returns, torch.float, result.act.device)\n",
" log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)\n",
" loss = -(log_prob * ret).mean()\n",
" loss.backward()\n",
" self.optim.step()\n",
" logging_losses.append(loss.item())\n",
" return {\"loss\": logging_losses}"
" optim_step(loss, self.optim)\n",
" train_losses.append(loss.item())\n",
"\n",
" return REINFORCETrainingStats(loss=SequenceSummaryStats.from_sequence(train_losses))"
]
},
{
Expand All @@ -256,7 +272,12 @@
},
"source": [
"## Implementation\n",
"Finally we can assemble the implemented methods and form a REINFORCE Policy."
"Now we can assemble the methods and form a REINFORCE Policy. The outputs of\n",
"`learn` will be collected to a dedicated dataclass.\n",
"\n",
"We will also use protocols to specify what fields are expected and produced inside a `Batch` in\n",
"each processing step. By using protocols, we can get better type checking and IDE support \n",
"without having to implement a separate class for each combination of fields."
]
},
{
Expand Down Expand Up @@ -290,30 +311,33 @@
" act = dist.sample()\n",
" return Batch(act=act, dist=dist)\n",
"\n",
" def process_fn(self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray) -> Batch:\n",
" def process_fn(\n",
" self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray\n",
" ) -> BatchWithReturnsProtocol:\n",
" \"\"\"Compute the discounted returns for each transition.\"\"\"\n",
" returns, _ = self.compute_episodic_return(\n",
" batch, buffer, indices, gamma=0.99, gae_lambda=1.0\n",
" )\n",
" batch.returns = returns\n",
" return batch\n",
" return cast(BatchWithReturnsProtocol, batch)\n",
"\n",
" def learn(self, batch: Batch, batch_size: int, repeat: int) -> Dict[str, List[float]]:\n",
" def learn(\n",
" self, batch: BatchWithReturnsProtocol, batch_size: int, repeat: int\n",
" ) -> REINFORCETrainingStats:\n",
" \"\"\"Perform the back-propagation.\"\"\"\n",
" logging_losses = []\n",
" train_losses = []\n",
" for _ in range(repeat):\n",
" for minibatch in batch.split(batch_size, merge_last=True):\n",
" self.optim.zero_grad()\n",
" result = self(minibatch)\n",
" dist = result.dist\n",
" act = to_torch_as(minibatch.act, result.act)\n",
" ret = to_torch(minibatch.returns, torch.float, result.act.device)\n",
" log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)\n",
" loss = -(log_prob * ret).mean()\n",
" loss.backward()\n",
" self.optim.step()\n",
" logging_losses.append(loss.item())\n",
" return {\"loss\": logging_losses}"
" optim_step(loss, self.optim)\n",
" train_losses.append(loss.item())\n",
"\n",
" return REINFORCETrainingStats(loss=SequenceSummaryStats.from_sequence(train_losses))"
]
},
{
Expand Down Expand Up @@ -370,8 +394,8 @@
"source": [
"print(policy)\n",
"print(\"========================================\")\n",
"for para in policy.parameters():\n",
" print(para.shape)"
"for param in policy.parameters():\n",
" print(param.shape)"
]
},
{
Expand Down Expand Up @@ -831,6 +855,13 @@
"<img src=../_static/images/policy_table.svg></img>\n",
"</center>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
14 changes: 7 additions & 7 deletions docs/02_notebooks/L5_Collector.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@
"source": [
"collect_result = test_collector.collect(n_episode=9)\n",
"print(collect_result)\n",
"print(\"Rewards of 9 episodes are {}\".format(collect_result[\"rews\"]))\n",
"print(\"Average episode reward is {}.\".format(collect_result[\"rew\"]))\n",
"print(\"Average episode length is {}.\".format(collect_result[\"len\"]))"
"print(f\"Returns of 9 episodes are {collect_result.returns}\")\n",
"print(f\"Average episode return is {collect_result.returns_stat.mean}.\")\n",
"print(f\"Average episode length is {collect_result.lens_stat.mean}.\")"
]
},
{
Expand Down Expand Up @@ -146,9 +146,9 @@
"test_collector.reset()\n",
"collect_result = test_collector.collect(n_episode=9, random=True)\n",
"print(collect_result)\n",
"print(\"Rewards of 9 episodes are {}\".format(collect_result[\"rews\"]))\n",
"print(\"Average episode reward is {}.\".format(collect_result[\"rew\"]))\n",
"print(\"Average episode length is {}.\".format(collect_result[\"len\"]))"
"print(f\"Returns of 9 episodes are {collect_result.returns}\")\n",
"print(f\"Average episode return is {collect_result.returns_stat.mean}.\")\n",
"print(f\"Average episode length is {collect_result.lens_stat.mean}.\")"
]
},
{
Expand All @@ -157,7 +157,7 @@
"id": "sKQRTiG10ljU"
},
"source": [
"Seems that an initialized policy performs even worse than a random policy without any training."
"It seems like an initialized policy performs even worse than a random policy without any training."
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion docs/02_notebooks/L6_Trainer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@
"replaybuffer.reset()\n",
"for i in range(10):\n",
" evaluation_result = test_collector.collect(n_episode=10)\n",
" print(\"Evaluation reward is {}\".format(evaluation_result[\"rew\"]))\n",
" print(f\"Evaluation mean episodic reward is: {evaluation_result.returns.mean()}\")\n",
" train_collector.collect(n_step=2000)\n",
" # 0 means taking all data stored in train_collector.buffer\n",
" policy.update(0, train_collector.buffer, batch_size=512, repeat=1)\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/02_notebooks/L7_Experiment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@
"# Let's watch its performance!\n",
"policy.eval()\n",
"result = test_collector.collect(n_episode=1, render=False)\n",
"print(\"Final reward: {}, length: {}\".format(result[\"rews\"].mean(), result[\"lens\"].mean()))"
"print(f\"Final episode reward: {result.returns.mean()}, length: {result.lens.mean()}\")"
]
}
],
Expand Down
7 changes: 7 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,10 @@ logp
autogenerated
subpackage
subpackages
recurse
rollout
rollouts
prepend
prepends
dict
dicts
2 changes: 1 addition & 1 deletion examples/atari/atari_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def watch():
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
rew = result["rews"].mean()
rew = result.returns_stat.mean
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")

if args.watch:
Expand Down
2 changes: 1 addition & 1 deletion examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def watch():
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
rew = result["rews"].mean()
rew = result.returns_stat.mean
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")

if args.watch:
Expand Down
2 changes: 1 addition & 1 deletion examples/atari/atari_fqf.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def watch():
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
rew = result["rews"].mean()
rew = result.returns_stat.mean
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")

if args.watch:
Expand Down
2 changes: 1 addition & 1 deletion examples/atari/atari_iqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def watch():
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
rew = result["rews"].mean()
rew = result.returns_stat.mean
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")

if args.watch:
Expand Down
2 changes: 1 addition & 1 deletion examples/atari/atari_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def watch():
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
rew = result["rews"].mean()
rew = result.returns_stat.mean
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")

if args.watch:
Expand Down
2 changes: 1 addition & 1 deletion examples/atari/atari_qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def watch():
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
rew = result["rews"].mean()
rew = result.returns_stat.mean
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")

if args.watch:
Expand Down
2 changes: 1 addition & 1 deletion examples/atari/atari_rainbow.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def watch():
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
rew = result["rews"].mean()
rew = result.returns_stat.mean
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")

if args.watch:
Expand Down
2 changes: 1 addition & 1 deletion examples/atari/atari_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def watch():
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
rew = result["rews"].mean()
rew = result.returns_stat.mean
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")

if args.watch:
Expand Down
5 changes: 2 additions & 3 deletions examples/box2d/acrobot_dualdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_fn(epoch, env_step):
logger=logger,
).run()

assert stop_fn(result["best_reward"])
assert stop_fn(result.best_reward)
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
Expand All @@ -139,8 +139,7 @@ def test_fn(epoch, env_step):
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}")


if __name__ == "__main__":
Expand Down
5 changes: 2 additions & 3 deletions examples/box2d/bipedal_bdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_fn(epoch, env_step):
logger=logger,
).run()

# assert stop_fn(result["best_reward"])
# assert stop_fn(result.best_reward)
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
Expand All @@ -154,8 +154,7 @@ def test_fn(epoch, env_step):
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}")


if __name__ == "__main__":
Expand Down
3 changes: 1 addition & 2 deletions examples/box2d/bipedal_hardcore_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,7 @@ def stop_fn(mean_rewards):
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}")


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 522f7fb

Please sign in to comment.