集中式支持fa3 (#3112)

This commit is contained in:
yangjianfengo1
2025-08-01 18:03:36 +08:00
committed by GitHub
parent bdb83e007d
commit 64d7a3194d
4 changed files with 257 additions and 25 deletions

View File

@@ -761,6 +761,17 @@ void SpeculateStepPaddle(
const int encoder_decoder_block_num,
const int max_draft_tokens);
void MergePrefillDecodeOutput(
const paddle::Tensor &encoder_res,
const paddle::Tensor &decoder_res,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &cu_seq_q,
const int head_num,
const int head_dim,
const int max_token);
PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
@@ -1111,4 +1122,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("mtp_step_paddle",&MTPStepPaddle, "mtp_step_paddle function");
m.def("speculate_step_paddle",&SpeculateStepPaddle, "speculate_step_paddle function");
m.def("merge_prefill_decode_output", &MergePrefillDecodeOutput, "merge_prefill_decode_output function");
}

View File

@@ -0,0 +1,117 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
template <int warps, typename T>
__global__ void FillEncoderDecoderResKernel(
T * encoder_res_data,
T * decoder_res_data,
const int * seq_lens_encoder,
const int * seq_lens_decoder,
const int * seq_lens_this_time,
const int * cu_seq_q,
const int head_num,
const int head_dim) {
const int bidb = blockIdx.x;
const int bidh = blockIdx.y;
const int bidt = blockIdx.z * warps;
const int tid = threadIdx.x;
const int warp_id = tid / 32;
const int land_id = tid % 32;
const int token_id = bidt + warp_id;
const int seq_len_encoder = seq_lens_encoder[bidb];
const int seq_len_decoder = seq_lens_decoder[bidb];
const int seq_len_this_time = seq_lens_this_time[bidb];
if (seq_len_encoder > 0 || seq_len_decoder == 0 || token_id >= seq_len_this_time) {
return;
}
const int load_idx = ((cu_seq_q[bidb] + token_id) * head_num + bidh) * head_dim + land_id * 4;
*reinterpret_cast<float2*>(encoder_res_data + load_idx) = *reinterpret_cast<float2*>(decoder_res_data + load_idx);
}
void MergePrefillDecodeOutput(
const paddle::Tensor &encoder_res,
const paddle::Tensor &decoder_res,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &cu_seq_q,
const int head_num,
const int head_dim,
const int max_token) {
if (head_dim != 128) {
PD_THROW("Only supported head_dim = 128");
}
const int batch_size = seq_lens_encoder.shape()[0];
constexpr int warps = 4;
const int tokens_block = (max_token + warps - 1) / warps;
dim3 grid_dims;
grid_dims.x = batch_size;
grid_dims.y = head_num;
grid_dims.z = tokens_block;
if (encoder_res.dtype() == paddle::DataType::FLOAT16) {
using T = phi::dtype::float16;
FillEncoderDecoderResKernel<warps>
<<<grid_dims, 128, 0, encoder_res.stream()>>>(
const_cast<T*>(encoder_res.data<T>()),
const_cast<T*>(decoder_res.data<T>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
seq_lens_this_time.data<int>(),
cu_seq_q.data<int>(),
head_num,
head_dim
);
} else if (encoder_res.dtype() == paddle::DataType::BFLOAT16) {
using T = phi::dtype::bfloat16;
FillEncoderDecoderResKernel<warps>
<<<grid_dims, 128, 0, encoder_res.stream()>>>(
const_cast<T*>(encoder_res.data<T>()),
const_cast<T*>(decoder_res.data<T>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
seq_lens_this_time.data<int>(),
cu_seq_q.data<int>(),
head_num,
head_dim
);
}
}
PD_BUILD_STATIC_OP(merge_prefill_decode_output)
.Inputs({"encoder_res",
"decoder_res",
"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
"cu_seq_q"})
.Outputs({"res"})
.Attrs({"head_num: int",
"head_dim: int",
"max_token: int"})
.SetInplaceMap({{"encoder_res", "res"}})
.SetKernelFn(PD_KERNEL(MergePrefillDecodeOutput));

View File

@@ -294,6 +294,7 @@ elif paddle.is_compiled_with_cuda():
"gpu_ops/fused_rotary_position_encoding.cu",
"gpu_ops/noaux_tc.cu",
"gpu_ops/custom_all_reduce/all_reduce.cu",
"gpu_ops/merge_prefill_decode_output.cu",
]
# pd_disaggregation

View File

@@ -34,6 +34,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionMetadata,
)
from fastdeploy.model_executor.layers.attention.ops import (
append_attention,
get_block_shape_and_split_kv_block,
gqa_rope_write_cache,
init_kv_signal_per_query,
@@ -46,6 +47,15 @@ from fastdeploy.model_executor.layers.attention.utils import init_rank_and_devic
if TYPE_CHECKING:
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import merge_prefill_decode_output
else:
merge_prefill_decode_output = None
import os
@dataclass
class FlashAttentionMetadata(AttentionMetadata):
@@ -61,6 +71,7 @@ class FlashAttentionMetadata(AttentionMetadata):
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
max_len_kv: paddle.Tensor = None
cu_seqlens_q: paddle.Tensor = None
cu_seqlens_k: paddle.Tensor = None
@@ -76,6 +87,12 @@ class FlashAttentionMetadata(AttentionMetadata):
kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
_fuse_kernel_compute_dtype: str = "bf16"
_dtype: paddle.dtype = paddle.bfloat16
max_len_tensor_cpu: paddle.Tensor = None
max_len_tensor_cpu_decoder: paddle.Tensor = None
class FlashAttentionBackend(AttentionBackend):
"""
@@ -143,6 +160,11 @@ class FlashAttentionBackend(AttentionBackend):
print(
"The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead."
)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", "32768"))
self.zero_seq_enc_lens_for_decode = paddle.zeros(
shape=[fd_config.parallel_config.max_num_seqs, 1], dtype=paddle.int32
)
def get_attntion_meta(self):
"""get_attntion_meta"""
@@ -208,7 +230,7 @@ class FlashAttentionBackend(AttentionBackend):
) = pre_cache_len_concat(
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
metadata.set_max_lengths[2],
forward_meta.max_len_tensor_cpu[2],
self.block_size,
)
@@ -227,6 +249,18 @@ class FlashAttentionBackend(AttentionBackend):
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
self.rank, int(self.device_id), self.keep_pd_step_flag
)
if metadata._dtype == "bfloat16":
metadata._fuse_kernel_compute_dtype = "bf16"
elif metadata._dtype == "float16":
metadata._fuse_kernel_compute_dtype = "fp16"
elif metadata._dtype == "float32":
metadata._fuse_kernel_compute_dtype = "fp32"
metadata.max_len_tensor_cpu = forward_meta.max_len_tensor_cpu
metadata.max_len_tensor_cpu_decoder = paddle.clone(metadata.max_len_tensor_cpu)
metadata.max_len_tensor_cpu_decoder[1] = 0
self.attention_metadata = metadata
def forward_mixed(
@@ -248,45 +282,112 @@ class FlashAttentionBackend(AttentionBackend):
layer.layer_id + self.start_layer_index,
)
q, k, v, _ = gqa_rope_write_cache(
if metadata.max_len_tensor_cpu[1] > 0:
q, k, v, _ = gqa_rope_write_cache(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.rotary_embs,
forward_meta.seq_lens_this_time,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.batch_id_per_token,
metadata.block_tables,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.pre_cache_batch_ids,
metadata.pre_cache_tile_ids_per_batch,
metadata.pre_cache_num_blocks_cpu,
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
getattr(layer, "cache_k_out_scale", None),
getattr(layer, "cache_v_out_scale", None),
getattr(layer, "cache_k_zp", None),
getattr(layer, "cache_v_zp", None),
metadata.kv_signal_data_list[layer.layer_id],
metadata.kv_token_num_cpu[0].item(),
self.max_seq_len,
getattr(layer, "cache_quant_type_str", "none"),
)
res_encoder = self.flash_attn_func(
q,
k,
v,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
max_seqlen_q=forward_meta.max_len_tensor_cpu[0],
max_seqlen_k=forward_meta.max_len_tensor_cpu[3],
causal=self.causal,
**self.flash_attn_kwargs,
)[0].reshape([-1, self.attn_outputsize_tp])
res_decoder = append_attention(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.rotary_embs,
forward_meta.seq_lens_this_time,
forward_meta.seq_lens_encoder,
self.zero_seq_enc_lens_for_decode,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.pre_cache_batch_ids,
metadata.pre_cache_tile_ids_per_batch,
metadata.pre_cache_num_blocks_cpu,
forward_meta.decoder_batch_ids, # from buffer
forward_meta.decoder_tile_ids_per_batch, # from buffer
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
metadata.max_len_kv,
metadata.rotary_embs,
forward_meta.attn_mask,
layer.qkv_bias,
layer.qkv_scale,
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
getattr(layer, "cache_k_out_scale", None),
getattr(layer, "cache_v_out_scale", None),
getattr(layer, "cache_k_zp", None),
getattr(layer, "cache_v_zp", None),
layer.linear_shift,
layer.linear_smooth,
metadata.kv_signal_data_list[layer.layer_id],
metadata.kv_token_num_cpu[0].item(),
self.max_seq_len,
metadata._fuse_kernel_compute_dtype,
getattr(layer, "cache_quant_type_str", "none"),
)
layer.use_neox_rotary_style,
self.rope_3d,
self.max_seq_len,
getattr(layer, "quant_max_bound", 0.0),
getattr(layer, "quant_min_bound", 0.0),
getattr(layer, "out_scale", -1.0),
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.max_partition_size,
self.max_seq_len,
self.speculate_max_draft_token_num + 1,
self.causal,
self.speculative_method is not None,
)[0]
res = self.flash_attn_func(
q,
k,
v,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
max_seqlen_q=forward_meta.max_len_tensor_cpu[0],
max_seqlen_k=forward_meta.max_len_tensor_cpu[3],
causal=self.causal,
**self.flash_attn_kwargs,
)[0].reshape([-1, self.attn_outputsize_tp])
return res
if metadata.max_len_tensor_cpu[1] > 0:
merge_prefill_decode_output(
res_encoder,
res_decoder,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q,
self.num_heads,
self.head_dim,
self.speculate_max_draft_token_num + 1,
)
return res_encoder
else:
return res_decoder