Skip to content

Commit

Permalink
use index_put with full block_indices
Browse files Browse the repository at this point in the history
Signed-off-by: Chendi.Xue <[email protected]>
  • Loading branch information
xuechendi committed Jan 8, 2025
1 parent 612abed commit ce48860
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 78 deletions.
22 changes: 4 additions & 18 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,6 @@ def forward(
block_offsets = kwargs.get('block_offsets', None)
seq_lens_tensor = kwargs.get('seq_lens_tensor', None)
attn_bias = kwargs.get('attn_bias', None)
seq_lens_tensor_list = kwargs.get('seq_lens_tensor_list', None)
enable_merged_prefill = attn_metadata.enable_merged_prefill
if block_indices is None:
block_indices = attn_metadata.block_indices
Expand All @@ -278,25 +277,12 @@ def forward(
seq_lens_tensor = attn_metadata.seq_lens_tensor
if attn_bias is None: # This is the case for prompt run
attn_bias = attn_metadata.attn_bias
if enable_merged_prefill and attn_metadata.is_prompt and kv_cache is not None:
max_len = attn_metadata.slot_mapping.size(1)
# we need to copy the key and value tensors to the padded tensors
# shape is [bacth_size, entire_seq_len, num_kv_heads, head_size]
padded_key_tensor = split_and_pad_to_length(
key, max_len, seq_lens_tensor_list)
padded_value_tensor = split_and_pad_to_length(
value, max_len, seq_lens_tensor_list)
padded_key_tensor = padded_key_tensor.flatten(0, 1).unflatten(
0, (block_indices.size(0), -1))
padded_value_tensor = padded_value_tensor.flatten(0, 1).unflatten(
0, (block_indices.size(0), -1))

if enable_merged_prefill and attn_metadata.is_prompt and kv_cache is not None:
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

key_cache = self.k_cache(padded_key_tensor, key_cache,
kv_cache, self.num_kv_heads, self.head_size)
key_cache = self.k_cache(key, key_cache,
block_indices, block_offsets)
value_cache = self.v_cache(padded_value_tensor, value_cache,
value_cache = self.v_cache(value, value_cache,
block_indices, block_offsets)
else:
if attn_metadata.is_prompt:
Expand Down
15 changes: 2 additions & 13 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,6 @@ def forward(
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
seq_lens_tensor_list: List[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
if isinstance(hidden_states, torch.Tensor):
skip_split = hidden_states.size()[0] == 1
Expand All @@ -314,8 +313,7 @@ def forward(
hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
seq_lens_tensor_list=seq_lens_tensor_list)
attn_metadata=attn_metadata)

# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
Expand Down Expand Up @@ -481,26 +479,17 @@ def forward(
import habana_frameworks.torch as htorch
htorch.core.mark_step()

if attn_metadata.enable_merged_prefill and attn_metadata.is_prompt:
seq_lens_tensor_list = attn_metadata.seq_lens_tensor.tolist()
else:
seq_lens_tensor_list = None
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual, seq_lens_tensor_list)
attn_metadata, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})

# we need to split result before do RMSNorm
if attn_metadata.enable_merged_prefill and attn_metadata.is_prompt:
max_len=attn_metadata.slot_mapping.size(1)
hidden_states = split_and_pad_to_length(hidden_states.view(-1, hidden_states.size(2)), max_len, seq_lens_tensor_list)
residual = split_and_pad_to_length(residual.view(-1, hidden_states.size(2)), max_len, seq_lens_tensor_list)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states

Expand Down
76 changes: 29 additions & 47 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ def generate_prompt_buckets(self):

prompt_bs_bucket_cfg = self.global_state.prompt_bs_bucket_cfg
prompt_seq_bucket_cfg = self.global_state.prompt_seq_bucket_cfg
print("prompt_seq_bucket_cfg: ", prompt_seq_bucket_cfg)
origin_max_prompt_len = prompt_seq_bucket_cfg[2]
max_prompt_len = prompt_bs_bucket_cfg[2] * prompt_seq_bucket_cfg[2]
max_prompt_len = min(self.max_num_batched_tokens, max_prompt_len)
Expand All @@ -226,23 +225,7 @@ def generate_prompt_buckets(self):
prompt_seq_bucket_cfg,
self.max_num_batched_tokens)

print("prompt_buckets: ", prompt_buckets)
# expand
self.global_state.prompt_buckets = []
VLLM_PROMPT_BS_BUCKET_MAX = int(
os.environ.get('VLLM_PROMPT_BS_BUCKET_MAX', 16))
for bucket in prompt_buckets:
bs = 1
while bs <= VLLM_PROMPT_BS_BUCKET_MAX:
seq_len = bucket[1] // bs
if seq_len <= 32:
bs = bs * 2
continue
self.global_state.prompt_buckets.append(
(bs * bucket[0], seq_len))
bs = bs * 2

self.global_state.prompt_buckets = list(filter(lambda bucket: bucket[1] <= origin_max_prompt_len, self.global_state.prompt_buckets))
self.global_state.prompt_buckets = list(filter(lambda bucket: bucket[1] <= origin_max_prompt_len and bucket[0] == 1, prompt_buckets))

msg = (f"Generated {len(self.global_state.prompt_buckets)} "
f"prompt buckets [bs, seq]: "
Expand Down Expand Up @@ -425,13 +408,12 @@ def _set_block_scales(self, metadata, device):
return metadata

def _set_indices_and_offsets(self, metadata, block_size, is_prompt):
slot_mapping = metadata.slot_mapping.flatten()
indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
if is_prompt:
indices = indices.unflatten(0, (-1, block_size))[:, 0]
offsets = None
if metadata.enable_merged_prefill and is_prompt:
slot_mapping = metadata.slot_mapping
else:
offsets = torch.fmod(slot_mapping, block_size)
slot_mapping = metadata.slot_mapping.flatten()
indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
offsets = torch.fmod(slot_mapping, block_size)
metadata = metadata._replace(block_offsets=offsets,
block_indices=indices)
return metadata
Expand Down Expand Up @@ -481,8 +463,7 @@ def forward(self, *args, **kwargs):
self._prepare_cos_sin(kwargs['positions'])
if kwargs['attn_metadata'].is_prompt:
print("Warming up HPU Graph - input_ids: ", input_ids.shape,
"seq_lens_tensor: ", kwargs['attn_metadata'].seq_lens_tensor,
"selected_token_indices: ", selected_token_indices)
"seq_lens_tensor: ", kwargs['attn_metadata'].seq_lens_tensor.shape, 'slot_mapping: ', kwargs['attn_metadata'].slot_mapping.shape, 'selected_token_indices: ', selected_token_indices)
hidden_states = self.model(*args, **kwargs)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
hidden_states = hidden_states.index_select(0, selected_token_indices)
Expand Down Expand Up @@ -1244,13 +1225,14 @@ def _prepare_prompt_merged(
#context_lens
#prefix_block_list

slot_mapping_merged = list(itertools.chain.from_iterable(slot_mapping))
slot_mapping_merged = [i for i in slot_mapping_merged if i != _PAD_SLOT_ID]
slot_mapping = [slot_mapping_merged]
input_tokens_merged = list(itertools.chain.from_iterable(input_tokens))
input_tokens_merged = [input_tokens_merged]
input_positions_merged = list(
itertools.chain.from_iterable(input_positions))
input_positions_merged = [input_positions_merged]
slot_mapping_merged = list(itertools.chain.from_iterable(slot_mapping))
slot_mapping_merged = [slot_mapping_merged]
context_lens_merged = [sum(context_lens)]
total_seq_lens = [sum(seq_lens)]
total_query_lens = [sum(query_lens)]
Expand Down Expand Up @@ -1284,11 +1266,15 @@ def _prepare_prompt_merged(
device='cpu')

slot_mapping = make_tensor_with_pad(slot_mapping,
max_len=max_prompt_len,
max_len=merged_prompt_len,
pad=_PAD_SLOT_ID,
dtype=torch.long,
device='cpu')

max_prefill_bs = int(os.environ.get('VLLM_PROMPT_BS_BUCKET_MAX', '16'))
max_prefill_bs = max(max_prefill_bs, len(seq_lens))
seq_lens = seq_lens + [0] * (max_prefill_bs - len(seq_lens))
context_lens = context_lens + [0] * (max_prefill_bs - len(context_lens))
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.long,
device='cpu')
Expand Down Expand Up @@ -1522,7 +1508,6 @@ def prepare_input_tensors(
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[TModelInputForHPU, SamplingMetadata]:
if len(seq_group_metadata_list) == 0:
print("seq_group_metadata_list is empty")
return self._model_input_cls(), None

input_tokens = None
Expand Down Expand Up @@ -1561,7 +1546,6 @@ def prepare_input_tensors(
decode_reqs.append(seq_group_meta)

# Prepare input tensors.
#print("prefill_reqs: ", prefill_reqs, "decode_reqs: ", decode_reqs)
prepare_prompt_impl = self._prepare_prompt_merged if self.enable_merged_prefill else self._prepare_prompt
(
input_tokens,
Expand Down Expand Up @@ -1615,23 +1599,21 @@ def prepare_input_tensors(

# FIXME: We need to adjust selected_token_indices to accommodate
# for padding
if self.enable_merged_prefill:
max_len = slot_mapping.size(1)
else:
if not self.enable_merged_prefill:
max_len = input_tokens.size(1)
paddings = [max_len - q for q in query_lens]
paddings = [0] + paddings[:-1]
paddings = list(itertools.accumulate(paddings))
paddings_prompt_logprobs = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
if seq_group_metadata.sampling_params.prompt_logprobs is not None \
and seq_group_metadata.is_prompt:
paddings_prompt_logprobs += ([paddings[i]] * seq_lens[i])
paddings = torch.tensor(
paddings_prompt_logprobs if paddings_prompt_logprobs else paddings,
dtype=sampling_metadata.selected_token_indices.dtype,
device=sampling_metadata.selected_token_indices.device)
sampling_metadata.selected_token_indices.add_(paddings)
paddings = [max_len - q for q in query_lens]
paddings = [0] + paddings[:-1]
paddings = list(itertools.accumulate(paddings))
paddings_prompt_logprobs = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
if seq_group_metadata.sampling_params.prompt_logprobs is not None \
and seq_group_metadata.is_prompt:
paddings_prompt_logprobs += ([paddings[i]] * seq_lens[i])
paddings = torch.tensor(
paddings_prompt_logprobs if paddings_prompt_logprobs else paddings,
dtype=sampling_metadata.selected_token_indices.dtype,
device=sampling_metadata.selected_token_indices.device)
sampling_metadata.selected_token_indices.add_(paddings)

if self.lora_config:
lora_mapping = LoRAMapping(
Expand Down

0 comments on commit ce48860

Please sign in to comment.