This commit is contained in:
xiaoxiaohehe001
2025-09-21 22:04:59 +08:00
committed by GitHub
parent 5223065d59
commit 9f1882d9a8
4 changed files with 17 additions and 7 deletions

View File

@@ -37,7 +37,8 @@ __global__ void GQAVariableLengthRotarySplitKernel(
const int q_num_head,
const int kv_num_head,
const int seq_len,
const int last_dim) {
const int last_dim,
const bool rope_3d) {
using LoadT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
@@ -62,6 +63,7 @@ __global__ void GQAVariableLengthRotarySplitKernel(
const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id;
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
const int64_t base_idx =
token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim +
h_bias;
@@ -80,8 +82,8 @@ __global__ void GQAVariableLengthRotarySplitKernel(
Load<T, VecSize>(&qkv[base_idx], &src_vec);
// do rope
if (hi < q_num_head + kv_num_head) {
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
const float input_left = static_cast<float>(src_vec[2 * i]);
@@ -118,6 +120,7 @@ void gqa_rotary_qk_split_variable(
const int seq_len,
const int input_output_len,
const int dim_head,
const bool rope_3d,
const cudaStream_t &stream) {
int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * dim_head;
constexpr int PackSize = 16 / sizeof(T);
@@ -146,7 +149,8 @@ void gqa_rotary_qk_split_variable(
num_heads,
kv_num_heads,
seq_len,
dim_head);
dim_head,
rope_3d);
}
template <typename T,
@@ -890,7 +894,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::optional<paddle::Tensor>& kv_signal_data,
const int kv_token_num,
const int max_seq_len,
const std::string& cache_quant_type) {
const std::string& cache_quant_type,
const bool rope_3d) {
typedef PDTraits<paddle::DataType::BFLOAT16> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
@@ -953,8 +958,9 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
num_heads,
kv_num_heads,
max_seq_len,
rotary_embs.dims()[2],
rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
head_dim,
rope_3d,
stream);
if (token_num < kv_token_num) {

View File

@@ -107,7 +107,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::optional<paddle::Tensor> &cache_v_zp,
const paddle::optional<paddle::Tensor> &kv_signal_data,
const int kv_token_num, const int max_seq_len,
const std::string &cache_quant_type);
const std::string &cache_quant_type,
const bool rope_3d);
std::vector<paddle::Tensor>
PreCacheLenConcat(const paddle::Tensor &seq_lens_decoder,

View File

@@ -283,6 +283,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata.kv_token_num_cpu[0].item(),
self.max_seq_len,
getattr(layer, "cache_quant_type_str", "none"),
self.rope_3d,
)
res = self.flash_attn_func(

View File

@@ -49,6 +49,7 @@ def gqa_rope_write_cache(
kv_token_num: int = 1,
max_seq_len: int = 0,
cache_quant_type: str = "none",
rope_3d: bool = False,
):
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import gqa_rope_write_cache
@@ -81,6 +82,7 @@ def gqa_rope_write_cache(
kv_token_num,
max_seq_len,
cache_quant_type,
rope_3d,
)
return q, k, v, qkv_
else: