Skip to content

Commit

Permalink
move tolist to llamamodel fwd
Browse files Browse the repository at this point in the history
Signed-off-by: Chendi Xue <[email protected]>
  • Loading branch information
xuechendi committed Jan 7, 2025
1 parent fade386 commit d18df2d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
2 changes: 1 addition & 1 deletion vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def forward(
block_offsets = kwargs.get('block_offsets', None)
seq_lens_tensor = kwargs.get('seq_lens_tensor', None)
attn_bias = kwargs.get('attn_bias', None)
seq_lens_tensor_list = kwargs.get('seq_lens_tensor_list', None)
enable_merged_prefill = attn_metadata.enable_merged_prefill
if block_indices is None:
block_indices = attn_metadata.block_indices
Expand All @@ -279,7 +280,6 @@ def forward(
attn_bias = attn_metadata.attn_bias
if enable_merged_prefill and attn_metadata.is_prompt and kv_cache is not None:
max_len = attn_metadata.slot_mapping.size(1)
seq_lens_tensor_list = attn_metadata.seq_lens_tensor.tolist()
# we need to copy the key and value tensors to the padded tensors
# shape is [bacth_size, entire_seq_len, num_kv_heads, head_size]
padded_key_tensor = split_and_pad_to_length(
Expand Down
11 changes: 8 additions & 3 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def forward(
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
seq_lens_tensor_list: List[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
if isinstance(hidden_states, torch.Tensor):
skip_split = hidden_states.size()[0] == 1
Expand All @@ -313,7 +314,8 @@ def forward(
hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata)
attn_metadata=attn_metadata,
seq_lens_tensor_list=seq_lens_tensor_list)

# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
Expand Down Expand Up @@ -479,11 +481,15 @@ def forward(
import habana_frameworks.torch as htorch
htorch.core.mark_step()

if attn_metadata.enable_merged_prefill and attn_metadata.is_prompt:
seq_lens_tensor_list = attn_metadata.seq_lens_tensor.tolist()
else:
seq_lens_tensor_list = None
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
attn_metadata, residual, seq_lens_tensor_list)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
Expand All @@ -493,7 +499,6 @@ def forward(
# we need to split result before do RMSNorm
if attn_metadata.enable_merged_prefill and attn_metadata.is_prompt:
max_len=attn_metadata.slot_mapping.size(1)
seq_lens_tensor_list = attn_metadata.seq_lens_tensor.tolist()
hidden_states = split_and_pad_to_length(hidden_states.view(-1, hidden_states.size(2)), max_len, seq_lens_tensor_list)
residual = split_and_pad_to_length(residual.view(-1, hidden_states.size(2)), max_len, seq_lens_tensor_list)
hidden_states, _ = self.norm(hidden_states, residual)
Expand Down

0 comments on commit d18df2d

Please sign in to comment.