support fa3 rope3d (#3622)

This commit is contained in:
xiaoxiaohehe001
2025-08-27 11:31:29 +08:00
committed by GitHub
parent 85afa72763
commit ad319a87cc
4 changed files with 17 additions and 7 deletions

View File

@@ -37,7 +37,8 @@ __global__ void GQAVariableLengthRotarySplitKernel(
const int q_num_head, const int q_num_head,
const int kv_num_head, const int kv_num_head,
const int seq_len, const int seq_len,
const int last_dim) { const int last_dim,
const bool rope_3d) {
using LoadT = AlignedVector<T, VecSize>; using LoadT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2; constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>; using LoadEmbT = AlignedVector<float, HalfVecSize>;
@@ -62,6 +63,7 @@ __global__ void GQAVariableLengthRotarySplitKernel(
const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id; const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id;
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2; const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
const int64_t base_idx = const int64_t base_idx =
token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim + token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim +
h_bias; h_bias;
@@ -80,8 +82,8 @@ __global__ void GQAVariableLengthRotarySplitKernel(
Load<T, VecSize>(&qkv[base_idx], &src_vec); Load<T, VecSize>(&qkv[base_idx], &src_vec);
// do rope // do rope
if (hi < q_num_head + kv_num_head) { if (hi < q_num_head + kv_num_head) {
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec); Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec); Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
#pragma unroll #pragma unroll
for (int i = 0; i < HalfVecSize; i++) { for (int i = 0; i < HalfVecSize; i++) {
const float input_left = static_cast<float>(src_vec[2 * i]); const float input_left = static_cast<float>(src_vec[2 * i]);
@@ -118,6 +120,7 @@ void gqa_rotary_qk_split_variable(
const int seq_len, const int seq_len,
const int input_output_len, const int input_output_len,
const int dim_head, const int dim_head,
const bool rope_3d,
const cudaStream_t &stream) { const cudaStream_t &stream) {
int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * dim_head; int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * dim_head;
constexpr int PackSize = 16 / sizeof(T); constexpr int PackSize = 16 / sizeof(T);
@@ -146,7 +149,8 @@ void gqa_rotary_qk_split_variable(
num_heads, num_heads,
kv_num_heads, kv_num_heads,
seq_len, seq_len,
dim_head); dim_head,
rope_3d);
} }
template <typename T, template <typename T,
@@ -890,7 +894,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::optional<paddle::Tensor>& kv_signal_data, const paddle::optional<paddle::Tensor>& kv_signal_data,
const int kv_token_num, const int kv_token_num,
const int max_seq_len, const int max_seq_len,
const std::string& cache_quant_type) { const std::string& cache_quant_type,
const bool rope_3d) {
typedef PDTraits<paddle::DataType::BFLOAT16> traits_; typedef PDTraits<paddle::DataType::BFLOAT16> traits_;
typedef typename traits_::DataType DataType_; typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t; typedef typename traits_::data_t data_t;
@@ -953,8 +958,9 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
num_heads, num_heads,
kv_num_heads, kv_num_heads,
max_seq_len, max_seq_len,
rotary_embs.dims()[2], rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
head_dim, head_dim,
rope_3d,
stream); stream);
if (token_num < kv_token_num) { if (token_num < kv_token_num) {

View File

@@ -111,7 +111,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::optional<paddle::Tensor> &cache_v_zp, const paddle::optional<paddle::Tensor> &cache_v_zp,
const paddle::optional<paddle::Tensor> &kv_signal_data, const paddle::optional<paddle::Tensor> &kv_signal_data,
const int kv_token_num, const int max_seq_len, const int kv_token_num, const int max_seq_len,
const std::string &cache_quant_type); const std::string &cache_quant_type,
const bool rope_3d);
std::vector<paddle::Tensor> std::vector<paddle::Tensor>
PreCacheLenConcat(const paddle::Tensor &seq_lens_decoder, PreCacheLenConcat(const paddle::Tensor &seq_lens_decoder,

View File

@@ -311,6 +311,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata.kv_token_num_cpu[0].item(), metadata.kv_token_num_cpu[0].item(),
self.max_seq_len, self.max_seq_len,
getattr(layer, "cache_quant_type_str", "none"), getattr(layer, "cache_quant_type_str", "none"),
self.rope_3d,
) )
res_encoder = self.flash_attn_func( res_encoder = self.flash_attn_func(

View File

@@ -49,6 +49,7 @@ def gqa_rope_write_cache(
kv_token_num: int = 1, kv_token_num: int = 1,
max_seq_len: int = 0, max_seq_len: int = 0,
cache_quant_type: str = "none", cache_quant_type: str = "none",
rope_3d: bool = False,
): ):
if current_platform.is_cuda(): if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import gqa_rope_write_cache from fastdeploy.model_executor.ops.gpu import gqa_rope_write_cache
@@ -81,6 +82,7 @@ def gqa_rope_write_cache(
kv_token_num, kv_token_num,
max_seq_len, max_seq_len,
cache_quant_type, cache_quant_type,
rope_3d,
) )
return q, k, v, qkv_ return q, k, v, qkv_
else: else: