polish code with new pre-commit rule (#2923)

This commit is contained in:
Zero Rains
2025-07-19 23:19:27 +08:00
committed by GitHub
parent b8676d71a8
commit 25698d56d1
424 changed files with 14307 additions and 13518 deletions

View File

@@ -24,8 +24,8 @@ from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.quantization.quant_base import \
QuantMethodBase
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
if TYPE_CHECKING:
from fastdeploy.model_executor.forward_meta import ForwardMeta
@@ -67,10 +67,14 @@ class Attention(nn.Layer):
ValueError: If the `v_head_dim` is less than 0.
"""
super().__init__()
self.num_heads: int = fd_config.model_config.num_attention_heads // fd_config.parallel_config.tensor_parallel_size
self.num_heads: int = (
fd_config.model_config.num_attention_heads // fd_config.parallel_config.tensor_parallel_size
)
self.head_dim: int = fd_config.model_config.head_dim
self.kv_num_heads: int = \
max(1, fd_config.model_config.num_key_value_heads // fd_config.parallel_config.tensor_parallel_size)
self.kv_num_heads: int = max(
1,
fd_config.model_config.num_key_value_heads // fd_config.parallel_config.tensor_parallel_size,
)
self.layer_id: int = layer_id
self.v_head_dim: int = v_head_dim if v_head_dim > 0 else self.head_dim
self.rope_type: str = rope_type
@@ -86,10 +90,8 @@ class Attention(nn.Layer):
self.out_scale: float = out_scale
self.use_neox_rotary_style: bool = use_neox_rotary_style
if fd_config.quant_config and hasattr(fd_config.quant_config,
"kv_cache_quant_type"):
self.kvcache_quant_method: QuantMethodBase = fd_config.quant_config.get_quant_method(
self)
if fd_config.quant_config and hasattr(fd_config.quant_config, "kv_cache_quant_type"):
self.kvcache_quant_method: QuantMethodBase = fd_config.quant_config.get_quant_method(self)
else:
self.kvcache_quant_method = None
@@ -100,11 +102,10 @@ class Attention(nn.Layer):
f"Attention is running in cache kv {self.kvcache_quant_method.cache_quant_config.quant_type} mode"
)
def load_state_dict(self, state_dict: Dict[str,
paddle.Tensor | np.ndarray]):
'''
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
"""
Attention only have quant related scales not other parameters.
'''
"""
if self.kvcache_quant_method is not None:
self.kvcache_quant_method.create_weights(self, state_dict)