mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
[V1 Loader] Support DeepSeekV3(bf16) (#3294)
* Support new loader for DeepSeekV3(bf16) * update paddle version * remove useless attr
This commit is contained in:
@@ -720,6 +720,7 @@ class KVBatchLinear(LinearBase):
|
||||
self.v_head_dim = v_head_dim
|
||||
# Split num_attention_heads when using TP inference.
|
||||
self.num_heads_per_partition = divide(num_attention_heads, self.nranks)
|
||||
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
|
||||
|
||||
# Initialize parent with combined dimensions
|
||||
super().__init__(
|
||||
@@ -738,6 +739,63 @@ class KVBatchLinear(LinearBase):
|
||||
self.k_weight_key = f"{prefix.replace('kv_b_proj', 'k_b_proj')}.weight"
|
||||
self.v_weight_key = f"{prefix.replace('kv_b_proj', 'v_b_proj')}.weight"
|
||||
|
||||
self.k_b_proj_weight = self.create_parameter(
|
||||
shape=[self.num_heads_per_partition, self.qk_nope_head_dim, self.kv_lora_rank],
|
||||
dtype=self.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
self.v_b_proj_weight = self.create_parameter(
|
||||
shape=[self.num_heads_per_partition, self.kv_lora_rank, self.v_head_dim],
|
||||
dtype=self.weight_dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
set_weight_attrs(
|
||||
self.k_b_proj_weight,
|
||||
{"weight_loader": self.weight_loader},
|
||||
)
|
||||
|
||||
if self.nranks > 0:
|
||||
_set_var_distributed(self.k_b_proj_weight, split_axis=1)
|
||||
set_weight_attrs(self.k_b_proj_weight, {"output_dim": True})
|
||||
|
||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
# Tensor parallelism splits the weight along the output_dim
|
||||
if output_dim is not None:
|
||||
dim = -1
|
||||
size = loaded_weight.get_shape()[dim]
|
||||
block_size = size // self.nranks
|
||||
shard_offset = self.local_rank * block_size
|
||||
shard_size = (self.local_rank + 1) * block_size
|
||||
loaded_weight = loaded_weight[..., shard_offset:shard_size]
|
||||
w = (
|
||||
get_tensor(loaded_weight)
|
||||
.reshape(
|
||||
[
|
||||
self.kv_lora_rank,
|
||||
self.num_heads_per_partition,
|
||||
-1,
|
||||
]
|
||||
)
|
||||
.transpose(perm=[1, 2, 0])
|
||||
)
|
||||
if param.dtype != w.dtype:
|
||||
w = w.cast(param.dtype)
|
||||
# Split into K and V weights
|
||||
# wk_b: [num_heads, qk_nope_head_dim, kv_lora_rank]
|
||||
wk_b = w[:, : self.qk_nope_head_dim, :]
|
||||
if self.v_head_dim is None:
|
||||
raise ValueError("self.v_head_dim should not be None")
|
||||
# wv_b: [num_heads, kv_lora_rank, v_head_dim]
|
||||
wv_b = w[:, -self.v_head_dim :, :].transpose(perm=[0, 2, 1])
|
||||
|
||||
self.k_b_proj_weight.set_value(wk_b)
|
||||
self.v_b_proj_weight.set_value(wv_b)
|
||||
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
"""
|
||||
Load the combined KV weight and split it into K and V projections
|
||||
|
@@ -52,7 +52,7 @@ def get_moe_scores(
|
||||
compute moe scores using e_score_correction_bias.
|
||||
"""
|
||||
scores = paddle.nn.functional.sigmoid(gating_output)
|
||||
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
|
||||
scores_with_bias = scores + e_score_correction_bias
|
||||
scores, topk_values, topk_idx = noaux_tc(
|
||||
scores,
|
||||
scores_with_bias,
|
||||
|
@@ -508,10 +508,11 @@ class FusedMoE(nn.Layer):
|
||||
gate_correction_bias_tensor = self.extract_gate_correction_bias(
|
||||
self.gate_correction_bias_key, state_dict
|
||||
)
|
||||
if self.gate_correction_bias.shape != gate_correction_bias_tensor.shape:
|
||||
gate_correction_bias_tensor = gate_correction_bias_tensor.reshape(self.gate_correction_bias.shape)
|
||||
self.gate_correction_bias.set_value(gate_correction_bias_tensor)
|
||||
else:
|
||||
self.gate_correction_bias = None
|
||||
|
||||
else:
|
||||
self.gate_correction_bias = None
|
||||
|
||||
|
@@ -628,6 +628,79 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
|
||||
self.model.load_state_dict(state_dict)
|
||||
self.lm_head.load_state_dict(state_dict)
|
||||
|
||||
@paddle.no_grad()
|
||||
def load_weights(self, weights_iterator) -> None:
|
||||
"""
|
||||
Load model parameters from a given weights_iterator object.
|
||||
Args:
|
||||
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
|
||||
"""
|
||||
from fastdeploy.model_executor.models.utils import default_weight_loader
|
||||
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("up_gate_proj", "gate_proj", "gate"),
|
||||
("up_gate_proj", "up_proj", "up"),
|
||||
("embed_tokens.embeddings", "embed_tokens", None),
|
||||
("lm_head.linear", "lm_head", None),
|
||||
("experts.gate_correction_bias", "gate.e_score_correction_bias", None),
|
||||
]
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
param_gate_up_proj_name="experts.up_gate_proj_",
|
||||
param_down_proj_name="experts.down_proj_",
|
||||
num_experts=self.fd_config.model_config.n_routed_experts,
|
||||
)
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
for loaded_weight_name, loaded_weight in weights_iterator:
|
||||
loaded_weight_name = loaded_weight_name.replace("deepseek_v3", "model")
|
||||
loaded_weight_name = loaded_weight_name.replace("layers", "decoder_layers")
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in loaded_weight_name:
|
||||
continue
|
||||
if "mlp.experts." in loaded_weight_name and loaded_weight_name not in params_dict:
|
||||
continue
|
||||
model_param_name = loaded_weight_name.replace(weight_name, param_name)
|
||||
|
||||
if model_param_name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[model_param_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in loaded_weight_name:
|
||||
continue
|
||||
model_param_name = loaded_weight_name.replace(weight_name, param_name)
|
||||
if model_param_name not in params_dict:
|
||||
continue
|
||||
param = params_dict[model_param_name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id=shard_id, expert_id=expert_id)
|
||||
break
|
||||
else:
|
||||
if loaded_weight_name not in params_dict:
|
||||
continue
|
||||
param = params_dict[loaded_weight_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
|
||||
weight_loader(param, loaded_weight)
|
||||
if "kv_b_proj.weight" in loaded_weight_name:
|
||||
# handle kv_b_proj_bmm
|
||||
model_param_name = loaded_weight_name.replace(
|
||||
"kv_b_proj.weight", "kv_b_proj_bmm.k_b_proj_weight"
|
||||
)
|
||||
param = params_dict[model_param_name]
|
||||
weight_loader = getattr(param, "weight_loader", None)
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
|
||||
def compute_logits(self, hidden_states: paddle.Tensor):
|
||||
""" """
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
@@ -78,9 +78,13 @@ def default_weight_loader(fd_config: FDConfig) -> None:
|
||||
if param.dtype != loaded_weight.dtype:
|
||||
loaded_weight = loaded_weight.cast(param.dtype)
|
||||
|
||||
assert param.shape == loaded_weight.shape, (
|
||||
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
|
||||
)
|
||||
if param.shape != loaded_weight.shape:
|
||||
try:
|
||||
param = param.reshape(loaded_weight.shape)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f" Attempted to load weight ({loaded_weight.shape}) into parameter ({param.shape}). {e}"
|
||||
)
|
||||
|
||||
param.copy_(loaded_weight, False)
|
||||
except Exception:
|
||||
|
Reference in New Issue
Block a user