[Intel HPU] enable level 1 prefix caching and fix some bugs (#4971)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled

* [Intel HPU] enable prefix caching and dense tp moe ep and fix some bugs

* update code by copilot

* remove dense tp and moe ep code
This commit is contained in:
fmiao2372
2025-11-14 19:42:50 +08:00
committed by GitHub
parent 0e819cd596
commit e43a5fc055
4 changed files with 130 additions and 38 deletions

View File

@@ -478,7 +478,7 @@ class EngineArgs:
self.enable_prefix_caching = False
if self.speculative_config is not None:
self.enable_prefix_caching = False
if not current_platform.is_cuda() and not current_platform.is_xpu():
if not current_platform.is_cuda() and not current_platform.is_xpu() and not current_platform.is_intel_hpu():
self.enable_prefix_caching = False
# if self.dynamic_load_weight:
# self.enable_prefix_caching = False

View File

@@ -40,6 +40,7 @@ from fastdeploy.engine.expert_service import start_data_parallel_service
from fastdeploy.engine.request import Request
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.platforms import current_platform
from fastdeploy.utils import EngineError, console_logger, envs, llm_logger
@@ -136,8 +137,9 @@ class LLMEngine:
# If block numer is specified and model is deployed in mixed mode, start cache manager first
if not self.do_profile and self.cfg.scheduler_config.splitwise_role != "mixed":
device_ids = self.cfg.parallel_config.device_ids.split(",")
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix)
if not current_platform.is_intel_hpu():
device_ids = self.cfg.parallel_config.device_ids.split(",")
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix)
# Start workers
self.worker_proc = self._start_worker_service()
@@ -170,8 +172,9 @@ class LLMEngine:
if self.do_profile:
self._stop_profile()
elif self.cfg.scheduler_config.splitwise_role == "mixed" and self.cfg.cache_config.enable_prefix_caching:
device_ids = self.cfg.parallel_config.device_ids.split(",")
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix)
if not current_platform.is_intel_hpu():
device_ids = self.cfg.parallel_config.device_ids.split(",")
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix)
# Launch components: scheduler, cache_manager, expert_service et.al.
if self.cfg.scheduler_config.splitwise_role != "mixed":
@@ -673,8 +676,9 @@ class LLMEngine:
self.cfg.cache_config.reset(num_gpu_blocks)
self.engine.resource_manager.reset_cache_config(self.cfg.cache_config)
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
device_ids = self.cfg.parallel_config.device_ids.split(",")
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix)
if not current_platform.is_intel_hpu():
device_ids = self.cfg.parallel_config.device_ids.split(",")
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix)
def check_health(self, time_interval_threashold=30):
"""

View File

@@ -16,6 +16,7 @@
from __future__ import annotations
import math
import os
from abc import abstractmethod
from dataclasses import dataclass, field
@@ -39,6 +40,31 @@ if TYPE_CHECKING:
from fastdeploy.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
def get_attention_mask(seq_lens_encoder, seq_lens_decoder, batch_size, query_len):
max_context_len = int(paddle.max(seq_lens_decoder).item())
past_mask = paddle.arange(0, max_context_len, dtype=paddle.int32)
past_mask = paddle.greater_equal(
past_mask.reshape([1, -1]).expand([batch_size, -1]), seq_lens_decoder.reshape([-1, 1]).astype(paddle.int32)
)
past_mask = (
past_mask.reshape([batch_size, 1, -1])
.expand([batch_size, query_len, -1])
.reshape([batch_size, 1, query_len, -1])
)
len_mask = paddle.greater_equal(
paddle.arange(0, query_len, dtype=paddle.int32).reshape([1, query_len]),
seq_lens_encoder.unsqueeze(-1).astype(paddle.int32),
)
len_mask = len_mask.reshape([batch_size, 1, 1, query_len])
attn_mask = paddle.triu(paddle.ones((batch_size, 1, query_len, query_len), dtype=paddle.bool), diagonal=1)
mask = attn_mask.logical_or(len_mask)
mask = paddle.concat((past_mask, mask), axis=-1)
off_value = -math.inf
attn_mask = paddle.zeros_like(mask, dtype=paddle.bfloat16).masked_fill_(mask, off_value)
attn_mask = paddle.unsqueeze(attn_mask, axis=1)
return attn_mask
class AttentionBackend_HPU(AttentionBackend):
"""The base class of attention backends"""
@@ -254,16 +280,40 @@ class HPUAttentionBackend(AttentionBackend_HPU):
index_copy_(k_cache, forward_meta.block_indices, key_states, 0)
index_copy_(v_cache, forward_meta.block_indices, value_states, 0)
out_linear_out = fused_sdpa_proj_t(
query_states,
key_value_states,
forward_meta.attn_mask,
None,
o_proj.weight,
scaling_factor=self.head_dim**-0.5,
causal=True,
softmax_mode=0,
)
if forward_meta.block_list.shape == forward_meta.block_indices.shape:
out_linear_out = fused_sdpa_proj_t(
query_states,
key_value_states,
forward_meta.attn_mask,
None,
o_proj.weight,
scaling_factor=self.head_dim**-0.5,
causal=True,
softmax_mode=0,
)
else:
key_states_with_context = k_cache.index_select(forward_meta.block_list)
val_states_with_context = v_cache.index_select(forward_meta.block_list)
key_value_states_with_context = paddle.stack(
[key_states_with_context, val_states_with_context], axis=0
).reshape([kv, B, -1, M, H])
if forward_meta.attn_mask is None:
forward_meta.attn_mask = get_attention_mask(
forward_meta.seq_lens_encoder[forward_meta.batch_ids],
forward_meta.seq_lens_decoder[forward_meta.batch_ids],
query_states.shape[0],
query_states.shape[1],
)
out_linear_out = fused_sdpa_proj_t(
query_states,
key_value_states_with_context,
forward_meta.attn_mask,
None,
o_proj.weight,
scaling_factor=self.head_dim**-0.5,
causal=False,
softmax_mode=0,
)
if self.nranks > 1:
from fastdeploy.distributed.communication import (
@@ -297,11 +347,14 @@ class HPUAttentionBackend(AttentionBackend_HPU):
qkv_proj.weight,
qkv_proj.bias,
o_proj.weight,
None, # past_key: not used in decode mode
None, # past_value: not used in decode mode
self.head_dim,
self.num_heads,
scaling_factor=self.head_dim**-0.5,
transpose=False,
use_neox_style=layer.use_neox_rotary_style,
epsilon=1e-6,
)
# all_reduce

View File

@@ -100,7 +100,9 @@ def recover_block_hpu(
stop_flags, # hpu
seq_lens_this_time, # hpu
ori_seq_lens_encoder, # cpu
ori_seq_lens_decoder, # cpu
seq_lens_encoder, # hpu
seq_lens_decoder, # hpu
block_tables, # cpu
free_list, # cpu
free_list_len, # cpu
@@ -116,6 +118,7 @@ def recover_block_hpu(
for bid in range(recover_len.item()):
recover_id = recover_block_list[bid].item()
ori_seq_len_encoder = ori_seq_lens_encoder[recover_id].item()
ori_seq_len_decoder = ori_seq_lens_decoder[recover_id].item()
step_idx_now = step_idx[recover_id].item()
seq_len = ori_seq_len_encoder + step_idx_now
encoder_block_len = encoder_block_lens[recover_id].item()
@@ -123,13 +126,13 @@ def recover_block_hpu(
seq_lens_this_time[recover_id] = seq_len
seq_lens_encoder[recover_id] = seq_len
seq_lens_decoder[recover_id] = ori_seq_len_decoder
stop_flags[recover_id] = False
ori_free_list_len = free_list_len[0]
free_list_len[0] -= decoder_used_len
for i in range(decoder_used_len):
block_tables[recover_id, encoder_block_len + i] = free_list[ori_free_list_len - i - 1]
free_list_len[0] -= decoder_used_len
recover_block(input_ids, first_token_ids, pre_ids, next_tokens, recover_id, ori_seq_len_encoder, step_idx_now)
@@ -160,13 +163,16 @@ def step_intel_hpu(share_inputs: Dict[str, paddle.Tensor], block_size: int, max_
max_model_len,
)
if share_inputs["recover_lens"].item() > 0:
logger.info("recover block hpu happening ...")
recover_block_hpu(
share_inputs["recover_block_list"],
share_inputs["recover_lens"],
share_inputs["stop_flags"],
share_inputs["seq_lens_this_time"],
share_inputs["ori_seq_lens_encoder"],
share_inputs["ori_seq_lens_decoder"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["block_tables"],
share_inputs["free_list"],
share_inputs["free_list_len"],
@@ -179,6 +185,7 @@ def step_intel_hpu(share_inputs: Dict[str, paddle.Tensor], block_size: int, max_
share_inputs["first_token_ids"],
)
share_inputs["recover_lens"] = paddle.full([1], 0, dtype="int32").cpu()
share_inputs["not_need_stop"][0] = True
# TODO: replace rebuild_padding_v3 in CustomDevice if we adopt this version pp optimization
@@ -477,9 +484,11 @@ class HPUModelRunner(ModelRunnerBase):
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = token_chunk_size
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
self.share_inputs["ori_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
else:
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
self.share_inputs["ori_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length
@@ -615,6 +624,7 @@ class HPUModelRunner(ModelRunnerBase):
self.share_inputs["infer_seed"] = paddle.full([max_num_seqs, 1], 0, dtype="int64").cpu()
self.share_inputs["first_token_ids"] = paddle.full([max_num_seqs, 1], -1, dtype="int64")
self.share_inputs["ori_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32").cpu()
self.share_inputs["ori_seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32").cpu()
self.share_inputs["system_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["system_ids"] = paddle.full([max_num_seqs, 1], -1, dtype="int32")
@@ -707,6 +717,7 @@ class HPUModelRunner(ModelRunnerBase):
self.share_inputs["seq_lens_decoder"],
self.cache_config.block_size,
self.model_config.dtype,
self.scheduler_config.max_num_batched_tokens,
)
is_prompt = is_prompt.item() == 1 if is_prompt.item() > 0 else None
if is_prompt is True:
@@ -1023,7 +1034,18 @@ class HPUModelRunner(ModelRunnerBase):
""" """
pass
def update_warmup_inputs(self, requests, is_decode=False):
def update_warmup_inputs(self, requests, is_decode=False, context_len=0) -> None:
"""
Update the shared input tensors for warmup requests.
Args:
requests (list): List of request dicts containing input data.
is_decode (bool, optional): If True, sets up inputs for decode phase. Defaults to False.
context_len (int, optional): The length of the context (prefix) to use for prefix caching during warmup.
If >0, this value is used to set the decoder sequence length for prefill (prefix caching).
Typically, set to the number of tokens in the prefix to be cached. Defaults to 0 (no prefix caching).
This parameter affects the warmup behavior for prefix caching by controlling how much of the input
is considered as context for the decoder during the prefill phase.
"""
for i in range(len(requests)):
request = requests[i]
idx = request["idx"]
@@ -1038,7 +1060,7 @@ class HPUModelRunner(ModelRunnerBase):
self.share_inputs["step_idx"][idx : idx + 1] = 1
else:
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = context_len
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length
self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0
@@ -1073,35 +1095,48 @@ class HPUModelRunner(ModelRunnerBase):
self.share_inputs["not_need_stop"][0] = True
def warm_up_bucket(self) -> None:
max_prefill_batch = 3 # Hard-Code in FastDeploy/fastdeploy/engine/config.py
max_prefill_batch = int(os.getenv("MAX_PREFILL_NUM", "3"))
warmup_max_model_len = min(int(os.environ.get("HPU_WARMUP_MODEL_LEN", 4096)), self.model_config.max_model_len)
prefill_batchs = []
prefill_batch_step = int(os.environ.get("BATCH_STEP_PREFILL", 1))
prefill_seq_step = int(os.environ.get("SEQUENCE_STEP_PREFILL", 128))
current_prefill_batch = prefill_batch_step
while current_prefill_batch <= max_prefill_batch:
prefill_batchs.append(int(current_prefill_batch))
current_prefill_batch += prefill_batch_step
max_prefill_length = self.cache_config.block_size + warmup_max_model_len
prefill_context_block_step = int(os.environ.get("CONTEXT_BLOCK_STEP_PREFILL", 1))
for prefill_batch in prefill_batchs:
for prefill_length in range(
self.cache_config.block_size, max_prefill_length, self.cache_config.block_size
for prefill_length_with_context in range(
self.cache_config.block_size, max_prefill_length, prefill_seq_step
):
if prefill_length * prefill_batch > self.scheduler_config.max_num_batched_tokens:
if prefill_length_with_context * prefill_batch > self.scheduler_config.max_num_batched_tokens:
continue
logger.info(f"Warmup prefill_batch: {prefill_batch}, prefill_length: {prefill_length} start")
requests = [
{
"idx": i,
"input_ids": [5] * (prefill_length - 1),
"block_tables": list(range(prefill_length // self.cache_config.block_size)),
"eos_token_ids": [2],
}
for i in range(prefill_batch)
]
self.update_warmup_inputs(requests, is_decode=False)
self.execute_model()
logger.info(f"warmup prefill_batch: {prefill_batch}, prefill_length: {prefill_length} done")
for context_len in range(
0, prefill_length_with_context, self.cache_config.block_size * prefill_context_block_step
):
prefill_length = prefill_length_with_context - context_len
logger.info(
f"Warmup prefill_batch: {prefill_batch}, prefill_length: {prefill_length}, context_len: {context_len} start"
)
requests = [
{
"idx": i,
"input_ids": [5] * (prefill_length_with_context - context_len - 1),
"block_tables": list(range(prefill_length_with_context // self.cache_config.block_size)),
"eos_token_ids": [2],
}
for i in range(prefill_batch)
]
self.update_warmup_inputs(requests, is_decode=False, context_len=context_len)
self.execute_model()
logger.info(
f"warmup prefill_batch: {prefill_batch}, prefill_length: {prefill_length}, context_len: {context_len} done"
)
# when disable prefix caching, only run context_len = 0 for each prefill_batch
if not self.cache_config.enable_prefix_caching:
break
decode_batchs = []
decode_batch_step = int(os.environ.get("BATCH_STEP_DECODE", 4))