[Intel HPU] enable kv cache scheduler v1 for hpu (#5648)

* [Intel HPU] enable kv cache scheduler v1 for hpu

* fix copilt comments
This commit is contained in:
fmiao2372
2025-12-19 12:03:39 +08:00
committed by GitHub
parent fc452c8e29
commit a8fce47195
6 changed files with 156 additions and 17 deletions

View File

@@ -13,7 +13,7 @@ export CACHE_QUEUE_PORT=8003
export HABANA_PROFILE=0
export HPU_VISIBLE_DEVICES=0
rm -rf log 2>/dev/null
FD_ENC_DEC_BLOCK_NUM=8 HPU_PERF_BREAKDOWN_SYNC_MODE=1 HPU_WARMUP_BUCKET=1 HPU_WARMUP_MODEL_LEN=4096 FD_ATTENTION_BACKEND=HPU_ATTN \
ENABLE_V1_KVCACHE_SCHEDULER=1 FD_ENC_DEC_BLOCK_NUM=8 HPU_PERF_BREAKDOWN_SYNC_MODE=1 HPU_WARMUP_BUCKET=1 HPU_WARMUP_MODEL_LEN=4096 FD_ATTENTION_BACKEND=HPU_ATTN \
python -m fastdeploy.entrypoints.openai.api_server \
--model ERNIE-4.5-21B-A3B-Paddle \
--port ${SERVER_PORT} \
@@ -32,7 +32,7 @@ FD_ENC_DEC_BLOCK_NUM=8 HPU_PERF_BREAKDOWN_SYNC_MODE=1 HPU_WARMUP_BUCKET=1 HPU_WA
# (2k + 1k) / 128(block_size) * 128(batch) = 3072
# export HPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# rm -rf log 2>/dev/null
# FD_ENC_DEC_BLOCK_NUM=8 HPU_PERF_BREAKDOWN_SYNC_MODE=1 HPU_WARMUP_BUCKET=1 HPU_WARMUP_MODEL_LEN=3072 FD_ATTENTION_BACKEND=HPU_ATTN \
# ENABLE_V1_KVCACHE_SCHEDULER=1 FD_ENC_DEC_BLOCK_NUM=8 HPU_PERF_BREAKDOWN_SYNC_MODE=1 HPU_WARMUP_BUCKET=1 HPU_WARMUP_MODEL_LEN=3072 FD_ATTENTION_BACKEND=HPU_ATTN \
# python -m fastdeploy.entrypoints.openai.api_server \
# --model ERNIE-4.5-300B-A47B-Paddle \
# --port ${SERVER_PORT} \

View File

@@ -27,7 +27,7 @@ else
fi
rm -rf log 2>/dev/null
FD_ENC_DEC_BLOCK_NUM=8 HPU_PERF_BREAKDOWN_SYNC_MODE=1 HPU_WARMUP_BUCKET=0 FD_ATTENTION_BACKEND=HPU_ATTN ENABLE_V1_KVCACHE_SCHEDULER=0 \
ENABLE_V1_KVCACHE_SCHEDULER=1 FD_ENC_DEC_BLOCK_NUM=8 HPU_PERF_BREAKDOWN_SYNC_MODE=1 HPU_WARMUP_BUCKET=0 FD_ATTENTION_BACKEND=HPU_ATTN ENABLE_V1_KVCACHE_SCHEDULER=0 \
python -m fastdeploy.entrypoints.openai.api_server --model ${MODEL} --port ${SERVER_PORT} \
--engine-worker-queue-port ${ENGINE_WORKER_QUEUE_PORT} --metrics-port ${METRICS_PORT} \
--cache-queue-port ${CACHE_QUEUE_PORT} --tensor-parallel-size ${CARD_NUM} --max-model-len 16384 \

View File

@@ -540,6 +540,7 @@ class EngineArgs:
or current_platform.is_xpu()
or current_platform.is_maca()
or current_platform.is_iluvatar()
or current_platform.is_intel_hpu()
):
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0

View File

@@ -26,7 +26,7 @@ from paddleformers.utils.log import logger
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce_custom
from fastdeploy.engine.request import Request
from fastdeploy.engine.request import Request, RequestType
# from fastdeploy.spec_decode import MTPProposer, NgramProposer
from fastdeploy.model_executor.forward_meta import HPUForwardMeta
@@ -474,6 +474,122 @@ class HPUModelRunner(ModelRunnerBase):
return self.guided_backend.get_logits_processor(schemata_key=schemata_key), schemata_key
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = None):
"""
Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1
req_dict: A list of Request dict
num_running_requests: batch_size
"""
# NOTE(luotingdan): Lazy initialize kv cache
if "caches" not in self.share_inputs:
self.initialize_kv_cache()
req_len = len(req_dicts)
has_prefill_task = False
has_decode_task = False
for i in range(req_len):
request = req_dicts[i]
idx = request.idx
if request.task_type.value == RequestType.PREFILL.value: # prefill task
prefill_start_index = request.prefill_start_index
prefill_end_index = request.prefill_end_index
length = prefill_end_index - prefill_start_index
if isinstance(request.prompt_token_ids, np.ndarray):
prompt_token_ids = request.prompt_token_ids.tolist()
else:
prompt_token_ids = request.prompt_token_ids
input_ids = prompt_token_ids + request.output_token_ids
prompt_len = len(prompt_token_ids)
self.share_inputs["prompt_ids"][idx : idx + 1, :prompt_len] = np.array(prompt_token_ids, dtype="int64")
logger.debug(
f"Handle prefill request {request} at idx {idx}, "
f"{prefill_start_index=}, {prefill_end_index=}, "
f"need_prefilled_token_num={len(input_ids)},"
f"prompt_len={prompt_len}"
)
self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(
input_ids[prefill_start_index:prefill_end_index]
)
encoder_block_num = len(request.block_tables)
self.share_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
self.share_inputs["block_tables"][idx : idx + 1, :] = -1
self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.block_tables, dtype="int32"
)
self.share_inputs["stop_flags"][idx : idx + 1] = False
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0
self.share_inputs["prompt_lens"][idx : idx + 1] = len(input_ids)
self.share_inputs["is_block_step"][idx : idx + 1] = False
# self.share_inputs["is_chunk_step"][idx : idx + 1] = prefill_end_index < len(input_ids)
self.share_inputs["step_idx"][idx : idx + 1] = (
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0
)
self.share_inputs["pre_ids"][idx : idx + 1] = -1
has_prefill_task = True
elif request.task_type.value == RequestType.DECODE.value: # decode task
logger.debug(f"Handle decode request {request} at idx {idx}")
encoder_block_num = len(request.block_tables)
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = 1
self.share_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
self.share_inputs["block_tables"][idx : idx + 1, :] = -1
self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.block_tables, dtype="int32"
)
if self.share_inputs["is_block_step"][idx]: # has tasks to continue to decode
has_decode_task = True
continue
else: # preempted task
logger.info(f"Handle preempted request {request} at idx {idx}")
self.share_inputs["block_tables"][idx : idx + 1, :] = -1
self.share_inputs["stop_flags"][idx : idx + 1] = True
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = 0
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0
self.share_inputs["is_block_step"][idx : idx + 1] = False
continue
assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens
self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1)
self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7)
self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95)
self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0)
self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0)
self.share_inputs["presence_score"][idx : idx + 1] = request.get("presence_penalty", 0.0)
self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1)
self.share_inputs["max_dec_len"][idx : idx + 1] = request.get(
"max_tokens", self.model_config.max_model_len
)
self.share_inputs["first_token_ids"][idx : idx + 1] = self.share_inputs["input_ids"][idx : idx + 1, :1]
if request.get("seed") is not None:
self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed")
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
stop_seqs_num = len(request.get("stop_seqs_len"))
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
request.sampling_params.stop_seqs_len.append(0)
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = np.array(
request.sampling_params.stop_seqs_len, dtype="int32"
)
self.share_inputs["stop_seqs"][
idx : idx + 1, :stop_seqs_num, : len(request.get("stop_token_ids")[0])
] = np.array(request.get("stop_token_ids"), dtype="int64")
else:
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
if has_prefill_task or has_decode_task:
self.share_inputs["not_need_stop"][0] = True
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int = None):
"""
Process inputs for prefill tasks and insert it to share_inputs buffer
@@ -581,11 +697,15 @@ class HPUModelRunner(ModelRunnerBase):
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
stop_seqs_num = len(request.get("stop_seqs_len"))
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
request.stop_seqs_len.append(0)
self.share_inputs["stop_seqs_len"][:] = np.array(request.stop_seqs_len, dtype="int32")
self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array(
request.get("stop_token_ids"), dtype="int64"
request.sampling_params.stop_seqs_len.append(0)
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = np.array(
request.sampling_params.stop_seqs_len, dtype="int32"
)
self.share_inputs["stop_seqs"][
idx : idx + 1, :stop_seqs_num, : len(request.get("stop_token_ids")[0])
] = np.array(request.get("stop_token_ids"), dtype="int64")
else:
self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0
self.sampler.apply_logits_processor(idx, request.get("logits_processor"), prefill_tokens)
@@ -635,6 +755,9 @@ class HPUModelRunner(ModelRunnerBase):
self.share_inputs["input_ids"] = paddle.full(
[max_num_seqs, self.model_config.max_model_len], self.model_config.pad_token_id, dtype="int64"
)
self.share_inputs["prompt_ids"] = paddle.full(
[max_num_seqs, self.model_config.max_model_len], self.model_config.pad_token_id, dtype="int64"
)
self.share_inputs["eos_token_id"] = paddle.full([self.model_config.eos_tokens_lens, 1], 0, dtype="int64")
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32")
self.share_inputs["temperature"] = paddle.full(
@@ -659,6 +782,7 @@ class HPUModelRunner(ModelRunnerBase):
self.share_inputs["seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["step_seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["step_seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["prompt_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
self.share_inputs["step_idx"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
self.share_inputs["not_need_stop"] = paddle.full(
[1], False, dtype="bool"
@@ -724,9 +848,17 @@ class HPUModelRunner(ModelRunnerBase):
self.share_inputs["free_list_len"] = paddle.full([1], self.free_list_len, dtype="int32").cpu()
# Initialize stop seqs
self.share_inputs["stop_seqs_len"] = paddle.full([self.model_config.max_stop_seqs_num], 0, dtype="int32")
self.share_inputs["stop_seqs_len"] = paddle.full(
[max_num_seqs, self.model_config.max_stop_seqs_num], 0, dtype="int32"
)
self.share_inputs["stop_seqs"] = paddle.full(
[self.model_config.max_stop_seqs_num, self.model_config.stop_seqs_max_len], -1, dtype="int32"
[
max_num_seqs,
self.model_config.max_stop_seqs_num,
self.model_config.stop_seqs_max_len,
],
-1,
dtype="int64",
)
if self.speculative_decoding:
max_draft_token_num = self.speculative_config.num_speculative_tokens
@@ -1401,12 +1533,13 @@ class HPUModelRunner(ModelRunnerBase):
# 7. Updata 'infer_seed' and step_cuda()
self.share_inputs["infer_seed"].add_(self.infer_seed_increment)
self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED
start_time = time.time()
step_intel_hpu(self.share_inputs, self.cache_config.block_size, self.model_config.max_model_len)
end_time = time.time()
execution_time = (end_time - start_time) * 1000
hpu_model_runner_profile_logger.info(f"StepPaddle execution time(ms): {execution_time}, BT={real_bs}")
self._update_chunked_prefill(model_forward_batch)
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
start_time = time.time()
step_intel_hpu(self.share_inputs, self.cache_config.block_size, self.model_config.max_model_len)
end_time = time.time()
execution_time = (end_time - start_time) * 1000
hpu_model_runner_profile_logger.info(f"StepPaddle execution time(ms): {execution_time}, BT={real_bs}")
self._update_chunked_prefill(model_forward_batch)
if int(os.environ.get("HABANA_PROFILE", 0)) == 1:
self.prof.step()

View File

@@ -23,6 +23,7 @@ import paddle
import paddle.nn as nn
from paddle.base import core
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request
from fastdeploy.utils import get_logger, set_random_seed
@@ -182,7 +183,10 @@ class HpuWorker(WorkerBase):
TODO(gongshaotian):The scheduler should schedule the handling of prefill,
and workers and modelrunners should not perceive it.
"""
self.model_runner.insert_prefill_inputs(req_dicts=req_dicts, num_running_requests=num_running_requests)
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.model_runner.insert_tasks_v1(req_dicts=req_dicts, num_running_requests=num_running_requests)
else:
self.model_runner.insert_prefill_inputs(req_dicts=req_dicts, num_running_requests=num_running_requests)
def graph_optimize_and_warm_up_model(self) -> None:
"""

View File

@@ -979,6 +979,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
or current_platform.is_xpu()
or current_platform.is_maca()
or current_platform.is_iluvatar()
or current_platform.is_intel_hpu()
):
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported.")
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0