Skip to content

Commit

Permalink
High-level API: Establish a strong link between the actor and the dis…
Browse files Browse the repository at this point in the history
…tribution function (thu-ml#1195)

Establish a strong link between the actor and the distribution function
(dist_fn) used in policies by creating the distribution function in the
actor factory which knows which function is appropriate.

Consequently, remove the policy parameter 'dist_fn' from the high-level
API because it is determined automatically, eliminating the possibility
of misspecification by the user. [breaking change: code must not specify
the 'dist_fn' parameter, but persisted objects continue to work as
expected]

Implements thu-ml#1194

- [X] I have added the correct label(s) to this Pull Request or linked
the relevant issue(s)
- [X] I have provided a description of the changes in this Pull Request
- [X] I have added documentation for my changes and have listed relevant
changes in CHANGELOG.md
- [X] If applicable, I have added tests to cover my changes.
- [X] I have reformatted the code using `poe format` 
- [X] I have checked style and types with `poe lint` and `poe
type-check`
- [ ] (Optional) I ran tests locally with `poe test` 
(or a subset of them with `poe test-reduced`) ,and they pass
- [ ] (Optional) I have tested that documentation builds correctly with
`poe doc-build`
  • Loading branch information
MischaPanch authored Aug 8, 2024
2 parents 9491c79 + 8114dd1 commit bd74273
Show file tree
Hide file tree
Showing 12 changed files with 95 additions and 60 deletions.
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Release 1.1.0

### Api Extensions
### Changes/Improvements
- `evaluation`: New package for repeating the same experiment with multiple seeds and aggregating the results. #1074 #1141 #1183
- `data`:
- `Batch`:
Expand Down Expand Up @@ -107,6 +107,11 @@ continuous and discrete cases. #1032
- `utils.net.common.Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077
- `AtariEnvFactory` constructor (in examples, so not really breaking) now requires explicit train and test seeds. #1074
- `EnvFactoryRegistered` now requires an explicit `test_seed` in the constructor. #1074
- `highlevel`:
- The parameter `dist_fn` has been removed from the parameter objects (`PGParams`, `A2CParams`, `PPOParams`, `NPGParams`, `TRPOParams`).
The correct distribution is now determined automatically based on the actor factory being used, avoiding the possibility of
misspecification. Persisted configurations/policies continue to work as expected, but code must not specify the `dist_fn` parameter.
#1194 #1195


### Tests
Expand Down
16 changes: 15 additions & 1 deletion examples/atari/atari_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
IntermediateModule,
IntermediateModuleFactory,
)
from tianshou.highlevel.params.dist_fn import DistributionFunctionFactoryCategorical
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils.net.common import NetBase
from tianshou.utils.net.discrete import Actor, NoisyLinear

Expand Down Expand Up @@ -246,6 +248,8 @@ def forward(


class ActorFactoryAtariDQN(ActorFactory):
USE_SOFTMAX_OUTPUT = False

def __init__(
self,
scale_obs: bool = True,
Expand Down Expand Up @@ -274,7 +278,17 @@ def create_module(self, envs: Environments, device: TDevice) -> Actor:
)
if self.scale_obs:
net = scale_obs(net)
return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device)
return Actor(
net,
envs.get_action_shape(),
device=device,
softmax_output=self.USE_SOFTMAX_OUTPUT,
).to(device)

def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None:
return DistributionFunctionFactoryCategorical(
is_probs_input=self.USE_SOFTMAX_OUTPUT,
).create_dist_fn(envs)


class IntermediateModuleFactoryAtariDQN(IntermediateModuleFactory):
Expand Down
4 changes: 0 additions & 4 deletions examples/mujoco/mujoco_npg_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
ExperimentConfig,
NPGExperimentBuilder,
)
from tianshou.highlevel.params.dist_fn import (
DistributionFunctionFactoryIndependentGaussians,
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import NPGParams
from tianshou.utils import logging
Expand Down Expand Up @@ -78,7 +75,6 @@ def main(
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
if lr_decay
else None,
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
),
)
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
Expand Down
4 changes: 0 additions & 4 deletions examples/mujoco/mujoco_ppo_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
ExperimentConfig,
PPOExperimentBuilder,
)
from tianshou.highlevel.params.dist_fn import (
DistributionFunctionFactoryIndependentGaussians,
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams
from tianshou.utils import logging
Expand Down Expand Up @@ -88,7 +85,6 @@ def main(
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
if lr_decay
else None,
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
),
)
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
Expand Down
4 changes: 0 additions & 4 deletions examples/mujoco/mujoco_ppo_hl_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@
PPOExperimentBuilder,
)
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.highlevel.params.dist_fn import (
DistributionFunctionFactoryIndependentGaussians,
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams
from tianshou.utils import logging
Expand Down Expand Up @@ -115,7 +112,6 @@ def main(
recompute_advantage=True,
lr=3e-4,
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config),
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
),
)
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
Expand Down
4 changes: 0 additions & 4 deletions examples/mujoco/mujoco_trpo_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
ExperimentConfig,
TRPOExperimentBuilder,
)
from tianshou.highlevel.params.dist_fn import (
DistributionFunctionFactoryIndependentGaussians,
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import TRPOParams
from tianshou.utils import logging
Expand Down Expand Up @@ -82,7 +79,6 @@ def main(
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
if lr_decay
else None,
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
),
)
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
Expand Down
1 change: 1 addition & 0 deletions test/highlevel/test_experiment_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def test_experiment_builder_continuous_default_params(builder_cls: type[Experime
@pytest.mark.parametrize(
"builder_cls",
[
PGExperimentBuilder,
PPOExperimentBuilder,
A2CExperimentBuilder,
DQNExperimentBuilder,
Expand Down
4 changes: 4 additions & 0 deletions tianshou/highlevel/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,14 @@ def _create_policy(self, envs: Environments, device: TDevice) -> PGPolicy:
optim_factory=self.optim_factory,
),
)
dist_fn = self.actor_factory.create_dist_fn(envs)
assert dist_fn is not None
return PGPolicy(
actor=actor.module,
optim=actor.optim,
action_space=envs.get_action_space(),
observation_space=envs.get_observation_space(),
dist_fn=dist_fn,
**kwargs,
)

Expand Down Expand Up @@ -333,6 +336,7 @@ def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]:
kwargs["critic"] = actor_critic.critic
kwargs["optim"] = actor_critic.optim
kwargs["action_space"] = envs.get_action_space()
kwargs["dist_fn"] = self.actor_factory.create_dist_fn(envs)
return kwargs

def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
Expand Down
42 changes: 38 additions & 4 deletions tianshou/highlevel/module/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
)
from tianshou.highlevel.module.module_opt import ModuleOpt
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.highlevel.params.dist_fn import (
DistributionFunctionFactoryCategorical,
DistributionFunctionFactoryIndependentGaussians,
)
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import BaseActor, ModuleType, Net
from tianshou.utils.string import ToStringMixin
Expand Down Expand Up @@ -47,6 +52,14 @@ class ActorFactory(ModuleFactory, ToStringMixin, ABC):
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
pass

@abstractmethod
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None:
"""
:param envs: the environments
:return: the distribution function, which converts the actor's output into a distribution, or None
if the actor does not output distribution parameters
"""

def create_module_opt(
self,
envs: Environments,
Expand All @@ -70,7 +83,7 @@ def create_module_opt(
def _init_linear(actor: torch.nn.Module) -> None:
"""Initializes linear layers of an actor module using default mechanisms.
:param module: the actor module.
:param actor: the actor module.
"""
init_linear_orthogonal(actor)
if hasattr(actor, "mu"):
Expand Down Expand Up @@ -104,7 +117,7 @@ def __init__(
self.hidden_activation = hidden_activation
self.discrete_softmax = discrete_softmax

def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
def _create_factory(self, envs: Environments) -> ActorFactory:
env_type = envs.get_type()
factory: ActorFactoryContinuousDeterministicNet | ActorFactoryContinuousGaussianNet | ActorFactoryDiscreteNet
if env_type == EnvType.CONTINUOUS:
Expand All @@ -125,15 +138,22 @@ def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
raise ValueError("Continuous action spaces are not supported by the algorithm")
case _:
raise ValueError(self.continuous_actor_type)
return factory.create_module(envs, device)
elif env_type == EnvType.DISCRETE:
factory = ActorFactoryDiscreteNet(
self.DEFAULT_HIDDEN_SIZES,
softmax_output=self.discrete_softmax,
)
return factory.create_module(envs, device)
else:
raise ValueError(f"{env_type} not supported")
return factory

def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
factory = self._create_factory(envs)
return factory.create_module(envs, device)

def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None:
factory = self._create_factory(envs)
return factory.create_dist_fn(envs)


class ActorFactoryContinuous(ActorFactory, ABC):
Expand All @@ -159,6 +179,9 @@ def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
device=device,
).to(device)

def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None:
return None


class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
def __init__(
Expand Down Expand Up @@ -202,6 +225,9 @@ def create_module(self, envs: Environments, device: TDevice) -> BaseActor:

return actor

def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None:
return DistributionFunctionFactoryIndependentGaussians().create_dist_fn(envs)


class ActorFactoryDiscreteNet(ActorFactory):
def __init__(
Expand Down Expand Up @@ -229,6 +255,11 @@ def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
softmax_output=self.softmax_output,
).to(device)

def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None:
return DistributionFunctionFactoryCategorical(
is_probs_input=self.softmax_output,
).create_dist_fn(envs)


class ActorFactoryTransientStorageDecorator(ActorFactory):
"""Wraps an actor factory, storing the most recently created actor instance such that it can be retrieved."""
Expand All @@ -254,6 +285,9 @@ def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.M
self._actor_future.actor = module
return module

def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None:
return self.actor_factory.create_dist_fn(envs)


class IntermediateModuleFactoryFromActorFactory(IntermediateModuleFactory):
def __init__(self, actor_factory: ActorFactory):
Expand Down
6 changes: 6 additions & 0 deletions tianshou/highlevel/module/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,12 @@ class CriticFactoryReuseActor(CriticFactory):
"""A critic factory which reuses the actor's preprocessing component.
This class is for internal use in experiment builders only.
Reuse of the actor network is supported through the concept of an actor future (:class:`ActorFuture`).
When the user declares that he wants to reuse the actor for the critic, we use this factory to support this,
but the actor does not exist yet. So the factory instead receives the future, which will eventually be filled
when the actor factory is called. When the creation method of this factory is eventually called, it can use the
then-filled actor to create the critic.
"""

def __init__(self, actor_future: ActorFuture):
Expand Down
36 changes: 21 additions & 15 deletions tianshou/highlevel/params/dist_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from tianshou.highlevel.env import Environments, EnvType
from tianshou.highlevel.env import Environments
from tianshou.policy.modelfree.pg import TDistFnDiscrete, TDistFnDiscrOrCont
from tianshou.utils.string import ToStringMixin

Expand All @@ -20,32 +20,38 @@ def create_dist_fn(


class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
def __init__(self, is_probs_input: bool = True):
"""
:param is_probs_input: If True, the distribution function shall create a categorical distribution from a
tensor containing probabilities; otherwise the tensor is assumed to contain logits.
"""
self.is_probs_input = is_probs_input

def create_dist_fn(self, envs: Environments) -> TDistFnDiscrete:
envs.get_type().assert_discrete(self)
return self._dist_fn
if self.is_probs_input:
return self._dist_fn_probs
else:
return self._dist_fn

# NOTE: Do not move/rename because a reference to the function can appear in persisted policies
@staticmethod
def _dist_fn(logits: torch.Tensor) -> torch.distributions.Categorical:
return torch.distributions.Categorical(logits=logits)

# NOTE: Do not move/rename because a reference to the function can appear in persisted policies
@staticmethod
def _dist_fn(p: torch.Tensor) -> torch.distributions.Categorical:
return torch.distributions.Categorical(logits=p)
def _dist_fn_probs(probs: torch.Tensor) -> torch.distributions.Categorical:
return torch.distributions.Categorical(probs=probs)


class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
envs.get_type().assert_continuous(self)
return self._dist_fn

# NOTE: Do not move/rename because a reference to the function can appear in persisted policies
@staticmethod
def _dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.distributions.Distribution:
loc, scale = loc_scale
return torch.distributions.Independent(torch.distributions.Normal(loc, scale), 1)


class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
match envs.get_type():
case EnvType.DISCRETE:
return DistributionFunctionFactoryCategorical().create_dist_fn(envs)
case EnvType.CONTINUOUS:
return DistributionFunctionFactoryIndependentGaussians().create_dist_fn(envs)
case _:
raise ValueError(envs.get_type())
Loading

0 comments on commit bd74273

Please sign in to comment.