mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
Simplify the Config code (#2770)
* simplify the code * fix vl * delete config * fix * perfect code * fix ci * fix xpu * fix xpu * fix server * resolve conflict * fix mtp * resolve conflict * fix xpu * fix xpu * fix vl * fix log * fix qwen moe * fix qwen moe * fix qwen moe
This commit is contained in:
@@ -27,6 +27,7 @@ from paddleformers.utils.log import logger
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.distributed.communication_op import \
|
||||
tensor_model_parallel_all_reduce
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.model_executor.layers.activation import SiluAndMul
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
|
||||
@@ -40,7 +41,6 @@ from fastdeploy.model_executor.layers.rotary_embedding import \
|
||||
DeepseekScalingRotaryEmbedding
|
||||
from fastdeploy.model_executor.models.model_base import ModelForCasualLM
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import \
|
||||
@@ -109,7 +109,7 @@ class DeepSeekV3MoE(nn.Layer):
|
||||
prefix: str) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_degree
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||
|
||||
weight_key_map = {
|
||||
"gate_weight_key": f"{prefix}.gate.weight",
|
||||
@@ -124,23 +124,23 @@ class DeepSeekV3MoE(nn.Layer):
|
||||
self.fused_moe = FusedMoE(
|
||||
fd_config=fd_config,
|
||||
reduce_results=False,
|
||||
moe_intermediate_size=fd_config.model_config.deepseekv3.
|
||||
moe_intermediate_size=fd_config.model_config.
|
||||
moe_intermediate_size,
|
||||
num_experts=fd_config.model_config.deepseekv3.n_routed_experts,
|
||||
top_k=fd_config.model_config.deepseekv3.num_experts_per_tok,
|
||||
topk_method=fd_config.model_config.deepseekv3.topk_method,
|
||||
topk_group=fd_config.model_config.deepseekv3.topk_group,
|
||||
n_group=fd_config.model_config.deepseekv3.n_group,
|
||||
routed_scaling_factor=fd_config.model_config.deepseekv3.
|
||||
num_experts=fd_config.model_config.n_routed_experts,
|
||||
top_k=fd_config.model_config.num_experts_per_tok,
|
||||
topk_method=fd_config.model_config.topk_method,
|
||||
topk_group=fd_config.model_config.topk_group,
|
||||
n_group=fd_config.model_config.n_group,
|
||||
routed_scaling_factor=fd_config.model_config.
|
||||
routed_scaling_factor,
|
||||
layer_idx=layer_id,
|
||||
weight_key_map=weight_key_map,
|
||||
)
|
||||
|
||||
self.num_shared_experts = fd_config.model_config.deepseekv3.n_shared_experts
|
||||
self.num_shared_experts = fd_config.model_config.n_shared_experts
|
||||
shared_experts_intermediate_size = (
|
||||
self.num_shared_experts *
|
||||
fd_config.model_config.deepseekv3.moe_intermediate_size)
|
||||
fd_config.model_config.moe_intermediate_size)
|
||||
|
||||
self.shared_experts = DeepSeekV3MLP(
|
||||
fd_config=fd_config,
|
||||
@@ -178,18 +178,18 @@ class DeepseekV3MLAAttention(nn.Layer):
|
||||
prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_degree
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.num_attention_heads = fd_config.model_config.num_attention_heads
|
||||
self.num_attention_heads_tp = self.num_attention_heads // self.tp_size
|
||||
|
||||
# MLA
|
||||
self.qk_nope_head_dim = fd_config.model_config.deepseekv3.qk_nope_head_dim
|
||||
self.qk_rope_head_dim = fd_config.model_config.deepseekv3.qk_rope_head_dim
|
||||
self.qk_nope_head_dim = fd_config.model_config.qk_nope_head_dim
|
||||
self.qk_rope_head_dim = fd_config.model_config.qk_rope_head_dim
|
||||
self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
|
||||
self.v_head_dim = fd_config.model_config.deepseekv3.v_head_dim
|
||||
self.q_lora_rank = fd_config.model_config.deepseekv3.q_lora_rank
|
||||
self.kv_lora_rank = fd_config.model_config.deepseekv3.kv_lora_rank
|
||||
self.v_head_dim = fd_config.model_config.v_head_dim
|
||||
self.q_lora_rank = fd_config.model_config.q_lora_rank
|
||||
self.kv_lora_rank = fd_config.model_config.kv_lora_rank
|
||||
|
||||
self.attn_softmax_scale = self.qk_head_dim**-0.5
|
||||
self.rope_theta = fd_config.model_config.rope_theta
|
||||
@@ -255,7 +255,7 @@ class DeepseekV3MLAAttention(nn.Layer):
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
v_head_dim=self.v_head_dim)
|
||||
|
||||
self.rope_scaling = fd_config.model_config.deepseekv3.rope_scaling
|
||||
self.rope_scaling = fd_config.model_config.rope_scaling
|
||||
if self.rope_scaling:
|
||||
mscale_all_dim = self.rope_scaling.get("mscale_all_dim", False)
|
||||
scaling_factor = self.rope_scaling["factor"]
|
||||
@@ -449,9 +449,9 @@ class DeepSeekV3DecoderLayer(nn.Layer):
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
|
||||
if (fd_config.model_config.deepseekv3.n_routed_experts is not None
|
||||
if (fd_config.model_config.n_routed_experts is not None
|
||||
and layer_id
|
||||
>= fd_config.model_config.deepseekv3.first_k_dense_replace):
|
||||
>= fd_config.model_config.first_k_dense_replace):
|
||||
self.mlp = DeepSeekV3MoE(
|
||||
fd_config=fd_config,
|
||||
layer_id=layer_id,
|
||||
@@ -525,8 +525,8 @@ class DeepSeekV3Model(nn.Layer):
|
||||
Initializer for the DeepSeekV3Model class.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_layers = fd_config.model_config.num_layers
|
||||
fd_config.model_config.prefix_name = "deepseek_v3"
|
||||
self.num_layers = fd_config.model_config.num_hidden_layers
|
||||
fd_config.model_config.pretrained_config.prefix_name = "deepseek_v3"
|
||||
|
||||
self.embeddings = VocabParallelEmbedding(
|
||||
fd_config,
|
||||
@@ -539,7 +539,7 @@ class DeepSeekV3Model(nn.Layer):
|
||||
self.decoder_layers = nn.LayerList([
|
||||
DeepSeekV3DecoderLayer(
|
||||
fd_config,
|
||||
prefix=f"{fd_config.model_config.prefix_name}.layers.{i}")
|
||||
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}")
|
||||
for i in range(self.num_layers)
|
||||
])
|
||||
|
||||
@@ -755,5 +755,5 @@ class DeepSeekV3PretrainedModel(PretrainedModel):
|
||||
|
||||
return final_actions
|
||||
|
||||
mappings = get_tensor_parallel_split_mappings(config.num_layers)
|
||||
mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers)
|
||||
return mappings
|
||||
|
Reference in New Issue
Block a user