diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 82d6e5e3f3225..b29d95c756dbe 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -268,7 +268,6 @@ def forward( block_offsets = kwargs.get('block_offsets', None) seq_lens_tensor = kwargs.get('seq_lens_tensor', None) attn_bias = kwargs.get('attn_bias', None) - seq_lens_tensor_list = kwargs.get('seq_lens_tensor_list', None) enable_merged_prefill = attn_metadata.enable_merged_prefill if block_indices is None: block_indices = attn_metadata.block_indices @@ -278,25 +277,12 @@ def forward( seq_lens_tensor = attn_metadata.seq_lens_tensor if attn_bias is None: # This is the case for prompt run attn_bias = attn_metadata.attn_bias - if enable_merged_prefill and attn_metadata.is_prompt and kv_cache is not None: - max_len = attn_metadata.slot_mapping.size(1) - # we need to copy the key and value tensors to the padded tensors - # shape is [bacth_size, entire_seq_len, num_kv_heads, head_size] - padded_key_tensor = split_and_pad_to_length( - key, max_len, seq_lens_tensor_list) - padded_value_tensor = split_and_pad_to_length( - value, max_len, seq_lens_tensor_list) - padded_key_tensor = padded_key_tensor.flatten(0, 1).unflatten( - 0, (block_indices.size(0), -1)) - padded_value_tensor = padded_value_tensor.flatten(0, 1).unflatten( - 0, (block_indices.size(0), -1)) - + if enable_merged_prefill and attn_metadata.is_prompt and kv_cache is not None: key_cache, value_cache = HPUPagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - - key_cache = self.k_cache(padded_key_tensor, key_cache, + kv_cache, self.num_kv_heads, self.head_size) + key_cache = self.k_cache(key, key_cache, block_indices, block_offsets) - value_cache = self.v_cache(padded_value_tensor, value_cache, + value_cache = self.v_cache(value, value_cache, block_indices, block_offsets) else: if attn_metadata.is_prompt: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index e0800010db7ee..b44b7417b03ac 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -292,7 +292,6 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], - seq_lens_tensor_list: List[int], ) -> Tuple[torch.Tensor, torch.Tensor]: if isinstance(hidden_states, torch.Tensor): skip_split = hidden_states.size()[0] == 1 @@ -314,8 +313,7 @@ def forward( hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, - attn_metadata=attn_metadata, - seq_lens_tensor_list=seq_lens_tensor_list) + attn_metadata=attn_metadata) # Fully Connected hidden_states, residual = self.post_attention_layernorm( @@ -481,26 +479,17 @@ def forward( import habana_frameworks.torch as htorch htorch.core.mark_step() - if attn_metadata.enable_merged_prefill and attn_metadata.is_prompt: - seq_lens_tensor_list = attn_metadata.seq_lens_tensor.tolist() - else: - seq_lens_tensor_list = None for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, kv_caches[i - self.start_layer], - attn_metadata, residual, seq_lens_tensor_list) + attn_metadata, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, "residual": residual }) - # we need to split result before do RMSNorm - if attn_metadata.enable_merged_prefill and attn_metadata.is_prompt: - max_len=attn_metadata.slot_mapping.size(1) - hidden_states = split_and_pad_to_length(hidden_states.view(-1, hidden_states.size(2)), max_len, seq_lens_tensor_list) - residual = split_and_pad_to_length(residual.view(-1, hidden_states.size(2)), max_len, seq_lens_tensor_list) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index b0dac110b8ac3..b9466c4f56a31 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -214,7 +214,6 @@ def generate_prompt_buckets(self): prompt_bs_bucket_cfg = self.global_state.prompt_bs_bucket_cfg prompt_seq_bucket_cfg = self.global_state.prompt_seq_bucket_cfg - print("prompt_seq_bucket_cfg: ", prompt_seq_bucket_cfg) origin_max_prompt_len = prompt_seq_bucket_cfg[2] max_prompt_len = prompt_bs_bucket_cfg[2] * prompt_seq_bucket_cfg[2] max_prompt_len = min(self.max_num_batched_tokens, max_prompt_len) @@ -226,23 +225,7 @@ def generate_prompt_buckets(self): prompt_seq_bucket_cfg, self.max_num_batched_tokens) - print("prompt_buckets: ", prompt_buckets) - # expand - self.global_state.prompt_buckets = [] - VLLM_PROMPT_BS_BUCKET_MAX = int( - os.environ.get('VLLM_PROMPT_BS_BUCKET_MAX', 16)) - for bucket in prompt_buckets: - bs = 1 - while bs <= VLLM_PROMPT_BS_BUCKET_MAX: - seq_len = bucket[1] // bs - if seq_len <= 32: - bs = bs * 2 - continue - self.global_state.prompt_buckets.append( - (bs * bucket[0], seq_len)) - bs = bs * 2 - - self.global_state.prompt_buckets = list(filter(lambda bucket: bucket[1] <= origin_max_prompt_len, self.global_state.prompt_buckets)) + self.global_state.prompt_buckets = list(filter(lambda bucket: bucket[1] <= origin_max_prompt_len and bucket[0] == 1, prompt_buckets)) msg = (f"Generated {len(self.global_state.prompt_buckets)} " f"prompt buckets [bs, seq]: " @@ -425,13 +408,12 @@ def _set_block_scales(self, metadata, device): return metadata def _set_indices_and_offsets(self, metadata, block_size, is_prompt): - slot_mapping = metadata.slot_mapping.flatten() - indices = torch.div(slot_mapping, block_size, rounding_mode="floor") - if is_prompt: - indices = indices.unflatten(0, (-1, block_size))[:, 0] - offsets = None + if metadata.enable_merged_prefill and is_prompt: + slot_mapping = metadata.slot_mapping else: - offsets = torch.fmod(slot_mapping, block_size) + slot_mapping = metadata.slot_mapping.flatten() + indices = torch.div(slot_mapping, block_size, rounding_mode="floor") + offsets = torch.fmod(slot_mapping, block_size) metadata = metadata._replace(block_offsets=offsets, block_indices=indices) return metadata @@ -481,8 +463,7 @@ def forward(self, *args, **kwargs): self._prepare_cos_sin(kwargs['positions']) if kwargs['attn_metadata'].is_prompt: print("Warming up HPU Graph - input_ids: ", input_ids.shape, - "seq_lens_tensor: ", kwargs['attn_metadata'].seq_lens_tensor, - "selected_token_indices: ", selected_token_indices) + "seq_lens_tensor: ", kwargs['attn_metadata'].seq_lens_tensor.shape, 'slot_mapping: ', kwargs['attn_metadata'].slot_mapping.shape, 'selected_token_indices: ', selected_token_indices) hidden_states = self.model(*args, **kwargs) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = hidden_states.index_select(0, selected_token_indices) @@ -1244,13 +1225,14 @@ def _prepare_prompt_merged( #context_lens #prefix_block_list + slot_mapping_merged = list(itertools.chain.from_iterable(slot_mapping)) + slot_mapping_merged = [i for i in slot_mapping_merged if i != _PAD_SLOT_ID] + slot_mapping = [slot_mapping_merged] input_tokens_merged = list(itertools.chain.from_iterable(input_tokens)) input_tokens_merged = [input_tokens_merged] input_positions_merged = list( itertools.chain.from_iterable(input_positions)) input_positions_merged = [input_positions_merged] - slot_mapping_merged = list(itertools.chain.from_iterable(slot_mapping)) - slot_mapping_merged = [slot_mapping_merged] context_lens_merged = [sum(context_lens)] total_seq_lens = [sum(seq_lens)] total_query_lens = [sum(query_lens)] @@ -1284,11 +1266,15 @@ def _prepare_prompt_merged( device='cpu') slot_mapping = make_tensor_with_pad(slot_mapping, - max_len=max_prompt_len, + max_len=merged_prompt_len, pad=_PAD_SLOT_ID, dtype=torch.long, device='cpu') + max_prefill_bs = int(os.environ.get('VLLM_PROMPT_BS_BUCKET_MAX', '16')) + max_prefill_bs = max(max_prefill_bs, len(seq_lens)) + seq_lens = seq_lens + [0] * (max_prefill_bs - len(seq_lens)) + context_lens = context_lens + [0] * (max_prefill_bs - len(context_lens)) seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.long, device='cpu') @@ -1522,7 +1508,6 @@ def prepare_input_tensors( seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[TModelInputForHPU, SamplingMetadata]: if len(seq_group_metadata_list) == 0: - print("seq_group_metadata_list is empty") return self._model_input_cls(), None input_tokens = None @@ -1561,7 +1546,6 @@ def prepare_input_tensors( decode_reqs.append(seq_group_meta) # Prepare input tensors. - #print("prefill_reqs: ", prefill_reqs, "decode_reqs: ", decode_reqs) prepare_prompt_impl = self._prepare_prompt_merged if self.enable_merged_prefill else self._prepare_prompt ( input_tokens, @@ -1615,23 +1599,21 @@ def prepare_input_tensors( # FIXME: We need to adjust selected_token_indices to accommodate # for padding - if self.enable_merged_prefill: - max_len = slot_mapping.size(1) - else: + if not self.enable_merged_prefill: max_len = input_tokens.size(1) - paddings = [max_len - q for q in query_lens] - paddings = [0] + paddings[:-1] - paddings = list(itertools.accumulate(paddings)) - paddings_prompt_logprobs = [] - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - if seq_group_metadata.sampling_params.prompt_logprobs is not None \ - and seq_group_metadata.is_prompt: - paddings_prompt_logprobs += ([paddings[i]] * seq_lens[i]) - paddings = torch.tensor( - paddings_prompt_logprobs if paddings_prompt_logprobs else paddings, - dtype=sampling_metadata.selected_token_indices.dtype, - device=sampling_metadata.selected_token_indices.device) - sampling_metadata.selected_token_indices.add_(paddings) + paddings = [max_len - q for q in query_lens] + paddings = [0] + paddings[:-1] + paddings = list(itertools.accumulate(paddings)) + paddings_prompt_logprobs = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + if seq_group_metadata.sampling_params.prompt_logprobs is not None \ + and seq_group_metadata.is_prompt: + paddings_prompt_logprobs += ([paddings[i]] * seq_lens[i]) + paddings = torch.tensor( + paddings_prompt_logprobs if paddings_prompt_logprobs else paddings, + dtype=sampling_metadata.selected_token_indices.dtype, + device=sampling_metadata.selected_token_indices.device) + sampling_metadata.selected_token_indices.add_(paddings) if self.lora_config: lora_mapping = LoRAMapping(