diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh index 48d769d81..4fb5c93d0 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh @@ -46,7 +46,8 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel( const int gqa_group_size, const float* q_norm_weight, const float* k_norm_weight, - const float rms_norm_eps) { + const float rms_norm_eps, + const bool rope_3d) { using LoadT = AlignedVector; using LoadFloat = AlignedVector; using LoadInT = AlignedVector; @@ -109,8 +110,9 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel( if (hi < num_heads + gqa_group_size) { // q k rope const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2; - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); } float thread_m2 = 0.0f; float warp_m2 = 0.0f; diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu index 99b9f1030..4fd07ae23 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu @@ -41,7 +41,8 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv, const bool use_neox_style, const float* q_norm_weight, const float* k_norm_weight, - const float rms_norm_eps) { + const float rms_norm_eps, + const bool rope_3d) { int output_inner_dim = num_heads + 2 * kv_num_heads; const uint32_t elem_nums = use_neox_style ? token_num * (num_heads + 2 * kv_num_heads) * dim_head / 2 @@ -53,7 +54,6 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv, const int blocksize = 128; int grid_size = 1; GetNumBlocks<128>(pack_num, &grid_size); - if (use_neox_style) { PD_THROW( "append_speculate_cache_rope_qk_norm not support neox rope yet"); @@ -82,7 +82,8 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv, kv_num_heads, q_norm_weight, k_norm_weight, - rms_norm_eps); + rms_norm_eps, + rope_3d); } } @@ -426,7 +427,6 @@ void SpeculateWriteCacheWithRoPEKernel( ? rotary_embs.get().data() + max_seq_len * dim_head : rotary_embs.get().data() + max_seq_len * dim_head / 2; } - if (q_norm_weight && k_norm_weight) { if (cache_quant_type_str == "none") { append_speculate_cache_rope_qk_norm( @@ -457,11 +457,13 @@ void SpeculateWriteCacheWithRoPEKernel( use_neox_rotary_style, reinterpret_cast(q_norm_weight.get().data()), reinterpret_cast(k_norm_weight.get().data()), - rms_norm_eps); + rms_norm_eps, + rope_3d); } else { PD_THROW( "append_decode_cache_rope_qk_norm not support cachekv quant yet"); } + } else { if (cache_quant_type_str == "none") { append_speculate_cache_rope(