support qk norm (#3145)

This commit is contained in:
Yuan Xiaolan
2025-08-05 16:46:14 +08:00
committed by GitHub
parent 4a10e29804
commit 7ce00e597c
17 changed files with 791 additions and 201 deletions

View File

@@ -262,6 +262,9 @@ class AppendAttentionBackend(AttentionBackend):
layer.linear_shift,
layer.linear_smooth,
metadata.kv_signal_data_list[layer.layer_id],
getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None),
getattr(layer, "rms_norm_eps", 1e-6),
metadata._fuse_kernel_compute_dtype,
getattr(layer, "cache_quant_type_str", "none"),
layer.use_neox_rotary_style,

View File

@@ -28,6 +28,7 @@ from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethod
if TYPE_CHECKING:
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.utils import get_tensor
class Attention(nn.Layer):
@@ -49,6 +50,7 @@ class Attention(nn.Layer):
linear_smooth: paddle.Tensor = None,
use_neox_rotary_style: bool = False,
use_qk_norm: bool = False,
rms_norm_eps: float = 1e-6,
) -> None:
"""
Initializes `LMLayer` with the given parameters.
@@ -63,6 +65,8 @@ class Attention(nn.Layer):
prefix (str, optional): The name of current layer. Defaults to "".
linear_shift (Optional[paddle.Tensor], optional): The shift of linear. Defaults to None.
linear_smooth (Optional[paddle.Tensor], optional): The smooth of linear. Defaults to None.
use_qk_norm (bool, optional): Whether to apply rmsnorm on QA after rope. Defaults to False.
rms_norm_eps (float, optional): The epsilon of RMSNorm. Defaults to 1e-6.
Raises:
ValueError: If the `v_head_dim` is less than 0.
@@ -102,6 +106,27 @@ class Attention(nn.Layer):
logger.info(
f"Attention is running in cache kv {self.kvcache_quant_method.cache_quant_config.quant_type} mode"
)
self.use_qk_norm = use_qk_norm
self.rms_norm_eps = rms_norm_eps
if self.use_qk_norm:
self.q_norm_key = f"{self.prefix}.q_norm"
self.k_norm_key = f"{self.prefix}.k_norm"
self.init_weight()
def init_weight(self):
self.q_norm_weight = self.create_parameter(
shape=[self.qk_head_dim],
dtype=self._dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
self.k_norm_weight = self.create_parameter(
shape=[self.qk_head_dim],
dtype=self._dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
"""
@@ -109,6 +134,11 @@ class Attention(nn.Layer):
"""
if self.kvcache_quant_method is not None:
self.kvcache_quant_method.create_weights(self, state_dict)
if self.use_qk_norm:
q_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.q_norm_key + ".weight")))
k_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.k_norm_key + ".weight")))
self.q_norm_weight.set_value(q_norm_weight_tensor)
self.k_norm_weight.set_value(k_norm_weight_tensor)
def forward(
self,

View File

@@ -60,6 +60,9 @@ def append_attention(
linear_shift: Optional[paddle.Tensor] = None,
linear_smooth: Optional[paddle.Tensor] = None,
kv_signal_data: Optional[paddle.Tensor] = None,
q_norm_weight: Optional[paddle.Tensor] = None,
k_norm_weight: Optional[paddle.Tensor] = None,
rms_norm_eps: float = 1e-6,
compute_type: str = "bf16",
cache_quant_type: str = "none",
use_neox_rotary_style: bool = False,
@@ -114,6 +117,9 @@ def append_attention(
linear_shift,
linear_smooth,
kv_signal_data,
q_norm_weight,
k_norm_weight,
rms_norm_eps,
compute_type,
cache_quant_type,
use_neox_rotary_style,