[MTP] support expert-parellel in mtp (#2835)

This commit is contained in:
freeliuzc
2025-07-14 14:28:50 +08:00
committed by GitHub
parent ece88596ed
commit 7f64d408a9
2 changed files with 5 additions and 4 deletions

View File

@@ -250,7 +250,8 @@ def load_composite_checkpoint(
# 2. Tensor Parallel (TP) # 2. Tensor Parallel (TP)
# 3. Pre-sharded (pre-split) # 3. Pre-sharded (pre-split)
""" """
if fd_config.parallel_config.use_ep: if fd_config.parallel_config.use_ep and \
fd_config.speculative_config.model_type != "mtp":
state_dict = load_ep_checkpoint(model_path, state_dict = load_ep_checkpoint(model_path,
fd_config.model_config, fd_config.model_config,
return_numpy=True) return_numpy=True)

View File

@@ -182,7 +182,7 @@ def post_process_normal(sampler_output: SamplerOutput,
) )
def post_process_specualate(model_output, skip_save_output: bool = False): def post_process_specualate(model_output, save_each_rank: bool = False, skip_save_output: bool = False):
"""""" """"""
speculate_update_v3( speculate_update_v3(
model_output.seq_lens_encoder, model_output.seq_lens_encoder,
@@ -204,7 +204,7 @@ def post_process_specualate(model_output, skip_save_output: bool = False):
model_output.accept_num, model_output.accept_num,
model_output.not_need_stop, model_output.not_need_stop,
model_output.mp_rank, model_output.mp_rank,
False, save_each_rank,
) )
speculate_clear_accept_nums(model_output.accept_num, speculate_clear_accept_nums(model_output.accept_num,
@@ -231,7 +231,7 @@ def post_process(sampler_output: SamplerOutput,
skip_save_output: bool = False) -> None: skip_save_output: bool = False) -> None:
""" Post-processing steps after completing a single token generation. """ """ Post-processing steps after completing a single token generation. """
if speculative_decoding: if speculative_decoding:
post_process_specualate(model_output, skip_save_output) post_process_specualate(model_output, save_each_rank, skip_save_output)
else: else:
post_process_normal(sampler_output, model_output, save_each_rank, post_process_normal(sampler_output, model_output, save_each_rank,
skip_save_output) skip_save_output)