Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash attention for rocm #1

Open
wants to merge 322 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 80 commits
Commits
Show all changes
322 commits
Select commit Hold shift + click to select a range
f5d8763
speed up
fsx950223 Feb 16, 2023
5c257c9
remove useless changes
fsx950223 Feb 16, 2023
63405db
update ck
fsx950223 Feb 16, 2023
17ea3a7
Merge pull request #9 from fsx950223/optimize
guangzlu Feb 16, 2023
d1bf99a
optimize performance
fsx950223 Feb 17, 2023
84ed6d5
Update license file for CMake files.
groenenboomj Feb 17, 2023
43f28bd
fix a bug
fsx950223 Feb 20, 2023
c730d50
Update Dockerfile for ROCm
groenenboomj Feb 21, 2023
24c81ea
Add forward pass benchmark
groenenboomj Feb 25, 2023
228bb1a
Increase number of run samples for flash attention forward pass
groenenboomj Feb 25, 2023
53dd6cd
changed submodule into attn-bwd-develop
guangzlu Feb 27, 2023
55165c0
added dropout verify into fwd
guangzlu Feb 28, 2023
c7ec4c0
modified fmha_api.cpp
guangzlu Mar 1, 2023
b1473a8
moified some files
guangzlu Mar 1, 2023
f788e7d
added dropout verify
guangzlu Mar 2, 2023
9f6d0ae
batched seqlen can pass
guangzlu Mar 2, 2023
7164b75
fix bugs
fsx950223 Mar 2, 2023
de43726
merge updates
fsx950223 Mar 3, 2023
f51255f
fix multi gpu
fsx950223 Mar 6, 2023
fb1be67
added template parameter for 32 < d <=64
guangzlu Mar 7, 2023
f6b11c7
fix bugs
fsx950223 Mar 7, 2023
9de5c29
fixed initialization of z tensor and added workaround in test file fo…
guangzlu Mar 7, 2023
f1eb89e
update ck
fsx950223 Mar 8, 2023
f26ced0
Merge remote-tracking branch 'origin/dropout-verify' into mlperf_test
fsx950223 Mar 8, 2023
93677be
speed up tests
fsx950223 Mar 8, 2023
9b94f55
fix test cases
fsx950223 Mar 8, 2023
40978cd
remove z tensor
fsx950223 Mar 8, 2023
ccd80bf
fix a bug
fsx950223 Mar 8, 2023
05aec02
modified method to verify dropout
guangzlu Mar 8, 2023
75458cb
merge updates
fsx950223 Mar 9, 2023
83c46c8
Merge remote-tracking branch 'origin/dropout-verify' into mlperf_test2
fsx950223 Mar 9, 2023
27f84e8
enable bf16
fsx950223 Mar 9, 2023
06acbdb
optimize
fsx950223 Mar 10, 2023
065c2f0
optimize api
fsx950223 Mar 11, 2023
324bcbf
optimize api
fsx950223 Mar 11, 2023
d3b9fc6
update ck
fsx950223 Mar 13, 2023
890091e
format code
fsx950223 Mar 13, 2023
cefe848
Merge pull request #14 from fsx950223/mlperf_test2
guangzlu Mar 14, 2023
80b3a49
fixed and optimized dropout verify
guangzlu Mar 14, 2023
d4d0c6f
modifed some annotation and test file
guangzlu Mar 14, 2023
92cedaf
fixed test file
guangzlu Mar 14, 2023
0103fcb
optimized dropout verify
guangzlu Mar 15, 2023
2d64089
modified test file
guangzlu Mar 15, 2023
a3ecabe
Merge pull request #12 from ROCmSoftwarePlatform/dropout-verify
groenenboomj Apr 11, 2023
d0cc349
modified ck backend
guangzlu Apr 13, 2023
79d7ca1
modified api
guangzlu Apr 13, 2023
f4827f8
can run now
guangzlu Apr 13, 2023
325367c
modified output of dq dk dv
guangzlu Apr 13, 2023
7a81af7
fixed fp16 path
guangzlu Apr 14, 2023
a67bc9c
can pass unpadded test now
guangzlu Apr 14, 2023
36de0b6
modified api for deterministic use
guangzlu Apr 19, 2023
e3ff7b1
add deterministic and fp32 tensor result cast in API
guangzlu Apr 21, 2023
f7d1133
fixed test file and updated ck
guangzlu Apr 21, 2023
e84f4a0
optimized fmha_api.cpp
guangzlu Apr 21, 2023
963dfb9
fix patch path
fsx950223 Apr 27, 2023
9ee09b1
udpate dockerfile for ROCm 5.4 and Py3.8; modify patch path
May 22, 2023
3b883df
add switch for RTZ and deterministic
May 30, 2023
58b0844
add switches for RTZ and deterministic
May 30, 2023
44a17a5
modify ignores
May 30, 2023
66cd14d
submodule updates
May 31, 2023
b6b4090
python runtime api for deterministic and performance mode
Jun 1, 2023
618918f
update python api
Jun 1, 2023
c0be910
update python api
Jun 1, 2023
261c92a
update python api
Jun 1, 2023
f4854a2
bug fix
Jun 1, 2023
93844af
bug fix
Jun 1, 2023
f638aa6
bug fixes
Jun 1, 2023
6cb6b26
Update README.md
Jun 2, 2023
d5d80c5
bug fixes
Jun 2, 2023
adcd98f
Update README.md
Jun 2, 2023
0c84715
Merge pull request #15 from ROCmSoftwarePlatform/jhzhan/release_test
sabreshao Jun 2, 2023
7633247
Merge branch 'flash_attention_for_rocm' into jhzhan/release_test
sabreshao Jun 2, 2023
782e7ab
Merge pull request #2 from ROCmSoftwarePlatform/jhzhan/release_test
sabreshao Jun 2, 2023
918cd00
modify readme and minor changes
Jun 2, 2023
cfb7f3f
modify readme
Jun 2, 2023
2205fdc
refine readme
Jun 2, 2023
9273197
Update flash_attn_interface.py
Jun 2, 2023
7e6a96a
Update dockerfile
Jun 5, 2023
9c01c25
update docker and readme to remove private reference
paklui Jun 6, 2023
ceea624
unify data types of input, output, and gemm in either FP16 or BF16 fo…
Jun 6, 2023
d565fad
using BF16 as GEMM type in performance mode
Jun 7, 2023
e488af5
Merge branch 'flash_attention_for_rocm' of https://github.com/ROCmSof…
Jun 7, 2023
ee0665c
change random seeds api in accordance with PyTorch 1.13.1+
Jun 15, 2023
8559ccd
Update fmha_utils.h
Jun 15, 2023
9887a29
fix pt2.0 build
fsx950223 Jun 19, 2023
3e1f9ea
fix setup.py
fsx950223 Jun 19, 2023
8512242
fix bugs
fsx950223 Jun 19, 2023
beab3fb
support torch1.12
fsx950223 Jun 20, 2023
4d05af4
update dockerfile
fsx950223 Jun 20, 2023
9838670
update README
fsx950223 Jun 20, 2023
0317244
rename folder
fsx950223 Jun 20, 2023
662535c
rename files
fsx950223 Jun 20, 2023
6a51836
remove useless code
fsx950223 Jun 20, 2023
ad3259a
optimize performance
fsx950223 Jun 21, 2023
99637e4
Merge remote-tracking branch 'public/flash_attention_for_rocm2' into …
fsx950223 Jun 21, 2023
983d299
remove useless code
fsx950223 Jun 21, 2023
e90010b
update README
fsx950223 Jun 21, 2023
63ce40f
Update instruction in README.md
sabreshao Jun 21, 2023
78aada9
disable triton test cases
fsx950223 Jun 21, 2023
1a11344
Fix misalignment between Dockerfile_1.12.rocm and hipify_patch_1.12.p…
sabreshao Jun 21, 2023
dedea21
changed submodule
guangzlu Jun 26, 2023
424141b
added qloop
guangzlu Jun 27, 2023
994eca4
updated to newest version of ck
guangzlu Jun 27, 2023
f67f948
fixed unittest mode
guangzlu Jun 27, 2023
7e190a4
modified gemm type for perf mode
guangzlu Jun 28, 2023
98258ef
modified backend to use rtz
guangzlu Jun 28, 2023
8bb4d98
updated ck
guangzlu Jun 28, 2023
ab576b9
change backend to attn-train-develop-qloop
guangzlu Jul 4, 2023
6834c97
add rtz
guangzlu Jul 5, 2023
10d7481
Revert "Update fmha_utils.h"
sabreshao Jul 11, 2023
777e166
fix pt2.0 build
fsx950223 Jun 19, 2023
b8f2ee6
fix setup.py
fsx950223 Jun 19, 2023
3e4f367
fix bugs
fsx950223 Jun 19, 2023
bfb1d75
support torch1.12
fsx950223 Jun 20, 2023
b83723c
update dockerfile
fsx950223 Jun 20, 2023
6e2a304
update README
fsx950223 Jun 20, 2023
5ad9386
rename folder
fsx950223 Jun 20, 2023
e29d75f
rename files
fsx950223 Jun 20, 2023
deb2e94
remove useless code
fsx950223 Jun 20, 2023
3f5297b
optimize performance
fsx950223 Jun 21, 2023
22b64b1
remove useless code
fsx950223 Jun 21, 2023
535f1b7
update README
fsx950223 Jun 21, 2023
1ddabb8
disable triton test cases
fsx950223 Jun 21, 2023
db62edc
Fix misalignment between Dockerfile_1.12.rocm and hipify_patch_1.12.p…
sabreshao Jun 21, 2023
cf2ffe1
Update instruction in README.md
sabreshao Jun 21, 2023
6aacb04
added kloop into qloop
guangzlu Jul 12, 2023
67d897b
can compile qloop and kloop together
guangzlu Jul 14, 2023
0cb0cd5
can compile qloop and kloop together modified python file
guangzlu Jul 14, 2023
9551449
Merge pull request #6 from ROCmSoftwarePlatform/attn-qloop-kloop-v2
sabreshao Jul 14, 2023
0ba1882
updated ck
guangzlu Jul 14, 2023
489a673
default using qloop
guangzlu Jul 14, 2023
a988787
modified README.md
guangzlu Jul 14, 2023
34e29f7
Merge branch 'flash_attention_for_rocm' into attn-qloop-kloop-v2
guangzlu Jul 14, 2023
0821eb0
Reduce the compiling time by spliting into several cpp files (#7)
Jul 31, 2023
05d45e4
Remove PyTorch patch by updating PyTorch
groenenboomj Aug 7, 2023
a2e81ca
ck sync up
Aug 11, 2023
0627500
added template for non padding
guangzlu Aug 11, 2023
ed7ccb3
added code to judg whether to use unpadding mode
guangzlu Aug 12, 2023
b6d78bd
fixed bugs
guangzlu Aug 14, 2023
21b45c3
fixed bug for flash_fwd_runner_qloop_hdim64_fp16_noncausal_gfx90a.cpp
guangzlu Aug 16, 2023
52427b5
Merge pull request #8 from ROCmSoftwarePlatform/remove_patch
groenenboomj Aug 16, 2023
6d88e70
added non dropout code path for qloop dim64
guangzlu Aug 17, 2023
eabcebf
fixed bug
guangzlu Aug 17, 2023
fe1cb5a
added unpad for kloop
guangzlu Aug 18, 2023
4619d9c
Merge pull request #10 from ROCmSoftwarePlatform/inference-opt
guangzlu Aug 22, 2023
c902c75
gfx941 support
Aug 28, 2023
ce59e9f
fix RTN logic
Aug 29, 2023
d394549
Optimized API for packed conditions (#12)
guangzlu Aug 31, 2023
efd5e04
compatiable with xformers (#13)
fsx950223 Sep 13, 2023
0b037c2
Merge tag 'v2.0.0' of https://github.com/Dao-AILab/flash-attention in…
Sep 15, 2023
0de9665
added setup.py for ROCm; increase code readability; rename files.
Sep 15, 2023
48f57bf
modified mha_fwd; added mha_varlen_fwd
Sep 15, 2023
cef81d1
enable mha_bwd + mha_varlen_bwd
Sep 18, 2023
16b0c17
updated ck and removed kloop
guangzlu Sep 18, 2023
1e9ddf8
update python interface
Sep 18, 2023
7581be2
removed kloop related files
guangzlu Sep 18, 2023
94273a8
updated test file
guangzlu Sep 18, 2023
b13603f
modified test file
guangzlu Sep 19, 2023
f978f3e
fix get_env_
Sep 19, 2023
c656139
added bwd light version
guangzlu Sep 19, 2023
3f53461
format code
Sep 19, 2023
631f027
optimize code for light
guangzlu Sep 19, 2023
a6900a4
sync to 2.0.4
Sep 19, 2023
dc98ee5
sync to 2.0.4
Sep 19, 2023
37e5961
bug fixes
Sep 19, 2023
609262f
stage process for bwd nonpadding
guangzlu Sep 20, 2023
8216584
modified ratit for bwd
guangzlu Sep 20, 2023
de70d9d
added padding branch
guangzlu Sep 20, 2023
e82b97a
removed kloop stuff
guangzlu Sep 21, 2023
d7be208
added rtn to ck
guangzlu Sep 21, 2023
444e15a
Merge pull request #15 from ROCmSoftwarePlatform/bwd-prof-opt
sabreshao Sep 22, 2023
1d4913f
bug fixes
Sep 28, 2023
39c6578
bug fixes
Sep 28, 2023
0d557df
bug fixes
Sep 28, 2023
623ffbb
Merge pull request 'https://github.com/ROCmSoftwarePlatform/flash-att…
Sep 28, 2023
e61ba7a
Update README.md
Sep 28, 2023
94b9dd5
added batched template
Oct 7, 2023
67162e3
added batched template
Oct 7, 2023
fe31011
bug fixes for batched template
Oct 7, 2023
ae87e65
bug fixes for batched template
Oct 7, 2023
89806a4
added batched
Oct 8, 2023
aa59b0f
params -> BaseParams for static members
Oct 11, 2023
aa96f3e
hpp suffix is prefered in cpp hence changed
Oct 11, 2023
f47a112
removing deprecated files for ifu readiness
Oct 11, 2023
86964ee
improving logic
Oct 11, 2023
185fd79
refine params
Oct 11, 2023
46172fb
cleaned redundencies
Oct 11, 2023
bc37a40
bug fixes
Oct 12, 2023
14df6f1
bug fixes
Oct 12, 2023
149c2b4
bug fixed
Oct 12, 2023
8cfc14c
update CK and RTN logic
Oct 13, 2023
f047ddb
bug fixes
Oct 13, 2023
2338516
bug fixes
Oct 13, 2023
7a382ae
bug fixing
Oct 13, 2023
889d4a8
bug fixing
Oct 13, 2023
182ef77
using ck index
Oct 18, 2023
89e44a0
initilize vectors
Oct 18, 2023
790eca7
fixing mqa/gqa params
Oct 18, 2023
6bcce4f
added mqa/gqa APIs
Oct 18, 2023
393b1ad
remove zombie code
Oct 18, 2023
cbca76f
update interface
Oct 18, 2023
490a01b
bug fixes
Oct 18, 2023
232e5a9
bug fixes
Oct 18, 2023
db9541b
bug fixing
Oct 18, 2023
f046d04
fixing unit test cases
Oct 20, 2023
1fe24cf
bug fixes
Oct 20, 2023
f5783bb
passed all unit test
Oct 24, 2023
cd463f9
sync interface
Oct 24, 2023
9f90750
Merge branch 'junhzhan/ifu-v2.0.0' of https://github.com/ROCmSoftware…
Oct 24, 2023
5d1365a
added optional FP32 dQKV for unit tests
Oct 25, 2023
a807948
pass qkv.contiguous() instead of assigning values
Oct 25, 2023
3a31e7e
simple code
fsx950223 Oct 25, 2023
f4c8dde
add time kernel env
fsx950223 Oct 26, 2023
6bc3374
updated ckbackend
guangzlu Oct 26, 2023
b4d20b2
simple code
fsx950223 Oct 26, 2023
15c19e2
modified api to support mqa gqa
guangzlu Oct 26, 2023
0c5b579
fix dropout z tensors allocation; enable unit test
Oct 26, 2023
d7b631a
Merge branch 'junhzhan/ifu-v2.0.0' of https://github.com/ROCmSoftware…
Oct 26, 2023
b5ba498
added mqa gqa
guangzlu Oct 26, 2023
cc78698
updated ck backend
guangzlu Oct 26, 2023
6daeb0c
enableed mqa gqa for batched conditions
guangzlu Oct 27, 2023
b6a9f6e
fixed params
guangzlu Oct 27, 2023
5e80fc7
passed mqa & gqa for varlen tests
guangzlu Oct 27, 2023
02c234b
Merge pull request #16 from ROCmSoftwarePlatform/ifu-mqa
guangzlu Oct 30, 2023
b27bd1d
Update .gitignore
Oct 30, 2023
9a5273d
better .gitignore
Oct 30, 2023
2d11119
update git ignore
Oct 30, 2023
4d79450
update uint8 dropout in FA
Oct 30, 2023
a197406
update RTN swtich; enable MQA/GQA UT
Oct 30, 2023
8da5b66
Merge branch 'junhzhan/ifu-v2.0.0' of https://github.com/ROCmSoftware…
Oct 30, 2023
5378a20
tidy codes
Oct 31, 2023
1b808f4
add legacy interface support
Oct 31, 2023
23ee8fb
Merge branch 'junhzhan/ifu-v2.0.0' of https://github.com/ROCmSoftware…
Oct 31, 2023
0c92f31
code formatting
Oct 31, 2023
2c057b4
fixed bugs for grouped mha && d%8=0
guangzlu Nov 1, 2023
1cd7f89
Disable MQA UT
Nov 3, 2023
edc7698
Merge pull request #14 from ROCmSoftwarePlatform/junhzhan/ifu-v2.0.0
sabreshao Nov 3, 2023
5f1ae07
Remove Hardcoded Building Options (#19)
Nov 17, 2023
675d324
Add build script
dejay-vu Nov 21, 2023
8a77d72
Use GPU_ARCHS instead of PYTORCH_ROCM_ARCH
dejay-vu Nov 21, 2023
18060ee
Update README.md
dejay-vu Nov 29, 2023
fa589c3
Update README.md
dejay-vu Nov 29, 2023
fa285bf
Update README.md
dejay-vu Nov 29, 2023
3d2b6f5
Update README.md
dejay-vu Nov 29, 2023
3b786a2
Update README.md
Naomiusearch Nov 29, 2023
820b2b1
Update README.md
Naomiusearch Nov 29, 2023
68aac13
Merge pull request #23 from Naomiusearch/flash_attention_for_rocm
dejay-vu Dec 5, 2023
b64f45e
Allow gfx908 to build
luizanao Jan 26, 2024
ae7928c
Merge pull request #38 from luizanao/add-support-gfx908
dejay-vu Feb 4, 2024
2554f49
add benchmark script (#49)
fsx950223 Mar 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "csrc/flash_attn/cutlass"]
path = csrc/flash_attn/cutlass
url = https://github.com/NVIDIA/cutlass.git
[submodule "csrc/flash_attn_rocm/composable_kernel"]
path = csrc/flash_attn_rocm/composable_kernel
url = https://github.com/ROCmSoftwarePlatform/composable_kernel
18 changes: 18 additions & 0 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# BSD 3 Clause
# Copyright 2023 Advanced Micro Devices, Inc.
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

FROM rocm/pytorch:rocm5.3.2_ubuntu20.04_py3.7_pytorch_1.12.1

WORKDIR /workspace
USER root

RUN pip install ninja
RUN git clone --recurse-submodules --branch flash_attention_for_rocm https://github.com/ROCmSoftwarePlatform/flash-attention.git
RUN cd /workspace/flash-attention \
&& patch /opt/conda/lib/python3.7/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch \
&& python setup.py install
5 changes: 5 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@ recursive-include flash_attn *.cu
recursive-include flash_attn *.h
recursive-include flash_attn *.cuh
recursive-include flash_attn *.cpp

recursive-include flash_attn_rocm *.cu
recursive-include flash_attn_rocm *.h
recursive-include flash_attn_rocm *.cuh
recursive-include flash_attn_rocm *.cpp
86 changes: 86 additions & 0 deletions benchmarks/benchmark_flash_attention_forward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from functools import partial
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange, repeat

from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward, benchmark_combined
from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func


def attention_ref(qkv, attn_mask, dropout_p, upcast=False, causal=False):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, head_dim)
attn_mask: (batch_size, seqlen)
dropout_p: float
Output:
output: (batch_size, seqlen, nheads, head_dim)
attention: softmax after dropout
"""
q, k, v = (qkv.float() if upcast else qkv).unbind(dim=2)
seqlen = qkv.shape[1]
d = qkv.shape[-1]
scores = torch.einsum('bthd,bshd->bhts', q, k / math.sqrt(d))
scores.masked_fill_(rearrange(~attn_mask, 'b s -> b 1 1 s'), float('-inf'))
if causal:
causal_mask = torch.triu(torch.ones(seqlen, seqlen, dtype=torch.bool, device=qkv.device), 1)
scores.masked_fill_(causal_mask, float('-inf'))
attention = torch.softmax(scores, dim=-1)
attention_drop = F.dropout(attention, dropout_p)
output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
# return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype)
return output.to(dtype=qkv.dtype)


torch.manual_seed(0)
repeats = 250
batch_size = [1,32,64,128]
nheads = 16
seqlen = [1024,2048,4096]
n = 1024
d = n // nheads
dropout_p = 0.1
causal = False
dtype = torch.float16
device = 'cuda'

result_summary = []

for bs in batch_size:
for sq in seqlen:
if (bs > 32 and sq > 2048) or (bs > 64 and sq > 1024):
continue
x = torch.randn(bs, sq, n, device='cuda', dtype=dtype, requires_grad=True)
Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype)

lengths = torch.randint(sq - 20, sq, (bs, 1), device='cuda')
attention_mask_bool = repeat(torch.arange(sq, device='cuda'), 's -> b s', b=bs) < lengths
attention_mask = torch.zeros(bs, sq, device='cuda', dtype=dtype)
attention_mask[~attention_mask_bool] = -10000.0
attention_mask = rearrange(attention_mask, 'b s -> b 1 1 s')

x_unpad, indices, cu_sqs, max_sq_in_batch = unpad_input(x, attention_mask_bool)
qkv_unpad = rearrange(Wqkv(x_unpad), 'nnz (t h d) -> nnz t h d', t=3,
h=nheads).detach().requires_grad_()
qkv = rearrange(Wqkv(x), 'b s (t h d) -> b s t h d', t=3, h=nheads).detach().requires_grad_()

print(f'Batch size: {bs}, Sequence Length: {sq}')

fn = lambda qkv_unpad: flash_attn_unpadded_qkvpacked_func(
qkv_unpad, cu_sqs, max_sq_in_batch, dropout_p, causal=causal
)
fa_time,fa_measurement = benchmark_forward(fn, qkv_unpad, repeats=repeats, desc='FlashAttention')
fn = lambda qkv: attention_ref(qkv, attention_mask_bool, dropout_p, causal=causal)
pyt_time,pyt_measurement = benchmark_forward(fn, qkv, repeats=repeats, desc='PyTorch Standard Attention')

relative_perf = ((pyt_measurement.mean-fa_measurement.mean)/pyt_measurement.mean) * 100

result_summary.append([bs,sq,pyt_measurement.mean,fa_measurement.mean,relative_perf])

print(f'Flash Attention Speedup: {relative_perf}\n')

print(f'batch size, sequence length, PyTorch Standard Attention, FlashAttention, speedup relative to PyTorch\n {result_summary}')
154 changes: 154 additions & 0 deletions csrc/flash_attn_rocm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# BSD 3 Clause
# Copyright 2023 Advanced Micro Devices, Inc.
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(fmha_api)

IF(NOT DEFINED ENV{ROCM_PATH})
SET(ROCM_PATH /opt/rocm)
ELSE()
SET(ROCM_PATH $ENV{ROCM_PATH})
ENDIF()
if(NOT DEFINED ENV{ROCM_INCLUDE_DIRS})
set(ROCM_INCLUDE_DIRS ${ROCM_PATH}/include)
else()
set(ROCM_INCLUDE_DIRS $ENV{ROCM_INCLUDE_DIRS})
endif()
# HIP_PATH
IF(NOT DEFINED ENV{HIP_PATH})
SET(HIP_PATH ${ROCM_PATH}/hip)
ELSE()
SET(HIP_PATH $ENV{HIP_PATH})
ENDIF()



IF(NOT EXISTS ${HIP_PATH})
return()
ENDIF()



# HCC_PATH
IF(NOT DEFINED ENV{HCC_PATH})
SET(HCC_PATH ${ROCM_PATH}/hcc)
ELSE()
SET(HCC_PATH $ENV{HCC_PATH})
ENDIF()



# HSA_PATH
IF(NOT DEFINED ENV{HSA_PATH})
SET(HSA_PATH ${ROCM_PATH}/hsa)
ELSE()
SET(HSA_PATH $ENV{HSA_PATH})
ENDIF()



# ROCBLAS_PATH
IF(NOT DEFINED ENV{ROCBLAS_PATH})
SET(ROCBLAS_PATH ${ROCM_PATH}/rocblas)
ELSE()
SET(ROCBLAS_PATH $ENV{ROCBLAS_PATH})
ENDIF()



# ROCSPARSE_PATH
IF(NOT DEFINED ENV{ROCSPARSE_PATH})
SET(ROCSPARSE_PATH ${ROCM_PATH}/rocsparse)
ELSE()
SET(ROCSPARSE_PATH $ENV{ROCSPARSE_PATH})
ENDIF()



# ROCFFT_PATH
IF(NOT DEFINED ENV{ROCFFT_PATH})
SET(ROCFFT_PATH ${ROCM_PATH}/rocfft)
ELSE()
SET(ROCFFT_PATH $ENV{ROCFFT_PATH})
ENDIF()



# HIPSPARSE_PATH
IF(NOT DEFINED ENV{HIPSPARSE_PATH})
SET(HIPSPARSE_PATH ${ROCM_PATH}/hipsparse)
ELSE()
SET(HIPSPARSE_PATH $ENV{HIPSPARSE_PATH})
ENDIF()



# THRUST_PATH
IF(NOT DEFINED ENV{THRUST_PATH})
SET(THRUST_PATH ${ROCM_PATH}/include)
ELSE()
SET(THRUST_PATH $ENV{THRUST_PATH})
ENDIF()



# HIPRAND_PATH
IF(NOT DEFINED ENV{HIPRAND_PATH})
SET(HIPRAND_PATH ${ROCM_PATH}/hiprand)
ELSE()
SET(HIPRAND_PATH $ENV{HIPRAND_PATH})
ENDIF()



# ROCRAND_PATH
IF(NOT DEFINED ENV{ROCRAND_PATH})
SET(ROCRAND_PATH ${ROCM_PATH}/rocrand)
ELSE()
SET(ROCRAND_PATH $ENV{ROCRAND_PATH})
ENDIF()



# MIOPEN_PATH
IF(NOT DEFINED ENV{MIOPEN_PATH})
SET(MIOPEN_PATH ${ROCM_PATH}/miopen)
ELSE()
SET(MIOPEN_PATH $ENV{MIOPEN_PATH})
ENDIF()



# Add HIP to the CMAKE Module Path
set(CMAKE_MODULE_PATH ${HIP_PATH}/cmake ${CMAKE_MODULE_PATH})

find_package(HIP)

set(CMAKE_CXX_COMPILER /opt/rocm/hip/bin/hipcc)
set(CMAKE_CXX_STANDARD 20)

list(APPEND CMAKE_PREFIX_PATH "/opt/conda/lib/python3.8/site-packages/torch/share/cmake")
find_package(Torch REQUIRED)

find_package(rocblas)
find_package(hipfft)
find_package(hiprand)
find_package(hipsparse)

include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/composable_kernel/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/composable_kernel/library/include)
include_directories(/opt/conda/include/python3.8)

aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/src FLA_SRCS)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/composable_kernel/library/src/utility CK_SRCS)

add_executable(fmha_api fmha_api.cpp ${FLA_SRCS} ${CK_SRCS})
target_link_libraries(fmha_api "${TORCH_LIBRARIES}")

message("${TORCH_LIBRARIES}")
31 changes: 31 additions & 0 deletions csrc/flash_attn_rocm/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# BSD 3 Clause
# Copyright 2023 Advanced Micro Devices, Inc.
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

FROM rocm/pytorch:rocm5.3.2_ubuntu20.04_py3.7_pytorch_1.12.1
WORKDIR /flash_attn

USER root

ENV DEBIAN_FRONTEND noninteractive
ENV TZ "Asia/Shanghai"

RUN apt-get update \
&& apt install -y git-all \
&& git clone https://<your github username>:<your github token>@github.com/ROCmSoftwarePlatform/flash-attention_private \
&& cd /flash_attn/flash-attention_private \
&& git checkout flash_attention_for_rocm \
&& cd /flash_attn/flash-attention_private/csrc/flash_attn_rocm/composable_kernel \
&& git submodule init \
&& git submodule update \
&& cd /flash_attn/flash-attention_private/csrc/flash_attn_rocm \
&& mkdir build \
&& cd build \
&& cmake .. \
&& cd /flash_attn/flash-attention_private/csrc/flash_attn_rocm/build \
&& make -j64

49 changes: 49 additions & 0 deletions csrc/flash_attn_rocm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
Here is the folder for APIs on rocm, which the backend code is from composable kernel.

Below is the introduction to the files.

"src/fmha.h" is the header file for the C++ APIs, in which declared the function "run_fmha_fp16_bf16_gfx90a".

"fmha_api.cpp" is the c++ file that defined the API function "mha_fwd", this function will call function "run_fmha_fp16_bf16_gfx90a". This function also contains a main function to test with the API.

"src/fmha_fprop_fp16_bf16_kernel.gfx90a" is the interface that link API in fmha_api.cpp and the CK backend, which defined function "run_fmha_fp16_bf16_gfx90a". In this function, it will use parameters conveyed from "mha_fwd" to choose proper instance parameters for CK function. Function "run_fmha_fp16_bf16_gfx90a_loop_" will use parameters from "run_fmha_fp16_bf16_gfx90a" to initialize instance in CK and call CK function.

"CMakeList.txt" is a cmake file to compile the example above.

Useage for "CMakeLists.txt":
```
$mkdir build
$cd build
$cmake ..
$make
```

My docker is from https://hub.docker.com/layers/rocm/pytorch/rocm5.3.2_ubuntu20.04_py3.7_pytorch_1.12.1/images/sha256-387b2538d14cfd55a9510b7ea07049f1e71b7e755413080153b997c798fe5099?context=explore

If you choose another docker or you install pytorch by yourself.

Please change line 8 in CMakeLists.txt file with your own path.

You can use command
```
python -c 'import torch;print(torch.utils.cmake_prefix_path)'
```
to find your path.

Way to build with docker file:

Change the github username and tocken with that of yourself in line https://github.com/ROCmSoftwarePlatform/flash-attention_private/blob/41ddb2fb3884085ee5318d30f8e919944ee18745/csrc/flash_attn_rocm/Dockerfile#L11 firstly.

Then
```
sudo docker build -t flash_attention:rocm5.3.2 .
```

If you want to test the performance, you can set the parameter “time_kernel” as true. And then the kernel will run 10 times and give out the average running time. You can find the parameter in this line: https://github.com/ROCmSoftwarePlatform/flash-attention_private/blob/fb47a607682a873a3e0b17ae220849cc11a34d8b/csrc/flash_attn_rocm/src/fmha_fprop_fp16_bf16_kernel.gfx90a.cpp#L142

If you want to verify the results, you can set the parameter “do_verification” in this line https://github.com/ROCmSoftwarePlatform/flash-attention_private/blob/fb47a607682a873a3e0b17ae220849cc11a34d8b/csrc/flash_attn_rocm/fmha_api.cpp#L271 . And then the code can do the same computation on cpu and compare with the results from device and show whether device results are right.





1 change: 1 addition & 0 deletions csrc/flash_attn_rocm/composable_kernel
Submodule composable_kernel added at 5736b4
Loading