mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
集中式支持fa3 (#3112)
This commit is contained in:
@@ -34,6 +34,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionMetadata,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.attention.ops import (
|
||||
append_attention,
|
||||
get_block_shape_and_split_kv_block,
|
||||
gqa_rope_write_cache,
|
||||
init_kv_signal_per_query,
|
||||
@@ -46,6 +47,15 @@ from fastdeploy.model_executor.layers.attention.utils import init_rank_and_devic
|
||||
if TYPE_CHECKING:
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import merge_prefill_decode_output
|
||||
else:
|
||||
merge_prefill_decode_output = None
|
||||
|
||||
import os
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionMetadata(AttentionMetadata):
|
||||
@@ -61,6 +71,7 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
kv_batch_ids: paddle.Tensor = None
|
||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||
kv_num_blocks: paddle.Tensor = None
|
||||
max_len_kv: paddle.Tensor = None
|
||||
|
||||
cu_seqlens_q: paddle.Tensor = None
|
||||
cu_seqlens_k: paddle.Tensor = None
|
||||
@@ -76,6 +87,12 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
kv_signal_metadata: Optional[paddle.Tensor] = None
|
||||
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
|
||||
|
||||
_fuse_kernel_compute_dtype: str = "bf16"
|
||||
_dtype: paddle.dtype = paddle.bfloat16
|
||||
|
||||
max_len_tensor_cpu: paddle.Tensor = None
|
||||
max_len_tensor_cpu_decoder: paddle.Tensor = None
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
@@ -143,6 +160,11 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
print(
|
||||
"The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead."
|
||||
)
|
||||
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
|
||||
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", "32768"))
|
||||
self.zero_seq_enc_lens_for_decode = paddle.zeros(
|
||||
shape=[fd_config.parallel_config.max_num_seqs, 1], dtype=paddle.int32
|
||||
)
|
||||
|
||||
def get_attntion_meta(self):
|
||||
"""get_attntion_meta"""
|
||||
@@ -208,7 +230,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
) = pre_cache_len_concat(
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
metadata.set_max_lengths[2],
|
||||
forward_meta.max_len_tensor_cpu[2],
|
||||
self.block_size,
|
||||
)
|
||||
|
||||
@@ -227,6 +249,18 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
|
||||
self.rank, int(self.device_id), self.keep_pd_step_flag
|
||||
)
|
||||
|
||||
if metadata._dtype == "bfloat16":
|
||||
metadata._fuse_kernel_compute_dtype = "bf16"
|
||||
elif metadata._dtype == "float16":
|
||||
metadata._fuse_kernel_compute_dtype = "fp16"
|
||||
elif metadata._dtype == "float32":
|
||||
metadata._fuse_kernel_compute_dtype = "fp32"
|
||||
|
||||
metadata.max_len_tensor_cpu = forward_meta.max_len_tensor_cpu
|
||||
metadata.max_len_tensor_cpu_decoder = paddle.clone(metadata.max_len_tensor_cpu)
|
||||
metadata.max_len_tensor_cpu_decoder[1] = 0
|
||||
|
||||
self.attention_metadata = metadata
|
||||
|
||||
def forward_mixed(
|
||||
@@ -248,45 +282,112 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
layer.layer_id + self.start_layer_index,
|
||||
)
|
||||
|
||||
q, k, v, _ = gqa_rope_write_cache(
|
||||
if metadata.max_len_tensor_cpu[1] > 0:
|
||||
q, k, v, _ = gqa_rope_write_cache(
|
||||
qkv,
|
||||
forward_meta.caches[2 * layer.layer_id],
|
||||
forward_meta.caches[2 * layer.layer_id + 1],
|
||||
metadata.cu_seqlens_q,
|
||||
metadata.cu_seqlens_k,
|
||||
metadata.rotary_embs,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.batch_id_per_token,
|
||||
metadata.block_tables,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.pre_cache_batch_ids,
|
||||
metadata.pre_cache_tile_ids_per_batch,
|
||||
metadata.pre_cache_num_blocks_cpu,
|
||||
getattr(layer, "cache_k_scale", None),
|
||||
getattr(layer, "cache_v_scale", None),
|
||||
getattr(layer, "cache_k_out_scale", None),
|
||||
getattr(layer, "cache_v_out_scale", None),
|
||||
getattr(layer, "cache_k_zp", None),
|
||||
getattr(layer, "cache_v_zp", None),
|
||||
metadata.kv_signal_data_list[layer.layer_id],
|
||||
metadata.kv_token_num_cpu[0].item(),
|
||||
self.max_seq_len,
|
||||
getattr(layer, "cache_quant_type_str", "none"),
|
||||
)
|
||||
|
||||
res_encoder = self.flash_attn_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
metadata.cu_seqlens_q,
|
||||
metadata.cu_seqlens_k,
|
||||
max_seqlen_q=forward_meta.max_len_tensor_cpu[0],
|
||||
max_seqlen_k=forward_meta.max_len_tensor_cpu[3],
|
||||
causal=self.causal,
|
||||
**self.flash_attn_kwargs,
|
||||
)[0].reshape([-1, self.attn_outputsize_tp])
|
||||
|
||||
res_decoder = append_attention(
|
||||
qkv,
|
||||
forward_meta.caches[2 * layer.layer_id],
|
||||
forward_meta.caches[2 * layer.layer_id + 1],
|
||||
metadata.cu_seqlens_q,
|
||||
metadata.cu_seqlens_k,
|
||||
metadata.rotary_embs,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.seq_lens_encoder,
|
||||
self.zero_seq_enc_lens_for_decode,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.batch_id_per_token,
|
||||
forward_meta.cu_seqlens_q,
|
||||
metadata.block_tables,
|
||||
metadata.encoder_batch_ids,
|
||||
metadata.encoder_tile_ids_per_batch,
|
||||
metadata.encoder_num_blocks,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.pre_cache_batch_ids,
|
||||
metadata.pre_cache_tile_ids_per_batch,
|
||||
metadata.pre_cache_num_blocks_cpu,
|
||||
forward_meta.decoder_batch_ids, # from buffer
|
||||
forward_meta.decoder_tile_ids_per_batch, # from buffer
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
metadata.max_len_kv,
|
||||
metadata.rotary_embs,
|
||||
forward_meta.attn_mask,
|
||||
layer.qkv_bias,
|
||||
layer.qkv_scale,
|
||||
getattr(layer, "cache_k_scale", None),
|
||||
getattr(layer, "cache_v_scale", None),
|
||||
getattr(layer, "cache_k_out_scale", None),
|
||||
getattr(layer, "cache_v_out_scale", None),
|
||||
getattr(layer, "cache_k_zp", None),
|
||||
getattr(layer, "cache_v_zp", None),
|
||||
layer.linear_shift,
|
||||
layer.linear_smooth,
|
||||
metadata.kv_signal_data_list[layer.layer_id],
|
||||
metadata.kv_token_num_cpu[0].item(),
|
||||
self.max_seq_len,
|
||||
metadata._fuse_kernel_compute_dtype,
|
||||
getattr(layer, "cache_quant_type_str", "none"),
|
||||
)
|
||||
layer.use_neox_rotary_style,
|
||||
self.rope_3d,
|
||||
self.max_seq_len,
|
||||
getattr(layer, "quant_max_bound", 0.0),
|
||||
getattr(layer, "quant_min_bound", 0.0),
|
||||
getattr(layer, "out_scale", -1.0),
|
||||
self.encoder_block_shape_q,
|
||||
self.decoder_block_shape_q,
|
||||
self.max_partition_size,
|
||||
self.max_seq_len,
|
||||
self.speculate_max_draft_token_num + 1,
|
||||
self.causal,
|
||||
self.speculative_method is not None,
|
||||
)[0]
|
||||
|
||||
res = self.flash_attn_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
metadata.cu_seqlens_q,
|
||||
metadata.cu_seqlens_k,
|
||||
max_seqlen_q=forward_meta.max_len_tensor_cpu[0],
|
||||
max_seqlen_k=forward_meta.max_len_tensor_cpu[3],
|
||||
causal=self.causal,
|
||||
**self.flash_attn_kwargs,
|
||||
)[0].reshape([-1, self.attn_outputsize_tp])
|
||||
return res
|
||||
if metadata.max_len_tensor_cpu[1] > 0:
|
||||
merge_prefill_decode_output(
|
||||
res_encoder,
|
||||
res_decoder,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.cu_seqlens_q,
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.speculate_max_draft_token_num + 1,
|
||||
)
|
||||
return res_encoder
|
||||
else:
|
||||
return res_decoder
|
||||
|
Reference in New Issue
Block a user