diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 13c62b8045785..d3e5c6bdf35f6 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -5,6 +5,7 @@ import random import time from typing import List, Optional +import os import pandas as pd import torch @@ -71,8 +72,19 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, raise ValueError("output_len too small") # Load the dataset. - with open(dataset_path) as f: - dataset = json.load(f) + if os.path.splitext(dataset_path)[1] == ".json": + with open(dataset_path) as f: + dataset = json.load(f) + elif os.path.splitext(dataset_path)[1] == ".pkl": + import pandas as pd + dataset = pd.read_pickle(dataset_path) + dataset = dataset[['input', 'output']].to_dict(orient="records") + for data in dataset: + data["conversations"] = [ + {"value": data["input"]}, + {"value": data["output"]} + ] + # Filter out the conversations with less than 2 turns. dataset = [data for data in dataset if len(data["conversations"]) >= 2] # Shuffle the dataset. @@ -80,8 +92,11 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, # Filter out sequences that are too long or too short filtered_dataset: List[SampleRequest] = [] + prompt_lens = [] for data in dataset: if len(filtered_dataset) == num_requests: + if args.sort_by_len: + filtered_dataset = sorted(filtered_dataset, key=lambda x: x.prompt_len) break # Only keep the first two turns of each conversation. @@ -120,7 +135,11 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=multi_modal_data)) + prompt_lens.append(prompt_len) + print("!!!!prompt length are: ", pd.Series(prompt_lens).describe()) + # for i, data in enumerate(filtered_dataset): + # print(i, data.prompt) return filtered_dataset @@ -151,9 +170,9 @@ def run_vllm( use_beam_search = False if not use_beam_search: - for _ in range(2): + for _ in range(1): start = time.perf_counter() - llm.generate(prompts, sampling_params, use_tqdm=True) + llm.generate(prompts, sampling_params, use_tqdm=False) end = time.perf_counter() else: prompts = [request.prompt for request in requests] @@ -445,6 +464,12 @@ def main(args: argparse.Namespace): action='store_true', default=False, help="Disable decoupled async engine frontend.") + parser.add_argument("--sort-by-len", + action='store_true', + default=False) + parser.add_argument("--bucket-selective", + action='store_true', + default=False) parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() if args.tokenizer is None: diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index c5b57cb1967f0..1d5b83c1e61f2 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -30,6 +30,42 @@ "vLLM will use native implementation.") +def prompt_fsdpa( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[torch.Tensor] = None, + p: float = 0.0, + scale: Optional[float] = None, + matmul_qk_op=torch.matmul, + softmax_op=torch.softmax, + matmul_av_op=torch.matmul, + valid_seq_lengths: Optional[torch.Tensor] = None, + fsdpa_op=None, +) -> torch.Tensor: + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + query_heads = query.size(1) + kv_heads = key.size(1) + VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE = os.environ.get( + 'VLLM_REMOVE_REPEAT_KV_CACHE_MERGED_PREFILL', '1') == '1' + # TODO: remove after fusedsdpa fix for query_heads != kv_heads + if query_heads != kv_heads: + if VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE: + key = ops.repeat_kv(key, int(query_heads // kv_heads)) + value = ops.repeat_kv(value, int(query_heads // kv_heads)) + if attn_bias is not None: + attn_bias = attn_bias.unsqueeze(1) + softmax_mode = 'fast' + recompute_mode = True + attn_weights = fsdpa_op(query, key, value, attn_bias, 0.0, False, + scale, softmax_mode, recompute_mode, None, + 'right') + attn_weights = attn_weights.transpose(1, 2) + return attn_weights + + class HPUAttentionBackend(AttentionBackend): @staticmethod @@ -83,6 +119,9 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata): attn_bias: Optional[torch.Tensor] seq_lens_tensor: Optional[torch.Tensor] context_lens_tensor: Optional[torch.Tensor] + enable_merged_prefill: bool = False + actual_num_prefills: Optional[torch.Tensor] = None + repeated_idx_tensor: Optional[torch.Tensor] = None seq_lens: Optional[List[int]] = None encoder_seq_lens: Optional[List[int]] = None encoder_seq_lens_tensor: Optional[torch.Tensor] = None @@ -213,6 +252,7 @@ def forward( block_offsets = kwargs.get('block_offsets', None) seq_lens_tensor = kwargs.get('seq_lens_tensor', None) attn_bias = kwargs.get('attn_bias', None) + enable_merged_prefill = attn_metadata.enable_merged_prefill if block_indices is None: block_indices = attn_metadata.block_indices if block_offsets is None: @@ -221,7 +261,7 @@ def forward( seq_lens_tensor = attn_metadata.seq_lens_tensor if attn_bias is None: # This is the case for prompt run attn_bias = attn_metadata.attn_bias - if attn_metadata.is_prompt: + if attn_metadata.is_prompt and not enable_merged_prefill: key = key.unflatten(0, (block_indices.size(0), -1)) value = value.unflatten(0, (block_indices.size(0), -1)) if kv_cache is not None: @@ -232,9 +272,9 @@ def forward( # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. key_cache = self.k_cache(key, key_cache, block_indices, - block_offsets) + block_offsets) value_cache = self.v_cache(value, value_cache, block_indices, - block_offsets) + block_offsets) if attn_metadata.is_prompt: # Prompt run. @@ -253,10 +293,16 @@ def forward( attn_bias = attn_bias.tile( (1, self.num_kv_heads, 1, 1)) attn_bias.add_(position_bias) + elif enable_merged_prefill: + pass else: attn_bias = None - out = ops.prompt_attention( + if enable_merged_prefill and self.prefill_use_fusedsdpa: + prompt_attn_func = prompt_fsdpa + else: + prompt_attn_func = ops.prompt_attention + out = prompt_attn_func( query.view(query_shape), key.view(kv_shape), value.view(kv_shape), diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 7c3679d40546d..b92c49dd7c154 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -11,6 +11,7 @@ import math import os import time +import copy from array import array from enum import IntEnum from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, @@ -19,7 +20,7 @@ import habana_frameworks.torch as htorch import habana_frameworks.torch.internal.bridge_config as bc import torch -from vllm_hpu_extension.bucketing import HPUBucketingContext +from vllm_hpu_extension.bucketing import HPUBucketingContext, generate_prompt_buckets from vllm_hpu_extension.ops import LoraMask as LoraMask from vllm_hpu_extension.ops import batch2block, block2batch from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler, @@ -205,6 +206,33 @@ def get_child(parent, suffix, is_list=False): } +class HPUBucketingContextWithMergedPrefill(HPUBucketingContext): + + def generate_prompt_buckets(self): + print( + "HPUBucketingContextWithMergedPrefill - generate_prompt_buckets is called" + ) + + prompt_bs_bucket_cfg = self.global_state.prompt_bs_bucket_cfg + prompt_seq_bucket_cfg = self.global_state.prompt_seq_bucket_cfg + origin_max_prompt_len = prompt_seq_bucket_cfg[2] + max_prompt_len = prompt_bs_bucket_cfg[2] * prompt_seq_bucket_cfg[2] + max_prompt_len = min(self.max_num_batched_tokens, max_prompt_len) + prompt_seq_bucket_cfg[2] = max_prompt_len + + prompt_buckets, prompt_omitted_buckets = \ + generate_prompt_buckets( + prompt_bs_bucket_cfg, + prompt_seq_bucket_cfg, + self.max_num_batched_tokens) + + self.global_state.prompt_buckets = list(filter(lambda bucket: bucket[1] <= origin_max_prompt_len and bucket[0] == 1, prompt_buckets)) + + msg = (f"Generated {len(self.global_state.prompt_buckets)} " + f"prompt buckets [bs, seq]: " + f"{list(sorted(self.global_state.prompt_buckets))}") + print(msg) + class HpuModelAdapter: def __init__(self, model, block_size, dtype, enforce_eager, layer_names): @@ -287,6 +315,35 @@ def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, attn_metadata = prefill_metadata._replace(attn_bias=attn_bias) return attn_metadata + def _set_merged_attn_bias( + self, + attn_metadata, + max_seq_len, + device, + dtype + ): # create a 2D causal attn mask to ensure I can only attend to the past + if attn_metadata is None or not attn_metadata.is_prompt: + return attn_metadata + if attn_metadata.attn_bias is not None: + return attn_metadata + #TODO: Support batch_size > 1 + # get length of each sequence + repeated_idx = attn_metadata.repeated_idx_tensor.view(1,-1).expand(max_seq_len, -1) + # create tensor with all indices from 0 to T-1 repeated T times along dimesion 1 + mask_indices = torch.arange(max_seq_len, dtype=dtype, device=device).view(-1,1).expand(-1, max_seq_len) + # create causal mask and additionally mask out all tokens from preceeding sequences + mask = mask_indices.le(repeated_idx) + causal_mask = torch.ones(max_seq_len, max_seq_len, dtype=torch.bool, device=device).tril() + causal_mask = causal_mask.logical_and(mask) + # should be math(-inf) but -10000 is used for numerical stability + causal_attn_mask_tensor = torch.zeros_like(causal_mask, device=device, dtype=dtype).masked_fill_(~causal_mask, -10000) + causal_attn_mask_tensor = causal_attn_mask_tensor.view( + 1, 1, causal_attn_mask_tensor.shape[0], causal_attn_mask_tensor.shape[1]) + + attn_metadata = attn_metadata._replace( + attn_bias=causal_attn_mask_tensor) + return attn_metadata + def _set_block_mapping(self, metadata, batch_size, device, dtype): mask = torch.arange(0, self.block_size, @@ -327,9 +384,12 @@ def _set_block_scales(self, metadata, device): return metadata def _set_indices_and_offsets(self, metadata, block_size, is_prompt): - slot_mapping = metadata.slot_mapping.flatten() + if metadata.enable_merged_prefill and is_prompt: + slot_mapping = metadata.slot_mapping + else: + slot_mapping = metadata.slot_mapping.flatten() indices = torch.div(slot_mapping, block_size, rounding_mode="floor") - if is_prompt: + if not metadata.enable_merged_prefill and is_prompt: indices = indices.unflatten(0, (-1, block_size))[:, 0] offsets = None else: @@ -340,7 +400,11 @@ def _set_indices_and_offsets(self, metadata, block_size, is_prompt): def _update_metadata(self, attn_metadata, batch_size, seq_len, device, dtype): - if attn_metadata.is_prompt: + if attn_metadata.is_prompt and attn_metadata.enable_merged_prefill: + attn_metadata = self._set_merged_attn_bias(attn_metadata, + seq_len, + device, dtype) + elif attn_metadata.is_prompt: attn_metadata = self._set_attn_bias(attn_metadata, batch_size, seq_len, device, dtype) else: @@ -377,6 +441,15 @@ def forward(self, *args, **kwargs): LoraMask.setLoraMask(kwargs.pop('lora_mask')) if self.layer_names is not None: self._prepare_cos_sin(kwargs['positions']) + if kwargs['attn_metadata'].is_prompt: + am = kwargs['attn_metadata'] + print("Warming up HPU Graph - input_ids: ", input_ids.shape, + "seq_lens_tensor: ", am.seq_lens_tensor.shape, + "context_lens_tensor: ", am.context_lens_tensor.shape, + "attn_bias: ", am.attn_bias.shape if am.attn_bias is not None else None, + "enable_merged_prefill:", am.enable_merged_prefill, + "slot_mapping: ", am.slot_mapping.shape, + "selected_token_indices: ", selected_token_indices.shape) hidden_states = self.model(*args, **kwargs) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = hidden_states.index_select(0, selected_token_indices) @@ -622,10 +695,16 @@ def __init__( self.profiler_counter_helper = HabanaProfilerCounterHelper() self.seen_configs: set = set() self._mem_margin: Optional[int] = None - self.bucketing_ctx = HPUBucketingContext(self.max_num_seqs, - self.max_num_prefill_seqs, - self.block_size, - self.max_num_batched_tokens) + self.enable_merged_prefill = os.environ.get('VLLM_MERGED_PREFILL', + 'false').lower() == 'true' + if self.enable_merged_prefill: + self.bucketing_ctx = HPUBucketingContextWithMergedPrefill( + self.max_num_seqs, self.max_num_prefill_seqs, self.block_size, + self.max_num_batched_tokens) + else: + self.bucketing_ctx = HPUBucketingContext( + self.max_num_seqs, self.max_num_prefill_seqs, self.block_size, + self.max_num_batched_tokens) self.graphed_buckets: Set[Any] = set() self._set_gc_threshold() @@ -1011,6 +1090,232 @@ def _prepare_prompt( slot_mapping=slot_mapping, lora_ids=lora_ids) + def _prepare_prompt_merged( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> PreparePromptMetadata: + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + slot_mapping: List[List[int]] = [] + lora_index_mapping: List[List[int]] = [] + lora_prompt_mapping: List[List[int]] = [] + lora_requests: Set[LoRARequest] = set() + + seq_lens: List[int] = [] + context_lens: List[int] = [] + query_lens: List[int] = [] + prefix_block_tables: List[List[int]] = [] + multi_modal_kwargs_list: List[MultiModalKwargs] = [] + + if len(seq_group_metadata_list) == 0: + return PreparePromptMetadata.empty() + + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + computed_block_nums = seq_group_metadata.computed_block_nums + if (self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled + and not (computed_block_nums is None + or computed_block_nums == [])): + raise RuntimeError( + "chunked prefill cannot be used with prefix caching " + "now.") + + token_chunk_size = seq_group_metadata.token_chunk_size + seq_data = seq_group_metadata.seq_data[seq_id] + context_len = seq_data.get_num_computed_tokens() + # We should use get_len here because in case of preemption + # it contains output tokens. + seq_len = min(seq_data.get_len(), context_len + token_chunk_size) + prompt_tokens = seq_data.get_token_ids()[context_len:seq_len] + seq_lens.append(seq_len) + + # NOTE: This only works for oooooooxxx style attention. + if computed_block_nums is not None and len( + computed_block_nums) > 0 and self.sliding_window is None: + # Prefix is not supported with sliding_window + context_len = len(computed_block_nums) * self.block_size + prompt_tokens = prompt_tokens[context_len:] + prefix_block_tables.append(computed_block_nums) + elif self.scheduler_config.chunked_prefill_enabled: + if seq_group_metadata.block_tables is not None: + # Prefill has chunked before. + block_table = seq_group_metadata.block_tables[seq_id] + prefix_block_tables.append(block_table) + else: + # The first prefill. + prefix_block_tables.append([]) + else: + prefix_block_tables.append([]) + # Right now, prefill start is always 0. However, this + # assumption can be changed once chunked prefill is introduced. + assert context_len == 0 + + # actual prompt lens + context_lens.append(context_len) + query_lens.append(seq_len - context_len) + input_tokens.append(prompt_tokens) + # NOTE(woosuk): Here we assume that the first token in the prompt + # is always the first token in the sequence. + input_positions.append(list(range(context_len, seq_len))) + + mm_data = seq_group_metadata.multi_modal_data + if mm_data: + mm_kwargs = self.multi_modal_input_mapper(mm_data) + multi_modal_kwargs_list.append(mm_kwargs) + + if seq_group_metadata.block_tables is None: + # During memory profiling, the block tables are not initialized + # yet. In this case, we just use a dummy slot mapping. + slot_mapping.append([_PAD_SLOT_ID] * seq_len) + continue + + # Compute the slot mapping. + slot_mapping.append([]) + block_table = seq_group_metadata.block_tables[seq_id] + + # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, + # where start_idx is max(0, seq_len - sliding_window). + # For example, if the prompt len is 10, sliding window is 8, and + # block size is 4, the first two tokens are masked and the slot + # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 + if self.sliding_window is not None: + assert context_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention") + start_idx = max(0, seq_len - self.sliding_window) + for i in range(context_len, seq_len): + if i < start_idx: + slot_mapping[-1].append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping[-1].append(slot) + + slot_mapping_merged = list(itertools.chain.from_iterable(slot_mapping)) + slot_mapping_merged = [i for i in slot_mapping_merged if i != _PAD_SLOT_ID] + slot_mapping = [slot_mapping_merged] + input_tokens_merged = list(itertools.chain.from_iterable(input_tokens)) + input_tokens_merged = [input_tokens_merged] + input_positions_merged = list( + itertools.chain.from_iterable(input_positions)) + input_positions_merged = [input_positions_merged] + total_seq_lens = [sum(seq_lens)] + total_query_lens = [sum(query_lens)] + + max_query_len = max(total_query_lens) + real_num_seqs = len(total_query_lens) + assert max_query_len > 0 + + + merged_prompt_len = max( + self.bucketing_ctx.get_padded_prompt_seq_len(max(total_seq_lens)), + self.block_size) + # get cumsum of seq_lens + repeated_idx = list(itertools.accumulate(seq_lens)) + repeated_idx = [[idx - 1] * seq_len for idx, seq_len in zip(repeated_idx, seq_lens)] + repeated_idx = list(itertools.chain.from_iterable(repeated_idx)) + [0] * (merged_prompt_len - sum(seq_lens)) + prefix_block_list_tensor = None + + repeated_idx_tensor = torch.tensor(repeated_idx, dtype=torch.long, device='cpu') + input_tokens_tensor = make_tensor_with_pad(input_tokens_merged, + max_len=merged_prompt_len, + pad=0, + dtype=torch.long, + device='cpu') + + input_positions = make_tensor_with_pad(input_positions_merged, + max_len=merged_prompt_len, + pad=0, + dtype=torch.long, + device='cpu') + + slot_mapping = make_tensor_with_pad(slot_mapping, + max_len=merged_prompt_len, + pad=_PAD_SLOT_ID, + dtype=torch.long, + device='cpu') + actual_num_prefills_tensor = torch.tensor(len(seq_lens), + dtype=torch.long, + device='cpu') + + max_prefill_bs = int(os.environ.get('VLLM_PROMPT_BS_BUCKET_MAX', '8')) + max_prefill_bs = max(max_prefill_bs, len(seq_lens)) + seq_lens = seq_lens + [0] * (max_prefill_bs - len(seq_lens)) + context_lens = context_lens + [0] * (max_prefill_bs - len(context_lens)) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.long, + device='cpu') + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.long, + device='cpu') + ##### Create attn_bias in CPU ##### + causal_attn_mask_tensor = None + ######################## + # Note: num_prefill_tokens is calculated using the length of + # input_tokens after padding. + num_prefill_tokens = input_tokens_tensor.numel() + input_tokens_tensor = input_tokens_tensor.to( # type: ignore + self.device, non_blocking=True) + input_positions = input_positions.to( # type: ignore + self.device, non_blocking=True) + slot_mapping = slot_mapping.to( # type: ignore + self.device, non_blocking=True) + seq_lens_tensor = seq_lens_tensor.to(self.device, non_blocking=True) + context_lens_tensor = context_lens_tensor.to(self.device, + non_blocking=True) + repeated_idx_tensor = repeated_idx_tensor.to(self.device, non_blocking=True) + actual_num_prefills_tensor = actual_num_prefills_tensor.to( + self.device, non_blocking=True) + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=True, + enable_merged_prefill=True, + actual_num_prefills=actual_num_prefills_tensor, + repeated_idx_tensor=repeated_idx_tensor, + block_list=prefix_block_list_tensor, + block_mapping=None, + block_usage=None, + block_indices=None, + block_offsets=None, + block_scales=None, + block_groups=None, + attn_bias=causal_attn_mask_tensor, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + context_lens_tensor=context_lens_tensor, + num_prefills=real_num_seqs, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps= + None # FIXME(kzawora): mutli-modality will not work here + ) + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) + for t in multi_modal_kwargs: + if torch.is_tensor(multi_modal_kwargs[t]): + multi_modal_kwargs[t] = multi_modal_kwargs[t].to( + self.device, non_blocking=True) + + return PreparePromptMetadata(input_tokens=input_tokens_tensor, + input_positions=input_positions, + attn_metadata=attn_metadata, + seq_lens=seq_lens, + query_lens=query_lens, + lora_index_mapping=lora_index_mapping, + lora_prompt_mapping=lora_prompt_mapping, + lora_requests=lora_requests, + multi_modal_kwargs=multi_modal_kwargs, + slot_mapping=slot_mapping, + lora_ids=[]) + def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -1223,6 +1528,7 @@ def prepare_input_tensors( decode_reqs.append(seq_group_meta) # Prepare input tensors. + prepare_prompt_impl = self._prepare_prompt_merged if self.enable_merged_prefill else self._prepare_prompt ( input_tokens, input_positions, @@ -1235,7 +1541,7 @@ def prepare_input_tensors( multi_modal_kwargs, slot_mapping, lora_ids, - ) = self._prepare_prompt(prefill_reqs) + ) = prepare_prompt_impl(prefill_reqs) ( decode_input_tokens, decode_input_positions, @@ -1275,20 +1581,29 @@ def prepare_input_tensors( # FIXME: We need to adjust selected_token_indices to accommodate # for padding - max_len = input_tokens.size(1) - paddings = [max_len - q for q in query_lens] - paddings = [0] + paddings[:-1] - paddings = list(itertools.accumulate(paddings)) - paddings_prompt_logprobs = [] - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - if seq_group_metadata.sampling_params.prompt_logprobs is not None \ - and seq_group_metadata.is_prompt: - paddings_prompt_logprobs += ([paddings[i]] * seq_lens[i]) - paddings = torch.tensor( - paddings_prompt_logprobs if paddings_prompt_logprobs else paddings, - dtype=sampling_metadata.selected_token_indices.dtype, - device=sampling_metadata.selected_token_indices.device) - sampling_metadata.selected_token_indices.add_(paddings) + if not self.enable_merged_prefill: + max_len = input_tokens.size(1) + paddings = [max_len - q for q in query_lens] + paddings = [0] + paddings[:-1] + paddings = list(itertools.accumulate(paddings)) + paddings_prompt_logprobs = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + if seq_group_metadata.sampling_params.prompt_logprobs is not None \ + and seq_group_metadata.is_prompt: + paddings_prompt_logprobs += ([paddings[i]] * seq_lens[i]) + paddings = torch.tensor( + paddings_prompt_logprobs if paddings_prompt_logprobs else paddings, + dtype=sampling_metadata.selected_token_indices.dtype, + device=sampling_metadata.selected_token_indices.device) + sampling_metadata.selected_token_indices.add_(paddings) + else: + paddings = [0] * (num_prefills - sampling_metadata.selected_token_indices.size(0)) + paddings = torch.tensor( + paddings, + dtype=sampling_metadata.selected_token_indices.dtype, + device=sampling_metadata.selected_token_indices.device) + sampling_metadata.selected_token_indices = \ + torch.cat((sampling_metadata.selected_token_indices, paddings), dim=0) if self.lora_config: lora_mapping = LoRAMapping( @@ -1375,6 +1690,9 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: 'attn_bias', 'seq_lens_tensor', 'context_lens_tensor', + 'enable_merged_prefill', + 'actual_num_prefills', + 'repeated_idx_tensor', 'block_list', 'block_mapping', 'block_usage', @@ -1423,8 +1741,11 @@ def profile_run(self) -> None: max_batch_size = min(self.max_num_seqs, self.max_num_batched_tokens // max_seq_len) + origin_enable_merged_prefill = self.enable_merged_prefill + self.enable_merged_prefill = False self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches, False, True) + self.enable_merged_prefill = origin_enable_merged_prefill return def warmup_scenario(self, @@ -2147,6 +2468,11 @@ def try_revert_dummy_output_tokens(): selected_token_indices=sampling_metadata. selected_token_indices) + # change the selected_token_indices shape after fwd, so hpu graph capture can use exactly same shape + if execute_model_kwargs['attn_metadata'].actual_num_prefills is not None: + actual_num_prefills = execute_model_kwargs['attn_metadata'].actual_num_prefills + sampling_metadata.selected_token_indices = sampling_metadata.selected_token_indices[:actual_num_prefills] + hidden_states = hidden_states[:actual_num_prefills] if self.lora_config: LoraMask.setLoraMask( lora_logits_mask.index_select(