mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
Supports DP+TP+EP hybrid parallel deployment strategy (#3489)
* Support DP+TP+EP hybrid parallel deployment strategy * Support DP+TP+EP hybrid parallel deployment strategy * fix conflict * add moe_tp_ep function split_allgather_out * del tp_group in moe_cutlass_backend * for ci * fix parallel_config for ci * del log
This commit is contained in:
@@ -51,6 +51,15 @@ class Qwen3MoeBlock(nn.Layer):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.expert_parallel_size = fd_config.parallel_config.expert_parallel_size
|
||||
self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size
|
||||
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
self.tp_group = fd_config.parallel_config.tp_group
|
||||
|
||||
self.use_ep = self.expert_parallel_size > 1
|
||||
self.us_tp = self.tensor_parallel_size > 1
|
||||
|
||||
weight_key_map = {
|
||||
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
|
||||
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
|
||||
@@ -74,8 +83,30 @@ class Qwen3MoeBlock(nn.Layer):
|
||||
weight_dtype="float32",
|
||||
)
|
||||
|
||||
def split_allgather_out(self, hidden_states: paddle.Tensor, token_num: int):
|
||||
token_num_per_rank = (token_num + self.tensor_parallel_size - 1) // self.tensor_parallel_size
|
||||
# AllGather will hang when the data shapes on multi-ranks are different!
|
||||
part_hidden_states = paddle.zeros(
|
||||
shape=[token_num_per_rank, hidden_states.shape[1]], dtype=hidden_states.dtype
|
||||
)
|
||||
start_offset = self.tensor_parallel_rank * token_num_per_rank
|
||||
end_offset = (self.tensor_parallel_rank + 1) * token_num_per_rank
|
||||
if end_offset > token_num:
|
||||
end_offset = token_num
|
||||
part_hidden_states[: (end_offset - start_offset), :] = hidden_states[start_offset:end_offset, :]
|
||||
out = self.experts(part_hidden_states, self.gate)
|
||||
multi_outs = []
|
||||
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
|
||||
out = paddle.concat(multi_outs, axis=0)
|
||||
out = out[:token_num, :]
|
||||
return out
|
||||
|
||||
def forward(self, x):
|
||||
out = self.experts(x, self.gate)
|
||||
token_num = x.shape[0]
|
||||
if self.use_ep and self.use_tp and token_num >= self.tensor_parallel_size:
|
||||
out = self.split_allgather_out(x, token_num)
|
||||
else:
|
||||
out = self.experts(x, self.gate)
|
||||
return out
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
|
Reference in New Issue
Block a user