【Inference Optimize】Update MergedReplicatedLinear for DSK qkv_a_proj_with_mqa. (#3673)

* support MergedReplicatedLinear

* update MergedReplicatedLinear to support DSK_wint4 V1_load

* update model name

* update linear class

* fix

* fix v0 moe_bias load

---------

Co-authored-by: bukejiyu <52310069+bukejiyu@users.noreply.github.com>
This commit is contained in:
AIbin
2025-09-05 12:16:05 +08:00
committed by GitHub
parent b23fc654d9
commit 41aee08982
4 changed files with 102 additions and 4 deletions

View File

@@ -298,6 +298,76 @@ class ReplicatedLinear(LinearBase):
)
class MergedReplicatedLinear(ReplicatedLinear):
"""
MergedReplicatedLinear linear layer.
"""
def __init__(
self,
fd_config: FDConfig,
prefix: str = "",
input_size: int = None,
output_sizes: list[int] = None,
with_bias: bool = False,
add_bias: bool = False,
skip_quant: bool = False,
weight_dtype: str = "",
weight_key: str = "",
):
"""
Initializes a mergedreplicated linear layer.
Args:
fd_config (FDConfig): Inference-related parameters.
prefix (str): Unique name of the layer, used to name internal attributes.
Can be arbitrarily named.
input_size (int): Number of input features. Defaults to None.
output_sizes (list[int]): Number of output features list. Defaults to None.
with_bias (bool): Whether to include bias or not. Defaults to False.
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
skip_quant (bool): Whether to skip quantization. Defaults to False.
"""
super().__init__(
fd_config=fd_config,
prefix=prefix,
input_size=input_size,
output_size=sum(output_sizes),
with_bias=with_bias,
add_bias=add_bias,
skip_quant=skip_quant,
weight_dtype=weight_dtype,
weight_key=weight_key,
)
self.output_sizes = output_sizes
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
model_format = getattr(param, "model_format", "")
loaded_weight = get_tensor(loaded_weight)
if model_format == "torch":
loaded_weight = loaded_weight.transpose([1, 0])
assert loaded_shard_id in ["q_a", "kv_a"]
if not param._is_initialized():
param.initialize()
if loaded_shard_id == "q_a":
param_shard_offset = 0
param_shard_size = self.output_sizes[0]
else:
# loaded_shard_id == "kv_a"
param_shard_offset = self.output_sizes[0]
param_shard_size = self.output_sizes[1]
if hasattr(param, "tensor_track"):
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
param = slice_fn(param, True, start=param_shard_offset, end=param_shard_offset + param_shard_size)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
param.copy_(loaded_weight, False)
class ColumnParallelLinear(LinearBase):
"""
ColumnParallelLinear Layer.

View File

@@ -24,6 +24,7 @@ from paddle.nn.quant import weight_only_linear, weight_quantize
from fastdeploy import envs
from fastdeploy.model_executor.layers.linear import (
MergedColumnParallelLinear,
MergedReplicatedLinear,
QKVParallelLinear,
)
from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs
@@ -203,11 +204,15 @@ class WeightOnlyLinearMethod(QuantMethodBase):
default_initializer=paddle.nn.initializer.Constant(0),
)
quant_attrs = extra_weight_attrs
if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear):
if (
isinstance(layer, MergedColumnParallelLinear)
or isinstance(layer, QKVParallelLinear)
or isinstance(layer, MergedReplicatedLinear)
):
quant_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(
shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim")
shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim", True)
),
}
set_weight_attrs(

View File

@@ -38,6 +38,7 @@ from fastdeploy.model_executor.layers.linear import (
ColumnParallelLinear,
KVBatchLinear,
MergedColumnParallelLinear,
MergedReplicatedLinear,
ReplicatedLinear,
RowParallelLinear,
)
@@ -169,6 +170,13 @@ class DeepSeekV3MoE(nn.Layer):
def load_state_dict(self, state_dict):
""" """
if self.experts.gate_correction_bias is not None:
gate_correction_bias_tensor = state_dict.pop(self.experts.gate_correction_bias_key)
if self.experts.gate_correction_bias.shape != gate_correction_bias_tensor.shape:
gate_correction_bias_tensor = gate_correction_bias_tensor.reshape(
self.experts.gate_correction_bias.shape
)
self.experts.gate_correction_bias.set_value(gate_correction_bias_tensor)
self.gate.load_state_dict(state_dict)
self.experts.load_state_dict(state_dict)
self.shared_experts.load_state_dict(state_dict)
@@ -211,11 +219,11 @@ class DeepseekV3MLAAttention(nn.Layer):
if self.q_lora_rank is not None:
# NOTE: (changwenbin) qkv_a_proj horizontal fusion
self.qkv_a_proj_with_mqa = ReplicatedLinear(
self.qkv_a_proj_with_mqa = MergedReplicatedLinear(
fd_config=fd_config,
prefix=f"{prefix}.qkv_a_proj_with_mqa",
input_size=self.hidden_size,
output_size=self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
output_sizes=[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
with_bias=False,
)
@@ -636,6 +644,8 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
("embed_tokens.embeddings", "embed_tokens", None),
("lm_head.linear", "lm_head", None),
("experts.gate_correction_bias", "gate.e_score_correction_bias", None),
("qkv_a_proj_with_mqa", "q_a_proj", "q_a"),
("qkv_a_proj_with_mqa", "kv_a_proj_with_mqa", "kv_a"),
]
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(

View File

@@ -58,6 +58,19 @@ model_param_map = {
{"quant_type": "block_wise_fp8", "backend": "deepgemm", "env": {"DG_NVCC_OVERRIDE_CPP_STANDARD": "17"}},
],
},
"DeepSeek-V3-0324": {
"tensor_parallel_size": 2,
"quantizations": [
{
"quant_type": "wint4",
"env": {
"FD_ATTENTION_BACKEND": "MLA_ATTN",
"FLAGS_mla_use_tensorcore": "1",
"FLAGS_flash_attn_version": "3",
},
},
],
},
}