[V1 Loader] Support DeepSeekV3(bf16) (#3294)

* Support new loader for DeepSeekV3(bf16)

* update paddle version

* remove useless attr
This commit is contained in:
Zero Rains
2025-08-11 13:39:28 +08:00
committed by GitHub
parent e0aeac58e1
commit 42af0b4b64
5 changed files with 141 additions and 5 deletions

View File

@@ -628,6 +628,79 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
self.model.load_state_dict(state_dict)
self.lm_head.load_state_dict(state_dict)
@paddle.no_grad()
def load_weights(self, weights_iterator) -> None:
"""
Load model parameters from a given weights_iterator object.
Args:
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
"""
from fastdeploy.model_executor.models.utils import default_weight_loader
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("up_gate_proj", "gate_proj", "gate"),
("up_gate_proj", "up_proj", "up"),
("embed_tokens.embeddings", "embed_tokens", None),
("lm_head.linear", "lm_head", None),
("experts.gate_correction_bias", "gate.e_score_correction_bias", None),
]
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
param_gate_up_proj_name="experts.up_gate_proj_",
param_down_proj_name="experts.down_proj_",
num_experts=self.fd_config.model_config.n_routed_experts,
)
params_dict = dict(self.named_parameters())
for loaded_weight_name, loaded_weight in weights_iterator:
loaded_weight_name = loaded_weight_name.replace("deepseek_v3", "model")
loaded_weight_name = loaded_weight_name.replace("layers", "decoder_layers")
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in loaded_weight_name:
continue
if "mlp.experts." in loaded_weight_name and loaded_weight_name not in params_dict:
continue
model_param_name = loaded_weight_name.replace(weight_name, param_name)
if model_param_name not in params_dict:
continue
param = params_dict[model_param_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in loaded_weight_name:
continue
model_param_name = loaded_weight_name.replace(weight_name, param_name)
if model_param_name not in params_dict:
continue
param = params_dict[model_param_name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id=shard_id, expert_id=expert_id)
break
else:
if loaded_weight_name not in params_dict:
continue
param = params_dict[loaded_weight_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
weight_loader(param, loaded_weight)
if "kv_b_proj.weight" in loaded_weight_name:
# handle kv_b_proj_bmm
model_param_name = loaded_weight_name.replace(
"kv_b_proj.weight", "kv_b_proj_bmm.k_b_proj_weight"
)
param = params_dict[model_param_name]
weight_loader = getattr(param, "weight_loader", None)
weight_loader(param, loaded_weight, shard_id)
def compute_logits(self, hidden_states: paddle.Tensor):
""" """
logits = self.lm_head(hidden_states)

View File

@@ -78,9 +78,13 @@ def default_weight_loader(fd_config: FDConfig) -> None:
if param.dtype != loaded_weight.dtype:
loaded_weight = loaded_weight.cast(param.dtype)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
if param.shape != loaded_weight.shape:
try:
param = param.reshape(loaded_weight.shape)
except ValueError as e:
raise ValueError(
f" Attempted to load weight ({loaded_weight.shape}) into parameter ({param.shape}). {e}"
)
param.copy_(loaded_weight, False)
except Exception: