mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Loader V1] modify layername for DeepSeekV3 (#3336)
Co-authored-by: Yuanle Liu <yuanlehome@163.com> Co-authored-by: YUNSHEN XIE <1084314248@qq.com>
This commit is contained in:
@@ -539,7 +539,7 @@ class DeepSeekV3Model(nn.Layer):
|
|||||||
prefix="deepseek_v3.embed_tokens",
|
prefix="deepseek_v3.embed_tokens",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.decoder_layers = nn.LayerList(
|
self.layers = nn.LayerList(
|
||||||
[
|
[
|
||||||
DeepSeekV3DecoderLayer(
|
DeepSeekV3DecoderLayer(
|
||||||
fd_config,
|
fd_config,
|
||||||
@@ -564,7 +564,7 @@ class DeepSeekV3Model(nn.Layer):
|
|||||||
self.norm.load_state_dict(state_dict)
|
self.norm.load_state_dict(state_dict)
|
||||||
for i in range(self.num_layers):
|
for i in range(self.num_layers):
|
||||||
logger.info(f"Start load layer {i}")
|
logger.info(f"Start load layer {i}")
|
||||||
self.decoder_layers[i].load_state_dict(state_dict)
|
self.layers[i].load_state_dict(state_dict)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -578,7 +578,7 @@ class DeepSeekV3Model(nn.Layer):
|
|||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i in range(self.num_layers):
|
for i in range(self.num_layers):
|
||||||
hidden_states, residual = self.decoder_layers[i](
|
hidden_states, residual = self.layers[i](
|
||||||
forward_meta,
|
forward_meta,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
@@ -658,12 +658,11 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
|
|||||||
|
|
||||||
for loaded_weight_name, loaded_weight in weights_iterator:
|
for loaded_weight_name, loaded_weight in weights_iterator:
|
||||||
loaded_weight_name = loaded_weight_name.replace("deepseek_v3", "model")
|
loaded_weight_name = loaded_weight_name.replace("deepseek_v3", "model")
|
||||||
loaded_weight_name = loaded_weight_name.replace("layers", "decoder_layers")
|
|
||||||
|
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in loaded_weight_name:
|
if weight_name not in loaded_weight_name:
|
||||||
continue
|
continue
|
||||||
if "mlp.experts." in loaded_weight_name and loaded_weight_name not in params_dict:
|
if "mlp.experts." in loaded_weight_name:
|
||||||
continue
|
continue
|
||||||
model_param_name = loaded_weight_name.replace(weight_name, param_name)
|
model_param_name = loaded_weight_name.replace(weight_name, param_name)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user