Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ALiBi for the non-flash code path #858

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions flash_attn/modules/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,25 @@ class SelfAttention(nn.Module):
(default: 0.0)
"""

def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None):
super().__init__()
self.causal = causal
self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout)
self.register_buffer('alibi_slopes', alibi_slopes, persistent=False)
if alibi_slopes is not None:
self.register_buffer('linear_biases', self._build_linear_biases(16), persistent=False)
else:
self.linear_biases = None

def _build_linear_biases(self, seqlen):
context_position = torch.arange(seqlen, device=self.alibi_slopes.device)[:, None]
memory_position = torch.arange(seqlen, device=self.alibi_slopes.device)[None, :]
# distance tensor is of shape (seqlen, seqlen)
distance = torch.abs(memory_position - context_position)
# alibi tensor is of shape (1, H, seqlen, seqlen)
linear_biases = (distance[None, ...] * self.alibi_slopes[:, None, None])[None, ...]
return linear_biases

def forward(self, qkv, causal=None, key_padding_mask=None):
"""Implements the multihead softmax attention.
Expand All @@ -261,6 +275,11 @@ def forward(self, qkv, causal=None, key_padding_mask=None):
padding_mask.masked_fill_(key_padding_mask, 0.0)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
if self.alibi_slopes is not None:
if seqlen > self.linear_biases.shape[-1]:
self.linear_biases = self._build_linear_biases(seqlen)
cropped_biases = self.linear_biases[..., :seqlen, :seqlen]
scores = scores - cropped_biases
if causal:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
Expand Down Expand Up @@ -420,7 +439,7 @@ def __init__(
self.return_residual = return_residual
self.checkpointing = checkpointing
if use_alibi:
assert use_flash_attn, "ALiBi code path requires flash_attn"
assert not cross_attn or use_flash_attn, "ALiBi code path requires self-attention or cross-attention with flash_attn"
alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
else:
alibi_slopes = None
Expand Down Expand Up @@ -458,7 +477,7 @@ def __init__(
inner_attn_cls = (
partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
if use_flash_attn
else SelfAttention
else partial(SelfAttention, alibi_slopes=alibi_slopes)
)
inner_cross_attn_cls = (
partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
Expand Down