mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
[MTP] support expert-parellel in mtp (#2835)
This commit is contained in:
@@ -250,7 +250,8 @@ def load_composite_checkpoint(
|
||||
# 2. Tensor Parallel (TP)
|
||||
# 3. Pre-sharded (pre-split)
|
||||
"""
|
||||
if fd_config.parallel_config.use_ep:
|
||||
if fd_config.parallel_config.use_ep and \
|
||||
fd_config.speculative_config.model_type != "mtp":
|
||||
state_dict = load_ep_checkpoint(model_path,
|
||||
fd_config.model_config,
|
||||
return_numpy=True)
|
||||
|
@@ -182,7 +182,7 @@ def post_process_normal(sampler_output: SamplerOutput,
|
||||
)
|
||||
|
||||
|
||||
def post_process_specualate(model_output, skip_save_output: bool = False):
|
||||
def post_process_specualate(model_output, save_each_rank: bool = False, skip_save_output: bool = False):
|
||||
""""""
|
||||
speculate_update_v3(
|
||||
model_output.seq_lens_encoder,
|
||||
@@ -204,7 +204,7 @@ def post_process_specualate(model_output, skip_save_output: bool = False):
|
||||
model_output.accept_num,
|
||||
model_output.not_need_stop,
|
||||
model_output.mp_rank,
|
||||
False,
|
||||
save_each_rank,
|
||||
)
|
||||
|
||||
speculate_clear_accept_nums(model_output.accept_num,
|
||||
@@ -231,7 +231,7 @@ def post_process(sampler_output: SamplerOutput,
|
||||
skip_save_output: bool = False) -> None:
|
||||
""" Post-processing steps after completing a single token generation. """
|
||||
if speculative_decoding:
|
||||
post_process_specualate(model_output, skip_save_output)
|
||||
post_process_specualate(model_output, save_each_rank, skip_save_output)
|
||||
else:
|
||||
post_process_normal(sampler_output, model_output, save_each_rank,
|
||||
skip_save_output)
|
||||
|
Reference in New Issue
Block a user