Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.full allocates memory on the wrong device #25396

Open
Chaosruler972 opened this issue Dec 11, 2024 · 2 comments
Open

jax.full allocates memory on the wrong device #25396

Chaosruler972 opened this issue Dec 11, 2024 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@Chaosruler972
Copy link

Description

When I use jax.numpy.full, it allocates memory on TPU, even when instructed to allocate it on the CPU

if the memory requested is smaller then the free memory on the TPU, the memory is deallocated and reallocated on the CPU instead

however if the memory requested is more than the free memory, the device crashes with OOM

in my snippet, it's an extremely edge case where I am trying to allocate 40GB, which normally never happens, however, during training and execution, allocating jax buffers on CPU may be the difference for some training pipelines fitting withing specific memory constraints or not.

snippet:

import jax
import jax.numpy as jnp
cpu = jax.devices('cpu')[0]
a = jnp.full(shape=(40*(2**30)), fill_value=1.0, dtype=jnp.uint8, device=cpu)

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.37
jaxlib: 0.4.36
numpy: 1.26.4
python: 3.10.12 (main, Nov 6 2024, 20:22:13) [GCC 11.4.0]
device info: TPU v6 lite-4, 4 local devices"
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-7691bba4-w-0', release='6.8.0-1015-gcp', version='#17~22.04.1-Ubuntu SMP Tue Sep 3 16:11:52 UTC 2024', machine='x86_64')

@Chaosruler972 Chaosruler972 added the bug Something isn't working label Dec 11, 2024
@yashk2810
Copy link
Collaborator

yashk2810 commented Dec 11, 2024

I think, the best solution here is to do the following to make sure all ops inside jnp.full run on CPU:

with jax.default_device(jax.devices('cpu')[0]):
  a = jnp.full(shape=(40*(2**30)), fill_value=1.0, dtype=jnp.uint8)
  print(a)
# any op outside the scope will run on the TPU/GPU

@Chaosruler972
Copy link
Author

I think, the best solution here is to do the following to make sure all ops inside jnp.full run on CPU:

with jax.default_device(jax.devices('cpu')[0]):
  a = jnp.full(shape=(40*(2**30)), fill_value=1.0, dtype=jnp.uint8)
  print(a)
# any op outside the scope will run on the TPU/GPU

That's a very good workaround!

However I think the behaviour of jnp.full and similar API needs to be more faithful to their arguments then their global environments, especially if it can lead to OOM crashes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants