You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When I use jax.numpy.full, it allocates memory on TPU, even when instructed to allocate it on the CPU
if the memory requested is smaller then the free memory on the TPU, the memory is deallocated and reallocated on the CPU instead
however if the memory requested is more than the free memory, the device crashes with OOM
in my snippet, it's an extremely edge case where I am trying to allocate 40GB, which normally never happens, however, during training and execution, allocating jax buffers on CPU may be the difference for some training pipelines fitting withing specific memory constraints or not.
snippet:
import jax
import jax.numpy as jnp
cpu = jax.devices('cpu')[0]
a = jnp.full(shape=(40*(2**30)), fill_value=1.0, dtype=jnp.uint8, device=cpu)
System info (python version, jaxlib version, accelerator, etc.)
I think, the best solution here is to do the following to make sure all ops inside jnp.full run on CPU:
with jax.default_device(jax.devices('cpu')[0]):
a = jnp.full(shape=(40*(2**30)), fill_value=1.0, dtype=jnp.uint8)
print(a)
# any op outside the scope will run on the TPU/GPU
I think, the best solution here is to do the following to make sure all ops inside jnp.full run on CPU:
with jax.default_device(jax.devices('cpu')[0]):
a = jnp.full(shape=(40*(2**30)), fill_value=1.0, dtype=jnp.uint8)
print(a)
# any op outside the scope will run on the TPU/GPU
That's a very good workaround!
However I think the behaviour of jnp.full and similar API needs to be more faithful to their arguments then their global environments, especially if it can lead to OOM crashes
Description
When I use jax.numpy.full, it allocates memory on TPU, even when instructed to allocate it on the CPU
if the memory requested is smaller then the free memory on the TPU, the memory is deallocated and reallocated on the CPU instead
however if the memory requested is more than the free memory, the device crashes with OOM
in my snippet, it's an extremely edge case where I am trying to allocate 40GB, which normally never happens, however, during training and execution, allocating jax buffers on CPU may be the difference for some training pipelines fitting withing specific memory constraints or not.
snippet:
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.37
jaxlib: 0.4.36
numpy: 1.26.4
python: 3.10.12 (main, Nov 6 2024, 20:22:13) [GCC 11.4.0]
device info: TPU v6 lite-4, 4 local devices"
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-7691bba4-w-0', release='6.8.0-1015-gcp', version='#17~22.04.1-Ubuntu SMP Tue Sep 3 16:11:52 UTC 2024', machine='x86_64')
The text was updated successfully, but these errors were encountered: