diff --git a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu index 20e8b147e..3b33c750a 100644 --- a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu +++ b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu @@ -37,7 +37,8 @@ __global__ void GQAVariableLengthRotarySplitKernel( const int q_num_head, const int kv_num_head, const int seq_len, - const int last_dim) { + const int last_dim, + const bool rope_3d) { using LoadT = AlignedVector; constexpr int HalfVecSize = VecSize / 2; using LoadEmbT = AlignedVector; @@ -62,6 +63,7 @@ __global__ void GQAVariableLengthRotarySplitKernel( const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id; const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2; + int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx; const int64_t base_idx = token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim + h_bias; @@ -80,8 +82,8 @@ __global__ void GQAVariableLengthRotarySplitKernel( Load(&qkv[base_idx], &src_vec); // do rope if (hi < q_num_head + kv_num_head) { - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); #pragma unroll for (int i = 0; i < HalfVecSize; i++) { const float input_left = static_cast(src_vec[2 * i]); @@ -118,6 +120,7 @@ void gqa_rotary_qk_split_variable( const int seq_len, const int input_output_len, const int dim_head, + const bool rope_3d, const cudaStream_t &stream) { int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * dim_head; constexpr int PackSize = 16 / sizeof(T); @@ -146,7 +149,8 @@ void gqa_rotary_qk_split_variable( num_heads, kv_num_heads, seq_len, - dim_head); + dim_head, + rope_3d); } template GQARopeWriteCacheKernel( const paddle::optional& kv_signal_data, const int kv_token_num, const int max_seq_len, - const std::string& cache_quant_type) { + const std::string& cache_quant_type, + const bool rope_3d) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; @@ -953,8 +958,9 @@ std::vector GQARopeWriteCacheKernel( num_heads, kv_num_heads, max_seq_len, - rotary_embs.dims()[2], + rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2], head_dim, + rope_3d, stream); if (token_num < kv_token_num) { diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 503e67a50..f8172394b 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -107,7 +107,8 @@ std::vector GQARopeWriteCacheKernel( const paddle::optional &cache_v_zp, const paddle::optional &kv_signal_data, const int kv_token_num, const int max_seq_len, - const std::string &cache_quant_type); + const std::string &cache_quant_type, + const bool rope_3d); std::vector PreCacheLenConcat(const paddle::Tensor &seq_lens_decoder, diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index fcbf6fa64..df5e5db9d 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -283,6 +283,7 @@ class FlashAttentionBackend(AttentionBackend): metadata.kv_token_num_cpu[0].item(), self.max_seq_len, getattr(layer, "cache_quant_type_str", "none"), + self.rope_3d, ) res = self.flash_attn_func( diff --git a/fastdeploy/model_executor/layers/attention/ops/gqa_rope_write_cache.py b/fastdeploy/model_executor/layers/attention/ops/gqa_rope_write_cache.py index ed0b8f239..9aac80df3 100644 --- a/fastdeploy/model_executor/layers/attention/ops/gqa_rope_write_cache.py +++ b/fastdeploy/model_executor/layers/attention/ops/gqa_rope_write_cache.py @@ -49,6 +49,7 @@ def gqa_rope_write_cache( kv_token_num: int = 1, max_seq_len: int = 0, cache_quant_type: str = "none", + rope_3d: bool = False, ): if current_platform.is_cuda(): from fastdeploy.model_executor.ops.gpu import gqa_rope_write_cache @@ -81,6 +82,7 @@ def gqa_rope_write_cache( kv_token_num, max_seq_len, cache_quant_type, + rope_3d, ) return q, k, v, qkv_ else: