集中式支持fa3 (#3112)

This commit is contained in:
yangjianfengo1
2025-08-01 18:03:36 +08:00
committed by GitHub
parent bdb83e007d
commit 64d7a3194d
4 changed files with 257 additions and 25 deletions

View File

@@ -34,6 +34,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionMetadata,
)
from fastdeploy.model_executor.layers.attention.ops import (
append_attention,
get_block_shape_and_split_kv_block,
gqa_rope_write_cache,
init_kv_signal_per_query,
@@ -46,6 +47,15 @@ from fastdeploy.model_executor.layers.attention.utils import init_rank_and_devic
if TYPE_CHECKING:
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import merge_prefill_decode_output
else:
merge_prefill_decode_output = None
import os
@dataclass
class FlashAttentionMetadata(AttentionMetadata):
@@ -61,6 +71,7 @@ class FlashAttentionMetadata(AttentionMetadata):
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
max_len_kv: paddle.Tensor = None
cu_seqlens_q: paddle.Tensor = None
cu_seqlens_k: paddle.Tensor = None
@@ -76,6 +87,12 @@ class FlashAttentionMetadata(AttentionMetadata):
kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
_fuse_kernel_compute_dtype: str = "bf16"
_dtype: paddle.dtype = paddle.bfloat16
max_len_tensor_cpu: paddle.Tensor = None
max_len_tensor_cpu_decoder: paddle.Tensor = None
class FlashAttentionBackend(AttentionBackend):
"""
@@ -143,6 +160,11 @@ class FlashAttentionBackend(AttentionBackend):
print(
"The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead."
)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", "32768"))
self.zero_seq_enc_lens_for_decode = paddle.zeros(
shape=[fd_config.parallel_config.max_num_seqs, 1], dtype=paddle.int32
)
def get_attntion_meta(self):
"""get_attntion_meta"""
@@ -208,7 +230,7 @@ class FlashAttentionBackend(AttentionBackend):
) = pre_cache_len_concat(
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
metadata.set_max_lengths[2],
forward_meta.max_len_tensor_cpu[2],
self.block_size,
)
@@ -227,6 +249,18 @@ class FlashAttentionBackend(AttentionBackend):
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
self.rank, int(self.device_id), self.keep_pd_step_flag
)
if metadata._dtype == "bfloat16":
metadata._fuse_kernel_compute_dtype = "bf16"
elif metadata._dtype == "float16":
metadata._fuse_kernel_compute_dtype = "fp16"
elif metadata._dtype == "float32":
metadata._fuse_kernel_compute_dtype = "fp32"
metadata.max_len_tensor_cpu = forward_meta.max_len_tensor_cpu
metadata.max_len_tensor_cpu_decoder = paddle.clone(metadata.max_len_tensor_cpu)
metadata.max_len_tensor_cpu_decoder[1] = 0
self.attention_metadata = metadata
def forward_mixed(
@@ -248,45 +282,112 @@ class FlashAttentionBackend(AttentionBackend):
layer.layer_id + self.start_layer_index,
)
q, k, v, _ = gqa_rope_write_cache(
if metadata.max_len_tensor_cpu[1] > 0:
q, k, v, _ = gqa_rope_write_cache(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.rotary_embs,
forward_meta.seq_lens_this_time,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.batch_id_per_token,
metadata.block_tables,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.pre_cache_batch_ids,
metadata.pre_cache_tile_ids_per_batch,
metadata.pre_cache_num_blocks_cpu,
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
getattr(layer, "cache_k_out_scale", None),
getattr(layer, "cache_v_out_scale", None),
getattr(layer, "cache_k_zp", None),
getattr(layer, "cache_v_zp", None),
metadata.kv_signal_data_list[layer.layer_id],
metadata.kv_token_num_cpu[0].item(),
self.max_seq_len,
getattr(layer, "cache_quant_type_str", "none"),
)
res_encoder = self.flash_attn_func(
q,
k,
v,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
max_seqlen_q=forward_meta.max_len_tensor_cpu[0],
max_seqlen_k=forward_meta.max_len_tensor_cpu[3],
causal=self.causal,
**self.flash_attn_kwargs,
)[0].reshape([-1, self.attn_outputsize_tp])
res_decoder = append_attention(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.rotary_embs,
forward_meta.seq_lens_this_time,
forward_meta.seq_lens_encoder,
self.zero_seq_enc_lens_for_decode,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.pre_cache_batch_ids,
metadata.pre_cache_tile_ids_per_batch,
metadata.pre_cache_num_blocks_cpu,
forward_meta.decoder_batch_ids, # from buffer
forward_meta.decoder_tile_ids_per_batch, # from buffer
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
metadata.max_len_kv,
metadata.rotary_embs,
forward_meta.attn_mask,
layer.qkv_bias,
layer.qkv_scale,
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
getattr(layer, "cache_k_out_scale", None),
getattr(layer, "cache_v_out_scale", None),
getattr(layer, "cache_k_zp", None),
getattr(layer, "cache_v_zp", None),
layer.linear_shift,
layer.linear_smooth,
metadata.kv_signal_data_list[layer.layer_id],
metadata.kv_token_num_cpu[0].item(),
self.max_seq_len,
metadata._fuse_kernel_compute_dtype,
getattr(layer, "cache_quant_type_str", "none"),
)
layer.use_neox_rotary_style,
self.rope_3d,
self.max_seq_len,
getattr(layer, "quant_max_bound", 0.0),
getattr(layer, "quant_min_bound", 0.0),
getattr(layer, "out_scale", -1.0),
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.max_partition_size,
self.max_seq_len,
self.speculate_max_draft_token_num + 1,
self.causal,
self.speculative_method is not None,
)[0]
res = self.flash_attn_func(
q,
k,
v,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
max_seqlen_q=forward_meta.max_len_tensor_cpu[0],
max_seqlen_k=forward_meta.max_len_tensor_cpu[3],
causal=self.causal,
**self.flash_attn_kwargs,
)[0].reshape([-1, self.attn_outputsize_tp])
return res
if metadata.max_len_tensor_cpu[1] > 0:
merge_prefill_decode_output(
res_encoder,
res_decoder,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q,
self.num_heads,
self.head_dim,
self.speculate_max_draft_token_num + 1,
)
return res_encoder
else:
return res_decoder