mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
support rope_3d in spec mode (#4034)
This commit is contained in:
@@ -46,7 +46,8 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
|
||||
const int gqa_group_size,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps) {
|
||||
const float rms_norm_eps,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadFloat = AlignedVector<float, VecSize>;
|
||||
using LoadInT = AlignedVector<InT, VecSize>;
|
||||
@@ -109,8 +110,9 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
// q k rope
|
||||
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
float thread_m2 = 0.0f;
|
||||
float warp_m2 = 0.0f;
|
||||
|
||||
@@ -41,7 +41,8 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
const bool use_neox_style,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps) {
|
||||
const float rms_norm_eps,
|
||||
const bool rope_3d) {
|
||||
int output_inner_dim = num_heads + 2 * kv_num_heads;
|
||||
const uint32_t elem_nums =
|
||||
use_neox_style ? token_num * (num_heads + 2 * kv_num_heads) * dim_head / 2
|
||||
@@ -53,7 +54,6 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
|
||||
if (use_neox_style) {
|
||||
PD_THROW(
|
||||
"append_speculate_cache_rope_qk_norm not support neox rope yet");
|
||||
@@ -82,7 +82,8 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
kv_num_heads,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
rms_norm_eps,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -426,7 +427,6 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
? rotary_embs.get().data<float>() + max_seq_len * dim_head
|
||||
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
|
||||
}
|
||||
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_speculate_cache_rope_qk_norm(
|
||||
@@ -457,11 +457,13 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
use_neox_rotary_style,
|
||||
reinterpret_cast<const float*>(q_norm_weight.get().data<float>()),
|
||||
reinterpret_cast<const float*>(k_norm_weight.get().data<float>()),
|
||||
rms_norm_eps);
|
||||
rms_norm_eps,
|
||||
rope_3d);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
|
||||
}
|
||||
|
||||
} else {
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_speculate_cache_rope(
|
||||
|
||||
Reference in New Issue
Block a user