This commit is contained in:
bukejiyu
2025-08-06 14:45:27 +08:00
committed by GitHub
parent 91dc87f1c5
commit 20839abccf
30 changed files with 1361 additions and 1087 deletions

View File

@@ -14,6 +14,8 @@
# limitations under the License.
"""
from typing import Optional
import paddle
from paddle import nn
from paddleformers.utils.log import logger
@@ -77,7 +79,7 @@ class FusedMoE(nn.Layer):
self.fd_config = fd_config
self.layer_idx = layer_idx
self.reduce_results = reduce_results
self.tp_rank = fd_config.parallel_config.tensor_parallel_rank
self.tp_size = fd_config.parallel_config.tensor_parallel_size
self.ep_size = fd_config.parallel_config.expert_parallel_size
self.ep_rank = fd_config.parallel_config.expert_parallel_rank
@@ -109,14 +111,19 @@ class FusedMoE(nn.Layer):
self.n_group = n_group
self.routed_scaling_factor = routed_scaling_factor
self._dtype = self._helper.get_default_dtype()
self.weight_dtype = self._dtype
moe_quant_config = fd_config.quant_config
self.moe_quant_config = moe_quant_config
self.moe_quant_type = None
if moe_quant_config:
self.quant_method = moe_quant_config.get_quant_method(self)
self.moe_quant_type = moe_quant_config.name()
else:
# now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future
# w_fp16 a_fp16
self.quant_method = get_moe_method()
self.quant_method.create_weights(self, weight_loader=self.weight_loader)
self.redundant_table_manger = None
if self.ep_size > 1:
@@ -140,21 +147,121 @@ class FusedMoE(nn.Layer):
tp_size={self.tp_size}."
)
def weight_loader(self, param, loaded_weight, expert_id, shard_id: Optional[str] = None):
from fastdeploy.platforms import current_platform
if shard_id is None:
# 1.gate up fused in disk
return
# 2.gate up splited in disk
assert shard_id in ["gate", "down", "up"]
expert_param = param[expert_id]
if current_platform.is_cuda():
SHARD_ID_TO_SHARDED_DIM = {"gate": 1, "down": 0, "up": 1}
else:
SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0}
self._load_expert_weight(
expert_param=expert_param,
shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id],
loaded_weight=loaded_weight,
shard_id=shard_id,
)
def _load_gate_up_weight(self, expert_param, shard_dim, loaded_weight, shard_id):
tensor_size = expert_param.shape[shard_dim] // 2
if shard_id == "gate":
expert_param = expert_param[..., :tensor_size] if shard_dim else expert_param[:tensor_size, ...]
elif shard_id == "up":
expert_param = expert_param[..., tensor_size:] if shard_dim else expert_param[tensor_size:, ...]
if self.tp_size > 1:
size = loaded_weight.get_shape()[-1]
block_size = size // self.tp_size
shard_offset = self.tp_rank * block_size
shard_size = (self.tp_rank + 1) * block_size
loaded_weight = loaded_weight[..., shard_offset:shard_size]
loaded_weight = get_tensor(loaded_weight)
# To ensure compatibility across backends, apply an extra transpose for GCU and XPU
if expert_param.shape != loaded_weight.shape:
loaded_weight = loaded_weight.transpose([1, 0])
assert expert_param.shape == loaded_weight.shape, (
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
)
expert_param.copy_(loaded_weight, False)
def _load_down_weight(self, expert_param, shard_dim, loaded_weight, shard_id):
if self.tp_size > 1:
size = loaded_weight.get_shape()[shard_dim]
block_size = size // self.tp_size
shard_offset = self.tp_rank * block_size
shard_size = (self.tp_rank + 1) * block_size
loaded_weight = loaded_weight[shard_offset:shard_size, ...]
loaded_weight = get_tensor(loaded_weight)
# To ensure compatibility across backends, apply an extra transpose for GCU and XPU
if expert_param.shape != loaded_weight.shape:
loaded_weight = loaded_weight.transpose([1, 0])
assert expert_param.shape == loaded_weight.shape, (
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
)
expert_param.copy_(loaded_weight, False)
def _load_expert_weight(
self,
expert_param,
shard_dim,
loaded_weight,
shard_id,
):
if shard_id == "down":
self._load_down_weight(expert_param, shard_dim, loaded_weight, shard_id)
elif shard_id in ["gate", "up"]:
self._load_gate_up_weight(expert_param, shard_dim, loaded_weight, shard_id)
@classmethod
def make_expert_params_mapping(
cls,
ckpt_gate_proj_name: str,
ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
param_gate_up_proj_name: str,
param_down_proj_name: str,
num_experts: int,
ckpt_expert_key_name: str = "experts",
ckpt_gate_up_proj_name: Optional[str] = None,
) -> list[tuple[str, str, int, str]]:
param_name_maping = [
("gate", ckpt_gate_proj_name),
("down", ckpt_down_proj_name),
("up", ckpt_up_proj_name),
]
if ckpt_gate_up_proj_name:
param_name_maping.append((None, ckpt_gate_up_proj_name))
return [
# (param_name, weight_name, expert_id, shard_id)
(
(
param_gate_up_proj_name
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
else param_down_proj_name
),
f"{ckpt_expert_key_name}.{expert_id}.{weight_name}.",
expert_id,
shard_id,
)
for expert_id in range(num_experts)
for shard_id, weight_name in param_name_maping
]
def init_moe_weights(self):
"""
Initialize the weight shapes and parameters for the MoE layer.
Combines weight shape initialization and parameter creation into a single function.
"""
# Initialize weight shapes
self._dtype = self._helper.get_default_dtype()
self.weight_dtype = self._dtype
gate_weight_shape = [self.hidden_size, self.num_experts]
gate_correction_bias_shape = [1, self.num_experts]
self.gate_weight = self.create_parameter(
shape=gate_weight_shape,
dtype="float32",
)
if self.fd_config.model_config.moe_use_aux_free:
self.gate_correction_bias = self.create_parameter(
shape=gate_correction_bias_shape,
@@ -374,26 +481,19 @@ class FusedMoE(nn.Layer):
)
self.gate_correction_bias.set_value(gate_correction_bias_tensor)
gate_weight_key = self.weight_key_map.get("gate_weight_key", None)
assert gate_weight_key is not None, "gate_weight_key should not be None, please check model checkpoints"
gate_weight_tensor = get_tensor(state_dict.pop(gate_weight_key))
self.gate_weight = self.create_parameter(
shape=gate_weight_tensor.shape,
dtype="float32",
)
self.gate_weight.set_value(gate_weight_tensor.astype("float32"))
if self.fd_config.model_config.is_quantized:
if getattr(self.fd_config.quant_config, "is_permuted", True):
self.quant_method.process_prequanted_weights(self, state_dict)
else:
self.quant_method.create_weights(self, state_dict)
else:
self.quant_method.create_weights(self, state_dict)
if self.moe_quant_config:
self.quant_method.create_weights(self, state_dict)
else:
# w_fp16 a_fp16
self.quant_method.process_loaded_weights(self, state_dict)
def forward(self, x: paddle.Tensor):
def forward(self, x: paddle.Tensor, gate: nn.Layer):
"""
Defines the forward computation of the moe layer.
@@ -404,6 +504,5 @@ class FusedMoE(nn.Layer):
Tensor: Output tensor.s
"""
gate_out = paddle.matmul(x.cast("float32"), self.gate_weight)
out = self.quant_method.apply(self, x, gate_out)
out = self.quant_method.apply(self, x, gate)
return out