Skip to content

Commit

Permalink
Merged with thuml-master
Browse files Browse the repository at this point in the history
  • Loading branch information
carlocagnetta committed Feb 2, 2024
2 parents e2a14ed + 4756ee8 commit f53a496
Show file tree
Hide file tree
Showing 37 changed files with 986 additions and 430 deletions.
242 changes: 182 additions & 60 deletions README.md

Large diffs are not rendered by default.

Binary file added docs/_static/images/discrete_dqn_hl.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
33 changes: 0 additions & 33 deletions examples/atari/atari_callbacks.py

This file was deleted.

20 changes: 10 additions & 10 deletions examples/atari/atari_dqn_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,11 @@

import os

from examples.atari.atari_callbacks import (
TestEpochCallbackDQNSetEps,
TrainEpochCallbackNatureDQNEpsLinearDecay,
)
from examples.atari.atari_network import (
IntermediateModuleFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
DQNExperimentBuilder,
Expand All @@ -20,14 +16,18 @@
from tianshou.highlevel.params.policy_wrapper import (
PolicyWrapperFactoryIntrinsicCuriosity,
)
from tianshou.highlevel.trainer import (
EpochTestCallbackDQNSetEps,
EpochTrainCallbackDQNEpsLinearDecay,
)
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag


def main(
experiment_config: ExperimentConfig,
task: str = "PongNoFrameskip-v4",
scale_obs: int = 0,
scale_obs: bool = False,
eps_test: float = 0.005,
eps_train: float = 1.0,
eps_train_final: float = 0.05,
Expand Down Expand Up @@ -79,11 +79,11 @@ def main(
),
)
.with_model_factory(IntermediateModuleFactoryAtariDQN())
.with_trainer_epoch_callback_train(
TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final),
.with_epoch_train_callback(
EpochTrainCallbackDQNEpsLinearDecay(eps_train, eps_train_final),
)
.with_trainer_epoch_callback_test(TestEpochCallbackDQNSetEps(eps_test))
.with_trainer_stop_callback(AtariStopCallback(task))
.with_epoch_test_callback(EpochTestCallbackDQNSetEps(eps_test))
.with_epoch_stop_callback(AtariEpochStopCallback(task))
)
if icm_lr_scale > 0:
builder.with_policy_wrapper_factory(
Expand Down
20 changes: 10 additions & 10 deletions examples/atari/atari_iqn_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,28 @@
import os
from collections.abc import Sequence

from examples.atari.atari_callbacks import (
TestEpochCallbackDQNSetEps,
TrainEpochCallbackNatureDQNEpsLinearDecay,
)
from examples.atari.atari_network import (
IntermediateModuleFactoryAtariDQN,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
ExperimentConfig,
IQNExperimentBuilder,
)
from tianshou.highlevel.params.policy_params import IQNParams
from tianshou.highlevel.trainer import (
EpochTestCallbackDQNSetEps,
EpochTrainCallbackDQNEpsLinearDecay,
)
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag


def main(
experiment_config: ExperimentConfig,
task: str = "PongNoFrameskip-v4",
scale_obs: int = 0,
scale_obs: bool = False,
eps_test: float = 0.005,
eps_train: float = 1.0,
eps_train_final: float = 0.05,
Expand Down Expand Up @@ -83,11 +83,11 @@ def main(
),
)
.with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(features_only=True))
.with_trainer_epoch_callback_train(
TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final),
.with_epoch_train_callback(
EpochTrainCallbackDQNEpsLinearDecay(eps_train, eps_train_final),
)
.with_trainer_epoch_callback_test(TestEpochCallbackDQNSetEps(eps_test))
.with_trainer_stop_callback(AtariStopCallback(task))
.with_epoch_test_callback(EpochTestCallbackDQNSetEps(eps_test))
.with_epoch_stop_callback(AtariEpochStopCallback(task))
.build()
)
experiment.run(log_name)
Expand Down
37 changes: 23 additions & 14 deletions examples/atari/atari_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,27 @@ def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.
return layer


def scale_obs(module: type[nn.Module], denom: float = 255.0) -> type[nn.Module]:
class scaled_module(module):
def forward(
self,
obs: np.ndarray | torch.Tensor,
state: Any | None = None,
info: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, Any]:
if info is None:
info = {}
return super().forward(obs / denom, state, info)
class ScaledObsInputModule(torch.nn.Module):
def __init__(self, module: torch.nn.Module, denom: float = 255.0):
super().__init__()
self.module = module
self.denom = denom
# This is required such that the value can be retrieved by downstream modules (see usages of get_output_dim)
self.output_dim = module.output_dim

def forward(
self,
obs: np.ndarray | torch.Tensor,
state: Any | None = None,
info: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, Any]:
if info is None:
info = {}
return self.module.forward(obs / self.denom, state, info)


return scaled_module
def scale_obs(module: nn.Module, denom: float = 255.0) -> nn.Module:
return ScaledObsInputModule(module, denom=denom)


class DQN(nn.Module):
Expand Down Expand Up @@ -238,15 +246,16 @@ def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool, features_o
self.features_only = features_only

def create_module(self, envs: Environments, device: TDevice) -> Actor:
net_cls = scale_obs(DQN) if self.scale_obs else DQN
net = net_cls(
net = DQN(
*envs.get_observation_shape(),
envs.get_action_shape(),
device=device,
features_only=self.features_only,
output_dim=self.hidden_size,
layer_init=layer_init,
)
if self.scale_obs:
net = scale_obs(net)
return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device)


Expand Down
5 changes: 3 additions & 2 deletions examples/atari/atari_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,16 @@ def test_ppo(args=get_args()):
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# define model
net_cls = scale_obs(DQN) if args.scale_obs else DQN
net = net_cls(
net = DQN(
*args.state_shape,
args.action_shape,
device=args.device,
features_only=True,
output_dim=args.hidden_size,
layer_init=layer_init,
)
if args.scale_obs:
net = scale_obs(net)
actor = Actor(net, args.action_shape, device=args.device, softmax_output=False)
critic = Critic(net, device=args.device)
optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr, eps=1e-5)
Expand Down
4 changes: 2 additions & 2 deletions examples/atari/atari_ppo_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ActorFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
ExperimentConfig,
Expand Down Expand Up @@ -95,7 +95,7 @@ def main(
)
.with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs, features_only=True))
.with_critic_factory_use_actor()
.with_trainer_stop_callback(AtariStopCallback(task))
.with_epoch_stop_callback(AtariEpochStopCallback(task))
)
if icm_lr_scale > 0:
builder.with_policy_wrapper_factory(
Expand Down
6 changes: 3 additions & 3 deletions examples/atari/atari_sac_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
ActorFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
DiscreteSACExperimentBuilder,
Expand All @@ -24,7 +24,7 @@
def main(
experiment_config: ExperimentConfig,
task: str = "PongNoFrameskip-v4",
scale_obs: int = 0,
scale_obs: bool = False,
buffer_size: int = 100000,
actor_lr: float = 1e-5,
critic_lr: float = 1e-5,
Expand Down Expand Up @@ -82,7 +82,7 @@ def main(
)
.with_actor_factory(ActorFactoryAtariDQN(hidden_size, scale_obs=False, features_only=True))
.with_common_critic_factory_use_actor()
.with_trainer_stop_callback(AtariStopCallback(task))
.with_epoch_stop_callback(AtariEpochStopCallback(task))
)
if icm_lr_scale > 0:
builder.with_policy_wrapper_factory(
Expand Down
Loading

0 comments on commit f53a496

Please sign in to comment.