mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 09:31:35 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -24,8 +24,8 @@ from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import \
|
||||
QuantMethodBase
|
||||
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
|
||||
@@ -67,10 +67,14 @@ class Attention(nn.Layer):
|
||||
ValueError: If the `v_head_dim` is less than 0.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_heads: int = fd_config.model_config.num_attention_heads // fd_config.parallel_config.tensor_parallel_size
|
||||
self.num_heads: int = (
|
||||
fd_config.model_config.num_attention_heads // fd_config.parallel_config.tensor_parallel_size
|
||||
)
|
||||
self.head_dim: int = fd_config.model_config.head_dim
|
||||
self.kv_num_heads: int = \
|
||||
max(1, fd_config.model_config.num_key_value_heads // fd_config.parallel_config.tensor_parallel_size)
|
||||
self.kv_num_heads: int = max(
|
||||
1,
|
||||
fd_config.model_config.num_key_value_heads // fd_config.parallel_config.tensor_parallel_size,
|
||||
)
|
||||
self.layer_id: int = layer_id
|
||||
self.v_head_dim: int = v_head_dim if v_head_dim > 0 else self.head_dim
|
||||
self.rope_type: str = rope_type
|
||||
@@ -86,10 +90,8 @@ class Attention(nn.Layer):
|
||||
self.out_scale: float = out_scale
|
||||
self.use_neox_rotary_style: bool = use_neox_rotary_style
|
||||
|
||||
if fd_config.quant_config and hasattr(fd_config.quant_config,
|
||||
"kv_cache_quant_type"):
|
||||
self.kvcache_quant_method: QuantMethodBase = fd_config.quant_config.get_quant_method(
|
||||
self)
|
||||
if fd_config.quant_config and hasattr(fd_config.quant_config, "kv_cache_quant_type"):
|
||||
self.kvcache_quant_method: QuantMethodBase = fd_config.quant_config.get_quant_method(self)
|
||||
else:
|
||||
self.kvcache_quant_method = None
|
||||
|
||||
@@ -100,11 +102,10 @@ class Attention(nn.Layer):
|
||||
f"Attention is running in cache kv {self.kvcache_quant_method.cache_quant_config.quant_type} mode"
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str,
|
||||
paddle.Tensor | np.ndarray]):
|
||||
'''
|
||||
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
|
||||
"""
|
||||
Attention only have quant related scales not other parameters.
|
||||
'''
|
||||
"""
|
||||
if self.kvcache_quant_method is not None:
|
||||
self.kvcache_quant_method.create_weights(self, state_dict)
|
||||
|
||||
|
Reference in New Issue
Block a user