Supports DP+TP+EP hybrid parallel deployment strategy (#3489)

* Support DP+TP+EP hybrid parallel deployment strategy

* Support DP+TP+EP hybrid parallel deployment strategy

* fix conflict

* add moe_tp_ep function split_allgather_out

* del tp_group in moe_cutlass_backend

* for ci

* fix parallel_config for ci

* del log
This commit is contained in:
lzy
2025-08-26 15:04:01 +08:00
committed by GitHub
parent 52eda7fdb3
commit d339df2e90
15 changed files with 304 additions and 224 deletions

View File

@@ -103,6 +103,14 @@ class Ernie4_5_MoE(nn.Layer):
if hasattr(fd_config.quant_config, "moe_quant_type"):
moe_quant_type = fd_config.quant_config.moe_quant_type
self.expert_parallel_size = fd_config.parallel_config.expert_parallel_size
self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
self.tp_group = fd_config.parallel_config.tp_group
self.use_ep = self.expert_parallel_size > 1
self.us_tp = self.tensor_parallel_size > 1
if moe_quant_type == "w4a8" or moe_quant_type == "w4afp8":
weight_key_map = {
"gate_weight_key": f"{prefix}.gate.weight",
@@ -201,8 +209,30 @@ class Ernie4_5_MoE(nn.Layer):
if self.num_shared_experts > 0:
self.shared_experts.load_state_dict(state_dict)
def split_allgather_out(self, hidden_states: paddle.Tensor, token_num: int):
token_num_per_rank = (token_num + self.tensor_parallel_size - 1) // self.tensor_parallel_size
# AllGather will hang when the data shapes on multi-ranks are different!
part_hidden_states = paddle.zeros(
shape=[token_num_per_rank, hidden_states.shape[1]], dtype=hidden_states.dtype
)
start_offset = self.tensor_parallel_rank * token_num_per_rank
end_offset = (self.tensor_parallel_rank + 1) * token_num_per_rank
if end_offset > token_num:
end_offset = token_num
part_hidden_states[: (end_offset - start_offset), :] = hidden_states[start_offset:end_offset, :]
out = self.experts(part_hidden_states, self.gate)
multi_outs = []
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
out = paddle.concat(multi_outs, axis=0)
out = out[:token_num, :]
return out
def forward(self, hidden_states: paddle.Tensor):
out = self.experts(hidden_states, self.gate)
token_num = hidden_states.shape[0]
if self.use_ep and self.use_tp and token_num >= self.tensor_parallel_size:
out = self.split_allgather_out(hidden_states, token_num)
else:
out = self.experts(hidden_states, self.gate)
if self.num_shared_experts > 0:
s_x = self.shared_experts(hidden_states)
out = out + s_x

View File

@@ -51,6 +51,15 @@ class Qwen3MoeBlock(nn.Layer):
prefix: str = "",
) -> None:
super().__init__()
self.expert_parallel_size = fd_config.parallel_config.expert_parallel_size
self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
self.tp_group = fd_config.parallel_config.tp_group
self.use_ep = self.expert_parallel_size > 1
self.us_tp = self.tensor_parallel_size > 1
weight_key_map = {
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
@@ -74,8 +83,30 @@ class Qwen3MoeBlock(nn.Layer):
weight_dtype="float32",
)
def split_allgather_out(self, hidden_states: paddle.Tensor, token_num: int):
token_num_per_rank = (token_num + self.tensor_parallel_size - 1) // self.tensor_parallel_size
# AllGather will hang when the data shapes on multi-ranks are different!
part_hidden_states = paddle.zeros(
shape=[token_num_per_rank, hidden_states.shape[1]], dtype=hidden_states.dtype
)
start_offset = self.tensor_parallel_rank * token_num_per_rank
end_offset = (self.tensor_parallel_rank + 1) * token_num_per_rank
if end_offset > token_num:
end_offset = token_num
part_hidden_states[: (end_offset - start_offset), :] = hidden_states[start_offset:end_offset, :]
out = self.experts(part_hidden_states, self.gate)
multi_outs = []
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
out = paddle.concat(multi_outs, axis=0)
out = out[:token_num, :]
return out
def forward(self, x):
out = self.experts(x, self.gate)
token_num = x.shape[0]
if self.use_ep and self.use_tp and token_num >= self.tensor_parallel_size:
out = self.split_allgather_out(x, token_num)
else:
out = self.experts(x, self.gate)
return out
def load_state_dict(self, state_dict):

View File

@@ -72,6 +72,7 @@ class TensorSplitMode(Enum):
"""TensorSplitMode"""
GQA = "is_gqa"
TP_ROW_BIAS = "is_tp_row_bias"
TRANSPOSE = "transpose"
QKV = "is_old_qkv"
PairFused = "is_naive_2fuse"
@@ -212,7 +213,7 @@ def gqa_qkv_split_func(
"""
def fn(x, is_column=True):
"""fucn"""
"""func"""
def get_shape(tensor):
"""get_shape"""
@@ -430,7 +431,15 @@ def split_or_merge_func_v1(
def fn(x, **kwargs):
"""func"""
is_gqa = kwargs.pop("is_gqa", False)
if is_gqa:
is_tp_row_bias = kwargs.pop("is_tp_row_bias", False)
if is_tp_row_bias:
tensor = x[:, ...]
if isinstance(tensor, paddle.Tensor):
res = tensor / tensor_parallel_degree
else:
res = paddle.to_tensor(tensor, paddle.get_default_dtype()) / tensor_parallel_degree
return res
elif is_gqa:
func = split_or_merge_qkv_func(
is_split=is_split,
tensor_parallel_degree=tensor_parallel_degree,