Skip to content

Commit

Permalink
yet more formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
izhigal committed Jan 26, 2024
1 parent f4b7ec1 commit 5d674c9
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 98 deletions.
3 changes: 1 addition & 2 deletions examples/human/leduc_holdem_human.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""A toy example of playing against pretrianed AI on Leduc Hold'em
"""
"""A toy example of playing against pretrianed AI on Leduc Hold'em"""

import rlcard
from rlcard import models
Expand Down
3 changes: 1 addition & 2 deletions examples/human/limit_holdem_human.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""A toy example of playing against a random agent on Limit Hold'em
"""
"""A toy example of playing against a random agent on Limit Hold'em"""

import rlcard
from rlcard.agents import LimitholdemHumanAgent as HumanAgent
Expand Down
4 changes: 2 additions & 2 deletions examples/run_dmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import rlcard
from rlcard.agents.dmc_agent import DMCTrainer

def train(args):

def train(args):
# Make the environment
env = rlcard.make(args.env)

Expand All @@ -28,6 +28,7 @@ def train(args):
# Train DMC Agents
trainer.start()


if __name__ == '__main__':
parser = argparse.ArgumentParser("DMC example in RLCard")
parser.add_argument(
Expand Down Expand Up @@ -94,4 +95,3 @@ def train(args):

os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
train(args)

16 changes: 8 additions & 8 deletions examples/run_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
plot_curve,
)

def train(args):

def train(args):
# Check whether gpu is available
device = get_device()

# Seed numpy, torch, random
set_seed(args.seed)

Expand All @@ -40,7 +40,7 @@ def train(args):
agent = DQNAgent(
num_actions=env.num_actions,
state_shape=env.state_shape[0],
mlp_layers=[64,64],
mlp_layers=[64, 64],
device=device,
save_path=args.log_dir,
save_every=args.save_every
Expand All @@ -54,8 +54,8 @@ def train(args):
agent = NFSPAgent(
num_actions=env.num_actions,
state_shape=env.state_shape[0],
hidden_layers_sizes=[64,64],
q_mlp_layers=[64,64],
hidden_layers_sizes=[64, 64],
q_mlp_layers=[64, 64],
device=device,
save_path=args.log_dir,
save_every=args.save_every
Expand Down Expand Up @@ -105,6 +105,7 @@ def train(args):
torch.save(agent, save_path)
print('Model saved in', save_path)


if __name__ == '__main__':
parser = argparse.ArgumentParser("DQN/NFSP example in RLCard")
parser.add_argument(
Expand Down Expand Up @@ -162,13 +163,13 @@ def train(args):
type=str,
default='experiments/leduc_holdem_dqn_result/',
)

parser.add_argument(
"--load_checkpoint_path",
type=str,
default="",
)

parser.add_argument(
"--save_every",
type=int,
Expand All @@ -178,4 +179,3 @@ def train(args):

os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
train(args)

37 changes: 10 additions & 27 deletions rlcard/agents/dmc_agent/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,17 @@
import torch
from torch import nn


class DMCNet(nn.Module):
def __init__(
self,
state_shape,
action_shape,
mlp_layers=None
):
def __init__(self, state_shape, action_shape, mlp_layers=None):
super().__init__()
if mlp_layers is None:
mlp_layers = [512, 512, 512, 512, 512]
input_dim = np.prod(state_shape) + np.prod(action_shape)
layer_dims = [input_dim] + mlp_layers
fc = []
for i in range(len(layer_dims)-1):
fc.append(nn.Linear(layer_dims[i], layer_dims[i+1]))
for i in range(len(layer_dims) - 1):
fc.append(nn.Linear(layer_dims[i], layer_dims[i + 1]))
fc.append(nn.ReLU())
fc.append(nn.Linear(layer_dims[-1], 1))
self.fc_layers = nn.Sequential(*fc)
Expand All @@ -44,19 +40,13 @@ def forward(self, obs, actions):
values = self.fc_layers(x).flatten()
return values


class DMCAgent:
def __init__(
self,
state_shape,
action_shape,
mlp_layers=None,
exp_epsilon=0.01,
device="0",
):
def __init__(self, state_shape, action_shape, mlp_layers=None, exp_epsilon=0.01, device="0"):
if mlp_layers is None:
mlp_layers = [512, 512, 512, 512, 512]
self.use_raw = False
self.device = 'cuda:'+device if device != "cpu" else "cpu"
self.device = 'cuda:' + device if device != "cpu" else "cpu"
self.net = DMCNet(state_shape, action_shape, mlp_layers).to(self.device)
self.exp_epsilon = exp_epsilon
self.action_shape = action_shape
Expand All @@ -78,8 +68,7 @@ def eval_step(self, state):
action_idx = np.argmax(values)
action = action_keys[action_idx]

info = {}
info['values'] = {state['raw_legal_actions'][i]: float(values[i]) for i in range(len(action_keys))}
info = {'values': {state['raw_legal_actions'][i]: float(values[i]) for i in range(len(action_keys))}}

return action, info

Expand Down Expand Up @@ -125,15 +114,9 @@ def state_dict(self):
def set_device(self, device):
self.device = device


class DMCModel:
def __init__(
self,
state_shape,
action_shape,
mlp_layers=None,
exp_epsilon=0.01,
device=0
):
def __init__(self, state_shape, action_shape, mlp_layers=None, exp_epsilon=0.01, device=0):
if mlp_layers is None:
mlp_layers = [512, 512, 512, 512, 512]
self.agents = []
Expand Down
8 changes: 1 addition & 7 deletions rlcard/agents/dmc_agent/pettingzoo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,7 @@ def feed(self, ts):


class DMCModelPettingZoo:
def __init__(
self,
env,
mlp_layers=None,
exp_epsilon=0.01,
device="0"
):
def __init__(self, env, mlp_layers=None, exp_epsilon=0.01, device="0"):
if mlp_layers is None:
mlp_layers = [512, 512, 512, 512, 512]

Expand Down
Loading

0 comments on commit 5d674c9

Please sign in to comment.