support mtp rope_3d (#3791)

* support mtp rope_3d

* Update speculate_write_cache_with_rope_kernel.cu
This commit is contained in:
xiaoxiaohehe001
2025-09-04 17:18:05 +08:00
committed by GitHub
parent f36a388ffe
commit f265a26f8b
4 changed files with 81 additions and 45 deletions

View File

@@ -273,6 +273,7 @@ void AppendAttentionKernel(
cache_v_zp, cache_v_zp,
cache_quant_type_str, cache_quant_type_str,
use_neox_rotary_style, use_neox_rotary_style,
rope_3d,
max_input_length, max_input_length,
exec_stream, exec_stream,
&qkv_out, &qkv_out,
@@ -299,6 +300,7 @@ void AppendAttentionKernel(
cache_v_zp, cache_v_zp,
cache_quant_type_str, cache_quant_type_str,
use_neox_rotary_style, use_neox_rotary_style,
rope_3d,
max_input_length, max_input_length,
exec_stream, exec_stream,
&qkv_out, &qkv_out,

View File

@@ -353,7 +353,8 @@ __global__ void append_speculate_cache_rope_kernel(
const int head_size, const int head_size,
const int block_size, const int block_size,
const int elem_cnt, const int elem_cnt,
const int gqa_group_size) { const int gqa_group_size,
const bool rope_3d) {
using LoadT = AlignedVector<T, VecSize>; using LoadT = AlignedVector<T, VecSize>;
using LoadFloat = AlignedVector<float, VecSize>; using LoadFloat = AlignedVector<float, VecSize>;
using LoadInT = AlignedVector<InT, VecSize>; using LoadInT = AlignedVector<InT, VecSize>;
@@ -413,8 +414,9 @@ __global__ void append_speculate_cache_rope_kernel(
if (hi < num_heads + gqa_group_size) { if (hi < num_heads + gqa_group_size) {
// q k rope // q k rope
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2; const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec); int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec); Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
} }
#pragma unroll #pragma unroll
for (int i = 0; i < HalfVecSize; i++) { for (int i = 0; i < HalfVecSize; i++) {
@@ -486,7 +488,8 @@ __global__ void append_speculate_cache_neox_rope_kernel(
const int head_size, const int head_size,
const int block_size, const int block_size,
const int elem_cnt, const int elem_cnt,
const int gqa_group_size) { const int gqa_group_size,
const bool rope_3d) {
using LoadT = AlignedVector<T, VecSize>; using LoadT = AlignedVector<T, VecSize>;
using LoadFloat = AlignedVector<float, VecSize>; using LoadFloat = AlignedVector<float, VecSize>;
using LoadInT = AlignedVector<InT, VecSize>; using LoadInT = AlignedVector<InT, VecSize>;
@@ -550,8 +553,9 @@ __global__ void append_speculate_cache_neox_rope_kernel(
if (hi < num_heads + gqa_group_size) { if (hi < num_heads + gqa_group_size) {
// q k rope // q k rope
const int64_t emb_idx = write_seq_id * head_size + h_bias; const int64_t emb_idx = write_seq_id * head_size + h_bias;
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec); int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2: emb_idx;
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec); Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
} }
#pragma unroll #pragma unroll
for (int i = 0; i < VecSize; i++) { for (int i = 0; i < VecSize; i++) {
@@ -636,7 +640,8 @@ __global__ void append_speculate_cache_int8_rope_kernel(
const int block_size, const int block_size,
const float max_bound, const float max_bound,
const float min_bound, const float min_bound,
const int gqa_group_size) { const int gqa_group_size,
const bool rope_3d) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4; constexpr int NUM_WARPS = 4;
@@ -682,8 +687,9 @@ __global__ void append_speculate_cache_int8_rope_kernel(
// q rope // q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec); uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec); Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
if (qkv_out_scales) { if (qkv_out_scales) {
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec); Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
} }
@@ -743,10 +749,11 @@ __global__ void append_speculate_cache_int8_rope_kernel(
T scale; T scale;
if (head_idx < num_heads + gqa_group_size) { if (head_idx < num_heads + gqa_group_size) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1); uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2); Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1); Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2); Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
scale = __ldg(&cache_k_scales[kv_head_idx]); scale = __ldg(&cache_k_scales[kv_head_idx]);
} else { } else {
scale = __ldg(&cache_v_scales[kv_head_idx]); scale = __ldg(&cache_v_scales[kv_head_idx]);
@@ -868,7 +875,8 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
const int block_size, const int block_size,
const float max_bound, const float max_bound,
const float min_bound, const float min_bound,
const int gqa_group_size) { const int gqa_group_size,
const bool rope_3d) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4; constexpr int NUM_WARPS = 4;
@@ -917,8 +925,9 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
// q rope // q rope
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec); uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec); Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
if (qkv_out_scales) { if (qkv_out_scales) {
Load<float, VecSize>(&qkv_out_scales[bias_idx_left], Load<float, VecSize>(&qkv_out_scales[bias_idx_left],
&left_out_scale_vec); &left_out_scale_vec);
@@ -1013,10 +1022,11 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
T scale; T scale;
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1); uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2); Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1); Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2); Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
scale = __ldg(&cache_k_scales[kv_head_idx]); scale = __ldg(&cache_k_scales[kv_head_idx]);
#pragma unroll #pragma unroll
for (int i = 0; i < HALF_K_VEC_SIZE; i++) { for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
@@ -1248,7 +1258,8 @@ __global__ void append_speculate_cache_int4_rope_kernel(
const int block_size, const int block_size,
const float max_bound, const float max_bound,
const float min_bound, const float min_bound,
const int gqa_group_size) { const int gqa_group_size,
const bool rope_3d) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4; constexpr int NUM_WARPS = 4;
@@ -1305,8 +1316,9 @@ __global__ void append_speculate_cache_int4_rope_kernel(
// Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec); // Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
// q rope // q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec); uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec); Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
#pragma unroll #pragma unroll
for (int i = 0; i < HalfVecSize; i++) { for (int i = 0; i < HalfVecSize; i++) {
// dequant + add_bias + rope // dequant + add_bias + rope
@@ -1395,10 +1407,11 @@ __global__ void append_speculate_cache_int4_rope_kernel(
// &out_scale_vec2); // &out_scale_vec2);
if (head_idx < num_heads + gqa_group_size) { if (head_idx < num_heads + gqa_group_size) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1); uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2); Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1); Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2); Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx], &scale_vec1); Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx], &scale_vec1);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx + 8], &scale_vec2); Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx + 8], &scale_vec2);
Load<T, HALF_K_VEC_SIZE>(&cache_k_zero_points[cache_idx], &zp_vec1); Load<T, HALF_K_VEC_SIZE>(&cache_k_zero_points[cache_idx], &zp_vec1);
@@ -1591,7 +1604,8 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
const int block_size, const int block_size,
const float max_bound, const float max_bound,
const float min_bound, const float min_bound,
const int gqa_group_size) { const int gqa_group_size,
const bool rope_3d) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4; constexpr int NUM_WARPS = 4;
@@ -1741,10 +1755,11 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
&right_out_scale_vec2); &right_out_scale_vec2);
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1); uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2); Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1); Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2); Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx], Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx],
&left_scale_vec1); &left_scale_vec1);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx + 8], Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx + 8],

View File

@@ -110,7 +110,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
const int bsz, const int bsz,
const int token_num, const int token_num,
const cudaStream_t& stream, const cudaStream_t& stream,
const bool use_neox_style) { const bool use_neox_style,
const bool rope_3d) {
int output_inner_dim = num_heads + 2 * kv_num_heads; int output_inner_dim = num_heads + 2 * kv_num_heads;
const uint32_t elem_nums = const uint32_t elem_nums =
@@ -144,7 +145,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
dim_head, dim_head,
block_size, block_size,
elem_nums, elem_nums,
kv_num_heads); kv_num_heads,
rope_3d);
} else { } else {
append_speculate_cache_rope_kernel<T, PackSize> append_speculate_cache_rope_kernel<T, PackSize>
<<<grid_size, threads_per_block, 0, stream>>>( <<<grid_size, threads_per_block, 0, stream>>>(
@@ -167,7 +169,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
dim_head, dim_head,
block_size, block_size,
elem_nums, elem_nums,
kv_num_heads); kv_num_heads,
rope_3d);
} }
} }
@@ -196,7 +199,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
const int bsz, const int bsz,
const int token_num, const int token_num,
const cudaStream_t& stream, const cudaStream_t& stream,
const bool use_neox_style) { const bool use_neox_style,
const bool rope_3d) {
constexpr int num_warps = 4; constexpr int num_warps = 4;
const int all_warps = const int all_warps =
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps; ((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
@@ -238,7 +242,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
block_size, block_size,
127.0f, 127.0f,
-127.0f, -127.0f,
kv_num_heads); kv_num_heads,
rope_3d);
} else { } else {
append_speculate_cache_int8_rope_kernel<T, 4, 0, 128, QKV_TYPE, IsFP8> append_speculate_cache_int8_rope_kernel<T, 4, 0, 128, QKV_TYPE, IsFP8>
<<<grids, num_warps * 32, 0, stream>>>(qkv, <<<grids, num_warps * 32, 0, stream>>>(qkv,
@@ -262,7 +267,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
block_size, block_size,
127.0f, 127.0f,
-127.0f, -127.0f,
kv_num_heads); kv_num_heads,
rope_3d);
} }
} }
@@ -293,7 +299,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
const int bsz, const int bsz,
const int token_num, const int token_num,
const cudaStream_t& stream, const cudaStream_t& stream,
const bool use_neox_style) { const bool use_neox_style,
const bool rope_3d) {
constexpr int num_warps = 4; constexpr int num_warps = 4;
const int all_warps = const int all_warps =
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps; ((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
@@ -337,7 +344,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
block_size, block_size,
7.0f, 7.0f,
-8.0f, -8.0f,
kv_num_heads); kv_num_heads,
rope_3d);
} else { } else {
append_speculate_cache_int4_rope_kernel<T, 4> append_speculate_cache_int4_rope_kernel<T, 4>
<<<grids, num_warps * 32, 0, stream>>>(qkv, <<<grids, num_warps * 32, 0, stream>>>(qkv,
@@ -363,7 +371,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
block_size, block_size,
7.0f, 7.0f,
-8.0f, -8.0f,
kv_num_heads); kv_num_heads,
rope_3d);
} }
} }
template <typename T, typename QKV_TYPE> template <typename T, typename QKV_TYPE>
@@ -384,6 +393,7 @@ void SpeculateWriteCacheWithRoPEKernel(
const paddle::optional<paddle::Tensor>& cache_v_zp, const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str, const std::string& cache_quant_type_str,
const bool use_neox_rotary_style, const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len, const int max_seq_len,
cudaStream_t& stream, cudaStream_t& stream,
paddle::Tensor* qkv_out, paddle::Tensor* qkv_out,
@@ -479,7 +489,8 @@ void SpeculateWriteCacheWithRoPEKernel(
bsz, bsz,
token_nums, token_nums,
stream, stream,
use_neox_rotary_style); use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_int8") { } else if (cache_quant_type_str == "cache_int8") {
append_speculate_cache_int8_rope( append_speculate_cache_int8_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr), reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
@@ -512,7 +523,8 @@ void SpeculateWriteCacheWithRoPEKernel(
bsz, bsz,
token_nums, token_nums,
stream, stream,
use_neox_rotary_style); use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_fp8") { } else if (cache_quant_type_str == "cache_fp8") {
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>( append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr), reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
@@ -545,7 +557,8 @@ void SpeculateWriteCacheWithRoPEKernel(
bsz, bsz,
token_nums, token_nums,
stream, stream,
use_neox_rotary_style); use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_int4_zp") { } else if (cache_quant_type_str == "cache_int4_zp") {
append_speculate_cache_int4_rope( append_speculate_cache_int4_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr), reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
@@ -584,7 +597,8 @@ void SpeculateWriteCacheWithRoPEKernel(
bsz, bsz,
token_nums, token_nums,
stream, stream,
use_neox_rotary_style); use_neox_rotary_style,
rope_3d);
} else { } else {
PD_THROW( PD_THROW(
"cache_quant_type_str should be one of [none, cache_int8, " "cache_quant_type_str should be one of [none, cache_int8, "
@@ -612,6 +626,7 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
const paddle::optional<paddle::Tensor>& cache_v_zp, const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str, const std::string& cache_quant_type_str,
const bool use_neox_rotary_style, const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len, const int max_seq_len,
cudaStream_t& stream, cudaStream_t& stream,
paddle::Tensor* qkv_out, paddle::Tensor* qkv_out,
@@ -641,6 +656,7 @@ SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
const paddle::optional<paddle::Tensor>& cache_v_zp, const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str, const std::string& cache_quant_type_str,
const bool use_neox_rotary_style, const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len, const int max_seq_len,
cudaStream_t& stream, cudaStream_t& stream,
paddle::Tensor* qkv_out, paddle::Tensor* qkv_out,
@@ -669,6 +685,7 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
const paddle::optional<paddle::Tensor>& cache_v_zp, const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str, const std::string& cache_quant_type_str,
const bool use_neox_rotary_style, const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len, const int max_seq_len,
cudaStream_t& stream, cudaStream_t& stream,
paddle::Tensor* qkv_out, paddle::Tensor* qkv_out,
@@ -699,6 +716,7 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
const paddle::optional<paddle::Tensor>& cache_v_zp, const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str, const std::string& cache_quant_type_str,
const bool use_neox_rotary_style, const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len, const int max_seq_len,
cudaStream_t& stream, cudaStream_t& stream,
paddle::Tensor* qkv_out, paddle::Tensor* qkv_out,

View File

@@ -35,6 +35,7 @@ void SpeculateWriteCacheWithRoPEKernel(
const paddle::optional<paddle::Tensor>& cache_v_zp, const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str, const std::string& cache_quant_type_str,
const bool use_neox_rotary_style, const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len, const int max_seq_len,
cudaStream_t& stream, cudaStream_t& stream,
paddle::Tensor* qkv_out, paddle::Tensor* qkv_out,