support qk norm (#3145)

This commit is contained in:
Yuan Xiaolan
2025-08-05 16:46:14 +08:00
committed by GitHub
parent 4a10e29804
commit 7ce00e597c
17 changed files with 791 additions and 201 deletions

View File

@@ -28,6 +28,7 @@ from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethod
if TYPE_CHECKING:
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.utils import get_tensor
class Attention(nn.Layer):
@@ -49,6 +50,7 @@ class Attention(nn.Layer):
linear_smooth: paddle.Tensor = None,
use_neox_rotary_style: bool = False,
use_qk_norm: bool = False,
rms_norm_eps: float = 1e-6,
) -> None:
"""
Initializes `LMLayer` with the given parameters.
@@ -63,6 +65,8 @@ class Attention(nn.Layer):
prefix (str, optional): The name of current layer. Defaults to "".
linear_shift (Optional[paddle.Tensor], optional): The shift of linear. Defaults to None.
linear_smooth (Optional[paddle.Tensor], optional): The smooth of linear. Defaults to None.
use_qk_norm (bool, optional): Whether to apply rmsnorm on QA after rope. Defaults to False.
rms_norm_eps (float, optional): The epsilon of RMSNorm. Defaults to 1e-6.
Raises:
ValueError: If the `v_head_dim` is less than 0.
@@ -102,6 +106,27 @@ class Attention(nn.Layer):
logger.info(
f"Attention is running in cache kv {self.kvcache_quant_method.cache_quant_config.quant_type} mode"
)
self.use_qk_norm = use_qk_norm
self.rms_norm_eps = rms_norm_eps
if self.use_qk_norm:
self.q_norm_key = f"{self.prefix}.q_norm"
self.k_norm_key = f"{self.prefix}.k_norm"
self.init_weight()
def init_weight(self):
self.q_norm_weight = self.create_parameter(
shape=[self.qk_head_dim],
dtype=self._dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
self.k_norm_weight = self.create_parameter(
shape=[self.qk_head_dim],
dtype=self._dtype,
is_bias=False,
default_initializer=paddle.nn.initializer.Constant(0),
)
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
"""
@@ -109,6 +134,11 @@ class Attention(nn.Layer):
"""
if self.kvcache_quant_method is not None:
self.kvcache_quant_method.create_weights(self, state_dict)
if self.use_qk_norm:
q_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.q_norm_key + ".weight")))
k_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.k_norm_key + ".weight")))
self.q_norm_weight.set_value(q_norm_weight_tensor)
self.k_norm_weight.set_value(k_norm_weight_tensor)
def forward(
self,