[NewFeatures] support eplb (#3547)

* [NewFeatures] support eplb

* fix eplb
This commit is contained in:
xiaoxiaohehe001
2025-08-26 16:19:30 +08:00
committed by GitHub
parent 56e2d7e668
commit 9afa236e39
17 changed files with 174 additions and 67 deletions

View File

@@ -49,6 +49,7 @@ from fastdeploy.model_executor.models.model_base import ModelForCasualLM
from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm
from fastdeploy.model_executor.models.utils import LayerIdPlaceholder as layerid
from fastdeploy.model_executor.models.utils import WeightMeta
from fastdeploy.worker.experts_manager import RedundantExpertManger
class Ernie4_5_MLP(nn.Layer):
@@ -97,7 +98,9 @@ class Ernie4_5_MLP(nn.Layer):
class Ernie4_5_MoE(nn.Layer):
def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str) -> None:
def __init__(
self, fd_config: FDConfig, layer_id: int, prefix: str, redundant_table_manger: RedundantExpertManger = None
) -> None:
super().__init__()
moe_quant_type = ""
if hasattr(fd_config.quant_config, "moe_quant_type"):
@@ -175,6 +178,7 @@ class Ernie4_5_MoE(nn.Layer):
top_k=fd_config.model_config.moe_k,
layer_idx=layer_id,
gate_correction_bias=None,
redundant_table_manger=redundant_table_manger,
weight_key_map=weight_key_map,
)
@@ -209,6 +213,9 @@ class Ernie4_5_MoE(nn.Layer):
if self.num_shared_experts > 0:
self.shared_experts.load_state_dict(state_dict)
def update_state_dict(self, state_dict):
self.fused_moe.load_state_dict(state_dict, True)
def split_allgather_out(self, hidden_states: paddle.Tensor, token_num: int):
token_num_per_rank = (token_num + self.tensor_parallel_size - 1) // self.tensor_parallel_size
# AllGather will hang when the data shapes on multi-ranks are different!
@@ -287,6 +294,7 @@ class Ernie4_5_DecoderLayer(nn.Layer):
def __init__(
self,
fd_config: FDConfig,
redundant_table_manger: RedundantExpertManger = None,
prefix: str = "",
) -> None:
super().__init__()
@@ -305,6 +313,7 @@ class Ernie4_5_DecoderLayer(nn.Layer):
self.mlp = Ernie4_5_MoE(
fd_config=fd_config,
layer_id=layer_id,
redundant_table_manger=redundant_table_manger,
prefix=f"{prefix}.mlp",
)
else:
@@ -334,6 +343,9 @@ class Ernie4_5_DecoderLayer(nn.Layer):
self.input_layernorm.load_state_dict(state_dict)
self.post_attention_layernorm.load_state_dict(state_dict)
def update_state_dict(self, state_dict):
self.mlp.update_state_dict(state_dict)
def forward(
self,
forward_meta: ForwardMeta,
@@ -374,6 +386,15 @@ class Ernie4_5_Model(nn.Layer):
self.num_layers = fd_config.model_config.num_hidden_layers
fd_config.model_config.pretrained_config.prefix_name = "ernie"
self.fd_config = fd_config
self.redundant_table_manger = None
if fd_config.model_config.enable_redundant_experts is True:
self.redundant_table_manger = RedundantExpertManger(
n_routed_experts=fd_config.model_config.moe_num_experts,
num_hidden_layers=fd_config.model_config.num_hidden_layers,
redundant_experts_num=fd_config.model_config.redundant_experts_num,
ep_size=fd_config.parallel_config.expert_parallel_size,
)
self.embed_tokens = VocabParallelEmbedding(
fd_config=fd_config,
@@ -387,6 +408,7 @@ class Ernie4_5_Model(nn.Layer):
[
Ernie4_5_DecoderLayer(
fd_config=fd_config,
redundant_table_manger=self.redundant_table_manger,
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}",
)
for i in range(self.num_layers)
@@ -415,6 +437,22 @@ class Ernie4_5_Model(nn.Layer):
logger.info(f"Start load layer {i}")
self.layers[i].load_state_dict(state_dict)
def update_state_dict(self, state_dict):
"""
Update model parameters from a given state dictionary.
Args:
state_dict (dict[str, np.ndarray | paddle.Tensor]):
A dictionary containing model parameters, where keys are parameter names
and values are NumPy arrays or PaddlePaddle tensors.
"""
for i in range(
self.fd_config.model_config.moe_layer_start_index,
self.fd_config.model_config.num_hidden_layers,
):
logger.info(f"Start update layer {i}")
self.layers[i].update_state_dict(state_dict)
def forward(
self,
ids_remove_padding: paddle.Tensor,

View File

@@ -86,8 +86,8 @@ class Ernie4_5_VLMoeBlock(nn.Layer):
) -> None:
super().__init__()
moe_quant_type = ""
if hasattr(fd_config, "quant_config") and fd_config.quant_config is not None:
moe_quant_type = getattr(fd_config.quant_config, "name", lambda: "")()
if hasattr(fd_config.quant_config, "moe_quant_type"):
moe_quant_type = fd_config.quant_config.moe_quant_type
if moe_quant_type == "tensor_wise_fp8" or (
moe_quant_type == "block_wise_fp8" and fd_config.model_config.is_quantized