Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error while running tutrials files #122

Open
Arseni1919 opened this issue Dec 8, 2024 · 6 comments
Open

Error while running tutrials files #122

Arseni1919 opened this issue Dec 8, 2024 · 6 comments

Comments

@Arseni1919
Copy link

Try to run tutorials files, here is the error I get:

Traceback (most recent call last):
  File "/Users/perchik/PycharmProjects/Learning_DRL/try_jaxmarl.py", line 6, in <module>
    from jaxmarl import make
  File "/Users/perchik/PycharmProjects/Learning_DRL/.venv/lib/python3.12/site-packages/jaxmarl/__init__.py", line 1, in <module>
    from .registration import make, registered_envs
  File "/Users/perchik/PycharmProjects/Learning_DRL/.venv/lib/python3.12/site-packages/jaxmarl/registration.py", line 1, in <module>
    from .environments import (
  File "/Users/perchik/PycharmProjects/Learning_DRL/.venv/lib/python3.12/site-packages/jaxmarl/environments/__init__.py", line 1, in <module>
    from .multi_agent_env import MultiAgentEnv, State
  File "/Users/perchik/PycharmProjects/Learning_DRL/.venv/lib/python3.12/site-packages/jaxmarl/environments/multi_agent_env.py", line 12, in <module>
    from flax import struct
  File "/Users/perchik/PycharmProjects/Learning_DRL/.venv/lib/python3.12/site-packages/flax/__init__.py", line 19, in <module>
    from .configurations import (
  File "/Users/perchik/PycharmProjects/Learning_DRL/.venv/lib/python3.12/site-packages/flax/configurations.py", line 93, in <module>
    flax_filter_frames = define_bool_state(
                         ^^^^^^^^^^^^^^^^^^
  File "/Users/perchik/PycharmProjects/Learning_DRL/.venv/lib/python3.12/site-packages/flax/configurations.py", line 42, in define_bool_state
    return jax_config.define_bool_state('flax_' + name, default, help)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'Config' object has no attribute 'define_bool_state'

What can be an issue?

@Arseni1919
Copy link
Author

Here is the code that I try to run:

"""
Short introduction to running the Overcooked environment and visualising it using random actions.
"""

import jax
from jaxmarl import make
from jaxmarl.viz.overcooked_visualizer import OvercookedVisualizer
from jaxmarl.environments.overcooked import Overcooked, overcooked_layouts, layout_grid_to_dict
import time

# Parameters + random keys
max_steps = 100
key = jax.random.PRNGKey(0)
key, key_r, key_a = jax.random.split(key, 3)

# Get one of the classic layouts (cramped_room, asymm_advantages, coord_ring, forced_coord, counter_circuit)
layout = overcooked_layouts["cramped_room"]

# Or make your own!
# custom_layout_grid = """
# WWOWW
# WA  W
# B P X
# W  AW
# WWOWW
# """
# layout = layout_grid_to_dict(custom_layout_grid)

# Instantiate environment
env = make('overcooked', layout=layout, max_steps=max_steps)

obs, state = env.reset(key_r)
print('list of agents in environment', env.agents)

# Sample random actions
key_a = jax.random.split(key_a, env.num_agents)
actions = {agent: env.action_space(agent).sample(key_a[i]) for i, agent in enumerate(env.agents)}
print('example action dict', actions)

state_seq = []
for _ in range(max_steps):
    state_seq.append(state)
    # Iterate random keys and sample actions
    key, key_s, key_a = jax.random.split(key, 3)
    key_a = jax.random.split(key_a, env.num_agents)

    actions = {agent: env.action_space(agent).sample(key_a[i]) for i, agent in enumerate(env.agents)}

    # Step environment
    obs, state, rewards, dones, infos = env.step(key_s, state, actions)

viz = OvercookedVisualizer()

# Render to screen
for s in state_seq:
    viz.render(env.agent_view_size, s, highlight=False)
    time.sleep(0.25)

# # Or save an animation
# viz.animate(state_seq, agent_view_size=5, filename='animation.gif')

@Arseni1919
Copy link
Author

Arseni1919 commented Dec 8, 2024

The same thing happens in your collab page (https://colab.research.google.com/github/FLAIROx/JaxMARL/blob/main/jaxmarl/tutorials/JaxMARL_Walkthrough.ipynb#scrollTo=h3VkMmsdPfc0) so the problem is not on my computer.
Screenshot 2024-12-08 at 15 00 31

@anhhuyalex
Copy link

Yep I can confirm that I got the same error:
here are the packages I installed with their versions
absl-py==2.1.0
annotated-types==0.7.0
antlr4-python3-runtime==4.9.3
asttokens==3.0.0
blinker==1.9.0
brax==0.10.3
certifi==2024.8.30
charset-normalizer==3.4.0
chex==0.1.84
click==8.1.7
cloudpickle==3.1.0
comm==0.2.2
contourpy==1.3.1
cycler==0.12.1
debugpy==1.8.9
decorator==5.1.1
distrax==0.1.5
dm-env==1.6
dm-tree==0.1.8
docker-pycreds==0.4.0
dotmap==1.3.30
etils==1.11.0
evosax==0.1.5
executing==2.1.0
Farama-Notifications==0.0.4
flashbax==0.1.0
Flask==3.1.0
Flask-Cors==5.0.0
flax==0.7.4
fonttools==4.55.2
fsspec==2024.10.0
gast==0.6.0
gitdb==4.0.11
GitPython==3.1.43
glfw==2.8.0
grpcio==1.68.1
gym==0.26.2
gym-notices==0.0.8
gymnasium==1.0.0
gymnax==0.0.6
hydra-core==1.3.2
idna==3.10
importlib_resources==6.4.5
iniconfig==2.0.0
ipykernel==6.29.5
ipython==8.30.0
itsdangerous==2.2.0
jax==0.4.25
jaxlib==0.4.25
jaxmarl==0.0.5
jaxopt==0.8.3
jedi==0.19.2
Jinja2==3.1.4
jupyter_client==8.6.3
jupyter_core==5.7.2
kiwisolver==1.4.7
markdown-it-py==3.0.0
MarkupSafe==3.0.2
matplotlib==3.9.3
matplotlib-inline==0.1.7
mdurl==0.1.2
ml_collections==1.0.0
ml_dtypes==0.5.0
msgpack==1.1.0
mujoco==3.1.3
mujoco-mjx==3.1.3
nest-asyncio==1.6.0
numpy==1.26.4
omegaconf==2.3.0
opt_einsum==3.4.0
optax==0.1.7
orbax-checkpoint==0.5.18
packaging==24.2
parso==0.8.4
pettingzoo==1.24.3
pexpect==4.9.0
pillow==11.0.0
platformdirs==4.3.6
pluggy==1.5.0
prompt_toolkit==3.0.48
protobuf==5.29.1
psutil==6.1.0
ptyprocess==0.7.0
pure_eval==0.2.3
pydantic==2.10.3
pydantic_core==2.27.1
pygame==2.6.1
Pygments==2.18.0
PyOpenGL==3.1.7
pyparsing==3.2.0
pytest==8.3.4
python-dateutil==2.9.0.post0
pytinyrenderer==0.0.14
PyYAML==6.0.2
pyzmq==26.2.0
requests==2.32.3
rich==13.9.4
safetensors==0.4.2
scipy==1.12.0
sentry-sdk==2.19.2
setproctitle==1.3.4
six==1.17.0
smmap==5.0.1
stack-data==0.6.3
tensorboardX==2.6.2.2
tensorflow-probability==0.25.0
tensorstore==0.1.69
toolz==1.0.0
tornado==6.4.2
tqdm==4.67.1
traitlets==5.14.3
trimesh==4.5.3
typing_extensions==4.12.2
urllib3==2.2.3
wandb==0.19.0
wcwidth==0.2.13
Werkzeug==3.1.3
zipp==3.21.0

@Arseni1919
Copy link
Author

OMG, the pip install jax==0.4.23 and pip install jaxlib==0.4.23 did the job, but it is an unintuitive terrible solution that took me an hour to find. Please, make this package compatible with the newest version of JAX. You are claiming to be The baseline env for MARL and already did a great job so far. Please, fix this bug!
TL;DR: I programmed with MacBook air M3 and I needed to downgrade the jax version to make the jaxmarl work.

@amacrutherford
Copy link
Collaborator

Hey! Thanks for posting this and apologies for our late reply, we’ll get on fixing this :)

@amacrutherford
Copy link
Collaborator

fixed with #123

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants