mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
qwen3_moe (#3084)
This commit is contained in:
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
@@ -77,7 +79,7 @@ class FusedMoE(nn.Layer):
|
||||
self.fd_config = fd_config
|
||||
self.layer_idx = layer_idx
|
||||
self.reduce_results = reduce_results
|
||||
|
||||
self.tp_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||
self.ep_size = fd_config.parallel_config.expert_parallel_size
|
||||
self.ep_rank = fd_config.parallel_config.expert_parallel_rank
|
||||
@@ -109,14 +111,19 @@ class FusedMoE(nn.Layer):
|
||||
self.n_group = n_group
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
|
||||
self._dtype = self._helper.get_default_dtype()
|
||||
self.weight_dtype = self._dtype
|
||||
|
||||
moe_quant_config = fd_config.quant_config
|
||||
self.moe_quant_config = moe_quant_config
|
||||
self.moe_quant_type = None
|
||||
if moe_quant_config:
|
||||
self.quant_method = moe_quant_config.get_quant_method(self)
|
||||
self.moe_quant_type = moe_quant_config.name()
|
||||
else:
|
||||
# now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future
|
||||
# w_fp16 a_fp16
|
||||
self.quant_method = get_moe_method()
|
||||
self.quant_method.create_weights(self, weight_loader=self.weight_loader)
|
||||
|
||||
self.redundant_table_manger = None
|
||||
if self.ep_size > 1:
|
||||
@@ -140,21 +147,121 @@ class FusedMoE(nn.Layer):
|
||||
tp_size={self.tp_size}."
|
||||
)
|
||||
|
||||
def weight_loader(self, param, loaded_weight, expert_id, shard_id: Optional[str] = None):
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if shard_id is None:
|
||||
# 1.gate up fused in disk
|
||||
return
|
||||
# 2.gate up splited in disk
|
||||
assert shard_id in ["gate", "down", "up"]
|
||||
expert_param = param[expert_id]
|
||||
if current_platform.is_cuda():
|
||||
SHARD_ID_TO_SHARDED_DIM = {"gate": 1, "down": 0, "up": 1}
|
||||
else:
|
||||
SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0}
|
||||
self._load_expert_weight(
|
||||
expert_param=expert_param,
|
||||
shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id],
|
||||
loaded_weight=loaded_weight,
|
||||
shard_id=shard_id,
|
||||
)
|
||||
|
||||
def _load_gate_up_weight(self, expert_param, shard_dim, loaded_weight, shard_id):
|
||||
tensor_size = expert_param.shape[shard_dim] // 2
|
||||
if shard_id == "gate":
|
||||
expert_param = expert_param[..., :tensor_size] if shard_dim else expert_param[:tensor_size, ...]
|
||||
elif shard_id == "up":
|
||||
expert_param = expert_param[..., tensor_size:] if shard_dim else expert_param[tensor_size:, ...]
|
||||
|
||||
if self.tp_size > 1:
|
||||
size = loaded_weight.get_shape()[-1]
|
||||
block_size = size // self.tp_size
|
||||
shard_offset = self.tp_rank * block_size
|
||||
shard_size = (self.tp_rank + 1) * block_size
|
||||
loaded_weight = loaded_weight[..., shard_offset:shard_size]
|
||||
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
# To ensure compatibility across backends, apply an extra transpose for GCU and XPU
|
||||
if expert_param.shape != loaded_weight.shape:
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
assert expert_param.shape == loaded_weight.shape, (
|
||||
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
|
||||
)
|
||||
expert_param.copy_(loaded_weight, False)
|
||||
|
||||
def _load_down_weight(self, expert_param, shard_dim, loaded_weight, shard_id):
|
||||
if self.tp_size > 1:
|
||||
size = loaded_weight.get_shape()[shard_dim]
|
||||
block_size = size // self.tp_size
|
||||
shard_offset = self.tp_rank * block_size
|
||||
shard_size = (self.tp_rank + 1) * block_size
|
||||
loaded_weight = loaded_weight[shard_offset:shard_size, ...]
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
# To ensure compatibility across backends, apply an extra transpose for GCU and XPU
|
||||
if expert_param.shape != loaded_weight.shape:
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
assert expert_param.shape == loaded_weight.shape, (
|
||||
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
|
||||
)
|
||||
expert_param.copy_(loaded_weight, False)
|
||||
|
||||
def _load_expert_weight(
|
||||
self,
|
||||
expert_param,
|
||||
shard_dim,
|
||||
loaded_weight,
|
||||
shard_id,
|
||||
):
|
||||
if shard_id == "down":
|
||||
self._load_down_weight(expert_param, shard_dim, loaded_weight, shard_id)
|
||||
elif shard_id in ["gate", "up"]:
|
||||
self._load_gate_up_weight(expert_param, shard_dim, loaded_weight, shard_id)
|
||||
|
||||
@classmethod
|
||||
def make_expert_params_mapping(
|
||||
cls,
|
||||
ckpt_gate_proj_name: str,
|
||||
ckpt_down_proj_name: str,
|
||||
ckpt_up_proj_name: str,
|
||||
param_gate_up_proj_name: str,
|
||||
param_down_proj_name: str,
|
||||
num_experts: int,
|
||||
ckpt_expert_key_name: str = "experts",
|
||||
ckpt_gate_up_proj_name: Optional[str] = None,
|
||||
) -> list[tuple[str, str, int, str]]:
|
||||
param_name_maping = [
|
||||
("gate", ckpt_gate_proj_name),
|
||||
("down", ckpt_down_proj_name),
|
||||
("up", ckpt_up_proj_name),
|
||||
]
|
||||
if ckpt_gate_up_proj_name:
|
||||
param_name_maping.append((None, ckpt_gate_up_proj_name))
|
||||
|
||||
return [
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
(
|
||||
(
|
||||
param_gate_up_proj_name
|
||||
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
|
||||
else param_down_proj_name
|
||||
),
|
||||
f"{ckpt_expert_key_name}.{expert_id}.{weight_name}.",
|
||||
expert_id,
|
||||
shard_id,
|
||||
)
|
||||
for expert_id in range(num_experts)
|
||||
for shard_id, weight_name in param_name_maping
|
||||
]
|
||||
|
||||
def init_moe_weights(self):
|
||||
"""
|
||||
Initialize the weight shapes and parameters for the MoE layer.
|
||||
Combines weight shape initialization and parameter creation into a single function.
|
||||
"""
|
||||
# Initialize weight shapes
|
||||
self._dtype = self._helper.get_default_dtype()
|
||||
self.weight_dtype = self._dtype
|
||||
gate_weight_shape = [self.hidden_size, self.num_experts]
|
||||
gate_correction_bias_shape = [1, self.num_experts]
|
||||
|
||||
self.gate_weight = self.create_parameter(
|
||||
shape=gate_weight_shape,
|
||||
dtype="float32",
|
||||
)
|
||||
if self.fd_config.model_config.moe_use_aux_free:
|
||||
self.gate_correction_bias = self.create_parameter(
|
||||
shape=gate_correction_bias_shape,
|
||||
@@ -374,26 +481,19 @@ class FusedMoE(nn.Layer):
|
||||
)
|
||||
self.gate_correction_bias.set_value(gate_correction_bias_tensor)
|
||||
|
||||
gate_weight_key = self.weight_key_map.get("gate_weight_key", None)
|
||||
assert gate_weight_key is not None, "gate_weight_key should not be None, please check model checkpoints"
|
||||
|
||||
gate_weight_tensor = get_tensor(state_dict.pop(gate_weight_key))
|
||||
|
||||
self.gate_weight = self.create_parameter(
|
||||
shape=gate_weight_tensor.shape,
|
||||
dtype="float32",
|
||||
)
|
||||
self.gate_weight.set_value(gate_weight_tensor.astype("float32"))
|
||||
|
||||
if self.fd_config.model_config.is_quantized:
|
||||
if getattr(self.fd_config.quant_config, "is_permuted", True):
|
||||
self.quant_method.process_prequanted_weights(self, state_dict)
|
||||
else:
|
||||
self.quant_method.create_weights(self, state_dict)
|
||||
else:
|
||||
self.quant_method.create_weights(self, state_dict)
|
||||
if self.moe_quant_config:
|
||||
self.quant_method.create_weights(self, state_dict)
|
||||
else:
|
||||
# w_fp16 a_fp16
|
||||
self.quant_method.process_loaded_weights(self, state_dict)
|
||||
|
||||
def forward(self, x: paddle.Tensor):
|
||||
def forward(self, x: paddle.Tensor, gate: nn.Layer):
|
||||
"""
|
||||
Defines the forward computation of the moe layer.
|
||||
|
||||
@@ -404,6 +504,5 @@ class FusedMoE(nn.Layer):
|
||||
Tensor: Output tensor.s
|
||||
|
||||
"""
|
||||
gate_out = paddle.matmul(x.cast("float32"), self.gate_weight)
|
||||
out = self.quant_method.apply(self, x, gate_out)
|
||||
out = self.quant_method.apply(self, x, gate)
|
||||
return out
|
||||
|
Reference in New Issue
Block a user