mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-03 07:46:50 +08:00
[MTP] support expert-parellel in mtp (#2835)
This commit is contained in:
@@ -182,7 +182,7 @@ def post_process_normal(sampler_output: SamplerOutput,
|
||||
)
|
||||
|
||||
|
||||
def post_process_specualate(model_output, skip_save_output: bool = False):
|
||||
def post_process_specualate(model_output, save_each_rank: bool = False, skip_save_output: bool = False):
|
||||
""""""
|
||||
speculate_update_v3(
|
||||
model_output.seq_lens_encoder,
|
||||
@@ -204,7 +204,7 @@ def post_process_specualate(model_output, skip_save_output: bool = False):
|
||||
model_output.accept_num,
|
||||
model_output.not_need_stop,
|
||||
model_output.mp_rank,
|
||||
False,
|
||||
save_each_rank,
|
||||
)
|
||||
|
||||
speculate_clear_accept_nums(model_output.accept_num,
|
||||
@@ -231,7 +231,7 @@ def post_process(sampler_output: SamplerOutput,
|
||||
skip_save_output: bool = False) -> None:
|
||||
""" Post-processing steps after completing a single token generation. """
|
||||
if speculative_decoding:
|
||||
post_process_specualate(model_output, skip_save_output)
|
||||
post_process_specualate(model_output, save_each_rank, skip_save_output)
|
||||
else:
|
||||
post_process_normal(sampler_output, model_output, save_each_rank,
|
||||
skip_save_output)
|
||||
|
Reference in New Issue
Block a user