mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00
Supports DP+TP+EP hybrid parallel deployment strategy (#3489)
* Support DP+TP+EP hybrid parallel deployment strategy * Support DP+TP+EP hybrid parallel deployment strategy * fix conflict * add moe_tp_ep function split_allgather_out * del tp_group in moe_cutlass_backend * for ci * fix parallel_config for ci * del log
This commit is contained in:
@@ -43,11 +43,16 @@
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 48: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 48; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 32: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 32; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 48: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 48; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 64: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 64; \
|
||||
__VA_ARGS__ \
|
||||
|
@@ -105,7 +105,8 @@ void SaveOutMmsg(const paddle::Tensor& x,
|
||||
int64_t rank_id,
|
||||
int msg_queue_id,
|
||||
bool save_each_rank) {
|
||||
if (!save_each_rank && rank_id > 0) {
|
||||
// don't use save_each_rank now!
|
||||
if (rank_id > 0) {
|
||||
return;
|
||||
}
|
||||
if (x.place() == paddle::CPUPlace()) {
|
||||
|
@@ -22,6 +22,7 @@ from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
from paddleformers.transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
import fastdeploy
|
||||
@@ -308,7 +309,10 @@ class ParallelConfig:
|
||||
setattr(self, key, value)
|
||||
|
||||
# currently, the expert parallel size is equal data parallel size
|
||||
self.expert_parallel_size = self.data_parallel_size
|
||||
if self.enable_expert_parallel:
|
||||
self.expert_parallel_size = self.data_parallel_size * self.tensor_parallel_size
|
||||
else:
|
||||
self.expert_parallel_size = 1
|
||||
self.use_ep = self.expert_parallel_size > 1
|
||||
if self.splitwise_role == "mixed":
|
||||
self.moe_phase = MoEPhase(phase="prefill")
|
||||
@@ -329,6 +333,22 @@ class ParallelConfig:
|
||||
else:
|
||||
self.pd_disaggregation_mode = "None"
|
||||
|
||||
def set_tp_group(self):
|
||||
# different tp group id
|
||||
# prevent different tp_groups using the same group_id
|
||||
dist.collective._set_custom_gid(self.data_parallel_rank + 100)
|
||||
self.tp_group = dist.new_group(
|
||||
range(
|
||||
self.data_parallel_rank * self.tensor_parallel_size,
|
||||
(self.data_parallel_rank + 1) * self.tensor_parallel_size,
|
||||
)
|
||||
)
|
||||
# same ep group id
|
||||
dist.collective._set_custom_gid(self.data_parallel_size + 100)
|
||||
logger.info(
|
||||
f"data_parallel_size: {self.data_parallel_size}, tensor_parallel_size: {self.tensor_parallel_size}, expert_parallel_size: {self.expert_parallel_size}, data_parallel_rank: {self.data_parallel_rank}, tensor_parallel_rank: {self.tensor_parallel_rank}, expert_parallel_rank: {self.expert_parallel_rank}, tp_group: {self.tp_group}."
|
||||
)
|
||||
|
||||
def print(self):
|
||||
"""
|
||||
print all config
|
||||
@@ -1104,7 +1124,7 @@ class FDConfig:
|
||||
if self.model_config is not None and self.model_config.enable_mm:
|
||||
self.max_prefill_batch = 1 # TODO:当前多模prefill阶段只支持并行度为1,待优化
|
||||
|
||||
num_ranks = self.parallel_config.tensor_parallel_size * self.parallel_config.expert_parallel_size
|
||||
num_ranks = self.parallel_config.tensor_parallel_size * self.parallel_config.data_parallel_size
|
||||
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
if num_ranks > self.max_chips_per_node:
|
||||
self.worker_num_per_node = self.max_chips_per_node
|
||||
|
@@ -47,15 +47,20 @@ try:
|
||||
@paddle.jit.marker.unified
|
||||
def tensor_model_parallel_all_reduce(
|
||||
input_: paddle.Tensor,
|
||||
group_: paddle.distributed.communication.group.Group = None,
|
||||
) -> paddle.Tensor:
|
||||
"""All-reduce the input tensor across model parallel group."""
|
||||
global _TP_AR
|
||||
if _TP_AR is not None and _TP_AR.should_custom_ar(input_):
|
||||
# TODO: supports different_group custom allreduce
|
||||
_TP_AR.custom_all_reduce(input_)
|
||||
elif paddle.in_dynamic_mode():
|
||||
hcg = fleet.get_hybrid_communicate_group()
|
||||
mp_group = hcg.get_model_parallel_group()
|
||||
dist.all_reduce(input_, group=mp_group)
|
||||
if group_ is not None:
|
||||
dist.all_reduce(input_, group=group_)
|
||||
else:
|
||||
hcg = fleet.get_hybrid_communicate_group()
|
||||
mp_group = hcg.get_model_parallel_group()
|
||||
dist.all_reduce(input_, group=mp_group)
|
||||
else:
|
||||
dist.all_reduce(input_)
|
||||
|
||||
|
@@ -57,43 +57,37 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
hcg = fleet.get_hybrid_communicate_group()
|
||||
self.mp_rank: int = hcg.get_model_parallel_rank()
|
||||
self.column_cut = False
|
||||
self.world_size: int = hcg.get_model_parallel_world_size()
|
||||
self.ring_id: int = hcg.get_model_parallel_group().id
|
||||
self.use_ep: bool = fd_config.parallel_config.use_ep
|
||||
self.world_size: int = fd_config.parallel_config.tensor_parallel_size
|
||||
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
self.tp_group = fd_config.parallel_config.tp_group
|
||||
self.hidden_dropout_prob: float = fd_config.model_config.hidden_dropout_prob
|
||||
self.initializer_range: float = fd_config.model_config.initializer_range
|
||||
self.max_position_embeddings: int = fd_config.model_config.max_position_embeddings
|
||||
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
|
||||
self.params_dtype: str = params_dtype
|
||||
|
||||
if self.use_ep:
|
||||
self.embeddings = nn.Embedding(
|
||||
if not self.column_cut:
|
||||
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
mp_group=self.tp_group,
|
||||
weight_attr=paddle.ParamAttr(
|
||||
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
|
||||
),
|
||||
)
|
||||
if self.world_size > 1:
|
||||
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
|
||||
else:
|
||||
if not self.column_cut:
|
||||
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||
weight_attr=paddle.ParamAttr(
|
||||
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
|
||||
),
|
||||
)
|
||||
if self.world_size > 1:
|
||||
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
|
||||
else:
|
||||
# column cut embedding
|
||||
self.embeddings = nn.Embedding(
|
||||
num_embeddings,
|
||||
embedding_dim // self.world_size,
|
||||
)
|
||||
# column cut embedding
|
||||
self.embeddings = nn.Embedding(
|
||||
num_embeddings,
|
||||
embedding_dim // self.world_size,
|
||||
)
|
||||
|
||||
self.embeddings.weight.is_distributed = True
|
||||
self.embeddings.weight.split_axis = 1
|
||||
if self.world_size > 1:
|
||||
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
|
||||
self.embeddings.weight.is_distributed = True
|
||||
self.embeddings.weight.split_axis = 1
|
||||
if self.world_size > 1:
|
||||
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
|
||||
|
||||
self.prefix = prefix
|
||||
self.dropout = nn.Dropout(self.hidden_dropout_prob)
|
||||
@@ -125,20 +119,17 @@ class VocabParallelEmbedding(nn.Layer):
|
||||
Returns:
|
||||
Tensor: Embedded tensor representation of the input IDs.
|
||||
"""
|
||||
if self.use_ep:
|
||||
if self.column_cut:
|
||||
input_embedings = self.embeddings(ids_remove_padding)
|
||||
inputs_embeds_temp = []
|
||||
paddle.distributed.all_gather(
|
||||
inputs_embeds_temp,
|
||||
input_embedings,
|
||||
group=self.tp_group,
|
||||
sync_op=True,
|
||||
)
|
||||
input_embedings = paddle.concat(inputs_embeds_temp, -1)
|
||||
else:
|
||||
if self.column_cut:
|
||||
input_embedings = self.embeddings(ids_remove_padding)
|
||||
inputs_embeds_temp = []
|
||||
paddle.distributed.all_gather(
|
||||
inputs_embeds_temp,
|
||||
input_embedings,
|
||||
group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||
sync_op=True,
|
||||
)
|
||||
input_embedings = paddle.concat(inputs_embeds_temp, -1)
|
||||
else:
|
||||
input_embedings = self.embeddings(ids_remove_padding)
|
||||
input_embedings = self.embeddings(ids_remove_padding)
|
||||
|
||||
return input_embedings
|
||||
|
@@ -703,6 +703,7 @@ class RowParallelLinear(LinearBase):
|
||||
self.fd_config = fd_config
|
||||
self.skip_quant = False
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
self.tp_group = fd_config.parallel_config.tp_group
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
self.num_heads = fd_config.model_config.num_attention_heads // self.nranks
|
||||
@@ -751,7 +752,7 @@ class RowParallelLinear(LinearBase):
|
||||
out = paddle.matmul(x, self.weight)
|
||||
|
||||
if self.reduce_results and self.nranks > 1:
|
||||
tensor_model_parallel_all_reduce(out)
|
||||
tensor_model_parallel_all_reduce(out, self.tp_group)
|
||||
|
||||
return out
|
||||
|
||||
|
@@ -58,7 +58,7 @@ class ParallelLMHead(nn.Layer):
|
||||
self.bias_key: Optional[str] = prefix + ".bias"
|
||||
else:
|
||||
self.bias_key: Optional[str] = None
|
||||
self.use_ep: bool = fd_config.parallel_config.use_ep
|
||||
self.tp_group = fd_config.parallel_config.tp_group
|
||||
self.column_cut = True
|
||||
self.nranks = fd_config.parallel_config.tensor_parallel_size
|
||||
self.fd_config = fd_config
|
||||
@@ -68,60 +68,46 @@ class ParallelLMHead(nn.Layer):
|
||||
|
||||
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
|
||||
|
||||
if self.use_ep:
|
||||
self.weight = self.create_parameter(
|
||||
shape=[embedding_dim, num_embeddings],
|
||||
dtype=paddle.get_default_dtype(),
|
||||
is_bias=False,
|
||||
if self.column_cut:
|
||||
need_gather = True
|
||||
self.linear = ColumnParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=self.tp_group,
|
||||
weight_attr=None,
|
||||
has_bias=True if self.bias_key is not None else False,
|
||||
gather_output=need_gather,
|
||||
fuse_matmul_bias=False,
|
||||
)
|
||||
if self.bias_key is not None:
|
||||
self.bias = self.create_parameter(
|
||||
shape=[num_embeddings],
|
||||
dtype=paddle.get_default_dtype(),
|
||||
is_bias=True,
|
||||
)
|
||||
|
||||
set_weight_attrs(
|
||||
self.linear.weight,
|
||||
{
|
||||
"weight_loader": default_weight_loader(self.fd_config),
|
||||
"model_format": self.fd_config.model_config.model_format,
|
||||
},
|
||||
)
|
||||
if self.nranks > 1:
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||
else:
|
||||
if self.column_cut:
|
||||
need_gather = True
|
||||
self.linear = ColumnParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||
weight_attr=None,
|
||||
has_bias=True if self.bias_key is not None else False,
|
||||
gather_output=need_gather,
|
||||
fuse_matmul_bias=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.linear.weight,
|
||||
{
|
||||
"weight_loader": default_weight_loader(self.fd_config),
|
||||
"model_format": self.fd_config.model_config.model_format,
|
||||
},
|
||||
)
|
||||
if self.nranks > 1:
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": True})
|
||||
else:
|
||||
self.linear = RowParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=fleet.get_hybrid_communicate_group().get_model_parallel_group(),
|
||||
weight_attr=None,
|
||||
has_bias=True if self.bias_key is not None else False,
|
||||
input_is_parallel=False,
|
||||
fuse_matmul_bias=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.linear.weight,
|
||||
{
|
||||
"weight_loader": default_weight_loader(self.fd_config),
|
||||
"model_format": self.fd_config.model_config.model_format,
|
||||
},
|
||||
)
|
||||
self.linear = RowParallelLinear(
|
||||
embedding_dim,
|
||||
num_embeddings,
|
||||
mp_group=self.tp_group,
|
||||
weight_attr=None,
|
||||
has_bias=True if self.bias_key is not None else False,
|
||||
input_is_parallel=False,
|
||||
fuse_matmul_bias=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.linear.weight,
|
||||
{
|
||||
"weight_loader": default_weight_loader(self.fd_config),
|
||||
"model_format": self.fd_config.model_config.model_format,
|
||||
},
|
||||
)
|
||||
|
||||
if self.nranks > 1:
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": False})
|
||||
if self.nranks > 1:
|
||||
set_weight_attrs(self.linear.weight, {"output_dim": False})
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
|
||||
"""
|
||||
@@ -131,24 +117,19 @@ class ParallelLMHead(nn.Layer):
|
||||
state_dict (dict): A dictionary containing the checkpoint weights and biases.
|
||||
"""
|
||||
|
||||
if self.use_ep:
|
||||
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()))
|
||||
if self.bias_key is not None:
|
||||
self.bias.set_value(get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype()))
|
||||
if self.tie_word_embeddings:
|
||||
self.linear.weight.set_value(
|
||||
get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()).transpose([1, 0])
|
||||
)
|
||||
else:
|
||||
if self.tie_word_embeddings:
|
||||
self.linear.weight.set_value(
|
||||
get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()).transpose([1, 0])
|
||||
)
|
||||
else:
|
||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
|
||||
if self.linear.weight.shape != weight_tensor.shape:
|
||||
weight_tensor = weight_tensor.transpose([1, 0])
|
||||
self.linear.weight.set_value(weight_tensor)
|
||||
weight_tensor = get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype())
|
||||
if self.linear.weight.shape != weight_tensor.shape:
|
||||
weight_tensor = weight_tensor.transpose([1, 0])
|
||||
self.linear.weight.set_value(weight_tensor)
|
||||
|
||||
if self.bias_key is not None:
|
||||
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
|
||||
self.linear.bias.set_value(bias)
|
||||
if self.bias_key is not None:
|
||||
bias = get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype())
|
||||
self.linear.bias.set_value(bias)
|
||||
|
||||
def forward(self, input: paddle.Tensor) -> paddle.Tensor:
|
||||
"""
|
||||
@@ -161,11 +142,5 @@ class ParallelLMHead(nn.Layer):
|
||||
Tensor: The output tensor after processing through the layer.
|
||||
"""
|
||||
logits = input
|
||||
if self.use_ep:
|
||||
if self.bias_key is None:
|
||||
logits = paddle.matmul(logits, self.weight)
|
||||
else:
|
||||
logits = paddle.incubate.nn.functional.fused_linear(logits, self.weight, self.bias)
|
||||
else:
|
||||
logits = self.linear(logits)
|
||||
logits = self.linear(logits)
|
||||
return logits
|
||||
|
@@ -466,6 +466,6 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
1.0,
|
||||
)[0]
|
||||
if layer.tp_size > 1:
|
||||
tensor_model_parallel_all_reduce(tmp_ffn_out)
|
||||
tensor_model_parallel_all_reduce(tmp_ffn_out, self.tp_group)
|
||||
|
||||
return tmp_ffn_out
|
||||
|
@@ -98,6 +98,11 @@ class FusedMoE(nn.Layer):
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||
self.ep_size = fd_config.parallel_config.expert_parallel_size
|
||||
self.ep_rank = fd_config.parallel_config.expert_parallel_rank
|
||||
self.tp_group = fd_config.parallel_config.tp_group
|
||||
# NOTE(Zhenyu Li): just supports tp_size = 1 when ep_size > 1 in MOE now.
|
||||
if self.ep_size > 1:
|
||||
self.tp_size = 1
|
||||
self.tp_rank = 0
|
||||
|
||||
assert (self.tp_size >= 1 and self.ep_size == 1) or (
|
||||
self.tp_size == 1 and self.ep_size > 1
|
||||
|
@@ -321,33 +321,28 @@ def load_composite_checkpoint(
|
||||
# 2. Tensor Parallel (TP)
|
||||
# 3. Pre-sharded (pre-split)
|
||||
"""
|
||||
if fd_config.parallel_config.use_ep and fd_config.speculative_config.model_type != "mtp":
|
||||
state_dict = load_ep_checkpoint(model_path, fd_config, return_numpy=True)
|
||||
rank_dirs = [
|
||||
f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f))
|
||||
]
|
||||
if len(rank_dirs) > 1:
|
||||
if fd_config.parallel_config.tensor_parallel_size != len(rank_dirs):
|
||||
raise ValueError(f"Your model only supports loading with tp{len(rank_dirs)}")
|
||||
state_dict = load_pre_sharded_checkpoint(
|
||||
model_path,
|
||||
fd_config.parallel_config.tensor_parallel_rank,
|
||||
use_fastsafetensor=False,
|
||||
)
|
||||
else:
|
||||
rank_dirs = [
|
||||
f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f))
|
||||
]
|
||||
if len(rank_dirs) > 1:
|
||||
if fd_config.parallel_config.tensor_parallel_size != len(rank_dirs):
|
||||
raise ValueError(f"Your model only supports loading with tp{len(rank_dirs)}")
|
||||
state_dict = load_pre_sharded_checkpoint(
|
||||
model_path,
|
||||
fd_config.parallel_config.tensor_parallel_rank,
|
||||
use_fastsafetensor=False,
|
||||
)
|
||||
if fd_config.load_config.use_fastsafetensor and (current_platform.available() and current_platform.is_cuda()):
|
||||
state_dict = load_tp_checkpoint_v1(model_path, cls, fd_config, use_fastsafetensor=True)
|
||||
deal_state_dict(state_dict)
|
||||
else:
|
||||
if fd_config.load_config.use_fastsafetensor and (
|
||||
current_platform.available() and current_platform.is_cuda()
|
||||
):
|
||||
state_dict = load_tp_checkpoint_v1(model_path, cls, fd_config, use_fastsafetensor=True)
|
||||
deal_state_dict(state_dict)
|
||||
else:
|
||||
state_dict = load_tp_checkpoint(
|
||||
model_path,
|
||||
cls,
|
||||
fd_config.model_config.pretrained_config,
|
||||
return_numpy=return_numpy,
|
||||
)
|
||||
state_dict = load_tp_checkpoint(
|
||||
model_path,
|
||||
cls,
|
||||
fd_config.model_config.pretrained_config,
|
||||
return_numpy=return_numpy,
|
||||
)
|
||||
if not state_dict:
|
||||
raise ValueError("weight not found in state_dict !")
|
||||
return state_dict
|
||||
|
@@ -103,6 +103,14 @@ class Ernie4_5_MoE(nn.Layer):
|
||||
if hasattr(fd_config.quant_config, "moe_quant_type"):
|
||||
moe_quant_type = fd_config.quant_config.moe_quant_type
|
||||
|
||||
self.expert_parallel_size = fd_config.parallel_config.expert_parallel_size
|
||||
self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size
|
||||
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
self.tp_group = fd_config.parallel_config.tp_group
|
||||
|
||||
self.use_ep = self.expert_parallel_size > 1
|
||||
self.us_tp = self.tensor_parallel_size > 1
|
||||
|
||||
if moe_quant_type == "w4a8" or moe_quant_type == "w4afp8":
|
||||
weight_key_map = {
|
||||
"gate_weight_key": f"{prefix}.gate.weight",
|
||||
@@ -201,8 +209,30 @@ class Ernie4_5_MoE(nn.Layer):
|
||||
if self.num_shared_experts > 0:
|
||||
self.shared_experts.load_state_dict(state_dict)
|
||||
|
||||
def split_allgather_out(self, hidden_states: paddle.Tensor, token_num: int):
|
||||
token_num_per_rank = (token_num + self.tensor_parallel_size - 1) // self.tensor_parallel_size
|
||||
# AllGather will hang when the data shapes on multi-ranks are different!
|
||||
part_hidden_states = paddle.zeros(
|
||||
shape=[token_num_per_rank, hidden_states.shape[1]], dtype=hidden_states.dtype
|
||||
)
|
||||
start_offset = self.tensor_parallel_rank * token_num_per_rank
|
||||
end_offset = (self.tensor_parallel_rank + 1) * token_num_per_rank
|
||||
if end_offset > token_num:
|
||||
end_offset = token_num
|
||||
part_hidden_states[: (end_offset - start_offset), :] = hidden_states[start_offset:end_offset, :]
|
||||
out = self.experts(part_hidden_states, self.gate)
|
||||
multi_outs = []
|
||||
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
|
||||
out = paddle.concat(multi_outs, axis=0)
|
||||
out = out[:token_num, :]
|
||||
return out
|
||||
|
||||
def forward(self, hidden_states: paddle.Tensor):
|
||||
out = self.experts(hidden_states, self.gate)
|
||||
token_num = hidden_states.shape[0]
|
||||
if self.use_ep and self.use_tp and token_num >= self.tensor_parallel_size:
|
||||
out = self.split_allgather_out(hidden_states, token_num)
|
||||
else:
|
||||
out = self.experts(hidden_states, self.gate)
|
||||
if self.num_shared_experts > 0:
|
||||
s_x = self.shared_experts(hidden_states)
|
||||
out = out + s_x
|
||||
|
@@ -51,6 +51,15 @@ class Qwen3MoeBlock(nn.Layer):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.expert_parallel_size = fd_config.parallel_config.expert_parallel_size
|
||||
self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size
|
||||
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
self.tp_group = fd_config.parallel_config.tp_group
|
||||
|
||||
self.use_ep = self.expert_parallel_size > 1
|
||||
self.us_tp = self.tensor_parallel_size > 1
|
||||
|
||||
weight_key_map = {
|
||||
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
|
||||
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
|
||||
@@ -74,8 +83,30 @@ class Qwen3MoeBlock(nn.Layer):
|
||||
weight_dtype="float32",
|
||||
)
|
||||
|
||||
def split_allgather_out(self, hidden_states: paddle.Tensor, token_num: int):
|
||||
token_num_per_rank = (token_num + self.tensor_parallel_size - 1) // self.tensor_parallel_size
|
||||
# AllGather will hang when the data shapes on multi-ranks are different!
|
||||
part_hidden_states = paddle.zeros(
|
||||
shape=[token_num_per_rank, hidden_states.shape[1]], dtype=hidden_states.dtype
|
||||
)
|
||||
start_offset = self.tensor_parallel_rank * token_num_per_rank
|
||||
end_offset = (self.tensor_parallel_rank + 1) * token_num_per_rank
|
||||
if end_offset > token_num:
|
||||
end_offset = token_num
|
||||
part_hidden_states[: (end_offset - start_offset), :] = hidden_states[start_offset:end_offset, :]
|
||||
out = self.experts(part_hidden_states, self.gate)
|
||||
multi_outs = []
|
||||
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
|
||||
out = paddle.concat(multi_outs, axis=0)
|
||||
out = out[:token_num, :]
|
||||
return out
|
||||
|
||||
def forward(self, x):
|
||||
out = self.experts(x, self.gate)
|
||||
token_num = x.shape[0]
|
||||
if self.use_ep and self.use_tp and token_num >= self.tensor_parallel_size:
|
||||
out = self.split_allgather_out(x, token_num)
|
||||
else:
|
||||
out = self.experts(x, self.gate)
|
||||
return out
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
|
@@ -72,6 +72,7 @@ class TensorSplitMode(Enum):
|
||||
"""TensorSplitMode"""
|
||||
|
||||
GQA = "is_gqa"
|
||||
TP_ROW_BIAS = "is_tp_row_bias"
|
||||
TRANSPOSE = "transpose"
|
||||
QKV = "is_old_qkv"
|
||||
PairFused = "is_naive_2fuse"
|
||||
@@ -212,7 +213,7 @@ def gqa_qkv_split_func(
|
||||
"""
|
||||
|
||||
def fn(x, is_column=True):
|
||||
"""fucn"""
|
||||
"""func"""
|
||||
|
||||
def get_shape(tensor):
|
||||
"""get_shape"""
|
||||
@@ -430,7 +431,15 @@ def split_or_merge_func_v1(
|
||||
def fn(x, **kwargs):
|
||||
"""func"""
|
||||
is_gqa = kwargs.pop("is_gqa", False)
|
||||
if is_gqa:
|
||||
is_tp_row_bias = kwargs.pop("is_tp_row_bias", False)
|
||||
if is_tp_row_bias:
|
||||
tensor = x[:, ...]
|
||||
if isinstance(tensor, paddle.Tensor):
|
||||
res = tensor / tensor_parallel_degree
|
||||
else:
|
||||
res = paddle.to_tensor(tensor, paddle.get_default_dtype()) / tensor_parallel_degree
|
||||
return res
|
||||
elif is_gqa:
|
||||
func = split_or_merge_qkv_func(
|
||||
is_split=is_split,
|
||||
tensor_parallel_degree=tensor_parallel_degree,
|
||||
|
@@ -1117,7 +1117,11 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
sampler_output = self.sampler(logits, self.sampling_metadata)
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0)
|
||||
paddle.distributed.broadcast(
|
||||
sampler_output.sampled_token_ids,
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
else:
|
||||
self.sampler(
|
||||
logits,
|
||||
@@ -1127,10 +1131,26 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
sampler_output = None
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
paddle.distributed.broadcast(self.share_inputs["accept_tokens"], 0)
|
||||
paddle.distributed.broadcast(self.share_inputs["accept_num"], 0)
|
||||
paddle.distributed.broadcast(self.share_inputs["step_idx"], 0)
|
||||
paddle.distributed.broadcast(self.share_inputs["stop_flags"], 0)
|
||||
paddle.distributed.broadcast(
|
||||
self.share_inputs["accept_tokens"],
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
paddle.distributed.broadcast(
|
||||
self.share_inputs["accept_num"],
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
paddle.distributed.broadcast(
|
||||
self.share_inputs["step_idx"],
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
paddle.distributed.broadcast(
|
||||
self.share_inputs["stop_flags"],
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
|
||||
# 5. post process
|
||||
model_output_data = ModelOutputData(
|
||||
@@ -1149,7 +1169,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
is_block_step=self.share_inputs["is_block_step"],
|
||||
full_hidden_states=model_output,
|
||||
msg_queue_id=self.parallel_config.msg_queue_id,
|
||||
mp_rank=self.local_rank,
|
||||
mp_rank=self.parallel_config.tensor_parallel_rank,
|
||||
use_ep=self.parallel_config.use_ep,
|
||||
draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None),
|
||||
actual_draft_token_num=(
|
||||
@@ -1200,13 +1220,15 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
"""
|
||||
if not self.cache_config.enable_chunked_prefill:
|
||||
return
|
||||
for task in tasks:
|
||||
if task.get("prefill_chunk_info", None) is None:
|
||||
continue
|
||||
|
||||
if task.chunk_idx > len(task.prefill_chunk_info):
|
||||
continue
|
||||
self.restore_chunked_prefill_request[task.request_id] = task
|
||||
if tasks is not None:
|
||||
for task in tasks:
|
||||
if task.get("prefill_chunk_info", None) is None:
|
||||
continue
|
||||
|
||||
if task.chunk_idx > len(task.prefill_chunk_info):
|
||||
continue
|
||||
self.restore_chunked_prefill_request[task.request_id] = task
|
||||
|
||||
for id, task in list(self.restore_chunked_prefill_request.items()):
|
||||
idx = task.idx
|
||||
@@ -1384,7 +1406,11 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
skip_idx_list,
|
||||
)
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0)
|
||||
paddle.distributed.broadcast(
|
||||
sampler_output.sampled_token_ids,
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
|
||||
else:
|
||||
self.sampler(
|
||||
@@ -1395,10 +1421,26 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
sampler_output = None
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
paddle.distributed.broadcast(self.share_inputs["accept_tokens"], 0)
|
||||
paddle.distributed.broadcast(self.share_inputs["accept_num"], 0)
|
||||
paddle.distributed.broadcast(self.share_inputs["step_idx"], 0)
|
||||
paddle.distributed.broadcast(self.share_inputs["stop_flags"], 0)
|
||||
paddle.distributed.broadcast(
|
||||
self.share_inputs["accept_tokens"],
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
paddle.distributed.broadcast(
|
||||
self.share_inputs["accept_num"],
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
paddle.distributed.broadcast(
|
||||
self.share_inputs["step_idx"],
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
paddle.distributed.broadcast(
|
||||
self.share_inputs["stop_flags"],
|
||||
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
|
||||
group=self.parallel_config.tp_group,
|
||||
)
|
||||
|
||||
# 5. Post Process
|
||||
model_output_data = ModelOutputData(
|
||||
@@ -1417,7 +1459,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
is_block_step=self.share_inputs["is_block_step"],
|
||||
full_hidden_states=model_output,
|
||||
msg_queue_id=self.parallel_config.msg_queue_id,
|
||||
mp_rank=self.local_rank,
|
||||
mp_rank=self.parallel_config.tensor_parallel_rank,
|
||||
use_ep=self.parallel_config.use_ep,
|
||||
draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None),
|
||||
actual_draft_token_num=(
|
||||
@@ -1454,7 +1496,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
else:
|
||||
self.proposer.run(share_inputs=self.share_inputs)
|
||||
|
||||
# 7. Updata 'infer_seed' and step_cuda()
|
||||
# 7. Update 'infer_seed' and step_cuda()
|
||||
self.share_inputs["infer_seed"].add_(self.infer_seed_increment)
|
||||
self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED
|
||||
|
||||
|
@@ -163,7 +163,7 @@ class PaddleDisWorkerProc:
|
||||
is_server=False,
|
||||
num_client=self.parallel_config.tensor_parallel_size,
|
||||
client_id=self.parallel_config.tensor_parallel_rank,
|
||||
local_data_parallel_id=self.parallel_config.expert_parallel_rank,
|
||||
local_data_parallel_id=self.parallel_config.data_parallel_rank,
|
||||
)
|
||||
|
||||
def init_health_status(self) -> None:
|
||||
@@ -180,7 +180,7 @@ class PaddleDisWorkerProc:
|
||||
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
|
||||
array_size = min(
|
||||
self.max_chips_per_node,
|
||||
self.parallel_config.tensor_parallel_size * self.parallel_config.expert_parallel_size,
|
||||
self.parallel_config.tensor_parallel_size * self.parallel_config.data_parallel_size,
|
||||
)
|
||||
workers_ready = np.zeros(shape=[array_size], dtype=np.int32)
|
||||
self.worker_ready_signal = IPCSignal(
|
||||
@@ -214,7 +214,7 @@ class PaddleDisWorkerProc:
|
||||
)
|
||||
|
||||
# init exist_task_signal
|
||||
workers_exist_task = np.zeros([self.parallel_config.expert_parallel_size], dtype=np.int32)
|
||||
workers_exist_task = np.zeros([self.parallel_config.data_parallel_size], dtype=np.int32)
|
||||
self.exist_task_signal = IPCSignal(
|
||||
name="exist_task_signal",
|
||||
array=workers_exist_task,
|
||||
@@ -224,7 +224,7 @@ class PaddleDisWorkerProc:
|
||||
)
|
||||
|
||||
# init exist_swapped_task_signal
|
||||
workers_swapped_task = np.zeros(shape=[self.parallel_config.expert_parallel_size], dtype=np.int32)
|
||||
workers_swapped_task = np.zeros(shape=[self.parallel_config.data_parallel_size], dtype=np.int32)
|
||||
self.exist_swapped_task_signal = IPCSignal(
|
||||
name="exist_swapped_task_signal",
|
||||
array=workers_swapped_task,
|
||||
@@ -243,32 +243,6 @@ class PaddleDisWorkerProc:
|
||||
create=False,
|
||||
)
|
||||
|
||||
def event_loop_ep(self) -> None:
|
||||
"""
|
||||
Tmp loop function for ep utill DP is supported
|
||||
"""
|
||||
while True:
|
||||
self.worker_healthy_live_signal.value[self.local_rank % self.max_chips_per_node] = int(time.time())
|
||||
|
||||
num_running_requests = 0
|
||||
if self.fd_config.parallel_config.tensor_parallel_rank == 0 and self.task_queue.num_tasks() > 0:
|
||||
tasks, read_finish = self.task_queue.get_tasks()
|
||||
|
||||
req_dicts = []
|
||||
for req_dict, bsz in tasks:
|
||||
num_running_requests = int(bsz)
|
||||
req_dicts.extend(req_dict)
|
||||
logger.info(
|
||||
f"Rank: {self.local_rank}, num_running_requests: {num_running_requests}, "
|
||||
f"num_insert_requests: {len(req_dicts)}"
|
||||
)
|
||||
# Process prefill inputs
|
||||
self.worker.preprocess_new_task(req_dicts, num_running_requests)
|
||||
|
||||
# Execute model to generate token. The generated token will be written to the buffer.
|
||||
# These generated tokens can be obtained through get_output op.
|
||||
self.worker.execute_model(num_running_requests)
|
||||
|
||||
def event_loop_normal(self) -> None:
|
||||
"""Main event loop for Paddle Distrubuted Workers.
|
||||
TODO(gongshaotian): support remote calling of functions that control worker.
|
||||
@@ -287,9 +261,10 @@ class PaddleDisWorkerProc:
|
||||
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
# Synchronize before updating weights
|
||||
paddle.distributed.barrier()
|
||||
paddle.distributed.barrier(self.parallel_config.tp_group)
|
||||
|
||||
self.insert_step = False
|
||||
req_dicts = None
|
||||
self.worker_healthy_live_signal.value[self.local_rank % self.max_chips_per_node] = int(time.time())
|
||||
|
||||
# The first worker detects whether there are tasks in the task queue
|
||||
@@ -302,12 +277,11 @@ class PaddleDisWorkerProc:
|
||||
if self.nnode > 1 and self.parallel_config.tensor_parallel_size > self.max_chips_per_node:
|
||||
self.task_queue.read_finish_flag.set(1)
|
||||
else:
|
||||
self.exist_task_signal.value[self.fd_config.parallel_config.expert_parallel_rank] = 1
|
||||
self.exist_task_signal.value[self.fd_config.parallel_config.data_parallel_rank] = 1
|
||||
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
# Synchronize the signal for other workers
|
||||
# TODO(@wufeisheng): Split TP group and EP group
|
||||
paddle.distributed.barrier()
|
||||
paddle.distributed.barrier(self.parallel_config.tp_group)
|
||||
|
||||
if self.fd_config.load_config.dynamic_load_weight:
|
||||
if self.exist_task_signal.value[0] == 2:
|
||||
@@ -322,7 +296,7 @@ class PaddleDisWorkerProc:
|
||||
)
|
||||
|
||||
if (
|
||||
self.exist_task_signal.value[self.fd_config.parallel_config.expert_parallel_rank] == 1
|
||||
self.exist_task_signal.value[self.fd_config.parallel_config.data_parallel_rank] == 1
|
||||
or self.task_queue.read_finish_flag.get() == 1
|
||||
):
|
||||
logger.info(f"Rank: {self.local_rank} Detected new requests.")
|
||||
@@ -331,7 +305,7 @@ class PaddleDisWorkerProc:
|
||||
tasks, read_finish = self.task_queue.get_tasks()
|
||||
if read_finish:
|
||||
# Ensure that every worker get the task
|
||||
self.exist_task_signal.value[self.fd_config.parallel_config.expert_parallel_rank] = 0
|
||||
self.exist_task_signal.value[self.fd_config.parallel_config.data_parallel_rank] = 0
|
||||
self.task_queue.read_finish_flag.set(0)
|
||||
|
||||
req_dicts = []
|
||||
@@ -348,9 +322,9 @@ class PaddleDisWorkerProc:
|
||||
# Process prefill inputs
|
||||
self.worker.preprocess_new_task(req_dicts, num_running_requests)
|
||||
|
||||
if not self.worker.model_runner.not_need_stop():
|
||||
if (not self.parallel_config.use_ep) and (not self.worker.model_runner.not_need_stop()):
|
||||
if self.ranks > 1:
|
||||
paddle.distributed.barrier()
|
||||
paddle.distributed.barrier(self.parallel_config.tp_group)
|
||||
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
@@ -633,23 +607,23 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
speculative_config = SpeculativeConfig(args.speculative_config)
|
||||
parallel_config = ParallelConfig(vars(args))
|
||||
cache_config = CacheConfig(vars(args))
|
||||
parallel_config.tensor_parallel_size = args.tensor_parallel_size
|
||||
parallel_config.tensor_parallel_rank = local_rank % args.tensor_parallel_size
|
||||
parallel_config.expert_parallel_size = args.expert_parallel_size
|
||||
parallel_config.tensor_parallel_rank = local_rank % parallel_config.tensor_parallel_size
|
||||
parallel_config.data_parallel_rank = local_rank // parallel_config.tensor_parallel_size
|
||||
# config for EP
|
||||
if args.expert_parallel_size > 1:
|
||||
expert_parallel_rank = int(local_rank / args.tensor_parallel_size)
|
||||
if parallel_config.expert_parallel_size > 1:
|
||||
expert_parallel_rank = int(local_rank % parallel_config.expert_parallel_size)
|
||||
if isinstance(model_config.moe_num_experts, list):
|
||||
num_experts = model_config.moe_num_experts[0]
|
||||
else:
|
||||
num_experts = model_config.moe_num_experts
|
||||
|
||||
num_experts_per_rank = num_experts // args.expert_parallel_size
|
||||
num_experts_per_rank = num_experts // parallel_config.expert_parallel_size
|
||||
num_experts_start_offset = expert_parallel_rank * num_experts_per_rank
|
||||
|
||||
parallel_config.expert_parallel_rank = expert_parallel_rank
|
||||
parallel_config.num_experts_per_rank = num_experts_per_rank
|
||||
parallel_config.num_experts_start_offset = num_experts_start_offset
|
||||
parallel_config.set_tp_group()
|
||||
|
||||
load_config = LoadConfig(vars(args))
|
||||
|
||||
@@ -770,11 +744,7 @@ def run_worker_proc() -> None:
|
||||
worker_proc.init_health_status()
|
||||
|
||||
# Start event loop
|
||||
if fd_config.parallel_config.use_ep:
|
||||
# TODO(wufeisheng): Delete this branch
|
||||
worker_proc.event_loop_ep()
|
||||
else:
|
||||
worker_proc.event_loop_normal()
|
||||
worker_proc.event_loop_normal()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user