Sync v2.0 version of code to github repo

This commit is contained in:
Jiang-Jia-Jun
2025-06-29 23:29:37 +00:00
parent d151496038
commit 92c2cfa2e7
597 changed files with 78776 additions and 22905 deletions

View File

@@ -1,5 +1,5 @@
"""
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,34 +14,13 @@
# limitations under the License.
"""
from dataclasses import dataclass
import paddle
from paddle import nn
from paddlenlp.utils.log import logger
from paddleformers.utils.log import logger
from fastdeploy import envs
from fastdeploy.model_executor.layers.utils import get_tensor
from .cutlass_fused_moe import CutlassFusedMoeMethod
@dataclass
class MoEComputeParams:
"""
some params for computing MoE.
it is given to different compute methods.
"""
global_num_experts: int = -1
top_k: int = -1
hidden_size: int = -1
num_local_experts: int = -1
moe_intermediate_size: int = -1
tp_size: int = -1
ep_size: int = -1
dp_size: int = -1
moe_quant_type: str = ""
class FusedMoE(nn.Layer):
"""
@@ -50,174 +29,195 @@ class FusedMoE(nn.Layer):
def __init__(
self,
llm_config,
fd_config,
moe_intermediate_size: int = -1,
num_experts: int = -1,
expert_id_offset: int = 0,
top_k: int = -1,
moe_use_gate_correction_bias: bool = False,
moe_quant_type: str = "weight_only_int4",
layer_idx: int = -1,
gate_weight_key=None,
gate_correction_bias_key=None,
ffn1_expert_weight_key=None,
ffn2_expert_weight_key=None,
moe_ffn1_bias_keys=None,
moe_ffn2_bias_keys=None,
moe_ffn1_weight_scale_keys=None,
moe_ffn2_weight_scale_keys=None,
moe_ffn1_in_scale_keys=None,
moe_ffn2_in_scale_keys=None,
moe_tag: str = "",
weight_key_map: dict = {},
):
"""
Initialize the Moe layer with given parameters.
Args:
llm_config (LLMConfig): Arguments related to inference, containing
fd_config (FDConfig): Arguments related to inference, containing
attributes such as weight_dtype, act_dtype, mp_size, hidden_size, head_dim,
num_attention_heads, and ffn_hidden_size.
"""
super().__init__()
self.llm_config = llm_config
self.fd_config = fd_config
self.layer_idx = layer_idx
self.tp_size = llm_config.parallel_config.mp_size
self.ep_size = llm_config.parallel_config.ep_size
self.moe_use_gate_correction_bias = moe_use_gate_correction_bias
self.tp_size = fd_config.parallel_config.tensor_parallel_degree
self.ep_size = fd_config.parallel_config.expert_parallel_degree
self.ep_rank = fd_config.parallel_config.expert_parallel_rank
assert (self.tp_size >= 1 and self.ep_size == 1) or \
(self.tp_size == 1 and self.ep_size > 1), \
'MoE only support parallelism on TP or EP dimension.'
self.hidden_size = fd_config.model_config.hidden_size
self.moe_config = fd_config.moe_config
self.hidden_size = llm_config.model_config.hidden_size
self.moe_config = llm_config.moe_config
self.use_offline_quant = llm_config.tmp_config.use_offline_quant
moe_tag = self.llm_config.moe_config.moe_tag
logger.info(f"{moe_tag}MoE is running in {moe_quant_type} mode")
self.moe_quant_type = moe_quant_type
self.num_experts = num_experts
self.num_local_experts = self.num_experts // self.ep_size
logger.info(f'''MoE config is num_experts:{num_experts},
top_k:{top_k},
hidden_size:{self.hidden_size},
moe_intermediate_size:{moe_intermediate_size}''')
logger.info(
f"MoE is running on moe_quant_type: {self.moe_quant_type}, ep:{self.ep_size}, tp:{self.tp_size} mode"
)
self.moe_intermediate_size = moe_intermediate_size // self.tp_size
self.gate_weight_key = gate_weight_key
self.gate_correction_bias_key = gate_correction_bias_key
self.top_k = top_k
self.hidden_size = self.hidden_size
self.moe_intermediate_size = moe_intermediate_size // self.tp_size
self.weight_key_map = weight_key_map
self.ffn1_expert_weight_key = ffn1_expert_weight_key
self.ffn2_expert_weight_key = ffn2_expert_weight_key
self.ffn1_bias_key = moe_ffn1_bias_keys
self.ffn2_bias_key = moe_ffn2_bias_keys
self.use_method = envs.FD_MOE_BACKEND.lower()
self.gate_correction_bias = None
self.moe_tag = moe_tag
if self.moe_quant_type == "w4a8":
# below keys are only used in MoE W4A8!
self.ffn1_expert_weight_scale_key = moe_ffn1_weight_scale_keys
self.ffn2_expert_weight_scale_key = moe_ffn2_weight_scale_keys
self.ffn1_expert_in_scale_key = moe_ffn1_in_scale_keys
self.ffn2_expert_in_scale_key = moe_ffn2_in_scale_keys
if self.ep_size > 1:
expert_id_offset = expert_id_offset + self.ep_rank * self.num_local_experts
self.compute_method = CutlassFusedMoeMethod()
self.expert_id_offset = expert_id_offset
self.moe_compute_params = MoEComputeParams()
self.moe_compute_params.global_num_experts = self.num_experts
self.moe_compute_params.top_k = top_k
self.moe_compute_params.hidden_size = self.hidden_size
self.moe_compute_params.num_local_experts = self.num_local_experts
self.moe_compute_params.moe_quant_type = self.moe_quant_type
self.moe_compute_params.moe_intermediate_size = self.moe_intermediate_size
self.moe_compute_params.ep_size = self.ep_size
self.moe_compute_params.tp_size = self.tp_size
if fd_config.quant_config:
self.quant_method = fd_config.quant_config.get_quant_method(self)
else:
# now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future
from .fused_moe_cutlass_backend import CutlassMoEMethod
self.quant_method = CutlassMoEMethod(None)
def load_gate_state_dict(self, state_dict):
if self.ep_size > 1:
self.quant_method.init_ep(self)
logger.info(
f"{moe_tag}MoE config is {num_experts=}[{expert_id_offset}, {expert_id_offset+self.num_local_experts}), \
{top_k=}, hidden_size={self.hidden_size}, {moe_intermediate_size=}, \
, ep_size={self.ep_size}, \
tp_size={self.tp_size}.")
def load_experts_weight(self, state_dict: dict,
ffn1_expert_weight_key: str,
ffn2_expert_weight_key: str):
"""
load_gate_state_dict function.
Load experts weight from state_dict.
Args:
state_dict (dict): The state_dict of model.
ffn1_expert_weight_key (str): The key of ffn1 expert weight.
ffn2_expert_weight_key (str): The key of ffn2 expert weight.
"""
up_gate_proj_weight = []
up_gate_proj_weight_scale = []
down_proj_weight = []
down_proj_weight_scale = []
for j in range(self.num_experts):
up_gate_proj_weight.append(
get_tensor(
state_dict.pop(self.ffn1_expert_weight_key.format(j))))
down_proj_weight.append(
get_tensor(
state_dict.pop(self.ffn2_expert_weight_key.format(j))))
return up_gate_proj_weight, down_proj_weight
ffn1_weights = []
ffn2_weights = []
is_ffn_merged = ffn1_expert_weight_key.format(
self.expert_id_offset) in state_dict
if is_ffn_merged:
for i in range(self.num_local_experts):
expert_idx = self.expert_id_offset + i
ffn1_weights.append(
get_tensor(
state_dict.pop(
ffn1_expert_weight_key.format(expert_idx))))
ffn2_weights.append(
get_tensor(
state_dict.pop(
ffn2_expert_weight_key.format(expert_idx))))
else:
gate_expert_weight_key = ffn1_expert_weight_key.replace(
"up_gate_proj", "gate_proj")
up_expert_weight_key = ffn1_expert_weight_key.replace(
"up_gate_proj", "up_proj")
for j in range(self.num_local_experts):
expert_idx = self.expert_id_offset + j
gate = get_tensor(
state_dict.pop(gate_expert_weight_key.format(expert_idx)))
up = get_tensor(
state_dict.pop(up_expert_weight_key.format(expert_idx)))
ffn1_weights.append(paddle.concat([gate, up], axis=-1))
ffn2_weights.append(
get_tensor(
state_dict.pop(
ffn2_expert_weight_key.format(expert_idx))))
return ffn1_weights, ffn2_weights
def load_state_dict(self, state_dict, is_update: bool = False):
def extract_moe_ffn_weights(self, state_dict: dict):
"""
Extract MoE FFN weights from state dict based on weight key mapping.
Args:
state_dict (dict): Model state dictionary containing the weights.
Returns:
tuple: A tuple containing two lists:
- ffn1_weights: List of tensors for first FFN layer weights
- ffn2_weights: List of tensors for second FFN layer weights
Raises:
AssertionError: If required weight keys are missing or number of weights
doesn't match number of local experts.
"""
ffn1_expert_weight_key = self.weight_key_map.get(
"ffn1_expert_weight_key", None)
ffn2_expert_weight_key = self.weight_key_map.get(
"ffn2_expert_weight_key", None)
assert ffn1_expert_weight_key is not None, "ffn1_expert_weight_key should not be none."
assert ffn2_expert_weight_key is not None, "ffn2_expert_weight_key should not be none."
ffn1_weights, ffn2_weights = self.load_experts_weight(
state_dict, ffn1_expert_weight_key, ffn2_expert_weight_key)
assert len(
ffn1_weights
) == self.num_local_experts, "ffn1_weights length should be equal to num_local_experts."
assert len(
ffn2_weights
) == self.num_local_experts, "ffn2_weights length should be equal to num_local_experts."
return ffn1_weights, ffn2_weights
def extract_gate_correction_bias(self, gate_correction_bias_key,
state_dict):
"""
extract_gate_correction_bias function.
"""
gate_correction_bias_tensor = get_tensor(
state_dict.pop(gate_correction_bias_key)).astype("float32")
return gate_correction_bias_tensor
def load_state_dict(self, state_dict):
"""
load_state_dict function.
"""
# gate
if not is_update:
gate_weight_tensor = get_tensor(state_dict.pop(self.gate_weight_key))
self.gate_weight = self.create_parameter(
shape=gate_weight_tensor.shape,
dtype="float32",
)
self.gate_weight.set_value(gate_weight_tensor)
# gate_correction_bias
self.gate_correction_bias_key = self.weight_key_map.get(
"gate_correction_bias_key", None)
if self.gate_correction_bias_key is not None and self.gate_correction_bias_key in state_dict:
self.moe_use_gate_correction_bias = True
else:
self.moe_use_gate_correction_bias = False
if self.moe_use_gate_correction_bias:
gate_correction_bias_tensor = get_tensor(
state_dict.pop(self.gate_correction_bias_key))
gate_correction_bias_tensor = self.extract_gate_correction_bias(
self.gate_correction_bias_key, state_dict)
self.gate_correction_bias = self.create_parameter(
shape=gate_correction_bias_tensor.shape,
dtype="float32",
)
self.gate_correction_bias.set_value(gate_correction_bias_tensor)
gate_weight_key = self.weight_key_map.get("gate_weight_key", None)
assert gate_weight_key is not None, "gate_weight_key should not be None, please check model checkpoints"
gate_weight_tensor = get_tensor(state_dict.pop(gate_weight_key))
self.gate_weight = self.create_parameter(
shape=gate_weight_tensor.shape,
dtype="float32",
)
self.gate_weight.set_value(gate_weight_tensor.astype("float32"))
if self.fd_config.model_config.is_quantized:
self.quant_method.process_prequanted_weights(self, state_dict)
else:
self.gate_correction_bias = None
self.quant_method.create_weights(self, state_dict)
up_gate_proj_weight, down_proj_weight = self.load_gate_state_dict(
state_dict)
weight1_scale = None
weight2_scale = None
ffn1_in_scale = None
ffn2_in_scale = None
if self.moe_quant_type == "w4a8":
weight1_scale = []
weight2_scale = []
ffn1_in_scale = []
ffn2_in_scale = []
for j in range(self.num_experts):
weight1_scale.append(
get_tensor(
state_dict.pop(
self.ffn1_expert_weight_scale_key.format(
self.layer_idx, j))))
weight2_scale.append(
get_tensor(
state_dict.pop(
self.ffn2_expert_weight_scale_key.format(
self.layer_idx, j))))
ffn1_in_scale.append(
get_tensor(
state_dict.pop(
self.ffn1_expert_in_scale_key.format(
self.layer_idx, j))))
ffn2_in_scale.append(
get_tensor(
state_dict.pop(
self.ffn2_expert_in_scale_key.format(
self.layer_idx, j))))
# other weight is with compute_method
# different method may have different way to create weights
self.compute_method.create_weights(self, self.moe_compute_params,
up_gate_proj_weight,
down_proj_weight, None, None,
weight1_scale, weight2_scale,
ffn1_in_scale, ffn2_in_scale)
def forward(self, x, **kwargs):
def forward(self, x: paddle.Tensor):
"""
Defines the forward computation of the moe layer.
@@ -225,13 +225,9 @@ class FusedMoE(nn.Layer):
x (Tensor): Input tensor to the moe layer.
Returns:
Tensor: Output tensor.
Tensor: Output tensor.s
"""
out = self.compute_method.apply(self, self.moe_compute_params, x)
if self.tp_size > 1:
from fastdeploy.distributed.communication_op import \
tensor_model_parallel_all_reduce
tensor_model_parallel_all_reduce(out)
gate_out = paddle.matmul(x.cast("float32"), self.gate_weight)
out = self.quant_method.apply(self, x, gate_out)
return out