mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-03 15:56:49 +08:00
[MTP] support expert-parellel in mtp (#2835)
This commit is contained in:
@@ -250,7 +250,8 @@ def load_composite_checkpoint(
|
|||||||
# 2. Tensor Parallel (TP)
|
# 2. Tensor Parallel (TP)
|
||||||
# 3. Pre-sharded (pre-split)
|
# 3. Pre-sharded (pre-split)
|
||||||
"""
|
"""
|
||||||
if fd_config.parallel_config.use_ep:
|
if fd_config.parallel_config.use_ep and \
|
||||||
|
fd_config.speculative_config.model_type != "mtp":
|
||||||
state_dict = load_ep_checkpoint(model_path,
|
state_dict = load_ep_checkpoint(model_path,
|
||||||
fd_config.model_config,
|
fd_config.model_config,
|
||||||
return_numpy=True)
|
return_numpy=True)
|
||||||
|
@@ -182,7 +182,7 @@ def post_process_normal(sampler_output: SamplerOutput,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def post_process_specualate(model_output, skip_save_output: bool = False):
|
def post_process_specualate(model_output, save_each_rank: bool = False, skip_save_output: bool = False):
|
||||||
""""""
|
""""""
|
||||||
speculate_update_v3(
|
speculate_update_v3(
|
||||||
model_output.seq_lens_encoder,
|
model_output.seq_lens_encoder,
|
||||||
@@ -204,7 +204,7 @@ def post_process_specualate(model_output, skip_save_output: bool = False):
|
|||||||
model_output.accept_num,
|
model_output.accept_num,
|
||||||
model_output.not_need_stop,
|
model_output.not_need_stop,
|
||||||
model_output.mp_rank,
|
model_output.mp_rank,
|
||||||
False,
|
save_each_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
speculate_clear_accept_nums(model_output.accept_num,
|
speculate_clear_accept_nums(model_output.accept_num,
|
||||||
@@ -231,7 +231,7 @@ def post_process(sampler_output: SamplerOutput,
|
|||||||
skip_save_output: bool = False) -> None:
|
skip_save_output: bool = False) -> None:
|
||||||
""" Post-processing steps after completing a single token generation. """
|
""" Post-processing steps after completing a single token generation. """
|
||||||
if speculative_decoding:
|
if speculative_decoding:
|
||||||
post_process_specualate(model_output, skip_save_output)
|
post_process_specualate(model_output, save_each_rank, skip_save_output)
|
||||||
else:
|
else:
|
||||||
post_process_normal(sampler_output, model_output, save_each_rank,
|
post_process_normal(sampler_output, model_output, save_each_rank,
|
||||||
skip_save_output)
|
skip_save_output)
|
||||||
|
Reference in New Issue
Block a user