diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index d5ece4f53..f1f5a6177 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -273,6 +273,7 @@ void AppendAttentionKernel( cache_v_zp, cache_quant_type_str, use_neox_rotary_style, + rope_3d, max_input_length, exec_stream, &qkv_out, @@ -299,6 +300,7 @@ void AppendAttentionKernel( cache_v_zp, cache_quant_type_str, use_neox_rotary_style, + rope_3d, max_input_length, exec_stream, &qkv_out, diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh index 9c9816d3b..48d769d81 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh @@ -353,7 +353,8 @@ __global__ void append_speculate_cache_rope_kernel( const int head_size, const int block_size, const int elem_cnt, - const int gqa_group_size) { + const int gqa_group_size, + const bool rope_3d) { using LoadT = AlignedVector; using LoadFloat = AlignedVector; using LoadInT = AlignedVector; @@ -413,8 +414,9 @@ __global__ void append_speculate_cache_rope_kernel( if (hi < num_heads + gqa_group_size) { // q k rope const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2; - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); } #pragma unroll for (int i = 0; i < HalfVecSize; i++) { @@ -486,7 +488,8 @@ __global__ void append_speculate_cache_neox_rope_kernel( const int head_size, const int block_size, const int elem_cnt, - const int gqa_group_size) { + const int gqa_group_size, + const bool rope_3d) { using LoadT = AlignedVector; using LoadFloat = AlignedVector; using LoadInT = AlignedVector; @@ -550,8 +553,9 @@ __global__ void append_speculate_cache_neox_rope_kernel( if (hi < num_heads + gqa_group_size) { // q k rope const int64_t emb_idx = write_seq_id * head_size + h_bias; - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2: emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); } #pragma unroll for (int i = 0; i < VecSize; i++) { @@ -636,7 +640,8 @@ __global__ void append_speculate_cache_int8_rope_kernel( const int block_size, const float max_bound, const float min_bound, - const int gqa_group_size) { + const int gqa_group_size, + const bool rope_3d) { static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); constexpr int NUM_WARPS = 4; @@ -682,8 +687,9 @@ __global__ void append_speculate_cache_int8_rope_kernel( // q rope const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); if (qkv_out_scales) { Load(&qkv_out_scales[bias_idx], &out_scale_vec); } @@ -743,10 +749,11 @@ __global__ void append_speculate_cache_int8_rope_kernel( T scale; if (head_idx < num_heads + gqa_group_size) { const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - Load(&cos_emb[emb_idx], &cos_emb_vec1); - Load(&cos_emb[emb_idx + 4], &cos_emb_vec2); - Load(&sin_emb[emb_idx], &sin_emb_vec1); - Load(&sin_emb[emb_idx + 4], &sin_emb_vec2); + uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec1); + Load(&cos_emb[new_emb_idx + 4], &cos_emb_vec2); + Load(&sin_emb[new_emb_idx], &sin_emb_vec1); + Load(&sin_emb[new_emb_idx + 4], &sin_emb_vec2); scale = __ldg(&cache_k_scales[kv_head_idx]); } else { scale = __ldg(&cache_v_scales[kv_head_idx]); @@ -868,7 +875,8 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel( const int block_size, const float max_bound, const float min_bound, - const int gqa_group_size) { + const int gqa_group_size, + const bool rope_3d) { static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); constexpr int NUM_WARPS = 4; @@ -917,8 +925,9 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel( // q rope const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); if (qkv_out_scales) { Load(&qkv_out_scales[bias_idx_left], &left_out_scale_vec); @@ -1013,10 +1022,11 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel( T scale; const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; - Load(&cos_emb[emb_idx], &cos_emb_vec1); - Load(&cos_emb[emb_idx + 8], &cos_emb_vec2); - Load(&sin_emb[emb_idx], &sin_emb_vec1); - Load(&sin_emb[emb_idx + 8], &sin_emb_vec2); + uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec1); + Load(&cos_emb[new_emb_idx + 8], &cos_emb_vec2); + Load(&sin_emb[new_emb_idx], &sin_emb_vec1); + Load(&sin_emb[new_emb_idx + 8], &sin_emb_vec2); scale = __ldg(&cache_k_scales[kv_head_idx]); #pragma unroll for (int i = 0; i < HALF_K_VEC_SIZE; i++) { @@ -1248,7 +1258,8 @@ __global__ void append_speculate_cache_int4_rope_kernel( const int block_size, const float max_bound, const float min_bound, - const int gqa_group_size) { + const int gqa_group_size, + const bool rope_3d) { static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); constexpr int NUM_WARPS = 4; @@ -1305,8 +1316,9 @@ __global__ void append_speculate_cache_int4_rope_kernel( // Load(&qkv_out_scales[bias_idx], &out_scale_vec); // q rope const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); #pragma unroll for (int i = 0; i < HalfVecSize; i++) { // dequant + add_bias + rope @@ -1395,10 +1407,11 @@ __global__ void append_speculate_cache_int4_rope_kernel( // &out_scale_vec2); if (head_idx < num_heads + gqa_group_size) { const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - Load(&cos_emb[emb_idx], &cos_emb_vec1); - Load(&cos_emb[emb_idx + 4], &cos_emb_vec2); - Load(&sin_emb[emb_idx], &sin_emb_vec1); - Load(&sin_emb[emb_idx + 4], &sin_emb_vec2); + uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec1); + Load(&cos_emb[new_emb_idx + 4], &cos_emb_vec2); + Load(&sin_emb[new_emb_idx], &sin_emb_vec1); + Load(&sin_emb[new_emb_idx + 4], &sin_emb_vec2); Load(&cache_k_scales[cache_idx], &scale_vec1); Load(&cache_k_scales[cache_idx + 8], &scale_vec2); Load(&cache_k_zero_points[cache_idx], &zp_vec1); @@ -1591,7 +1604,8 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel( const int block_size, const float max_bound, const float min_bound, - const int gqa_group_size) { + const int gqa_group_size, + const bool rope_3d) { static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); constexpr int NUM_WARPS = 4; @@ -1741,10 +1755,11 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel( &right_out_scale_vec2); const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; - Load(&cos_emb[emb_idx], &cos_emb_vec1); - Load(&cos_emb[emb_idx + 8], &cos_emb_vec2); - Load(&sin_emb[emb_idx], &sin_emb_vec1); - Load(&sin_emb[emb_idx + 8], &sin_emb_vec2); + uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec1); + Load(&cos_emb[new_emb_idx + 8], &cos_emb_vec2); + Load(&sin_emb[new_emb_idx], &sin_emb_vec1); + Load(&sin_emb[new_emb_idx + 8], &sin_emb_vec2); Load(&cache_k_scales[left_cache_idx], &left_scale_vec1); Load(&cache_k_scales[left_cache_idx + 8], diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu index 8e8195c30..99b9f1030 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu @@ -110,7 +110,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv, const int bsz, const int token_num, const cudaStream_t& stream, - const bool use_neox_style) { + const bool use_neox_style, + const bool rope_3d) { int output_inner_dim = num_heads + 2 * kv_num_heads; const uint32_t elem_nums = @@ -144,7 +145,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv, dim_head, block_size, elem_nums, - kv_num_heads); + kv_num_heads, + rope_3d); } else { append_speculate_cache_rope_kernel <<>>( @@ -167,7 +169,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv, dim_head, block_size, elem_nums, - kv_num_heads); + kv_num_heads, + rope_3d); } } @@ -196,7 +199,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv, const int bsz, const int token_num, const cudaStream_t& stream, - const bool use_neox_style) { + const bool use_neox_style, + const bool rope_3d) { constexpr int num_warps = 4; const int all_warps = ((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps; @@ -238,7 +242,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv, block_size, 127.0f, -127.0f, - kv_num_heads); + kv_num_heads, + rope_3d); } else { append_speculate_cache_int8_rope_kernel <<>>(qkv, @@ -262,7 +267,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv, block_size, 127.0f, -127.0f, - kv_num_heads); + kv_num_heads, + rope_3d); } } @@ -293,7 +299,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv, const int bsz, const int token_num, const cudaStream_t& stream, - const bool use_neox_style) { + const bool use_neox_style, + const bool rope_3d) { constexpr int num_warps = 4; const int all_warps = ((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps; @@ -337,7 +344,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv, block_size, 7.0f, -8.0f, - kv_num_heads); + kv_num_heads, + rope_3d); } else { append_speculate_cache_int4_rope_kernel <<>>(qkv, @@ -363,7 +371,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv, block_size, 7.0f, -8.0f, - kv_num_heads); + kv_num_heads, + rope_3d); } } template @@ -384,6 +393,7 @@ void SpeculateWriteCacheWithRoPEKernel( const paddle::optional& cache_v_zp, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, + const bool rope_3d, const int max_seq_len, cudaStream_t& stream, paddle::Tensor* qkv_out, @@ -479,7 +489,8 @@ void SpeculateWriteCacheWithRoPEKernel( bsz, token_nums, stream, - use_neox_rotary_style); + use_neox_rotary_style, + rope_3d); } else if (cache_quant_type_str == "cache_int8") { append_speculate_cache_int8_rope( reinterpret_cast(qkv_ptr), @@ -512,7 +523,8 @@ void SpeculateWriteCacheWithRoPEKernel( bsz, token_nums, stream, - use_neox_rotary_style); + use_neox_rotary_style, + rope_3d); } else if (cache_quant_type_str == "cache_fp8") { append_speculate_cache_int8_rope( reinterpret_cast(qkv_ptr), @@ -545,7 +557,8 @@ void SpeculateWriteCacheWithRoPEKernel( bsz, token_nums, stream, - use_neox_rotary_style); + use_neox_rotary_style, + rope_3d); } else if (cache_quant_type_str == "cache_int4_zp") { append_speculate_cache_int4_rope( reinterpret_cast(qkv_ptr), @@ -584,7 +597,8 @@ void SpeculateWriteCacheWithRoPEKernel( bsz, token_nums, stream, - use_neox_rotary_style); + use_neox_rotary_style, + rope_3d); } else { PD_THROW( "cache_quant_type_str should be one of [none, cache_int8, " @@ -612,6 +626,7 @@ template void SpeculateWriteCacheWithRoPEKernel( const paddle::optional& cache_v_zp, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, + const bool rope_3d, const int max_seq_len, cudaStream_t& stream, paddle::Tensor* qkv_out, @@ -641,6 +656,7 @@ SpeculateWriteCacheWithRoPEKernel( const paddle::optional& cache_v_zp, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, + const bool rope_3d, const int max_seq_len, cudaStream_t& stream, paddle::Tensor* qkv_out, @@ -669,6 +685,7 @@ template void SpeculateWriteCacheWithRoPEKernel( const paddle::optional& cache_v_zp, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, + const bool rope_3d, const int max_seq_len, cudaStream_t& stream, paddle::Tensor* qkv_out, @@ -699,6 +716,7 @@ SpeculateWriteCacheWithRoPEKernel( const paddle::optional& cache_v_zp, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, + const bool rope_3d, const int max_seq_len, cudaStream_t& stream, paddle::Tensor* qkv_out, diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.h b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.h index a44a9db15..2db42bc26 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.h +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.h @@ -35,6 +35,7 @@ void SpeculateWriteCacheWithRoPEKernel( const paddle::optional& cache_v_zp, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, + const bool rope_3d, const int max_seq_len, cudaStream_t& stream, paddle::Tensor* qkv_out,