Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support "triangular-scan" for native lower/upper triangular algorithms #2667

Open
quattro opened this issue Dec 12, 2024 · 3 comments
Open

Comments

@quattro
Copy link

quattro commented Dec 12, 2024

Request description

Hi! First, I want to emphasize how much myself and my lab uses JAX to dramatically improve our scientific software and algorithms and truly appreciate the incredible work that JAX/XLA teams have done.

I'm curious if it would be possible to define HLO/XLA primitives that target the specific double loops common for upper/lower triangular matrix algorithms (ie backsolve, levinson-durbin, etc). Due to the static shape requirements throughout loop iterations, workarounds for these style of problems usually involve masking. However, provided the original shape is static/known, the shape of each row operation is inferrable as well.

For example, if provided an n x n matrix, the outer loop iterates from i=1...n, while the inner loop is j=1..i, which is often dealt with through subsetting/vectorization of the necessary computation.

Is this special case something that could be supported? While it may seem niche, it covers quite many classical algorithms in linear algebra.

@GleasonK
Copy link
Member

Hello! I'm a fan of the idea! A few thoughts:

  1. How would this be represented in a framework like JAX? Is the thought that the tensor would be annotated that it's data is triangular? Or would this be an operator itself that would only compute on the triangular section of the data regardless of what's in the other portion?
  2. Such a feature would require compiler support, are you hoping to contribute this? We could likely quickly prototype something using composite ops to unblock exploring an accelerated XLA impl
  3. We've been discussing something like a ragged_map op, I think in theory this could be used to accomplish the iteration granularity (depending on the op design, of course).

cc @wsmoses who has talked about something similar to this as well I think? Let me know if you have any thoughts

@wsmoses
Copy link
Contributor

wsmoses commented Dec 13, 2024

So offhand I was earlier thinking of potentially having a type or operation attribute that specifies whether the data has some known structure. For example, for matmul saying if the first operand is upper triangular one could use https://netlib.org/lapack/explore-html-3.6.1/d1/d54/group__double__blas__level3_gaf07edfbb2d2077687522652c9e283e1e.html instead of GEMM. Of course this is still a spec-level change it would be up to whatever stablehlo lowers into to leverage the additional information (or not)

@quattro
Copy link
Author

quattro commented Dec 16, 2024

Hi all, thanks for the detailed responses!

  1. How would this be represented in a framework like JAX? Is the thought that the tensor would be annotated that it's data is triangular? Or would this be an operator itself that would only compute on the triangular section of the data regardless of what's in the other portion?

My initial thoughts were the latter, resembling something along the lines of tri_scan(x, my_func, init, 'lower') to capture the behavior of

n = x.shape[0]
carry = init
for i in range(n):
  if type == 'lower':
    carry = my_func(x[i, 0:i], carry)
  else:
    carry = my_func(x[i, i:n], carry)
return carry

This structure encompasses back/forward solves and other linear-algebraic operations that require only inspecting upper/lower triangular parts of the a square matrix.

  1. Such a feature would require compiler support, are you hoping to contribute this? We could likely quickly prototype something using composite ops to unblock exploring an accelerated XLA impl

While I would love to be able to, I'm afraid I do not have the time to contribute substantially towards this, due to my current research/mentoring/administrative duties. I apologize, as I realize it's a bit unfair for me to request a feature, but contribute so little towards its realization.

  1. We've been discussing something like a ragged_map op, I think in theory this could be used to accomplish the iteration granularity (depending on the op design, of course).

Yes, absolutely. The case I've outlined above could be cast as a special case of ragged_map.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants