mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature] Support block scheduler v1 for FD (#2928)
* Support FD block scheduler v1 * Support FD block scheduler v1 * Support FD block scheduler v1 * Fix according to copilot review * Fix according to review * Remove is_dummy * Fix bug when real_bsz=1 * Fix infer first token cost time --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -61,9 +61,10 @@ else:
|
||||
speculate_step_system_cache,
|
||||
speculate_update_v3,
|
||||
step_paddle,
|
||||
step_reschedule,
|
||||
step_system_cache,
|
||||
update_inputs,
|
||||
step_reschedule,
|
||||
update_inputs_v1,
|
||||
)
|
||||
|
||||
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput
|
||||
@@ -152,6 +153,8 @@ def pre_process(
|
||||
def post_process_normal(
|
||||
sampler_output: SamplerOutput,
|
||||
model_output: ModelOutputData,
|
||||
share_inputs: Dict[str, paddle.Tensor],
|
||||
block_size: int = 64,
|
||||
save_each_rank: bool = False,
|
||||
skip_save_output: bool = False,
|
||||
) -> ModelRunnerOutput:
|
||||
@@ -219,17 +222,35 @@ def post_process_normal(
|
||||
|
||||
# 2. Update the input buffer of the model
|
||||
with paddle.framework._no_check_dy2st_diff():
|
||||
update_inputs(
|
||||
model_output.stop_flags,
|
||||
model_output.not_need_stop,
|
||||
model_output.seq_lens_this_time,
|
||||
model_output.seq_lens_encoder,
|
||||
model_output.seq_lens_decoder,
|
||||
model_output.input_ids,
|
||||
model_output.stop_nums,
|
||||
sampler_output.sampled_token_ids,
|
||||
model_output.is_block_step,
|
||||
)
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
update_inputs_v1(
|
||||
model_output.stop_flags,
|
||||
model_output.not_need_stop,
|
||||
model_output.seq_lens_this_time,
|
||||
model_output.seq_lens_encoder,
|
||||
model_output.seq_lens_decoder,
|
||||
share_inputs["step_seq_lens_decoder"],
|
||||
share_inputs["prompt_lens"],
|
||||
sampler_output.sampled_token_ids,
|
||||
model_output.input_ids,
|
||||
share_inputs["block_tables"],
|
||||
model_output.stop_nums,
|
||||
model_output.next_tokens,
|
||||
model_output.is_block_step,
|
||||
block_size,
|
||||
)
|
||||
else:
|
||||
update_inputs(
|
||||
model_output.stop_flags,
|
||||
model_output.not_need_stop,
|
||||
model_output.seq_lens_this_time,
|
||||
model_output.seq_lens_encoder,
|
||||
model_output.seq_lens_decoder,
|
||||
model_output.input_ids,
|
||||
model_output.stop_nums,
|
||||
sampler_output.sampled_token_ids,
|
||||
model_output.is_block_step,
|
||||
)
|
||||
# 3. Transmit the model's output and stop generation signal via message queue.
|
||||
# In the future, we will abandon this approach.
|
||||
if not skip_save_output:
|
||||
@@ -295,6 +316,8 @@ def post_process_specualate(model_output, save_each_rank: bool = False, skip_sav
|
||||
def post_process(
|
||||
sampler_output: SamplerOutput,
|
||||
model_output: ModelOutputData,
|
||||
share_inputs: Dict[str, paddle.Tensor],
|
||||
block_size: int = 64,
|
||||
save_each_rank: bool = False,
|
||||
speculative_decoding: bool = False,
|
||||
skip_save_output: bool = False,
|
||||
@@ -303,7 +326,7 @@ def post_process(
|
||||
if speculative_decoding:
|
||||
post_process_specualate(model_output, save_each_rank, skip_save_output)
|
||||
else:
|
||||
post_process_normal(sampler_output, model_output, save_each_rank, skip_save_output)
|
||||
post_process_normal(sampler_output, model_output, share_inputs, block_size, save_each_rank, skip_save_output)
|
||||
|
||||
|
||||
def step_cuda(
|
||||
|
Reference in New Issue
Block a user