This commit is contained in:
bukejiyu
2025-08-06 14:45:27 +08:00
committed by GitHub
parent 91dc87f1c5
commit 20839abccf
30 changed files with 1361 additions and 1087 deletions

View File

@@ -117,13 +117,12 @@ class DeepSeekV3MoE(nn.Layer):
self.tp_size = fd_config.parallel_config.tensor_parallel_size
weight_key_map = {
"gate_weight_key": f"{prefix}.gate.weight",
"gate_correction_bias_key": f"{prefix}.gate.e_score_correction_bias",
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
}
self.fused_moe = FusedMoE(
self.experts = FusedMoE(
fd_config=fd_config,
reduce_results=False,
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
@@ -137,6 +136,16 @@ class DeepSeekV3MoE(nn.Layer):
weight_key_map=weight_key_map,
)
self.gate = ReplicatedLinear(
fd_config=fd_config,
prefix=f"{prefix}.gate",
input_size=fd_config.model_config.hidden_size,
output_size=fd_config.model_config.n_routed_experts,
with_bias=False,
skip_quant=True,
weight_dtype="float32",
)
self.num_shared_experts = fd_config.model_config.n_shared_experts
shared_experts_intermediate_size = self.num_shared_experts * fd_config.model_config.moe_intermediate_size
@@ -149,13 +158,14 @@ class DeepSeekV3MoE(nn.Layer):
def load_state_dict(self, state_dict):
""" """
self.fused_moe.load_state_dict(state_dict)
self.gate.load_state_dict(state_dict)
self.experts.load_state_dict(state_dict)
self.shared_experts.load_state_dict(state_dict)
def forward(self, hidden_states: paddle.Tensor):
""" """
shared_experts_out = self.shared_experts(hidden_states)
moe_out = self.fused_moe(hidden_states)
moe_out = self.experts(hidden_states, self.gate)
moe_out = moe_out + shared_experts_out
# We do to TP all reduce after the sum of experts.
if self.tp_size > 1:

View File

@@ -37,6 +37,7 @@ from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
@@ -147,7 +148,7 @@ class Ernie4_5_MoE(nn.Layer):
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
}
self.fused_moe = FusedMoE(
self.experts = FusedMoE(
fd_config=fd_config,
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
num_experts=fd_config.model_config.moe_num_experts,
@@ -156,6 +157,16 @@ class Ernie4_5_MoE(nn.Layer):
weight_key_map=weight_key_map,
)
self.gate = ReplicatedLinear(
fd_config=fd_config,
prefix=f"{prefix}.gate",
input_size=fd_config.model_config.hidden_size,
output_size=fd_config.model_config.moe_num_experts,
with_bias=False,
skip_quant=True,
weight_dtype="float32",
)
self.num_shared_experts = fd_config.model_config.moe_num_shared_experts
if self.num_shared_experts > 0:
shared_experts_hidden_dim = self.num_shared_experts * fd_config.model_config.moe_intermediate_size
@@ -166,12 +177,13 @@ class Ernie4_5_MoE(nn.Layer):
)
def load_state_dict(self, state_dict):
self.fused_moe.load_state_dict(state_dict)
self.gate.load_state_dict(state_dict)
self.experts.load_state_dict(state_dict)
if self.num_shared_experts > 0:
self.shared_experts.load_state_dict(state_dict)
def forward(self, hidden_states: paddle.Tensor):
out = self.fused_moe(hidden_states)
out = self.experts(hidden_states, self.gate)
if self.num_shared_experts > 0:
s_x = self.shared_experts(hidden_states)
out = out + s_x
@@ -435,7 +447,7 @@ class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
self.fd_config.model_config.moe_layer_start_index,
self.fd_config.model_config.num_hidden_layers,
):
self.ernie.layers[i].mlp.fused_moe(fake_hidden_states)
self.ernie.layers[i].mlp.expert(fake_hidden_states)
def forward(
self,

View File

@@ -33,6 +33,7 @@ from fastdeploy.model_executor.graph_optimization.decorator import (
support_graph_optimization,
)
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import ReplicatedLinear
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
from fastdeploy.model_executor.layers.normalization import RMSNorm
@@ -73,6 +74,93 @@ class VLMoEMeta:
fake_hidden_states: Optional[paddle.Tensor] = None
class Ernie4_5_VLMoeBlock(nn.Layer):
def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str, moe_tag: str, expert_id_offset: int) -> None:
super().__init__()
moe_quant_type = ""
if hasattr(fd_config, "quant_config") and fd_config.quant_config is not None:
moe_quant_type = getattr(fd_config.quant_config, "name", lambda: "")()
if moe_quant_type == "tensor_wise_fp8" or (
moe_quant_type == "block_wise_fp8" and fd_config.model_config.is_quantized
):
weight_key_map = {
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.quant_weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.quant_weight",
"up_gate_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.weight_scale",
"down_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.down_proj.weight_scale",
"up_gate_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.activation_scale",
"down_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.down_proj.activation_scale",
}
else:
# wint4/wint8/bfloat16
weight_key_map = {
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
}
moe_intermediate_size = (
fd_config.model_config.moe_intermediate_size[0]
if moe_tag == "Text"
else fd_config.model_config.moe_intermediate_size[1]
)
num_experts = (
fd_config.model_config.moe_num_experts[0]
if moe_tag == "Text"
else fd_config.model_config.moe_num_experts[1]
)
self.experts = FusedMoE(
fd_config=fd_config,
reduce_results=False,
moe_intermediate_size=moe_intermediate_size,
num_experts=num_experts,
expert_id_offset=expert_id_offset,
top_k=fd_config.model_config.moe_k,
layer_idx=layer_id,
moe_tag=moe_tag,
weight_key_map=weight_key_map,
)
self.gate = ReplicatedLinear(
fd_config=fd_config,
prefix=f"{prefix}.gate",
input_size=fd_config.model_config.hidden_size,
output_size=num_experts,
with_bias=False,
skip_quant=True,
weight_dtype="float32",
weight_key="weight" if moe_tag == "Text" else "weight_1",
)
if moe_tag == "Text":
self.experts.extract_gate_correction_bias = self.extract_gate_correction_bias_text
elif moe_tag == "Image":
self.experts.extract_gate_correction_bias = self.extract_gate_correction_bias_image
def forward(self, hidden_states: paddle.Tensor):
out = self.experts(hidden_states, self.gate)
return out
def extract_gate_correction_bias_text(self, gate_correction_bias_key, state_dict):
"""
extract_gate_correction_bias function.
"""
gate_correction_bias_tensor = get_tensor(state_dict[gate_correction_bias_key]).astype("float32")
return gate_correction_bias_tensor[0].unsqueeze(0)
def extract_gate_correction_bias_image(self, gate_correction_bias_key, state_dict):
"""
extract_gate_correction_bias function.
"""
gate_correction_bias_tensor = get_tensor(state_dict[gate_correction_bias_key]).astype("float32")
return gate_correction_bias_tensor[1].unsqueeze(0)
def load_state_dict(self, state_dict):
self.experts.load_state_dict(state_dict)
self.gate.load_state_dict(state_dict)
class Ernie4_5_VLMoE(nn.Layer):
def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str) -> None:
super().__init__()
@@ -99,43 +187,10 @@ class Ernie4_5_VLMoE(nn.Layer):
assert text_moe_layer_start_index <= text_moe_layer_end_index
moe_quant_type = ""
if hasattr(fd_config, "quant_config") and fd_config.quant_config is not None:
moe_quant_type = getattr(fd_config.quant_config, "name", lambda: "")()
if layer_id >= text_moe_layer_start_index and layer_id <= text_moe_layer_end_index:
if moe_quant_type == "tensor_wise_fp8" or (
moe_quant_type == "block_wise_fp8" and fd_config.model_config.is_quantized
):
weight_key_map = {
"gate_weight_key": f"{prefix}.gate.weight",
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.quant_weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.quant_weight",
"up_gate_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.weight_scale",
"down_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.down_proj.weight_scale",
"up_gate_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.activation_scale",
"down_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.down_proj.activation_scale",
}
else:
weight_key_map = {
"gate_weight_key": f"{prefix}.gate.weight",
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
}
self.text_fused_moe = FusedMoE(
fd_config=fd_config,
reduce_results=False,
moe_intermediate_size=fd_config.model_config.moe_intermediate_size[0],
num_experts=fd_config.model_config.moe_num_experts[0],
expert_id_offset=0,
top_k=fd_config.model_config.moe_k,
layer_idx=layer_id,
moe_tag="Text",
weight_key_map=weight_key_map,
self.text_fused_moe = Ernie4_5_VLMoeBlock(
fd_config=fd_config, layer_id=layer_id, prefix=f"{prefix}", moe_tag="Text", expert_id_offset=0
)
self.text_fused_moe.extract_gate_correction_bias = self.extract_gate_correction_bias_text
else:
self.text_fused_moe = Ernie4_5_VLMLP(
fd_config=fd_config,
@@ -146,38 +201,13 @@ class Ernie4_5_VLMoE(nn.Layer):
assert image_moe_layer_start_index <= image_moe_layer_end_index
if layer_id >= image_moe_layer_start_index and layer_id <= image_moe_layer_end_index:
if moe_quant_type == "tensor_wise_fp8" or (
moe_quant_type == "block_wise_fp8" and fd_config.model_config.is_quantized
):
weight_key_map = {
"gate_weight_key": f"{prefix}.gate.weight_1",
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.quant_weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.quant_weight",
"up_gate_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.weight_scale",
"down_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.down_proj.weight_scale",
"up_gate_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.activation_scale",
"down_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.down_proj.activation_scale",
}
else:
weight_key_map = {
"gate_weight_key": f"{prefix}.gate.weight_1",
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
}
self.image_fused_moe = FusedMoE(
self.image_fused_moe = Ernie4_5_VLMoeBlock(
fd_config=fd_config,
reduce_results=False,
moe_intermediate_size=fd_config.model_config.moe_intermediate_size[1],
num_experts=fd_config.model_config.moe_num_experts[1],
expert_id_offset=fd_config.model_config.moe_num_experts[0],
top_k=fd_config.model_config.moe_k,
layer_idx=layer_id,
layer_id=layer_id,
prefix=f"{prefix}",
moe_tag="Image",
weight_key_map=weight_key_map,
expert_id_offset=fd_config.model_config.moe_num_experts[0],
)
self.image_fused_moe.extract_gate_correction_bias = self.extract_gate_correction_bias_image
else:
self.image_fused_moe = Ernie4_5_VLMLP(
fd_config=fd_config,
@@ -195,25 +225,11 @@ class Ernie4_5_VLMoE(nn.Layer):
reduce_results=False,
)
def extract_gate_correction_bias_text(self, gate_correction_bias_key, state_dict):
"""
extract_gate_correction_bias function.
"""
gate_correction_bias_tensor = get_tensor(state_dict[gate_correction_bias_key]).astype("float32")
return gate_correction_bias_tensor[0].unsqueeze(0)
def extract_gate_correction_bias_image(self, gate_correction_bias_key, state_dict):
"""
extract_gate_correction_bias function.
"""
gate_correction_bias_tensor = get_tensor(state_dict[gate_correction_bias_key]).astype("float32")
return gate_correction_bias_tensor[1].unsqueeze(0)
def load_state_dict(self, state_dict):
self.text_fused_moe.load_state_dict(state_dict)
self.image_fused_moe.load_state_dict(state_dict)
if self.text_fused_moe.moe_use_gate_correction_bias:
state_dict.pop(self.text_fused_moe.gate_correction_bias_key)
if self.text_fused_moe.experts.moe_use_gate_correction_bias:
state_dict.pop(self.text_fused_moe.experts.gate_correction_bias_key)
if self.num_shared_experts > 0:
self.shared_experts.load_state_dict(state_dict)

View File

@@ -32,6 +32,7 @@ from fastdeploy.model_executor.layers.activation import SiluAndMul
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
@@ -41,6 +42,47 @@ from fastdeploy.model_executor.models.model_base import ModelForCasualLM
from fastdeploy.model_executor.models.qwen3 import Qwen3Attention
class Qwen3MoeBlock(nn.Layer):
def __init__(
self,
fd_config: FDConfig,
layer_id: int,
prefix: str = "",
) -> None:
super().__init__()
weight_key_map = {
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
}
self.experts = FusedMoE(
fd_config,
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
num_experts=fd_config.model_config.num_experts,
top_k=fd_config.model_config.num_experts_per_tok,
layer_idx=layer_id,
weight_key_map=weight_key_map,
)
self.gate = ReplicatedLinear(
fd_config=fd_config,
prefix=f"{prefix}.gate",
input_size=fd_config.model_config.hidden_size,
output_size=fd_config.model_config.num_experts,
with_bias=False,
skip_quant=True,
weight_dtype="float32",
)
def forward(self, x):
out = self.experts(x, self.gate)
return out
def load_state_dict(self, state_dict):
""" """
self.gate.load_state_dict(state_dict)
self.experts.load_state_dict(state_dict)
class Qwen3MLP(nn.Layer):
""" """
@@ -104,22 +146,13 @@ class Qwen3DecoderLayer(nn.Layer):
layer_id=layer_id,
prefix=f"{prefix}.self_attn",
)
weight_key_map = {
"gate_weight_key": f"{prefix}.mlp.gate.weight",
"up_gate_proj_expert_weight_key": f"{prefix}.mlp.experts.{{}}.up_gate_proj.weight",
"down_proj_expert_weight_key": f"{prefix}.mlp.experts.{{}}.down_proj.weight",
}
if fd_config.model_config.num_experts is not None and layer_id >= fd_config.model_config.moe_layer_start_index:
self.mlp = FusedMoE(
fd_config,
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
num_experts=fd_config.model_config.num_experts,
top_k=fd_config.model_config.num_experts_per_tok,
layer_idx=layer_id,
weight_key_map=weight_key_map,
)
mlp_only_layers = (
[] if not hasattr(fd_config.model_config, "mlp_only_layers") else fd_config.model_config.mlp_only_layers
)
if (layer_id not in mlp_only_layers) and (
fd_config.model_config.num_experts > 0 and (layer_id + 1) % fd_config.model_config.decoder_sparse_step == 0
):
self.mlp = Qwen3MoeBlock(fd_config, layer_id, prefix=f"{prefix}.mlp")
else:
self.mlp = Qwen3MLP(
fd_config,
@@ -279,6 +312,74 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
""" """
return "Qwen3MoeForCausalLM"
def get_expert_mapping(
self,
) -> list[tuple[str, str, int, str]]:
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
param_gate_up_proj_name="experts.up_gate_proj_",
param_down_proj_name="experts.down_proj_",
num_experts=self.fd_config.model_config.num_experts,
)
@paddle.no_grad()
def load_weights(self, weights_iterator) -> None:
"""
Load model parameters from a given weights_iterator object.
Args:
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
"""
from fastdeploy.model_executor.models.utils import default_weight_loader
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("up_gate_proj", "gate_proj", "gate"),
("up_gate_proj", "up_proj", "up"),
("embed_tokens.embeddings", "embed_tokens", None),
("lm_head.linear", "lm_head", None),
]
expert_params_mapping = self.get_expert_mapping()
params_dict = dict(self.named_parameters())
for loaded_weight_name, loaded_weight in weights_iterator:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in loaded_weight_name:
continue
if "mlp.experts" in loaded_weight_name:
continue
model_param_name = loaded_weight_name.replace(weight_name, param_name)
if model_param_name not in params_dict:
continue
param = params_dict[model_param_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in loaded_weight_name:
continue
model_param_name = loaded_weight_name.replace(weight_name, param_name)
if model_param_name not in params_dict:
continue
param = params_dict[model_param_name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id=shard_id, expert_id=expert_id)
break
else:
if loaded_weight_name not in params_dict:
continue
param = params_dict[loaded_weight_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight)
@paddle.no_grad()
def set_state_dict(self, state_dict):
"""

View File

@@ -72,7 +72,11 @@ def default_weight_loader(fd_config: FDConfig) -> None:
loaded_weight = loaded_weight[..., shard_offset:shard_size]
else:
loaded_weight = loaded_weight[shard_offset:shard_size, ...]
loaded_weight = get_tensor(loaded_weight)
# mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
if param.dtype != loaded_weight.dtype:
loaded_weight = loaded_weight.cast(param.dtype)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"