diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 4add2e78e..5af1ddfd2 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -478,7 +478,7 @@ class EngineArgs: self.enable_prefix_caching = False if self.speculative_config is not None: self.enable_prefix_caching = False - if not current_platform.is_cuda() and not current_platform.is_xpu(): + if not current_platform.is_cuda() and not current_platform.is_xpu() and not current_platform.is_intel_hpu(): self.enable_prefix_caching = False # if self.dynamic_load_weight: # self.enable_prefix_caching = False diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index dcf3f8a59..ee261fa0f 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -40,6 +40,7 @@ from fastdeploy.engine.expert_service import start_data_parallel_service from fastdeploy.engine.request import Request from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal from fastdeploy.metrics.metrics import main_process_metrics +from fastdeploy.platforms import current_platform from fastdeploy.utils import EngineError, console_logger, envs, llm_logger @@ -136,8 +137,9 @@ class LLMEngine: # If block numer is specified and model is deployed in mixed mode, start cache manager first if not self.do_profile and self.cfg.scheduler_config.splitwise_role != "mixed": - device_ids = self.cfg.parallel_config.device_ids.split(",") - self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix) + if not current_platform.is_intel_hpu(): + device_ids = self.cfg.parallel_config.device_ids.split(",") + self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix) # Start workers self.worker_proc = self._start_worker_service() @@ -170,8 +172,9 @@ class LLMEngine: if self.do_profile: self._stop_profile() elif self.cfg.scheduler_config.splitwise_role == "mixed" and self.cfg.cache_config.enable_prefix_caching: - device_ids = self.cfg.parallel_config.device_ids.split(",") - self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix) + if not current_platform.is_intel_hpu(): + device_ids = self.cfg.parallel_config.device_ids.split(",") + self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix) # Launch components: scheduler, cache_manager, expert_service et.al. if self.cfg.scheduler_config.splitwise_role != "mixed": @@ -673,8 +676,9 @@ class LLMEngine: self.cfg.cache_config.reset(num_gpu_blocks) self.engine.resource_manager.reset_cache_config(self.cfg.cache_config) if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed": - device_ids = self.cfg.parallel_config.device_ids.split(",") - self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix) + if not current_platform.is_intel_hpu(): + device_ids = self.cfg.parallel_config.device_ids.split(",") + self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix) def check_health(self, time_interval_threashold=30): """ diff --git a/fastdeploy/model_executor/layers/backends/intel_hpu/attention/hpu_attn_backend.py b/fastdeploy/model_executor/layers/backends/intel_hpu/attention/hpu_attn_backend.py index 1b8a6f261..f5dcc284d 100644 --- a/fastdeploy/model_executor/layers/backends/intel_hpu/attention/hpu_attn_backend.py +++ b/fastdeploy/model_executor/layers/backends/intel_hpu/attention/hpu_attn_backend.py @@ -16,6 +16,7 @@ from __future__ import annotations +import math import os from abc import abstractmethod from dataclasses import dataclass, field @@ -39,6 +40,31 @@ if TYPE_CHECKING: from fastdeploy.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear +def get_attention_mask(seq_lens_encoder, seq_lens_decoder, batch_size, query_len): + max_context_len = int(paddle.max(seq_lens_decoder).item()) + past_mask = paddle.arange(0, max_context_len, dtype=paddle.int32) + past_mask = paddle.greater_equal( + past_mask.reshape([1, -1]).expand([batch_size, -1]), seq_lens_decoder.reshape([-1, 1]).astype(paddle.int32) + ) + past_mask = ( + past_mask.reshape([batch_size, 1, -1]) + .expand([batch_size, query_len, -1]) + .reshape([batch_size, 1, query_len, -1]) + ) + len_mask = paddle.greater_equal( + paddle.arange(0, query_len, dtype=paddle.int32).reshape([1, query_len]), + seq_lens_encoder.unsqueeze(-1).astype(paddle.int32), + ) + len_mask = len_mask.reshape([batch_size, 1, 1, query_len]) + attn_mask = paddle.triu(paddle.ones((batch_size, 1, query_len, query_len), dtype=paddle.bool), diagonal=1) + mask = attn_mask.logical_or(len_mask) + mask = paddle.concat((past_mask, mask), axis=-1) + off_value = -math.inf + attn_mask = paddle.zeros_like(mask, dtype=paddle.bfloat16).masked_fill_(mask, off_value) + attn_mask = paddle.unsqueeze(attn_mask, axis=1) + return attn_mask + + class AttentionBackend_HPU(AttentionBackend): """The base class of attention backends""" @@ -254,16 +280,40 @@ class HPUAttentionBackend(AttentionBackend_HPU): index_copy_(k_cache, forward_meta.block_indices, key_states, 0) index_copy_(v_cache, forward_meta.block_indices, value_states, 0) - out_linear_out = fused_sdpa_proj_t( - query_states, - key_value_states, - forward_meta.attn_mask, - None, - o_proj.weight, - scaling_factor=self.head_dim**-0.5, - causal=True, - softmax_mode=0, - ) + if forward_meta.block_list.shape == forward_meta.block_indices.shape: + out_linear_out = fused_sdpa_proj_t( + query_states, + key_value_states, + forward_meta.attn_mask, + None, + o_proj.weight, + scaling_factor=self.head_dim**-0.5, + causal=True, + softmax_mode=0, + ) + else: + key_states_with_context = k_cache.index_select(forward_meta.block_list) + val_states_with_context = v_cache.index_select(forward_meta.block_list) + key_value_states_with_context = paddle.stack( + [key_states_with_context, val_states_with_context], axis=0 + ).reshape([kv, B, -1, M, H]) + if forward_meta.attn_mask is None: + forward_meta.attn_mask = get_attention_mask( + forward_meta.seq_lens_encoder[forward_meta.batch_ids], + forward_meta.seq_lens_decoder[forward_meta.batch_ids], + query_states.shape[0], + query_states.shape[1], + ) + out_linear_out = fused_sdpa_proj_t( + query_states, + key_value_states_with_context, + forward_meta.attn_mask, + None, + o_proj.weight, + scaling_factor=self.head_dim**-0.5, + causal=False, + softmax_mode=0, + ) if self.nranks > 1: from fastdeploy.distributed.communication import ( @@ -297,11 +347,14 @@ class HPUAttentionBackend(AttentionBackend_HPU): qkv_proj.weight, qkv_proj.bias, o_proj.weight, + None, # past_key: not used in decode mode + None, # past_value: not used in decode mode self.head_dim, self.num_heads, scaling_factor=self.head_dim**-0.5, transpose=False, use_neox_style=layer.use_neox_rotary_style, + epsilon=1e-6, ) # all_reduce diff --git a/fastdeploy/worker/hpu_model_runner.py b/fastdeploy/worker/hpu_model_runner.py index c890e4f2c..1656c9c79 100644 --- a/fastdeploy/worker/hpu_model_runner.py +++ b/fastdeploy/worker/hpu_model_runner.py @@ -100,7 +100,9 @@ def recover_block_hpu( stop_flags, # hpu seq_lens_this_time, # hpu ori_seq_lens_encoder, # cpu + ori_seq_lens_decoder, # cpu seq_lens_encoder, # hpu + seq_lens_decoder, # hpu block_tables, # cpu free_list, # cpu free_list_len, # cpu @@ -116,6 +118,7 @@ def recover_block_hpu( for bid in range(recover_len.item()): recover_id = recover_block_list[bid].item() ori_seq_len_encoder = ori_seq_lens_encoder[recover_id].item() + ori_seq_len_decoder = ori_seq_lens_decoder[recover_id].item() step_idx_now = step_idx[recover_id].item() seq_len = ori_seq_len_encoder + step_idx_now encoder_block_len = encoder_block_lens[recover_id].item() @@ -123,13 +126,13 @@ def recover_block_hpu( seq_lens_this_time[recover_id] = seq_len seq_lens_encoder[recover_id] = seq_len + seq_lens_decoder[recover_id] = ori_seq_len_decoder stop_flags[recover_id] = False ori_free_list_len = free_list_len[0] - free_list_len[0] -= decoder_used_len - for i in range(decoder_used_len): block_tables[recover_id, encoder_block_len + i] = free_list[ori_free_list_len - i - 1] + free_list_len[0] -= decoder_used_len recover_block(input_ids, first_token_ids, pre_ids, next_tokens, recover_id, ori_seq_len_encoder, step_idx_now) @@ -160,13 +163,16 @@ def step_intel_hpu(share_inputs: Dict[str, paddle.Tensor], block_size: int, max_ max_model_len, ) if share_inputs["recover_lens"].item() > 0: + logger.info("recover block hpu happening ...") recover_block_hpu( share_inputs["recover_block_list"], share_inputs["recover_lens"], share_inputs["stop_flags"], share_inputs["seq_lens_this_time"], share_inputs["ori_seq_lens_encoder"], + share_inputs["ori_seq_lens_decoder"], share_inputs["seq_lens_encoder"], + share_inputs["seq_lens_decoder"], share_inputs["block_tables"], share_inputs["free_list"], share_inputs["free_list_len"], @@ -179,6 +185,7 @@ def step_intel_hpu(share_inputs: Dict[str, paddle.Tensor], block_size: int, max_ share_inputs["first_token_ids"], ) share_inputs["recover_lens"] = paddle.full([1], 0, dtype="int32").cpu() + share_inputs["not_need_stop"][0] = True # TODO: replace rebuild_padding_v3 in CustomDevice if we adopt this version pp optimization @@ -477,9 +484,11 @@ class HPUModelRunner(ModelRunnerBase): self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = token_chunk_size self.share_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0) + self.share_inputs["ori_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0) self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0) else: self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0) + self.share_inputs["ori_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0) self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0) self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length @@ -615,6 +624,7 @@ class HPUModelRunner(ModelRunnerBase): self.share_inputs["infer_seed"] = paddle.full([max_num_seqs, 1], 0, dtype="int64").cpu() self.share_inputs["first_token_ids"] = paddle.full([max_num_seqs, 1], -1, dtype="int64") self.share_inputs["ori_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32").cpu() + self.share_inputs["ori_seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32").cpu() self.share_inputs["system_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["system_ids"] = paddle.full([max_num_seqs, 1], -1, dtype="int32") @@ -707,6 +717,7 @@ class HPUModelRunner(ModelRunnerBase): self.share_inputs["seq_lens_decoder"], self.cache_config.block_size, self.model_config.dtype, + self.scheduler_config.max_num_batched_tokens, ) is_prompt = is_prompt.item() == 1 if is_prompt.item() > 0 else None if is_prompt is True: @@ -1023,7 +1034,18 @@ class HPUModelRunner(ModelRunnerBase): """ """ pass - def update_warmup_inputs(self, requests, is_decode=False): + def update_warmup_inputs(self, requests, is_decode=False, context_len=0) -> None: + """ + Update the shared input tensors for warmup requests. + Args: + requests (list): List of request dicts containing input data. + is_decode (bool, optional): If True, sets up inputs for decode phase. Defaults to False. + context_len (int, optional): The length of the context (prefix) to use for prefix caching during warmup. + If >0, this value is used to set the decoder sequence length for prefill (prefix caching). + Typically, set to the number of tokens in the prefix to be cached. Defaults to 0 (no prefix caching). + This parameter affects the warmup behavior for prefix caching by controlling how much of the input + is considered as context for the decoder during the prefill phase. + """ for i in range(len(requests)): request = requests[i] idx = request["idx"] @@ -1038,7 +1060,7 @@ class HPUModelRunner(ModelRunnerBase): self.share_inputs["step_idx"][idx : idx + 1] = 1 else: self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length - self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 + self.share_inputs["seq_lens_decoder"][idx : idx + 1] = context_len self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0 @@ -1073,35 +1095,48 @@ class HPUModelRunner(ModelRunnerBase): self.share_inputs["not_need_stop"][0] = True def warm_up_bucket(self) -> None: - max_prefill_batch = 3 # Hard-Code in FastDeploy/fastdeploy/engine/config.py + max_prefill_batch = int(os.getenv("MAX_PREFILL_NUM", "3")) warmup_max_model_len = min(int(os.environ.get("HPU_WARMUP_MODEL_LEN", 4096)), self.model_config.max_model_len) prefill_batchs = [] prefill_batch_step = int(os.environ.get("BATCH_STEP_PREFILL", 1)) + prefill_seq_step = int(os.environ.get("SEQUENCE_STEP_PREFILL", 128)) current_prefill_batch = prefill_batch_step while current_prefill_batch <= max_prefill_batch: prefill_batchs.append(int(current_prefill_batch)) current_prefill_batch += prefill_batch_step max_prefill_length = self.cache_config.block_size + warmup_max_model_len + prefill_context_block_step = int(os.environ.get("CONTEXT_BLOCK_STEP_PREFILL", 1)) for prefill_batch in prefill_batchs: - for prefill_length in range( - self.cache_config.block_size, max_prefill_length, self.cache_config.block_size + for prefill_length_with_context in range( + self.cache_config.block_size, max_prefill_length, prefill_seq_step ): - if prefill_length * prefill_batch > self.scheduler_config.max_num_batched_tokens: + if prefill_length_with_context * prefill_batch > self.scheduler_config.max_num_batched_tokens: continue - logger.info(f"Warmup prefill_batch: {prefill_batch}, prefill_length: {prefill_length} start") - requests = [ - { - "idx": i, - "input_ids": [5] * (prefill_length - 1), - "block_tables": list(range(prefill_length // self.cache_config.block_size)), - "eos_token_ids": [2], - } - for i in range(prefill_batch) - ] - self.update_warmup_inputs(requests, is_decode=False) - self.execute_model() - logger.info(f"warmup prefill_batch: {prefill_batch}, prefill_length: {prefill_length} done") + for context_len in range( + 0, prefill_length_with_context, self.cache_config.block_size * prefill_context_block_step + ): + prefill_length = prefill_length_with_context - context_len + logger.info( + f"Warmup prefill_batch: {prefill_batch}, prefill_length: {prefill_length}, context_len: {context_len} start" + ) + requests = [ + { + "idx": i, + "input_ids": [5] * (prefill_length_with_context - context_len - 1), + "block_tables": list(range(prefill_length_with_context // self.cache_config.block_size)), + "eos_token_ids": [2], + } + for i in range(prefill_batch) + ] + self.update_warmup_inputs(requests, is_decode=False, context_len=context_len) + self.execute_model() + logger.info( + f"warmup prefill_batch: {prefill_batch}, prefill_length: {prefill_length}, context_len: {context_len} done" + ) + # when disable prefix caching, only run context_len = 0 for each prefill_batch + if not self.cache_config.enable_prefix_caching: + break decode_batchs = [] decode_batch_step = int(os.environ.get("BATCH_STEP_DECODE", 4))