Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-symmetric eigenvalue decomposition: unimplemented error on jax>=0.4.36 #25687

Open
gautierronan opened this issue Dec 26, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@gautierronan
Copy link

gautierronan commented Dec 26, 2024

Description

Following the docstring of jax.lax.linalg.eig, the function should work on GPU as of jax>=0.4.36. However, running the following MWE, neither of the three calls run as intended:

import jax
import jax.numpy as jnp

print(jax.devices)
# [CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3)]

x = jnp.ones((3, 3))
print(x.device)
# CudaDevice(id=0)

# attempt to run the following lines will raise an unimplemented error
jnp.linalg.eig(x)
jax.lax.linalg.eig(x)
jax.lax.linalg.eig(x, use_magma=False)

and return

XlaRuntimeError: UNIMPLEMENTED: No registered implementation for custom call to cuhybrid_eig_real for platform CUDA

I do not intend to use MAGMA (don't have it installed), but the function should still be using LAPACK on the host CPU if I understand correctly.

Full stack trace
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
File Untitled-1:5
      [2](untitled-1:2) import jax.numpy as jnp
      [4](untitled-1:4) x = jnp.ones((3, 3))
----> [5](untitled-1:5) jnp.linalg.eig(x)
      [6](untitled-1:6) jax.lax.linalg.eig(x)
      [7](untitled-1:7) jax.lax.linalg.eig(x, use_magma=False)

File ~/.venv/lib/python3.11/site-packages/jax/_src/numpy/linalg.py:765, in eig(a)
    [763](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/numpy/linalg.py:763) check_arraylike("jnp.linalg.eig", a)
    [764](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/numpy/linalg.py:764) a, = promote_dtypes_inexact(jnp.asarray(a))
--> [765](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/numpy/linalg.py:765) w, v = lax_linalg.eig(a, compute_left_eigenvectors=False)
    [766](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/numpy/linalg.py:766) return w, v

File ~/.venv/lib/python3.11/site-packages/jax/_src/lax/linalg.py:175, in eig(x, compute_left_eigenvectors, compute_right_eigenvectors, use_magma)
    [124](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/lax/linalg.py:124) def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True,
    [125](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/lax/linalg.py:125)         compute_right_eigenvectors: bool = True,
    [126](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/lax/linalg.py:126)         use_magma: bool | None = None) -> list[Array]:
    [127](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/lax/linalg.py:127)   """Eigendecomposition of a general matrix.
    [128](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/lax/linalg.py:128) 
    [129](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/lax/linalg.py:129)   Nonsymmetric eigendecomposition is only implemented on CPU and GPU. On GPU,
   (...)
    [173](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/lax/linalg.py:173)     for that batch element.
    [174](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/lax/linalg.py:174)   """
--> [175](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/lax/linalg.py:175)   return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors,
    [176](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/lax/linalg.py:176)                     compute_right_eigenvectors=compute_right_eigenvectors,
    [177](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/lax/linalg.py:177)                     use_magma=use_magma)

File ~/.venv/lib/python3.11/site-packages/jax/_src/core.py:463, in Primitive.bind(self, *args, **params)
    [461](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/core.py:461) trace_ctx.set_trace(eval_trace)
    [462](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/core.py:462) try:
--> [463](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/core.py:463)   return self.bind_with_trace(prev_trace, args, params)
    [464](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/core.py:464) finally:
    [465](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/core.py:465)   trace_ctx.set_trace(prev_trace)

File ~/.venv/lib/python3.11/site-packages/jax/_src/core.py:468, in Primitive.bind_with_trace(self, trace, args, params)
    [467](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/core.py:467) def bind_with_trace(self, trace, args, params):
--> [468](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/core.py:468)   return trace.process_primitive(self, args, params)

File ~/.venv/lib/python3.11/site-packages/jax/_src/core.py:941, in EvalTrace.process_primitive(self, primitive, args, params)
    [939](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/core.py:939)       return primitive.bind_with_trace(arg._trace, args, params)
    [940](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/core.py:940) check_eval_args(args)
--> [941](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/core.py:941) return primitive.impl(*args, **params)

File ~/.venv/lib/python3.11/site-packages/jax/_src/lax/linalg.py:715, in eig_impl(operand, compute_left_eigenvectors, compute_right_eigenvectors, use_magma)
    [713](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/lax/linalg.py:713) def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors,
    [714](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/lax/linalg.py:714)              use_magma):
--> [715](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/lax/linalg.py:715)   return dispatch.apply_primitive(
    [716](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/lax/linalg.py:716)       eig_p,
    [717](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/lax/linalg.py:717)       operand,
    [718](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/lax/linalg.py:718)       compute_left_eigenvectors=compute_left_eigenvectors,
    [719](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/lax/linalg.py:719)       compute_right_eigenvectors=compute_right_eigenvectors,
    [720](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/lax/linalg.py:720)       use_magma=use_magma,
    [721](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/lax/linalg.py:721)   )

File ~/.venv/lib/python3.11/site-packages/jax/_src/dispatch.py:90, in apply_primitive(prim, *args, **params)
     [88](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/dispatch.py:88) prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
     [89](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/dispatch.py:89) try:
---> [90](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/dispatch.py:90)   outs = fun(*args)
     [91](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/dispatch.py:91) finally:
     [92](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/dispatch.py:92)   lib.jax_jit.swap_thread_local_state_disable_jit(prev)

    [... skipping hidden 11 frame]

File ~/.venv/lib/python3.11/site-packages/jax/_src/compiler.py:303, in backend_compile(backend, module, options, host_callbacks)
    [297](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/compiler.py:297)     return backend.compile(
    [298](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/compiler.py:298)         built_c, compile_options=options, host_callbacks=host_callbacks
    [299](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/compiler.py:299)     )
    [300](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/compiler.py:300)   # Some backends don't have `host_callbacks` option yet
    [301](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/compiler.py:301)   # TODO(sharadmv): remove this fallback when all backends allow `compile`
    [302](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/compiler.py:302)   # to take in `host_callbacks`
--> [303](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/compiler.py:303)   return backend.compile(built_c, compile_options=options)
    [304](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/compiler.py:304) except xc.XlaRuntimeError as e:
    [305](https://vscode-remote+ssh-002dremote-002bhiggs.vscode-resource.vscode-cdn.net/home/rgautier/~/.venv/lib/python3.11/site-packages/jax/_src/compiler.py:305)   for error_handler in _XLA_RUNTIME_ERROR_HANDLERS:

XlaRuntimeError: UNIMPLEMENTED: No registered implementation for custom call to cuhybrid_eig_real for platform CUDA

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.38
jaxlib: 0.4.38
numpy:  2.2.1
python: 3.11.10 (main, Sep  9 2024, 22:11:19) [Clang 18.1.8 ]
device info: NVIDIA L40S-4, 4 local devices"
process_count: 1
platform: uname_result(system='Linux', node='higgs', release='6.8.0-45-generic', version='#45-Ubuntu SMP PREEMPT_DYNAMIC Fri Aug 30 12:02:04 UTC 2024', machine='x86_64')


$ nvidia-smi
Thu Dec 26 12:05:38 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA L40S                    Off |   00000000:03:00.0 Off |                    0 |
| N/A   48C    P0            107W /  350W |     443MiB /  46068MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA L40S                    Off |   00000000:04:00.0 Off |                    0 |
| N/A   47C    P0            103W /  350W |     435MiB /  46068MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA L40S                    Off |   00000000:63:00.0 Off |                    0 |
| N/A   52C    P0            115W /  350W |     435MiB /  46068MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA L40S                    Off |   00000000:64:00.0 Off |                    0 |
| N/A   48C    P0            106W /  350W |     435MiB /  46068MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A    338088      C   /home/rgautier/.venv/bin/python               434MiB |
|    1   N/A  N/A    338088      C   /home/rgautier/.venv/bin/python               426MiB |
|    2   N/A  N/A    338088      C   /home/rgautier/.venv/bin/python               426MiB |
|    3   N/A  N/A    338088      C   /home/rgautier/.venv/bin/python               426MiB |
+-----------------------------------------------------------------------------------------+
@gautierronan gautierronan added the bug Something isn't working label Dec 26, 2024
@rajasekharporeddy
Copy link
Contributor

Hi @gautierronan,

I could not reproduce the error mentioned in this issue when tested with JAX 0.4.38, NumPy 2.2.1 and Python 3.10-3.12 on Cloud VMs (one with 4 NVIDIA Tesla T4 GPUs and another with 2 NVIDIA L4 GPUs). Please find the below screenshot for reference.

With 4 - T4 GPUs:
image

With 2 - L4 GPUs:
image

Could you please try to install JAX on a fresh environment and verify if the issue still persists.

Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants