Skip to content

Commit

Permalink
Implements FlashDecoding with Sparsity Support (#899)
Browse files Browse the repository at this point in the history
* Implements FlashDecoding

* require kv_seq_len

* update
  • Loading branch information
hanzhi713 authored Jan 2, 2025
1 parent e4ff72c commit 60ca6ce
Show file tree
Hide file tree
Showing 7 changed files with 867 additions and 239 deletions.
528 changes: 334 additions & 194 deletions axlearn/common/flash_attention/gpu_attention_benchmark.py

Large diffs are not rendered by default.

94 changes: 90 additions & 4 deletions axlearn/common/flash_attention/gpu_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
import jax
import jax.numpy as jnp
import pytest
from absl.testing import parameterized

from axlearn.common.attention_bias import sliding_window_causal_mask
from axlearn.common.flash_attention.gpu_attention import (
cudnn_dot_product_attention,
flash_attention,
)
from axlearn.common.flash_attention.utils import mha_reference
from axlearn.common.flash_attention.gpu_decoding import NEG_INF, flash_decoding
from axlearn.common.flash_attention.utils import _repeat_kv_heads, mha_reference
from axlearn.common.test_utils import TestCase

if jax.default_backend() != "gpu":
pytest.skip(reason="Incompatible hardware", allow_module_level=True)
Expand Down Expand Up @@ -92,9 +96,91 @@ def impl(q, k, v, bias, segment_ids):
chex.assert_trees_all_close(o, o_ref, atol=0.07)


# We test the flash_attention against the reference mha_reference.
# The outputs should be close in both fp16 and fp32, with a relaxed bound due
# to the numerical difference during operations.
class FlashDecodingTest(TestCase):
"""Tests FlashDecoding."""

@parameterized.product(
[
dict(zip(["batch_size", "seq_len", "num_heads", "per_head_dim"], args))
for args in [
(1, 1024, 32, 64),
(1, 444, 16, 64),
(8, 1596, 48, 128),
(8, 4044, 64, 128),
]
],
softmax_scale=[1.0, 0.83],
attention_bias_type=["2d", "4d", None],
input_dtype=[jnp.float32, jnp.float16],
padding=[0, 111],
kv_head_factor=[1, 4, 8],
window_len=[-1, 16, 127],
)
def test_decode_against_ref(
self,
batch_size: int,
seq_len: int,
num_heads: int,
per_head_dim: int,
softmax_scale: float,
attention_bias_type: Literal["2d", "4d", None],
input_dtype: jnp.dtype,
padding: int,
kv_head_factor: int,
window_len: int,
):
self.assertEqual(num_heads % kv_head_factor, 0)
assert num_heads % kv_head_factor == 0
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(42), 4)
q = jax.random.normal(k1, (batch_size, 1, num_heads, per_head_dim), dtype=input_dtype)
k = jax.random.normal(
k2,
(batch_size, seq_len + padding, num_heads // kv_head_factor, per_head_dim),
dtype=input_dtype,
)
v = jax.random.normal(
k3,
(batch_size, seq_len + padding, num_heads // kv_head_factor, per_head_dim),
dtype=input_dtype,
)

if attention_bias_type == "4d":
bias = jax.random.normal(
k4, (batch_size, num_heads, 1, seq_len + padding), dtype=input_dtype
)
elif attention_bias_type == "2d":
bias = jax.random.normal(k4, (1, 1, 1, seq_len + padding), dtype=input_dtype)
else:
bias = None

mask_fn = None
if window_len > 0:
mask_fn = sliding_window_causal_mask(window_len)
o = flash_decoding(
q, k, v, bias=bias, softmax_scale=softmax_scale, kv_seq_len=seq_len, mask_fn=mask_fn
)
if bias is not None:
bias = bias[:, :, :, :seq_len]
if window_len > 0:
if bias is None:
bias = jnp.zeros((1, 1, 1, seq_len), dtype=input_dtype)
bias = bias.at[:, :, :, : -window_len - 1].set(NEG_INF)
o_ref = mha_reference(
q,
_repeat_kv_heads(num_heads, k[:, :seq_len]),
_repeat_kv_heads(num_heads, v[:, :seq_len]),
bias,
None,
causal=False,
softmax_scale=softmax_scale,
)
self.assertGreaterEqual(jnp.median(jnp.abs(o_ref)).item(), 0.25)
if input_dtype is jnp.float32:
self.assertNestedAllClose(o, o_ref, rtol=0.01, atol=0.01)
else:
self.assertNestedAllClose(o, o_ref, rtol=0.05, atol=0.05)


@pytest.mark.parametrize(
"batch_size,num_heads,seq_len,per_head_dim",
[
Expand Down
Loading

0 comments on commit 60ca6ce

Please sign in to comment.