mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
【Inference Optimize】Update MergedReplicatedLinear for DSK qkv_a_proj_with_mqa. (#3673)
* support MergedReplicatedLinear * update MergedReplicatedLinear to support DSK_wint4 V1_load * update model name * update linear class * fix * fix v0 moe_bias load --------- Co-authored-by: bukejiyu <52310069+bukejiyu@users.noreply.github.com>
This commit is contained in:
@@ -298,6 +298,76 @@ class ReplicatedLinear(LinearBase):
|
||||
)
|
||||
|
||||
|
||||
class MergedReplicatedLinear(ReplicatedLinear):
|
||||
"""
|
||||
MergedReplicatedLinear linear layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
prefix: str = "",
|
||||
input_size: int = None,
|
||||
output_sizes: list[int] = None,
|
||||
with_bias: bool = False,
|
||||
add_bias: bool = False,
|
||||
skip_quant: bool = False,
|
||||
weight_dtype: str = "",
|
||||
weight_key: str = "",
|
||||
):
|
||||
"""
|
||||
Initializes a mergedreplicated linear layer.
|
||||
Args:
|
||||
fd_config (FDConfig): Inference-related parameters.
|
||||
prefix (str): Unique name of the layer, used to name internal attributes.
|
||||
Can be arbitrarily named.
|
||||
input_size (int): Number of input features. Defaults to None.
|
||||
output_sizes (list[int]): Number of output features list. Defaults to None.
|
||||
with_bias (bool): Whether to include bias or not. Defaults to False.
|
||||
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to False.
|
||||
skip_quant (bool): Whether to skip quantization. Defaults to False.
|
||||
"""
|
||||
super().__init__(
|
||||
fd_config=fd_config,
|
||||
prefix=prefix,
|
||||
input_size=input_size,
|
||||
output_size=sum(output_sizes),
|
||||
with_bias=with_bias,
|
||||
add_bias=add_bias,
|
||||
skip_quant=skip_quant,
|
||||
weight_dtype=weight_dtype,
|
||||
weight_key=weight_key,
|
||||
)
|
||||
self.output_sizes = output_sizes
|
||||
|
||||
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
|
||||
model_format = getattr(param, "model_format", "")
|
||||
loaded_weight = get_tensor(loaded_weight)
|
||||
|
||||
if model_format == "torch":
|
||||
loaded_weight = loaded_weight.transpose([1, 0])
|
||||
|
||||
assert loaded_shard_id in ["q_a", "kv_a"]
|
||||
if not param._is_initialized():
|
||||
param.initialize()
|
||||
|
||||
if loaded_shard_id == "q_a":
|
||||
param_shard_offset = 0
|
||||
param_shard_size = self.output_sizes[0]
|
||||
else:
|
||||
# loaded_shard_id == "kv_a"
|
||||
param_shard_offset = self.output_sizes[0]
|
||||
param_shard_size = self.output_sizes[1]
|
||||
|
||||
if hasattr(param, "tensor_track"):
|
||||
param.tensor_track.mark(start=param_shard_offset, end=param_shard_offset + param_shard_size)
|
||||
param = slice_fn(param, True, start=param_shard_offset, end=param_shard_offset + param_shard_size)
|
||||
assert param.shape == loaded_weight.shape, (
|
||||
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
|
||||
)
|
||||
param.copy_(loaded_weight, False)
|
||||
|
||||
|
||||
class ColumnParallelLinear(LinearBase):
|
||||
"""
|
||||
ColumnParallelLinear Layer.
|
||||
|
@@ -24,6 +24,7 @@ from paddle.nn.quant import weight_only_linear, weight_quantize
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
MergedReplicatedLinear,
|
||||
QKVParallelLinear,
|
||||
)
|
||||
from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs
|
||||
@@ -203,11 +204,15 @@ class WeightOnlyLinearMethod(QuantMethodBase):
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
quant_attrs = extra_weight_attrs
|
||||
if isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear):
|
||||
if (
|
||||
isinstance(layer, MergedColumnParallelLinear)
|
||||
or isinstance(layer, QKVParallelLinear)
|
||||
or isinstance(layer, MergedReplicatedLinear)
|
||||
):
|
||||
quant_attrs = {
|
||||
**extra_weight_attrs,
|
||||
"tensor_track": TensorTracker(
|
||||
shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim")
|
||||
shape=layer.weight_shape, output_dim=extra_weight_attrs.get("output_dim", True)
|
||||
),
|
||||
}
|
||||
set_weight_attrs(
|
||||
|
@@ -38,6 +38,7 @@ from fastdeploy.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
KVBatchLinear,
|
||||
MergedColumnParallelLinear,
|
||||
MergedReplicatedLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
@@ -169,6 +170,13 @@ class DeepSeekV3MoE(nn.Layer):
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
""" """
|
||||
if self.experts.gate_correction_bias is not None:
|
||||
gate_correction_bias_tensor = state_dict.pop(self.experts.gate_correction_bias_key)
|
||||
if self.experts.gate_correction_bias.shape != gate_correction_bias_tensor.shape:
|
||||
gate_correction_bias_tensor = gate_correction_bias_tensor.reshape(
|
||||
self.experts.gate_correction_bias.shape
|
||||
)
|
||||
self.experts.gate_correction_bias.set_value(gate_correction_bias_tensor)
|
||||
self.gate.load_state_dict(state_dict)
|
||||
self.experts.load_state_dict(state_dict)
|
||||
self.shared_experts.load_state_dict(state_dict)
|
||||
@@ -211,11 +219,11 @@ class DeepseekV3MLAAttention(nn.Layer):
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
# NOTE: (changwenbin) qkv_a_proj horizontal fusion
|
||||
self.qkv_a_proj_with_mqa = ReplicatedLinear(
|
||||
self.qkv_a_proj_with_mqa = MergedReplicatedLinear(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{prefix}.qkv_a_proj_with_mqa",
|
||||
input_size=self.hidden_size,
|
||||
output_size=self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
output_sizes=[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
|
||||
with_bias=False,
|
||||
)
|
||||
|
||||
@@ -636,6 +644,8 @@ class DeepseekV3ForCausalLM(ModelForCasualLM):
|
||||
("embed_tokens.embeddings", "embed_tokens", None),
|
||||
("lm_head.linear", "lm_head", None),
|
||||
("experts.gate_correction_bias", "gate.e_score_correction_bias", None),
|
||||
("qkv_a_proj_with_mqa", "q_a_proj", "q_a"),
|
||||
("qkv_a_proj_with_mqa", "kv_a_proj_with_mqa", "kv_a"),
|
||||
]
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
|
@@ -58,6 +58,19 @@ model_param_map = {
|
||||
{"quant_type": "block_wise_fp8", "backend": "deepgemm", "env": {"DG_NVCC_OVERRIDE_CPP_STANDARD": "17"}},
|
||||
],
|
||||
},
|
||||
"DeepSeek-V3-0324": {
|
||||
"tensor_parallel_size": 2,
|
||||
"quantizations": [
|
||||
{
|
||||
"quant_type": "wint4",
|
||||
"env": {
|
||||
"FD_ATTENTION_BACKEND": "MLA_ATTN",
|
||||
"FLAGS_mla_use_tensorcore": "1",
|
||||
"FLAGS_flash_attn_version": "3",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user