mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
Supports DP+TP+EP hybrid parallel deployment strategy (#3489)
* Support DP+TP+EP hybrid parallel deployment strategy * Support DP+TP+EP hybrid parallel deployment strategy * fix conflict * add moe_tp_ep function split_allgather_out * del tp_group in moe_cutlass_backend * for ci * fix parallel_config for ci * del log
This commit is contained in:
@@ -57,43 +57,37 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
hcg = fleet.get_hybrid_communicate_group()
|
||||
self.mp_rank: int = hcg.get_model_parallel_rank()
|
||||
self.column_cut = False
|
||||
self.world_size: int = hcg.get_model_parallel_world_size()
|
||||
self.ring_id: int = hcg.get_model_parallel_group().id
|
||||
self.use_ep: bool = fd_config.parallel_config.use_ep
|
||||
self.world_size: int = fd_config.parallel_config.tensor_parallel_size
|
||||
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
self.tp_group = fd_config.parallel_config.tp_group
|
||||
self.hidden_dropout_prob: float = fd_config.model_config.hidden_dropout_prob
|
||||
self.initializer_range: float = fd_config.model_config.initializer_range
|
||||
self.max_position_embeddings: int = fd_config.model_config.max_position_embeddings
|
||||
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
|
||||
self.params_dtype: str = params_dtype
|
||||
|
||||
if self.use_ep:
|
||||
self.embeddings = nn.Embedding(
|
||||
if not self.column_cut:
|
||||
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
mp_group=self.tp_group,
|
||||
weight_attr=paddle.ParamAttr(
|
||||
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
|
||||
),
|
||||
)
|
||||
if self.world_size > 1:
|
||||
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
|
||||
else:
|
||||
if not self.column_cut:
|
||||
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||
weight_attr=paddle.ParamAttr(
|
||||
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
|
||||
),
|
||||
)
|
||||
if self.world_size > 1:
|
||||
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
|
||||
else:
|
||||
# column cut embedding
|
||||
self.embeddings = nn.Embedding(
|
||||
num_embeddings,
|
||||
embedding_dim // self.world_size,
|
||||
)
|
||||
# column cut embedding
|
||||
self.embeddings = nn.Embedding(
|
||||
num_embeddings,
|
||||
embedding_dim // self.world_size,
|
||||
)
|
||||
|
||||
self.embeddings.weight.is_distributed = True
|
||||
self.embeddings.weight.split_axis = 1
|
||||
if self.world_size > 1:
|
||||
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
|
||||
self.embeddings.weight.is_distributed = True
|
||||
self.embeddings.weight.split_axis = 1
|
||||
if self.world_size > 1:
|
||||
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
|
||||
|
||||
self.prefix = prefix
|
||||
self.dropout = nn.Dropout(self.hidden_dropout_prob)
|
||||
@@ -125,20 +119,17 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
Returns:
|
||||
Tensor: Embedded tensor representation of the input IDs.
|
||||
"""
|
||||
if self.use_ep:
|
||||
if self.column_cut:
|
||||
input_embedings = self.embeddings(ids_remove_padding)
|
||||
inputs_embeds_temp = []
|
||||
paddle.distributed.all_gather(
|
||||
inputs_embeds_temp,
|
||||
input_embedings,
|
||||
group=self.tp_group,
|
||||
sync_op=True,
|
||||
)
|
||||
input_embedings = paddle.concat(inputs_embeds_temp, -1)
|
||||
else:
|
||||
if self.column_cut:
|
||||
input_embedings = self.embeddings(ids_remove_padding)
|
||||
inputs_embeds_temp = []
|
||||
paddle.distributed.all_gather(
|
||||
inputs_embeds_temp,
|
||||
input_embedings,
|
||||
group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||
sync_op=True,
|
||||
)
|
||||
input_embedings = paddle.concat(inputs_embeds_temp, -1)
|
||||
else:
|
||||
input_embedings = self.embeddings(ids_remove_padding)
|
||||
input_embedings = self.embeddings(ids_remove_padding)
|
||||
|
||||
return input_embedings
|
||||
|
@@ -703,6 +703,7 @@ class RowParallelLinear(LinearBase):
|
||||
self.fd_config = fd_config
|
||||
self.skip_quant = False
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
self.tp_group = fd_config.parallel_config.tp_group
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
self.num_heads = fd_config.model_config.num_attention_heads // self.nranks
|
||||
@@ -751,7 +752,7 @@ class RowParallelLinear(LinearBase):
|
||||
out = paddle.matmul(x, self.weight)
|
||||
|
||||
if self.reduce_results and self.nranks > 1:
|
||||
tensor_model_parallel_all_reduce(out)
|
||||
tensor_model_parallel_all_reduce(out, self.tp_group)
|
||||
|
||||
return out
|
||||
|
||||
|
@@ -58,7 +58,7 @@ class ParallelLMHead(nn.Layer):
|
||||
self.bias_key: Optional[str] = prefix + ".bias"
|
||||
else:
|
||||
self.bias_key: Optional[str] = None
|
||||
self.use_ep: bool = fd_config.parallel_config.use_ep
|
||||
self.tp_group = fd_config.parallel_config.tp_group
|
||||
self.column_cut = True
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
self.fd_config = fd_config
|
||||
@@ -68,60 +68,46 @@ class ParallelLMHead(nn.Layer):
|
||||
|
||||
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
|
||||
|
||||
if self.use_ep:
|
||||
self.weight = self.create_parameter(
|
||||
shape=[embedding_dim, num_embeddings],
|
||||
dtype=paddle.get_default_dtype(),
|
||||
is_bias=False,
|
||||
if self.column_cut:
|
||||
need_gather = True
|
||||
self.linear = ColumnParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=self.tp_group,
|
||||
weight_attr=None,
|
||||
has_bias=True if self.bias_key is not None else False,
|
||||
gather_output=need_gather,
|
||||
fuse_matmul_bias=False,
|
||||
)
|
||||
if self.bias_key is not None:
|
||||
self.bias = self.create_parameter(
|
||||
shape=[num_embeddings],
|
||||
dtype=paddle.get_default_dtype(),
|
||||
is_bias=True,
|
||||
)
|
||||
|
||||
set_weight_attrs(
|
||||
self.linear.weight,
|
||||
{
|
||||
"weight_loader": default_weight_loader(self.fd_config),
|
||||
"model_format": self.fd_config.model_config.model_format,
|
||||
},
|
||||
)
|
||||
if self.nranks > 1:
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||
else:
|
||||
if self.column_cut:
|
||||
need_gather = True
|
||||
self.linear = ColumnParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||
weight_attr=None,
|
||||
has_bias=True if self.bias_key is not None else False,
|
||||
gather_output=need_gather,
|
||||
fuse_matmul_bias=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.linear.weight,
|
||||
{
|
||||
"weight_loader": default_weight_loader(self.fd_config),
|
||||
"model_format": self.fd_config.model_config.model_format,
|
||||
},
|
||||
)
|
||||
if self.nranks > 1:
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||
else:
|
||||
self.linear = RowParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||
weight_attr=None,
|
||||
has_bias=True if self.bias_key is not None else False,
|
||||
input_is_parallel=False,
|
||||
fuse_matmul_bias=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.linear.weight,
|
||||
{
|
||||
"weight_loader": default_weight_loader(self.fd_config),
|
||||
"model_format": self.fd_config.model_config.model_format,
|
||||
},
|
||||
)
|
||||
self.linear = RowParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=self.tp_group,
|
||||
weight_attr=None,
|
||||
has_bias=True if self.bias_key is not None else False,
|
||||
input_is_parallel=False,
|
||||
fuse_matmul_bias=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.linear.weight,
|
||||
{
|
||||
"weight_loader": default_weight_loader(self.fd_config),
|
||||
"model_format": self.fd_config.model_config.model_format,
|
||||
},
|
||||
)
|
||||
|
||||
if self.nranks > 1:
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": False})
|
||||
if self.nranks > 1:
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": False})
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
|
||||
"""
|
||||
@@ -131,24 +117,19 @@ class ParallelLMHead(nn.Layer):
|
||||
state_dict (dict): A dictionary containing the checkpoint weights and biases.
|
||||
"""
|
||||
|
||||
if self.use_ep:
|
||||
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()))
|
||||
if self.bias_key is not None:
|
||||
self.bias.set_value(get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype()))
|
||||
if self.tie_word_embeddings:
|
||||
self.linear.weight.set_value(
|
||||
get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()).transpose([1, 0])
|
||||
)
|
||||
else:
|
||||
if self.tie_word_embeddings:
|
||||
self.linear.weight.set_value(
|
||||
get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()).transpose([1, 0])
|
||||
)
|
||||
else:
|
||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
|
||||
if self.linear.weight.shape != weight_tensor.shape:
|
||||
weight_tensor = weight_tensor.transpose([1, 0])
|
||||
self.linear.weight.set_value(weight_tensor)
|
||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
|
||||
if self.linear.weight.shape != weight_tensor.shape:
|
||||
weight_tensor = weight_tensor.transpose([1, 0])
|
||||
self.linear.weight.set_value(weight_tensor)
|
||||
|
||||
if self.bias_key is not None:
|
||||
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
|
||||
self.linear.bias.set_value(bias)
|
||||
if self.bias_key is not None:
|
||||
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
|
||||
self.linear.bias.set_value(bias)
|
||||
|
||||
def forward(self, input: paddle.Tensor) -> paddle.Tensor:
|
||||
"""
|
||||
@@ -161,11 +142,5 @@ class ParallelLMHead(nn.Layer):
|
||||
Tensor: The output tensor after processing through the layer.
|
||||
"""
|
||||
logits = input
|
||||
if self.use_ep:
|
||||
if self.bias_key is None:
|
||||
logits = paddle.matmul(logits, self.weight)
|
||||
else:
|
||||
logits = paddle.incubate.nn.functional.fused_linear(logits, self.weight, self.bias)
|
||||
else:
|
||||
logits = self.linear(logits)
|
||||
logits = self.linear(logits)
|
||||
return logits
|
||||
|
@@ -466,6 +466,6 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
1.0,
|
||||
)[0]
|
||||
if layer.tp_size > 1:
|
||||
tensor_model_parallel_all_reduce(tmp_ffn_out)
|
||||
tensor_model_parallel_all_reduce(tmp_ffn_out, self.tp_group)
|
||||
|
||||
return tmp_ffn_out
|
||||
|
@@ -98,6 +98,11 @@ class FusedMoE(nn.Layer):
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||
self.ep_size = fd_config.parallel_config.expert_parallel_size
|
||||
self.ep_rank = fd_config.parallel_config.expert_parallel_rank
|
||||
self.tp_group = fd_config.parallel_config.tp_group
|
||||
# NOTE(Zhenyu Li): just supports tp_size = 1 when ep_size > 1 in MOE now.
|
||||
if self.ep_size > 1:
|
||||
self.tp_size = 1
|
||||
self.tp_rank = 0
|
||||
|
||||
assert (self.tp_size >= 1 and self.ep_size == 1) or (
|
||||
self.tp_size == 1 and self.ep_size > 1
|
||||
|
Reference in New Issue
Block a user