Supports DP+TP+EP hybrid parallel deployment strategy (#3489)

* Support DP+TP+EP hybrid parallel deployment strategy

* Support DP+TP+EP hybrid parallel deployment strategy

* fix conflict

* add moe_tp_ep function split_allgather_out

* del tp_group in moe_cutlass_backend

* for ci

* fix parallel_config for ci

* del log
This commit is contained in:
lzy
2025-08-26 15:04:01 +08:00
committed by GitHub
parent 52eda7fdb3
commit d339df2e90
15 changed files with 304 additions and 224 deletions

View File

@@ -57,43 +57,37 @@ class VocabParallelEmbedding(nn.Layer):
hcg = fleet.get_hybrid_communicate_group()
self.mp_rank: int = hcg.get_model_parallel_rank()
self.column_cut = False
self.world_size: int = hcg.get_model_parallel_world_size()
self.ring_id: int = hcg.get_model_parallel_group().id
self.use_ep: bool = fd_config.parallel_config.use_ep
self.world_size: int = fd_config.parallel_config.tensor_parallel_size
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
self.tp_group = fd_config.parallel_config.tp_group
self.hidden_dropout_prob: float = fd_config.model_config.hidden_dropout_prob
self.initializer_range: float = fd_config.model_config.initializer_range
self.max_position_embeddings: int = fd_config.model_config.max_position_embeddings
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
self.params_dtype: str = params_dtype
if self.use_ep:
self.embeddings = nn.Embedding(
if not self.column_cut:
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
num_embeddings,
embedding_dim,
mp_group=self.tp_group,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
),
)
if self.world_size > 1:
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
else:
if not self.column_cut:
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
num_embeddings,
embedding_dim,
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
),
)
if self.world_size > 1:
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
else:
# column cut embedding
self.embeddings = nn.Embedding(
num_embeddings,
embedding_dim // self.world_size,
)
# column cut embedding
self.embeddings = nn.Embedding(
num_embeddings,
embedding_dim // self.world_size,
)
self.embeddings.weight.is_distributed = True
self.embeddings.weight.split_axis = 1
if self.world_size > 1:
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
self.embeddings.weight.is_distributed = True
self.embeddings.weight.split_axis = 1
if self.world_size > 1:
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
self.prefix = prefix
self.dropout = nn.Dropout(self.hidden_dropout_prob)
@@ -125,20 +119,17 @@ class VocabParallelEmbedding(nn.Layer):
Returns:
Tensor: Embedded tensor representation of the input IDs.
"""
if self.use_ep:
if self.column_cut:
input_embedings = self.embeddings(ids_remove_padding)
inputs_embeds_temp = []
paddle.distributed.all_gather(
inputs_embeds_temp,
input_embedings,
group=self.tp_group,
sync_op=True,
)
input_embedings = paddle.concat(inputs_embeds_temp, -1)
else:
if self.column_cut:
input_embedings = self.embeddings(ids_remove_padding)
inputs_embeds_temp = []
paddle.distributed.all_gather(
inputs_embeds_temp,
input_embedings,
group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
sync_op=True,
)
input_embedings = paddle.concat(inputs_embeds_temp, -1)
else:
input_embedings = self.embeddings(ids_remove_padding)
input_embedings = self.embeddings(ids_remove_padding)
return input_embedings

View File

@@ -703,6 +703,7 @@ class RowParallelLinear(LinearBase):
self.fd_config = fd_config
self.skip_quant = False
self.nranks = fd_config.parallel_config.tensor_parallel_size
self.tp_group = fd_config.parallel_config.tp_group
self.hidden_size = fd_config.model_config.hidden_size
self.head_dim = fd_config.model_config.head_dim
self.num_heads = fd_config.model_config.num_attention_heads // self.nranks
@@ -751,7 +752,7 @@ class RowParallelLinear(LinearBase):
out = paddle.matmul(x, self.weight)
if self.reduce_results and self.nranks > 1:
tensor_model_parallel_all_reduce(out)
tensor_model_parallel_all_reduce(out, self.tp_group)
return out

View File

@@ -58,7 +58,7 @@ class ParallelLMHead(nn.Layer):
self.bias_key: Optional[str] = prefix + ".bias"
else:
self.bias_key: Optional[str] = None
self.use_ep: bool = fd_config.parallel_config.use_ep
self.tp_group = fd_config.parallel_config.tp_group
self.column_cut = True
self.nranks = fd_config.parallel_config.tensor_parallel_size
self.fd_config = fd_config
@@ -68,60 +68,46 @@ class ParallelLMHead(nn.Layer):
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
if self.use_ep:
self.weight = self.create_parameter(
shape=[embedding_dim, num_embeddings],
dtype=paddle.get_default_dtype(),
is_bias=False,
if self.column_cut:
need_gather = True
self.linear = ColumnParallelLinear(
embedding_dim,
num_embeddings,
mp_group=self.tp_group,
weight_attr=None,
has_bias=True if self.bias_key is not None else False,
gather_output=need_gather,
fuse_matmul_bias=False,
)
if self.bias_key is not None:
self.bias = self.create_parameter(
shape=[num_embeddings],
dtype=paddle.get_default_dtype(),
is_bias=True,
)
set_weight_attrs(
self.linear.weight,
{
"weight_loader": default_weight_loader(self.fd_config),
"model_format": self.fd_config.model_config.model_format,
},
)
if self.nranks > 1:
set_weight_attrs(self.linear.weight, {"output_dim": True})
else:
if self.column_cut:
need_gather = True
self.linear = ColumnParallelLinear(
embedding_dim,
num_embeddings,
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
weight_attr=None,
has_bias=True if self.bias_key is not None else False,
gather_output=need_gather,
fuse_matmul_bias=False,
)
set_weight_attrs(
self.linear.weight,
{
"weight_loader": default_weight_loader(self.fd_config),
"model_format": self.fd_config.model_config.model_format,
},
)
if self.nranks > 1:
set_weight_attrs(self.linear.weight, {"output_dim": True})
else:
self.linear = RowParallelLinear(
embedding_dim,
num_embeddings,
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
weight_attr=None,
has_bias=True if self.bias_key is not None else False,
input_is_parallel=False,
fuse_matmul_bias=False,
)
set_weight_attrs(
self.linear.weight,
{
"weight_loader": default_weight_loader(self.fd_config),
"model_format": self.fd_config.model_config.model_format,
},
)
self.linear = RowParallelLinear(
embedding_dim,
num_embeddings,
mp_group=self.tp_group,
weight_attr=None,
has_bias=True if self.bias_key is not None else False,
input_is_parallel=False,
fuse_matmul_bias=False,
)
set_weight_attrs(
self.linear.weight,
{
"weight_loader": default_weight_loader(self.fd_config),
"model_format": self.fd_config.model_config.model_format,
},
)
if self.nranks > 1:
set_weight_attrs(self.linear.weight, {"output_dim": False})
if self.nranks > 1:
set_weight_attrs(self.linear.weight, {"output_dim": False})
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
"""
@@ -131,24 +117,19 @@ class ParallelLMHead(nn.Layer):
state_dict (dict): A dictionary containing the checkpoint weights and biases.
"""
if self.use_ep:
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()))
if self.bias_key is not None:
self.bias.set_value(get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype()))
if self.tie_word_embeddings:
self.linear.weight.set_value(
get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()).transpose([1, 0])
)
else:
if self.tie_word_embeddings:
self.linear.weight.set_value(
get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()).transpose([1, 0])
)
else:
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
if self.linear.weight.shape != weight_tensor.shape:
weight_tensor = weight_tensor.transpose([1, 0])
self.linear.weight.set_value(weight_tensor)
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
if self.linear.weight.shape != weight_tensor.shape:
weight_tensor = weight_tensor.transpose([1, 0])
self.linear.weight.set_value(weight_tensor)
if self.bias_key is not None:
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
self.linear.bias.set_value(bias)
if self.bias_key is not None:
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
self.linear.bias.set_value(bias)
def forward(self, input: paddle.Tensor) -> paddle.Tensor:
"""
@@ -161,11 +142,5 @@ class ParallelLMHead(nn.Layer):
Tensor: The output tensor after processing through the layer.
"""
logits = input
if self.use_ep:
if self.bias_key is None:
logits = paddle.matmul(logits, self.weight)
else:
logits = paddle.incubate.nn.functional.fused_linear(logits, self.weight, self.bias)
else:
logits = self.linear(logits)
logits = self.linear(logits)
return logits

View File

@@ -466,6 +466,6 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
1.0,
)[0]
if layer.tp_size > 1:
tensor_model_parallel_all_reduce(tmp_ffn_out)
tensor_model_parallel_all_reduce(tmp_ffn_out, self.tp_group)
return tmp_ffn_out

View File

@@ -98,6 +98,11 @@ class FusedMoE(nn.Layer):
self.tp_size = fd_config.parallel_config.tensor_parallel_size
self.ep_size = fd_config.parallel_config.expert_parallel_size
self.ep_rank = fd_config.parallel_config.expert_parallel_rank
self.tp_group = fd_config.parallel_config.tp_group
# NOTE(Zhenyu Li): just supports tp_size = 1 when ep_size > 1 in MOE now.
if self.ep_size > 1:
self.tp_size = 1
self.tp_rank = 0
assert (self.tp_size >= 1 and self.ep_size == 1) or (
self.tp_size == 1 and self.ep_size > 1

View File

@@ -321,33 +321,28 @@ def load_composite_checkpoint(
# 2. Tensor Parallel (TP)
# 3. Pre-sharded (pre-split)
"""
if fd_config.parallel_config.use_ep and fd_config.speculative_config.model_type != "mtp":
state_dict = load_ep_checkpoint(model_path, fd_config, return_numpy=True)
rank_dirs = [
f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f))
]
if len(rank_dirs) > 1:
if fd_config.parallel_config.tensor_parallel_size != len(rank_dirs):
raise ValueError(f"Your model only supports loading with tp{len(rank_dirs)}")
state_dict = load_pre_sharded_checkpoint(
model_path,
fd_config.parallel_config.tensor_parallel_rank,
use_fastsafetensor=False,
)
else:
rank_dirs = [
f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f))
]
if len(rank_dirs) > 1:
if fd_config.parallel_config.tensor_parallel_size != len(rank_dirs):
raise ValueError(f"Your model only supports loading with tp{len(rank_dirs)}")
state_dict = load_pre_sharded_checkpoint(
model_path,
fd_config.parallel_config.tensor_parallel_rank,
use_fastsafetensor=False,
)
if fd_config.load_config.use_fastsafetensor and (current_platform.available() and current_platform.is_cuda()):
state_dict = load_tp_checkpoint_v1(model_path, cls, fd_config, use_fastsafetensor=True)
deal_state_dict(state_dict)
else:
if fd_config.load_config.use_fastsafetensor and (
current_platform.available() and current_platform.is_cuda()
):
state_dict = load_tp_checkpoint_v1(model_path, cls, fd_config, use_fastsafetensor=True)
deal_state_dict(state_dict)
else:
state_dict = load_tp_checkpoint(
model_path,
cls,
fd_config.model_config.pretrained_config,
return_numpy=return_numpy,
)
state_dict = load_tp_checkpoint(
model_path,
cls,
fd_config.model_config.pretrained_config,
return_numpy=return_numpy,
)
if not state_dict:
raise ValueError("weight not found in state_dict !")
return state_dict

View File

@@ -103,6 +103,14 @@ class Ernie4_5_MoE(nn.Layer):
if hasattr(fd_config.quant_config, "moe_quant_type"):
moe_quant_type = fd_config.quant_config.moe_quant_type
self.expert_parallel_size = fd_config.parallel_config.expert_parallel_size
self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
self.tp_group = fd_config.parallel_config.tp_group
self.use_ep = self.expert_parallel_size > 1
self.us_tp = self.tensor_parallel_size > 1
if moe_quant_type == "w4a8" or moe_quant_type == "w4afp8":
weight_key_map = {
"gate_weight_key": f"{prefix}.gate.weight",
@@ -201,8 +209,30 @@ class Ernie4_5_MoE(nn.Layer):
if self.num_shared_experts > 0:
self.shared_experts.load_state_dict(state_dict)
def split_allgather_out(self, hidden_states: paddle.Tensor, token_num: int):
token_num_per_rank = (token_num + self.tensor_parallel_size - 1) // self.tensor_parallel_size
# AllGather will hang when the data shapes on multi-ranks are different!
part_hidden_states = paddle.zeros(
shape=[token_num_per_rank, hidden_states.shape[1]], dtype=hidden_states.dtype
)
start_offset = self.tensor_parallel_rank * token_num_per_rank
end_offset = (self.tensor_parallel_rank + 1) * token_num_per_rank
if end_offset > token_num:
end_offset = token_num
part_hidden_states[: (end_offset - start_offset), :] = hidden_states[start_offset:end_offset, :]
out = self.experts(part_hidden_states, self.gate)
multi_outs = []
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
out = paddle.concat(multi_outs, axis=0)
out = out[:token_num, :]
return out
def forward(self, hidden_states: paddle.Tensor):
out = self.experts(hidden_states, self.gate)
token_num = hidden_states.shape[0]
if self.use_ep and self.use_tp and token_num >= self.tensor_parallel_size:
out = self.split_allgather_out(hidden_states, token_num)
else:
out = self.experts(hidden_states, self.gate)
if self.num_shared_experts > 0:
s_x = self.shared_experts(hidden_states)
out = out + s_x

View File

@@ -51,6 +51,15 @@ class Qwen3MoeBlock(nn.Layer):
prefix: str = "",
) -> None:
super().__init__()
self.expert_parallel_size = fd_config.parallel_config.expert_parallel_size
self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
self.tp_group = fd_config.parallel_config.tp_group
self.use_ep = self.expert_parallel_size > 1
self.us_tp = self.tensor_parallel_size > 1
weight_key_map = {
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
@@ -74,8 +83,30 @@ class Qwen3MoeBlock(nn.Layer):
weight_dtype="float32",
)
def split_allgather_out(self, hidden_states: paddle.Tensor, token_num: int):
token_num_per_rank = (token_num + self.tensor_parallel_size - 1) // self.tensor_parallel_size
# AllGather will hang when the data shapes on multi-ranks are different!
part_hidden_states = paddle.zeros(
shape=[token_num_per_rank, hidden_states.shape[1]], dtype=hidden_states.dtype
)
start_offset = self.tensor_parallel_rank * token_num_per_rank
end_offset = (self.tensor_parallel_rank + 1) * token_num_per_rank
if end_offset > token_num:
end_offset = token_num
part_hidden_states[: (end_offset - start_offset), :] = hidden_states[start_offset:end_offset, :]
out = self.experts(part_hidden_states, self.gate)
multi_outs = []
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
out = paddle.concat(multi_outs, axis=0)
out = out[:token_num, :]
return out
def forward(self, x):
out = self.experts(x, self.gate)
token_num = x.shape[0]
if self.use_ep and self.use_tp and token_num >= self.tensor_parallel_size:
out = self.split_allgather_out(x, token_num)
else:
out = self.experts(x, self.gate)
return out
def load_state_dict(self, state_dict):

View File

@@ -72,6 +72,7 @@ class TensorSplitMode(Enum):
"""TensorSplitMode"""
GQA = "is_gqa"
TP_ROW_BIAS = "is_tp_row_bias"
TRANSPOSE = "transpose"
QKV = "is_old_qkv"
PairFused = "is_naive_2fuse"
@@ -212,7 +213,7 @@ def gqa_qkv_split_func(
"""
def fn(x, is_column=True):
"""fucn"""
"""func"""
def get_shape(tensor):
"""get_shape"""
@@ -430,7 +431,15 @@ def split_or_merge_func_v1(
def fn(x, **kwargs):
"""func"""
is_gqa = kwargs.pop("is_gqa", False)
if is_gqa:
is_tp_row_bias = kwargs.pop("is_tp_row_bias", False)
if is_tp_row_bias:
tensor = x[:, ...]
if isinstance(tensor, paddle.Tensor):
res = tensor / tensor_parallel_degree
else:
res = paddle.to_tensor(tensor, paddle.get_default_dtype()) / tensor_parallel_degree
return res
elif is_gqa:
func = split_or_merge_qkv_func(
is_split=is_split,
tensor_parallel_degree=tensor_parallel_degree,