Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in saving policies #4

Open
cmakelabs opened this issue Oct 31, 2024 · 1 comment
Open

Error in saving policies #4

cmakelabs opened this issue Oct 31, 2024 · 1 comment

Comments

@cmakelabs
Copy link

Hello,
I have successfully managed to train the policy, however when I tried to test out how to save and load the training parameters. I encountered the following error in the notebook save_load_policy.ipynb:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[4], line 11
      8 path = FLIGHTNING_PATH + "[/../examples/saved_params](http://localhost:8889/examples/saved_params)"
     10 ckptr = PyTreeCheckpointer()
---> 11 ckptr.save(path, params)

File [~/miniconda3/envs/flightning/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py:204](http://localhost:8889/home/noetic/miniconda3/envs/flightning/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py#line=203), in Checkpointer.save(self, directory, force, *args, **kwargs)
    202 ckpt_args = construct_checkpoint_args(self._handler, True, *args, **kwargs)
    203 tmpdir = asyncio.run(self.create_temporary_path(directory))
--> 204 self._handler.save(tmpdir.get(), args=ckpt_args)
    205 multihost.sync_global_processes(
    206     multihost.unique_barrier_key(
    207         'Checkpointer:save',
   (...)
    211     processes=self._active_processes,
    212 )
    214 # Ensure save operation atomicity and record time saved by checkpoint.

File [~/miniconda3/envs/flightning/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py:566](http://localhost:8889/home/noetic/miniconda3/envs/flightning/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py#line=565), in PyTreeCheckpointHandler.save(self, directory, item, save_args, args)
    564 """Saves the provided item. See async_save."""
    565 args = _get_impl_save_args(item, save_args, args)
--> 566 self._handler_impl.save(directory, args=args)

File [~/miniconda3/envs/flightning/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py:501](http://localhost:8889/home/noetic/miniconda3/envs/flightning/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py#line=500), in BasePyTreeCheckpointHandler.save(self, directory, *args, **kwargs)
    498     for f in commit_futures:
    499       f.result()  # Block on result.
--> 501 asyncio.run(async_save(directory, *args, **kwargs))

File [~/miniconda3/envs/flightning/lib/python3.10/site-packages/nest_asyncio.py:30](http://localhost:8889/home/noetic/miniconda3/envs/flightning/lib/python3.10/site-packages/nest_asyncio.py#line=29), in _patch_asyncio.<locals>.run(main, debug)
     28 task = asyncio.ensure_future(main)
     29 try:
---> 30     return loop.run_until_complete(task)
     31 finally:
     32     if not task.done():

File [~/miniconda3/envs/flightning/lib/python3.10/site-packages/nest_asyncio.py:98](http://localhost:8889/home/noetic/miniconda3/envs/flightning/lib/python3.10/site-packages/nest_asyncio.py#line=97), in _patch_loop.<locals>.run_until_complete(self, future)
     95 if not f.done():
     96     raise RuntimeError(
     97         'Event loop stopped before Future completed.')
---> 98 return f.result()

File [~/miniconda3/envs/flightning/lib/python3.10/asyncio/futures.py:201](http://localhost:8889/home/noetic/miniconda3/envs/flightning/lib/python3.10/asyncio/futures.py#line=200), in Future.result(self)
    199 self.__log_traceback = False
    200 if self._exception is not None:
--> 201     raise self._exception
    202 return self._result

File [~/miniconda3/envs/flightning/lib/python3.10/asyncio/tasks.py:232](http://localhost:8889/home/noetic/miniconda3/envs/flightning/lib/python3.10/asyncio/tasks.py#line=231), in Task.__step(***failed resolving arguments***)
    228 try:
    229     if exc is None:
    230         # We use the `send` method directly, because coroutines
    231         # don't have `__iter__` and `__next__` methods.
--> 232         result = coro.send(None)
    233     else:
    234         result = coro.throw(exc)

File [~/miniconda3/envs/flightning/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py:494](http://localhost:8889/home/noetic/miniconda3/envs/flightning/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py#line=493), in BasePyTreeCheckpointHandler.save.<locals>.async_save(*args, **kwargs)
    493 async def async_save(*args, **kwargs):
--> 494   commit_futures = await self.async_save(*args, **kwargs)  # pytype: disable=bad-return-type
    495   # Futures are already running, so sequential waiting is equivalent to
    496   # concurrent waiting.
    497   if commit_futures:  # May be None.

File [~/miniconda3/envs/flightning/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py:421](http://localhost:8889/home/noetic/miniconda3/envs/flightning/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py#line=420), in BasePyTreeCheckpointHandler.async_save(self, directory, args)
    418 ocdbt_target_data_file_size = args.ocdbt_target_data_file_size
    420 save_args = _fill_missing_save_or_restore_args(item, save_args, mode='save')
--> 421 byte_limiter = serialization.get_byte_limiter(self._save_concurrent_bytes)
    422 param_infos = self._get_param_infos(
    423     item,
    424     directory,
   (...)
    427     byte_limiter=byte_limiter,
    428 )
    429 assert all(
    430     leaf.parent_dir == directory for leaf in jax.tree.leaves(param_infos)
    431 )

File [~/miniconda3/envs/flightning/lib/python3.10/site-packages/orbax/checkpoint/serialization.py:230](http://localhost:8889/home/noetic/miniconda3/envs/flightning/lib/python3.10/site-packages/orbax/checkpoint/serialization.py#line=229), in get_byte_limiter(concurrent_bytes)
    226 if concurrent_bytes <= 0:
    227   raise ValueError(
    228       f'Must provide positive `concurrent_bytes`. Found: {concurrent_bytes}'
    229   )
--> 230 return LimitInFlightBytes(concurrent_bytes)

File [~/miniconda3/envs/flightning/lib/python3.10/site-packages/orbax/checkpoint/serialization.py:181](http://localhost:8889/home/noetic/miniconda3/envs/flightning/lib/python3.10/site-packages/orbax/checkpoint/serialization.py#line=180), in LimitInFlightBytes.__init__(self, num_bytes)
    179 self._max_bytes = num_bytes
    180 self._available_bytes = num_bytes
--> 181 self._cv = asyncio.Condition(lock=asyncio.Lock())

File [~/miniconda3/envs/flightning/lib/python3.10/asyncio/locks.py:234](http://localhost:8889/home/noetic/miniconda3/envs/flightning/lib/python3.10/asyncio/locks.py#line=233), in Condition.__init__(self, lock, loop)
    232     lock = Lock()
    233 elif lock._loop is not self._get_loop():
--> 234     raise ValueError("loop argument must agree with lock")
    236 self._lock = lock
    237 # Export the lock's locked(), acquire() and release() methods.

ValueError: loop argument must agree with lock

I assume it is related to async operations within orbax itself, but not sure how to resolve this issue. Would be really grateful for your guidance,

@joheeg
Copy link
Contributor

joheeg commented Nov 17, 2024

I updated the installation instructions and tested the examples again. I presume, your error was caused by a newer version of orbax. If you want to stick to your setup, I suggest coming up with a different saving routine. Otherwise, installing the code using python 3.9, as now indicated in the README should also work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants