mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-25 01:20:43 +08:00 
			
		
		
		
	fix mtp (#4105)
This commit is contained in:
		| @@ -16,6 +16,7 @@ | |||||||
|  |  | ||||||
| from __future__ import annotations | from __future__ import annotations | ||||||
|  |  | ||||||
|  | import re | ||||||
| from functools import partial | from functools import partial | ||||||
| from typing import Dict, Union | from typing import Dict, Union | ||||||
|  |  | ||||||
| @@ -250,7 +251,7 @@ class Ernie4_5_MTPModel(nn.Layer): | |||||||
|         self.embed_tokens = fd_config.speculative_config.sharing_model.ernie.embed_tokens |         self.embed_tokens = fd_config.speculative_config.sharing_model.ernie.embed_tokens | ||||||
|         self.norm = fd_config.speculative_config.sharing_model.ernie.norm |         self.norm = fd_config.speculative_config.sharing_model.ernie.norm | ||||||
|  |  | ||||||
|         self.layers = nn.LayerList( |         self.mtp_block = nn.LayerList( | ||||||
|             [ |             [ | ||||||
|                 Ernie4_5_DecoderLayer( |                 Ernie4_5_DecoderLayer( | ||||||
|                     fd_config=fd_config, |                     fd_config=fd_config, | ||||||
| @@ -296,7 +297,7 @@ class Ernie4_5_MTPModel(nn.Layer): | |||||||
|         self.eh_proj.load_state_dict(state_dict) |         self.eh_proj.load_state_dict(state_dict) | ||||||
|         for i in range(self.num_layers): |         for i in range(self.num_layers): | ||||||
|             logger.info(f"Start load layer {i}") |             logger.info(f"Start load layer {i}") | ||||||
|             self.layers[i].load_state_dict(state_dict) |             self.mtp_block[i].load_state_dict(state_dict) | ||||||
|  |  | ||||||
|     def forward( |     def forward( | ||||||
|         self, |         self, | ||||||
| @@ -315,7 +316,7 @@ class Ernie4_5_MTPModel(nn.Layer): | |||||||
|         hidden_states = self.eh_proj(inputs_embedding) |         hidden_states = self.eh_proj(inputs_embedding) | ||||||
|         residual = None |         residual = None | ||||||
|         for i in range(self.num_layers): |         for i in range(self.num_layers): | ||||||
|             hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual) |             hidden_states, residual = self.mtp_block[i](forward_meta, hidden_states, residual) | ||||||
|  |  | ||||||
|         hidden_states = hidden_states + residual |         hidden_states = hidden_states + residual | ||||||
|  |  | ||||||
| @@ -374,17 +375,23 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM): | |||||||
|             weights_iterator (Iterator): An iterator yielding (name, weight) pairs. |             weights_iterator (Iterator): An iterator yielding (name, weight) pairs. | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|         from fastdeploy.model_executor.utils import default_weight_loader |         from fastdeploy.model_executor.utils import ( | ||||||
|  |             default_weight_loader, | ||||||
|  |             process_weights_after_loading, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         all_param_mapping = [ |         all_param_mapping = [ | ||||||
|             # (param_name, weight_name, expert_id, shard_id) |             # (param_name, weight_name, expert_id, shard_id) | ||||||
|             ("embed_tokens.embeddings", "embed_tokens", None, None), |             ("embed_tokens.embeddings", "embed_tokens", None, None), | ||||||
|             ("lm_head.linear", "lm_head", None, None), |             ("lm_head.linear", "lm_head", None, None), | ||||||
|  |             ("enorm", "mtp_emb_norm.0", None, None), | ||||||
|  |             ("hnorm", "mtp_hidden_norm.0", None, None), | ||||||
|  |             ("eh_proj.linear", "mtp_linear_proj.0", None, None), | ||||||
|         ] |         ] | ||||||
|  |  | ||||||
|         params_dict = dict(self.named_parameters()) |         params_dict = dict(self.named_parameters()) | ||||||
|         shard_id = None |         shard_id = None | ||||||
|  |         process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers())) | ||||||
|         for loaded_weight_name, loaded_weight in weights_iterator: |         for loaded_weight_name, loaded_weight in weights_iterator: | ||||||
|             for param_name, weight_name, exp_id, shard_id in all_param_mapping: |             for param_name, weight_name, exp_id, shard_id in all_param_mapping: | ||||||
|                 if weight_name not in loaded_weight_name: |                 if weight_name not in loaded_weight_name: | ||||||
| @@ -396,11 +403,16 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM): | |||||||
|             else: |             else: | ||||||
|                 if loaded_weight_name not in params_dict.keys(): |                 if loaded_weight_name not in params_dict.keys(): | ||||||
|                     continue |                     continue | ||||||
|  |                 model_param_name = loaded_weight_name | ||||||
|                 param = params_dict[loaded_weight_name] |                 param = params_dict[loaded_weight_name] | ||||||
|  |  | ||||||
|             # Get weight loader from parameter and set weight |             # Get weight loader from parameter and set weight | ||||||
|             weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) |             weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) | ||||||
|             weight_loader(param, loaded_weight) |             weight_loader(param, loaded_weight) | ||||||
|  |             model_sublayer_name = re.sub( | ||||||
|  |                 r"\.(up_gate_proj_weight|down_proj_weight|weight|cache_k_scale|cache_v_scale)$", "", model_param_name | ||||||
|  |             ) | ||||||
|  |             process_weights_after_loading_fn(model_sublayer_name, param) | ||||||
|  |  | ||||||
|     def compute_logits(self, hidden_states: paddle.Tensor): |     def compute_logits(self, hidden_states: paddle.Tensor): | ||||||
|         """ |         """ | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 YuanRisheng
					YuanRisheng