Files
FastDeploy/fastdeploy/model_executor/models/ernie4_5_moe.py
lizexu123 c86945ef49 [Feature] support pool (#3827)
* support pool

* update pooling

* add pooler_config and check

* update

* support AutoWeightsLoader load weight

* fix

* update

* delete print

* update pre-commit

* fix

* fix xpu

* fix ModelRegistry->model_registry

* fix Copilot review

* fix pooler.py

* delete StepPooler

* fix abstract

* fix default_loader_v1

* fix Pre Commit

* support torch qwen3 dense

* add test and fix torch-qwen

* fix

* fix

* adapter ci:

* fix review

* fix pooling_params.py

* fix

* fix tasks.py 2025

* fix print and logger

* Modefy ModelRegistry and delete AutoWeightsLoader

* fix logger

* fix test_embedding

* fix ci bug

* ernie4_5 model_registry

* fix test

* support Qwen3-Embedding-0.6B tp=1 load

* fix extra code

* fix

* delete fix vocab_size

* delete prepare_params_dict

* fix:
2025-09-22 14:09:09 +08:00

861 lines
32 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
import inspect
import re
from functools import partial
from typing import Dict, Union
import numpy as np
import paddle
from paddle import nn
from paddleformers.transformers import PretrainedModel
from paddleformers.transformers.configuration_utils import PretrainedConfig
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.decorator import (
support_graph_optimization,
)
from fastdeploy.model_executor.layers.activation import SiluAndMul
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.models.model_base import (
ModelCategory,
ModelForCasualLM,
ModelRegistry,
)
from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm
from fastdeploy.model_executor.models.utils import LayerIdPlaceholder as layerid
from fastdeploy.model_executor.models.utils import WeightMeta
from fastdeploy.worker.experts_manager import RedundantExpertManger
class Ernie4_5_MLP(nn.Layer):
def __init__(
self,
fd_config: FDConfig,
intermediate_size: int,
prefix: str = "",
reduce_results: bool = True,
) -> None:
super().__init__()
self.nranks = fd_config.parallel_config.tensor_parallel_size
self.up_gate_proj = MergedColumnParallelLinear(
fd_config=fd_config,
prefix=f"{prefix}.up_gate_proj",
input_size=fd_config.model_config.hidden_size,
output_size=intermediate_size * 2,
with_bias=False,
activation=fd_config.model_config.hidden_act,
)
self.down_proj = RowParallelLinear(
fd_config=fd_config,
prefix=f"{prefix}.down_proj",
input_size=intermediate_size,
output_size=fd_config.model_config.hidden_size,
with_bias=False,
reduce_results=reduce_results,
)
self.act_fn = SiluAndMul(
fd_config=fd_config,
bias=None,
act_method=fd_config.model_config.hidden_act,
)
def load_state_dict(self, state_dict):
self.up_gate_proj.load_state_dict(state_dict)
self.down_proj.load_state_dict(state_dict)
def forward(self, hidden_states: paddle.Tensor):
gate_up_out = self.up_gate_proj(hidden_states)
act_out = self.act_fn(gate_up_out)
down_out = self.down_proj(act_out)
return down_out
class Ernie4_5_MoE(nn.Layer):
def __init__(
self, fd_config: FDConfig, layer_id: int, prefix: str, redundant_table_manger: RedundantExpertManger = None
) -> None:
super().__init__()
moe_quant_type = ""
if hasattr(fd_config.quant_config, "moe_quant_type"):
moe_quant_type = fd_config.quant_config.moe_quant_type
self.expert_parallel_size = fd_config.parallel_config.expert_parallel_size
self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size
self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank
self.tp_group = fd_config.parallel_config.tp_group
self.use_ep = self.expert_parallel_size > 1
self.use_tp = self.tensor_parallel_size > 1
if moe_quant_type == "w4a8" or moe_quant_type == "w4afp8":
weight_key_map = {
"gate_weight_key": f"{prefix}.gate.weight",
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.quant_weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.quant_weight",
"up_gate_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.weight_scale",
"down_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.down_proj.weight_scale",
"up_gate_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.activation_scale",
"down_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.down_proj.activation_scale",
}
elif moe_quant_type == "w4w2":
weight_key_map = {
"gate_weight_key": f"{prefix}.gate.weight",
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.quant_weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.quant_weight",
"up_gate_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.weight_scale",
"down_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.down_proj.weight_scale",
"up_gate_proj_expert_super_scales_key": f"{prefix}.experts.{{}}.up_gate_proj.super_scales",
"down_proj_expert_super_scales_key": f"{prefix}.experts.{{}}.down_proj.super_scales",
"up_gate_proj_expert_code_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.code_scale",
"down_proj_expert_code_scale_key": f"{prefix}.experts.{{}}.down_proj.code_scale",
"up_gate_proj_expert_code_zp_key": f"{prefix}.experts.{{}}.up_gate_proj.code_zp",
"down_proj_expert_code_zp_key": f"{prefix}.experts.{{}}.down_proj.code_zp",
}
elif moe_quant_type == "tensor_wise_fp8" or (
moe_quant_type == "block_wise_fp8" and fd_config.model_config.is_quantized
):
weight_key_map = {
"gate_weight_key": f"{prefix}.gate.weight",
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.quant_weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.quant_weight",
"up_gate_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.weight_scale",
"down_proj_expert_weight_scale_key": f"{prefix}.experts.{{}}.down_proj.weight_scale",
"up_gate_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.up_gate_proj.activation_scale",
"down_proj_expert_in_scale_key": f"{prefix}.experts.{{}}.down_proj.activation_scale",
}
else:
weight_key_map = {
"gate_weight_key": f"{prefix}.gate.weight",
"gate_correction_bias_key": f"{prefix}.moe_statics.e_score_correction_bias",
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
}
self.gate = ReplicatedLinear(
fd_config=fd_config,
prefix=f"{prefix}.gate",
input_size=fd_config.model_config.hidden_size,
output_size=fd_config.model_config.moe_num_experts,
with_bias=False,
skip_quant=True,
weight_dtype="float32",
)
self.experts = FusedMoE(
fd_config=fd_config,
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
num_experts=fd_config.model_config.moe_num_experts,
top_k=fd_config.model_config.moe_k,
layer_idx=layer_id,
gate_correction_bias=None,
redundant_table_manger=redundant_table_manger,
weight_key_map=weight_key_map,
)
if fd_config.model_config.moe_use_aux_free:
self.experts.gate_correction_bias = self.create_parameter(
shape=[1, fd_config.model_config.moe_num_experts],
dtype="float32",
default_initializer=paddle.nn.initializer.Constant(0),
)
else:
self.experts.gate_correction_bias = None
self.num_shared_experts = fd_config.model_config.moe_num_shared_experts
if self.num_shared_experts > 0:
shared_experts_hidden_dim = self.num_shared_experts * fd_config.model_config.moe_intermediate_size
self.shared_experts = Ernie4_5_MLP(
fd_config=fd_config,
intermediate_size=shared_experts_hidden_dim,
prefix=f"{prefix}.shared_experts",
)
def load_state_dict(self, state_dict):
self.gate.load_state_dict(state_dict)
self.experts.load_state_dict(state_dict)
if self.experts.gate_correction_bias is not None:
gate_correction_bias_tensor = state_dict.pop(self.experts.gate_correction_bias_key)
if self.experts.gate_correction_bias.shape != gate_correction_bias_tensor.shape:
gate_correction_bias_tensor = gate_correction_bias_tensor.reshape(
self.experts.gate_correction_bias.shape
)
self.experts.gate_correction_bias.set_value(gate_correction_bias_tensor)
if self.num_shared_experts > 0:
self.shared_experts.load_state_dict(state_dict)
def update_state_dict(self, state_dict):
self.fused_moe.load_state_dict(state_dict, True)
def split_allgather_out(self, hidden_states: paddle.Tensor, token_num: int):
token_num_per_rank = (token_num + self.tensor_parallel_size - 1) // self.tensor_parallel_size
# AllGather will hang when the data shapes on multi-ranks are different!
part_hidden_states = paddle.zeros(
shape=[token_num_per_rank, hidden_states.shape[1]], dtype=hidden_states.dtype
)
start_offset = self.tensor_parallel_rank * token_num_per_rank
end_offset = (self.tensor_parallel_rank + 1) * token_num_per_rank
if end_offset > token_num:
end_offset = token_num
part_hidden_states[: (end_offset - start_offset), :] = hidden_states[start_offset:end_offset, :]
out = self.experts(part_hidden_states, self.gate)
multi_outs = []
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
out = paddle.concat(multi_outs, axis=0)
out = out[:token_num, :]
return out
def forward(self, hidden_states: paddle.Tensor):
token_num = hidden_states.shape[0]
if self.use_ep and self.use_tp and token_num >= self.tensor_parallel_size:
out = self.split_allgather_out(hidden_states, token_num)
else:
out = self.experts(hidden_states, self.gate)
if self.num_shared_experts > 0:
s_x = self.shared_experts(hidden_states)
out = out + s_x
return out
class Ernie4_5_Attention(nn.Layer):
def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str) -> None:
super().__init__()
self.qkv_proj = QKVParallelLinear(
fd_config=fd_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
fd_config=fd_config,
prefix=f"{prefix}.o_proj",
input_size=fd_config.model_config.head_dim * fd_config.model_config.num_attention_heads,
output_size=fd_config.model_config.hidden_size,
)
self.attn = Attention(
fd_config=fd_config,
layer_id=layer_id,
prefix=prefix,
use_neox_rotary_style=False,
)
def load_state_dict(self, state_dict):
self.qkv_proj.load_state_dict(state_dict)
self.o_proj.load_state_dict(state_dict)
self.attn.load_state_dict(state_dict)
def forward(
self,
forward_meta: ForwardMeta,
hidden_states: paddle.Tensor,
):
qkv_out = self.qkv_proj(hidden_states)
attn_out = self.attn(
qkv=qkv_out,
forward_meta=forward_meta,
)
output = self.o_proj(attn_out)
return output
class Ernie4_5_DecoderLayer(nn.Layer):
def __init__(
self,
fd_config: FDConfig,
redundant_table_manger: RedundantExpertManger = None,
prefix: str = "",
) -> None:
super().__init__()
layer_id = int(prefix.split(sep=".")[-1])
self.self_attn = Ernie4_5_Attention(
fd_config=fd_config,
layer_id=layer_id,
prefix=f"{prefix}.self_attn",
)
if (
getattr(fd_config.model_config, "moe_num_experts", None) is not None
and layer_id >= fd_config.model_config.moe_layer_start_index
):
self.mlp = Ernie4_5_MoE(
fd_config=fd_config,
layer_id=layer_id,
redundant_table_manger=redundant_table_manger,
prefix=f"{prefix}.mlp",
)
else:
self.mlp = Ernie4_5_MLP(
fd_config=fd_config,
intermediate_size=fd_config.model_config.intermediate_size,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(
fd_config,
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.input_layernorm",
)
self.post_attention_layernorm = RMSNorm(
fd_config,
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.post_attention_layernorm",
)
def load_state_dict(self, state_dict):
self.self_attn.load_state_dict(state_dict)
self.mlp.load_state_dict(state_dict)
self.input_layernorm.load_state_dict(state_dict)
self.post_attention_layernorm.load_state_dict(state_dict)
def update_state_dict(self, state_dict):
self.mlp.update_state_dict(state_dict)
def forward(
self,
forward_meta: ForwardMeta,
hidden_states: paddle.Tensor,
residual: paddle.Tensor = None,
):
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
hidden_states=hidden_states,
forward_meta=forward_meta,
)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_graph_optimization
class Ernie4_5_Model(nn.Layer):
def __init__(
self,
fd_config: FDConfig = None,
):
"""
Initializer for the Ernie4_5_Model class.
Args:
"""
super().__init__()
self.num_layers = fd_config.model_config.num_hidden_layers
fd_config.model_config.pretrained_config.prefix_name = "ernie"
self.fd_config = fd_config
self.redundant_table_manger = None
if fd_config.model_config.enable_redundant_experts is True:
self.redundant_table_manger = RedundantExpertManger(
n_routed_experts=fd_config.model_config.moe_num_experts,
num_hidden_layers=fd_config.model_config.num_hidden_layers,
redundant_experts_num=fd_config.model_config.redundant_experts_num,
ep_size=fd_config.parallel_config.expert_parallel_size,
)
self.embed_tokens = VocabParallelEmbedding(
fd_config=fd_config,
num_embeddings=fd_config.model_config.vocab_size,
embedding_dim=fd_config.model_config.hidden_size,
params_dtype=paddle.get_default_dtype(),
prefix=(f"{fd_config.model_config.pretrained_config.prefix_name}.embed_tokens"),
)
self.layers = nn.LayerList(
[
Ernie4_5_DecoderLayer(
fd_config=fd_config,
redundant_table_manger=self.redundant_table_manger,
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}",
)
for i in range(self.num_layers)
]
)
self.norm = RMSNorm(
fd_config,
hidden_size=fd_config.model_config.hidden_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.norm",
)
def load_state_dict(self, state_dict):
"""
Load model parameters from a given state dictionary.
Args:
state_dict (dict[str, np.ndarray | paddle.Tensor]):
A dictionary containing model parameters, where keys are parameter names
and values are NumPy arrays or PaddlePaddle tensors.
"""
self.embed_tokens.load_state_dict(state_dict)
self.norm.load_state_dict(state_dict)
for i in range(self.num_layers):
logger.info(f"Start load layer {i}")
self.layers[i].load_state_dict(state_dict)
def update_state_dict(self, state_dict):
"""
Update model parameters from a given state dictionary.
Args:
state_dict (dict[str, np.ndarray | paddle.Tensor]):
A dictionary containing model parameters, where keys are parameter names
and values are NumPy arrays or PaddlePaddle tensors.
"""
for i in range(
self.fd_config.model_config.moe_layer_start_index,
self.fd_config.model_config.num_hidden_layers,
):
logger.info(f"Start update layer {i}")
self.layers[i].update_state_dict(state_dict)
def forward(
self,
ids_remove_padding: paddle.Tensor,
forward_meta: ForwardMeta,
):
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding)
residual = None
for i in range(self.num_layers):
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)
hidden_states = hidden_states + residual
out = self.norm(hidden_states)
return out
@ModelRegistry.register_model_class(
architecture="Ernie4_5_MoeForCausalLM",
module_path="ernie4_5_moe",
category=ModelCategory.TEXT_GENERATION,
primary_use=ModelCategory.TEXT_GENERATION,
)
class Ernie4_5_MoeForCausalLM(ModelForCasualLM):
"""
Ernie4_5_MoeForCausalLM
"""
def __init__(self, fd_config: FDConfig):
"""
Args:
fd_config (FDConfig): Configurations for the LLM model.
"""
super(Ernie4_5_MoeForCausalLM, self).__init__(fd_config)
self.fd_config = fd_config
self.ernie = Ernie4_5_Model(fd_config=fd_config)
self.ori_vocab_size = fd_config.model_config.ori_vocab_size
self.lm_head = ParallelLMHead(
fd_config=fd_config,
embedding_dim=fd_config.model_config.hidden_size,
num_embeddings=fd_config.model_config.vocab_size,
prefix="lm_head",
)
self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings
@classmethod
def name(self):
return "Ernie4_5_MoeForCausalLM"
@paddle.no_grad()
def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]):
"""
Load model parameters from a given state dictionary.
Args:
state_dict (dict[str, np.ndarray | paddle.Tensor]):
A dictionary containing model parameters, where keys are parameter names
and values are NumPy arrays or PaddlePaddle tensors.
"""
self.ernie.load_state_dict(state_dict)
if self.tie_word_embeddings:
self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight})
else:
self.lm_head.load_state_dict(state_dict)
@paddle.no_grad()
def load_weights(self, weights_iterator) -> None:
"""
Load model parameters from a given weights_iterator object.
Args:
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
"""
from fastdeploy.model_executor.utils import (
default_weight_loader,
process_weights_after_loading,
rename_offline_ckpt_suffix_to_fd_suffix,
)
general_params_mapping = [
# (param_name, weight_name, expert_id, shard_id)
("embed_tokens.embeddings", "embed_tokens", None, None),
("lm_head.linear", "lm_head", None, None),
("experts.gate_correction_bias", "moe_statics.e_score_correction_bias", None, None),
("qkv_proj", "q_proj", None, "q"),
("qkv_proj", "k_proj", None, "k"),
("qkv_proj", "v_proj", None, "v"),
("up_gate_proj", "gate_proj", None, "gate"),
("up_gate_proj", "up_proj", None, "up"),
("attn.cache_k_scale", "cachek_matmul.activation_scale", None, None),
("attn.cache_v_scale", "cachev_matmul.activation_scale", None, None),
("attn.cache_k_zp", "cachek_matmul.activation_zero_point", None, None),
("attn.cache_v_zp", "cachev_matmul.activation_zero_point", None, None),
]
expert_params_mapping = []
if getattr(self.fd_config.model_config, "moe_num_experts", None) is not None:
if self.fd_config.parallel_config.expert_parallel_size > 1:
num_experts = self.fd_config.parallel_config.num_experts_per_rank
num_experts_start_offset = self.fd_config.parallel_config.num_experts_start_offset
else:
num_experts = self.fd_config.model_config.moe_num_experts
num_experts_start_offset = 0
expert_params_mapping = FusedMoE.make_expert_params_mapping(
num_experts=num_experts,
ckpt_down_proj_name="down_proj",
ckpt_gate_up_proj_name="up_gate_proj",
ckpt_gate_proj_name="gate_proj",
ckpt_up_proj_name="up_proj",
param_gate_up_proj_name="experts.up_gate_proj_",
param_down_proj_name="experts.down_proj_",
num_experts_start_offset=num_experts_start_offset,
)
all_param_mapping = [
(param, weight, exp, shard, False) for param, weight, exp, shard in general_params_mapping
] + [(param, weight, exp, shard, True) for param, weight, exp, shard in expert_params_mapping]
checkpoint_to_fd_key_fn = rename_offline_ckpt_suffix_to_fd_suffix(
fd_config=self.fd_config, ckpt_weight_suffix="quant_weight", ckpt_scale_suffix="weight_scale"
)
params_dict = dict(self.named_parameters())
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
for loaded_weight_name, loaded_weight in weights_iterator:
loaded_weight_name = loaded_weight_name.replace("model", "ernie")
for param_name, weight_name, exp_id, shard_id, is_moe in all_param_mapping:
loaded_weight_name = checkpoint_to_fd_key_fn(loaded_weight_name, is_moe)
model_param_name = loaded_weight_name.replace(weight_name, param_name)
if model_param_name not in params_dict:
continue
param = params_dict[model_param_name]
expert_id = exp_id
shard_id = shard_id
break
else:
expert_id = None
shard_id = None
loaded_weight_name = checkpoint_to_fd_key_fn(loaded_weight_name, is_moe=False)
model_param_name = loaded_weight_name
if model_param_name not in params_dict.keys():
continue
param = params_dict[model_param_name]
# Get weight loader from parameter and set weight
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
sig = inspect.signature(weight_loader)
if "expert_id" in sig.parameters:
weight_loader(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
else:
weight_loader(param, loaded_weight, shard_id)
model_sublayer_name = re.sub(
r"\.(up_gate_proj_weight|down_proj_weight|weight|cache_k_scale|cache_v_scale)$", "", model_param_name
)
process_weights_after_loading_fn(model_sublayer_name, param)
if self.tie_word_embeddings:
self.lm_head.load_state_dict({self.lm_head.weight_key: self.ernie.embed_tokens.embeddings.weight})
def compute_logits(self, hidden_states: paddle.Tensor):
logits = self.lm_head(hidden_states)
logits = logits.astype(paddle.float32)
logits[:, self.ori_vocab_size :] = -float("inf")
return logits
def empty_input_forward(self):
"""
empty_input_forward
"""
fake_hidden_states = paddle.empty(
shape=[0, self.fd_config.model_config.hidden_size],
dtype=paddle.get_default_dtype(),
)
for i in range(
self.fd_config.model_config.moe_layer_start_index,
self.fd_config.model_config.num_hidden_layers,
):
self.ernie.layers[i].mlp.experts(fake_hidden_states, self.ernie.layers[i].mlp.gate)
def forward(
self,
ids_remove_padding: paddle.Tensor,
forward_meta: ForwardMeta,
):
hidden_states = self.ernie(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta)
return hidden_states
def clear_grpah_opt_backend(self):
"""Clear graph optimization backend, the captured cuda graph will be cleaned"""
self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config)
@ModelRegistry.register_model_class(
architecture="Ernie4_5_ForCausalLM",
module_path="ernie4_5_moe",
category=ModelCategory.TEXT_GENERATION,
primary_use=ModelCategory.TEXT_GENERATION,
)
class Ernie4_5_ForCausalLM(Ernie4_5_MoeForCausalLM):
"""
Ernie4_5_ForCausalLM
"""
@classmethod
def name(self):
"""
Model Architecture Name
"""
return "Ernie4_5_ForCausalLM"
@ModelRegistry.register_model_class(
architecture="Ernie4_5ForCausalLM",
module_path="ernie4_5_moe",
category=ModelCategory.TEXT_GENERATION,
primary_use=ModelCategory.TEXT_GENERATION,
)
class Ernie4_5ForCausalLM(Ernie4_5_ForCausalLM):
"""
Ernie4_5ForCausalLM 0.3B-PT
"""
@classmethod
def name(self):
"""
Model Architecture Name
"""
return "Ernie4_5ForCausalLM"
class Ernie4_5_MoePretrainedModel(PretrainedModel):
"""
Ernie4_5_MoePretrainedModel
"""
config_class = FDConfig
def _init_weight(self, layer):
"""
_init_weight
"""
return None
@classmethod
def arch_name(self):
return "Ernie4_5_MoeForCausalLM"
weight_infos = [
WeightMeta(
f".layers.{{{layerid.LAYER_ID}}}.self_attn.qkv_proj.weight",
True,
tsm.GQA,
),
WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.o_proj.weight", False),
WeightMeta(
f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.up_gate_proj.weight",
True,
tsm.PairFused,
),
WeightMeta(f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.down_proj.weight", False),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.up_gate_proj.weight",
True,
tsm.PairFused,
),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.down_proj.weight",
False,
),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.up_gate_proj.weight",
True,
tsm.PairFused,
),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.down_proj.weight",
False,
),
WeightMeta(".embed_tokens.weight", False),
WeightMeta("lm_head.weight", True),
# quant tensorwise
WeightMeta(
f".layers.{{{layerid.LAYER_ID}}}.self_attn.qkv_proj.quant_weight",
True,
tsm.GQA,
),
WeightMeta(
f".layers.{{{layerid.LAYER_ID}}}.self_attn.o_proj.quant_weight",
False,
),
WeightMeta(
f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.up_gate_proj.quant_weight",
True,
tsm.PairFused,
),
WeightMeta(
f".layers.{{{layerid.FFN_LAYER_ID}}}.mlp.down_proj.quant_weight",
False,
),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.up_gate_proj.quant_weight",
True,
tsm.PairFused,
),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.experts.{{{layerid.EXPERT_ID}}}.down_proj.quant_weight",
False,
),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.up_gate_proj.quant_weight",
True,
tsm.PairFused,
),
WeightMeta(
f".layers.{{{layerid.MOE_LAYER_ID}}}.mlp.shared_experts.down_proj.quant_weight",
False,
),
]
@classmethod
def _get_tensor_parallel_mappings(cls, config: PretrainedConfig, is_split=True):
"""
get_tensor_parallel_mappings
"""
logger.info("erine inference model _get_tensor_parallel_mappings")
from fastdeploy.model_executor.models.tp_utils import (
build_expanded_keys,
has_prefix,
split_or_merge_func_v1,
)
fn = split_or_merge_func_v1(
is_split=is_split,
tensor_parallel_degree=config.tensor_parallel_degree,
tensor_parallel_rank=config.tensor_parallel_rank,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
head_dim=config.head_dim,
)
def get_tensor_parallel_split_mappings(num_layers, moe_num_experts, moe_layer_start_index, prefix_name):
base_actions = {}
weight_infos = cls.weight_infos
for weight_name, is_column, extra in weight_infos:
params = {
"is_column": is_column,
**({extra.value: True} if extra else {}),
}
if "lm_head.weight" in weight_name:
key = weight_name
elif not has_prefix(prefix_name, weight_name):
key = f"{prefix_name}{weight_name}"
else:
key = weight_name
base_actions[key] = partial(fn, **params)
final_actions = {}
start_layer = moe_layer_start_index if moe_layer_start_index > 0 else num_layers
final_actions = build_expanded_keys(base_actions, num_layers, start_layer, moe_num_experts)
return final_actions
mappings = get_tensor_parallel_split_mappings(
config.num_hidden_layers,
getattr(config, "moe_num_experts", 0),
getattr(config, "moe_layer_start_index", -1),
config.prefix_name,
)
return mappings
class Ernie4_5_PretrainedModel(Ernie4_5_MoePretrainedModel):
"""
Ernie4_5_PretrainedModel
"""
@classmethod
def arch_name(self):
"""
Model Architecture Name
"""
return "Ernie4_5_ForCausalLM"
class Ernie4_5PretrainedModel(Ernie4_5_PretrainedModel):
"""
Ernie4_5PretrainedModel 0.3B-PT
"""
@classmethod
def arch_name(self):
"""
Model Architecture Name
"""
return "Ernie4_5ForCausalLM"