mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
fix mtp (#4153)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Co-authored-by: YuanRisheng <yuanrisheng@baidu.com>
This commit is contained in:
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, Union
|
from typing import Dict, Union
|
||||||
|
|
||||||
@@ -250,7 +251,7 @@ class Ernie4_5_MTPModel(nn.Layer):
|
|||||||
self.embed_tokens = fd_config.speculative_config.sharing_model.ernie.embed_tokens
|
self.embed_tokens = fd_config.speculative_config.sharing_model.ernie.embed_tokens
|
||||||
self.norm = fd_config.speculative_config.sharing_model.ernie.norm
|
self.norm = fd_config.speculative_config.sharing_model.ernie.norm
|
||||||
|
|
||||||
self.layers = nn.LayerList(
|
self.mtp_block = nn.LayerList(
|
||||||
[
|
[
|
||||||
Ernie4_5_DecoderLayer(
|
Ernie4_5_DecoderLayer(
|
||||||
fd_config=fd_config,
|
fd_config=fd_config,
|
||||||
@@ -296,7 +297,7 @@ class Ernie4_5_MTPModel(nn.Layer):
|
|||||||
self.eh_proj.load_state_dict(state_dict)
|
self.eh_proj.load_state_dict(state_dict)
|
||||||
for i in range(self.num_layers):
|
for i in range(self.num_layers):
|
||||||
logger.info(f"Start load layer {i}")
|
logger.info(f"Start load layer {i}")
|
||||||
self.layers[i].load_state_dict(state_dict)
|
self.mtp_block[i].load_state_dict(state_dict)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -315,7 +316,7 @@ class Ernie4_5_MTPModel(nn.Layer):
|
|||||||
hidden_states = self.eh_proj(inputs_embedding)
|
hidden_states = self.eh_proj(inputs_embedding)
|
||||||
residual = None
|
residual = None
|
||||||
for i in range(self.num_layers):
|
for i in range(self.num_layers):
|
||||||
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)
|
hidden_states, residual = self.mtp_block[i](forward_meta, hidden_states, residual)
|
||||||
|
|
||||||
hidden_states = hidden_states + residual
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
@@ -374,17 +375,23 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
|
|||||||
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
|
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from fastdeploy.model_executor.utils import default_weight_loader
|
from fastdeploy.model_executor.utils import (
|
||||||
|
default_weight_loader,
|
||||||
|
process_weights_after_loading,
|
||||||
|
)
|
||||||
|
|
||||||
all_param_mapping = [
|
all_param_mapping = [
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
("embed_tokens.embeddings", "embed_tokens", None, None),
|
("embed_tokens.embeddings", "embed_tokens", None, None),
|
||||||
("lm_head.linear", "lm_head", None, None),
|
("lm_head.linear", "lm_head", None, None),
|
||||||
|
("enorm", "mtp_emb_norm.0", None, None),
|
||||||
|
("hnorm", "mtp_hidden_norm.0", None, None),
|
||||||
|
("eh_proj.linear", "mtp_linear_proj.0", None, None),
|
||||||
]
|
]
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
shard_id = None
|
shard_id = None
|
||||||
|
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
|
||||||
for loaded_weight_name, loaded_weight in weights_iterator:
|
for loaded_weight_name, loaded_weight in weights_iterator:
|
||||||
for param_name, weight_name, exp_id, shard_id in all_param_mapping:
|
for param_name, weight_name, exp_id, shard_id in all_param_mapping:
|
||||||
if weight_name not in loaded_weight_name:
|
if weight_name not in loaded_weight_name:
|
||||||
@@ -396,11 +403,16 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
|
|||||||
else:
|
else:
|
||||||
if loaded_weight_name not in params_dict.keys():
|
if loaded_weight_name not in params_dict.keys():
|
||||||
continue
|
continue
|
||||||
|
model_param_name = loaded_weight_name
|
||||||
param = params_dict[loaded_weight_name]
|
param = params_dict[loaded_weight_name]
|
||||||
|
|
||||||
# Get weight loader from parameter and set weight
|
# Get weight loader from parameter and set weight
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
|
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
model_sublayer_name = re.sub(
|
||||||
|
r"\.(up_gate_proj_weight|down_proj_weight|weight|cache_k_scale|cache_v_scale)$", "", model_param_name
|
||||||
|
)
|
||||||
|
process_weights_after_loading_fn(model_sublayer_name, param)
|
||||||
|
|
||||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||||
"""
|
"""
|
||||||
|
Reference in New Issue
Block a user