mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
fix_eplb (#3160)
This commit is contained in:
@@ -332,7 +332,7 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
|||||||
self.moe_quant_type = "w4a8"
|
self.moe_quant_type = "w4a8"
|
||||||
self.pack_num = 2
|
self.pack_num = 2
|
||||||
|
|
||||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
|
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
|
||||||
"""
|
"""
|
||||||
Paddle cutlass process prequanted weights.
|
Paddle cutlass process prequanted weights.
|
||||||
"""
|
"""
|
||||||
@@ -500,7 +500,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
|||||||
self.moe_quant_type = self.quant_config.algo
|
self.moe_quant_type = self.quant_config.algo
|
||||||
self.pack_num = 1
|
self.pack_num = 1
|
||||||
|
|
||||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
|
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
|
||||||
"""
|
"""
|
||||||
Paddle cutlass process prequanted weights.
|
Paddle cutlass process prequanted weights.
|
||||||
"""
|
"""
|
||||||
|
@@ -62,7 +62,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
|||||||
quanted_weight_scale = quanted_weight_scale.transpose([0, 2, 1]).contiguous()
|
quanted_weight_scale = quanted_weight_scale.transpose([0, 2, 1]).contiguous()
|
||||||
create_and_set_parameter(layer, scale_name, quanted_weight_scale)
|
create_and_set_parameter(layer, scale_name, quanted_weight_scale)
|
||||||
|
|
||||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
|
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
|
||||||
"""
|
"""
|
||||||
Paddle cutlass process prequanted weights.
|
Paddle cutlass process prequanted weights.
|
||||||
"""
|
"""
|
||||||
@@ -72,9 +72,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
|||||||
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
||||||
|
|
||||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = layer.load_experts_weight(
|
up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = layer.load_experts_weight(
|
||||||
state_dict,
|
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, is_rearrange
|
||||||
up_gate_proj_expert_weight_key,
|
|
||||||
down_proj_expert_weight_key,
|
|
||||||
)
|
)
|
||||||
# self.check(layer, up_gate_proj_weights, down_proj_weights)
|
# self.check(layer, up_gate_proj_weights, down_proj_weights)
|
||||||
up_gate_proj_weight_scale = []
|
up_gate_proj_weight_scale = []
|
||||||
|
@@ -217,6 +217,7 @@ class FusedMoE(nn.Layer):
|
|||||||
state_dict: dict,
|
state_dict: dict,
|
||||||
up_gate_proj_expert_weight_key: str,
|
up_gate_proj_expert_weight_key: str,
|
||||||
down_proj_expert_weight_key: str,
|
down_proj_expert_weight_key: str,
|
||||||
|
is_rearrange: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Load experts weight from state_dict.
|
Load experts weight from state_dict.
|
||||||
@@ -245,7 +246,13 @@ class FusedMoE(nn.Layer):
|
|||||||
]
|
]
|
||||||
up_gate_proj_weights = []
|
up_gate_proj_weights = []
|
||||||
down_proj_weights = []
|
down_proj_weights = []
|
||||||
is_ffn_merged = up_gate_proj_expert_weight_key.format(self.expert_id_offset) in state_dict
|
|
||||||
|
if isinstance(state_dict, list):
|
||||||
|
state_dict = dict(state_dict)
|
||||||
|
is_ffn_merged = (
|
||||||
|
up_gate_proj_expert_weight_key.format(logical_expert_ids[0] if is_rearrange else self.expert_id_offset)
|
||||||
|
in state_dict
|
||||||
|
)
|
||||||
if is_ffn_merged:
|
if is_ffn_merged:
|
||||||
for expert_idx in logical_expert_ids:
|
for expert_idx in logical_expert_ids:
|
||||||
down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx)
|
down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx)
|
||||||
@@ -381,7 +388,7 @@ class FusedMoE(nn.Layer):
|
|||||||
|
|
||||||
if self.fd_config.model_config.is_quantized:
|
if self.fd_config.model_config.is_quantized:
|
||||||
if getattr(self.fd_config.quant_config, "is_permuted", False):
|
if getattr(self.fd_config.quant_config, "is_permuted", False):
|
||||||
self.quant_method.process_prequanted_weights(self, state_dict)
|
self.quant_method.process_prequanted_weights(self, state_dict, is_rearrange)
|
||||||
else:
|
else:
|
||||||
self.quant_method.create_weights(self, state_dict)
|
self.quant_method.create_weights(self, state_dict)
|
||||||
else:
|
else:
|
||||||
|
@@ -112,7 +112,11 @@ def load_ep_checkpoint(model_path: str, fd_config: FDConfig, return_numpy: bool
|
|||||||
num_local_ffn_keys.append(down_proj_in_scale_key)
|
num_local_ffn_keys.append(down_proj_in_scale_key)
|
||||||
|
|
||||||
# for EP w4a8, we need all expert's activation_scale for up_gate_proj
|
# for EP w4a8, we need all expert's activation_scale for up_gate_proj
|
||||||
for j in range(fd_config.model_config.moe_num_experts):
|
num_experts = fd_config.model_config.moe_num_experts
|
||||||
|
if isinstance(num_experts, list):
|
||||||
|
num_experts = num_experts[0]
|
||||||
|
|
||||||
|
for j in range(num_experts):
|
||||||
up_gate_proj_in_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.activation_scale"
|
up_gate_proj_in_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.activation_scale"
|
||||||
num_local_ffn_keys.append(up_gate_proj_in_scale_key)
|
num_local_ffn_keys.append(up_gate_proj_in_scale_key)
|
||||||
|
|
||||||
|
@@ -95,11 +95,9 @@ class Ernie4_5_MLP(nn.Layer):
|
|||||||
|
|
||||||
|
|
||||||
class Ernie4_5_MoE(nn.Layer):
|
class Ernie4_5_MoE(nn.Layer):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
fd_config: FDConfig,
|
self, fd_config: FDConfig, layer_id: int, prefix: str, redundant_table_manger: RedundantExpertManger = None
|
||||||
layer_id: int,
|
) -> None:
|
||||||
prefix: str,
|
|
||||||
redundant_table_manger: RedundantExpertManger = None) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
moe_quant_type = ""
|
moe_quant_type = ""
|
||||||
if hasattr(fd_config.quant_config, "moe_quant_type"):
|
if hasattr(fd_config.quant_config, "moe_quant_type"):
|
||||||
@@ -176,6 +174,9 @@ class Ernie4_5_MoE(nn.Layer):
|
|||||||
if self.num_shared_experts > 0:
|
if self.num_shared_experts > 0:
|
||||||
self.shared_experts.load_state_dict(state_dict)
|
self.shared_experts.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
def update_state_dict(self, state_dict):
|
||||||
|
self.fused_moe.load_state_dict(state_dict, True)
|
||||||
|
|
||||||
def forward(self, hidden_states: paddle.Tensor):
|
def forward(self, hidden_states: paddle.Tensor):
|
||||||
out = self.fused_moe(hidden_states)
|
out = self.fused_moe(hidden_states)
|
||||||
if self.num_shared_experts > 0:
|
if self.num_shared_experts > 0:
|
||||||
@@ -281,6 +282,9 @@ class Ernie4_5_DecoderLayer(nn.Layer):
|
|||||||
self.input_layernorm.load_state_dict(state_dict)
|
self.input_layernorm.load_state_dict(state_dict)
|
||||||
self.post_attention_layernorm.load_state_dict(state_dict)
|
self.post_attention_layernorm.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
def update_state_dict(self, state_dict):
|
||||||
|
self.mlp.update_state_dict(state_dict)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
forward_meta: ForwardMeta,
|
forward_meta: ForwardMeta,
|
||||||
@@ -321,6 +325,7 @@ class Ernie4_5_Model(nn.Layer):
|
|||||||
|
|
||||||
self.num_layers = fd_config.model_config.num_hidden_layers
|
self.num_layers = fd_config.model_config.num_hidden_layers
|
||||||
fd_config.model_config.pretrained_config.prefix_name = "ernie"
|
fd_config.model_config.pretrained_config.prefix_name = "ernie"
|
||||||
|
self.fd_config = fd_config
|
||||||
|
|
||||||
self.redundant_table_manger = None
|
self.redundant_table_manger = None
|
||||||
if fd_config.model_config.enable_redundant_experts is True:
|
if fd_config.model_config.enable_redundant_experts is True:
|
||||||
@@ -372,6 +377,22 @@ class Ernie4_5_Model(nn.Layer):
|
|||||||
logger.info(f"Start load layer {i}")
|
logger.info(f"Start load layer {i}")
|
||||||
self.layers[i].load_state_dict(state_dict)
|
self.layers[i].load_state_dict(state_dict)
|
||||||
|
|
||||||
|
def update_state_dict(self, state_dict):
|
||||||
|
"""
|
||||||
|
Update model parameters from a given state dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict (dict[str, np.ndarray | paddle.Tensor]):
|
||||||
|
A dictionary containing model parameters, where keys are parameter names
|
||||||
|
and values are NumPy arrays or PaddlePaddle tensors.
|
||||||
|
"""
|
||||||
|
for i in range(
|
||||||
|
self.fd_config.model_config.moe_layer_start_index,
|
||||||
|
self.fd_config.model_config.num_hidden_layers,
|
||||||
|
):
|
||||||
|
logger.info(f"Start update layer {i}")
|
||||||
|
self.layers[i].update_state_dict(state_dict)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
ids_remove_padding: paddle.Tensor,
|
ids_remove_padding: paddle.Tensor,
|
||||||
|
@@ -99,8 +99,8 @@ class Ernie4_5_VLMoE(nn.Layer):
|
|||||||
assert text_moe_layer_start_index <= text_moe_layer_end_index
|
assert text_moe_layer_start_index <= text_moe_layer_end_index
|
||||||
|
|
||||||
moe_quant_type = ""
|
moe_quant_type = ""
|
||||||
if hasattr(fd_config, "quant_config") and fd_config.quant_config is not None:
|
if hasattr(fd_config.quant_config, "moe_quant_type"):
|
||||||
moe_quant_type = getattr(fd_config.quant_config, "name", lambda: "")()
|
moe_quant_type = fd_config.quant_config.moe_quant_type
|
||||||
|
|
||||||
if layer_id >= text_moe_layer_start_index and layer_id <= text_moe_layer_end_index:
|
if layer_id >= text_moe_layer_start_index and layer_id <= text_moe_layer_end_index:
|
||||||
if moe_quant_type == "tensor_wise_fp8" or (
|
if moe_quant_type == "tensor_wise_fp8" or (
|
||||||
|
Reference in New Issue
Block a user