support rope_3d in spec mode (#4034)

This commit is contained in:
freeliuzc
2025-09-10 18:15:05 +08:00
committed by GitHub
parent 684e93269b
commit 7ee100903f
2 changed files with 12 additions and 8 deletions

View File

@@ -46,7 +46,8 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
const int gqa_group_size,
const float* q_norm_weight,
const float* k_norm_weight,
const float rms_norm_eps) {
const float rms_norm_eps,
const bool rope_3d) {
using LoadT = AlignedVector<T, VecSize>;
using LoadFloat = AlignedVector<float, VecSize>;
using LoadInT = AlignedVector<InT, VecSize>;
@@ -109,8 +110,9 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
if (hi < num_heads + gqa_group_size) {
// q k rope
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
}
float thread_m2 = 0.0f;
float warp_m2 = 0.0f;

View File

@@ -41,7 +41,8 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
const bool use_neox_style,
const float* q_norm_weight,
const float* k_norm_weight,
const float rms_norm_eps) {
const float rms_norm_eps,
const bool rope_3d) {
int output_inner_dim = num_heads + 2 * kv_num_heads;
const uint32_t elem_nums =
use_neox_style ? token_num * (num_heads + 2 * kv_num_heads) * dim_head / 2
@@ -53,7 +54,6 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
const int blocksize = 128;
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);
if (use_neox_style) {
PD_THROW(
"append_speculate_cache_rope_qk_norm not support neox rope yet");
@@ -82,7 +82,8 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
kv_num_heads,
q_norm_weight,
k_norm_weight,
rms_norm_eps);
rms_norm_eps,
rope_3d);
}
}
@@ -426,7 +427,6 @@ void SpeculateWriteCacheWithRoPEKernel(
? rotary_embs.get().data<float>() + max_seq_len * dim_head
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
}
if (q_norm_weight && k_norm_weight) {
if (cache_quant_type_str == "none") {
append_speculate_cache_rope_qk_norm(
@@ -457,11 +457,13 @@ void SpeculateWriteCacheWithRoPEKernel(
use_neox_rotary_style,
reinterpret_cast<const float*>(q_norm_weight.get().data<float>()),
reinterpret_cast<const float*>(k_norm_weight.get().data<float>()),
rms_norm_eps);
rms_norm_eps,
rope_3d);
} else {
PD_THROW(
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
}
} else {
if (cache_quant_type_str == "none") {
append_speculate_cache_rope(