mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 09:07:10 +08:00
fix_eplb (#3160)
This commit is contained in:
@@ -332,7 +332,7 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
||||
self.moe_quant_type = "w4a8"
|
||||
self.pack_num = 2
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
|
||||
"""
|
||||
Paddle cutlass process prequanted weights.
|
||||
"""
|
||||
@@ -500,7 +500,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
||||
self.moe_quant_type = self.quant_config.algo
|
||||
self.pack_num = 1
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
|
||||
"""
|
||||
Paddle cutlass process prequanted weights.
|
||||
"""
|
||||
|
@@ -62,7 +62,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
quanted_weight_scale = quanted_weight_scale.transpose([0, 2, 1]).contiguous()
|
||||
create_and_set_parameter(layer, scale_name, quanted_weight_scale)
|
||||
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
|
||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
|
||||
"""
|
||||
Paddle cutlass process prequanted weights.
|
||||
"""
|
||||
@@ -72,9 +72,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
||||
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = layer.load_experts_weight(
|
||||
state_dict,
|
||||
up_gate_proj_expert_weight_key,
|
||||
down_proj_expert_weight_key,
|
||||
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, is_rearrange
|
||||
)
|
||||
# self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||
up_gate_proj_weight_scale = []
|
||||
|
@@ -217,6 +217,7 @@ class FusedMoE(nn.Layer):
|
||||
state_dict: dict,
|
||||
up_gate_proj_expert_weight_key: str,
|
||||
down_proj_expert_weight_key: str,
|
||||
is_rearrange: bool = False,
|
||||
):
|
||||
"""
|
||||
Load experts weight from state_dict.
|
||||
@@ -245,7 +246,13 @@ class FusedMoE(nn.Layer):
|
||||
]
|
||||
up_gate_proj_weights = []
|
||||
down_proj_weights = []
|
||||
is_ffn_merged = up_gate_proj_expert_weight_key.format(self.expert_id_offset) in state_dict
|
||||
|
||||
if isinstance(state_dict, list):
|
||||
state_dict = dict(state_dict)
|
||||
is_ffn_merged = (
|
||||
up_gate_proj_expert_weight_key.format(logical_expert_ids[0] if is_rearrange else self.expert_id_offset)
|
||||
in state_dict
|
||||
)
|
||||
if is_ffn_merged:
|
||||
for expert_idx in logical_expert_ids:
|
||||
down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx)
|
||||
@@ -381,7 +388,7 @@ class FusedMoE(nn.Layer):
|
||||
|
||||
if self.fd_config.model_config.is_quantized:
|
||||
if getattr(self.fd_config.quant_config, "is_permuted", False):
|
||||
self.quant_method.process_prequanted_weights(self, state_dict)
|
||||
self.quant_method.process_prequanted_weights(self, state_dict, is_rearrange)
|
||||
else:
|
||||
self.quant_method.create_weights(self, state_dict)
|
||||
else:
|
||||
|
@@ -112,7 +112,11 @@ def load_ep_checkpoint(model_path: str, fd_config: FDConfig, return_numpy: bool
|
||||
num_local_ffn_keys.append(down_proj_in_scale_key)
|
||||
|
||||
# for EP w4a8, we need all expert's activation_scale for up_gate_proj
|
||||
for j in range(fd_config.model_config.moe_num_experts):
|
||||
num_experts = fd_config.model_config.moe_num_experts
|
||||
if isinstance(num_experts, list):
|
||||
num_experts = num_experts[0]
|
||||
|
||||
for j in range(num_experts):
|
||||
up_gate_proj_in_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.activation_scale"
|
||||
num_local_ffn_keys.append(up_gate_proj_in_scale_key)
|
||||
|
||||
|
@@ -95,11 +95,9 @@ class Ernie4_5_MLP(nn.Layer):
|
||||
|
||||
|
||||
class Ernie4_5_MoE(nn.Layer):
|
||||
def __init__(self,
|
||||
fd_config: FDConfig,
|
||||
layer_id: int,
|
||||
prefix: str,
|
||||
redundant_table_manger: RedundantExpertManger = None) -> None:
|
||||
def __init__(
|
||||
self, fd_config: FDConfig, layer_id: int, prefix: str, redundant_table_manger: RedundantExpertManger = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
moe_quant_type = ""
|
||||
if hasattr(fd_config.quant_config, "moe_quant_type"):
|
||||
@@ -176,6 +174,9 @@ class Ernie4_5_MoE(nn.Layer):
|
||||
if self.num_shared_experts > 0:
|
||||
self.shared_experts.load_state_dict(state_dict)
|
||||
|
||||
def update_state_dict(self, state_dict):
|
||||
self.fused_moe.load_state_dict(state_dict, True)
|
||||
|
||||
def forward(self, hidden_states: paddle.Tensor):
|
||||
out = self.fused_moe(hidden_states)
|
||||
if self.num_shared_experts > 0:
|
||||
@@ -281,6 +282,9 @@ class Ernie4_5_DecoderLayer(nn.Layer):
|
||||
self.input_layernorm.load_state_dict(state_dict)
|
||||
self.post_attention_layernorm.load_state_dict(state_dict)
|
||||
|
||||
def update_state_dict(self, state_dict):
|
||||
self.mlp.update_state_dict(state_dict)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
forward_meta: ForwardMeta,
|
||||
@@ -321,6 +325,7 @@ class Ernie4_5_Model(nn.Layer):
|
||||
|
||||
self.num_layers = fd_config.model_config.num_hidden_layers
|
||||
fd_config.model_config.pretrained_config.prefix_name = "ernie"
|
||||
self.fd_config = fd_config
|
||||
|
||||
self.redundant_table_manger = None
|
||||
if fd_config.model_config.enable_redundant_experts is True:
|
||||
@@ -372,6 +377,22 @@ class Ernie4_5_Model(nn.Layer):
|
||||
logger.info(f"Start load layer {i}")
|
||||
self.layers[i].load_state_dict(state_dict)
|
||||
|
||||
def update_state_dict(self, state_dict):
|
||||
"""
|
||||
Update model parameters from a given state dictionary.
|
||||
|
||||
Args:
|
||||
state_dict (dict[str, np.ndarray | paddle.Tensor]):
|
||||
A dictionary containing model parameters, where keys are parameter names
|
||||
and values are NumPy arrays or PaddlePaddle tensors.
|
||||
"""
|
||||
for i in range(
|
||||
self.fd_config.model_config.moe_layer_start_index,
|
||||
self.fd_config.model_config.num_hidden_layers,
|
||||
):
|
||||
logger.info(f"Start update layer {i}")
|
||||
self.layers[i].update_state_dict(state_dict)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ids_remove_padding: paddle.Tensor,
|
||||
|
@@ -99,8 +99,8 @@ class Ernie4_5_VLMoE(nn.Layer):
|
||||
assert text_moe_layer_start_index <= text_moe_layer_end_index
|
||||
|
||||
moe_quant_type = ""
|
||||
if hasattr(fd_config, "quant_config") and fd_config.quant_config is not None:
|
||||
moe_quant_type = getattr(fd_config.quant_config, "name", lambda: "")()
|
||||
if hasattr(fd_config.quant_config, "moe_quant_type"):
|
||||
moe_quant_type = fd_config.quant_config.moe_quant_type
|
||||
|
||||
if layer_id >= text_moe_layer_start_index and layer_id <= text_moe_layer_end_index:
|
||||
if moe_quant_type == "tensor_wise_fp8" or (
|
||||
|
Reference in New Issue
Block a user