Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Selective merged prefill #643

Open
wants to merge 23 commits into
base: mlperf_features
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 50 additions & 4 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import random
import time
from typing import List, Optional
import os

import pandas as pd
import torch
Expand Down Expand Up @@ -71,17 +72,54 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
raise ValueError("output_len too small")

# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
if os.path.splitext(dataset_path)[1] == ".json":
with open(dataset_path) as f:
dataset = json.load(f)
elif os.path.splitext(dataset_path)[1] == ".pkl":
import pandas as pd
dataset = pd.read_pickle(dataset_path)
dataset = dataset[['input', 'output']].to_dict(orient="records")
for data in dataset:
data["conversations"] = [
{"value": data["input"]},
{"value": data["output"]}
]

# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Shuffle the dataset.
random.shuffle(dataset)
#random.shuffle(dataset)

# Filter out sequences that are too long or too short
filtered_dataset: List[SampleRequest] = []
for data in dataset:
if len(filtered_dataset) == num_requests:
if args.sort_by_len:
filtered_dataset = sorted(filtered_dataset, key=lambda x: x.prompt_len)
if args.bucket_selective:
length_map = {}
for i, request in enumerate(filtered_dataset):
length_map.setdefault(request.prompt_len, []).append(i)
ret = {}
for length, indices in length_map.items():
bucket_size = (int(length / 128) + 1) * 128
while len(indices) > 0:
i = indices.pop(0)
if ret.get(bucket_size, None) is None:
ret[bucket_size] = []
ret[bucket_size].append(filtered_dataset[i])
remain_len = bucket_size - length
while remain_len > 0:
if length_map.get(remain_len, None) is not None and len(length_map[remain_len]) > 0:
j = length_map[remain_len].pop(0)
ret[bucket_size].append(filtered_dataset[j])
break
else:
remain_len -= 1
# sort ret by key
ret = dict(sorted(ret.items(), key=lambda x: x[0]))
print("!!!!!!!!!!!!!!!sorted requests:", [(bucket_size, [i.prompt_len for i in req_list]) for bucket_size, req_list in ret.items()])
filtered_dataset = [req for data in ret.items() for req in data[1]]
break

# Only keep the first two turns of each conversation.
Expand Down Expand Up @@ -121,6 +159,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
expected_output_len=output_len,
multi_modal_data=multi_modal_data))

# for i, data in enumerate(filtered_dataset):
# print(i, data.prompt)
return filtered_dataset


Expand Down Expand Up @@ -151,7 +191,7 @@ def run_vllm(
use_beam_search = False

if not use_beam_search:
for _ in range(2):
for _ in range(3):
start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
Expand Down Expand Up @@ -445,6 +485,12 @@ def main(args: argparse.Namespace):
action='store_true',
default=False,
help="Disable decoupled async engine frontend.")
parser.add_argument("--sort-by-len",
action='store_true',
default=False)
parser.add_argument("--bucket-selective",
action='store_true',
default=False)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
if args.tokenizer is None:
Expand Down
109 changes: 94 additions & 15 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import vllm_hpu_extension.ops as ops
from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax,
VLLMKVCache)
from vllm_hpu_extension.cache_ops import insert_or_update_cache

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
Expand All @@ -18,6 +19,7 @@
HPUPagedAttentionMetadata)
from vllm.logger import init_logger
from vllm.utils import is_fake_hpu
from vllm.model_executor.models.utils import split_and_pad_to_length

logger = init_logger(__name__)

Expand All @@ -29,6 +31,55 @@
logger.warning("Could not import HPU FusedSDPA kernel. "
"vLLM will use native implementation.")

def prompt_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[torch.Tensor] = None,
p: float = 0.0,
scale: Optional[float] = None,
matmul_qk_op=torch.matmul,
softmax_op=torch.softmax,
matmul_av_op=torch.matmul,
valid_seq_lengths: Optional[torch.Tensor] = None,
fsdpa_op = None,
) -> torch.Tensor:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
query_heads = query.size(1)
kv_heads = key.size(1)
#if attn_bias is not None or fsdpa_op is None:
if fsdpa_op is None:
if query_heads != kv_heads:
query = query.unflatten(1, (kv_heads, -1))
key = key.unflatten(1, (kv_heads, 1))
value = value.unflatten(1, (kv_heads, 1))
if attn_bias is not None:
attn_bias = attn_bias.unsqueeze(1)
attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2))
if attn_bias is not None:
attn_weights.add_(attn_bias)
attn_weights = softmax_op(attn_weights, dim=-1)
attn_weights = matmul_av_op(attn_weights, value)
if query_heads != kv_heads:
attn_weights = attn_weights.flatten(1, 2)
else:
VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE = os.environ.get('VLLM_REMOVE_REPEAT_KV_CACHE', '1') == '1'
# TODO: remove after fusedsdpa fix for query_heads != kv_heads
if query_heads != kv_heads:
if VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE:
key = ops.repeat_kv(key, int(query_heads // kv_heads))
value = ops.repeat_kv(value, int(query_heads // kv_heads))
if attn_bias is not None:
attn_bias = attn_bias.unsqueeze(1)
softmax_mode = 'fast'
recompute_mode = True
attn_weights = fsdpa_op(query, key, value, attn_bias, 0.0, False,
scale, softmax_mode, recompute_mode,
None, 'right')
attn_weights = attn_weights.transpose(1, 2)
return attn_weights

class HPUAttentionBackend(AttentionBackend):

Expand Down Expand Up @@ -83,6 +134,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
attn_bias: Optional[torch.Tensor]
seq_lens_tensor: Optional[torch.Tensor]
context_lens_tensor: Optional[torch.Tensor]
enable_merged_prefill: bool = False
seq_lens: Optional[List[int]] = None
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -213,6 +265,7 @@ def forward(
block_offsets = kwargs.get('block_offsets', None)
seq_lens_tensor = kwargs.get('seq_lens_tensor', None)
attn_bias = kwargs.get('attn_bias', None)
enable_merged_prefill = attn_metadata.enable_merged_prefill
if block_indices is None:
block_indices = attn_metadata.block_indices
if block_offsets is None:
Expand All @@ -221,20 +274,40 @@ def forward(
seq_lens_tensor = attn_metadata.seq_lens_tensor
if attn_bias is None: # This is the case for prompt run
attn_bias = attn_metadata.attn_bias
if attn_metadata.is_prompt:
key = key.unflatten(0, (block_indices.size(0), -1))
value = value.unflatten(0, (block_indices.size(0), -1))
if kv_cache is not None:
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
key_cache = self.k_cache(key, key_cache, block_indices,
block_offsets)
value_cache = self.v_cache(value, value_cache, block_indices,
block_offsets)
if enable_merged_prefill:
if attn_metadata.is_prompt:
max_len=attn_metadata.slot_mapping.size(1)
seq_lens_tensor_list = attn_metadata.seq_lens_tensor.tolist()
# we need to copy the key and value tensors to the padded tensors
# shape is [bacth_size, entire_seq_len, num_kv_heads, head_size]
padded_key_tensor = split_and_pad_to_length(key, max_len, seq_lens_tensor_list)
padded_value_tensor = split_and_pad_to_length(value, max_len, seq_lens_tensor_list)
padded_key_tensor = padded_key_tensor.flatten(0, 1).unflatten(0, (block_indices.size(0), -1))
padded_value_tensor = padded_value_tensor.flatten(0, 1).unflatten(0, (block_indices.size(0), -1))
Copy link

@yangw1234 yangw1234 Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to get rid of these line if we prepare block_indices and block_offsets in a way that excludes the padded tokens?


if kv_cache is not None:
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

key_cache = self.k_cache(padded_key_tensor, key_cache, block_indices,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that when decoding, padded_key_tensor is not defined. Would this be a problem?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're right, but it didn't trigger any error, I'll look into it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yangw1234 , after checking with the codes, since enable_merged_prefill is only enabled in prefill_fwd, I'll clean up the codes to make it more readable

block_offsets)
value_cache = self.v_cache(padded_value_tensor, value_cache, block_indices,
block_offsets)
else:
if attn_metadata.is_prompt:
key = key.unflatten(0, (block_indices.size(0), -1))
value = value.unflatten(0, (block_indices.size(0), -1))
if kv_cache is not None:
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
key_cache = self.k_cache(key, key_cache, block_indices,
block_offsets)
value_cache = self.v_cache(value, value_cache, block_indices,
block_offsets)

if attn_metadata.is_prompt:
# Prompt run.
Expand All @@ -253,10 +326,16 @@ def forward(
attn_bias = attn_bias.tile(
(1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
elif enable_merged_prefill:
pass
else:
attn_bias = None

out = ops.prompt_attention(
if enable_merged_prefill:
prompt_attn_func = prompt_attention
else:
prompt_attn_func = ops.prompt_attention
out = prompt_attn_func(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
Expand Down
8 changes: 7 additions & 1 deletion vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
maybe_prefix, split_and_pad_to_length)

is_hpu = current_platform.is_hpu()

Expand Down Expand Up @@ -490,6 +490,12 @@ def forward(
"residual": residual
})

# we need to split result before do RMSNorm
if attn_metadata.enable_merged_prefill and attn_metadata.is_prompt:
max_len=attn_metadata.slot_mapping.size(1)
seq_lens_tensor_list = attn_metadata.seq_lens_tensor.tolist()
hidden_states = split_and_pad_to_length(hidden_states.view(-1, hidden_states.size(2)), max_len, seq_lens_tensor_list)
residual = split_and_pad_to_length(residual.view(-1, hidden_states.size(2)), max_len, seq_lens_tensor_list)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states

Expand Down
23 changes: 23 additions & 0 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,3 +666,26 @@ def extract_layer_index(layer_name: str) -> int:
assert len(int_vals) == 1, (f"layer name {layer_name} should"
" only contain one integer")
return int_vals[0]

def split_and_pad_to_length(input, target_length, seq_lens_tensor_list):
# we need to copy the key and value tensors to the padded tensors
# shape is [bacth_size, entire_seq_len, num_kv_heads, head_size]
padded_list = torch.split_with_sizes(input[:sum(seq_lens_tensor_list)], seq_lens_tensor_list, dim=0)

padded_tensor = torch.nn.utils.rnn.pad_sequence(padded_list, batch_first=True)
pad_shape = [0] * (input.dim() - 1) * 2
pad_shape += [0, target_length - padded_tensor.size(1)]
padded_tensor = torch.nn.functional.pad(padded_tensor, pad_shape, value=0)
return padded_tensor

def split_and_pad_to_length_2(input, target_length, seq_lens_tensor_list):
# we need to copy the key and value tensors to the padded tensors
# shape is [bacth_size, entire_seq_len, num_kv_heads, head_size]
padded_tensor = torch.zeros((len(seq_lens_tensor_list), target_length, input.size(1), input.size(2)), device=input.device, dtype=input.dtype)

start = 0
for i in range(len(seq_lens_tensor_list)):
padded_tensor[i, :seq_lens_tensor_list[i], :, :] = input[start: start + seq_lens_tensor_list[i], :, :]
start = start + seq_lens_tensor_list[i]

return padded_tensor
Loading
Loading