Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segmentation fault in jaxlib 0.4.37 and Debian 11 ARM64 #25436

Open
gongomgra opened this issue Dec 12, 2024 · 3 comments
Open

Segmentation fault in jaxlib 0.4.37 and Debian 11 ARM64 #25436

gongomgra opened this issue Dec 12, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@gongomgra
Copy link

Description

Running a sample script for testing purposes in a Debian 11 ARM64 machine for latest jaxlib version 0.4.37 returns a segmentation fault. There is no difference if I run the jax.random.key() method instead of jax.random.PRNGKey() as mentioned in the docs.

root@ec47560c0686:/app/goss# cat /etc/os-release | head -n 4
PRETTY_NAME="Debian GNU/Linux 11 (bullseye)"
NAME="Debian GNU/Linux"
VERSION_ID="11"
VERSION="11 (bullseye)"

root@ec47560c0686:/app/goss# python -c 'import jax.numpy as jnp; from jax import grad, jit, vmap; from jax import random; key = random.PRNGKey(0); x = random.normal(key, (10,)); print(x)'
Segmentation fault

I have verified that I can generate pseudo random numbers from /dev/urandom without issues, and also that the command above works in Debian 11 AMD64

root@ec47560c0686:/app# od -An  -N8 -d /dev/urandom | sed -e 's| ||g' -e 's|\(.\{11\}\).*|\1|'
40475556011
root@ec47560c0686:/app# od -An  -N8 -d /dev/urandom | sed -e 's| ||g' -e 's|\(.\{11\}\).*|\1|'
64618156373

The jaxlib wheel has been build from source with the command mentioned in the docs

python build/build.py build --wheels=jaxlib

System info (python version, jaxlib version, accelerator, etc.)

root@ec47560c0686:/app# python -c 'import jax; jax.print_environment_info()'
jax:    0.4.37.dev20241209+ffb07cd
jaxlib: 0.4.37.dev20241211
numpy:  2.2.0
python: 3.12.8 (main, Dec  4 2024, 00:26:16) [GCC 10.2.1 20210110]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='123c48685638', release='6.10.0-linuxkit', version='#1 SMP Wed Jul 17 10:51:09 UTC 2024', machine='aarch64')
@gongomgra gongomgra added the bug Something isn't working label Dec 12, 2024
@justinjfu
Copy link
Collaborator

Can you perform any JAX op (such as x = jnp.ones((10,)); print(x+1), or is it specific to jax.random?

JAX doesn't use the system's RNG but instead has it's own software implementation. So this is more likely an issue with the JAX installation itself rather than the random module in particular.

@gongomgra
Copy link
Author

You are right, that other operation you proposed also produces a segmentation fault.

root@63ae726a28cd:/app# python -c 'import jax.numpy as jnp; x = jnp.ones((10,)); print(x+1)'
Segmentation fault

@hawkinsp
Copy link
Collaborator

I can't reproduce on Debian 12 on a Google cloud C4A (ARM Axion) VM.

In a fresh Debian 12 VM, I did this:

$ sudo apt install python3.11-venv
$ python3.11 -m venv myenv
$ source myenv/bin/activate
$ pip install jax ipython

and I had no trouble running your example.

So one of the following is true:

  • your ARM machine is different in some important way. Share the output of lscpu?
  • Debian 11 vs Debian 12 makes a difference. GCP doesn't have Debian 11 ARM images, so it's not very easy for me to try this.
  • something is broken about your installation, for example your Python installation.

Can you try in a fresh Docker container or VM?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants