[Feature] Support block scheduler v1 for FD (#2928)

* Support FD block scheduler v1

* Support FD block scheduler v1

* Support FD block scheduler v1

* Fix according to copilot review

* Fix according to review

* Remove is_dummy

* Fix bug when real_bsz=1

* Fix infer first token cost time

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
chenjian
2025-07-23 20:31:31 +08:00
committed by GitHub
parent ca0f71bd39
commit 85a78d695d
16 changed files with 898 additions and 40 deletions

View File

@@ -24,7 +24,7 @@ from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request
from fastdeploy.engine.request import Request, RequestType
from fastdeploy.model_executor.graph_optimization.utils import (
profile_run_guard,
sot_warmup_guard,
@@ -42,6 +42,7 @@ from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler
from fastdeploy.model_executor.model_loader import get_model_from_loader
from fastdeploy.model_executor.ops.gpu import (
recover_decode_task,
set_value_by_flags_and_idx,
share_external_data,
)
@@ -56,6 +57,7 @@ from fastdeploy.platforms import current_platform
if not current_platform.is_dcu():
from fastdeploy.spec_decode import MTPProposer, NgramProposer
from fastdeploy import envs
from fastdeploy.input.mm_processor import DataProcessor
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
@@ -189,10 +191,97 @@ class GPUModelRunner(ModelRunnerBase):
elif request.structural_tag is not None:
schemata_key = ("structural_tag", request.structural_tag)
return (
self.guided_backend.get_logits_processor(schemata_key=schemata_key),
schemata_key,
)
return self.guided_backend.get_logits_processor(schemata_key=schemata_key), schemata_key
def insert_tasks_v1(self, req_dicts: List[Request]):
"""
Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1
"""
# NOTE(luotingdan): Lazy initialize kv cache
if "caches" not in self.share_inputs:
self.initialize_kv_cache()
req_len = len(req_dicts)
has_prefill_task = False
for i in range(req_len):
request = req_dicts[i]
idx = request.idx
if request.task_type.value == RequestType.PREFILL.value: # prefill task
logger.debug(f"Handle prefill request {request} at idx {idx}")
prefill_start_index = request.prefill_start_index
prefill_end_index = request.prefill_end_index
length = prefill_end_index - prefill_start_index
input_ids = request.prompt_token_ids + request.output_token_ids
self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(
input_ids[prefill_start_index:prefill_end_index]
)
encoder_block_num = len(request.block_tables)
self.share_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
self.share_inputs["block_tables"][idx : idx + 1, :] = -1
self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.block_tables, dtype="int32"
)
self.share_inputs["stop_flags"][idx : idx + 1] = False
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0
self.share_inputs["prompt_lens"][idx : idx + 1] = len(input_ids)
self.share_inputs["is_block_step"][idx : idx + 1] = False
self.share_inputs["step_idx"][idx : idx + 1] = (
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0
)
has_prefill_task = True
elif request.task_type.value == RequestType.DECODE.value: # decode task
logger.debug(f"Handle decode request {request} at idx {idx}")
encoder_block_num = len(request.block_tables)
self.share_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
self.share_inputs["block_tables"][idx : idx + 1, :] = -1
self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.block_tables, dtype="int32"
)
continue
else: # preempted task
logger.debug(f"Handle preempted request {request} at idx {idx}")
self.share_inputs["block_tables"][idx : idx + 1, :] = -1
self.share_inputs["stop_flags"][idx : idx + 1] = True
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = 0
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0
self.share_inputs["is_block_step"][idx : idx + 1] = False
continue
if len(request.eos_token_ids) < self.parallel_config.eos_tokens_lens:
request.eos_token_ids.append(request.eos_token_ids[0])
self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1)
self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7)
self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95)
self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0)
self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0)
self.share_inputs["presence_score"][idx : idx + 1] = request.get("presence_penalty", 0.0)
self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1)
self.share_inputs["max_dec_len"][idx : idx + 1] = request.get(
"max_tokens", self.model_config.max_model_len
)
self.share_inputs["first_token_ids"][idx : idx + 1] = self.share_inputs["input_ids"][idx : idx + 1, :1]
self.share_inputs["ori_seq_lens_encoder"][idx : idx + 1] = length
if request.get("seed") is not None:
self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed")
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
stop_seqs_num = len(request.get("stop_seqs_len"))
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
request.stop_seqs_len.append(0)
self.share_inputs["stop_seqs_len"][:] = np.array(request.stop_seqs_len, dtype="int32")
self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array(
request.get("stop_token_ids"), dtype="int64"
)
if has_prefill_task:
self.share_inputs["not_need_stop"][0] = True
def insert_prefill_inputs(self, req_dicts: List[Request]):
"""
@@ -591,6 +680,18 @@ class GPUModelRunner(ModelRunnerBase):
def _prepare_inputs(self) -> None:
"""Prepare the model inputs"""
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
recover_decode_task(
self.share_inputs["stop_flags"],
self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_encoder"],
self.share_inputs["seq_lens_decoder"],
self.share_inputs["step_seq_lens_decoder"],
self.share_inputs["block_tables"],
self.share_inputs["is_block_step"],
self.parallel_config.block_size,
)
# Remove padding
(
ids_remove_padding,
@@ -901,6 +1002,8 @@ class GPUModelRunner(ModelRunnerBase):
post_process(
sampler_output=sampler_output,
model_output=model_output_data,
share_inputs=self.share_inputs,
block_size=self.parallel_config.block_size,
speculative_decoding=self.speculative_decoding,
skip_save_output=True,
)
@@ -1165,6 +1268,8 @@ class GPUModelRunner(ModelRunnerBase):
post_process(
sampler_output=sampler_output,
model_output=model_output_data,
share_inputs=self.share_inputs,
block_size=self.parallel_config.block_size,
save_each_rank=self.parallel_config.use_ep,
speculative_decoding=self.speculative_decoding,
skip_save_output=skip_save_output,
@@ -1180,16 +1285,17 @@ class GPUModelRunner(ModelRunnerBase):
# 7. Updata 'infer_seed' and step_cuda()
self.share_inputs["infer_seed"].add_(self.infer_seed_increment)
self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED
step_cuda(
self.share_inputs,
self.parallel_config.block_size,
self.parallel_config.enc_dec_block_num,
self.speculative_config,
self.parallel_config.enable_prefix_caching,
)
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
step_cuda(
self.share_inputs,
self.parallel_config.block_size,
self.parallel_config.enc_dec_block_num,
self.speculative_config,
self.parallel_config.enable_prefix_caching,
)
self._update_chunked_prefill(model_forward_batch)
self._add_cache(model_forward_batch)
self._update_chunked_prefill(model_forward_batch)
self._add_cache(model_forward_batch)
return None
def _add_cache(self, model_forward_batch) -> None:

View File

@@ -22,6 +22,7 @@ import paddle
import pynvml
from paddle import nn
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request
from fastdeploy.platforms import current_platform
@@ -183,7 +184,10 @@ class GpuWorker(WorkerBase):
TODO(gongshaotian):The scheduler should schedule the handling of prefill,
and workers and modelrunners should not perceive it.
"""
self.model_runner.insert_prefill_inputs(req_dicts=req_dicts)
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.model_runner.insert_tasks_v1(req_dicts=req_dicts)
else:
self.model_runner.insert_prefill_inputs(req_dicts=req_dicts)
def graph_optimize_and_warm_up_model(self) -> None:
"""