[Optimization] 1.fix tp+ep moe_forward; 2.set max_prefill_batch=env.MAX_PREFILL_NUM (#5315)

This commit is contained in:
lzy
2025-12-03 13:33:15 +08:00
committed by GitHub
parent f6544c0b1b
commit 690bcb8e50
2 changed files with 14 additions and 7 deletions

View File

@@ -1586,7 +1586,11 @@ class FDConfig:
self.max_prefill_batch = int(os.getenv("MAX_PREFILL_NUM", "3"))
if current_platform.is_xpu():
self.max_prefill_batch = 1
if self.model_config is not None and self.model_config.enable_mm:
if (
int(envs.ENABLE_V1_KVCACHE_SCHEDULER) == 0
and self.model_config is not None
and self.model_config.enable_mm
):
self.max_prefill_batch = 1 # TODO:当前多模prefill阶段只支持并行度为1,待优化
else:
self.max_prefill_batch = self.scheduler_config.max_num_seqs

View File

@@ -163,6 +163,9 @@ class FusedMoE(nn.Layer):
self.tp_size = 1
self.tp_rank = 0
self.attn_tp_size = fd_config.parallel_config.tensor_parallel_size
self.attn_tp_rank = fd_config.parallel_config.tensor_parallel_rank
assert (self.tp_size >= 1 and self.ep_size == 1) or (
self.tp_size == 1 and self.ep_size > 1
), "MoE only support parallelism on TP or EP dimension."
@@ -598,18 +601,18 @@ class FusedMoE(nn.Layer):
Forward split allgather function.
"""
token_num = x.shape[0]
token_num_per_rank = (token_num + self.tp_size - 1) // self.tp_size
token_num_per_rank = (token_num + self.attn_tp_size - 1) // self.attn_tp_size
# AllGather will hang when the data shapes on multi-ranks are different!
part_x = paddle.zeros(shape=[token_num_per_rank, x.shape[1]], dtype=x.dtype)
start_offset = self.tp_rank * token_num_per_rank
end_offset = (self.tp_rank + 1) * token_num_per_rank
start_offset = self.attn_tp_rank * token_num_per_rank
end_offset = (self.attn_tp_rank + 1) * token_num_per_rank
if start_offset >= token_num:
start_offset = token_num
if end_offset > token_num:
end_offset = token_num
part_x[: (end_offset - start_offset), :] = x[start_offset:end_offset, :]
out = self.quant_method.apply(self, part_x, gate)
multi_outs = paddle.zeros([token_num_per_rank * self.tp_size, x.shape[1]], dtype=x.dtype)
multi_outs = paddle.zeros([token_num_per_rank * self.attn_tp_size, x.shape[1]], dtype=x.dtype)
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
out = multi_outs[:token_num, :]
@@ -629,9 +632,9 @@ class FusedMoE(nn.Layer):
token_num = x.shape[0]
if (
self.ep_size > 1
and self.tp_size > 1
and self.attn_tp_size > 1
and (not self.fd_config.parallel_config.use_sequence_parallel_moe)
and token_num >= self.tp_size
and token_num >= self.attn_tp_size
):
out = self.forward_split_allgather(x, gate)
elif self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.enable_chunked_moe: