Skip to content

Commit

Permalink
Add calibration to CQL as in CalQL paper arXiv:2303.05479 (#915)
Browse files Browse the repository at this point in the history
- [X] I have marked all applicable categories:
    + [ ] exception-raising fix
    + [ ] algorithm implementation fix
    + [ ] documentation modification
    + [X] new feature
- [X] I have reformatted the code using `make format` (**required**)
- [X] I have checked the code using `make commit-checks` (**required**)
- [X] If applicable, I have mentioned the relevant/related issue(s)
- [X] If applicable, I have listed every items in this Pull Request
below
  • Loading branch information
BFAnas authored Oct 3, 2023
1 parent 6449a43 commit c30b4ab
Show file tree
Hide file tree
Showing 13 changed files with 296 additions and 54 deletions.
203 changes: 173 additions & 30 deletions examples/offline/d4rl_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,41 +22,182 @@

def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="HalfCheetah-v2")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--expert-data-task", type=str, default="halfcheetah-expert-v2")
parser.add_argument("--buffer-size", type=int, default=1000000)
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256])
parser.add_argument("--actor-lr", type=float, default=1e-4)
parser.add_argument("--critic-lr", type=float, default=3e-4)
parser.add_argument("--alpha", type=float, default=0.2)
parser.add_argument("--auto-alpha", default=True, action="store_true")
parser.add_argument("--alpha-lr", type=float, default=1e-4)
parser.add_argument("--cql-alpha-lr", type=float, default=3e-4)
parser.add_argument("--start-timesteps", type=int, default=10000)
parser.add_argument("--epoch", type=int, default=200)
parser.add_argument("--step-per-epoch", type=int, default=5000)
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--batch-size", type=int, default=256)

parser.add_argument("--tau", type=float, default=0.005)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--cql-weight", type=float, default=1.0)
parser.add_argument("--with-lagrange", type=bool, default=True)
parser.add_argument("--lagrange-threshold", type=float, default=10.0)
parser.add_argument("--gamma", type=float, default=0.99)

parser.add_argument("--eval-freq", type=int, default=1)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=1 / 35)
parser.add_argument(
"--task",
type=str,
default="Hopper-v2",
help="The name of the OpenAI Gym environment to train on.",
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="The random seed to use.",
)
parser.add_argument(
"--expert-data-task",
type=str,
default="hopper-expert-v2",
help="The name of the OpenAI Gym environment to use for expert data collection.",
)
parser.add_argument(
"--buffer-size",
type=int,
default=1000000,
help="The size of the replay buffer.",
)
parser.add_argument(
"--hidden-sizes",
type=int,
nargs="*",
default=[256, 256],
help="The list of hidden sizes for the neural networks.",
)
parser.add_argument(
"--actor-lr",
type=float,
default=1e-4,
help="The learning rate for the actor network.",
)
parser.add_argument(
"--critic-lr",
type=float,
default=3e-4,
help="The learning rate for the critic network.",
)
parser.add_argument(
"--alpha",
type=float,
default=0.2,
help="The weight of the entropy term in the loss function.",
)
parser.add_argument(
"--auto-alpha",
default=True,
action="store_true",
help="Whether to use automatic entropy tuning.",
)
parser.add_argument(
"--alpha-lr",
type=float,
default=1e-4,
help="The learning rate for the entropy tuning.",
)
parser.add_argument(
"--cql-alpha-lr",
type=float,
default=3e-4,
help="The learning rate for the CQL entropy tuning.",
)
parser.add_argument(
"--start-timesteps",
type=int,
default=10000,
help="The number of timesteps before starting to train.",
)
parser.add_argument(
"--epoch",
type=int,
default=200,
help="The number of epochs to train for.",
)
parser.add_argument(
"--step-per-epoch",
type=int,
default=5000,
help="The number of steps per epoch.",
)
parser.add_argument(
"--n-step",
type=int,
default=3,
help="The number of steps to use for N-step TD learning.",
)
parser.add_argument(
"--batch-size",
type=int,
default=256,
help="The batch size for training.",
)
parser.add_argument(
"--tau",
type=float,
default=0.005,
help="The soft target update coefficient.",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="The temperature for the Boltzmann policy.",
)
parser.add_argument(
"--cql-weight",
type=float,
default=1.0,
help="The weight of the CQL loss term.",
)
parser.add_argument(
"--with-lagrange",
type=bool,
default=True,
help="Whether to use the Lagrange multiplier for CQL.",
)
parser.add_argument(
"--calibrated",
type=bool,
default=True,
help="Whether to use calibration for CQL.",
)
parser.add_argument(
"--lagrange-threshold",
type=float,
default=10.0,
help="The Lagrange multiplier threshold for CQL.",
)
parser.add_argument("--gamma", type=float, default=0.99, help="The discount factor")
parser.add_argument(
"--eval-freq",
type=int,
default=1,
help="The frequency of evaluation.",
)
parser.add_argument(
"--test-num",
type=int,
default=10,
help="The number of episodes to evaluate for.",
)
parser.add_argument(
"--logdir",
type=str,
default="log",
help="The directory to save logs to.",
)
parser.add_argument(
"--render",
type=float,
default=1 / 35,
help="The frequency of rendering the environment.",
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="The device to train on (cpu or cuda).",
)
parser.add_argument(
"--resume-path",
type=str,
default=None,
help="The path to the checkpoint to resume from.",
)
parser.add_argument(
"--resume-id",
type=str,
default=None,
help="The ID of the checkpoint to resume from.",
)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
"--logger",
type=str,
Expand Down Expand Up @@ -145,6 +286,8 @@ def test_cql():
critic1_optim,
critic2,
critic2_optim,
calibrated=args.calibrated,
action_space=env.action_space,
cql_alpha_lr=args.cql_alpha_lr,
cql_weight=args.cql_weight,
tau=args.tau,
Expand Down
44 changes: 34 additions & 10 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ gymnasium = "^0.29.0"
h5py = "^3.9.0"
numba = "^0.57.1"
numpy = "^1"
overrides = "^7.4.0"
packaging = "*"
pettingzoo = "^1.22"
tensorboard = "^2.5.0"
Expand Down
8 changes: 0 additions & 8 deletions test/offline/test_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,6 @@ def save_best_fn(policy):
def stop_fn(mean_rewards):
return mean_rewards >= args.reward_threshold

def watch():
policy.load_state_dict(
torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")),
)
policy.eval()
collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35)

# trainer
trainer = OfflineTrainer(
policy=policy,
Expand Down
2 changes: 2 additions & 0 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,8 @@ def __init__(
batch_dict = cast(Sequence[dict | BatchProtocol], batch_dict)
self.stack_(batch_dict)
if len(kwargs) > 0:
# TODO: that's a rather weird pattern, is it really needed?
# Feels like kwargs could be just merged into batch_dict in the beginning
self.__init__(kwargs, copy=copy) # type: ignore

def __setattr__(self, key: str, value: Any) -> None:
Expand Down
4 changes: 3 additions & 1 deletion tianshou/data/buffer/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Self, cast
from typing import Any, Self, TypeVar, cast

import h5py
import numpy as np
Expand All @@ -8,6 +8,8 @@
from tianshou.data.types import RolloutBatchProtocol
from tianshou.data.utils.converter import from_hdf5, to_hdf5

TBuffer = TypeVar("TBuffer", bound="ReplayBuffer")


class ReplayBuffer:
""":class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment.
Expand Down
Loading

0 comments on commit c30b4ab

Please sign in to comment.