mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-03 02:53:26 +08:00
[Sync] Update to latest code (#2679)
* [Sync] Update to latest code * Add new code files * Add new code files * update code * Try to fix build.sh * Try to fix build.sh * Update code * Update requirements.txt * Update code --------- Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
This commit is contained in:
@@ -30,10 +30,15 @@ class FusedMoE(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
fd_config,
|
||||
reduce_results: bool = True,
|
||||
moe_intermediate_size: int = -1,
|
||||
num_experts: int = -1,
|
||||
expert_id_offset: int = 0,
|
||||
top_k: int = -1,
|
||||
topk_method: str = "",
|
||||
topk_group: int = -1,
|
||||
n_group: int = -1,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
layer_idx: int = -1,
|
||||
moe_tag: str = "",
|
||||
weight_key_map: dict = {},
|
||||
@@ -49,6 +54,7 @@ class FusedMoE(nn.Layer):
|
||||
|
||||
self.fd_config = fd_config
|
||||
self.layer_idx = layer_idx
|
||||
self.reduce_results = reduce_results
|
||||
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_degree
|
||||
self.ep_size = fd_config.parallel_config.expert_parallel_degree
|
||||
@@ -60,28 +66,32 @@ class FusedMoE(nn.Layer):
|
||||
|
||||
self.hidden_size = fd_config.model_config.hidden_size
|
||||
self.moe_config = fd_config.moe_config
|
||||
|
||||
self.num_experts = num_experts
|
||||
self.num_local_experts = self.num_experts // self.ep_size
|
||||
|
||||
self.moe_intermediate_size = moe_intermediate_size // self.tp_size
|
||||
|
||||
self.top_k = top_k
|
||||
self.hidden_size = self.hidden_size
|
||||
self.moe_intermediate_size = moe_intermediate_size // self.tp_size
|
||||
self.weight_key_map = weight_key_map
|
||||
|
||||
self.use_method = envs.FD_MOE_BACKEND.lower()
|
||||
self.gate_correction_bias = None
|
||||
self.moe_tag = moe_tag
|
||||
|
||||
if self.ep_size > 1:
|
||||
expert_id_offset = expert_id_offset + self.ep_rank * self.num_local_experts
|
||||
|
||||
self.expert_id_offset = expert_id_offset
|
||||
|
||||
if fd_config.quant_config:
|
||||
self.quant_method = fd_config.quant_config.get_quant_method(self)
|
||||
# used for deepseek_v3
|
||||
self.topk_method = topk_method
|
||||
self.topk_group = topk_group
|
||||
self.n_group = n_group
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
|
||||
moe_quant_config = fd_config.quant_config
|
||||
if moe_quant_config:
|
||||
self.quant_method = moe_quant_config.get_quant_method(self)
|
||||
self.moe_quant_type = moe_quant_config.name()
|
||||
else:
|
||||
# now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future
|
||||
from .fused_moe_cutlass_backend import CutlassMoEMethod
|
||||
@@ -90,12 +100,78 @@ class FusedMoE(nn.Layer):
|
||||
if self.ep_size > 1:
|
||||
self.quant_method.init_ep(self)
|
||||
|
||||
if fd_config.load_config.dynamic_load_weight:
|
||||
# It's for RL to build model
|
||||
self.init_moe_weights()
|
||||
|
||||
logger.info(
|
||||
f"{moe_tag}MoE config is {num_experts=}[{expert_id_offset}, {expert_id_offset+self.num_local_experts}), \
|
||||
{top_k=}, hidden_size={self.hidden_size}, {moe_intermediate_size=}, \
|
||||
, ep_size={self.ep_size}, \
|
||||
tp_size={self.tp_size}.")
|
||||
|
||||
def init_moe_weights(self):
|
||||
"""
|
||||
Initialize the weight shapes and parameters for the MoE layer.
|
||||
Combines weight shape initialization and parameter creation into a single function.
|
||||
"""
|
||||
# Initialize weight shapes
|
||||
self._dtype = self._helper.get_default_dtype()
|
||||
self.weight_dtype = self._dtype
|
||||
gate_weight_shape = [self.hidden_size, self.num_experts]
|
||||
gate_correction_bias_shape = [1, self.num_experts]
|
||||
|
||||
self.gate_weight = self.create_parameter(
|
||||
shape=gate_weight_shape,
|
||||
dtype="float32",
|
||||
)
|
||||
if self.moe_config.moe_use_aux_free:
|
||||
self.gate_correction_bias = self.create_parameter(
|
||||
shape=gate_correction_bias_shape,
|
||||
dtype="float32",
|
||||
)
|
||||
ffn1_output_dim = self.moe_intermediate_size * 2
|
||||
if self.moe_quant_type in ["fp8", "wint8"]:
|
||||
ffn1_weight_shape = [self.num_local_experts, ffn1_output_dim, self.hidden_size]
|
||||
ffn2_weight_shape = [self.num_local_experts, self.hidden_size, self.moe_intermediate_size]
|
||||
else:
|
||||
ffn1_weight_shape = [self.num_local_experts, self.hidden_size, ffn1_output_dim]
|
||||
ffn2_weight_shape = [self.num_local_experts, self.moe_intermediate_size, self.hidden_size]
|
||||
|
||||
# Create parameters
|
||||
if self.moe_quant_type == "fp8":
|
||||
#(TODO:gaoziyuan)
|
||||
pass
|
||||
else:
|
||||
self.weight_dtype = "int8"
|
||||
self.init_weight_only_scale()
|
||||
|
||||
# FFN1 parameters
|
||||
self.moe_ffn1_weight = self.create_parameter(
|
||||
shape=ffn1_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
# FFN2 parameters
|
||||
self.moe_ffn2_weight = self.create_parameter(
|
||||
shape=ffn2_weight_shape,
|
||||
dtype=self.weight_dtype,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
def init_weight_only_scale(self):
|
||||
"""
|
||||
Initialize the weight scale.
|
||||
"""
|
||||
self.moe_ffn1_weight_scale = self.create_parameter(
|
||||
shape=[self.num_local_experts, self.moe_intermediate_size * 2],
|
||||
dtype=self._dtype,
|
||||
)
|
||||
self.moe_ffn2_weight_scale = self.create_parameter(
|
||||
shape=[self.num_local_experts, self.hidden_size],
|
||||
dtype=self._dtype,
|
||||
)
|
||||
|
||||
def load_experts_weight(self, state_dict: dict,
|
||||
ffn1_expert_weight_key: str,
|
||||
ffn2_expert_weight_key: str):
|
||||
|
||||
Reference in New Issue
Block a user