polish code with new pre-commit rule (#2923)

This commit is contained in:
Zero Rains
2025-07-19 23:19:27 +08:00
committed by GitHub
parent b8676d71a8
commit 25698d56d1
424 changed files with 14307 additions and 13518 deletions

View File

@@ -25,12 +25,15 @@ from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.decorator import \
support_graph_optimization
from fastdeploy.model_executor.graph_optimization.decorator import (
support_graph_optimization,
)
from fastdeploy.model_executor.layers.activation import SiluAndMul
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear, RowParallelLinear)
MergedColumnParallelLinear,
RowParallelLinear,
)
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
from fastdeploy.model_executor.layers.normalization import RMSNorm
@@ -39,8 +42,7 @@ from fastdeploy.model_executor.models.qwen3 import Qwen3Attention
class Qwen3MLP(nn.Layer):
"""
"""
""" """
def __init__(
self,
@@ -74,14 +76,12 @@ class Qwen3MLP(nn.Layer):
)
def load_state_dict(self, state_dict):
"""
"""
""" """
self.up_gate_proj.load_state_dict(state_dict)
self.down_proj.load_state_dict(state_dict)
def forward(self, x):
"""
"""
""" """
gate_up_out = self.up_gate_proj(x)
act_out = self.act_fn(gate_up_out)
down_out = self.down_proj(act_out)
@@ -89,8 +89,7 @@ class Qwen3MLP(nn.Layer):
class Qwen3DecoderLayer(nn.Layer):
"""
"""
""" """
def __init__(
self,
@@ -98,7 +97,7 @@ class Qwen3DecoderLayer(nn.Layer):
prefix: str = "",
) -> None:
super().__init__()
layer_id = int(prefix.split(sep='.')[-1])
layer_id = int(prefix.split(sep=".")[-1])
self.self_attn = Qwen3Attention(
fd_config=fd_config,
@@ -106,24 +105,24 @@ class Qwen3DecoderLayer(nn.Layer):
prefix=f"{prefix}.self_attn",
)
weight_key_map = {
"gate_weight_key":
f"{prefix}.mlp.gate.weight",
"up_gate_proj_expert_weight_key":
f"{prefix}.mlp.experts.{{}}.up_gate_proj.weight",
"down_proj_expert_weight_key":
f"{prefix}.mlp.experts.{{}}.down_proj.weight",
"gate_weight_key": f"{prefix}.mlp.gate.weight",
"up_gate_proj_expert_weight_key": f"{prefix}.mlp.experts.{{}}.up_gate_proj.weight",
"down_proj_expert_weight_key": f"{prefix}.mlp.experts.{{}}.down_proj.weight",
}
if (fd_config.model_config.moe_num_experts is not None
and layer_id >= fd_config.model_config.moe_layer_start_index):
if (
fd_config.model_config.moe_num_experts is not None
and layer_id >= fd_config.model_config.moe_layer_start_index
):
self.mlp = FusedMoE(fd_config,
moe_intermediate_size=fd_config.model_config.
moe_intermediate_size,
num_experts=fd_config.model_config.moe_num_experts,
top_k=fd_config.model_config.moe_topk,
layer_idx=layer_id,
weight_key_map=weight_key_map)
self.mlp = FusedMoE(
fd_config,
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
num_experts=fd_config.model_config.moe_num_experts,
top_k=fd_config.model_config.moe_topk,
layer_idx=layer_id,
weight_key_map=weight_key_map,
)
else:
self.mlp = Qwen3MLP(
fd_config,
@@ -145,8 +144,7 @@ class Qwen3DecoderLayer(nn.Layer):
)
def load_state_dict(self, state_dict):
"""
"""
""" """
self.self_attn.load_state_dict(state_dict)
self.mlp.load_state_dict(state_dict)
self.input_layernorm.load_state_dict(state_dict)
@@ -158,14 +156,12 @@ class Qwen3DecoderLayer(nn.Layer):
hidden_states: paddle.Tensor,
residual: paddle.Tensor = None,
):
"""
"""
""" """
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
hidden_states=hidden_states,
@@ -173,8 +169,7 @@ class Qwen3DecoderLayer(nn.Layer):
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
@@ -183,8 +178,7 @@ class Qwen3DecoderLayer(nn.Layer):
@support_graph_optimization
class Qwen3MoeModel(nn.Layer):
"""
"""
""" """
def __init__(
self,
@@ -209,12 +203,15 @@ class Qwen3MoeModel(nn.Layer):
prefix=(f"{fd_config.model_config.pretrained_config.prefix_name}.embed_tokens"),
)
self.layers = nn.LayerList([
Qwen3DecoderLayer(
fd_config,
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}")
for i in range(self.num_layers)
])
self.layers = nn.LayerList(
[
Qwen3DecoderLayer(
fd_config,
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}",
)
for i in range(self.num_layers)
]
)
self.norm = RMSNorm(
fd_config,
@@ -243,15 +240,13 @@ class Qwen3MoeModel(nn.Layer):
ids_remove_padding: paddle.Tensor,
forward_meta: ForwardMeta,
):
"""
"""
""" """
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding)
residual = None
for i in range(self.num_layers):
hidden_states, residual = self.layers[i](forward_meta,
hidden_states, residual)
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)
hidden_states = hidden_states + residual
out = self.norm(hidden_states)
@@ -284,8 +279,7 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
@classmethod
def name(self):
"""
"""
""" """
return "Qwen3MoeForCausalLM"
@paddle.no_grad()
@@ -302,11 +296,10 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
self.lm_head.load_state_dict(state_dict)
def compute_logits(self, hidden_states: paddle.Tensor):
"""
"""
""" """
logits = self.lm_head(hidden_states)
logits = paddle.cast(logits, paddle.float32)
logits[:, self.ori_vocab_size:] = -float("inf")
logits[:, self.ori_vocab_size :] = -float("inf")
return logits
@@ -315,10 +308,8 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
ids_remove_padding: paddle.Tensor,
forward_meta: ForwardMeta,
):
"""
"""
hidden_states = self.model(ids_remove_padding=ids_remove_padding,
forward_meta=forward_meta)
""" """
hidden_states = self.model(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta)
return hidden_states
@@ -340,8 +331,7 @@ class Qwen3MoePretrainedModel(PretrainedModel):
def _get_tensor_parallel_mappings(cls, config, is_split=True):
# TODO not support TP split now, next PR will support TP.
from paddleformers.transformers.conversion_utils import \
split_or_merge_func
from paddleformers.transformers.conversion_utils import split_or_merge_func
fn = split_or_merge_func(
is_split=is_split,
@@ -357,45 +347,33 @@ class Qwen3MoePretrainedModel(PretrainedModel):
"lm_head.weight": partial(fn, is_column=True),
# Row Linear
"embed_tokens.weight": partial(fn, is_column=False),
"layers.0.self_attn.o_proj.weight": partial(fn,
is_column=False),
"layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
}
# Column Linear
config.fuse_attention_qkv = False
if config.fuse_attention_qkv:
base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(
fn, is_column=True)
base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(fn, is_column=True)
else:
base_actions["layers.0.self_attn.q_proj.weight"] = partial(
fn, is_column=True)
base_actions["layers.0.self_attn.q_proj.bias"] = partial(
fn, is_column=True)
base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True)
# if we have enough num_key_value_heads to split, then split it.
if config.num_key_value_heads % config.tensor_parallel_degree == 0:
base_actions["layers.0.self_attn.k_proj.weight"] = partial(
fn, is_column=True)
base_actions["layers.0.self_attn.v_proj.weight"] = partial(
fn, is_column=True)
base_actions["layers.0.self_attn.k_proj.bias"] = partial(
fn, is_column=True)
base_actions["layers.0.self_attn.v_proj.bias"] = partial(
fn, is_column=True)
base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.k_proj.bias"] = partial(fn, is_column=True)
base_actions["layers.0.self_attn.v_proj.bias"] = partial(fn, is_column=True)
for key, action in base_actions.items():
if "layers.0." in key:
for i in range(num_layers):
final_actions[key.replace("layers.0.",
f"layers.{i}.")] = action
final_actions[key.replace("layers.0.", f"layers.{i}.")] = action
final_actions[key] = action
base_actions = {
"layers.0.mlp.experts.0.gate_proj.weight":
partial(fn, is_column=True),
"layers.0.mlp.experts.0.down_proj.weight":
partial(fn, is_column=False),
"layers.0.mlp.experts.0.up_proj.weight":
partial(fn, is_column=True),
"layers.0.mlp.experts.0.gate_proj.weight": partial(fn, is_column=True),
"layers.0.mlp.experts.0.down_proj.weight": partial(fn, is_column=False),
"layers.0.mlp.experts.0.up_proj.weight": partial(fn, is_column=True),
}
for key, action in base_actions.items():
@@ -413,11 +391,8 @@ class Qwen3MoePretrainedModel(PretrainedModel):
elif isinstance(config.moe_num_experts, int):
num_experts = config.moe_num_experts
else:
raise ValueError(
f"Not support type of num_experts [{type(config.moe_num_experts)}]"
)
raise ValueError(f"Not support type of num_experts [{type(config.moe_num_experts)}]")
mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers,
num_experts)
mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers, num_experts)
return mappings