mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
support qk norm (#3145)
This commit is contained in:
@@ -28,6 +28,7 @@ from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethod
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
|
||||
|
||||
class Attention(nn.Layer):
|
||||
@@ -49,6 +50,7 @@ class Attention(nn.Layer):
|
||||
linear_smooth: paddle.Tensor = None,
|
||||
use_neox_rotary_style: bool = False,
|
||||
use_qk_norm: bool = False,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes `LMLayer` with the given parameters.
|
||||
@@ -63,6 +65,8 @@ class Attention(nn.Layer):
|
||||
prefix (str, optional): The name of current layer. Defaults to "".
|
||||
linear_shift (Optional[paddle.Tensor], optional): The shift of linear. Defaults to None.
|
||||
linear_smooth (Optional[paddle.Tensor], optional): The smooth of linear. Defaults to None.
|
||||
use_qk_norm (bool, optional): Whether to apply rmsnorm on QA after rope. Defaults to False.
|
||||
rms_norm_eps (float, optional): The epsilon of RMSNorm. Defaults to 1e-6.
|
||||
|
||||
Raises:
|
||||
ValueError: If the `v_head_dim` is less than 0.
|
||||
@@ -102,6 +106,27 @@ class Attention(nn.Layer):
|
||||
logger.info(
|
||||
f"Attention is running in cache kv {self.kvcache_quant_method.cache_quant_config.quant_type} mode"
|
||||
)
|
||||
self.use_qk_norm = use_qk_norm
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
if self.use_qk_norm:
|
||||
self.q_norm_key = f"{self.prefix}.q_norm"
|
||||
self.k_norm_key = f"{self.prefix}.k_norm"
|
||||
self.init_weight()
|
||||
|
||||
def init_weight(self):
|
||||
self.q_norm_weight = self.create_parameter(
|
||||
shape=[self.qk_head_dim],
|
||||
dtype=self._dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
self.k_norm_weight = self.create_parameter(
|
||||
shape=[self.qk_head_dim],
|
||||
dtype=self._dtype,
|
||||
is_bias=False,
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
|
||||
"""
|
||||
@@ -109,6 +134,11 @@ class Attention(nn.Layer):
|
||||
"""
|
||||
if self.kvcache_quant_method is not None:
|
||||
self.kvcache_quant_method.create_weights(self, state_dict)
|
||||
if self.use_qk_norm:
|
||||
q_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.q_norm_key + ".weight")))
|
||||
k_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.k_norm_key + ".weight")))
|
||||
self.q_norm_weight.set_value(q_norm_weight_tensor)
|
||||
self.k_norm_weight.set_value(k_norm_weight_tensor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
Reference in New Issue
Block a user