Skip to content

Commit

Permalink
Feature/algo eval (thu-ml#1074)
Browse files Browse the repository at this point in the history
# Changes

## Dependencies

- New extra "eval"

## Api Extension
- `Experiment` and `ExperimentConfig` now have a `name`, that can
however be overridden when `Experiment.run()` is called
- When building an `Experiment` from an `ExperimentConfig`, the user has
the option to add info about seeds to the name.
- New method in `ExperimentConfig` called
`build_default_seeded_experiments`
- `SamplingConfig` has an explicit training seed, `test_seed` is
inferred.
- New `evaluation` package for repeating the same experiment with
multiple seeds and aggregating the results (important extension!).
Currently in alpha state.
- Loggers can now restore the logged data into python by using the new
`restore_logged_data`

## Breaking Changes
- `AtariEnvFactory` (in examples) now receives explicit train and test
seeds
- `EnvFactoryRegistered` now requires an explicit `test_seed`
- `BaseLogger.prepare_dict_for_logging` is now abstract

---------

Co-authored-by: Maximilian Huettenrauch <[email protected]>
Co-authored-by: Michael Panchenko <[email protected]>
Co-authored-by: Michael Panchenko <[email protected]>
  • Loading branch information
4 people authored Apr 20, 2024
1 parent 9c0b3e7 commit ade85ab
Show file tree
Hide file tree
Showing 38 changed files with 1,655 additions and 153 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint_and_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
key: venv-${{ hashFiles('poetry.lock') }}
- name: Install the project dependencies
run: |
poetry install --with dev
poetry install --with dev --extras "eval"
- name: Lint
run: poetry run poe lint
- name: Types
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
key: venv-${{ hashFiles('poetry.lock') }}
- name: Install the project dependencies
run: |
poetry install --with dev --extras "envpool"
poetry install --with dev --extras "envpool eval"
- name: wandb login
run: |
poetry run wandb login e2366d661b89f2bee877c40bee15502d67b7abef
Expand Down
6 changes: 5 additions & 1 deletion docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -261,4 +261,8 @@ BA
BH
BO
BD

configs
postfix
backend
rliable
hl
10 changes: 8 additions & 2 deletions examples/atari/atari_dqn_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,13 @@ def main(
replay_buffer_save_only_last_obs=True,
)

env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs)
env_factory = AtariEnvFactory(
task,
sampling_config.train_seed,
sampling_config.test_seed,
frames_stack,
scale=scale_obs,
)

builder = (
DQNExperimentBuilder(env_factory, experiment_config, sampling_config)
Expand Down Expand Up @@ -98,7 +104,7 @@ def main(
)

experiment = builder.build()
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)


if __name__ == "__main__":
Expand Down
10 changes: 8 additions & 2 deletions examples/atari/atari_iqn_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,13 @@ def main(
replay_buffer_save_only_last_obs=True,
)

env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs)
env_factory = AtariEnvFactory(
task,
sampling_config.train_seed,
sampling_config.test_seed,
frames_stack,
scale=scale_obs,
)

experiment = (
IQNExperimentBuilder(env_factory, experiment_config, sampling_config)
Expand All @@ -90,7 +96,7 @@ def main(
.with_epoch_stop_callback(AtariEpochStopCallback(task))
.build()
)
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)


if __name__ == "__main__":
Expand Down
10 changes: 8 additions & 2 deletions examples/atari/atari_ppo_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,13 @@ def main(
replay_buffer_save_only_last_obs=True,
)

env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack)
env_factory = AtariEnvFactory(
task,
sampling_config.train_seed,
sampling_config.test_seed,
frames_stack,
scale=scale_obs,
)

builder = (
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
Expand Down Expand Up @@ -109,7 +115,7 @@ def main(
),
)
experiment = builder.build()
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)


if __name__ == "__main__":
Expand Down
10 changes: 8 additions & 2 deletions examples/atari/atari_sac_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,13 @@ def main(
replay_buffer_save_only_last_obs=True,
)

env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs)
env_factory = AtariEnvFactory(
task,
sampling_config.train_seed,
sampling_config.test_seed,
frames_stack,
scale=scale_obs,
)

builder = (
DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config)
Expand Down Expand Up @@ -97,7 +103,7 @@ def main(
),
)
experiment = builder.build()
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)


if __name__ == "__main__":
Expand Down
8 changes: 5 additions & 3 deletions examples/atari/atari_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def make_atari_env(
:return: a tuple of (single env, training envs, test envs).
"""
env_factory = AtariEnvFactory(task, seed, frame_stack, scale=bool(scale))
env_factory = AtariEnvFactory(task, seed, seed + training_num, frame_stack, scale=bool(scale))
envs = env_factory.create_envs(training_num, test_num)
return envs.env, envs.train_envs, envs.test_envs

Expand All @@ -392,7 +392,8 @@ class AtariEnvFactory(EnvFactoryRegistered):
def __init__(
self,
task: str,
seed: int,
train_seed: int,
test_seed: int,
frame_stack: int,
scale: bool = False,
use_envpool_if_available: bool = True,
Expand All @@ -409,7 +410,8 @@ def __init__(
log.info("Not using envpool, because it is not available")
super().__init__(
task=task,
seed=seed,
train_seed=train_seed,
test_seed=test_seed,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
envpool_factory=envpool_factory,
)
Expand Down
8 changes: 7 additions & 1 deletion examples/discrete/discrete_dqn_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
def main() -> None:
experiment = (
DQNExperimentBuilder(
EnvFactoryRegistered(task="CartPole-v1", seed=0, venv_type=VectorEnvType.DUMMY),
EnvFactoryRegistered(
task="CartPole-v1",
seed=0,
venv_type=VectorEnvType.DUMMY,
train_seed=0,
test_seed=10,
),
ExperimentConfig(
persistence_enabled=False,
watch=True,
Expand Down
9 changes: 7 additions & 2 deletions examples/mujoco/mujoco_a2c_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ def main(
repeat_per_collect=repeat_per_collect,
)

env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=True,
)

experiment = (
A2CExperimentBuilder(env_factory, experiment_config, sampling_config)
Expand All @@ -78,7 +83,7 @@ def main(
.with_critic_factory_default(hidden_sizes, nn.Tanh)
.build()
)
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)


if __name__ == "__main__":
Expand Down
9 changes: 7 additions & 2 deletions examples/mujoco/mujoco_ddpg_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,12 @@ def main(
start_timesteps_random=True,
)

env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=False)
env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=False,
)

experiment = (
DDPGExperimentBuilder(env_factory, experiment_config, sampling_config)
Expand All @@ -69,7 +74,7 @@ def main(
.with_critic_factory_default(hidden_sizes)
.build()
)
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)


if __name__ == "__main__":
Expand Down
8 changes: 5 additions & 3 deletions examples/mujoco/mujoco_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def make_mujoco_env(
:return: a tuple of (single env, training envs, test envs).
"""
envs = MujocoEnvFactory(task, seed, obs_norm=obs_norm).create_envs(
envs = MujocoEnvFactory(task, seed, seed + num_train_envs, obs_norm=obs_norm).create_envs(
num_train_envs,
num_test_envs,
)
Expand Down Expand Up @@ -73,13 +73,15 @@ class MujocoEnvFactory(EnvFactoryRegistered):
def __init__(
self,
task: str,
seed: int,
train_seed: int,
test_seed: int,
obs_norm: bool = True,
venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM,
) -> None:
super().__init__(
task=task,
seed=seed,
train_seed=train_seed,
test_seed=test_seed,
venv_type=venv_type,
envpool_factory=EnvPoolFactory() if envpool_is_available else None,
)
Expand Down
9 changes: 7 additions & 2 deletions examples/mujoco/mujoco_npg_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ def main(
repeat_per_collect=repeat_per_collect,
)

env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=True,
)

experiment = (
NPGExperimentBuilder(env_factory, experiment_config, sampling_config)
Expand All @@ -80,7 +85,7 @@ def main(
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build()
)
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)


if __name__ == "__main__":
Expand Down
9 changes: 7 additions & 2 deletions examples/mujoco/mujoco_ppo_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,12 @@ def main(
repeat_per_collect=repeat_per_collect,
)

env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
env_factory = MujocoEnvFactory(
task,
train_seed=sampling_config.train_seed,
test_seed=sampling_config.test_seed,
obs_norm=True,
)

experiment = (
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
Expand Down Expand Up @@ -90,7 +95,7 @@ def main(
.with_critic_factory_default(hidden_sizes, torch.nn.Tanh)
.build()
)
experiment.run(log_name)
experiment.run(override_experiment_name=log_name)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit ade85ab

Please sign in to comment.