[Feature] Support block scheduler v1 for FD (#2928)

* Support FD block scheduler v1

* Support FD block scheduler v1

* Support FD block scheduler v1

* Fix according to copilot review

* Fix according to review

* Remove is_dummy

* Fix bug when real_bsz=1

* Fix infer first token cost time

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
chenjian
2025-07-23 20:31:31 +08:00
committed by GitHub
parent ca0f71bd39
commit 85a78d695d
16 changed files with 898 additions and 40 deletions

View File

@@ -61,9 +61,10 @@ else:
speculate_step_system_cache,
speculate_update_v3,
step_paddle,
step_reschedule,
step_system_cache,
update_inputs,
step_reschedule,
update_inputs_v1,
)
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput
@@ -152,6 +153,8 @@ def pre_process(
def post_process_normal(
sampler_output: SamplerOutput,
model_output: ModelOutputData,
share_inputs: Dict[str, paddle.Tensor],
block_size: int = 64,
save_each_rank: bool = False,
skip_save_output: bool = False,
) -> ModelRunnerOutput:
@@ -219,17 +222,35 @@ def post_process_normal(
# 2. Update the input buffer of the model
with paddle.framework._no_check_dy2st_diff():
update_inputs(
model_output.stop_flags,
model_output.not_need_stop,
model_output.seq_lens_this_time,
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
model_output.input_ids,
model_output.stop_nums,
sampler_output.sampled_token_ids,
model_output.is_block_step,
)
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
update_inputs_v1(
model_output.stop_flags,
model_output.not_need_stop,
model_output.seq_lens_this_time,
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
share_inputs["step_seq_lens_decoder"],
share_inputs["prompt_lens"],
sampler_output.sampled_token_ids,
model_output.input_ids,
share_inputs["block_tables"],
model_output.stop_nums,
model_output.next_tokens,
model_output.is_block_step,
block_size,
)
else:
update_inputs(
model_output.stop_flags,
model_output.not_need_stop,
model_output.seq_lens_this_time,
model_output.seq_lens_encoder,
model_output.seq_lens_decoder,
model_output.input_ids,
model_output.stop_nums,
sampler_output.sampled_token_ids,
model_output.is_block_step,
)
# 3. Transmit the model's output and stop generation signal via message queue.
# In the future, we will abandon this approach.
if not skip_save_output:
@@ -295,6 +316,8 @@ def post_process_specualate(model_output, save_each_rank: bool = False, skip_sav
def post_process(
sampler_output: SamplerOutput,
model_output: ModelOutputData,
share_inputs: Dict[str, paddle.Tensor],
block_size: int = 64,
save_each_rank: bool = False,
speculative_decoding: bool = False,
skip_save_output: bool = False,
@@ -303,7 +326,7 @@ def post_process(
if speculative_decoding:
post_process_specualate(model_output, save_each_rank, skip_save_output)
else:
post_process_normal(sampler_output, model_output, save_each_rank, skip_save_output)
post_process_normal(sampler_output, model_output, share_inputs, block_size, save_each_rank, skip_save_output)
def step_cuda(