From 869626b0f4ec44f88e208e25f670af962e3cb0c7 Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 <49090790+xiaoxiaohehe001@users.noreply.github.com> Date: Sun, 3 Aug 2025 01:50:07 +0800 Subject: [PATCH] fix_eplb (#3160) --- .../layers/moe/fused_moe_cutlass_backend.py | 4 +-- .../layers/moe/fused_moe_deepgemm_backend.py | 6 ++-- fastdeploy/model_executor/layers/moe/moe.py | 11 +++++-- .../model_executor/load_weight_utils.py | 6 +++- .../model_executor/models/ernie4_5_moe.py | 31 ++++++++++++++++--- .../models/ernie4_5_vl/ernie4_5_vl_moe.py | 4 +-- 6 files changed, 46 insertions(+), 16 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 87ac8fbcb..2fe8d1c57 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -332,7 +332,7 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod): self.moe_quant_type = "w4a8" self.pack_num = 2 - def process_prequanted_weights(self, layer: nn.Layer, state_dict): + def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False): """ Paddle cutlass process prequanted weights. """ @@ -500,7 +500,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod): self.moe_quant_type = self.quant_config.algo self.pack_num = 1 - def process_prequanted_weights(self, layer: nn.Layer, state_dict): + def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False): """ Paddle cutlass process prequanted weights. """ diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index 4abee5c94..bf39adffd 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -62,7 +62,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): quanted_weight_scale = quanted_weight_scale.transpose([0, 2, 1]).contiguous() create_and_set_parameter(layer, scale_name, quanted_weight_scale) - def process_prequanted_weights(self, layer: nn.Layer, state_dict): + def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False): """ Paddle cutlass process prequanted weights. """ @@ -72,9 +72,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None) up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = layer.load_experts_weight( - state_dict, - up_gate_proj_expert_weight_key, - down_proj_expert_weight_key, + state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, is_rearrange ) # self.check(layer, up_gate_proj_weights, down_proj_weights) up_gate_proj_weight_scale = [] diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 9b172247d..627041b17 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -217,6 +217,7 @@ class FusedMoE(nn.Layer): state_dict: dict, up_gate_proj_expert_weight_key: str, down_proj_expert_weight_key: str, + is_rearrange: bool = False, ): """ Load experts weight from state_dict. @@ -245,7 +246,13 @@ class FusedMoE(nn.Layer): ] up_gate_proj_weights = [] down_proj_weights = [] - is_ffn_merged = up_gate_proj_expert_weight_key.format(self.expert_id_offset) in state_dict + + if isinstance(state_dict, list): + state_dict = dict(state_dict) + is_ffn_merged = ( + up_gate_proj_expert_weight_key.format(logical_expert_ids[0] if is_rearrange else self.expert_id_offset) + in state_dict + ) if is_ffn_merged: for expert_idx in logical_expert_ids: down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx) @@ -381,7 +388,7 @@ class FusedMoE(nn.Layer): if self.fd_config.model_config.is_quantized: if getattr(self.fd_config.quant_config, "is_permuted", False): - self.quant_method.process_prequanted_weights(self, state_dict) + self.quant_method.process_prequanted_weights(self, state_dict, is_rearrange) else: self.quant_method.create_weights(self, state_dict) else: diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index 2172d0f82..c894e9e23 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -112,7 +112,11 @@ def load_ep_checkpoint(model_path: str, fd_config: FDConfig, return_numpy: bool num_local_ffn_keys.append(down_proj_in_scale_key) # for EP w4a8, we need all expert's activation_scale for up_gate_proj - for j in range(fd_config.model_config.moe_num_experts): + num_experts = fd_config.model_config.moe_num_experts + if isinstance(num_experts, list): + num_experts = num_experts[0] + + for j in range(num_experts): up_gate_proj_in_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.activation_scale" num_local_ffn_keys.append(up_gate_proj_in_scale_key) diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index 1c568382e..fa12a099b 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -95,11 +95,9 @@ class Ernie4_5_MLP(nn.Layer): class Ernie4_5_MoE(nn.Layer): - def __init__(self, - fd_config: FDConfig, - layer_id: int, - prefix: str, - redundant_table_manger: RedundantExpertManger = None) -> None: + def __init__( + self, fd_config: FDConfig, layer_id: int, prefix: str, redundant_table_manger: RedundantExpertManger = None + ) -> None: super().__init__() moe_quant_type = "" if hasattr(fd_config.quant_config, "moe_quant_type"): @@ -176,6 +174,9 @@ class Ernie4_5_MoE(nn.Layer): if self.num_shared_experts > 0: self.shared_experts.load_state_dict(state_dict) + def update_state_dict(self, state_dict): + self.fused_moe.load_state_dict(state_dict, True) + def forward(self, hidden_states: paddle.Tensor): out = self.fused_moe(hidden_states) if self.num_shared_experts > 0: @@ -281,6 +282,9 @@ class Ernie4_5_DecoderLayer(nn.Layer): self.input_layernorm.load_state_dict(state_dict) self.post_attention_layernorm.load_state_dict(state_dict) + def update_state_dict(self, state_dict): + self.mlp.update_state_dict(state_dict) + def forward( self, forward_meta: ForwardMeta, @@ -321,6 +325,7 @@ class Ernie4_5_Model(nn.Layer): self.num_layers = fd_config.model_config.num_hidden_layers fd_config.model_config.pretrained_config.prefix_name = "ernie" + self.fd_config = fd_config self.redundant_table_manger = None if fd_config.model_config.enable_redundant_experts is True: @@ -372,6 +377,22 @@ class Ernie4_5_Model(nn.Layer): logger.info(f"Start load layer {i}") self.layers[i].load_state_dict(state_dict) + def update_state_dict(self, state_dict): + """ + Update model parameters from a given state dictionary. + + Args: + state_dict (dict[str, np.ndarray | paddle.Tensor]): + A dictionary containing model parameters, where keys are parameter names + and values are NumPy arrays or PaddlePaddle tensors. + """ + for i in range( + self.fd_config.model_config.moe_layer_start_index, + self.fd_config.model_config.num_hidden_layers, + ): + logger.info(f"Start update layer {i}") + self.layers[i].update_state_dict(state_dict) + def forward( self, ids_remove_padding: paddle.Tensor, diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index 2dd562135..428a6ecd3 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -99,8 +99,8 @@ class Ernie4_5_VLMoE(nn.Layer): assert text_moe_layer_start_index <= text_moe_layer_end_index moe_quant_type = "" - if hasattr(fd_config, "quant_config") and fd_config.quant_config is not None: - moe_quant_type = getattr(fd_config.quant_config, "name", lambda: "")() + if hasattr(fd_config.quant_config, "moe_quant_type"): + moe_quant_type = fd_config.quant_config.moe_quant_type if layer_id >= text_moe_layer_start_index and layer_id <= text_moe_layer_end_index: if moe_quant_type == "tensor_wise_fp8" or (