Skip to content

Commit

Permalink
Cleaned up handling of output_dim retrieval, adding exceptions for er…
Browse files Browse the repository at this point in the history
…roneous cases
  • Loading branch information
opcode81 committed Jan 16, 2024
1 parent 6437609 commit 226f8e4
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 13 deletions.
1 change: 1 addition & 0 deletions examples/atari/atari_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self, module: torch.nn.Module, denom: float = 255.0):
super().__init__()
self.module = module
self.denom = denom
# This is required such that the value can be retrieved by downstream modules (see usages of get_output_dim)
self.output_dim = module.output_dim

def forward(
Expand Down
39 changes: 38 additions & 1 deletion tianshou/utils/net/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from typing import Any, TypeAlias, no_type_check
from typing import Any, TypeAlias, TypeVar, no_type_check

import numpy as np
import torch
Expand All @@ -13,6 +13,7 @@
ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]]
TActionShape: TypeAlias = Sequence[int] | int
TLinearLayer: TypeAlias = Callable[[int, int], nn.Module]
T = TypeVar("T")


def miniblock(
Expand Down Expand Up @@ -608,3 +609,39 @@ def get_preprocess_net(self) -> nn.Module:
@abstractmethod
def get_output_dim(self) -> int:
pass


def getattr_with_matching_alt_value(obj: Any, attr_name: str, alt_value: T | None) -> T:
"""Gets the given attribute from the given object or takes the alternative value if it is not present.
If both are present, they are required to match.
:param obj: the object from which to obtain the attribute value
:param attr_name: the attribute name
:param alt_value: the alternative value for the case where the attribute is not present, which cannot be None
if the attribute is not present
:return: the value
"""
v = getattr(obj, attr_name)
if v is not None:
if alt_value is not None and v != alt_value:
raise (
f"Attribute '{attr_name}' of {obj} is defined ({v}) but does not match alt. value ({alt_value})"
)
return v
else:
if alt_value is None:
raise ValueError(
f"Attribute '{attr_name}' of {obj} is not defined and no fallback given",
)
return alt_value


def get_output_dim(module: nn.Module, alt_value: int | None):
"""Retrieves value the `output_dim` attribute of the given module or uses the given alternative value if the attribute is not present.
If both are present, they must match.
:param module: the module
:param alt_value: the alternative value
:return: the value
"""
return getattr_with_matching_alt_value(module, "output_dim", alt_value)
17 changes: 11 additions & 6 deletions tianshou/utils/net/continuous.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import warnings
from collections.abc import Sequence
from typing import Any, cast
from typing import Any

import numpy as np
import torch
from torch import nn

from tianshou.utils.net.common import MLP, BaseActor, TActionShape, TLinearLayer
from tianshou.utils.net.common import (
MLP,
BaseActor,
TActionShape,
TLinearLayer,
get_output_dim,
)

SIGMA_MIN = -20
SIGMA_MAX = 2
Expand Down Expand Up @@ -50,8 +56,7 @@ def __init__(
self.device = device
self.preprocess = preprocess_net
self.output_dim = int(np.prod(action_shape))
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
input_dim = cast(int, input_dim)
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.last = MLP(
input_dim,
self.output_dim,
Expand Down Expand Up @@ -118,7 +123,7 @@ def __init__(
self.device = device
self.preprocess = preprocess_net
self.output_dim = 1
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.last = MLP(
input_dim, # type: ignore
1,
Expand Down Expand Up @@ -199,7 +204,7 @@ def __init__(
self.preprocess = preprocess_net
self.device = device
self.output_dim = int(np.prod(action_shape))
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device) # type: ignore
self._c_sigma = conditioned_sigma
if conditioned_sigma:
Expand Down
11 changes: 5 additions & 6 deletions tianshou/utils/net/discrete.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from collections.abc import Sequence
from typing import Any, cast
from typing import Any

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

from tianshou.data import Batch, to_torch
from tianshou.utils.net.common import MLP, BaseActor, TActionShape
from tianshou.utils.net.common import MLP, BaseActor, TActionShape, get_output_dim


class Actor(BaseActor):
Expand Down Expand Up @@ -51,8 +51,7 @@ def __init__(
self.device = device
self.preprocess = preprocess_net
self.output_dim = int(np.prod(action_shape))
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
input_dim = cast(int, input_dim)
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.last = MLP(
input_dim,
self.output_dim,
Expand Down Expand Up @@ -118,7 +117,7 @@ def __init__(
self.device = device
self.preprocess = preprocess_net
self.output_dim = last_size
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device) # type: ignore

def forward(self, obs: np.ndarray | torch.Tensor, **kwargs: Any) -> torch.Tensor:
Expand Down Expand Up @@ -197,7 +196,7 @@ def __init__(
) -> None:
last_size = int(np.prod(action_shape))
super().__init__(preprocess_net, hidden_sizes, last_size, preprocess_net_output_dim, device)
self.input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
self.input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.embed_model = CosineEmbeddingNetwork(num_cosines, self.input_dim).to( # type: ignore
device,
)
Expand Down

0 comments on commit 226f8e4

Please sign in to comment.