[TSP] Support qwen3 moe tsp + cudagraph (#4871)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* support qwen3_moe tsp mode

* fix

* fix

* update

* update

* update

* fix

* support external_rmsnorm

* update

* fix
This commit is contained in:
Yuanle Liu
2025-11-10 23:37:51 +08:00
committed by GitHub
parent fb2eb403ab
commit 3dc0ffa46d
28 changed files with 173 additions and 273 deletions

View File

@@ -40,6 +40,8 @@ When using FastDeploy to deploy models (including offline inference and service
| ```use_cudagraph``` | `bool` | __[DEPRECATED]__ CUDAGraph is enabled by default since version 2.3. It is recommended to read [graph_optimization.md](./features/graph_optimization.md) carefully before opening. |
| ```graph_optimization_config``` | `dict[str]` | Can configure parameters related to calculation graph optimization, the default value is'{"use_cudagraph":true, "graph_opt_level":0}'Detailed description reference [graph_optimization.md](./features/graph_optimization.md)|
| ```disable_custom_all_reduce``` | `bool` | Disable Custom all-reduce, default: False |
| ```use_internode_ll_two_stage``` | `bool` | Use two stage communication in deepep moe, default: False |
| ```disable_sequence_parallel_moe``` | `bool` | Disable sequence parallel moe, default: False |
| ```splitwise_role``` | `str` | Whether to enable splitwise inference, default value: mixed, supported parameters: ["mixed", "decode", "prefill"] |
| ```innode_prefill_ports``` | `str` | Internal engine startup ports for prefill instances (only required for single-machine PD separation), default: None |
| ```guided_decoding_backend``` | `str` | Specify the guided decoding backend to use, supports `auto`, `xgrammar`, `off`, default: `off` |

View File

@@ -38,6 +38,8 @@
| ```use_cudagraph``` | `bool` | __[已废弃]__ 2.3版本开始 CUDAGraph 默认开启,详细说明参考 [graph_optimization.md](./features/graph_optimization.md) |
| ```graph_optimization_config``` | `dict[str]` | 可以配置计算图优化相关的参数,默认值为'{"use_cudagraph":true, "graph_opt_level":0}',详细说明参考 [graph_optimization.md](./features/graph_optimization.md)|
| ```disable_custom_all_reduce``` | `bool` | 关闭Custom all-reduce默认False |
| ```use_internode_ll_two_stage``` | `bool` | 是否在DeepEP MoE中使用两阶段通信, default: False |
| ```disable_sequence_parallel_moe``` | `bool` | 禁止在TP+EP中使用序列并行优化, default: False |
| ```splitwise_role``` | `str` | 是否开启splitwise推理默认值mixed 支持参数为["mixed", "decode", "prefill"] |
| ```innode_prefill_ports``` | `str` | prefill 实例内部引擎启动端口 仅单机PD分离需要默认值None |
| ```guided_decoding_backend``` | `str` | 指定要使用的guided decoding后端支持 `auto`、`xgrammar`、`off`, 默认为 `off` |

View File

@@ -307,8 +307,8 @@ class ModelConfig:
Read configuration information from environment variables and update the object's attributes.
If an attribute is not present or is an empty string in the environment variables, use the default value.
"""
self.max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM)
self.stop_seqs_max_len = int(envs.FD_STOP_SEQS_MAX_LEN)
self.max_stop_seqs_num = envs.FD_MAX_STOP_SEQS_NUM
self.stop_seqs_max_len = envs.FD_STOP_SEQS_MAX_LEN
def reset_config_value(key, value):
if not hasattr(self, key.lower()):
@@ -548,6 +548,8 @@ class ParallelConfig:
self.do_profile: bool = False
# Use internode_ll_two_stage or not
self.use_internode_ll_two_stage: bool = False
# disable sequence parallel moe
self.disable_sequence_parallel_moe: bool = False
self.pod_ip: str = None
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
@@ -577,14 +579,14 @@ class ParallelConfig:
else:
self.pd_disaggregation_mode = "None"
# ep+tp strategy: "all_reduce" or "all_to_all"
# all_reduce: qkv_linear + attn + out_linear + allreduce
# all_to_all: allgather + qkv_linear + attn + all2all + out_linear
self.ep_tp_strategy = envs.FD_EP_TP_STRATEGY
assert self.ep_tp_strategy in [
"all_reduce",
"all_to_all",
], f"FD_EP_TP_STRATEGY: '{self.ep_tp_strategy}' is not supported, only supports 'all_reduce' or 'all_to_all'."
# disable_sequence_parallel_moe: qkv_linear + attn + out_linear + allreduce
# use_sequence_parallel_moe: allgather + qkv_linear + attn + all2all + out_linear
self.use_sequence_parallel_moe = (
(not self.disable_sequence_parallel_moe)
and self.expert_parallel_size > 1
and self.tensor_parallel_size > 1
)
logger.info(f"use_sequence_parallel_moe: {self.use_sequence_parallel_moe}")
def set_communicate_group(self):
# different tp group id

View File

@@ -240,7 +240,7 @@ class EngineArgs:
disable_custom_all_reduce: bool = False
"""
Flag to enable the custom all-reduce kernel.
Flag to disable the custom all-reduce kernel.
"""
use_internode_ll_two_stage: bool = False
@@ -248,6 +248,19 @@ class EngineArgs:
Flag to use the internode_ll_two_stage kernel.
"""
disable_sequence_parallel_moe: bool = False
"""
# The all_reduce at the end of attention (during o_proj) means that
# inputs are replicated across each rank of the tensor parallel group.
# If using expert-parallelism with DeepEP All2All ops, replicated
# tokens results in useless duplicate computation and communication.
#
# In this case, ensure the input to the experts is sequence parallel
# to avoid the excess work.
#
# This optimization is enabled by default, and can be disabled by using this flag.
"""
engine_worker_queue_port: str = "0"
"""
Port for worker queue communication.
@@ -766,6 +779,12 @@ class EngineArgs:
default=EngineArgs.use_internode_ll_two_stage,
help="Flag to use the internode_ll_two_stage kernel.",
)
parallel_group.add_argument(
"--disable-sequence-parallel-moe",
action="store_true",
default=EngineArgs.disable_sequence_parallel_moe,
help="Flag to disable disable the sequence parallel moe.",
)
parallel_group.add_argument(
"--max-num-seqs",
type=int,

View File

@@ -842,6 +842,8 @@ class AsyncLLMEngine:
"dynamic_load_weight": self.cfg.load_config.dynamic_load_weight,
"disable_any_whitespace": self.cfg.structured_outputs_config.disable_any_whitespace,
"disable_custom_all_reduce": self.cfg.parallel_config.disable_custom_all_reduce,
"use_internode_ll_two_stage": self.cfg.parallel_config.use_internode_ll_two_stage,
"disable_sequence_parallel_moe": self.cfg.parallel_config.disable_sequence_parallel_moe,
"enable_logprob": self.cfg.model_config.enable_logprob,
"lm_head_fp32": self.cfg.model_config.lm_head_fp32,
}

View File

@@ -286,7 +286,7 @@ class LLMEngine:
if request.get("stop_seqs_len") is not None:
stop_seqs_len = request.get("stop_seqs_len")
max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM)
max_stop_seqs_num = envs.FD_MAX_STOP_SEQS_NUM
if len(stop_seqs_len) > max_stop_seqs_num:
error_msg = (
f"Length of stop ({stop_seqs_len}) exceeds the limit max_stop_seqs_num({max_stop_seqs_num})."
@@ -294,7 +294,7 @@ class LLMEngine:
)
llm_logger.error(error_msg)
raise EngineError(error_msg, error_code=400)
stop_seqs_max_len = int(envs.FD_STOP_SEQS_MAX_LEN)
stop_seqs_max_len = envs.FD_STOP_SEQS_MAX_LEN
for single_stop_seq_len in stop_seqs_len:
if single_stop_seq_len > stop_seqs_max_len:
error_msg = (
@@ -568,6 +568,7 @@ class LLMEngine:
"disable_any_whitespace": self.cfg.structured_outputs_config.disable_any_whitespace,
"disable_custom_all_reduce": self.cfg.parallel_config.disable_custom_all_reduce,
"use_internode_ll_two_stage": self.cfg.parallel_config.use_internode_ll_two_stage,
"disable_sequence_parallel_moe": self.cfg.parallel_config.disable_sequence_parallel_moe,
"enable_logprob": self.cfg.model_config.enable_logprob,
"lm_head_fp32": self.cfg.model_config.lm_head_fp32,
}

View File

@@ -235,7 +235,7 @@ class EngineClient:
if "stop_seqs_len" in task:
stop_seqs_len = task["stop_seqs_len"]
max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM)
max_stop_seqs_num = envs.FD_MAX_STOP_SEQS_NUM
if len(stop_seqs_len) > max_stop_seqs_num:
error_msg = (
f"Length of stop ({stop_seqs_len}) exceeds the limit max_stop_seqs_num({max_stop_seqs_num})."
@@ -243,7 +243,7 @@ class EngineClient:
)
api_server_logger.error(error_msg)
raise EngineError(error_msg, error_code=400)
stop_seqs_max_len = int(envs.FD_STOP_SEQS_MAX_LEN)
stop_seqs_max_len = envs.FD_STOP_SEQS_MAX_LEN
for single_stop_seq_len in stop_seqs_len:
if single_stop_seq_len > stop_seqs_max_len:
error_msg = (

View File

@@ -35,9 +35,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Model download cache directory.
"FD_MODEL_CACHE": lambda: os.getenv("FD_MODEL_CACHE", None),
# Maximum number of stop sequences.
"FD_MAX_STOP_SEQS_NUM": lambda: os.getenv("FD_MAX_STOP_SEQS_NUM", "5"),
"FD_MAX_STOP_SEQS_NUM": lambda: int(os.getenv("FD_MAX_STOP_SEQS_NUM", "5")),
# Maximum length of stop sequences.
"FD_STOP_SEQS_MAX_LEN": lambda: os.getenv("FD_STOP_SEQS_MAX_LEN", "8"),
"FD_STOP_SEQS_MAX_LEN": lambda: int(os.getenv("FD_STOP_SEQS_MAX_LEN", "8")),
# GPU devices that will be used. This is a string that
# splited by comma, such as 0,1,2.
"CUDA_VISIBLE_DEVICES": lambda: os.getenv("CUDA_VISIBLE_DEVICES", None),
@@ -159,10 +159,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_OFFLINE_PERF_TEST_FOR_PD": lambda: int(os.getenv("FD_OFFLINE_PERF_TEST_FOR_PD", "0")),
"FD_ENABLE_E2W_TENSOR_CONVERT": lambda: int(os.getenv("FD_ENABLE_E2W_TENSOR_CONVERT", "0")),
"FD_ENGINE_TASK_QUEUE_WITH_SHM": lambda: int(os.getenv("FD_ENGINE_TASK_QUEUE_WITH_SHM", "0")),
# ep+tp strategy: "all_reduce" or "all_to_all"
# all_reduce: qkv_linear + attn + out_linear + allreduce
# all_to_all: allgather + qkv_linear + attn + all2all + out_linear
"FD_EP_TP_STRATEGY": lambda: os.getenv("FD_EP_TP_STRATEGY", "all_reduce"),
}

View File

@@ -550,7 +550,6 @@ class DataProcessor(BaseDataProcessor):
tokenize=False,
split_special_tokens=False,
add_special_tokens=False,
return_tensors="pd",
**kwargs,
)
request["prompt_tokens"] = spliced_message

View File

@@ -823,9 +823,7 @@ class RowParallelLinear(LinearBase):
self.hidden_size = fd_config.model_config.hidden_size
self.head_dim = fd_config.model_config.head_dim
self.split_token = (
self.ep_size > 1
and self.tp_size > 1
and fd_config.parallel_config.ep_tp_strategy == "all_to_all"
fd_config.parallel_config.use_sequence_parallel_moe
and layer_id >= fd_config.model_config.moe_layer_start_index
and layer_id < fd_config.model_config.num_hidden_layers
)
@@ -853,7 +851,7 @@ class RowParallelLinear(LinearBase):
self.quant_method.create_weights(
self,
split_axis=0,
output_dim=False,
output_dim=None if self.split_token else False,
weight_loader=(
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
),
@@ -877,7 +875,7 @@ class RowParallelLinear(LinearBase):
paddle.distributed.alltoall(out, x, group=self.tp_group)
out.reshape_([self.tp_size, -1, x.shape[1]])
out = paddle.transpose(out, [1, 0, 2])
out.reshape_([x.shape[0] // self.tp_size, self.hidden_size])
out.reshape_([x.shape[0] // self.tp_size, self.input_size])
return out
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:

View File

@@ -137,7 +137,6 @@ class FusedMoE(nn.Layer):
self.ep_size = fd_config.parallel_config.expert_parallel_size
self.ep_rank = fd_config.parallel_config.expert_parallel_rank
self.tp_group = fd_config.parallel_config.tp_group
self.ep_tp_strategy = self.fd_config.parallel_config.ep_tp_strategy
# NOTE(Zhenyu Li): just supports tp_size = 1 when ep_size > 1 in MOE now.
if self.ep_size > 1:
self.tp_size = 1
@@ -582,20 +581,18 @@ class FusedMoE(nn.Layer):
Forward split allgather function.
"""
token_num = x.shape[0]
tp_size = self.fd_config.parallel_config.tensor_parallel_size
tp_rank = self.fd_config.parallel_config.tensor_parallel_rank
token_num_per_rank = (token_num + tp_size - 1) // tp_size
token_num_per_rank = (token_num + self.tp_size - 1) // self.tp_size
# AllGather will hang when the data shapes on multi-ranks are different!
part_x = paddle.zeros(shape=[token_num_per_rank, x.shape[1]], dtype=x.dtype)
start_offset = tp_rank * token_num_per_rank
end_offset = (tp_rank + 1) * token_num_per_rank
start_offset = self.tp_rank * token_num_per_rank
end_offset = (self.tp_rank + 1) * token_num_per_rank
if start_offset >= token_num:
start_offset = token_num
if end_offset > token_num:
end_offset = token_num
part_x[: (end_offset - start_offset), :] = x[start_offset:end_offset, :]
out = self.quant_method.apply(self, part_x, gate)
multi_outs = paddle.zeros([token_num_per_rank * tp_size, x.shape[1]], dtype=x.dtype)
multi_outs = paddle.zeros([token_num_per_rank * self.tp_size, x.shape[1]], dtype=x.dtype)
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
out = multi_outs[:token_num, :]
return out
@@ -612,8 +609,12 @@ class FusedMoE(nn.Layer):
"""
token_num = x.shape[0]
tp_size = self.fd_config.parallel_config.tensor_parallel_size
if self.ep_size > 1 and tp_size > 1 and self.ep_tp_strategy == "all_reduce" and token_num >= tp_size:
if (
self.ep_size > 1
and self.tp_size > 1
and (not self.fd_config.parallel_config.use_sequence_parallel_moe)
and token_num >= self.tp_size
):
out = self.forward_split_allgather(x, gate)
else:
out = self.quant_method.apply(self, x, gate)

View File

@@ -100,27 +100,19 @@ class RMSNorm(nn.Layer):
self.begin_norm_axis: int = begin_norm_axis
self.layer_id = layer_id
parallel_config = self.fd_config.parallel_config
self.ep_size = parallel_config.expert_parallel_size
self.tp_size = parallel_config.tensor_parallel_size
self.tp_rank = parallel_config.tensor_parallel_rank
self.tp_group = parallel_config.tp_group
self.ep_tp_strategy = parallel_config.ep_tp_strategy
self.moe_layer_start_index = self.fd_config.model_config.moe_layer_start_index
self.ep_size = self.fd_config.parallel_config.expert_parallel_size
self.tp_size = self.fd_config.parallel_config.tensor_parallel_size
self.tp_rank = self.fd_config.parallel_config.tensor_parallel_rank
self.tp_group = self.fd_config.parallel_config.tp_group
is_input_norm = prefix.endswith(".input_layernorm")
is_last_norm = prefix.endswith(".norm")
self.split_x = (
self.ep_size > 1
and self.tp_size > 1
and self.ep_tp_strategy == "all_to_all"
and self.layer_id == self.moe_layer_start_index
self.fd_config.parallel_config.use_sequence_parallel_moe
and self.layer_id == self.fd_config.model_config.moe_layer_start_index
and is_input_norm
)
self.allgather_out = (
self.ep_size > 1
and self.tp_size > 1
and self.ep_tp_strategy == "all_to_all"
and ((self.layer_id > self.moe_layer_start_index and is_input_norm) or is_last_norm)
self.allgather_out = self.fd_config.parallel_config.use_sequence_parallel_moe and (
(self.layer_id > self.fd_config.model_config.moe_layer_start_index and is_input_norm) or is_last_norm
)
self.init_weight()
@@ -193,6 +185,7 @@ class RMSNorm(nn.Layer):
x,
residual_input: Optional[paddle.Tensor] = None,
forward_meta: Optional[ForwardMeta] = None,
external_rmsnorm: Optional[Callable] = None,
) -> paddle.Tensor:
"""
Defines the forward computation of the layer.
@@ -215,37 +208,46 @@ class RMSNorm(nn.Layer):
if residual_input is not None:
residual_input_dtype = residual_input.dtype
residual_input = residual_input.astype(self.weight.dtype)
if current_platform.is_gcu():
if residual_input is None:
norm_out = rms_norm(x, self.weight, self.eps)
return norm_out.astype(x_dtype)
norm_out = self.norm_func(x, residual_input, self.weight, self.eps)
else:
norm_out = self.norm_func(
x,
norm_weight=self.weight,
norm_bias=None,
epsilon=self.eps,
begin_norm_axis=self.begin_norm_axis,
bias=self.bias,
residual=residual_input,
quant_scale=(-1 if self.quant_scale is None else self.quant_scale),
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
out = norm_out[0].astype(x_dtype)
residual_out = norm_out[1].astype(residual_input_dtype) if residual_input is not None else None
if self.split_x:
residual_out = self.split(residual_out)
if self.allgather_out:
out = self.allgather(out, forward_meta.ids_remove_padding.shape[0])
if residual_input is None:
return out
residual_out = x
if external_rmsnorm is None:
if current_platform.is_gcu():
if residual_input is None:
norm_out = rms_norm(x, self.weight, self.eps)
return norm_out.astype(x_dtype), residual_out
norm_out = self.norm_func(x, residual_input, self.weight, self.eps)
else:
norm_out = self.norm_func(
x,
norm_weight=self.weight,
norm_bias=None,
epsilon=self.eps,
begin_norm_axis=self.begin_norm_axis,
bias=self.bias,
residual=residual_input,
quant_scale=(-1 if self.quant_scale is None else self.quant_scale),
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
else:
return out, residual_out
if residual_input is not None:
x = x + residual_input
norm_out = external_rmsnorm(x, self.weight, self.eps), x
out = norm_out[0].astype(x_dtype)
if residual_input is not None:
residual_out = norm_out[1].astype(residual_input_dtype)
if self.split_x:
assert residual_out is not None
residual_out = self.split(residual_out)
if self.allgather_out:
assert forward_meta is not None
out = self.allgather(out, forward_meta.ids_remove_padding.shape[0])
return out, residual_out
class LayerNorm(nn.Layer):

View File

@@ -130,7 +130,9 @@ class BlockWiseFP8LinearMethod(QuantMethodBase):
dtype="float32",
is_bias=False,
)
extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"]
extra_weight_attrs["output_dim"] = (
not extra_weight_attrs["output_dim"] if extra_weight_attrs["output_dim"] is not None else None
)
extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch"
set_weight_attrs(

View File

@@ -269,7 +269,7 @@ def load_ep_checkpoint(cls: PretrainedModel, model_path: str, fd_config: FDConfi
if fd_config.parallel_config.tensor_parallel_size > 1:
no_tp_action_keys = copy.deepcopy(num_local_ffn_keys)
if fd_config.parallel_config.ep_tp_strategy == "all_to_all":
if fd_config.parallel_config.use_sequence_parallel_moe:
for i in range(fd_config.model_config.moe_layer_start_index, fd_config.model_config.num_hidden_layers):
k = f"ernie.layers.{i}.self_attn.o_proj.weight"
if k in weight_list:

View File

@@ -271,6 +271,7 @@ class DeepseekV3MLAAttention(nn.Layer):
input_size=self.num_attention_heads * self.v_head_dim,
output_size=self.hidden_size,
with_bias=False,
layer_id=layer_id,
)
self.kv_b_proj_bmm = KVBatchLinear(
@@ -344,13 +345,13 @@ class DeepseekV3MLAAttention(nn.Layer):
[self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], axis=-1
)
query = self.q_a_layernorm(query)
query = self.q_a_layernorm(query)[0]
query = self.q_b_proj(query)
query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
compressed_kv = self.kv_a_layernorm(compressed_kv)
compressed_kv = self.kv_a_layernorm(compressed_kv)[0]
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
@@ -479,6 +480,7 @@ class DeepSeekV3DecoderLayer(nn.Layer):
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.input_layernorm",
layer_id=layer_id,
)
self.post_attention_layernorm = RMSNorm(
@@ -486,6 +488,7 @@ class DeepSeekV3DecoderLayer(nn.Layer):
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.post_attention_layernorm",
layer_id=layer_id,
)
def load_state_dict(self, state_dict):
@@ -504,11 +507,9 @@ class DeepSeekV3DecoderLayer(nn.Layer):
mask_encoder_batch: paddle.Tensor,
):
""" """
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states, residual = self.input_layernorm(
hidden_states, residual_input=residual, forward_meta=forward_meta
)
hidden_states = self.self_attn(forward_meta, hidden_states, position_ids, mask_encoder_batch)
@@ -588,8 +589,7 @@ class DeepSeekV3Model(nn.Layer):
position_ids,
mask_encoder_batch,
)
hidden_states = hidden_states + residual
out = self.norm(hidden_states)
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
return out

View File

@@ -312,6 +312,7 @@ class Ernie4_5_DecoderLayer(nn.Layer):
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.post_attention_layernorm",
layer_id=layer_id,
)
def load_state_dict(self, state_dict):
@@ -329,18 +330,9 @@ class Ernie4_5_DecoderLayer(nn.Layer):
hidden_states: paddle.Tensor,
residual: paddle.Tensor = None,
):
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(
hidden_states,
forward_meta=forward_meta,
)
else:
hidden_states, residual = self.input_layernorm(
hidden_states,
residual,
forward_meta=forward_meta,
)
hidden_states, residual = self.input_layernorm(
hidden_states, residual_input=residual, forward_meta=forward_meta
)
hidden_states = self.self_attn(
hidden_states=hidden_states,
@@ -350,7 +342,6 @@ class Ernie4_5_DecoderLayer(nn.Layer):
hidden_states, residual = self.post_attention_layernorm(
hidden_states,
residual,
forward_meta=forward_meta,
)
hidden_states = self.mlp(hidden_states)
@@ -455,9 +446,7 @@ class Ernie4_5_Model(nn.Layer):
for i in range(self.num_layers):
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)
hidden_states = hidden_states + residual
out = self.norm(hidden_states, forward_meta=forward_meta)
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
if current_platform.is_iluvatar() and forward_meta.attn_backend.mixed:
out = forward_meta.attn_backend.reverse_transpose(out)

View File

@@ -318,7 +318,7 @@ class Ernie4_5_MTPModel(nn.Layer):
"""
inputs_embedding = self.embed_tokens(ids_remove_padding=ids_remove_padding)
inputs_embedding = paddle.concat(
[self.enorm(inputs_embedding), self.hnorm(previous_hidden_states)],
[self.enorm(inputs_embedding)[0], self.hnorm(previous_hidden_states)[0]],
axis=-1,
)
hidden_states = self.eh_proj(inputs_embedding)
@@ -326,9 +326,7 @@ class Ernie4_5_MTPModel(nn.Layer):
for i in range(self.num_layers):
hidden_states, residual = self.mtp_block[i](forward_meta, hidden_states, residual)
hidden_states = hidden_states + residual
hidden_states = self.norm(hidden_states)
hidden_states = self.norm(hidden_states, residual)[0]
return hidden_states

View File

@@ -358,6 +358,7 @@ class Ernie4_5_VLDecoderLayer(nn.Layer):
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.input_layernorm",
layer_id=layer_id,
)
self.post_attention_layernorm = RMSNorm(
@@ -365,6 +366,7 @@ class Ernie4_5_VLDecoderLayer(nn.Layer):
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.post_attention_layernorm",
layer_id=layer_id,
)
def load_state_dict(self, state_dict):
@@ -380,11 +382,9 @@ class Ernie4_5_VLDecoderLayer(nn.Layer):
residual: paddle.Tensor = None,
vl_moe_meta: VLMoEMeta = None,
):
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states, residual = self.input_layernorm(
hidden_states, residual_input=residual, forward_meta=forward_meta
)
hidden_states = self.self_attn(
hidden_states=hidden_states,
@@ -546,8 +546,7 @@ class Ernie4_5_VLModel(nn.Layer):
vl_moe_meta,
)
hidden_states = hidden_states + residual
out = self.norm(hidden_states)
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
return out

View File

@@ -194,6 +194,7 @@ class Glm4MoeAttention(nn.Layer):
prefix=f"{prefix}.o_proj",
input_size=fd_config.model_config.num_attention_heads * fd_config.model_config.head_dim,
output_size=fd_config.model_config.hidden_size,
layer_id=layer_id,
)
self.attn = Attention(
@@ -229,8 +230,8 @@ class Glm4MoeAttention(nn.Layer):
if self.use_qk_norm:
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size], axis=-1)
q = self.q_norm(q.reshape([-1, self.num_heads, self.head_dim])).reshape(q.shape)
k = self.k_norm(k.reshape([-1, self.num_kv_heads, self.head_dim])).reshape(k.shape)
q = self.q_norm(q.reshape([-1, self.num_heads, self.head_dim]))[0].reshape(q.shape)
k = self.k_norm(k.reshape([-1, self.num_kv_heads, self.head_dim]))[0].reshape(k.shape)
qkv_out = paddle.concat([q, k, v], axis=-1)
atten_out = self.attn(
@@ -275,6 +276,7 @@ class Glm4MoeDecoderLayer(nn.Layer):
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.input_layernorm",
layer_id=layer_id,
)
self.post_attention_layernorm = RMSNorm(
@@ -282,6 +284,7 @@ class Glm4MoeDecoderLayer(nn.Layer):
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.post_attention_layernorm",
layer_id=layer_id,
)
def forward(
@@ -291,11 +294,9 @@ class Glm4MoeDecoderLayer(nn.Layer):
residual: paddle.Tensor = None,
):
""" """
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states, residual = self.input_layernorm(
hidden_states, residual_input=residual, forward_meta=forward_meta
)
hidden_states = self.self_attn(
hidden_states=hidden_states,
@@ -366,9 +367,8 @@ class Glm4MoeModel(nn.Layer):
for i in range(self.num_layers):
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)
hidden_states = hidden_states + residual
out = self.norm(hidden_states)
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
return out

View File

@@ -150,6 +150,7 @@ class GptOssDecoderLayer(nn.Layer):
hidden_size=hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.post_attention_layernorm",
layer_id=layer_id,
)
self.mlp = GptOssMoe(fd_config, layer_id, prefix=f"{prefix}.mlp")
@@ -159,11 +160,9 @@ class GptOssDecoderLayer(nn.Layer):
hidden_states: paddle.Tensor,
residual: paddle.Tensor = None,
):
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states, residual = self.input_layernorm(
hidden_states, residual_input=residual, forward_meta=forward_meta
)
hidden_states = self.self_attn(
hidden_states=hidden_states,
@@ -213,9 +212,8 @@ class GptOssModel(nn.Layer):
residual = None
for i in range(self.num_layers):
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)
hidden_states = hidden_states + residual
hidden_states = self.norm(hidden_states)
hidden_states = self.norm(hidden_states, residual)[0]
return hidden_states

View File

@@ -20,7 +20,6 @@ from typing import Dict, Optional, Union
import numpy as np
import paddle
import paddle.nn as nn
from paddleformers.transformers import PretrainedModel
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.forward_meta import ForwardMeta
@@ -104,9 +103,7 @@ class PaddleOCRVLModel(nn.Layer):
for i in range(self.num_layers):
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)
hidden_states = hidden_states + residual
out = self.norm(hidden_states)
out = self.norm(hidden_states, residual)[0]
return out
@@ -257,94 +254,3 @@ class PaddleOCRVLForConditionalGeneration(ModelForCasualLM):
)
return hidden_states
class PaddleOCRVLPretrainedModel(PretrainedModel):
config_class = FDConfig
def _init_weight(self, layer):
"""
_init_weight
"""
return None
@classmethod
def arch_name(self):
return "PaddleOCRVLForConditionalGeneration"
from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm
from fastdeploy.model_executor.models.utils import LayerIdPlaceholder as layerid
from fastdeploy.model_executor.models.utils import WeightMeta
weight_infos = [
WeightMeta(
f".layers.{{{layerid.LAYER_ID}}}.self_attn.qkv_proj.weight",
True,
tsm.GQA,
),
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.o_proj.weight", False),
WeightMeta(
f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.up_gate_proj.weight",
True,
tsm.PairFused,
),
WeightMeta(f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.down_proj.weight", False),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.TEXT_EXPERT_ID}}}.up_gate_proj.weight",
True,
tsm.PairFused,
),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.TEXT_EXPERT_ID}}}.down_proj.weight",
False,
),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.IMG_EXPERT_ID}}}.up_gate_proj.weight",
True,
tsm.PairFused,
),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.IMG_EXPERT_ID}}}.down_proj.weight",
False,
),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.up_gate_proj.weight",
True,
tsm.PairFused,
),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.down_proj.weight",
False,
),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.down_proj.weight",
False,
),
WeightMeta(".embed_tokens.weight", False),
WeightMeta("lm_head.weight", True),
]
weight_vison = [
# resampler_model
WeightMeta("ernie.resampler_model.spatial_linear.0.weight", False),
WeightMeta("resampler_model.spatial_linear.0.weight", False),
# vision
WeightMeta(
f"vision_model.blocks.{{{layerid.LAYER_ID}}}.attn.proj.weight",
False,
),
WeightMeta(f"vision_model.blocks.{{{layerid.LAYER_ID}}}.mlp.fc2.weight", False),
WeightMeta(f"vision_model.blocks.{{{layerid.LAYER_ID}}}.mlp.fc1.weight", True),
WeightMeta(f"vision_model.blocks.{{{layerid.LAYER_ID}}}.mlp.fc1.bias", True),
WeightMeta(
f"vision_model.blocks.{{{layerid.LAYER_ID}}}.attn.qkv.weight",
True,
tsm.GQA,
),
WeightMeta(
f"vision_model.blocks.{{{layerid.LAYER_ID}}}.attn.qkv.bias",
True,
tsm.GQA,
),
]

View File

@@ -176,6 +176,7 @@ class Qwen2DecoderLayer(nn.Layer):
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.post_attention_layernorm",
layer_id=layer_id,
)
def load_state_dict(self, state_dict):
@@ -193,11 +194,9 @@ class Qwen2DecoderLayer(nn.Layer):
):
""" """
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states, residual = self.input_layernorm(
hidden_states, residual_input=residual, forward_meta=forward_meta
)
hidden_states = self.self_attn(
hidden_states=hidden_states,
@@ -285,9 +284,7 @@ class Qwen2Model(nn.Layer):
for i in range(self.num_layers):
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)
hidden_states = hidden_states + residual
out = self.norm(hidden_states)
out = self.norm(hidden_states, residual)[0]
return out

View File

@@ -124,9 +124,7 @@ class Qwen2_5_VLModel(nn.Layer):
residual,
)
hidden_states = hidden_states + residual
out = self.norm(hidden_states)
out = self.norm(hidden_states, residual)[0]
return out
@@ -262,21 +260,6 @@ class Qwen2_5_VLForConditionalGeneration(ModelForCasualLM):
return logits
def empty_input_forward(self):
"""
empty_input_forward
"""
fake_hidden_states = paddle.empty(
shape=[0, self.fd_config.model_config.hidden_size],
dtype=paddle.get_default_dtype(),
)
for i in range(
self.fd_config.model_config.moe_layer_start_index,
self.fd_config.model_config.num_hidden_layers,
):
self.ernie.layers[i].mlp.text_fused_moe(fake_hidden_states)
self.ernie.layers[i].mlp.image_fused_moe(fake_hidden_states)
def get_input_embeddings(
self,
ids_remove_padding: paddle.Tensor,

View File

@@ -66,6 +66,7 @@ class Qwen3Attention(nn.Layer):
prefix=f"{prefix}.o_proj",
input_size=fd_config.model_config.head_dim * fd_config.model_config.num_attention_heads,
output_size=fd_config.model_config.hidden_size,
layer_id=layer_id,
)
self.attn = Attention(
@@ -114,11 +115,11 @@ class Qwen3Attention(nn.Layer):
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size], axis=-1)
q_by_head = q.reshape([*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim])
q_by_head = self.q_norm(q_by_head)
q_by_head = self.q_norm(q_by_head)[0]
q = q_by_head.reshape(q.shape)
k_by_head = k.reshape([*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim])
k_by_head = self.k_norm(k_by_head)
k_by_head = self.k_norm(k_by_head)[0]
k = k_by_head.reshape(k.shape)
qkv_out = paddle.concat([q, k, v], axis=-1)
@@ -216,9 +217,7 @@ class Qwen3Model(nn.Layer):
for i in range(self.num_layers):
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)
hidden_states = hidden_states + residual
out = self.norm(hidden_states)
out = self.norm(hidden_states, residual)[0]
return out

View File

@@ -167,15 +167,17 @@ class Qwen3DecoderLayer(nn.Layer):
self.input_layernorm = RMSNorm(
fd_config,
hidden_size=fd_config.model_config.hidden_size,
eps=1e-6,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.input_layernorm",
layer_id=layer_id,
)
self.post_attention_layernorm = RMSNorm(
fd_config,
hidden_size=fd_config.model_config.hidden_size,
eps=1e-6,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.post_attention_layernorm",
layer_id=layer_id,
)
def load_state_dict(self, state_dict):
@@ -192,11 +194,9 @@ class Qwen3DecoderLayer(nn.Layer):
residual: paddle.Tensor = None,
):
""" """
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states, residual = self.input_layernorm(
hidden_states, residual_input=residual, forward_meta=forward_meta
)
hidden_states = self.self_attn(
hidden_states=hidden_states,
@@ -251,7 +251,7 @@ class Qwen3MoeModel(nn.Layer):
self.norm = RMSNorm(
fd_config,
hidden_size=fd_config.model_config.hidden_size,
eps=1e-6,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.norm",
)
@@ -275,16 +275,14 @@ class Qwen3MoeModel(nn.Layer):
ids_remove_padding: paddle.Tensor,
forward_meta: ForwardMeta,
):
""" """
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding)
residual = None
for i in range(self.num_layers):
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)
hidden_states = hidden_states + residual
out = self.norm(hidden_states)
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
return out

View File

@@ -660,6 +660,11 @@ def parse_args():
action="store_true",
help="enable custom all-reduce",
)
parser.add_argument(
"--disable_sequence_parallel_moe",
action="store_true",
help="disable sequence parallel moe",
)
parser.add_argument("--splitwise_role", type=str, default="mixed", help="splitwise role")
parser.add_argument(
"--tensor_parallel_size",

View File

@@ -360,6 +360,7 @@ export BKCL_RDMA_VERBS=1
export enable_expert_parallel=1
export enable_tensor_parallel=1
export disable_sequence_parallel_moe=1
python -m pytest -s --timeout=600 tests/ci_use/XPU_45T/run_ep.py
ep_exit_code=$?
@@ -373,6 +374,7 @@ unset XSHMEM_QP_NUM_PER_RANK
unset BKCL_RDMA_VERBS
unset enable_expert_parallel
unset enable_tensor_parallel
unset disable_sequence_parallel_moe
stop_processes
if [ ${ep_exit_code} -ne 0 ]; then
@@ -400,7 +402,6 @@ export BKCL_RDMA_VERBS=1
export enable_expert_parallel=1
export enable_tensor_parallel=1
export FD_EP_TP_STRATEGY=all_to_all
python -m pytest -s --timeout=600 tests/ci_use/XPU_45T/run_ep.py
ep_exit_code=$?
@@ -414,7 +415,6 @@ unset XSHMEM_QP_NUM_PER_RANK
unset BKCL_RDMA_VERBS
unset enable_expert_parallel
unset enable_tensor_parallel
unset FD_EP_TP_STRATEGY
stop_processes
if [ ${ep_exit_code} -ne 0 ]; then

View File

@@ -26,6 +26,7 @@ def test_fd_ep():
enable_expert_parallel = strtobool(os.getenv("enable_expert_parallel", "1"))
enable_tensor_parallel = strtobool(os.getenv("enable_tensor_parallel", "0"))
disable_sequence_parallel_moe = strtobool(os.getenv("disable_sequence_parallel_moe", "0"))
print(f"enable_expert_parallel: {enable_expert_parallel}, enable_tensor_parallel: {enable_tensor_parallel}")
if enable_expert_parallel:
if enable_tensor_parallel:
@@ -47,6 +48,7 @@ def test_fd_ep():
enable_expert_parallel=enable_expert_parallel,
tensor_parallel_size=tensor_parallel_size,
data_parallel_size=data_parallel_size,
disable_sequence_parallel_moe=disable_sequence_parallel_moe,
max_model_len=8192,
quantization="wint4",
engine_worker_queue_port=engine_worker_queue_port,