diff --git a/fastdeploy/model_executor/models/ernie4_5_mtp.py b/fastdeploy/model_executor/models/ernie4_5_mtp.py index 19123678a..a722b2e56 100644 --- a/fastdeploy/model_executor/models/ernie4_5_mtp.py +++ b/fastdeploy/model_executor/models/ernie4_5_mtp.py @@ -16,6 +16,7 @@ from __future__ import annotations +import re from functools import partial from typing import Dict, Union @@ -250,7 +251,7 @@ class Ernie4_5_MTPModel(nn.Layer): self.embed_tokens = fd_config.speculative_config.sharing_model.ernie.embed_tokens self.norm = fd_config.speculative_config.sharing_model.ernie.norm - self.layers = nn.LayerList( + self.mtp_block = nn.LayerList( [ Ernie4_5_DecoderLayer( fd_config=fd_config, @@ -296,7 +297,7 @@ class Ernie4_5_MTPModel(nn.Layer): self.eh_proj.load_state_dict(state_dict) for i in range(self.num_layers): logger.info(f"Start load layer {i}") - self.layers[i].load_state_dict(state_dict) + self.mtp_block[i].load_state_dict(state_dict) def forward( self, @@ -315,7 +316,7 @@ class Ernie4_5_MTPModel(nn.Layer): hidden_states = self.eh_proj(inputs_embedding) residual = None for i in range(self.num_layers): - hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual) + hidden_states, residual = self.mtp_block[i](forward_meta, hidden_states, residual) hidden_states = hidden_states + residual @@ -374,17 +375,23 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM): weights_iterator (Iterator): An iterator yielding (name, weight) pairs. """ - from fastdeploy.model_executor.utils import default_weight_loader + from fastdeploy.model_executor.utils import ( + default_weight_loader, + process_weights_after_loading, + ) all_param_mapping = [ # (param_name, weight_name, expert_id, shard_id) ("embed_tokens.embeddings", "embed_tokens", None, None), ("lm_head.linear", "lm_head", None, None), + ("enorm", "mtp_emb_norm.0", None, None), + ("hnorm", "mtp_hidden_norm.0", None, None), + ("eh_proj.linear", "mtp_linear_proj.0", None, None), ] params_dict = dict(self.named_parameters()) shard_id = None - + process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers())) for loaded_weight_name, loaded_weight in weights_iterator: for param_name, weight_name, exp_id, shard_id in all_param_mapping: if weight_name not in loaded_weight_name: @@ -396,11 +403,16 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM): else: if loaded_weight_name not in params_dict.keys(): continue + model_param_name = loaded_weight_name param = params_dict[loaded_weight_name] # Get weight loader from parameter and set weight weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) weight_loader(param, loaded_weight) + model_sublayer_name = re.sub( + r"\.(up_gate_proj_weight|down_proj_weight|weight|cache_k_scale|cache_v_scale)$", "", model_param_name + ) + process_weights_after_loading_fn(model_sublayer_name, param) def compute_logits(self, hidden_states: paddle.Tensor): """