From 60ca6ce5bc0df0ecbe30d1e67e088461f24c988d Mon Sep 17 00:00:00 2001 From: Hanzhi Zhou Date: Thu, 2 Jan 2025 09:29:22 -1000 Subject: [PATCH] Implements FlashDecoding with Sparsity Support (#899) * Implements FlashDecoding * require kv_seq_len * update --- .../gpu_attention_benchmark.py | 528 +++++++++++------- .../flash_attention/gpu_attention_test.py | 94 +++- .../common/flash_attention/gpu_decoding.py | 357 ++++++++++++ axlearn/common/flash_attention/layer.py | 20 +- axlearn/common/flash_attention/layer_test.py | 47 +- .../common/flash_attention/tpu_attention.py | 2 +- axlearn/common/flash_attention/utils.py | 58 +- 7 files changed, 867 insertions(+), 239 deletions(-) create mode 100644 axlearn/common/flash_attention/gpu_decoding.py diff --git a/axlearn/common/flash_attention/gpu_attention_benchmark.py b/axlearn/common/flash_attention/gpu_attention_benchmark.py index 695f9fa9b..b7c1d79a5 100644 --- a/axlearn/common/flash_attention/gpu_attention_benchmark.py +++ b/axlearn/common/flash_attention/gpu_attention_benchmark.py @@ -1,235 +1,375 @@ # Copyright © 2023 Apple Inc. +# +# Some of the code in this file is adapted from: +# +# jax-ml/jax: +# Copyright 2023 The JAX Authors. +# Licensed under the Apache License, Version 2.0 (the "License"). +# pylint: disable=line-too-long """FlashAttention kernel benchmarks. -Sample outputs on A100: - -bench-head32-seq1024-d64: - batch_size Jax Jax Triton Pallas PyTorch PyTorch Triton -0 2.0 1.294677 0.308249 0.178258 2.774377 0.198660 -1 4.0 2.337943 0.417789 0.268292 5.449353 0.327930 -2 8.0 4.341445 0.625852 0.587789 10.795019 0.585716 -3 16.0 8.209461 1.210101 0.736935 21.485954 1.101084 -batch2-seq2048-d64: - num_heads Jax Jax Triton Pallas PyTorch PyTorch Triton -0 12.0 1.758803 0.377092 0.292216 4.154821 0.282745 -1 16.0 2.274562 0.437537 0.324150 5.495135 0.334550 -2 32.0 4.111703 0.646520 0.599221 10.850373 0.558403 -3 40.0 5.044082 0.765017 0.676835 13.539443 0.675699 -4 56.0 6.987845 1.035317 0.764956 18.896891 0.900977 -5 72.0 8.709219 1.234217 0.836006 24.280788 1.129773 -batch2-head32-d64: - seq_len Jax Jax Triton Pallas PyTorch PyTorch Triton -0 128.0 0.141517 0.127410 0.274202 0.102813 0.022059 -1 256.0 0.224241 0.128345 0.268729 0.212434 0.042063 -2 512.0 0.426494 0.148231 0.268689 0.770047 0.081487 -3 1024.0 1.298753 0.316831 0.156720 2.774356 0.198292 -4 2048.0 4.031973 0.626030 0.589759 10.851770 0.558741 -batch2-head32-seq2048: - per_head_dim Jax Jax Triton Pallas PyTorch PyTorch Triton -0 16.0 3.858791 0.437569 0.213704 10.560411 0.331421 -1 32.0 3.955815 0.514656 0.261292 10.627665 0.437445 -2 64.0 4.121394 0.636113 0.584212 10.851702 0.558101 -3 128.0 4.439079 0.973939 0.371719 11.237614 0.754280 - -With backward pass: - -grad-head32-seq1024-d64: - batch_size Jax Jax Triton Pallas PyTorch PyTorch Triton -0 2.0 2.848025 1.942416 1.694485 6.299916 1.133216 -1 4.0 5.322315 1.991064 2.380001 12.186511 1.898364 -2 8.0 10.041783 2.945663 4.342529 23.935926 2.966447 -3 16.0 19.633947 5.056746 7.604803 47.527328 5.060277 -grad-batch2-seq2048-d64: - num_heads Jax Jax Triton Pallas PyTorch PyTorch Triton -0 12.0 3.776024 2.353134 3.311361 9.244321 2.498945 -1 16.0 5.015553 2.435218 3.501143 12.162596 2.632030 -2 32.0 9.679915 2.942006 4.506633 24.000723 3.347016 -3 40.0 11.839152 3.409772 5.224550 29.938248 3.646133 -4 56.0 16.324726 5.345623 8.512093 41.617474 5.386331 -5 72.0 20.826162 5.967896 9.429584 53.344563 6.108087 -grad-batch2-head32-d64: - seq_len Jax Jax Triton Pallas PyTorch PyTorch Triton -0 128.0 2.084800 1.946231 1.715087 1.179843 0.623846 -1 256.0 2.205729 2.110186 1.704993 0.648029 0.340229 -2 512.0 2.259852 2.102726 1.554168 1.883081 0.480496 -3 1024.0 2.859227 2.069817 1.645913 6.291944 1.120009 -4 2048.0 9.606998 2.954207 4.540852 23.863083 2.998823 -grad-batch2-head32-seq2048: - per_head_dim Jax Jax Triton Pallas PyTorch PyTorch Triton -0 16.0 9.076053 2.093116 2.735943 23.176823 1.625986 -1 32.0 9.133391 2.438204 3.100990 23.311085 2.156206 -2 64.0 9.490566 3.001139 4.537563 23.846561 3.010464 -3 128.0 10.347649 5.208034 6.846362 24.728912 5.447452 - -In addition to the dependencies in attention.py, also requires: -torch==2.1.0.dev20230726+cu121 -pytorch-triton==2.1.0+9e3e10c5ed +Tor run: python3 gpu_attention_benchmark.py > out.txt +Requires Jax >= 0.4.36. Sample numbers on H100 SXM5: +is_decode=True, use_bwd=False, num_heads=8, num_kv_heads=8, per_head_dim=128, sw_sz=-1 + jax axlearn jax-cudnn +bs=1,seq_len=1024 0.020608 0.018656 0.023680 +bs=1,seq_len=4096 0.037856 0.022784 0.056704 +bs=1,seq_len=8192 0.033792 0.032768 0.104448 +bs=1,seq_len=131072 0.227808 0.198816 1.486752 +bs=4,seq_len=1024 0.021440 0.022208 0.024032 +bs=4,seq_len=4096 0.069728 0.054624 0.059584 +bs=4,seq_len=8192 0.081952 0.076064 0.105920 +bs=4,seq_len=131072 0.823104 0.705056 1.488832 +bs=8,seq_len=1024 0.032544 0.030688 0.024608 +bs=8,seq_len=4096 0.089728 0.071648 0.063584 +bs=8,seq_len=8192 0.129184 0.114944 0.109856 +bs=8,seq_len=131072 1.616800 1.376288 1.503360 +bs=16,seq_len=1024 0.050976 0.048608 0.037504 +bs=16,seq_len=4096 0.136768 0.117312 0.104224 +bs=16,seq_len=8192 0.234688 0.200128 0.190944 +bs=16,seq_len=131072 3.211200 2.727040 2.779872 +bs=32,seq_len=1024 0.078656 0.072992 0.061440 +bs=32,seq_len=4096 0.236576 0.204512 0.190752 +bs=32,seq_len=8192 0.443488 0.372352 0.361216 +bs=32,seq_len=131072 6.392320 5.453344 5.495488 +is_decode=True, use_bwd=False, num_heads=8, seq_len=32768, per_head_dim=128, sw_sz=-1 + jax axlearn jax-cudnn +bs=1,num_kv_heads=1 0.049280 0.059296 0.378304 +bs=1,num_kv_heads=8 0.076352 0.070912 0.377344 +bs=8,num_kv_heads=1 0.111072 0.080480 0.377696 +bs=8,num_kv_heads=8 0.425536 0.368576 0.386880 +is_decode=True, use_bwd=False, num_heads=8, num_kv_heads=8, per_head_dim=128 + jax axlearn jax-cudnn +bs=1,seq_len=131072,sw_sz=-1 0.228640 0.199040 1.476928 +bs=1,seq_len=131072,sw_sz=4096 0.232320 0.053824 4.441376 +bs=1,seq_len=131072,sw_sz=16384 0.233696 0.061120 4.420992 +bs=8,seq_len=131072,sw_sz=-1 1.621696 1.374080 1.496224 +bs=8,seq_len=131072,sw_sz=4096 1.626016 0.193792 4.463296 +bs=8,seq_len=131072,sw_sz=16384 1.628704 0.318176 4.451648 +is_decode=False, use_bwd=False, num_heads=32, num_kv_heads=None, seq_len=4096, per_head_dim=128, sw_sz=-1 + jax axlearn jax-cudnn jax-pallas +bs=2 3.502944 0.915360 0.467744 0.845792 +bs=4 6.969376 1.753152 0.890496 1.617280 +bs=8 13.962816 3.415232 1.735232 3.150752 +is_decode=False, use_bwd=False, bs=2, num_kv_heads=None, seq_len=4096, per_head_dim=128, sw_sz=-1 + jax axlearn jax-cudnn jax-pallas +num_heads=12 1.262560 0.393536 0.205952 0.362304 +num_heads=16 1.786816 0.498304 0.257664 0.459936 +num_heads=32 3.507488 2.591456 0.468672 2.443296 +num_heads=48 5.246336 1.338272 0.675968 1.231328 +num_heads=72 7.866848 1.961152 0.995712 1.805376 +is_decode=False, use_bwd=False, bs=2, num_heads=32, num_kv_heads=None, per_head_dim=128, sw_sz=-1 + jax axlearn jax-cudnn jax-pallas +seq_len=128 0.030592 0.011584 0.013024 0.012960 +seq_len=256 0.051520 0.015648 0.016640 0.015744 +seq_len=512 0.118720 0.038976 0.028224 0.037152 +seq_len=1024 0.310880 0.096256 0.054784 0.090368 +seq_len=2048 0.931072 0.277312 0.150784 0.256928 +seq_len=4096 3.516672 2.595872 0.465568 2.448128 +is_decode=False, use_bwd=False, bs=2, num_heads=32, num_kv_heads=None, seq_len=4096, sw_sz=-1 + jax axlearn jax-cudnn jax-pallas +per_head_dim=16 3.220960 0.487808 0.332928 0.478720 +per_head_dim=32 3.277824 0.530240 0.334624 0.515040 +per_head_dim=64 3.345376 0.696480 0.338944 0.631296 +per_head_dim=128 3.515616 2.594208 0.465824 2.442784 +is_decode=False, use_bwd=True, num_heads=32, num_kv_heads=None, seq_len=4096, per_head_dim=128, sw_sz=-1 + jax axlearn jax-cudnn jax-pallas +bs=2 10.780096 4.573344 2.080672 4.487104 +bs=4 21.426336 9.336192 3.988224 9.159904 +bs=8 42.808033 18.926559 7.975296 18.075487 +is_decode=False, use_bwd=True, bs=2, num_kv_heads=None, seq_len=4096, per_head_dim=128, sw_sz=-1 + jax axlearn jax-cudnn jax-pallas +num_heads=12 4.128352 1.738016 0.882976 1.696704 +num_heads=16 5.467808 2.307488 1.120608 2.247904 +num_heads=32 10.782432 4.559456 2.082592 4.488448 +num_heads=48 16.119776 6.958272 3.027808 6.858144 +num_heads=72 24.140833 10.706656 4.560288 10.279136 +is_decode=False, use_bwd=True, bs=2, num_heads=32, num_kv_heads=None, per_head_dim=128, sw_sz=-1 + jax axlearn jax-cudnn jax-pallas +seq_len=128 0.058944 0.037824 0.039040 0.036384 +seq_len=256 0.100384 0.069024 0.052608 0.067872 +seq_len=512 0.317056 0.159904 0.111840 0.158912 +seq_len=1024 0.906400 0.431104 0.244160 0.421792 +seq_len=2048 2.861056 1.319648 0.655840 1.297728 +seq_len=4096 10.762560 4.576864 2.079904 4.489056 +is_decode=False, use_bwd=True, bs=2, num_heads=32, num_kv_heads=None, seq_len=4096, sw_sz=-1 + jax axlearn jax-cudnn jax-pallas +per_head_dim=16 10.084800 1.744640 1.263264 1.711296 +per_head_dim=32 10.204480 2.098816 1.291104 2.041184 +per_head_dim=64 10.374720 2.649888 1.335200 2.510304 +per_head_dim=128 10.779680 4.568096 2.079264 4.489792 """ -# pylint: skip-file +# pylint: enable=line-too-long +import itertools +from functools import partial +from typing import Any, Optional, Protocol, Union import jax import jax.numpy as jnp -import triton # pytype: disable=import-error +from jax.experimental.mosaic.gpu.profiler import _event_elapsed, _event_record, has_registrations from jax.experimental.pallas.ops.gpu.attention import mha as pallas_mha +from axlearn.common.attention_bias import sliding_window_causal_mask from axlearn.common.flash_attention.gpu_attention import ( cudnn_dot_product_attention, flash_attention, ) +from axlearn.common.flash_attention.gpu_decoding import Tensor, flash_decoding from axlearn.common.flash_attention.utils import mha_reference +X = jnp.zeros((8192, 8192)) +Y = jnp.zeros((8192, 8192)) -def _perf_report(prefix: str): - # 128 is the most common value for per_head_dim. - batch_size, num_heads, seq_len, per_head_dim = 2, 32, 2048, 128 - - # Vary batch size for fixed heads and seq length. - batch_size_bench = triton.testing.Benchmark( - x_names=["batch_size"], - x_vals=[2, 4, 8, 16], - line_arg="library", - line_vals=["jax", "jax-triton", "jax-pallas", "jax-cudnn"], - line_names=["Jax", "Jax Triton", "Pallas", "jax-cudnn"], - styles=[("blue", "-"), ("purple", "-"), ("green", "-"), ("red", "-")], - ylabel="ms", - plot_name=f"{prefix}-head{num_heads}-seq1024-d{per_head_dim}", - args={"num_heads": num_heads, "seq_len": 1024, "per_head_dim": per_head_dim}, - ) - # Vary num heads for fixed batch and seq length. - num_heads_bench = triton.testing.Benchmark( - x_names=["num_heads"], - x_vals=[12, 16, 32, 48, 72], - line_arg="library", - line_vals=["jax", "jax-triton", "jax-pallas", "jax-cudnn"], - line_names=["Jax", "Jax Triton", "Pallas", "jax-cudnn"], - styles=[("blue", "-"), ("purple", "-"), ("green", "-"), ("red", "-")], - ylabel="ms", - plot_name=f"{prefix}-batch{batch_size}-seq{seq_len}-d{per_head_dim}", - args={"batch_size": batch_size, "seq_len": seq_len, "per_head_dim": per_head_dim}, - ) - # Vary seq length for fixed heads and batch size. - seq_len_bench = triton.testing.Benchmark( - x_names=["seq_len"], - x_vals=[2**i for i in range(7, 12)], # 128 to 2048. - line_arg="library", - line_vals=["jax", "jax-triton", "jax-pallas", "jax-cudnn"], - line_names=["Jax", "Jax Triton", "Pallas", "jax-cudnn"], - styles=[("blue", "-"), ("purple", "-"), ("green", "-"), ("red", "-")], - ylabel="ms", - plot_name=f"{prefix}-batch{batch_size}-head{num_heads}-d{per_head_dim}", - args={"batch_size": batch_size, "num_heads": num_heads, "per_head_dim": per_head_dim}, - ) - # Vary per head dim for fixed batch and seq length. - per_head_dim_bench = triton.testing.Benchmark( - x_names=["per_head_dim"], - x_vals=[16, 32, 64, 128], - line_arg="library", - line_vals=["jax", "jax-triton", "jax-pallas", "jax-cudnn"], - line_names=["Jax", "Jax Triton", "Pallas", "jax-cudnn"], - styles=[("blue", "-"), ("purple", "-"), ("green", "-"), ("red", "-")], - ylabel="ms", - plot_name=f"{prefix}-batch{batch_size}-head{num_heads}-seq{seq_len}", - args={"batch_size": batch_size, "num_heads": num_heads, "seq_len": seq_len}, - ) - return triton.testing.perf_report( - [batch_size_bench, num_heads_bench, seq_len_bench, per_head_dim_bench] - ) +BenchFnResult = Union[tuple[Tensor], Tensor] -@_perf_report("fwd") -def bench_flash_attention( - batch_size: int, num_heads: int, seq_len: int, per_head_dim: int, library: str -): - warmup = 25 - rep = 500 - - if library.startswith("jax"): - q = jax.random.normal( - jax.random.PRNGKey(0), (batch_size, seq_len, num_heads, per_head_dim), dtype=jnp.float16 - ) - k = jax.random.normal( - jax.random.PRNGKey(1), (batch_size, seq_len, num_heads, per_head_dim), dtype=jnp.float16 - ) - v = jax.random.normal( - jax.random.PRNGKey(2), (batch_size, seq_len, num_heads, per_head_dim), dtype=jnp.float16 - ) - # Bias is not supported in pallas, so we don't include it here. - bias = None - - if "triton" in library: - fn = lambda: flash_attention(q, k, v, bias, causal=True) - elif "pallas" in library: - fn = lambda: pallas_mha(q, k, v, segment_ids=None, causal=True) - elif "cudnn" in library: - fn = lambda: cudnn_dot_product_attention(q, k, v, bias=bias, causal=True) - else: - fn = lambda: mha_reference(q, k, v, bias, causal=True) - ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) +class BenchFn(Protocol): + def __call__(self, *args: Tensor) -> BenchFnResult: + ... - else: - raise ValueError(f"Unsupported: {library}") - return ms +class SweepFn(Protocol): + def __call__(self, library: str, *args: Any, **kwargs: Any) -> tuple[BenchFnResult, float]: + ... + + +def measure(f: BenchFn, *args: Tensor) -> tuple[Tensor, float]: + """Measures the time it takes to execute the function on the GPU. + + This function is modified from + https://github.com/jax-ml/jax/blob/978d35f69704ce95a9d792f9ca9c7e3ee356417f/jax/experimental/mosaic/gpu/profiler.py#L72 + to support measuring fast kernels more accurately. This is done by queueing expensive GEMM + kernels before the benchmarked function is launched to avoid including dispatch and kernel + launch overhead to cuda event. -@_perf_report("grad") -def bench_flash_attention_backward( - batch_size: int, num_heads: int, seq_len: int, per_head_dim: int, library: str + Args: + f: The function to measure. It must accept at least one argument and return + at least one output to be measurable. + *args: The arguments to pass to ``f``. + **kwargs: The keyword arguments to pass to ``f``. + + Returns: + The return value of ``f`` and the elapsed time in milliseconds. + """ + if not has_registrations: + raise RuntimeError("This function requires jaxlib >=0.4.36 with CUDA support.") + + if not args: + # We require at least one argument and at least one output to ensure + # that there is a data dependency between `_event_record` calls in + # the resulting HLO program. + raise ValueError("Can only measure functions with arguments") + + @jax.jit + def run(*args): + start_event, args = _event_record(args, copy_before=True) + end_event, outs = _event_record(f(*args), copy_before=False) + if jax.tree.structure(outs).num_leaves == 0: + raise ValueError("Can only measure functions with at least one output") + return outs, _event_elapsed(start_event, end_event) + + jax.block_until_ready(run(*args)) # Warmup. + # Queue some expensive kernels into the stream to make events more accurate. + for _ in range(2): + _ = X @ Y + outs, elapsed = run(*args) + return outs, float(elapsed) + + +def bench_flash_attention( + library: str, + bs: int, + num_heads: int, + num_kv_heads: Optional[int], + seq_len: int, + per_head_dim: int, + is_decode: bool, + use_bwd: bool, + sw_sz: int = -1, ): - warmup = 25 - rep = 500 - - if library.startswith("jax"): - q = jax.random.normal( - jax.random.PRNGKey(0), (batch_size, seq_len, num_heads, per_head_dim), dtype=jnp.float16 - ) - k = jax.random.normal( - jax.random.PRNGKey(1), (batch_size, seq_len, num_heads, per_head_dim), dtype=jnp.float16 - ) - v = jax.random.normal( - jax.random.PRNGKey(2), (batch_size, seq_len, num_heads, per_head_dim), dtype=jnp.float16 - ) - # Bias is not supported in pallas, so we don't include it here. - bias = None - - if "triton" in library: + if use_bwd and is_decode: + raise ValueError("use_bwd and is_decode cannot both be true.") + + if num_kv_heads is None: + num_kv_heads = num_heads + q_seq_len = 1 if is_decode else seq_len + if is_decode: + if "pallas" in library: + q_seq_len = 16 # min supported seq length for triton and pallas + else: + q_seq_len = 1 + else: + q_seq_len = seq_len + q = jax.random.normal( + jax.random.PRNGKey(0), + (bs, q_seq_len, num_heads, per_head_dim), + dtype=jnp.float16, + ) + k = jax.random.normal( + jax.random.PRNGKey(1), (bs, seq_len, num_kv_heads, per_head_dim), dtype=jnp.float16 + ) + v = jax.random.normal( + jax.random.PRNGKey(2), (bs, seq_len, num_kv_heads, per_head_dim), dtype=jnp.float16 + ) + # Bias is not supported in pallas, so we don't include it here. + bias = None + if sw_sz != -1: + mask_fn = sliding_window_causal_mask(sw_sz) + assert bias is None + bias = jnp.zeros((1, 1, 1, seq_len), dtype=jnp.float16) + bias = bias.at[:, :, :, :-sw_sz].set(jnp.finfo(jnp.float16).min) + else: + mask_fn = None + args = (q, k, v, bias) + if "axlearn" in library: + if use_bwd: @jax.jit - def test_fn(q, k, v, bias): - return flash_attention(q, k, v, bias, causal=True).sum() + def triton_fn(q, k, v, bias): + # Use mean rather than sum so that gradients won't overflow. + return flash_attention(q, k, v, bias, causal=True).mean() - test_bwd = jax.grad(test_fn, argnums=(0, 1, 2)) - fn = lambda: test_bwd(q, k, v, bias) - elif "pallas" in library: + fn = jax.grad(triton_fn, argnums=(0, 1, 2)) + else: + if q_seq_len == 1: + fn = partial(flash_decoding, kv_seq_len=None, mask_fn=mask_fn) + args = (q, k, v) + else: + fn = partial(flash_attention, causal=True) + elif "pallas" in library: + k = k.repeat(num_heads // num_kv_heads, axis=2) + v = v.repeat(num_heads // num_kv_heads, axis=2) + args = (q, k, v) + if use_bwd: @jax.jit def pallas_fn(q, k, v): - return pallas_mha(q, k, v, segment_ids=None, causal=True).sum() + return pallas_mha(q, k, v, segment_ids=None, causal=True).mean() - pallas_bwd = jax.grad(pallas_fn, argnums=(0, 1, 2)) - fn = lambda: pallas_bwd(q, k, v) - elif "cudnn" in library: + fn = jax.grad(pallas_fn, argnums=(0, 1, 2)) + else: + fn = partial(pallas_mha, segment_ids=None, causal=not is_decode) + elif "cudnn" in library: + k = k.repeat(num_heads // num_kv_heads, axis=2) + v = v.repeat(num_heads // num_kv_heads, axis=2) + if use_bwd: @jax.jit - def cudnn_fn(q, k, v): - return cudnn_dot_product_attention(q, k, v, bias=bias, causal=True).sum() + def cudnn_fn(q, k, v, bias): + return cudnn_dot_product_attention(q, k, v, bias=bias, causal=True).mean() - cudnn_bwd = jax.grad(cudnn_fn, argnums=(0, 1, 2)) - fn = lambda: cudnn_bwd(q, k, v) + fn = jax.grad(cudnn_fn, argnums=(0, 1, 2)) else: + fn = partial(cudnn_dot_product_attention, causal=not is_decode) + else: + k = k.repeat(num_heads // num_kv_heads, axis=2) + v = v.repeat(num_heads // num_kv_heads, axis=2) + if use_bwd: @jax.jit def ref_fn(q, k, v, bias): - return mha_reference(q, k, v, bias, causal=True).sum() + return mha_reference(q, k, v, bias, causal=True).mean() + + fn = jax.grad(ref_fn, argnums=(0, 1, 2)) + else: + fn = partial(mha_reference, causal=not is_decode) - ref_bwd = jax.grad(ref_fn, argnums=(0, 1, 2)) - fn = lambda: ref_bwd(q, k, v, bias) - ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return measure(fn, *args) - else: - raise ValueError(f"Unsupported: {library}") - return ms + +def _sweep( + fn: SweepFn, libraries: list[str], common_kwargs: dict[str, Any], **_sweep_kwargs: list[Any] +): + """Benchmarks `fn` by sweeping through combinations of parameters. + + Args: + fn: The function to benchmark. + libraries: Libraries to benchmark. + common_kwargs: kwargs (k=v) that stays unchanged in the sweep. + sweep_kwargs: kwargs (k=[...]) that will be pass in as cartesian products to `fn`. + """ + sweep_kwargs: dict[str, list[tuple[str, Any]]] = {} + for k, args in _sweep_kwargs.items(): + common_kwargs.pop(k, None) + sweep_kwargs[k] = [(k, v) for v in args] + + # Simple sanity check for results. + def check_fn(result, ref_result): + if not jax.numpy.allclose(result, ref_result, atol=0.1): + raise ValueError( + f"{library} not equal to jax reference. Args: {common_kwargs} {bench_kwargs}." + f"Diff: {result - ref_result}" + ) + + results = [] + ref_result = None + for comb in itertools.product(*sweep_kwargs.values()): + bench_kwargs = dict(comb) + bench_key = ",".join(f"{k}={v}" for k, v in comb) + lib_results = [bench_key] + for i, library in enumerate(libraries): + result, t = fn(library=library, **bench_kwargs, **common_kwargs) + if i == 0: + ref_result = result + elif i > 0 and library != "jax-pallas": + jax.tree.map(check_fn, result, ref_result) + lib_results.append(t) + results.append(lib_results) + + # Header. + print(", ".join(f"{k}={v}" for k, v in common_kwargs.items())) + print(("{:<40}" + "{:<14}" * len(libraries)).format("", *libraries)) + # Result rows. + format_str = "{:<40}" + "{:<14.6f}" * len(libraries) + for lib_results in results: + print(format_str.format(*lib_results)) + + +def benchmark_sweep(libraries: list[str], common_kwargs: dict[str, Any], **sweep_args: list[Any]): + _sweep(bench_flash_attention, libraries, common_kwargs.copy(), **sweep_args) + + +def benchmark_decode(): + libraries = ["jax", "axlearn", "jax-cudnn"] + common_kwargs = dict( + is_decode=True, + use_bwd=False, + bs=1, + num_heads=8, + num_kv_heads=8, + seq_len=32 * 1024, + per_head_dim=128, + sw_sz=-1, + ) + benchmark_sweep( + libraries, common_kwargs, bs=[1, 4, 8, 16, 32], seq_len=[1024, 4096, 8192, 128 * 1024] + ) + benchmark_sweep(libraries, common_kwargs, bs=[1, 8], num_kv_heads=[1, 8]) + benchmark_sweep( + libraries, common_kwargs, bs=[1, 8], seq_len=[128 * 1024], sw_sz=[-1, 4096, 16 * 1024] + ) + + +def bench_flash_attention_fwd_bwd(use_bwd: bool): + common_kwargs = dict( + is_decode=False, + use_bwd=use_bwd, + bs=2, + num_heads=32, + num_kv_heads=None, + seq_len=4096, + per_head_dim=128, + sw_sz=-1, + ) + libraries = ["jax", "axlearn", "jax-cudnn", "jax-pallas"] + benchmark_sweep(libraries, common_kwargs, bs=[2, 4, 8]) + benchmark_sweep(libraries, common_kwargs, num_heads=[12, 16, 32, 48, 72]) + # 128 to 4096. + benchmark_sweep(libraries, common_kwargs, seq_len=[int(2**i) for i in range(7, 13)]) + benchmark_sweep(libraries, common_kwargs, per_head_dim=[16, 32, 64, 128]) -bench_flash_attention.run(save_path=".", print_data=True) -bench_flash_attention_backward.run(save_path=".", print_data=True) +benchmark_decode() +bench_flash_attention_fwd_bwd(False) +bench_flash_attention_fwd_bwd(True) diff --git a/axlearn/common/flash_attention/gpu_attention_test.py b/axlearn/common/flash_attention/gpu_attention_test.py index 901f9bf5b..9a1e8561a 100644 --- a/axlearn/common/flash_attention/gpu_attention_test.py +++ b/axlearn/common/flash_attention/gpu_attention_test.py @@ -17,12 +17,16 @@ import jax import jax.numpy as jnp import pytest +from absl.testing import parameterized +from axlearn.common.attention_bias import sliding_window_causal_mask from axlearn.common.flash_attention.gpu_attention import ( cudnn_dot_product_attention, flash_attention, ) -from axlearn.common.flash_attention.utils import mha_reference +from axlearn.common.flash_attention.gpu_decoding import NEG_INF, flash_decoding +from axlearn.common.flash_attention.utils import _repeat_kv_heads, mha_reference +from axlearn.common.test_utils import TestCase if jax.default_backend() != "gpu": pytest.skip(reason="Incompatible hardware", allow_module_level=True) @@ -92,9 +96,91 @@ def impl(q, k, v, bias, segment_ids): chex.assert_trees_all_close(o, o_ref, atol=0.07) -# We test the flash_attention against the reference mha_reference. -# The outputs should be close in both fp16 and fp32, with a relaxed bound due -# to the numerical difference during operations. +class FlashDecodingTest(TestCase): + """Tests FlashDecoding.""" + + @parameterized.product( + [ + dict(zip(["batch_size", "seq_len", "num_heads", "per_head_dim"], args)) + for args in [ + (1, 1024, 32, 64), + (1, 444, 16, 64), + (8, 1596, 48, 128), + (8, 4044, 64, 128), + ] + ], + softmax_scale=[1.0, 0.83], + attention_bias_type=["2d", "4d", None], + input_dtype=[jnp.float32, jnp.float16], + padding=[0, 111], + kv_head_factor=[1, 4, 8], + window_len=[-1, 16, 127], + ) + def test_decode_against_ref( + self, + batch_size: int, + seq_len: int, + num_heads: int, + per_head_dim: int, + softmax_scale: float, + attention_bias_type: Literal["2d", "4d", None], + input_dtype: jnp.dtype, + padding: int, + kv_head_factor: int, + window_len: int, + ): + self.assertEqual(num_heads % kv_head_factor, 0) + assert num_heads % kv_head_factor == 0 + k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(42), 4) + q = jax.random.normal(k1, (batch_size, 1, num_heads, per_head_dim), dtype=input_dtype) + k = jax.random.normal( + k2, + (batch_size, seq_len + padding, num_heads // kv_head_factor, per_head_dim), + dtype=input_dtype, + ) + v = jax.random.normal( + k3, + (batch_size, seq_len + padding, num_heads // kv_head_factor, per_head_dim), + dtype=input_dtype, + ) + + if attention_bias_type == "4d": + bias = jax.random.normal( + k4, (batch_size, num_heads, 1, seq_len + padding), dtype=input_dtype + ) + elif attention_bias_type == "2d": + bias = jax.random.normal(k4, (1, 1, 1, seq_len + padding), dtype=input_dtype) + else: + bias = None + + mask_fn = None + if window_len > 0: + mask_fn = sliding_window_causal_mask(window_len) + o = flash_decoding( + q, k, v, bias=bias, softmax_scale=softmax_scale, kv_seq_len=seq_len, mask_fn=mask_fn + ) + if bias is not None: + bias = bias[:, :, :, :seq_len] + if window_len > 0: + if bias is None: + bias = jnp.zeros((1, 1, 1, seq_len), dtype=input_dtype) + bias = bias.at[:, :, :, : -window_len - 1].set(NEG_INF) + o_ref = mha_reference( + q, + _repeat_kv_heads(num_heads, k[:, :seq_len]), + _repeat_kv_heads(num_heads, v[:, :seq_len]), + bias, + None, + causal=False, + softmax_scale=softmax_scale, + ) + self.assertGreaterEqual(jnp.median(jnp.abs(o_ref)).item(), 0.25) + if input_dtype is jnp.float32: + self.assertNestedAllClose(o, o_ref, rtol=0.01, atol=0.01) + else: + self.assertNestedAllClose(o, o_ref, rtol=0.05, atol=0.05) + + @pytest.mark.parametrize( "batch_size,num_heads,seq_len,per_head_dim", [ diff --git a/axlearn/common/flash_attention/gpu_decoding.py b/axlearn/common/flash_attention/gpu_decoding.py new file mode 100644 index 000000000..90e795b3f --- /dev/null +++ b/axlearn/common/flash_attention/gpu_decoding.py @@ -0,0 +1,357 @@ +# Copyright © 2024 Apple Inc. +# +# Some of the code in this file is adapted from: +# +# jax-ml/jax: +# Copyright 2023 The JAX Authors. +# Licensed under the Apache License, Version 2.0 (the "License"). + +"""Implements FlashDecoding. + +Reference: https://pytorch.org/blog/flash-decoding. +TLDR: FlashDecoding addresses the issue of SM under-utilization during decoding when batch size is +small and kv sequence length is long by parallelizing over the kv sequence length dimension. Each +thread block handles a chunk of the kv sequence, writing softmax residuals to HBM. The outputs +from each block are then combined and rescaled using these residuals to get the final output. + +This file is adapted from +https://github.com/jax-ml/jax/blob/861115ad4bf0f57e53f61d4d083cd2bda6877ab5/jax/experimental/pallas/ops/gpu/decode_attention.py, +but is heavily modified to improve performance and add support for bias and MaskFn: +1. Jax implementation uses double vmap to parallelize over batch and num_kv_heads. This requires + a axis permute for k and v, resulting in a transpose kernel that doubles the execution time of + decoding kernel. To remove this transpose, we only vmap the batch dimension and add an + additional dimension to the Pallas BlockSpec to let Pallas handles the strided k and v load. +2. Added support for attention bias. +3. Added support for MaskFn. The MaskFn makes it possible to support sparse attentions such as + sliding window attention or global-local attention without materializing the mask as attention + bias. The kernel can take advantage of sparsity by skipping fully-masked blocks, significantly + improving performance. Note that we do not materialize a compile time mask. Instead, we rely on + runtime voting in thread blocks. This leads to simpler code and faster compilation at the cost + of small runtime overhead (at microseconds level), which I find acceptable. + +Performance note (see numbers in gpu_attention_benchmark.py for numerical values): +No sparsity: +1. FlashDecoding is faster than XLA across the board by some margin (5%~20%). +2. FlashDecoding is significantly faster than cudnn for long context when bs * num_kv_head is + small. +3. FlashDecoding is slightly slower (~5%) than cudnn when bs * num_kv_head is large. +With sparsity such as with sliding window attention, FlashDecoding is few times faster. The +performance gain is proportional to total context length divided by window size. +""" +from __future__ import annotations + +import functools +from typing import Optional + +import jax +import jax.numpy as jnp +from jax import lax +from jax._src.cudnn.fused_attention_stablehlo import check_compute_capability +from jax.experimental import pallas as pl +from jax.experimental.pallas import gpu as plgpu + +from axlearn.common.attention import NEG_INF, MaskFn, Tensor + + +# Note: split_k_seq_len must be a multiple of block_k. +def _attn_forward_kernel( + # Inputs: + q_ref, # [block_h, head_dim] + k_ref, # [split_k_seq_len, head_dim] + v_ref, # [split_k_seq_len, head_dim] + bias_ref, # [block_h, split_k_seq_len] + kv_seq_len_ref, # [] (i.e., scalar) + # Outputs: + o_ref, # [block_h, head_dim] + l_ref, # [block_h,] + m_ref, # [block_h,] + # Non-tensors: + mask_fn: Optional[MaskFn], + softmax_scale: float, + block_k: int, + block_h: int, + qhead_per_kvhead: int, +): + _, head_dim = q_ref.shape + split_k_seq_len, _ = k_ref.shape + prog_i, prog_j = pl.program_id(1), pl.program_id(2) + q_mask = (block_h * prog_i + jnp.arange(block_h) < qhead_per_kvhead)[:, None] + + def _compute(block_kv_start_idx, block_kv_seqlen, o, m_i, l_i): + # Load q: it will stay in L1 throughout. Indices form a matrix because we + # read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index. + # q tile has shape [block_h, head_dim]. + q = pl.load(q_ref, (slice(None), slice(None)), mask=q_mask) * softmax_scale + + mask_indices = jnp.arange(block_k) + + # Loop over blocks of kv to process entire kv seq_len. + def body(start_k, carry): + o_prev, m_prev, l_prev = carry + + indices = block_kv_start_idx + start_k * block_k + mask_indices + # This mask guards against out-of-bound values only. + mask = indices < block_kv_seqlen + logits_mask = mask if mask_fn is None else mask_fn(kv_seq_len - 1, indices) & mask + + def compute(): + curr_k_slice = pl.ds(start_k * block_k, block_k) + k = pl.load(k_ref, (curr_k_slice, slice(None)), mask=mask[:, None], other=0.0) + qk = pl.dot(q, k.T, allow_tf32=False) # [block_h, block_k] + if bias_ref is not None: + qk += pl.load( + bias_ref, (slice(None), curr_k_slice), mask=mask[None, :], other=0.0 + ) + + qk = jnp.where(logits_mask[None, :], qk, NEG_INF) + + m_curr = qk.max(axis=-1) + m_next = jnp.maximum(m_prev, m_curr) + correction = jnp.exp(m_prev - m_next) + l_prev_corr = correction * l_prev + # Use m_next instead of m_curr to avoid a correction on l_curr. + s_curr = jnp.exp(qk - m_next[:, None]) + l_curr = s_curr.sum(axis=-1) + l_next = l_prev_corr + l_curr + v = pl.load(v_ref, (curr_k_slice, slice(None)), mask=mask[:, None], other=0.0) + o_curr = pl.dot(s_curr.astype(v.dtype), v, allow_tf32=False) + + # Flash2 unscaled_o. + o_next = correction[:, None] * o_prev + o_curr + return o_next, m_next, l_next + + def no_compute(): + return carry + + # Skip this block if this block is fully masked. Note: loading V is skipped. This + # basically means that we assume qk is not fully masked across the entire kv_seq_len. + # Note: cannot use jnp.all as reduce_and is not implemented in pallas/triton. + return lax.cond(jnp.sum(logits_mask) > 0, compute, no_compute) + + max_it = pl.cdiv(block_kv_seqlen - block_kv_start_idx, block_k) + (o, m_i, l_i) = lax.fori_loop(0, max_it, body, (o, m_i, l_i)) + return o, m_i, l_i + + # o is the buffer where we accumulate the output on sram. + # m_i and l_i (see FlashAttention2 paper) are updated during the k,v loop. + m_i = jnp.full(block_h, NEG_INF, dtype=jnp.float32) + l_i = jnp.zeros(block_h, dtype=jnp.float32) + o = jnp.zeros((block_h, head_dim), dtype=jnp.float32) + + block_kv_start_idx = prog_j * split_k_seq_len + kv_seq_len = pl.load(kv_seq_len_ref, ()) + block_kv_seqlen = jnp.minimum((prog_j + 1) * split_k_seq_len, kv_seq_len) + + # Skip padding in seq dim. + o, m_i, l_i = jax.lax.cond( + block_kv_start_idx >= kv_seq_len, + lambda: (o, m_i, l_i), + lambda: _compute(block_kv_start_idx, block_kv_seqlen, o, m_i, l_i), + ) + + # Write output to HBM. + vec_q_mask = q_mask.reshape(-1) + pl.store(l_ref, slice(None), l_i, mask=vec_q_mask) + pl.store(m_ref, slice(None), m_i, mask=vec_q_mask) + pl.store(o_ref, (slice(None), slice(None)), o, mask=q_mask) + + +def _get_sm_count() -> int: + """Returns number of SMs for the current GPU or 0 if unknown.""" + if check_compute_capability("9.0"): # H100 + return 132 + if check_compute_capability("8.9"): # L4, L40 + return 0 + if check_compute_capability("8.6"): # A40, A10 + return 0 + # This assumes we're not using A30. + if check_compute_capability("8.0"): # A100, A30 + return 108 + return 0 + + +def _decode_attn_unbatched( + q, # [kv_heads, qhead_per_kvhead, head_dim] + k, # [k_seq_len, kv_heads, head_dim] + v, # [k_seq_len, kv_heads, head_dim] + bias, # [kv_heads, qhead_per_kvhead, k_seq_len] + kv_seq_len, # [] + softmax_scale: float, + mask_fn: Optional[MaskFn], + block_h: int, + block_k: int, + num_warps: int, + num_stages: int, + interpret: bool, + debug: bool, + batch_size: int, +): + num_kvheads, qhead_per_kvhead, head_dim = q.shape + padded_kv_seq_len = k.shape[0] + head_splits = pl.cdiv(qhead_per_kvhead, block_h) + # Calculate the intiial k_splits. Cap the k_splits at 16, but increase it if batch_size * + # qhead_per_kvhead * 16 cannot fully utilize GPU and seqlen is long. Each block has 4 wraps. + # Each SM can hold at least two of these blocks according to the smem usage. + good_k_split_for_sm_util = _get_sm_count() // (batch_size * qhead_per_kvhead) + k_splits = min(max(good_k_split_for_sm_util, 16), pl.cdiv(padded_kv_seq_len, block_k)) + split_k_seq_len = pl.cdiv(padded_kv_seq_len, k_splits) + # Round up to a multiple of block_k. + split_k_seq_len = pl.cdiv(split_k_seq_len, block_k) * block_k + k_splits = pl.cdiv(padded_kv_seq_len, split_k_seq_len) + + grid = (num_kvheads, head_splits, k_splits) + block_k = min(block_k, split_k_seq_len) + kernel = functools.partial( + _attn_forward_kernel, + softmax_scale=softmax_scale, + block_k=block_k, + block_h=block_h, + qhead_per_kvhead=qhead_per_kvhead, + mask_fn=mask_fn, + ) + + o, l, m = pl.pallas_call( + kernel, + grid=grid, + in_specs=[ + # kv_h = axis along num kv heads. + # q_h = axis along q heads per kv head. + # k = axis along kv sequence length. + pl.BlockSpec((None, block_h, head_dim), lambda kv_h, q_h, k: (kv_h, q_h, 0)), + pl.BlockSpec((split_k_seq_len, None, head_dim), lambda kv_h, q_h, k: (k, kv_h, 0)), + pl.BlockSpec((split_k_seq_len, None, head_dim), lambda kv_h, q_h, k: (k, kv_h, 0)), + ] + + [ + ( + None + if bias is None + else pl.BlockSpec( + (None, block_h, split_k_seq_len), lambda kv_h, q_h, k: (kv_h, q_h, k) + ) + ) + ] + + [pl.BlockSpec((), lambda kv_h, q_h, k: ())], + out_specs=[ + pl.BlockSpec( + (None, None, block_h, head_dim), lambda kv_h, q_h, k: (kv_h, k, q_h, 0) + ), # o + pl.BlockSpec((None, None, block_h), lambda kv_h, q_h, k: (kv_h, k, q_h)), # l + pl.BlockSpec((None, None, block_h), lambda kv_h, q_h, k: (kv_h, k, q_h)), # m + ], + compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps, num_stages=num_stages), + out_shape=[ + jax.ShapeDtypeStruct( + shape=(num_kvheads, k_splits, *q.shape[1:]), dtype=jnp.float32 + ), # o + jax.ShapeDtypeStruct( + shape=(num_kvheads, k_splits, qhead_per_kvhead), dtype=jnp.float32 + ), # l + jax.ShapeDtypeStruct( + shape=(num_kvheads, k_splits, qhead_per_kvhead), dtype=jnp.float32 + ), # m + ], + debug=debug, + interpret=interpret, + name="flash_decoding_forward", + )(q, k, v, bias, kv_seq_len) + + # Combine the results from blocks into final output. + m_next = m.max(axis=1, keepdims=True) # [num_kv_heads, 1, qhead_per_kvhead] + correction = jnp.exp(m - m_next) # [num_kv_heads, k_splits, qhead_per_kvhead] + o = o * correction[..., None] # [num_kv_heads, k_splits, qhead_per_kvhead, head_dim] + l_next = (l * correction).sum(axis=1) # [num_kv_heads, qhead_per_kvhead] + o = o.sum(axis=1) / (l_next[..., None] + jnp.finfo(l_next.dtype).eps) + return o.astype(q.dtype) + + +@functools.partial( + jax.jit, + static_argnames=[ + "softmax_scale", + "block_h", + "block_k", + "num_warps", + "num_stages", + "interpret", + "debug", + "mask_fn", + ], +) +def flash_decoding( + q: Tensor, + k: Tensor, + v: Tensor, + kv_seq_len: Optional[Tensor], + *, + bias: Optional[Tensor] = None, + softmax_scale: float = 1.0, + mask_fn: Optional[MaskFn] = None, + block_h: int = 16, + block_k: int = 128, + num_warps: int = 4, + num_stages: int = 2, + interpret: bool = False, + debug: bool = False, +): + """Runs flash decoding with GQA support. + + Args: + q: Tensor of shape [batch_size, 1, num_q_heads, head_dim]. + k: Tensor of shape [batch_size, padded_kv_seq_len, num_kv_heads, head_dim]. + v: Tensor of shape [batch_size, padded_kv_seq_len, num_kv_heads, head_dim]. + kv_seq_len: Tensor that can broadcast to [batch_size], indicating the actual kv sequence + length for each sequence in the batch. If None, assumes k and v are not padded in the + sequence dimension. + bias: Tensor that can broadcast to [batch_size, q_heads, 1, padded_kv_seq_len]. + Defaults to None. + softmax_scale: Softmax scale. + mask_fn: Mask function to use. Preferred over bias. + block_h: Block dimension for num_q_heads // num_kv_heads. Defaults to 16, which is the + minimal size required for pl.dot. Increase the block size for better performance if + num_q_heads // num_kv_heads > 16. + block_k: Block dimension along the sequence dim. Defaults to 128. Decrease if SMEM usage + exceeds limit. + num_warps: See the compiler options of `pl.pallas_call`. Default to 4 wraps which have the + best performance in most settings. + num_stages: See the compiler options of `pl.pallas_call`. Default to 2 which is faster than + no software pipelining. Higher values may cause SMEM OOM. + interpret: See `pl.pallas_call`. + debug: See `pl.pallas_call`. + + Returns: + A tensor with the same shape and dtype as q. + """ + q = q.squeeze(1) + batch_size, q_heads, head_dim = q.shape + padded_kv_seq_len, kv_heads = k.shape[1], k.shape[2] + if k.shape != v.shape: + raise RuntimeError(f"Expect k and v to have the same shape. Got {k.shape=}, {v.shape=}!") + if q_heads % kv_heads != 0: + raise RuntimeError( + f"Expect number of kv heads divides number of q heads. Got {kv_heads=}, {q_heads=}!" + ) + if kv_seq_len is not None: + kv_seq_len = jnp.broadcast_to(jnp.asarray(kv_seq_len), (batch_size,)) + else: + kv_seq_len = jnp.full((batch_size,), padded_kv_seq_len, dtype=jnp.int32) + q_heads_per_kv_head = q_heads // kv_heads + q = q.reshape(batch_size, kv_heads, q_heads_per_kv_head, head_dim) + + if bias is not None: + bias = jnp.broadcast_to(bias, (batch_size, q_heads, 1, padded_kv_seq_len)) + bias = bias.reshape(batch_size, kv_heads, q_heads_per_kv_head, padded_kv_seq_len) + + inner = functools.partial( + _decode_attn_unbatched, + softmax_scale=softmax_scale, + block_h=block_h, + block_k=block_k, + num_warps=num_warps, + num_stages=num_stages, + interpret=interpret, + debug=debug, + mask_fn=mask_fn, + batch_size=batch_size, + ) + o = jax.vmap(inner)(q, k, v, bias, kv_seq_len) + return o.reshape(batch_size, 1, q_heads, head_dim) diff --git a/axlearn/common/flash_attention/layer.py b/axlearn/common/flash_attention/layer.py index 4317f5148..2cb8da181 100644 --- a/axlearn/common/flash_attention/layer.py +++ b/axlearn/common/flash_attention/layer.py @@ -105,17 +105,6 @@ def _logit_biases_spec(self, attention_logit_biases: BaseAttentionBias) -> BaseA cfg = self.config return attention_logit_biases.partition_spec(cfg.mha_dim_to_partition_spec) - def _repeat_kv_heads(self, key_or_value: Tensor) -> Tensor: - """Repeats key or value heads dim to match the query. - - TODO(dhwang2): optimize computation like GroupedQueryAttention. - """ - num_head_repeats = self.config.num_heads // key_or_value.shape[-2] - if num_head_repeats == 1: - return key_or_value - # Repeat along the num_heads dim: [batch, source_length, num_heads, per_head_dim]. - return jnp.repeat(key_or_value, num_head_repeats, axis=-2) - def _compute_attention( self, *, @@ -127,10 +116,6 @@ def _compute_attention( cfg = self.config backend = self._backend() - # Repeats key/value heads dim if necessary. - k_proj = self._repeat_kv_heads(k_proj) - v_proj = self._repeat_kv_heads(v_proj) - batch, target_len, num_heads, _ = q_proj.shape _, source_len, _, _ = k_proj.shape @@ -162,8 +147,11 @@ def _compute_attention( jit_attn, mesh=thread_resources.env.physical_mesh, in_specs=( - # QKV [batch_size, seq_len, num_heads, per_head_dim]. + # Q [batch_size, seq_len, num_heads, per_head_dim]. cfg.mha_dim_to_partition_spec["btnh"], + # KV [batch_size, seq_len, num_kv_heads, per_head_dim]. + # Note: while num_kv_heads can be different from num_heads, their partition spec + # should be the same. cfg.mha_dim_to_partition_spec["bsnh"], cfg.mha_dim_to_partition_spec["bsnh"], # Bias that can broadcast to [batch_size, num_heads, seq_len, seq_len]. diff --git a/axlearn/common/flash_attention/layer_test.py b/axlearn/common/flash_attention/layer_test.py index e2bf076e0..1d0df4931 100644 --- a/axlearn/common/flash_attention/layer_test.py +++ b/axlearn/common/flash_attention/layer_test.py @@ -24,13 +24,12 @@ from jax.experimental import mesh_utils from jax.sharding import Mesh -from axlearn.common.attention import GroupedQueryAttention, apply_attention_logit_biases +from axlearn.common.attention import GroupedQueryAttention from axlearn.common.attention_bias import ( CompositeAttentionBias, SegmentIdAttentionBias, TensorAttentionBias, bool_to_bias, - make_causal_biases, sliding_window_causal_mask, ) from axlearn.common.base_layer import BaseLayer @@ -446,7 +445,7 @@ def test_forward( ) # TODO(markblee): Test probs. self.assertNestedAllClose(ref_out.data, test_out.data, atol=0.05) - jax.clear_backends() + jax.extend.backend.clear_backends() @parameterized.product( _TEST_CONFIGS, @@ -589,9 +588,11 @@ def loss(params, inputs, layer): self.assertNestedAllClose(ref_value, test_value, atol=atol, rtol=rtol) self.assertNestedAllClose(ref_grads, test_grads, atol=atol, rtol=rtol) - jax.clear_backends() + jax.extend.backend.clear_backends() - @parameterized.product(_TEST_CONFIGS, causal=[True], sliding_window_size=[None, 4]) + @parameterized.product( + _TEST_CONFIGS, causal=[True], sliding_window_size=[None, 4], use_bias=[True, False] + ) def test_extend_step( self, batch, @@ -602,6 +603,7 @@ def test_extend_step( mesh_axis_names, causal, sliding_window_size, + use_bias, ): print( f"batch={batch}, seq_len={seq_len} (ignored->16), num_heads={num_heads}, \n" @@ -646,15 +648,13 @@ def test_extend_step( [batch, seq_len, hidden_dim], dtype=dtype, ) - bias = jax.random.normal( - jax.random.PRNGKey(0), - [batch, num_heads, seq_len, seq_len], - dtype=dtype, - ) - # Note: We need to use causal bias for flash attention input in case of decoding. - causal_bias = apply_attention_logit_biases(bias, make_causal_biases(seq_len)).astype( - dtype - ) + causal_bias = None + if use_bias: + causal_bias = jax.random.normal( + jax.random.PRNGKey(0), + [batch, num_heads, seq_len, seq_len], + dtype=dtype, + ) kv_state = None return_aux = {"probs"} @@ -713,7 +713,10 @@ def test_extend_step( attention_logit_biases=None, ) ref_inputs = dict( - cached_states=ref_initial_state, kv_state=kv_state, return_aux=return_aux + cached_states=ref_initial_state, + kv_state=kv_state, + return_aux=return_aux, + attention_logit_biases=None, ) decoder_output = jnp.zeros(shape=[seq_len, batch, hidden_dim]).astype(dtype) @@ -721,12 +724,16 @@ def test_extend_step( for t in range(seq_len): cur_query = jnp.expand_dims(query[:, t, :], axis=1) inputs["query"] = cur_query - inputs["attention_logit_biases"] = jnp.expand_dims(causal_bias[:, :, t, :], axis=2) + if use_bias: + inputs["attention_logit_biases"] = jnp.expand_dims( + causal_bias[:, :, t, :], axis=2 + ) ref_inputs["query"] = cur_query - ref_inputs["attention_logit_biases"] = jnp.expand_dims( - causal_bias[:, :, t, :], axis=2 - ) + if use_bias: + ref_inputs["attention_logit_biases"] = jnp.expand_dims( + causal_bias[:, :, t, :], axis=2 + ) ref_extend_step_outputs, _ = F( ref_layer, @@ -778,4 +785,4 @@ def test_extend_step( test_out.data, atol=2e-2, ) - jax.clear_backends() + jax.extend.backend.clear_backends() diff --git a/axlearn/common/flash_attention/tpu_attention.py b/axlearn/common/flash_attention/tpu_attention.py index 0edc6c0c5..de18a402f 100644 --- a/axlearn/common/flash_attention/tpu_attention.py +++ b/axlearn/common/flash_attention/tpu_attention.py @@ -236,7 +236,7 @@ def _tpu_splash_attention( bias: Optional[Tensor] = None, # [batch_size, num_heads, target_len, source_len] segment_ids: Optional[Tensor] = None, # [batch_size, target_len] *, - mask: Union[MaskFnAttentionBias | ZeroAttentionBias], + mask: Union[MaskFnAttentionBias, ZeroAttentionBias], block_sizes: Optional[splash_attention_kernel.BlockSizes] = None, ) -> Tensor: # [batch_size, num_heads, target_len, head_dim]. """Wraps JAX's sparse TPU flash-attention. diff --git a/axlearn/common/flash_attention/utils.py b/axlearn/common/flash_attention/utils.py index 85af50aeb..938a47080 100644 --- a/axlearn/common/flash_attention/utils.py +++ b/axlearn/common/flash_attention/utils.py @@ -22,6 +22,7 @@ ) from axlearn.common.flash_attention.gpu_attention import cudnn_dot_product_attention from axlearn.common.flash_attention.gpu_attention import flash_attention as gpu_flash_attention +from axlearn.common.flash_attention.gpu_decoding import flash_decoding from axlearn.common.flash_attention.tpu_attention import tpu_flash_attention from axlearn.common.utils import Tensor @@ -80,6 +81,18 @@ def mha_reference( return context +def _repeat_kv_heads(num_q_heads: int, key_or_value: Tensor) -> Tensor: + """Repeats key or value heads dim to match the query. + + TODO(dhwang2): optimize computation like GroupedQueryAttention. + """ + num_head_repeats = num_q_heads // key_or_value.shape[-2] + if num_head_repeats == 1: + return key_or_value + # Repeat along the num_heads dim: [batch, source_length, num_heads, per_head_dim]. + return jnp.repeat(key_or_value, num_head_repeats, axis=-2) + + # Accepts [query, key, value, attention_bias, segment_ids] tensors and returns the context Tensor. MultiHeadAttentionImpl = Callable[[Tensor, Tensor, Tensor, Tensor, Tensor], Tensor] @@ -118,13 +131,14 @@ def jit_attn( ) -> Tensor: # Fall back to plain MHA implementation when the seq_len is not be divisible by # block size. - if query.shape[1] % block_size != 0: + is_gpu_decoding = query.shape[1] == 1 and backend == "gpu" + if not is_gpu_decoding and query.shape[1] % block_size != 0: backend = "xla" - # For decoding, fall back to non-flash implementation and merge all biases + # For non-GPU decoding, fall back to non-flash implementation and merge all biases # into a dense floating point bias tensor since that implementation does not # support target_positions. - if query.shape[1] == 1: - # TODO(senyut): Implement FlashDecoding kernel and support TPU decoding. + if not is_gpu_decoding and query.shape[1] == 1: + # TODO(senyut): Support TPU decoding. backend = "xla" bias = TensorAttentionBias(bias.value()) @@ -148,6 +162,35 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: return segment_ids.segment_ids if backend == "gpu": + # TODO(hanzhi-zhou): supports small q sequence length for future use cases such as + # speculative decoding. + if query.shape[1] == 1: + # Decoding case. We should not repeat kv heads to match q heads for FlashDecoding. + # Note: decoding is always causal. Discard the causal mask if present. + mask, explicit_bias = split(bias, MaskFnAttentionBias) + if mask is None or mask.target_positions is None: + raise RuntimeError("Cannot retrive MaskFnAttentionBias or target_positions.") + mask_fn = mask.mask + kv_seq_len = mask.target_positions + 1 + logging.info("Using mask_fn=%s for FlashDecoding.", mask_fn) + + bias = explicit_bias.value() + if bias is not None: + logging.info( + "Using explicit_bias=%s for FlashDecoding. " + "This is not expected unless an explicit Tensor bias is used.", + bias, + ) + return flash_decoding( + query, + key, + value, + bias=bias, + mask_fn=mask_fn, + kv_seq_len=kv_seq_len, + softmax_scale=softmax_scale, + ) + if query.shape[1] != key.shape[1]: # TODO(xuan-zou): Generalize GPU Flash Attention for q_len != kv_len. # Remove pytest.skip corresponding to q_len != kv_len in layer_test.py once fixed. @@ -156,6 +199,9 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: f"{key.shape[1]} for correctly supported GPU flash attention usage." ) + key = _repeat_kv_heads(query.shape[2], key) + value = _repeat_kv_heads(query.shape[2], value) + # We have two implementations to choose from. # Both support `causal`. # One supports `segment_ids`. @@ -195,6 +241,8 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: ) elif backend == "tpu": + key = _repeat_kv_heads(query.shape[2], key) + value = _repeat_kv_heads(query.shape[2], value) # `mask` is supported. # `segment_ids` is supported. # Optimized handling for the above two types. @@ -217,6 +265,8 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: ) elif backend in ("cpu", "xla"): + key = _repeat_kv_heads(query.shape[2], key) + value = _repeat_kv_heads(query.shape[2], value) if backend == "cpu": logging.warning("Flash attention CPU backend is for testing only.") logging.warning("Flash attention falling back using plain MHA implementation")