[Feat] support mixed ep (#2969)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* Support mixed ep

* fix comment

* fix comment

* update mixep

* fix conflict

* fix typo

* update

* fix typo

* fix code style

* fix conflict
This commit is contained in:
Longzhi Wang
2025-07-25 15:29:30 +08:00
committed by GitHub
parent 332154f504
commit 0700c90caa
4 changed files with 140 additions and 51 deletions

View File

@@ -794,6 +794,14 @@ class GPUModelRunner(ModelRunnerBase):
# Update Batch type for cuda graph
# TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch
is_decode_batch = not ((self.share_inputs["seq_lens_this_time"] > 1).sum() > 0)
# mix ep in single node
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
is_decode_batch_list = []
paddle.distributed.all_gather_object(is_decode_batch_list, is_decode_batch)
is_decode_batch = all(is_decode_batch_list)
self.fd_config.parallel_config.moe_phase.phase = "decode" if is_decode_batch else "prefill"
self.forward_meta.step_use_cudagraph = self.use_cudagraph and is_decode_batch
# Initialzie attention meta data
@@ -1163,16 +1171,18 @@ class GPUModelRunner(ModelRunnerBase):
We plan to replace it with 'ModelForwardBatch'.
intermediate_tensors:
"""
# NOTE(wufeisheng): For Expert Parallelism
if not self.not_need_stop():
self._execute_empty_input()
return None
# 1. Prepare inputs of model and sampler.
skip_idx_list = self._get_skip_idx(model_forward_batch)
self._prepare_inputs()
self.sampler.pre_process(skip_idx_list)
# NOTE(wufeisheng): If `not_need_stop`` is False, it means the current worker is in an idle state.
# This logic is not used in TP (Tensor Parallelism) mode. However, in EP (Expert Parallelism) mode,
# when there is data on other runner, the current runner is required to execute part of the model.
if not self.not_need_stop():
self._execute_empty_input()
return None
# 2. Padding inputs for cuda graph
self.padding_cudagraph_inputs()