This commit is contained in:
xiaoxiaohehe001
2025-08-03 01:50:07 +08:00
committed by GitHub
parent 9307f2619b
commit 869626b0f4
6 changed files with 46 additions and 16 deletions

View File

@@ -332,7 +332,7 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
self.moe_quant_type = "w4a8"
self.pack_num = 2
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
"""
Paddle cutlass process prequanted weights.
"""
@@ -500,7 +500,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
self.moe_quant_type = self.quant_config.algo
self.pack_num = 1
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
"""
Paddle cutlass process prequanted weights.
"""

View File

@@ -62,7 +62,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
quanted_weight_scale = quanted_weight_scale.transpose([0, 2, 1]).contiguous()
create_and_set_parameter(layer, scale_name, quanted_weight_scale)
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
"""
Paddle cutlass process prequanted weights.
"""
@@ -72,9 +72,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = layer.load_experts_weight(
state_dict,
up_gate_proj_expert_weight_key,
down_proj_expert_weight_key,
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, is_rearrange
)
# self.check(layer, up_gate_proj_weights, down_proj_weights)
up_gate_proj_weight_scale = []

View File

@@ -217,6 +217,7 @@ class FusedMoE(nn.Layer):
state_dict: dict,
up_gate_proj_expert_weight_key: str,
down_proj_expert_weight_key: str,
is_rearrange: bool = False,
):
"""
Load experts weight from state_dict.
@@ -245,7 +246,13 @@ class FusedMoE(nn.Layer):
]
up_gate_proj_weights = []
down_proj_weights = []
is_ffn_merged = up_gate_proj_expert_weight_key.format(self.expert_id_offset) in state_dict
if isinstance(state_dict, list):
state_dict = dict(state_dict)
is_ffn_merged = (
up_gate_proj_expert_weight_key.format(logical_expert_ids[0] if is_rearrange else self.expert_id_offset)
in state_dict
)
if is_ffn_merged:
for expert_idx in logical_expert_ids:
down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx)
@@ -381,7 +388,7 @@ class FusedMoE(nn.Layer):
if self.fd_config.model_config.is_quantized:
if getattr(self.fd_config.quant_config, "is_permuted", False):
self.quant_method.process_prequanted_weights(self, state_dict)
self.quant_method.process_prequanted_weights(self, state_dict, is_rearrange)
else:
self.quant_method.create_weights(self, state_dict)
else:

View File

@@ -112,7 +112,11 @@ def load_ep_checkpoint(model_path: str, fd_config: FDConfig, return_numpy: bool
num_local_ffn_keys.append(down_proj_in_scale_key)
# for EP w4a8, we need all expert's activation_scale for up_gate_proj
for j in range(fd_config.model_config.moe_num_experts):
num_experts = fd_config.model_config.moe_num_experts
if isinstance(num_experts, list):
num_experts = num_experts[0]
for j in range(num_experts):
up_gate_proj_in_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.activation_scale"
num_local_ffn_keys.append(up_gate_proj_in_scale_key)

View File

@@ -95,11 +95,9 @@ class Ernie4_5_MLP(nn.Layer):
class Ernie4_5_MoE(nn.Layer):
def __init__(self,
fd_config: FDConfig,
layer_id: int,
prefix: str,
redundant_table_manger: RedundantExpertManger = None) -> None:
def __init__(
self, fd_config: FDConfig, layer_id: int, prefix: str, redundant_table_manger: RedundantExpertManger = None
) -> None:
super().__init__()
moe_quant_type = ""
if hasattr(fd_config.quant_config, "moe_quant_type"):
@@ -176,6 +174,9 @@ class Ernie4_5_MoE(nn.Layer):
if self.num_shared_experts > 0:
self.shared_experts.load_state_dict(state_dict)
def update_state_dict(self, state_dict):
self.fused_moe.load_state_dict(state_dict, True)
def forward(self, hidden_states: paddle.Tensor):
out = self.fused_moe(hidden_states)
if self.num_shared_experts > 0:
@@ -281,6 +282,9 @@ class Ernie4_5_DecoderLayer(nn.Layer):
self.input_layernorm.load_state_dict(state_dict)
self.post_attention_layernorm.load_state_dict(state_dict)
def update_state_dict(self, state_dict):
self.mlp.update_state_dict(state_dict)
def forward(
self,
forward_meta: ForwardMeta,
@@ -321,6 +325,7 @@ class Ernie4_5_Model(nn.Layer):
self.num_layers = fd_config.model_config.num_hidden_layers
fd_config.model_config.pretrained_config.prefix_name = "ernie"
self.fd_config = fd_config
self.redundant_table_manger = None
if fd_config.model_config.enable_redundant_experts is True:
@@ -372,6 +377,22 @@ class Ernie4_5_Model(nn.Layer):
logger.info(f"Start load layer {i}")
self.layers[i].load_state_dict(state_dict)
def update_state_dict(self, state_dict):
"""
Update model parameters from a given state dictionary.
Args:
state_dict (dict[str, np.ndarray | paddle.Tensor]):
A dictionary containing model parameters, where keys are parameter names
and values are NumPy arrays or PaddlePaddle tensors.
"""
for i in range(
self.fd_config.model_config.moe_layer_start_index,
self.fd_config.model_config.num_hidden_layers,
):
logger.info(f"Start update layer {i}")
self.layers[i].update_state_dict(state_dict)
def forward(
self,
ids_remove_padding: paddle.Tensor,

View File

@@ -99,8 +99,8 @@ class Ernie4_5_VLMoE(nn.Layer):
assert text_moe_layer_start_index <= text_moe_layer_end_index
moe_quant_type = ""
if hasattr(fd_config, "quant_config") and fd_config.quant_config is not None:
moe_quant_type = getattr(fd_config.quant_config, "name", lambda: "")()
if hasattr(fd_config.quant_config, "moe_quant_type"):
moe_quant_type = fd_config.quant_config.moe_quant_type
if layer_id >= text_moe_layer_start_index and layer_id <= text_moe_layer_end_index:
if moe_quant_type == "tensor_wise_fp8" or (