diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 2b04bd2c4..248bae29d 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -577,6 +577,15 @@ class ParallelConfig: else: self.pd_disaggregation_mode = "None" + # ep+tp strategy: "all_reduce" or "all_to_all" + # all_reduce: qkv_linear + attn + out_linear + allreduce + # all_to_all: allgather + qkv_linear + attn + all2all + out_linear + self.ep_tp_strategy = envs.FD_EP_TP_STRATEGY + assert self.ep_tp_strategy in [ + "all_reduce", + "all_to_all", + ], f"FD_EP_TP_STRATEGY: '{self.ep_tp_strategy}' is not supported, only supports 'all_reduce' or 'all_to_all'." + def set_communicate_group(self): # different tp group id # prevent different tp_groups using the same group_id diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 65aa35df1..34013d52e 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -155,6 +155,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "ENCODE_FEATURE_BOS_SK": lambda: os.getenv("ENCODE_FEATURE_BOS_SK"), # Enable offline perf test mode for PD disaggregation "FD_OFFLINE_PERF_TEST_FOR_PD": lambda: int(os.getenv("FD_OFFLINE_PERF_TEST_FOR_PD", "0")), + # ep+tp strategy: "all_reduce" or "all_to_all" + # all_reduce: qkv_linear + attn + out_linear + allreduce + # all_to_all: allgather + qkv_linear + attn + all2all + out_linear + "FD_EP_TP_STRATEGY": lambda: os.getenv("FD_EP_TP_STRATEGY", "all_reduce"), } diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index b329844da..01636a178 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -799,6 +799,7 @@ class RowParallelLinear(LinearBase): reduce_results: bool = True, skip_quant: bool = False, weight_dtype="", + layer_id: int = -1, ): """ Initialize a linear layer with additional parameters for inference and quantization. @@ -815,14 +816,25 @@ class RowParallelLinear(LinearBase): """ self.fd_config = fd_config self.skip_quant = False + self.ep_size = fd_config.parallel_config.expert_parallel_size + self.tp_size = fd_config.parallel_config.tensor_parallel_size self.nranks = fd_config.parallel_config.tensor_parallel_size self.tp_group = fd_config.parallel_config.tp_group self.hidden_size = fd_config.model_config.hidden_size self.head_dim = fd_config.model_config.head_dim - self.num_heads = fd_config.model_config.num_attention_heads // self.nranks + self.split_token = ( + self.ep_size > 1 + and self.tp_size > 1 + and fd_config.parallel_config.ep_tp_strategy == "all_to_all" + and layer_id >= fd_config.model_config.moe_layer_start_index + and layer_id < fd_config.model_config.num_hidden_layers + ) # Split input_size when using TP inference. - self.input_size = divide(input_size, self.nranks) + if self.split_token: + self.input_size = input_size + else: + self.input_size = divide(input_size, self.nranks) self.output_size = output_size super().__init__( @@ -854,13 +866,30 @@ class RowParallelLinear(LinearBase): self.reduce_results = reduce_results + def all2all_transpose(self, x: paddle.Tensor) -> paddle.Tensor: + token_num = x.shape[0] + token_num_pad = (token_num + self.tp_size - 1) // self.tp_size * self.tp_size + if token_num_pad > token_num: + x_new = paddle.zeros([token_num_pad, x.shape[1]], x.dtype) + x_new[:token_num, :] = x + x = x_new + out = paddle.zeros_like(x) + paddle.distributed.alltoall(out, x, group=self.tp_group) + out.reshape_([self.tp_size, -1, x.shape[1]]) + out = paddle.transpose(out, [1, 0, 2]) + out.reshape_([x.shape[0] // self.tp_size, self.hidden_size]) + return out + def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor: + if self.split_token: + x = self.all2all_transpose(x) + if self.fd_config.quant_config: out = self.quant_method.apply(self, x) else: out = paddle.matmul(x, self.weight) - if self.reduce_results and self.nranks > 1: + if self.reduce_results and self.nranks > 1 and not self.split_token: out = tensor_model_parallel_all_reduce(out, self.tp_group) if not self.fd_config.quant_config and self.add_bias: out = paddle.add(out, self.bias) diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index efa6a2abe..8c44fec26 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -137,6 +137,7 @@ class FusedMoE(nn.Layer): self.ep_size = fd_config.parallel_config.expert_parallel_size self.ep_rank = fd_config.parallel_config.expert_parallel_rank self.tp_group = fd_config.parallel_config.tp_group + self.ep_tp_strategy = self.fd_config.parallel_config.ep_tp_strategy # NOTE(Zhenyu Li): just supports tp_size = 1 when ep_size > 1 in MOE now. if self.ep_size > 1: self.tp_size = 1 @@ -612,7 +613,7 @@ class FusedMoE(nn.Layer): """ token_num = x.shape[0] tp_size = self.fd_config.parallel_config.tensor_parallel_size - if self.ep_size > 1 and tp_size > 1 and token_num >= tp_size: + if self.ep_size > 1 and tp_size > 1 and self.ep_tp_strategy == "all_reduce" and token_num >= tp_size: out = self.forward_split_allgather(x, gate) else: out = self.quant_method.apply(self, x, gate) diff --git a/fastdeploy/model_executor/layers/normalization.py b/fastdeploy/model_executor/layers/normalization.py index 6bcb05ba7..9f6d40d6c 100644 --- a/fastdeploy/model_executor/layers/normalization.py +++ b/fastdeploy/model_executor/layers/normalization.py @@ -28,6 +28,7 @@ else: from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm from fastdeploy.config import FDConfig +from fastdeploy.model_executor.forward_meta import ForwardMeta from .utils import get_tensor @@ -47,6 +48,7 @@ class RMSNorm(nn.Layer): quant_scale: float = None, begin_norm_axis: int = 1, dtype: str = None, + layer_id: int = -1, ) -> None: """ Initializes the RMSNormalization layer. @@ -97,6 +99,30 @@ class RMSNorm(nn.Layer): self.quant_min_bound: int = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0 self.begin_norm_axis: int = begin_norm_axis + self.layer_id = layer_id + parallel_config = self.fd_config.parallel_config + self.ep_size = parallel_config.expert_parallel_size + self.tp_size = parallel_config.tensor_parallel_size + self.tp_rank = parallel_config.tensor_parallel_rank + self.tp_group = parallel_config.tp_group + self.ep_tp_strategy = parallel_config.ep_tp_strategy + self.moe_layer_start_index = self.fd_config.model_config.moe_layer_start_index + is_input_norm = prefix.endswith(".input_layernorm") + is_last_norm = prefix.endswith(".norm") + self.split_x = ( + self.ep_size > 1 + and self.tp_size > 1 + and self.ep_tp_strategy == "all_to_all" + and self.layer_id == self.moe_layer_start_index + and is_input_norm + ) + self.allgather_out = ( + self.ep_size > 1 + and self.tp_size > 1 + and self.ep_tp_strategy == "all_to_all" + and ((self.layer_id > self.moe_layer_start_index and is_input_norm) or is_last_norm) + ) + self.init_weight() def init_weight(self): @@ -124,7 +150,50 @@ class RMSNorm(nn.Layer): weight_tensor = get_tensor(state_dict.pop(self.weight_key)) self.weight.set_value(weight_tensor.astype(self._norm_weight_dtype)) - def forward(self, x, residual_input: Optional[paddle.Tensor] = None) -> paddle.Tensor: + def split(self, x): + """ + Split the input tensor across tensor parallel dimension. + + Args: + x (paddle.Tensor): Input tensor to be split. + + Returns: + paddle.Tensor: Splitted tensor. + """ + token_num = x.shape[0] + token_num_per_rank = (token_num + self.tp_size - 1) // self.tp_size + # AllGather will hang when the data shapes on multi-ranks are different! + start_offset = self.tp_rank * token_num_per_rank + end_offset = (self.tp_rank + 1) * token_num_per_rank + if start_offset >= token_num: + start_offset = token_num + if end_offset > token_num: + end_offset = token_num + part_x = paddle.zeros(shape=[token_num_per_rank, x.shape[1]], dtype=x.dtype) + part_x[: (end_offset - start_offset), :] = x[start_offset:end_offset, :] + return part_x + + def allgather(self, out, token_num): + """ + Gather the output tensor from each tensor parallel rank. + + Args: + out (paddle.Tensor): Output tensor to be gathered. + + Returns: + paddle.Tensor: Gathered tensor. + """ + token_num_per_rank = out.shape[0] + multi_outs = paddle.zeros([token_num_per_rank * self.tp_size, out.shape[1]], dtype=out.dtype) + paddle.distributed.all_gather(multi_outs, out, self.tp_group) + return multi_outs[:token_num, :] + + def forward( + self, + x, + residual_input: Optional[paddle.Tensor] = None, + forward_meta: Optional[ForwardMeta] = None, + ) -> paddle.Tensor: """ Defines the forward computation of the layer. @@ -165,10 +234,18 @@ class RMSNorm(nn.Layer): quant_max_bound=self.quant_max_bound, quant_min_bound=self.quant_min_bound, ) - if residual_input is not None: - return norm_out[0].astype(x_dtype), norm_out[1].astype(residual_input_dtype) + out = norm_out[0].astype(x_dtype) + residual_out = norm_out[1].astype(residual_input_dtype) if residual_input is not None else None + + if self.split_x: + residual_out = self.split(residual_out) + if self.allgather_out: + out = self.allgather(out, forward_meta.ids_remove_padding.shape[0]) + + if residual_input is None: + return out else: - return norm_out[0].astype(x_dtype) + return out, residual_out class LayerNorm(nn.Layer): diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index 06ac35efd..508fb20bf 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -15,6 +15,7 @@ """ import contextlib +import copy import hashlib import inspect import json @@ -267,8 +268,14 @@ def load_ep_checkpoint(cls: PretrainedModel, model_path: str, fd_config: FDConfi filtered_map[k] = weight_list[k] if fd_config.parallel_config.tensor_parallel_size > 1: + no_tp_action_keys = copy.deepcopy(num_local_ffn_keys) + if fd_config.parallel_config.ep_tp_strategy == "all_to_all": + for i in range(fd_config.model_config.moe_layer_start_index, fd_config.model_config.num_hidden_layers): + k = f"ernie.layers.{i}.self_attn.o_proj.weight" + if k in weight_list: + no_tp_action_keys.append(k) tp_actions = cls._get_tensor_parallel_mappings(fd_config.model_config.pretrained_config) - new_actions = {k: v for k, v in tp_actions.items() if k not in num_local_ffn_keys} + new_actions = {k: v for k, v in tp_actions.items() if k not in no_tp_action_keys} state_dict = {} # Get all safetensor file paths that need to be opened diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index 55d0b4304..bbadfdbef 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -235,6 +235,7 @@ class Ernie4_5_Attention(nn.Layer): prefix=f"{prefix}.o_proj", input_size=fd_config.model_config.head_dim * fd_config.model_config.num_attention_heads, output_size=fd_config.model_config.hidden_size, + layer_id=layer_id, ) self.attn = Attention( fd_config=fd_config, @@ -303,6 +304,7 @@ class Ernie4_5_DecoderLayer(nn.Layer): hidden_size=fd_config.model_config.hidden_size, eps=fd_config.model_config.rms_norm_eps, prefix=f"{prefix}.input_layernorm", + layer_id=layer_id, ) self.post_attention_layernorm = RMSNorm( @@ -329,16 +331,27 @@ class Ernie4_5_DecoderLayer(nn.Layer): ): if residual is None: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.input_layernorm( + hidden_states, + forward_meta=forward_meta, + ) else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states, residual = self.input_layernorm( + hidden_states, + residual, + forward_meta=forward_meta, + ) hidden_states = self.self_attn( hidden_states=hidden_states, forward_meta=forward_meta, ) - hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, + residual, + forward_meta=forward_meta, + ) hidden_states = self.mlp(hidden_states) @@ -444,7 +457,7 @@ class Ernie4_5_Model(nn.Layer): hidden_states = hidden_states + residual - out = self.norm(hidden_states) + out = self.norm(hidden_states, forward_meta=forward_meta) if current_platform.is_iluvatar() and forward_meta.attn_backend.mixed: out = forward_meta.attn_backend.reverse_transpose(out) diff --git a/scripts/run_ci_xpu.sh b/scripts/run_ci_xpu.sh index c407d945a..357870c1f 100644 --- a/scripts/run_ci_xpu.sh +++ b/scripts/run_ci_xpu.sh @@ -251,7 +251,7 @@ if [ ${vl_test_exit_code} -ne 0 ]; then fi -echo "============================开始 EP4TP1 测试!============================" +echo "============================开始 EP8TP1 测试!============================" sleep 5 rm -rf log/* rm -f core* @@ -290,12 +290,12 @@ stop_processes if [ ${ep_exit_code} -ne 0 ]; then echo "log/workerlog.0" cat log/workerlog.0 - echo "EP4TP1 相关测试失败,请检查pr代码" + echo "EP8TP1 相关测试失败,请检查pr代码" exit 1 fi -echo "============================开始 EP4TP4 测试!============================" +echo "============================开始 EP8TP8 allreduce 测试!============================" sleep 5 rm -rf log/* rm -f core* @@ -323,11 +323,55 @@ unset BKCL_PCIE_RING unset XSHMEM_MODE unset XSHMEM_QP_NUM_PER_RANK unset BKCL_RDMA_VERBS +unset enable_expert_parallel +unset enable_tensor_parallel stop_processes if [ ${ep_exit_code} -ne 0 ]; then echo "log/workerlog.0" cat log/workerlog.0 - echo "EP4TP4 相关测试失败,请检查pr代码" + echo "EP8TP8 allreduce 相关测试失败,请检查pr代码" + exit 1 +fi + + +echo "============================开始 EP8TP8 all2all 测试!============================" +sleep 5 +rm -rf log/* +rm -f core* +ipcrm --all=msg +xpu-smi +export XPU_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" +export BKCL_ENABLE_XDR=1 +export BKCL_RDMA_NICS=xgbe1,xgbe2,xgbe3,xgbe4 +export BKCL_TRACE_TOPO=1 +export BKCL_PCIE_RING=1 +export XSHMEM_MODE=1 +export XSHMEM_QP_NUM_PER_RANK=32 +export BKCL_RDMA_VERBS=1 + +export enable_expert_parallel=1 +export enable_tensor_parallel=1 +export EP_TP_SPLIT_MODE=1 + +python -m pytest -s --timeout=600 tests/ci_use/XPU_45T/run_ep.py +ep_exit_code=$? + +unset BKCL_ENABLE_XDR +unset BKCL_RDMA_NICS +unset BKCL_TRACE_TOPO +unset BKCL_PCIE_RING +unset XSHMEM_MODE +unset XSHMEM_QP_NUM_PER_RANK +unset BKCL_RDMA_VERBS +unset enable_expert_parallel +unset enable_tensor_parallel +unset EP_TP_SPLIT_MODE +stop_processes + +if [ ${ep_exit_code} -ne 0 ]; then + echo "log/workerlog.0" + cat log/workerlog.0 + echo "EP8TP8 all2all 相关测试失败,请检查pr代码" exit 1 fi