From 7f64d408a9655cd41aed2ae5f29617251d302eb1 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Mon, 14 Jul 2025 14:28:50 +0800 Subject: [PATCH] [MTP] support expert-parellel in mtp (#2835) --- fastdeploy/model_executor/load_weight_utils.py | 3 ++- fastdeploy/model_executor/pre_and_post_process.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index c8ba1f673..012905249 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -250,7 +250,8 @@ def load_composite_checkpoint( # 2. Tensor Parallel (TP) # 3. Pre-sharded (pre-split) """ - if fd_config.parallel_config.use_ep: + if fd_config.parallel_config.use_ep and \ + fd_config.speculative_config.model_type != "mtp": state_dict = load_ep_checkpoint(model_path, fd_config.model_config, return_numpy=True) diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 5ba348574..0ddb8f6f0 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -182,7 +182,7 @@ def post_process_normal(sampler_output: SamplerOutput, ) -def post_process_specualate(model_output, skip_save_output: bool = False): +def post_process_specualate(model_output, save_each_rank: bool = False, skip_save_output: bool = False): """""" speculate_update_v3( model_output.seq_lens_encoder, @@ -204,7 +204,7 @@ def post_process_specualate(model_output, skip_save_output: bool = False): model_output.accept_num, model_output.not_need_stop, model_output.mp_rank, - False, + save_each_rank, ) speculate_clear_accept_nums(model_output.accept_num, @@ -231,7 +231,7 @@ def post_process(sampler_output: SamplerOutput, skip_save_output: bool = False) -> None: """ Post-processing steps after completing a single token generation. """ if speculative_decoding: - post_process_specualate(model_output, skip_save_output) + post_process_specualate(model_output, save_each_rank, skip_save_output) else: post_process_normal(sampler_output, model_output, save_each_rank, skip_save_output)