-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Selective merged prefill #643
base: mlperf_features
Are you sure you want to change the base?
Changes from 10 commits
997a10a
79c8b8e
552e294
bd87512
1caf266
c528736
510722e
f6c0c84
b133542
2d6ceb9
116dc6c
fade386
911f14b
612abed
a3602f2
97ea32b
11ffc2f
2ef08d6
405243a
774c13c
3826c1d
a906f36
67a2923
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
import vllm_hpu_extension.ops as ops | ||
from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax, | ||
VLLMKVCache) | ||
from vllm_hpu_extension.cache_ops import insert_or_update_cache | ||
|
||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, | ||
AttentionMetadata, AttentionType) | ||
|
@@ -18,6 +19,7 @@ | |
HPUPagedAttentionMetadata) | ||
from vllm.logger import init_logger | ||
from vllm.utils import is_fake_hpu | ||
from vllm.model_executor.models.utils import split_and_pad_to_length | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
@@ -29,6 +31,55 @@ | |
logger.warning("Could not import HPU FusedSDPA kernel. " | ||
"vLLM will use native implementation.") | ||
|
||
def prompt_attention( | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
attn_bias: Optional[torch.Tensor] = None, | ||
p: float = 0.0, | ||
scale: Optional[float] = None, | ||
matmul_qk_op=torch.matmul, | ||
softmax_op=torch.softmax, | ||
matmul_av_op=torch.matmul, | ||
valid_seq_lengths: Optional[torch.Tensor] = None, | ||
fsdpa_op = None, | ||
) -> torch.Tensor: | ||
query = query.transpose(1, 2) | ||
key = key.transpose(1, 2) | ||
value = value.transpose(1, 2) | ||
query_heads = query.size(1) | ||
kv_heads = key.size(1) | ||
#if attn_bias is not None or fsdpa_op is None: | ||
if fsdpa_op is None: | ||
if query_heads != kv_heads: | ||
query = query.unflatten(1, (kv_heads, -1)) | ||
key = key.unflatten(1, (kv_heads, 1)) | ||
value = value.unflatten(1, (kv_heads, 1)) | ||
if attn_bias is not None: | ||
attn_bias = attn_bias.unsqueeze(1) | ||
attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2)) | ||
if attn_bias is not None: | ||
attn_weights.add_(attn_bias) | ||
attn_weights = softmax_op(attn_weights, dim=-1) | ||
attn_weights = matmul_av_op(attn_weights, value) | ||
if query_heads != kv_heads: | ||
attn_weights = attn_weights.flatten(1, 2) | ||
else: | ||
VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE = os.environ.get('VLLM_REMOVE_REPEAT_KV_CACHE', '1') == '1' | ||
# TODO: remove after fusedsdpa fix for query_heads != kv_heads | ||
if query_heads != kv_heads: | ||
if VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE: | ||
key = ops.repeat_kv(key, int(query_heads // kv_heads)) | ||
value = ops.repeat_kv(value, int(query_heads // kv_heads)) | ||
if attn_bias is not None: | ||
attn_bias = attn_bias.unsqueeze(1) | ||
softmax_mode = 'fast' | ||
recompute_mode = True | ||
attn_weights = fsdpa_op(query, key, value, attn_bias, 0.0, False, | ||
scale, softmax_mode, recompute_mode, | ||
None, 'right') | ||
attn_weights = attn_weights.transpose(1, 2) | ||
return attn_weights | ||
|
||
class HPUAttentionBackend(AttentionBackend): | ||
|
||
|
@@ -83,6 +134,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata): | |
attn_bias: Optional[torch.Tensor] | ||
seq_lens_tensor: Optional[torch.Tensor] | ||
context_lens_tensor: Optional[torch.Tensor] | ||
enable_merged_prefill: bool = False | ||
seq_lens: Optional[List[int]] = None | ||
encoder_seq_lens: Optional[List[int]] = None | ||
encoder_seq_lens_tensor: Optional[torch.Tensor] = None | ||
|
@@ -213,6 +265,7 @@ def forward( | |
block_offsets = kwargs.get('block_offsets', None) | ||
seq_lens_tensor = kwargs.get('seq_lens_tensor', None) | ||
attn_bias = kwargs.get('attn_bias', None) | ||
enable_merged_prefill = attn_metadata.enable_merged_prefill | ||
if block_indices is None: | ||
block_indices = attn_metadata.block_indices | ||
if block_offsets is None: | ||
|
@@ -221,20 +274,40 @@ def forward( | |
seq_lens_tensor = attn_metadata.seq_lens_tensor | ||
if attn_bias is None: # This is the case for prompt run | ||
attn_bias = attn_metadata.attn_bias | ||
if attn_metadata.is_prompt: | ||
key = key.unflatten(0, (block_indices.size(0), -1)) | ||
value = value.unflatten(0, (block_indices.size(0), -1)) | ||
if kv_cache is not None: | ||
key_cache, value_cache = HPUPagedAttention.split_kv_cache( | ||
kv_cache, self.num_kv_heads, self.head_size) | ||
|
||
# Reshape the input keys and values and store them in the cache. | ||
# If kv_cache is not provided, the new key and value tensors are | ||
# not cached. This happens during the initial memory profiling run. | ||
key_cache = self.k_cache(key, key_cache, block_indices, | ||
block_offsets) | ||
value_cache = self.v_cache(value, value_cache, block_indices, | ||
block_offsets) | ||
if enable_merged_prefill: | ||
if attn_metadata.is_prompt: | ||
max_len=attn_metadata.slot_mapping.size(1) | ||
seq_lens_tensor_list = attn_metadata.seq_lens_tensor.tolist() | ||
# we need to copy the key and value tensors to the padded tensors | ||
# shape is [bacth_size, entire_seq_len, num_kv_heads, head_size] | ||
padded_key_tensor = split_and_pad_to_length(key, max_len, seq_lens_tensor_list) | ||
padded_value_tensor = split_and_pad_to_length(value, max_len, seq_lens_tensor_list) | ||
padded_key_tensor = padded_key_tensor.flatten(0, 1).unflatten(0, (block_indices.size(0), -1)) | ||
padded_value_tensor = padded_value_tensor.flatten(0, 1).unflatten(0, (block_indices.size(0), -1)) | ||
|
||
if kv_cache is not None: | ||
key_cache, value_cache = HPUPagedAttention.split_kv_cache( | ||
kv_cache, self.num_kv_heads, self.head_size) | ||
|
||
key_cache = self.k_cache(padded_key_tensor, key_cache, block_indices, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that when decoding, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you're right, but it didn't trigger any error, I'll look into it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @yangw1234 , after checking with the codes, since enable_merged_prefill is only enabled in prefill_fwd, I'll clean up the codes to make it more readable |
||
block_offsets) | ||
value_cache = self.v_cache(padded_value_tensor, value_cache, block_indices, | ||
block_offsets) | ||
else: | ||
if attn_metadata.is_prompt: | ||
key = key.unflatten(0, (block_indices.size(0), -1)) | ||
value = value.unflatten(0, (block_indices.size(0), -1)) | ||
if kv_cache is not None: | ||
key_cache, value_cache = HPUPagedAttention.split_kv_cache( | ||
kv_cache, self.num_kv_heads, self.head_size) | ||
|
||
# Reshape the input keys and values and store them in the cache. | ||
# If kv_cache is not provided, the new key and value tensors are | ||
# not cached. This happens during the initial memory profiling run. | ||
key_cache = self.k_cache(key, key_cache, block_indices, | ||
block_offsets) | ||
value_cache = self.v_cache(value, value_cache, block_indices, | ||
block_offsets) | ||
|
||
if attn_metadata.is_prompt: | ||
# Prompt run. | ||
|
@@ -253,10 +326,16 @@ def forward( | |
attn_bias = attn_bias.tile( | ||
(1, self.num_kv_heads, 1, 1)) | ||
attn_bias.add_(position_bias) | ||
elif enable_merged_prefill: | ||
pass | ||
else: | ||
attn_bias = None | ||
|
||
out = ops.prompt_attention( | ||
if enable_merged_prefill: | ||
prompt_attn_func = prompt_attention | ||
else: | ||
prompt_attn_func = ops.prompt_attention | ||
out = prompt_attn_func( | ||
query.view(query_shape), | ||
key.view(kv_shape), | ||
value.view(kv_shape), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to get rid of these line if we prepare
block_indices
andblock_offsets
in a way that excludes the padded tokens?