Skip to content

Commit

Permalink
feat: update to SheepRLv0.4.8
Browse files Browse the repository at this point in the history
  • Loading branch information
michele-milesi authored and alexpalms committed Nov 29, 2023
1 parent 57dbc63 commit 4c4d36a
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 38 deletions.
72 changes: 35 additions & 37 deletions diambra/arena/sheeprl/make_sheeprl_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,36 @@ def thunk() -> gym.Env:
instantiate_kwargs["rank"] = rank + vector_env_idx
env = hydra.utils.instantiate(cfg.env.wrapper, **instantiate_kwargs)

if not (
isinstance(cfg.algo.mlp_keys.encoder, list)
and isinstance(cfg.algo.cnn_keys.encoder, list)
and len(cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder) > 0
):
raise ValueError(
"`algo.cnn_keys.encoder` and `algo.mlp_keys.encoder` must be lists of strings, got: "
f"cnn encoder keys `{cfg.algo.cnn_keys.encoder}` of type `{type(cfg.algo.cnn_keys.encoder)}` "
f"and mlp encoder keys `{cfg.algo.mlp_keys.encoder}` of type `{type(cfg.algo.mlp_keys.encoder)}`. "
"Both must be non-empty lists."
)

if (
len(
set(k for k in env.observation_space.keys()).intersection(
set(cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder)
)
)
== 0
):
raise ValueError(
f"The user specified keys `{cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder}` "
"are not a subset of the "
f"environment `{env.observation_space.keys()}` observation keys. Please check your config file."
)

env_cnn_keys = set(
[
k
for k in env.observation_space.spaces.keys()
if len(env.observation_space[k].shape) in {2, 3}
]
[k for k in env.observation_space.spaces.keys() if len(env.observation_space[k].shape) in {2, 3}]
)
if cfg.cnn_keys.encoder is None:
user_cnn_keys = set()
else:
user_cnn_keys = set(cfg.cnn_keys.encoder)
cnn_keys = env_cnn_keys.intersection(user_cnn_keys)
cnn_keys = env_cnn_keys.intersection(set(cfg.algo.cnn_keys.encoder))

def transform_obs(obs: Dict[str, Any]):
for k in cnn_keys:
Expand All @@ -93,9 +111,7 @@ def transform_obs(obs: Dict[str, Any]):
# resize
if current_obs.shape[:-1] != (cfg.env.screen_size, cfg.env.screen_size):
current_obs = cv2.resize(
current_obs,
(cfg.env.screen_size, cfg.env.screen_size),
interpolation=cv2.INTER_AREA,
current_obs, (cfg.env.screen_size, cfg.env.screen_size), interpolation=cv2.INTER_AREA
)

# to grayscale
Expand All @@ -116,49 +132,31 @@ def transform_obs(obs: Dict[str, Any]):
env = gym.wrappers.TransformObservation(env, transform_obs)
for k in cnn_keys:
env.observation_space[k] = gym.spaces.Box(
0,
255,
(
1 if cfg.env.grayscale else 3,
cfg.env.screen_size,
cfg.env.screen_size,
),
np.uint8,
0, 255, (1 if cfg.env.grayscale else 3, cfg.env.screen_size, cfg.env.screen_size), np.uint8
)

if cnn_keys is not None and len(cnn_keys) > 0 and cfg.env.frame_stack > 1:
if cfg.env.frame_stack_dilation <= 0:
raise ValueError(
f"The frame stack dilation argument must be greater than zero, got: {cfg.env.frame_stack_dilation}"
)
env = FrameStack(
env, cfg.env.frame_stack, cnn_keys, cfg.env.frame_stack_dilation
)
env = FrameStack(env, cfg.env.frame_stack, cnn_keys, cfg.env.frame_stack_dilation)

if cfg.env.reward_as_observation:
env = RewardAsObservationWrapper(env)

env.action_space.seed(seed)
env.observation_space.seed(seed)
if cfg.env.max_episode_steps and cfg.env.max_episode_steps > 0:
env = gym.wrappers.TimeLimit(
env, max_episode_steps=cfg.env.max_episode_steps
)
env = gym.wrappers.TimeLimit(env, max_episode_steps=cfg.env.max_episode_steps)
env = gym.wrappers.RecordEpisodeStatistics(env)
if (
cfg.env.capture_video
and rank == 0
and vector_env_idx == 0
and run_name is not None
):
if cfg.env.capture_video and rank == 0 and vector_env_idx == 0 and run_name is not None:
if cfg.env.grayscale:
env = GrayscaleRenderWrapper(env)
env = gym.experimental.wrappers.RecordVideoV0(
env,
os.path.join(run_name, prefix + "_videos" if prefix else "videos"),
disable_logger=True,
env, os.path.join(run_name, prefix + "_videos" if prefix else "videos"), disable_logger=True
)
env.metadata["render_fps"] = env.frames_per_sec
return env

return thunk
return thunk
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"stable-baselines3": ["stable-baselines3[extra]~=2.1.0", "pyyaml"],
"ray-rllib": ["ray[rllib]~=2.7.0", "tensorflow", "torch", "pyyaml"],
"sheeprl": [
"sheeprl==0.4.7",
"sheeprl==0.4.8",
"importlib-resources==6.1.0",
],
}
Expand Down

0 comments on commit 4c4d36a

Please sign in to comment.