fix mtp (#4153)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

Co-authored-by: YuanRisheng <yuanrisheng@baidu.com>
This commit is contained in:
JYChen
2025-09-18 10:53:07 +08:00
committed by GitHub
parent 0fa28b1068
commit 74d7b9151d

View File

@@ -16,6 +16,7 @@
from __future__ import annotations from __future__ import annotations
import re
from functools import partial from functools import partial
from typing import Dict, Union from typing import Dict, Union
@@ -250,7 +251,7 @@ class Ernie4_5_MTPModel(nn.Layer):
self.embed_tokens = fd_config.speculative_config.sharing_model.ernie.embed_tokens self.embed_tokens = fd_config.speculative_config.sharing_model.ernie.embed_tokens
self.norm = fd_config.speculative_config.sharing_model.ernie.norm self.norm = fd_config.speculative_config.sharing_model.ernie.norm
self.layers = nn.LayerList( self.mtp_block = nn.LayerList(
[ [
Ernie4_5_DecoderLayer( Ernie4_5_DecoderLayer(
fd_config=fd_config, fd_config=fd_config,
@@ -296,7 +297,7 @@ class Ernie4_5_MTPModel(nn.Layer):
self.eh_proj.load_state_dict(state_dict) self.eh_proj.load_state_dict(state_dict)
for i in range(self.num_layers): for i in range(self.num_layers):
logger.info(f"Start load layer {i}") logger.info(f"Start load layer {i}")
self.layers[i].load_state_dict(state_dict) self.mtp_block[i].load_state_dict(state_dict)
def forward( def forward(
self, self,
@@ -315,7 +316,7 @@ class Ernie4_5_MTPModel(nn.Layer):
hidden_states = self.eh_proj(inputs_embedding) hidden_states = self.eh_proj(inputs_embedding)
residual = None residual = None
for i in range(self.num_layers): for i in range(self.num_layers):
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual) hidden_states, residual = self.mtp_block[i](forward_meta, hidden_states, residual)
hidden_states = hidden_states + residual hidden_states = hidden_states + residual
@@ -374,17 +375,23 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
weights_iterator (Iterator): An iterator yielding (name, weight) pairs. weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
""" """
from fastdeploy.model_executor.utils import default_weight_loader from fastdeploy.model_executor.utils import (
default_weight_loader,
process_weights_after_loading,
)
all_param_mapping = [ all_param_mapping = [
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
("embed_tokens.embeddings", "embed_tokens", None, None), ("embed_tokens.embeddings", "embed_tokens", None, None),
("lm_head.linear", "lm_head", None, None), ("lm_head.linear", "lm_head", None, None),
("enorm", "mtp_emb_norm.0", None, None),
("hnorm", "mtp_hidden_norm.0", None, None),
("eh_proj.linear", "mtp_linear_proj.0", None, None),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
shard_id = None shard_id = None
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
for loaded_weight_name, loaded_weight in weights_iterator: for loaded_weight_name, loaded_weight in weights_iterator:
for param_name, weight_name, exp_id, shard_id in all_param_mapping: for param_name, weight_name, exp_id, shard_id in all_param_mapping:
if weight_name not in loaded_weight_name: if weight_name not in loaded_weight_name:
@@ -396,11 +403,16 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
else: else:
if loaded_weight_name not in params_dict.keys(): if loaded_weight_name not in params_dict.keys():
continue continue
model_param_name = loaded_weight_name
param = params_dict[loaded_weight_name] param = params_dict[loaded_weight_name]
# Get weight loader from parameter and set weight # Get weight loader from parameter and set weight
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
model_sublayer_name = re.sub(
r"\.(up_gate_proj_weight|down_proj_weight|weight|cache_k_scale|cache_v_scale)$", "", model_param_name
)
process_weights_after_loading_fn(model_sublayer_name, param)
def compute_logits(self, hidden_states: paddle.Tensor): def compute_logits(self, hidden_states: paddle.Tensor):
""" """