You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Following the docstring of jax.lax.linalg.eig, the function should work on GPU as of jax>=0.4.36. However, running the following MWE, neither of the three calls run as intended:
importjaximportjax.numpyasjnpprint(jax.devices)
# [CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3)]x=jnp.ones((3, 3))
print(x.device)
# CudaDevice(id=0)# attempt to run the following lines will raise an unimplemented errorjnp.linalg.eig(x)
jax.lax.linalg.eig(x)
jax.lax.linalg.eig(x, use_magma=False)
and return
XlaRuntimeError: UNIMPLEMENTED: No registered implementation for custom call to cuhybrid_eig_real for platform CUDA
I do not intend to use MAGMA (don't have it installed), but the function should still be using LAPACK on the host CPU if I understand correctly.
I could not reproduce the error mentioned in this issue when tested with JAX 0.4.38, NumPy 2.2.1 and Python 3.10-3.12 on Cloud VMs (one with 4 NVIDIA Tesla T4 GPUs and another with 2 NVIDIA L4 GPUs). Please find the below screenshot for reference.
With 4 - T4 GPUs:
With 2 - L4 GPUs:
Could you please try to install JAX on a fresh environment and verify if the issue still persists.
Description
Following the docstring of
jax.lax.linalg.eig
, the function should work on GPU as of jax>=0.4.36. However, running the following MWE, neither of the three calls run as intended:and return
I do not intend to use MAGMA (don't have it installed), but the function should still be using LAPACK on the host CPU if I understand correctly.
Full stack trace
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: