mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -23,9 +23,12 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.attention.ops import (
|
||||
append_attention, get_block_shape_and_split_kv_block,
|
||||
init_signal_layerwise, open_shm_and_get_meta_signal,
|
||||
init_kv_signal_per_query)
|
||||
append_attention,
|
||||
get_block_shape_and_split_kv_block,
|
||||
init_kv_signal_per_query,
|
||||
init_signal_layerwise,
|
||||
open_shm_and_get_meta_signal,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
@@ -33,9 +36,10 @@ if TYPE_CHECKING:
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend, AttentionMetadata)
|
||||
from fastdeploy.model_executor.layers.attention.utils import \
|
||||
init_rank_and_device_id
|
||||
AttentionBackend,
|
||||
AttentionMetadata,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -43,6 +47,7 @@ class AppendAttentionMetadata(AttentionMetadata):
|
||||
"""
|
||||
AppendAttentionMetadata
|
||||
"""
|
||||
|
||||
max_len_kv: paddle.Tensor = None
|
||||
set_max_lengths: int = -1
|
||||
encoder_batch_ids: paddle.Tensor = None
|
||||
@@ -75,8 +80,13 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
AppendAttentionBackend backend implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int,
|
||||
head_dim: int) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
kv_num_heads: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
) -> None:
|
||||
"""
|
||||
AppendAttentionBackend __init__
|
||||
"""
|
||||
@@ -84,9 +94,9 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
self.attention_metadata: AppendAttentionMetadata = None
|
||||
self.block_size: int = fd_config.parallel_config.block_size
|
||||
self.max_seq_len: int = fd_config.parallel_config.max_model_len
|
||||
self.rope_theta: float = (10000.0
|
||||
if fd_config.model_config.rope_theta is None
|
||||
else fd_config.model_config.rope_theta)
|
||||
self.rope_theta: float = (
|
||||
10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta
|
||||
)
|
||||
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
|
||||
self.causal: bool = getattr(fd_config.model_config, "causal", True)
|
||||
self.speculative_method: str = fd_config.speculative_config.method
|
||||
@@ -99,11 +109,10 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
self.num_heads: int = num_heads
|
||||
self.head_dim: int = fd_config.model_config.head_dim
|
||||
self.num_layers: int = fd_config.model_config.num_hidden_layers
|
||||
self.max_partition_size: int = int(
|
||||
os.getenv("FLAGS_max_partition_size", 32768))
|
||||
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 32768))
|
||||
|
||||
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
|
||||
|
||||
|
||||
self.start_layer_index: int = fd_config.model_config.start_layer_index
|
||||
|
||||
if fd_config.parallel_config.expert_parallel_rank is None:
|
||||
@@ -137,7 +146,7 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.decoder_batch_ids, # will copy to buffer
|
||||
metadata.decoder_tile_ids_per_batch, # will copy to buffer
|
||||
metadata.decoder_tile_ids_per_batch, # will copy to buffer
|
||||
metadata.decoder_num_blocks,
|
||||
metadata.max_len_kv,
|
||||
metadata.set_max_lengths,
|
||||
@@ -165,12 +174,12 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
)
|
||||
elif self.pd_disaggregation_mode == "per_query":
|
||||
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
|
||||
self.rank, int(self.device_id), self.keep_pd_step_flag)
|
||||
self.rank, int(self.device_id), self.keep_pd_step_flag
|
||||
)
|
||||
|
||||
self.attention_metadata: AttentionMetadata = metadata
|
||||
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
|
||||
forward_meta.decoder_tile_ids_per_batch.copy_(
|
||||
metadata.decoder_tile_ids_per_batch, False)
|
||||
forward_meta.decoder_tile_ids_per_batch.copy_(metadata.decoder_tile_ids_per_batch, False)
|
||||
|
||||
def get_attntion_meta(self) -> AttentionMetadata:
|
||||
"""get_attntion_meta"""
|
||||
@@ -183,8 +192,12 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
Caculate kv cache shape
|
||||
"""
|
||||
return (max_num_blocks, self.kv_num_heads, self.block_size,
|
||||
self.head_dim)
|
||||
return (
|
||||
max_num_blocks,
|
||||
self.kv_num_heads,
|
||||
self.block_size,
|
||||
self.head_dim,
|
||||
)
|
||||
|
||||
def forward_mixed(
|
||||
self,
|
||||
@@ -203,10 +216,10 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
metadata = self.attention_metadata
|
||||
|
||||
if self.pd_disaggregation_mode == "per_query":
|
||||
metadata.kv_signal_data_list[
|
||||
layer.layer_id] = init_signal_layerwise(
|
||||
metadata.kv_signal_metadata,
|
||||
layer.layer_id + self.start_layer_index)
|
||||
metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise(
|
||||
metadata.kv_signal_metadata,
|
||||
layer.layer_id + self.start_layer_index,
|
||||
)
|
||||
|
||||
res = append_attention(
|
||||
qkv,
|
||||
|
Reference in New Issue
Block a user