mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 09:07:10 +08:00
Supports DP+TP+EP hybrid parallel deployment strategy (#3489)
* Support DP+TP+EP hybrid parallel deployment strategy * Support DP+TP+EP hybrid parallel deployment strategy * fix conflict * add moe_tp_ep function split_allgather_out * del tp_group in moe_cutlass_backend * for ci * fix parallel_config for ci * del log
This commit is contained in:
@@ -57,43 +57,37 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
hcg = fleet.get_hybrid_communicate_group()
|
||||
self.mp_rank: int = hcg.get_model_parallel_rank()
|
||||
self.column_cut = False
|
||||
self.world_size: int = hcg.get_model_parallel_world_size()
|
||||
self.ring_id: int = hcg.get_model_parallel_group().id
|
||||
self.use_ep: bool = fd_config.parallel_config.use_ep
|
||||
self.world_size: int = fd_config.parallel_config.tensor_parallel_size
|
||||
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
self.tp_group = fd_config.parallel_config.tp_group
|
||||
self.hidden_dropout_prob: float = fd_config.model_config.hidden_dropout_prob
|
||||
self.initializer_range: float = fd_config.model_config.initializer_range
|
||||
self.max_position_embeddings: int = fd_config.model_config.max_position_embeddings
|
||||
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
|
||||
self.params_dtype: str = params_dtype
|
||||
|
||||
if self.use_ep:
|
||||
self.embeddings = nn.Embedding(
|
||||
if not self.column_cut:
|
||||
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
mp_group=self.tp_group,
|
||||
weight_attr=paddle.ParamAttr(
|
||||
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
|
||||
),
|
||||
)
|
||||
if self.world_size > 1:
|
||||
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
|
||||
else:
|
||||
if not self.column_cut:
|
||||
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||
weight_attr=paddle.ParamAttr(
|
||||
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
|
||||
),
|
||||
)
|
||||
if self.world_size > 1:
|
||||
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
|
||||
else:
|
||||
# column cut embedding
|
||||
self.embeddings = nn.Embedding(
|
||||
num_embeddings,
|
||||
embedding_dim // self.world_size,
|
||||
)
|
||||
# column cut embedding
|
||||
self.embeddings = nn.Embedding(
|
||||
num_embeddings,
|
||||
embedding_dim // self.world_size,
|
||||
)
|
||||
|
||||
self.embeddings.weight.is_distributed = True
|
||||
self.embeddings.weight.split_axis = 1
|
||||
if self.world_size > 1:
|
||||
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
|
||||
self.embeddings.weight.is_distributed = True
|
||||
self.embeddings.weight.split_axis = 1
|
||||
if self.world_size > 1:
|
||||
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
|
||||
|
||||
self.prefix = prefix
|
||||
self.dropout = nn.Dropout(self.hidden_dropout_prob)
|
||||
@@ -125,20 +119,17 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
Returns:
|
||||
Tensor: Embedded tensor representation of the input IDs.
|
||||
"""
|
||||
if self.use_ep:
|
||||
if self.column_cut:
|
||||
input_embedings = self.embeddings(ids_remove_padding)
|
||||
inputs_embeds_temp = []
|
||||
paddle.distributed.all_gather(
|
||||
inputs_embeds_temp,
|
||||
input_embedings,
|
||||
group=self.tp_group,
|
||||
sync_op=True,
|
||||
)
|
||||
input_embedings = paddle.concat(inputs_embeds_temp, -1)
|
||||
else:
|
||||
if self.column_cut:
|
||||
input_embedings = self.embeddings(ids_remove_padding)
|
||||
inputs_embeds_temp = []
|
||||
paddle.distributed.all_gather(
|
||||
inputs_embeds_temp,
|
||||
input_embedings,
|
||||
group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||
sync_op=True,
|
||||
)
|
||||
input_embedings = paddle.concat(inputs_embeds_temp, -1)
|
||||
else:
|
||||
input_embedings = self.embeddings(ids_remove_padding)
|
||||
input_embedings = self.embeddings(ids_remove_padding)
|
||||
|
||||
return input_embedings
|
||||
|
Reference in New Issue
Block a user