[Feature][MTP]Support MTP for rl-model (#4009)

* qk norm for speculate decode C16

* support mtp in v1_scheduler mode

* support mtp rope_3d

* support mtp features

* add unit test && del some log

---------

Co-authored-by: yuanxiaolan <yuanxiaolan01@baidu.com>
Co-authored-by: xiaoxiaohehe001 <hiteezsf@163.com>
This commit is contained in:
freeliuzc
2025-09-10 13:34:37 +08:00
committed by GitHub
parent cce2410fad
commit 2f473ba966
21 changed files with 1465 additions and 531 deletions

View File

@@ -273,11 +273,15 @@ void AppendAttentionKernel(
cache_v_zp, cache_v_zp,
cache_quant_type_str, cache_quant_type_str,
use_neox_rotary_style, use_neox_rotary_style,
rope_3d,
max_input_length, max_input_length,
exec_stream, exec_stream,
&qkv_out, &qkv_out,
const_cast<paddle::Tensor*>(&key_cache), const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache)); const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
} else { } else {
SpeculateWriteCacheWithRoPEKernel<data_t, data_t>( SpeculateWriteCacheWithRoPEKernel<data_t, data_t>(
meta_data, meta_data,
@@ -296,11 +300,15 @@ void AppendAttentionKernel(
cache_v_zp, cache_v_zp,
cache_quant_type_str, cache_quant_type_str,
use_neox_rotary_style, use_neox_rotary_style,
rope_3d,
max_input_length, max_input_length,
exec_stream, exec_stream,
&qkv_out, &qkv_out,
const_cast<paddle::Tensor*>(&key_cache), const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache)); const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
} }
} else { } else {
if (qkv_out_scales) { if (qkv_out_scales) {

View File

@@ -120,7 +120,6 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
float row_variance = float row_variance =
max(warp_m2 / head_size, 0.0f); max(warp_m2 / head_size, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps); float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
if (hi < num_heads) { // q if (hi < num_heads) { // q
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec); Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
#pragma unroll #pragma unroll
@@ -129,6 +128,7 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
} }
} else { // k } else { // k
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec); Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) { for (int i = 0; i < VecSize; i++) {
out_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]); out_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
} }

View File

@@ -18,6 +18,168 @@
#include "mma_tensor_op.cuh" #include "mma_tensor_op.cuh"
#include "utils.cuh" #include "utils.cuh"
template <typename T, int VecSize = 1, typename InT = T>
__global__ void append_speculate_cache_T_rope_qk_norm_kernel(
const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size,
// head_size]
T* __restrict__ key_cache, // [num_blocks, gqa_group_size, block_size,
// head_size // 2]
T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size,
// head_size // 2]
T* __restrict__ q_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens_decoder, // [bsz]
const float* __restrict__ cos_emb,
const float* __restrict__ sin_emb,
const float*
qkv_out_scales, // [(num_heads + 2 * gqa_group_size) * head_size]
const T* qkv_biases, // [num_head + 2 * gqa_group_size, dim_head]
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int output_inner_dim,
const int head_size,
const int block_size,
const int elem_cnt,
const int gqa_group_size,
const float* q_norm_weight,
const float* k_norm_weight,
const float rms_norm_eps,
const bool rope_3d) {
using LoadT = AlignedVector<T, VecSize>;
using LoadFloat = AlignedVector<float, VecSize>;
using LoadInT = AlignedVector<InT, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
LoadInT src_vec;
LoadFloat scale_vec;
LoadT bias_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
LoadFloat tmp_vec;
LoadFloat q_norm_vec;
LoadFloat k_norm_vec;
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
int64_t all_warp_num = gridDim.x * blockDim.y;
int64_t all_head_dim = elem_cnt / head_size;
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * head_size;
const int half_head_size = head_size / 2;
for (int global_hi = global_warp_idx; global_hi < all_head_dim; global_hi += all_warp_num) {
int64_t linear_index = global_hi * head_size + threadIdx.x * VecSize;
const int token_id = linear_index / hidden_size;
const int ori_bi = batch_id_per_token[token_id];
if (seq_lens_decoder[ori_bi] == 0) continue;
const int bias = linear_index % hidden_size;
const int hi = bias / head_size; // q + k + v
const int h_bias = bias % head_size;
const int start_token_idx = cu_seqlens_q[ori_bi];
const int write_seq_id =
seq_lens_decoder[ori_bi] + token_id - start_token_idx;
if (write_seq_id == 0) continue;
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
if (block_idx < 0) {
printf(
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
"%d %d %d %d\n",
block_idx,
write_seq_id,
ori_bi,
seq_lens_decoder[ori_bi],
token_id,
cu_seqlens_q[ori_bi]);
}
const int block_offset = write_seq_id % block_size;
const int write_q_idx =
token_id * output_inner_dim * head_size + hi * head_size + h_bias;
const int bias_idx = hi * head_size + h_bias;
Load<InT, VecSize>(&qkv[linear_index], &src_vec);
if (qkv_biases) {
Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
}
if (qkv_out_scales) {
Load<float, VecSize>(&qkv_out_scales[bias_idx], &scale_vec);
}
if (hi < num_heads + gqa_group_size) {
// q k rope
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
}
float thread_m2 = 0.0f;
float warp_m2 = 0.0f;
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
// add_bias + rope
float input_left = static_cast<float>(src_vec[2 * i]);
float input_right = static_cast<float>(src_vec[2 * i + 1]);
if (qkv_out_scales) {
input_left *= scale_vec[2 * i];
input_right *= scale_vec[2 * i + 1];
}
if (qkv_biases) {
input_left = input_left + static_cast<float>(bias_vec[2 * i]);
input_right = input_right + static_cast<float>(bias_vec[2 * i + 1]);
}
if (hi < num_heads + gqa_group_size) {
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
tmp_vec[2 * i] = tmp1;
tmp_vec[2 * i + 1] = tmp2;
} else {
bias_vec[2 * i] = static_cast<T>(input_left);
bias_vec[2 * i + 1] = static_cast<T>(input_right);
}
}
if (hi < (num_heads + gqa_group_size)) {
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
float row_variance =
max(warp_m2 / head_size, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
if (hi < num_heads) {
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
}
} else {
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
}
}
}
if (hi < num_heads) {
// write q
Store<T, VecSize>(bias_vec, &q_out[write_q_idx]);
} else {
// write k/v
const int kv_head_idx = (hi - num_heads) % gqa_group_size;
const int tgt_idx = (block_idx * gqa_group_size * block_size * head_size +
kv_head_idx * block_size * head_size +
block_offset * head_size + h_bias);
// write
if (hi < num_heads + gqa_group_size) {
Store<T, VecSize>(bias_vec, &key_cache[tgt_idx]);
} else {
Store<T, VecSize>(bias_vec, &value_cache[tgt_idx]);
}
}
}
}
template <int VecSize = 4, int HeadDim = 128> template <int VecSize = 4, int HeadDim = 128>
__global__ void append_clear_cache_int8_block( __global__ void append_clear_cache_int8_block(
uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size, uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size,
@@ -193,7 +355,8 @@ __global__ void append_speculate_cache_rope_kernel(
const int head_size, const int head_size,
const int block_size, const int block_size,
const int elem_cnt, const int elem_cnt,
const int gqa_group_size) { const int gqa_group_size,
const bool rope_3d) {
using LoadT = AlignedVector<T, VecSize>; using LoadT = AlignedVector<T, VecSize>;
using LoadFloat = AlignedVector<float, VecSize>; using LoadFloat = AlignedVector<float, VecSize>;
using LoadInT = AlignedVector<InT, VecSize>; using LoadInT = AlignedVector<InT, VecSize>;
@@ -253,8 +416,9 @@ __global__ void append_speculate_cache_rope_kernel(
if (hi < num_heads + gqa_group_size) { if (hi < num_heads + gqa_group_size) {
// q k rope // q k rope
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2; const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec); int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec); Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
} }
#pragma unroll #pragma unroll
for (int i = 0; i < HalfVecSize; i++) { for (int i = 0; i < HalfVecSize; i++) {
@@ -326,7 +490,8 @@ __global__ void append_speculate_cache_neox_rope_kernel(
const int head_size, const int head_size,
const int block_size, const int block_size,
const int elem_cnt, const int elem_cnt,
const int gqa_group_size) { const int gqa_group_size,
const bool rope_3d) {
using LoadT = AlignedVector<T, VecSize>; using LoadT = AlignedVector<T, VecSize>;
using LoadFloat = AlignedVector<float, VecSize>; using LoadFloat = AlignedVector<float, VecSize>;
using LoadInT = AlignedVector<InT, VecSize>; using LoadInT = AlignedVector<InT, VecSize>;
@@ -390,8 +555,9 @@ __global__ void append_speculate_cache_neox_rope_kernel(
if (hi < num_heads + gqa_group_size) { if (hi < num_heads + gqa_group_size) {
// q k rope // q k rope
const int64_t emb_idx = write_seq_id * head_size + h_bias; const int64_t emb_idx = write_seq_id * head_size + h_bias;
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec); int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2: emb_idx;
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec); Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
} }
#pragma unroll #pragma unroll
for (int i = 0; i < VecSize; i++) { for (int i = 0; i < VecSize; i++) {
@@ -476,7 +642,8 @@ __global__ void append_speculate_cache_int8_rope_kernel(
const int block_size, const int block_size,
const float max_bound, const float max_bound,
const float min_bound, const float min_bound,
const int gqa_group_size) { const int gqa_group_size,
const bool rope_3d) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4; constexpr int NUM_WARPS = 4;
@@ -522,8 +689,9 @@ __global__ void append_speculate_cache_int8_rope_kernel(
// q rope // q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec); uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec); Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
if (qkv_out_scales) { if (qkv_out_scales) {
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec); Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
} }
@@ -583,10 +751,11 @@ __global__ void append_speculate_cache_int8_rope_kernel(
T scale; T scale;
if (head_idx < num_heads + gqa_group_size) { if (head_idx < num_heads + gqa_group_size) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1); uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2); Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1); Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2); Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
scale = __ldg(&cache_k_scales[kv_head_idx]); scale = __ldg(&cache_k_scales[kv_head_idx]);
} else { } else {
scale = __ldg(&cache_v_scales[kv_head_idx]); scale = __ldg(&cache_v_scales[kv_head_idx]);
@@ -708,7 +877,8 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
const int block_size, const int block_size,
const float max_bound, const float max_bound,
const float min_bound, const float min_bound,
const int gqa_group_size) { const int gqa_group_size,
const bool rope_3d) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4; constexpr int NUM_WARPS = 4;
@@ -757,8 +927,9 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
// q rope // q rope
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec); uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec); Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
if (qkv_out_scales) { if (qkv_out_scales) {
Load<float, VecSize>(&qkv_out_scales[bias_idx_left], Load<float, VecSize>(&qkv_out_scales[bias_idx_left],
&left_out_scale_vec); &left_out_scale_vec);
@@ -853,10 +1024,11 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
T scale; T scale;
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1); uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2); Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1); Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2); Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
scale = __ldg(&cache_k_scales[kv_head_idx]); scale = __ldg(&cache_k_scales[kv_head_idx]);
#pragma unroll #pragma unroll
for (int i = 0; i < HALF_K_VEC_SIZE; i++) { for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
@@ -1088,7 +1260,8 @@ __global__ void append_speculate_cache_int4_rope_kernel(
const int block_size, const int block_size,
const float max_bound, const float max_bound,
const float min_bound, const float min_bound,
const int gqa_group_size) { const int gqa_group_size,
const bool rope_3d) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4; constexpr int NUM_WARPS = 4;
@@ -1145,8 +1318,9 @@ __global__ void append_speculate_cache_int4_rope_kernel(
// Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec); // Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
// q rope // q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec); uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec); Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
#pragma unroll #pragma unroll
for (int i = 0; i < HalfVecSize; i++) { for (int i = 0; i < HalfVecSize; i++) {
// dequant + add_bias + rope // dequant + add_bias + rope
@@ -1235,10 +1409,11 @@ __global__ void append_speculate_cache_int4_rope_kernel(
// &out_scale_vec2); // &out_scale_vec2);
if (head_idx < num_heads + gqa_group_size) { if (head_idx < num_heads + gqa_group_size) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1); uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2); Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1); Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2); Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx], &scale_vec1); Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx], &scale_vec1);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx + 8], &scale_vec2); Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx + 8], &scale_vec2);
Load<T, HALF_K_VEC_SIZE>(&cache_k_zero_points[cache_idx], &zp_vec1); Load<T, HALF_K_VEC_SIZE>(&cache_k_zero_points[cache_idx], &zp_vec1);
@@ -1431,7 +1606,8 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
const int block_size, const int block_size,
const float max_bound, const float max_bound,
const float min_bound, const float min_bound,
const int gqa_group_size) { const int gqa_group_size,
const bool rope_3d) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4; constexpr int NUM_WARPS = 4;
@@ -1581,10 +1757,11 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
&right_out_scale_vec2); &right_out_scale_vec2);
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1); uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2); Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1); Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2); Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx], Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx],
&left_scale_vec1); &left_scale_vec1);
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx + 8], Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx + 8],

View File

@@ -15,6 +15,78 @@
#include "speculate_write_cache_with_rope_kernel.h" #include "speculate_write_cache_with_rope_kernel.h"
#include "utils.cuh" #include "utils.cuh"
template <typename T, typename QKV_TYPE>
void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
T* key_cache,
T* value_cache,
T* qkv_out,
const int* block_tables,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
const float* sin_emb,
const float* qkv_out_scales,
const T* qkv_biases,
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int kv_num_heads,
const int dim_head,
const int block_size,
const int bsz,
const int token_num,
const cudaStream_t& stream,
const bool use_neox_style,
const float* q_norm_weight,
const float* k_norm_weight,
const float rms_norm_eps,
const bool rope_3d) {
int output_inner_dim = num_heads + 2 * kv_num_heads;
const uint32_t elem_nums =
use_neox_style ? token_num * (num_heads + 2 * kv_num_heads) * dim_head / 2
: token_num * (num_heads + 2 * kv_num_heads) * dim_head;
constexpr int HEAD_DIM = 128;
constexpr int PackSize = HEAD_DIM / kWarpSize;
const int pack_num = elem_nums / PackSize;
const int blocksize = 128;
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);
if (use_neox_style) {
PD_THROW(
"append_speculate_cache_rope_qk_norm not support neox rope yet");
} else {
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
append_speculate_cache_T_rope_qk_norm_kernel<T, PackSize>
<<<grid_size, block_dim, 0, stream>>>(qkv,
key_cache,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
cos_emb,
sin_emb,
qkv_out_scales,
qkv_biases,
max_seq_len,
max_blocks_per_seq,
num_heads,
output_inner_dim,
dim_head,
block_size,
elem_nums,
kv_num_heads,
q_norm_weight,
k_norm_weight,
rms_norm_eps,
rope_3d);
}
}
// rope + write // rope + write
template <typename T, typename QKV_TYPE> template <typename T, typename QKV_TYPE>
void append_speculate_cache_rope(const QKV_TYPE* qkv, void append_speculate_cache_rope(const QKV_TYPE* qkv,
@@ -39,7 +111,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
const int bsz, const int bsz,
const int token_num, const int token_num,
const cudaStream_t& stream, const cudaStream_t& stream,
const bool use_neox_style) { const bool use_neox_style,
const bool rope_3d) {
int output_inner_dim = num_heads + 2 * kv_num_heads; int output_inner_dim = num_heads + 2 * kv_num_heads;
const uint32_t elem_nums = const uint32_t elem_nums =
@@ -73,7 +146,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
dim_head, dim_head,
block_size, block_size,
elem_nums, elem_nums,
kv_num_heads); kv_num_heads,
rope_3d);
} else { } else {
append_speculate_cache_rope_kernel<T, PackSize> append_speculate_cache_rope_kernel<T, PackSize>
<<<grid_size, threads_per_block, 0, stream>>>( <<<grid_size, threads_per_block, 0, stream>>>(
@@ -96,7 +170,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
dim_head, dim_head,
block_size, block_size,
elem_nums, elem_nums,
kv_num_heads); kv_num_heads,
rope_3d);
} }
} }
@@ -125,7 +200,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
const int bsz, const int bsz,
const int token_num, const int token_num,
const cudaStream_t& stream, const cudaStream_t& stream,
const bool use_neox_style) { const bool use_neox_style,
const bool rope_3d) {
constexpr int num_warps = 4; constexpr int num_warps = 4;
const int all_warps = const int all_warps =
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps; ((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
@@ -167,7 +243,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
block_size, block_size,
127.0f, 127.0f,
-127.0f, -127.0f,
kv_num_heads); kv_num_heads,
rope_3d);
} else { } else {
append_speculate_cache_int8_rope_kernel<T, 4, 0, 128, QKV_TYPE, IsFP8> append_speculate_cache_int8_rope_kernel<T, 4, 0, 128, QKV_TYPE, IsFP8>
<<<grids, num_warps * 32, 0, stream>>>(qkv, <<<grids, num_warps * 32, 0, stream>>>(qkv,
@@ -191,7 +268,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
block_size, block_size,
127.0f, 127.0f,
-127.0f, -127.0f,
kv_num_heads); kv_num_heads,
rope_3d);
} }
} }
@@ -222,7 +300,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
const int bsz, const int bsz,
const int token_num, const int token_num,
const cudaStream_t& stream, const cudaStream_t& stream,
const bool use_neox_style) { const bool use_neox_style,
const bool rope_3d) {
constexpr int num_warps = 4; constexpr int num_warps = 4;
const int all_warps = const int all_warps =
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps; ((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
@@ -266,7 +345,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
block_size, block_size,
7.0f, 7.0f,
-8.0f, -8.0f,
kv_num_heads); kv_num_heads,
rope_3d);
} else { } else {
append_speculate_cache_int4_rope_kernel<T, 4> append_speculate_cache_int4_rope_kernel<T, 4>
<<<grids, num_warps * 32, 0, stream>>>(qkv, <<<grids, num_warps * 32, 0, stream>>>(qkv,
@@ -292,7 +372,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
block_size, block_size,
7.0f, 7.0f,
-8.0f, -8.0f,
kv_num_heads); kv_num_heads,
rope_3d);
} }
} }
template <typename T, typename QKV_TYPE> template <typename T, typename QKV_TYPE>
@@ -313,11 +394,15 @@ void SpeculateWriteCacheWithRoPEKernel(
const paddle::optional<paddle::Tensor>& cache_v_zp, const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str, const std::string& cache_quant_type_str,
const bool use_neox_rotary_style, const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len, const int max_seq_len,
cudaStream_t& stream, cudaStream_t& stream,
paddle::Tensor* qkv_out, paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out, paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out) { paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps) {
typedef cascade_attn_type_traits<T> traits_; typedef cascade_attn_type_traits<T> traits_;
typedef cascade_attn_type_traits<QKV_TYPE> qkt_nv_type_; typedef cascade_attn_type_traits<QKV_TYPE> qkt_nv_type_;
typedef typename traits_::type DataType_; typedef typename traits_::type DataType_;
@@ -342,142 +427,185 @@ void SpeculateWriteCacheWithRoPEKernel(
? rotary_embs.get().data<float>() + max_seq_len * dim_head ? rotary_embs.get().data<float>() + max_seq_len * dim_head
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2; : rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
} }
if (cache_quant_type_str == "none") { if (q_norm_weight && k_norm_weight) {
append_speculate_cache_rope( if (cache_quant_type_str == "none") {
reinterpret_cast<const QKV_TYPE*>(qkv_ptr), append_speculate_cache_rope_qk_norm(
reinterpret_cast<DataType_*>(key_cache_out->data<T>()), reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()), reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()), reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
block_tables.data<int>(), reinterpret_cast<DataType_*>(qkv_out->data<T>()),
batch_id_per_token.data<int>(), block_tables.data<int>(),
cu_seqlens_q.data<int>(), batch_id_per_token.data<int>(),
seq_lens.data<int>(), cu_seqlens_q.data<int>(),
seq_lens_encoder.data<int>(), seq_lens.data<int>(),
cos_emb, seq_lens_encoder.data<int>(),
sin_emb, cos_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr, sin_emb,
qkv_biases ? reinterpret_cast<DataType_*>( qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
const_cast<T*>(qkv_biases.get().data<T>())) qkv_biases ? reinterpret_cast<DataType_*>(
: nullptr, const_cast<T*>(qkv_biases.get().data<T>()))
max_seq_len, : nullptr,
max_blocks_per_seq, max_seq_len,
num_heads, max_blocks_per_seq,
kv_num_heads, num_heads,
dim_head, kv_num_heads,
block_size, dim_head,
bsz, block_size,
token_nums, bsz,
stream, token_nums,
use_neox_rotary_style); stream,
} else if (cache_quant_type_str == "cache_int8") { use_neox_rotary_style,
append_speculate_cache_int8_rope( reinterpret_cast<const float*>(q_norm_weight.get().data<float>()),
reinterpret_cast<const QKV_TYPE*>(qkv_ptr), reinterpret_cast<const float*>(k_norm_weight.get().data<float>()),
key_cache_out->data<uint8_t>(), rms_norm_eps,
value_cache_out->data<uint8_t>(), rope_3d);
reinterpret_cast<DataType_*>(qkv_out->data<T>()), } else {
block_tables.data<int>(), PD_THROW(
batch_id_per_token.data<int>(), "append_decode_cache_rope_qk_norm not support cachekv quant yet");
cu_seqlens_q.data<int>(), }
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style);
} else if (cache_quant_type_str == "cache_fp8") {
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style);
} else if (cache_quant_type_str == "cache_int4_zp") {
append_speculate_cache_int4_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
cache_k_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_zp.get().data<T>()))
: nullptr,
cache_v_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_zp.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style);
} else { } else {
PD_THROW( if (cache_quant_type_str == "none") {
"cache_quant_type_str should be one of [none, cache_int8, " append_speculate_cache_rope(
"cache_int4_zp]"); reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_int8") {
append_speculate_cache_int8_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_fp8") {
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_int4_zp") {
append_speculate_cache_int4_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
cache_k_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_zp.get().data<T>()))
: nullptr,
cache_v_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_zp.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
rope_3d);
} else {
PD_THROW(
"cache_quant_type_str should be one of [none, cache_int8, "
"cache_int4_zp]");
}
} }
} }
@@ -500,11 +628,15 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
const paddle::optional<paddle::Tensor>& cache_v_zp, const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str, const std::string& cache_quant_type_str,
const bool use_neox_rotary_style, const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len, const int max_seq_len,
cudaStream_t& stream, cudaStream_t& stream,
paddle::Tensor* qkv_out, paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out, paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out); paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
template void template void
SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>( SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
@@ -526,11 +658,15 @@ SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
const paddle::optional<paddle::Tensor>& cache_v_zp, const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str, const std::string& cache_quant_type_str,
const bool use_neox_rotary_style, const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len, const int max_seq_len,
cudaStream_t& stream, cudaStream_t& stream,
paddle::Tensor* qkv_out, paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out, paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out); paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>( template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
const AppendAttnMetaData& meta_data, const AppendAttnMetaData& meta_data,
@@ -551,11 +687,15 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
const paddle::optional<paddle::Tensor>& cache_v_zp, const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str, const std::string& cache_quant_type_str,
const bool use_neox_rotary_style, const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len, const int max_seq_len,
cudaStream_t& stream, cudaStream_t& stream,
paddle::Tensor* qkv_out, paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out, paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out); paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
template void template void
@@ -578,8 +718,12 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
const paddle::optional<paddle::Tensor>& cache_v_zp, const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str, const std::string& cache_quant_type_str,
const bool use_neox_rotary_style, const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len, const int max_seq_len,
cudaStream_t& stream, cudaStream_t& stream,
paddle::Tensor* qkv_out, paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out, paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out); paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);

View File

@@ -35,8 +35,12 @@ void SpeculateWriteCacheWithRoPEKernel(
const paddle::optional<paddle::Tensor>& cache_v_zp, const paddle::optional<paddle::Tensor>& cache_v_zp,
const std::string& cache_quant_type_str, const std::string& cache_quant_type_str,
const bool use_neox_rotary_style, const bool use_neox_rotary_style,
const bool rope_3d,
const int max_seq_len, const int max_seq_len,
cudaStream_t& stream, cudaStream_t& stream,
paddle::Tensor* qkv_out, paddle::Tensor* qkv_out,
paddle::Tensor* key_cache_out, paddle::Tensor* key_cache_out,
paddle::Tensor* value_cache_out); paddle::Tensor* value_cache_out,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);

View File

@@ -378,9 +378,11 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
const paddle::Tensor &step_seq_lens_decoder, const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &block_tables, const paddle::Tensor &block_tables,
const paddle::Tensor &is_block_step, const paddle::Tensor &is_block_step,
const int block_size); const paddle::optional<paddle::Tensor> &draft_tokens,
const paddle::optional<paddle::Tensor> &step_draft_tokens,
const paddle::optional<paddle::Tensor> &step_seq_lens_this_time,
const int block_size,
const int max_draft_tokens);
paddle::Tensor paddle::Tensor
GroupSwigluWithMasked(const paddle::Tensor &fc1_out_tensor, GroupSwigluWithMasked(const paddle::Tensor &fc1_out_tensor,
@@ -707,6 +709,22 @@ void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num, void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder); const paddle::Tensor& seq_lens_decoder);
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
const paddle::Tensor &block_tables,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &step_draft_tokens,
const paddle::Tensor &step_seq_lens_this_time,
const paddle::Tensor &accept_num,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &stop_nums,
const int block_size,
const int max_draft_tokens);
void NgramMatch(const paddle::Tensor &input_ids, void NgramMatch(const paddle::Tensor &input_ids,
const paddle::Tensor &input_ids_len, const paddle::Tensor &input_ids_len,
const paddle::Tensor &pre_ids, const paddle::Tensor &pre_ids,
@@ -750,6 +768,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx, const paddle::Tensor& step_idx,
const paddle::Tensor& not_need_stop, const paddle::Tensor& not_need_stop,
const paddle::Tensor& is_block_step,
const paddle::Tensor& batch_drop, const paddle::Tensor& batch_drop,
const paddle::Tensor& pre_ids, const paddle::Tensor& pre_ids,
const paddle::Tensor& accept_tokens, const paddle::Tensor& accept_tokens,
@@ -763,7 +782,8 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& base_model_draft_tokens, const paddle::Tensor& base_model_draft_tokens,
const int max_draft_token, const int max_draft_token,
const bool truncate_first_token, const bool truncate_first_token,
const bool splitwise_prefill); const bool splitwise_prefill,
const bool kvcache_scheduler_v1);
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens, void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
@@ -1228,6 +1248,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("speculate_clear_accept_nums",&SpeculateClearAcceptNums, "speculate_clear_accept_nums function"); m.def("speculate_clear_accept_nums",&SpeculateClearAcceptNums, "speculate_clear_accept_nums function");
m.def("speculate_schedule_cache",&SpeculateScheduleCache, "SpeculateScheduleCache function");
m.def("ngram_match", &NgramMatch, "ngram_match function"); m.def("ngram_match", &NgramMatch, "ngram_match function");
m.def("hybird_mtp_ngram", &HybridMtpNgram, "ngram_match_mixed function"); m.def("hybird_mtp_ngram", &HybridMtpNgram, "ngram_match_mixed function");

View File

@@ -15,31 +15,72 @@
#include "helper.h" #include "helper.h"
__global__ void recover_decode_task(bool *stop_flags, __global__ void recover_decode_task(bool *stop_flags,
int *seq_lens_this_time, int *seq_lens_this_time,
int *seq_lens_encoder, int *seq_lens_encoder,
int *seq_lens_decoder, int *seq_lens_decoder,
int *step_seq_lens_decoder, int *step_seq_lens_decoder,
int *block_tables, int *block_tables,
bool *is_block_step, bool *is_block_step,
const int bsz, const int bsz,
const int block_num_per_seq, const int block_num_per_seq,
const int block_size) { const int block_size) {
int thread_idx = threadIdx.x; int thread_idx = threadIdx.x;
if (thread_idx < bsz) { if (thread_idx < bsz) {
if(is_block_step[thread_idx] == true) { if(is_block_step[thread_idx] == true) {
int *block_table_now = block_tables + thread_idx * block_num_per_seq; int *block_table_now = block_tables + thread_idx * block_num_per_seq;
if (block_table_now[step_seq_lens_decoder[thread_idx] / block_size] != -1) { if (block_table_now[step_seq_lens_decoder[thread_idx] / block_size] != -1) {
// can be recovered for decoding // can be recovered for decoding
is_block_step[thread_idx] = false; is_block_step[thread_idx] = false;
seq_lens_this_time[thread_idx]= 1; seq_lens_this_time[thread_idx]= 1;
stop_flags[thread_idx] = false; stop_flags[thread_idx] = false;
seq_lens_encoder[thread_idx] = 0; seq_lens_encoder[thread_idx] = 0;
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx]; seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
}
}
} }
} }
} }
__global__ void recover_spec_decode_task(bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int *block_tables,
bool *is_block_step,
int64_t *draft_tokens,
const int64_t *step_draft_tokens,
const int *step_seq_lens_this_time,
const int bsz,
const int block_num_per_seq,
const int block_size,
const int draft_tokens_len,
const int num_extra_tokens) {
int thread_idx = threadIdx.x;
if (thread_idx < bsz) {
if(is_block_step[thread_idx] == true) {
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
int max_possible_block_idx = (step_seq_lens_decoder[thread_idx] + num_extra_tokens) / block_size;
max_possible_block_idx = min(max_possible_block_idx, block_num_per_seq);
if (block_table_now[max_possible_block_idx] != -1) {
// can be recovered for decoding
int64_t *draft_tokens_now = draft_tokens + thread_idx * draft_tokens_len;
const int64_t *step_draft_tokens_now = step_draft_tokens + thread_idx * draft_tokens_len;
is_block_step[thread_idx] = false;
seq_lens_this_time[thread_idx] = step_seq_lens_this_time[thread_idx];
stop_flags[thread_idx] = false;
seq_lens_encoder[thread_idx] = 0;
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
for (int i = 0; i < seq_lens_this_time[thread_idx]; i++) {
draft_tokens_now[i] = step_draft_tokens_now[i];
}
}
}
}
}
void RecoverDecodeTask(const paddle::Tensor &stop_flags, void RecoverDecodeTask(const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time, const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_encoder,
@@ -47,7 +88,11 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
const paddle::Tensor &step_seq_lens_decoder, const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &block_tables, const paddle::Tensor &block_tables,
const paddle::Tensor &is_block_step, const paddle::Tensor &is_block_step,
const int block_size) { const paddle::optional<paddle::Tensor> &draft_tokens,
const paddle::optional<paddle::Tensor> &step_draft_tokens,
const paddle::optional<paddle::Tensor> &step_seq_lens_this_time,
const int block_size,
const int max_draft_tokens) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(seq_lens_this_time.place())); auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(seq_lens_this_time.place()));
auto cu_stream = dev_ctx->stream(); auto cu_stream = dev_ctx->stream();
@@ -56,17 +101,38 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
#endif #endif
const int bsz = seq_lens_this_time.shape()[0]; const int bsz = seq_lens_this_time.shape()[0];
const int block_num_per_seq = block_tables.shape()[1]; const int block_num_per_seq = block_tables.shape()[1];
recover_decode_task<<<1, 1024, 0, cu_stream>>>( if (draft_tokens) {
const_cast<bool *>(stop_flags.data<bool>()), const int draft_tokens_len = draft_tokens.get_ptr()->shape()[1];
const_cast<int *>(seq_lens_this_time.data<int>()), recover_spec_decode_task<<<1, 1024, 0, cu_stream>>>(
const_cast<int *>(seq_lens_encoder.data<int>()), const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int *>(seq_lens_decoder.data<int>()), const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(step_seq_lens_decoder.data<int>()), const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(block_tables.data<int>()), const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<bool *>(is_block_step.data<bool>()), const_cast<int *>(step_seq_lens_decoder.data<int>()),
bsz, const_cast<int *>(block_tables.data<int>()),
block_num_per_seq, const_cast<bool *>(is_block_step.data<bool>()),
block_size); const_cast<int64_t *>(draft_tokens.get_ptr()->data<int64_t>()),
step_draft_tokens.get_ptr()->data<int64_t>(),
step_seq_lens_this_time.get_ptr()->data<int>(),
bsz,
block_num_per_seq,
block_size,
draft_tokens_len,
max_draft_tokens * 2 + 1);
} else {
recover_decode_task<<<1, 1024, 0, cu_stream>>>(
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(step_seq_lens_decoder.data<int>()),
const_cast<int *>(block_tables.data<int>()),
const_cast<bool *>(is_block_step.data<bool>()),
bsz,
block_num_per_seq,
block_size);
}
} }
PD_BUILD_STATIC_OP(recover_decode_task) PD_BUILD_STATIC_OP(recover_decode_task)
@@ -76,8 +142,11 @@ PD_BUILD_STATIC_OP(recover_decode_task)
"seq_lens_decoder", "seq_lens_decoder",
"step_seq_lens_decoder", "step_seq_lens_decoder",
"block_tables", "block_tables",
"is_block_step"}) "is_block_step",
.Attrs({"block_size: int"}) paddle::Optional("draft_tokens"),
paddle::Optional("step_draft_tokens"),
paddle::Optional("step_seq_lens_this_time")})
.Attrs({"block_size: int", "max_draft_tokens: int"})
.Outputs({"seq_lens_this_time_out", .Outputs({"seq_lens_this_time_out",
"seq_lens_encoder_out", "seq_lens_encoder_out",
"seq_lens_decoder_out", "seq_lens_decoder_out",

View File

@@ -15,7 +15,48 @@
#include "helper.h" #include "helper.h"
#include "paddle/extension.h" #include "paddle/extension.h"
template <int THREADBLOCK_SIZE, bool TRCUNCATE_FIRST_TOKEN>
#define DISPATCH_BLOCKSIZE(BLOCK_SIZE, ...) \
do { \
constexpr int BlockSize = BLOCK_SIZE; \
__VA_ARGS__; \
} while (0)
#define DISPATCH_TRUNCATE_FIRST_TOKEN(truncate_first_token, TRUNCATE_FIRST_TOKEN, ...) \
do { \
if (truncate_first_token) { \
constexpr bool TRUNCATE_FIRST_TOKEN = true; \
__VA_ARGS__; \
} else { \
constexpr bool TRUNCATE_FIRST_TOKEN = false; \
__VA_ARGS__; \
} \
} while (0)
#define DISPATCH_KVCACHE_SCHEDULER(kvcache_scheduler_v1, KVCACHE_SCHEDULER_V1, ...) \
do { \
if (kvcache_scheduler_v1) { \
constexpr bool KVCACHE_SCHEDULER_V1 = true; \
__VA_ARGS__; \
} else { \
constexpr bool KVCACHE_SCHEDULER_V1 = false; \
__VA_ARGS__; \
} \
} while (0)
#define DISPATCH_SPLITWISE_PREFILL(splitwise_prefill, SPLITWISE_PREFILL, ...) \
do { \
if (splitwise_prefill) { \
constexpr bool SPLITWISE_PREFILL = true; \
__VA_ARGS__; \
} else { \
constexpr bool SPLITWISE_PREFILL = false; \
__VA_ARGS__; \
} \
} while (0)
template <int THREADBLOCK_SIZE, bool TRUNCATE_FIRST_TOKEN, bool KVCACHE_SCHEDULER_V1>
__global__ void process_splitwise_prefill( __global__ void process_splitwise_prefill(
int64_t* draft_tokens, int64_t* draft_tokens,
int64_t* input_ids, int64_t* input_ids,
@@ -25,6 +66,7 @@ __global__ void process_splitwise_prefill(
int* seq_lens_decoder, int* seq_lens_decoder,
int64_t* step_idx, int64_t* step_idx,
bool* not_need_stop, bool* not_need_stop,
bool* is_block_step,
bool* batch_drop, bool* batch_drop,
int64_t* pre_ids, int64_t* pre_ids,
const int64_t* accept_tokens, const int64_t* accept_tokens,
@@ -58,7 +100,7 @@ __global__ void process_splitwise_prefill(
stop_flags[tid] = false; stop_flags[tid] = false;
int64_t base_model_first_token = accept_tokens_now[0]; int64_t base_model_first_token = accept_tokens_now[0];
int position = seq_len_encoder; int position = seq_len_encoder;
if (TRCUNCATE_FIRST_TOKEN) { if (TRUNCATE_FIRST_TOKEN) {
input_ids_now[position - 1] = base_model_first_token; input_ids_now[position - 1] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder; seq_lens_this_time[tid] = seq_len_encoder;
} else { } else {
@@ -84,7 +126,7 @@ __global__ void process_splitwise_prefill(
template <int THREADBLOCK_SIZE, bool TRCUNCATE_FIRST_TOKEN> template <int THREADBLOCK_SIZE, bool TRUNCATE_FIRST_TOKEN, bool KVCACHE_SCHEDULER_V1>
__global__ void draft_model_preprocess_kernel( __global__ void draft_model_preprocess_kernel(
int64_t* draft_tokens, int64_t* draft_tokens,
int64_t* input_ids, int64_t* input_ids,
@@ -94,6 +136,7 @@ __global__ void draft_model_preprocess_kernel(
int* seq_lens_decoder, int* seq_lens_decoder,
int64_t* step_idx, int64_t* step_idx,
bool* not_need_stop, bool* not_need_stop,
bool* is_block_step,
bool* batch_drop, bool* batch_drop,
int64_t* pre_ids, int64_t* pre_ids,
const int64_t* accept_tokens, const int64_t* accept_tokens,
@@ -134,14 +177,26 @@ __global__ void draft_model_preprocess_kernel(
base_model_draft_tokens_now[i] = -1; base_model_draft_tokens_now[i] = -1;
} }
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) { // 1. process block_step situation
batch_drop[tid] = true; // -- In v0 mode, block_step will drop mtp query.
stop_flags[tid] = true; // -- In v1 mode, block_step will continue to infer.
if constexpr(KVCACHE_SCHEDULER_V1) {
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
stop_flags[tid] = true;
is_block_step[tid] = true;
// Need to continue infer
}
} else {
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
batch_drop[tid] = true;
stop_flags[tid] = true;
}
} }
// 2. process normal query, not in any special case.
if (!(base_model_stop_flags[tid] || batch_drop[tid])) { if (!(base_model_stop_flags[tid] || batch_drop[tid])) {
not_stop_flag = 1; not_stop_flag = 1;
// 1. first token // prefill generation
if (seq_lens_encoder[tid] > 0) { if (seq_lens_encoder[tid] > 0) {
// Can be extended to first few tokens // Can be extended to first few tokens
int seq_len_encoder = seq_lens_encoder[tid]; int seq_len_encoder = seq_lens_encoder[tid];
@@ -149,14 +204,20 @@ __global__ void draft_model_preprocess_kernel(
int64_t base_model_first_token = accept_tokens_now[0]; int64_t base_model_first_token = accept_tokens_now[0];
pre_ids_now[0] = base_model_first_token; pre_ids_now[0] = base_model_first_token;
int position = seq_len_encoder; int position = seq_len_encoder;
if (TRCUNCATE_FIRST_TOKEN) { if (TRUNCATE_FIRST_TOKEN) {
input_ids_now[position - 1] = base_model_first_token; input_ids_now[position - 1] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder; seq_lens_this_time[tid] = seq_len_encoder;
} else { } else {
input_ids_now[position] = base_model_first_token; input_ids_now[position] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder + 1; seq_lens_this_time[tid] = seq_len_encoder + 1;
} }
} else { } else { // decode generation
if constexpr (KVCACHE_SCHEDULER_V1) {
// 3. try to recover mtp infer in V1 mode
if (!base_model_is_block_step[tid] && is_block_step[tid]) {
is_block_step[tid] = false;
}
}
if (stop_flags[tid]) { if (stop_flags[tid]) {
stop_flags[tid] = false; stop_flags[tid] = false;
// TODO: check // TODO: check
@@ -189,99 +250,8 @@ __global__ void draft_model_preprocess_kernel(
} }
} }
template <bool TRCUNCATE_FIRST_TOKEN>
void DispatchRunner(
const cudaStream_t& stream,
int64_t* draft_tokens,
int64_t* input_ids,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
bool* not_need_stop,
bool* batch_drop,
int64_t* pre_ids,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
const bool* base_model_stop_flags,
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
const int bsz,
const int num_model_step,
const int accept_tokens_len,
const int draft_tokens_len,
const int input_ids_len,
const int base_model_draft_tokens_len,
const int pre_ids_len,
const bool splitwise_prefill) {
constexpr int BlockSize = 512;
if (splitwise_prefill) {
process_splitwise_prefill<BlockSize, TRCUNCATE_FIRST_TOKEN>
<<<1, BlockSize, 0, stream>>>(
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
not_need_stop,
batch_drop,
pre_ids,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
bsz,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len);
} else {
draft_model_preprocess_kernel<BlockSize, TRCUNCATE_FIRST_TOKEN>
<<<1, BlockSize, 0, stream>>>(
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
not_need_stop,
batch_drop,
pre_ids,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
bsz,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len);
}
}
void DispatchTokenMode( void DispatchRunner(
const cudaStream_t &stream, const cudaStream_t &stream,
int64_t* draft_tokens, int64_t* draft_tokens,
int64_t* input_ids, int64_t* input_ids,
@@ -291,6 +261,7 @@ void DispatchTokenMode(
int* seq_lens_decoder, int* seq_lens_decoder,
int64_t* step_idx, int64_t* step_idx,
bool* not_need_stop, bool* not_need_stop,
bool* is_block_step,
bool* batch_drop, bool* batch_drop,
int64_t* pre_ids, int64_t* pre_ids,
const int64_t* accept_tokens, const int64_t* accept_tokens,
@@ -310,75 +281,79 @@ void DispatchTokenMode(
const int base_model_draft_tokens_len, const int base_model_draft_tokens_len,
const int pre_ids_len, const int pre_ids_len,
const bool truncate_first_token, const bool truncate_first_token,
const bool splitwise_prefill) { const bool splitwise_prefill,
if (truncate_first_token) { const bool kvcache_scheduler_v1) {
DispatchRunner<true>( DISPATCH_BLOCKSIZE(512, {
stream, DISPATCH_TRUNCATE_FIRST_TOKEN(truncate_first_token, TRUNCATE_FIRST_TOKEN, {
draft_tokens, DISPATCH_KVCACHE_SCHEDULER(kvcache_scheduler_v1, KVCACHE_SCHEDULER_V1, {
input_ids, DISPATCH_SPLITWISE_PREFILL(splitwise_prefill, SPLITWISE_PREFILL, {
stop_flags, if constexpr (SPLITWISE_PREFILL) {
seq_lens_this_time, process_splitwise_prefill<BlockSize, TRUNCATE_FIRST_TOKEN, KVCACHE_SCHEDULER_V1>
seq_lens_encoder, <<<1, BlockSize, 0, stream>>>(
seq_lens_decoder, draft_tokens,
step_idx, input_ids,
not_need_stop, stop_flags,
batch_drop, seq_lens_this_time,
pre_ids, seq_lens_encoder,
accept_tokens, seq_lens_decoder,
accept_num, step_idx,
base_model_seq_lens_this_time, not_need_stop,
base_model_seq_lens_encoder, is_block_step,
base_model_seq_lens_decoder, batch_drop,
base_model_step_idx, pre_ids,
base_model_stop_flags, accept_tokens,
base_model_is_block_step, accept_num,
base_model_draft_tokens, base_model_seq_lens_this_time,
bsz, base_model_seq_lens_encoder,
num_model_step, base_model_seq_lens_decoder,
accept_tokens_len, base_model_step_idx,
draft_tokens_len, base_model_stop_flags,
input_ids_len, base_model_is_block_step,
base_model_draft_tokens_len, base_model_draft_tokens,
pre_ids_len, bsz,
splitwise_prefill num_model_step,
); accept_tokens_len,
} else { draft_tokens_len,
DispatchRunner<false>( input_ids_len,
stream, base_model_draft_tokens_len,
draft_tokens, pre_ids_len);
input_ids, } else {
stop_flags, draft_model_preprocess_kernel<BlockSize, TRUNCATE_FIRST_TOKEN, KVCACHE_SCHEDULER_V1>
seq_lens_this_time, <<<1, BlockSize, 0, stream>>>(
seq_lens_encoder, draft_tokens,
seq_lens_decoder, input_ids,
step_idx, stop_flags,
not_need_stop, seq_lens_this_time,
batch_drop, seq_lens_encoder,
pre_ids, seq_lens_decoder,
accept_tokens, step_idx,
accept_num, not_need_stop,
base_model_seq_lens_this_time, is_block_step,
base_model_seq_lens_encoder, batch_drop,
base_model_seq_lens_decoder, pre_ids,
base_model_step_idx, accept_tokens,
base_model_stop_flags, accept_num,
base_model_is_block_step, base_model_seq_lens_this_time,
base_model_draft_tokens, base_model_seq_lens_encoder,
bsz, base_model_seq_lens_decoder,
num_model_step, base_model_step_idx,
accept_tokens_len, base_model_stop_flags,
draft_tokens_len, base_model_is_block_step,
input_ids_len, base_model_draft_tokens,
base_model_draft_tokens_len, bsz,
pre_ids_len, num_model_step,
splitwise_prefill accept_tokens_len,
); draft_tokens_len,
} input_ids_len,
base_model_draft_tokens_len,
pre_ids_len);
}
});
});
});
});
} }
void DraftModelPreprocess(const paddle::Tensor& draft_tokens, void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& input_ids, const paddle::Tensor& input_ids,
const paddle::Tensor& stop_flags, const paddle::Tensor& stop_flags,
@@ -387,6 +362,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx, const paddle::Tensor& step_idx,
const paddle::Tensor& not_need_stop, const paddle::Tensor& not_need_stop,
const paddle::Tensor& is_block_step,
const paddle::Tensor& batch_drop, const paddle::Tensor& batch_drop,
const paddle::Tensor& pre_ids, const paddle::Tensor& pre_ids,
const paddle::Tensor& accept_tokens, const paddle::Tensor& accept_tokens,
@@ -400,7 +376,8 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& base_model_draft_tokens, const paddle::Tensor& base_model_draft_tokens,
const int num_model_step, const int num_model_step,
const bool truncate_first_token, const bool truncate_first_token,
const bool splitwise_prefill) { const bool splitwise_prefill,
const bool kvcache_scheduler_v1) {
int real_bsz = seq_lens_this_time.shape()[0]; int real_bsz = seq_lens_this_time.shape()[0];
int accept_tokens_len = accept_tokens.shape()[1]; int accept_tokens_len = accept_tokens.shape()[1];
int input_ids_len = input_ids.shape()[1]; int input_ids_len = input_ids.shape()[1];
@@ -412,36 +389,38 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
auto not_need_stop_gpu = auto not_need_stop_gpu =
not_need_stop.copy_to(seq_lens_this_time.place(), false); not_need_stop.copy_to(seq_lens_this_time.place(), false);
DispatchTokenMode( DispatchRunner(
cu_stream, cu_stream,
const_cast<int64_t*>(draft_tokens.data<int64_t>()), const_cast<int64_t*>(draft_tokens.data<int64_t>()),
const_cast<int64_t*>(input_ids.data<int64_t>()), const_cast<int64_t*>(input_ids.data<int64_t>()),
const_cast<bool*>(stop_flags.data<bool>()), const_cast<bool*>(stop_flags.data<bool>()),
const_cast<int*>(seq_lens_this_time.data<int>()), const_cast<int*>(seq_lens_this_time.data<int>()),
const_cast<int*>(seq_lens_encoder.data<int>()), const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()), const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()), const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<bool*>(not_need_stop_gpu.data<bool>()), const_cast<bool*>(not_need_stop_gpu.data<bool>()),
const_cast<bool*>(batch_drop.data<bool>()), const_cast<bool*>(is_block_step.data<bool>()),
const_cast<int64_t*>(pre_ids.data<int64_t>()), const_cast<bool*>(batch_drop.data<bool>()),
accept_tokens.data<int64_t>(), const_cast<int64_t*>(pre_ids.data<int64_t>()),
accept_num.data<int>(), accept_tokens.data<int64_t>(),
base_model_seq_lens_this_time.data<int>(), accept_num.data<int>(),
base_model_seq_lens_encoder.data<int>(), base_model_seq_lens_this_time.data<int>(),
base_model_seq_lens_decoder.data<int>(), base_model_seq_lens_encoder.data<int>(),
base_model_step_idx.data<int64_t>(), base_model_seq_lens_decoder.data<int>(),
base_model_stop_flags.data<bool>(), base_model_step_idx.data<int64_t>(),
base_model_is_block_step.data<bool>(), base_model_stop_flags.data<bool>(),
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()), base_model_is_block_step.data<bool>(),
real_bsz, const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
num_model_step, real_bsz,
accept_tokens_len, num_model_step,
draft_tokens_len, accept_tokens_len,
input_ids_len, draft_tokens_len,
base_model_draft_tokens_len, input_ids_len,
pre_ids_len, base_model_draft_tokens_len,
truncate_first_token, pre_ids_len,
splitwise_prefill); truncate_first_token,
splitwise_prefill,
kvcache_scheduler_v1);
auto not_need_stop_cpu = auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), false); not_need_stop_gpu.copy_to(not_need_stop.place(), false);
@@ -459,6 +438,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
"seq_lens_decoder", "seq_lens_decoder",
"step_idx", "step_idx",
"not_need_stop", "not_need_stop",
"is_block_step",
"batch_drop", "batch_drop",
"pre_ids", "pre_ids",
"accept_tokens", "accept_tokens",
@@ -480,7 +460,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
"not_need_stop_out", "not_need_stop_out",
"batch_drop_out", "batch_drop_out",
"pre_ids_out"}) "pre_ids_out"})
.Attrs({"num_model_step: int", "truncate_first_token: bool", "splitwise_prefill: bool"}) .Attrs({"num_model_step: int", "truncate_first_token: bool", "splitwise_prefill: bool", "kvcache_scheduler_v1: bool"})
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"}, .SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
{"input_ids", "input_ids_out"}, {"input_ids", "input_ids_out"},
{"stop_flags", "stop_flags_out"}, {"stop_flags", "stop_flags_out"},

View File

@@ -0,0 +1,176 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
template <int THREADBLOCK_SIZE>
__global__ void speculate_schedula_cache(
const int64_t *draft_tokens,
int *block_tables,
bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *step_draft_tokens,
int *step_seq_lens_this_time,
int *accept_num,
int64_t *accept_tokens,
bool *is_block_step,
bool *not_need_stop,
const int64_t *stop_nums,
const int real_bsz,
const int max_bsz,
const int max_next_step_tokens,
const int draft_tokens_len,
const int accept_tokens_len,
const int block_size,
const int block_num_per_seq) {
const int bid = threadIdx.x;
int stop_flag_now_int = 0;
if (bid < real_bsz) {
if (!stop_flags[bid]) {
const int64_t *draft_tokens_now = draft_tokens + bid * draft_tokens_len;
int64_t *step_draft_tokens_now = step_draft_tokens + bid * draft_tokens_len;
int *block_table_now = block_tables + bid * block_num_per_seq;
int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len;
const int max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) / block_size;
if (max_possible_block_idx < block_num_per_seq && block_table_now[max_possible_block_idx] == -1) {
is_block_step[bid] = true;
step_seq_lens_this_time[bid] = seq_lens_this_time[bid];
seq_lens_this_time[bid] = 0;
stop_flags[bid] = true;
stop_flag_now_int = 1;
step_seq_lens_decoder[bid] = seq_lens_decoder[bid];
seq_lens_decoder[bid] = 0;
accept_num[bid] = 0;
for (int i = 0; i < accept_tokens_len; i++) {
accept_tokens_now[i] = -1;
}
for (int i = 0; i < draft_tokens_len; i++) {
step_draft_tokens_now[i] = draft_tokens_now[i];
}
}
} else {
stop_flag_now_int = 1;
}
} else if (bid >= real_bsz && bid < max_bsz) {
stop_flag_now_int = 1;
}
__syncthreads();
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
// printf("stop_flag_now_int %d \n", stop_flag_now_int);
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
if (threadIdx.x == 0) {
// printf("stop_sum %d \n", stop_sum);
not_need_stop[0] = stop_sum < stop_nums[0];
}
}
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
const paddle::Tensor &block_tables,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &step_draft_tokens,
const paddle::Tensor &step_seq_lens_this_time,
const paddle::Tensor &accept_num,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &stop_nums,
const int block_size,
const int max_draft_tokens) {
const int real_bsz = seq_lens_this_time.shape()[0];
const int max_bsz = stop_flags.shape()[0];
const int accept_tokens_len = accept_tokens.shape()[1];
const int draft_token_len = draft_tokens.shape()[1];
const int block_num_per_seq = block_tables.shape()[1];
constexpr int BlockSize = 512;
const int max_next_step_tokens = 2 * max_draft_tokens + 2;
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
speculate_schedula_cache<BlockSize><<<1, BlockSize, 0, seq_lens_this_time.stream()>>>(
draft_tokens.data<int64_t>(),
const_cast<int *>(block_tables.data<int>()),
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(step_seq_lens_decoder.data<int>()),
const_cast<int64_t *>(step_draft_tokens.data<int64_t>()),
const_cast<int *>(step_seq_lens_this_time.data<int>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<bool *>(is_block_step.data<bool>()),
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
stop_nums.data<int64_t>(),
real_bsz,
max_bsz,
max_next_step_tokens,
draft_token_len,
accept_tokens_len,
block_size,
block_num_per_seq
);
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), true);
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
PD_BUILD_STATIC_OP(speculate_schedule_cache)
.Inputs({"draft_tokens",
"block_tables",
"stop_flags",
"seq_lens_this_time",
"seq_lens_decoder",
"step_seq_lens_decoder",
"step_draft_tokens",
"step_seq_lens_this_time",
"accept_num",
"accept_tokens",
"is_block_step",
"not_need_stop",
"stop_nums"})
.Attrs({"block_size: int", "max_draft_tokens: int"})
.Outputs({"draft_tokens_out",
"block_tables_out",
"stop_flags_out",
"seq_lens_this_time_out",
"seq_lens_decoder_out",
"step_seq_lens_decoder_out",
"step_draft_tokens_out",
"step_seq_lens_this_time_out",
"accept_num_out",
"accept_tokens_out",
"is_block_step_out",
"not_need_stop_out"})
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
{"block_tables", "block_tables_out"},
{"stop_flags", "stop_flags_out"},
{"seq_lens_this_time", "seq_lens_this_time_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"step_seq_lens_decoder", "step_seq_lens_decoder_out"},
{"step_draft_tokens", "step_draft_tokens_out"},
{"step_seq_lens_this_time", "step_seq_lens_this_time_out"},
{"accept_num", "accept_num_out"},
{"accept_tokens", "accept_tokens_out"},
{"is_block_step", "is_block_step_out"},
{"not_need_stop", "not_need_stop_out"},})
.SetKernelFn(PD_KERNEL(SpeculateScheduleCache));

View File

@@ -889,7 +889,7 @@ class CacheConfig:
else: else:
self.kv_cache_ratio = 0.75 self.kv_cache_ratio = 0.75
self.enc_dec_block_num = 0 if current_platform.is_iluvatar() else 2 self.enc_dec_block_num = 0 if current_platform.is_iluvatar() else 2
self.prealloc_dec_block_slot_num_threshold = 5 self.prealloc_dec_block_slot_num_threshold = 12
self.cache_dtype = "bfloat16" self.cache_dtype = "bfloat16"
self.model_cfg = None self.model_cfg = None
self.enable_chunked_prefill = False self.enable_chunked_prefill = False

View File

@@ -165,8 +165,7 @@ class EngineArgs:
""" """
Ratio of tokens to process in a block. Ratio of tokens to process in a block.
""" """
prealloc_dec_block_slot_num_threshold: int = 12
prealloc_dec_block_slot_num_threshold: int = 5
""" """
Token slot threshold for preallocating decoder blocks. Token slot threshold for preallocating decoder blocks.
""" """
@@ -405,8 +404,6 @@ class EngineArgs:
raise NotImplementedError("Logprob does not support enable_expert_parallel.") raise NotImplementedError("Logprob does not support enable_expert_parallel.")
if not current_platform.is_cuda(): if not current_platform.is_cuda():
raise NotImplementedError("Only CUDA platform supports logprob.") raise NotImplementedError("Only CUDA platform supports logprob.")
if self.speculative_config is not None:
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if self.splitwise_role != "mixed": if self.splitwise_role != "mixed":
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if not current_platform.is_cuda(): if not current_platform.is_cuda():
@@ -706,7 +703,7 @@ class EngineArgs:
cache_group.add_argument( cache_group.add_argument(
"--prealloc-dec-block-slot-num-threshold", "--prealloc-dec-block-slot-num-threshold",
type=int, type=int,
default=5, default=12,
help="Number of token slot threadshold to allocate next blocks for decoding.", help="Number of token slot threadshold to allocate next blocks for decoding.",
) )

View File

@@ -62,7 +62,6 @@ class EngineSevice:
self.cfg = cfg self.cfg = cfg
self.scheduler = cfg.scheduler_config.scheduler() self.scheduler = cfg.scheduler_config.scheduler()
if envs.ENABLE_V1_KVCACHE_SCHEDULER: if envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.resource_manager = ResourceManagerV1( self.resource_manager = ResourceManagerV1(
cfg.max_num_seqs, cfg.max_num_seqs,

View File

@@ -84,10 +84,14 @@ class ResourceManagerV1(ResourceManager):
return len(request.block_tables) * self.config.cache_config.block_size return len(request.block_tables) * self.config.cache_config.block_size
def get_new_block_nums(self, request: Request, num_new_tokens: int): def get_new_block_nums(self, request: Request, num_new_tokens: int):
return ( block_num = (
request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1 request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1
) // self.config.cache_config.block_size - len(request.block_tables) ) // self.config.cache_config.block_size - len(request.block_tables)
if self.config.speculative_config.method is not None:
block_num = min(block_num + 1, self.config.cache_config.max_block_num_per_seq)
return block_num
def _prepare_prefill_task(self, request, new_token_num): def _prepare_prefill_task(self, request, new_token_num):
request.prefill_start_index = request.num_computed_tokens request.prefill_start_index = request.num_computed_tokens
request.prefill_end_index = request.num_computed_tokens + new_token_num request.prefill_end_index = request.num_computed_tokens + new_token_num

View File

@@ -100,6 +100,8 @@ class AppendAttentionBackend(AttentionBackend):
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr( self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
fd_config.model_config, "use_3d_rope", False fd_config.model_config, "use_3d_rope", False
) )
if fd_config.speculative_config.model_type != "main":
self.rope_3d = False
self.causal: bool = getattr(fd_config.model_config, "causal", True) self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.speculative_method: str = fd_config.speculative_config.method self.speculative_method: str = fd_config.speculative_config.method
self.use_speculate: bool = self.speculative_method is not None self.use_speculate: bool = self.speculative_method is not None
@@ -356,7 +358,7 @@ class AppendAttentionBackend(AttentionBackend):
getattr(layer, "cache_v_zp", None), getattr(layer, "cache_v_zp", None),
layer.linear_shift, layer.linear_shift,
layer.linear_smooth, layer.linear_smooth,
forward_meta.attn_mask_offsets, None if self.use_speculate else forward_meta.attn_mask_offsets,
metadata.kv_signal_data_list[layer.layer_id], metadata.kv_signal_data_list[layer.layer_id],
getattr(layer, "q_norm_weight", None), getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None), getattr(layer, "k_norm_weight", None),
@@ -374,7 +376,7 @@ class AppendAttentionBackend(AttentionBackend):
metadata.max_partition_size, metadata.max_partition_size,
metadata.encoder_max_partition_size, metadata.encoder_max_partition_size,
self.speculate_max_draft_token_num + 1, self.speculate_max_draft_token_num + 1,
self.causal, self.causal or self.use_speculate,
self.speculative_method is not None, self.speculative_method is not None,
) )
return res return res

View File

@@ -306,7 +306,9 @@ def post_process_normal(
) )
def post_process_specualate(model_output, save_each_rank: bool = False, skip_save_output: bool = False): def post_process_specualate(
model_output: ModelOutputData, save_each_rank: bool = False, skip_save_output: bool = False
):
"""""" """"""
speculate_update( speculate_update(
model_output.seq_lens_encoder, model_output.seq_lens_encoder,

View File

@@ -261,7 +261,7 @@ class TokenProcessor:
def _compute_speculative_status(self): def _compute_speculative_status(self):
# TODO(liuzichang): Supplement more statistics # TODO(liuzichang): Supplement more statistics
interval = 10 interval = 1
if self.speculative_stats_step % interval == 0: if self.speculative_stats_step % interval == 0:
accept_ratio = 1 - self.total_step * 1.0 / self.number_of_output_tokens accept_ratio = 1 - self.total_step * 1.0 / self.number_of_output_tokens
spec_logger.info( spec_logger.info(

View File

@@ -19,8 +19,10 @@ from typing import List
import numpy as np import numpy as np
import paddle import paddle
from paddleformers.utils.log import logger
from fastdeploy.engine.request import Request from fastdeploy import envs
from fastdeploy.engine.request import Request, RequestType
from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.attention import get_attention_backend from fastdeploy.model_executor.layers.attention import get_attention_backend
from fastdeploy.model_executor.layers.attention.base_attention_backend import ( from fastdeploy.model_executor.layers.attention.base_attention_backend import (
@@ -50,14 +52,14 @@ class MTPProposer(Proposer):
Proposer for Multi-Token-Prediction(MTP) Proposer for Multi-Token-Prediction(MTP)
""" """
def __init__(self, cfg, main_model, local_rank, device_id, main_model_inputs): def __init__(self, cfg, main_model, local_rank, device_id, target_model_inputs):
super().__init__(cfg) super().__init__(cfg)
self.num_main_model_layers = self.model_config.num_hidden_layers self.num_main_model_layers = self.model_config.num_hidden_layers
self.local_rank = local_rank self.local_rank = local_rank
self.device_id = device_id self.device_id = device_id
self._update_cfg(main_model) self._update_cfg(main_model)
self._load_model() self._load_model()
self.main_model_inputs = main_model_inputs self.target_model_inputs = target_model_inputs
self.mtp_strategy = self.speculative_config.mtp_strategy self.mtp_strategy = self.speculative_config.mtp_strategy
self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps
@@ -73,7 +75,7 @@ class MTPProposer(Proposer):
""" """
Update config for MTP from global config Update config for MTP from global config
""" """
self.model_config.architectures[0] = "Ernie4_5_MTPForCausalLM" self.model_config.architectures[0] = self.model_config.architectures[0].replace("Moe", "MTP")
self.speculative_config.sharing_model = main_model self.speculative_config.sharing_model = main_model
self.model_config.num_hidden_layers = 1 self.model_config.num_hidden_layers = 1
self.model_config.model = self.speculative_config.model self.model_config.model = self.speculative_config.model
@@ -199,14 +201,16 @@ class MTPProposer(Proposer):
encoder_block_shape_q = 64 encoder_block_shape_q = 64
decoder_block_shape_q = 16 decoder_block_shape_q = 16
self.model_inputs["decoder_batch_ids"] = paddle.zeros_like(self.main_model_inputs["decoder_batch_ids"]) self.model_inputs["decoder_batch_ids"] = paddle.zeros_like(self.target_model_inputs["decoder_batch_ids"])
self.model_inputs["decoder_tile_ids_per_batch"] = paddle.zeros_like( self.model_inputs["decoder_tile_ids_per_batch"] = paddle.zeros_like(
self.main_model_inputs["decoder_tile_ids_per_batch"] self.target_model_inputs["decoder_tile_ids_per_batch"]
) )
self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like( self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like(
self.main_model_inputs["decoder_num_blocks_cpu"] self.target_model_inputs["decoder_num_blocks_cpu"]
).pin_memory() ).pin_memory()
self.model_inputs["max_len_tensor_cpu"] = paddle.zeros_like(self.main_model_inputs["max_len_tensor_cpu"]).cpu() self.model_inputs["max_len_tensor_cpu"] = paddle.zeros_like(
self.target_model_inputs["max_len_tensor_cpu"]
).cpu()
# Get the attention backend # Get the attention backend
attn_cls = get_attention_backend() attn_cls = get_attention_backend()
@@ -265,28 +269,29 @@ class MTPProposer(Proposer):
""" """
self.model_inputs = {} self.model_inputs = {}
# Same shape/dytpe with base model # Same shape/dytpe with base model
self.model_inputs["block_tables"] = paddle.clone(self.main_model_inputs["block_tables"]) self.model_inputs["block_tables"] = paddle.clone(self.target_model_inputs["block_tables"])
self.model_inputs["input_ids"] = paddle.clone(self.main_model_inputs["input_ids"]) self.model_inputs["input_ids"] = paddle.clone(self.target_model_inputs["input_ids"])
self.seq_lens_this_time_buffer = paddle.clone(self.main_model_inputs["seq_lens_this_time"])
self.model_inputs["input_ids_cpu"] = paddle.full( self.model_inputs["input_ids_cpu"] = paddle.full(
shape=[self.max_num_seqs, self.parallel_config.max_model_len], shape=[self.max_num_seqs, self.parallel_config.max_model_len],
fill_value=-1, fill_value=-1,
dtype="int64", dtype="int64",
).cpu() ).cpu()
self.model_inputs["seq_lens_encoder"] = paddle.clone(self.main_model_inputs["seq_lens_encoder"]) self.seq_lens_this_time_buffer = paddle.clone(self.target_model_inputs["seq_lens_this_time"])
self.model_inputs["seq_lens_decoder"] = paddle.clone(self.main_model_inputs["seq_lens_decoder"])
self.model_inputs["step_idx"] = paddle.clone(self.main_model_inputs["step_idx"]) self.model_inputs["seq_lens_encoder"] = paddle.clone(self.target_model_inputs["seq_lens_encoder"])
self.model_inputs["stop_flags"] = paddle.clone(self.main_model_inputs["stop_flags"]) self.model_inputs["seq_lens_decoder"] = paddle.clone(self.target_model_inputs["seq_lens_decoder"])
self.model_inputs["stop_nums"] = paddle.clone(self.main_model_inputs["stop_nums"]) self.model_inputs["step_idx"] = paddle.clone(self.target_model_inputs["step_idx"])
self.model_inputs["stop_flags"] = paddle.clone(self.target_model_inputs["stop_flags"])
self.model_inputs["stop_nums"] = paddle.clone(self.target_model_inputs["stop_nums"])
self.model_inputs["not_need_stop"] = paddle.to_tensor([False], dtype="bool", place="cpu") self.model_inputs["not_need_stop"] = paddle.to_tensor([False], dtype="bool", place="cpu")
self.model_inputs["pre_ids"] = paddle.clone(self.main_model_inputs["pre_ids"]) self.model_inputs["pre_ids"] = paddle.clone(self.target_model_inputs["pre_ids"])
self.model_inputs["ids_remove_padding"] = paddle.clone(self.main_model_inputs["ids_remove_padding"]) self.model_inputs["ids_remove_padding"] = paddle.clone(self.target_model_inputs["ids_remove_padding"])
self.model_inputs["batch_id_per_token"] = paddle.clone(self.main_model_inputs["batch_id_per_token"]) self.model_inputs["batch_id_per_token"] = paddle.clone(self.target_model_inputs["batch_id_per_token"])
self.model_inputs["cu_seqlens_q"] = paddle.clone(self.main_model_inputs["cu_seqlens_q"]) self.model_inputs["cu_seqlens_q"] = paddle.clone(self.target_model_inputs["cu_seqlens_q"])
self.model_inputs["cu_seqlens_k"] = paddle.clone(self.main_model_inputs["cu_seqlens_k"]) self.model_inputs["cu_seqlens_k"] = paddle.clone(self.target_model_inputs["cu_seqlens_k"])
self.model_inputs["decoder_batch_ids"] = paddle.clone(self.main_model_inputs["decoder_batch_ids"]) self.model_inputs["decoder_batch_ids"] = paddle.clone(self.target_model_inputs["decoder_batch_ids"])
self.model_inputs["decoder_tile_ids_per_batch"] = paddle.clone( self.model_inputs["decoder_tile_ids_per_batch"] = paddle.clone(
self.main_model_inputs["decoder_tile_ids_per_batch"] self.target_model_inputs["decoder_tile_ids_per_batch"]
) )
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1)) tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
@@ -298,22 +303,22 @@ class MTPProposer(Proposer):
) )
# self.model_inputs["caches"] = self.cache_kvs # self.model_inputs["caches"] = self.cache_kvs
# Inherit generation hyperparameters from the main model for consistency # Inherit generation hyperparameters from the main model for consistency
self.model_inputs["top_p"] = self.main_model_inputs["top_p"] self.model_inputs["top_p"] = self.target_model_inputs["top_p"]
self.model_inputs["top_k"] = self.main_model_inputs["top_k"] self.model_inputs["top_k"] = self.target_model_inputs["top_k"]
self.model_inputs["temperature"] = self.main_model_inputs["temperature"] self.model_inputs["temperature"] = self.target_model_inputs["temperature"]
self.model_inputs["eos_token_id"] = self.main_model_inputs["eos_token_id"] self.model_inputs["eos_token_id"] = self.target_model_inputs["eos_token_id"]
self.model_inputs["penalty_score"] = self.main_model_inputs["penalty_score"] self.model_inputs["penalty_score"] = self.target_model_inputs["penalty_score"]
self.model_inputs["frequency_score"] = self.main_model_inputs["frequency_score"] self.model_inputs["frequency_score"] = self.target_model_inputs["frequency_score"]
self.model_inputs["presence_score"] = self.main_model_inputs["presence_score"] self.model_inputs["presence_score"] = self.target_model_inputs["presence_score"]
self.model_inputs["infer_seed"] = self.main_model_inputs["infer_seed"] self.model_inputs["infer_seed"] = self.target_model_inputs["infer_seed"]
self.model_inputs["max_dec_len"] = self.main_model_inputs["max_dec_len"] self.model_inputs["max_dec_len"] = self.target_model_inputs["max_dec_len"]
self.model_inputs["min_dec_len"] = self.main_model_inputs["min_dec_len"] self.model_inputs["min_dec_len"] = self.target_model_inputs["min_dec_len"]
self.model_inputs["bad_tokens"] = self.main_model_inputs["bad_tokens"] self.model_inputs["bad_tokens"] = self.target_model_inputs["bad_tokens"]
# Integrate the updated results in model forward # Integrate the updated results in model forward
self.model_inputs["base_model_draft_tokens"] = self.main_model_inputs["draft_tokens"] self.model_inputs["base_model_draft_tokens"] = self.target_model_inputs["draft_tokens"]
self.model_inputs["substep"] = 0 self.model_inputs["substep"] = 0
# Declare AttentionBackend buffers # Declare AttentionBackend buffers
@@ -327,7 +332,7 @@ class MTPProposer(Proposer):
shape=[self.max_num_seqs, self.max_draft_token_num + 1], fill_value=-1, dtype="int64" shape=[self.max_num_seqs, self.max_draft_token_num + 1], fill_value=-1, dtype="int64"
) )
self.model_inputs["encoder_block_lens"] = paddle.clone(self.main_model_inputs["encoder_block_lens"]) self.model_inputs["encoder_block_lens"] = paddle.clone(self.target_model_inputs["encoder_block_lens"])
self.free_list = list( self.free_list = list(
range( range(
@@ -341,14 +346,76 @@ class MTPProposer(Proposer):
self.model_inputs["free_list"] = paddle.to_tensor(self.free_list, dtype="int32") self.model_inputs["free_list"] = paddle.to_tensor(self.free_list, dtype="int32")
self.model_inputs["free_list_len"] = paddle.full(shape=[1], fill_value=self.free_list_len, dtype="int32") self.model_inputs["free_list_len"] = paddle.full(shape=[1], fill_value=self.free_list_len, dtype="int32")
self.model_inputs["is_block_step"] = paddle.full(shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool")
self.model_inputs["batch_drop"] = paddle.full(shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool") self.model_inputs["batch_drop"] = paddle.full(shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool")
self.model_inputs["used_list_len"] = paddle.full(shape=[self.max_num_seqs], fill_value=0, dtype="int32") self.model_inputs["used_list_len"] = paddle.full(shape=[self.max_num_seqs], fill_value=0, dtype="int32")
if self.num_model_steps > 1: if self.num_model_steps > 1:
self.last_seq_lens_this_time = paddle.full_like( self.last_seq_lens_this_time = paddle.full_like(
self.main_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32" self.target_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32"
) )
self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu() self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu()
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):
if "caches" not in self.model_inputs:
self.initialize_kv_cache()
req_len = len(req_dicts)
for i in range(req_len):
request = req_dicts[i]
logger.debug(f"{i}th request-{request.request_id}: {request}")
idx = request.idx
if request.task_type.value == RequestType.PREFILL.value: # prefill task
prefill_start_index = request.prefill_start_index
prefill_end_index = request.prefill_end_index
length = prefill_end_index - prefill_start_index
input_ids = request.prompt_token_ids + request.output_token_ids
self.input_ids_len[idx] = length
self.model_inputs["pre_ids"][idx : idx + 1] = -1
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs["input_ids"][
idx : idx + 1, 1:length
]
encoder_block_num = len(request.block_tables)
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.block_tables, dtype="int32"
)
self.model_inputs["stop_flags"][idx : idx + 1] = False
self.model_inputs["batch_drop"][idx : idx + 1] = False
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index
self.seq_lens_this_time_buffer[idx : idx + 1] = length
self.model_inputs["step_idx"][idx : idx + 1] = (
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0
)
# has_prefill_task = True
elif request.task_type.value == RequestType.DECODE.value: # decode task
encoder_block_num = len(request.block_tables)
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.block_tables, dtype="int32"
)
# if self.model_inputs["is_block_step"][idx]: # has tasks to continue to decode
# has_decode_task = True
# continue
else:
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
self.model_inputs["stop_flags"][idx : idx + 1] = True
self.seq_lens_this_time_buffer[idx : idx + 1] = 0
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = 0
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
self.model_inputs["is_block_step"][idx : idx + 1] = False
continue
# if has_prefill_task or has_decode_task:
# self.model_inputs["not_need_stop"][0] = True
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int): def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int):
""" """
Process inputs for prefill tasks and insert it to model_inputs buffer Process inputs for prefill tasks and insert it to model_inputs buffer
@@ -408,9 +475,9 @@ class MTPProposer(Proposer):
length = len(request.prompt_token_ids) length = len(request.prompt_token_ids)
if length > 1: if length > 1:
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.main_model_inputs["input_ids"][ self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[
idx : idx + 1, 1:length "input_ids"
] ][idx : idx + 1, 1:length]
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array( self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array(
request.prompt_token_ids request.prompt_token_ids
)[1:] )[1:]
@@ -470,6 +537,7 @@ class MTPProposer(Proposer):
""" """
Prepare MTP inputs Prepare MTP inputs
""" """
use_v1_cache_scheduler = envs.ENABLE_V1_KVCACHE_SCHEDULER
draft_model_preprocess( draft_model_preprocess(
self.model_inputs["draft_tokens"], self.model_inputs["draft_tokens"],
self.model_inputs["input_ids"], self.model_inputs["input_ids"],
@@ -480,19 +548,21 @@ class MTPProposer(Proposer):
self.model_inputs["step_idx"], self.model_inputs["step_idx"],
self.model_inputs["not_need_stop"], self.model_inputs["not_need_stop"],
self.model_inputs["batch_drop"], self.model_inputs["batch_drop"],
self.model_inputs["is_block_step"],
self.model_inputs["pre_ids"], self.model_inputs["pre_ids"],
self.main_model_inputs["accept_tokens"], self.target_model_inputs["accept_tokens"],
self.main_model_inputs["accept_num"], self.target_model_inputs["accept_num"],
self.main_model_inputs["seq_lens_this_time"], self.target_model_inputs["seq_lens_this_time"],
self.main_model_inputs["seq_lens_encoder"], self.target_model_inputs["seq_lens_encoder"],
self.main_model_inputs["seq_lens_decoder"], self.target_model_inputs["seq_lens_decoder"],
self.main_model_inputs["step_idx"], self.target_model_inputs["step_idx"],
self.main_model_inputs["stop_flags"], self.target_model_inputs["stop_flags"],
self.main_model_inputs["is_block_step"], self.target_model_inputs["is_block_step"],
self.main_model_inputs["draft_tokens"], self.target_model_inputs["draft_tokens"],
self.num_model_steps, self.num_model_steps,
self.speculative_method in ["eagle", "mtp"], self.speculative_method in ["eagle", "mtp"],
self.role == "prefill", self.role == "prefill",
use_v1_cache_scheduler,
) )
target_hidden_states = eagle_get_hidden_states( target_hidden_states = eagle_get_hidden_states(
@@ -501,9 +571,9 @@ class MTPProposer(Proposer):
self.model_inputs["seq_lens_encoder"], self.model_inputs["seq_lens_encoder"],
self.model_inputs["seq_lens_decoder"], self.model_inputs["seq_lens_decoder"],
self.model_inputs["stop_flags"], self.model_inputs["stop_flags"],
self.main_model_inputs["accept_num"], self.target_model_inputs["accept_num"],
self.main_model_inputs["seq_lens_this_time"], self.target_model_inputs["seq_lens_this_time"],
self.main_model_inputs["seq_lens_encoder"], self.target_model_inputs["seq_lens_encoder"],
self.num_model_steps, self.num_model_steps,
) )
if isinstance(target_hidden_states, list): if isinstance(target_hidden_states, list):
@@ -673,41 +743,41 @@ class MTPProposer(Proposer):
Allocate/Free block of MPT. Allocate/Free block of MPT.
""" """
draft_model_postprocess( draft_model_postprocess(
self.main_model_inputs["draft_tokens"], self.target_model_inputs["draft_tokens"],
self.main_model_inputs["seq_lens_this_time"], self.target_model_inputs["seq_lens_this_time"],
self.main_model_inputs["seq_lens_encoder"], self.target_model_inputs["seq_lens_encoder"],
self.main_model_inputs["stop_flags"], self.target_model_inputs["stop_flags"],
)
mtp_step_paddle(
self.main_model_inputs["stop_flags"],
self.model_inputs["stop_flags"],
self.model_inputs["batch_drop"],
self.model_inputs["seq_lens_this_time"],
self.model_inputs["seq_lens_encoder"],
self.model_inputs["seq_lens_decoder"],
self.model_inputs["block_tables"],
self.model_inputs["encoder_block_lens"],
self.model_inputs["used_list_len"],
self.model_inputs["free_list"],
self.model_inputs["free_list_len"],
self.cache_config.block_size,
self.max_draft_token_num,
) )
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
mtp_step_paddle(
self.target_model_inputs["stop_flags"],
self.model_inputs["stop_flags"],
self.model_inputs["batch_drop"],
self.model_inputs["seq_lens_this_time"],
self.model_inputs["seq_lens_encoder"],
self.model_inputs["seq_lens_decoder"],
self.model_inputs["block_tables"],
self.model_inputs["encoder_block_lens"],
self.model_inputs["used_list_len"],
self.model_inputs["free_list"],
self.model_inputs["free_list_len"],
self.cache_config.block_size,
self.max_draft_token_num,
)
def _extend_draft_token_with_ngram_match(self): def _extend_draft_token_with_ngram_match(self):
# TODO(liuzichang): Optimize this Kernel to CUDA Kernel to reduce lantency # TODO(liuzichang): Optimize this Kernel to CUDA Kernel to reduce lantency
device = paddle.CUDAPinnedPlace() device = paddle.CUDAPinnedPlace()
draft_tokens = self.main_model_inputs["draft_tokens"].cpu() draft_tokens = self.target_model_inputs["draft_tokens"].cpu()
seq_lens_this_time = self.main_model_inputs["seq_lens_this_time"].cpu() seq_lens_this_time = self.target_model_inputs["seq_lens_this_time"].cpu()
seq_lens_decoder = self.model_inputs["seq_lens_decoder"].cpu() seq_lens_decoder = self.model_inputs["seq_lens_decoder"].cpu()
hybrid_mtp_ngram( hybrid_mtp_ngram(
self.model_inputs["input_ids_cpu"], self.model_inputs["input_ids_cpu"],
self.input_ids_len, self.input_ids_len,
self.model_inputs["pre_ids"]._copy_to(device, True), self.model_inputs["pre_ids"]._copy_to(device, True),
self.model_inputs["step_idx"].cpu(), self.model_inputs["step_idx"].cpu(),
self.main_model_inputs["actual_draft_token_num"].cpu(), self.target_model_inputs["actual_draft_token_num"].cpu(),
draft_tokens, draft_tokens,
seq_lens_this_time, seq_lens_this_time,
seq_lens_decoder, seq_lens_decoder,
@@ -716,8 +786,8 @@ class MTPProposer(Proposer):
self.min_ngram_size, self.min_ngram_size,
self.max_draft_token_num, self.max_draft_token_num,
) )
self.main_model_inputs["draft_tokens"][:] = draft_tokens.cuda() self.target_model_inputs["draft_tokens"][:] = draft_tokens.cuda()
self.main_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda() self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()
def _run_impl(self, full_hidden_states): def _run_impl(self, full_hidden_states):
"""""" """"""

View File

@@ -59,6 +59,7 @@ else:
recover_decode_task, recover_decode_task,
set_value_by_flags_and_idx, set_value_by_flags_and_idx,
share_external_data, share_external_data,
speculate_schedule_cache,
) )
from fastdeploy.model_executor.pre_and_post_process import ( from fastdeploy.model_executor.pre_and_post_process import (
@@ -383,6 +384,8 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["stop_flags"].sum() == self.parallel_config.max_num_seqs self.share_inputs["stop_flags"].sum() == self.parallel_config.max_num_seqs
) )
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests] self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
if self.speculative_method in ["mtp"]:
self.proposer.insert_tasks_v1(req_dicts, num_running_requests)
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int = None): def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int = None):
""" """
@@ -803,6 +806,13 @@ class GPUModelRunner(ModelRunnerBase):
fill_value=0, fill_value=0,
dtype="int32", dtype="int32",
) )
# For V1_KVCACHE_SCHEDULER
self.share_inputs["step_draft_tokens"] = paddle.full(
shape=[max_num_seqs, max_draft_token_num + 1],
fill_value=0,
dtype="int64",
)
self.share_inputs["step_seq_lens_this_time"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
if self.enable_mm: if self.enable_mm:
head_dim = self.model_config.head_dim head_dim = self.model_config.head_dim
@@ -841,7 +851,11 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["step_seq_lens_decoder"], self.share_inputs["step_seq_lens_decoder"],
self.share_inputs["block_tables"], self.share_inputs["block_tables"],
self.share_inputs["is_block_step"], self.share_inputs["is_block_step"],
self.share_inputs["draft_tokens"] if self.speculative_decoding else None,
self.share_inputs["step_draft_tokens"] if self.speculative_decoding else None,
self.share_inputs["step_seq_lens_this_time"] if self.speculative_decoding else None,
self.cache_config.block_size, self.cache_config.block_size,
self.speculative_config.num_speculative_tokens if self.speculative_decoding else 0,
) )
# Remove padding # Remove padding
@@ -1540,6 +1554,24 @@ class GPUModelRunner(ModelRunnerBase):
self._update_chunked_prefill(model_forward_batch) self._update_chunked_prefill(model_forward_batch)
self._add_cache(model_forward_batch) self._add_cache(model_forward_batch)
elif self.speculative_decoding:
speculate_schedule_cache(
self.share_inputs["draft_tokens"],
self.share_inputs["block_tables"],
self.share_inputs["stop_flags"],
self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_decoder"],
self.share_inputs["step_seq_lens_decoder"],
self.share_inputs["step_draft_tokens"],
self.share_inputs["step_seq_lens_this_time"],
self.share_inputs["accept_num"],
self.share_inputs["accept_tokens"],
self.share_inputs["is_block_step"],
self.share_inputs["not_need_stop"],
self.share_inputs["stop_nums"],
self.cache_config.block_size,
self.speculative_config.num_speculative_tokens,
)
self.seq_lens_this_time_buffer[:num_running_requests].copy_( self.seq_lens_this_time_buffer[:num_running_requests].copy_(
self.share_inputs["seq_lens_this_time"][:num_running_requests], False self.share_inputs["seq_lens_this_time"][:num_running_requests], False

View File

@@ -742,13 +742,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}") logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
logger.info(f"- Load strategy: {load_config.load_strategy}") logger.info(f"- Load strategy: {load_config.load_strategy}")
if (
args.speculative_config is not None
and ("method" in args.speculative_config)
and (args.speculative_config["method"] is not None)
):
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not support speculative decoding now.")
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if args.splitwise_role != "mixed": if args.splitwise_role != "mixed":
logger.info(f"Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported {args.splitwise_role} now.") logger.info(f"Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported {args.splitwise_role} now.")
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 envs.ENABLE_V1_KVCACHE_SCHEDULER = 0

View File

@@ -0,0 +1,238 @@
import unittest
import numpy as np
import paddle
from fastdeploy.model_executor.ops.gpu import speculate_schedule_cache
def cpu_reference(
draft_tokens,
block_tables,
stop_flags,
seq_lens_this_time,
seq_lens_decoder,
step_seq_lens_decoder,
step_draft_tokens,
step_seq_lens_this_time,
accept_num,
accept_tokens,
is_block_step,
not_need_stop,
stop_nums,
block_size,
max_draft_tokens,
):
"""Pure-NumPy mirror of the CUDA kernel's logic (single block of 512 threads).
Shapes are the same as inputs to the custom op. This mutates the provided
NumPy arrays in-place, exactly like the kernel does.
"""
real_bsz = seq_lens_this_time.shape[0]
max_bsz = stop_flags.shape[0]
draft_tokens_len = draft_tokens.shape[1]
block_num_per_seq = block_tables.shape[1]
max_next_step_tokens = 2 * max_draft_tokens + 2
# Block-local reduction input per thread (threadIdx.x -> bid)
stop_flag_now_int = np.zeros(512, dtype=np.int64) # THREADBLOCK_SIZE = 512
for bid in range(512):
if bid < real_bsz:
if not stop_flags[bid]:
max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) // block_size
if max_possible_block_idx < block_num_per_seq and block_tables[bid, max_possible_block_idx] == -1:
is_block_step[bid] = True
step_seq_lens_this_time[bid] = seq_lens_this_time[bid]
seq_lens_this_time[bid] = 0
stop_flags[bid] = True
step_seq_lens_decoder[bid] = seq_lens_decoder[bid]
seq_lens_decoder[bid] = 0
accept_num[bid] = 0
accept_tokens[bid, :] = -1
step_draft_tokens[bid, :draft_tokens_len] = draft_tokens[bid, :draft_tokens_len]
stop_flag_now_int[bid] = 1
else:
stop_flag_now_int[bid] = 0
else:
stop_flag_now_int[bid] = 1
elif bid < max_bsz:
# Threads in [real_bsz, max_bsz) contribute 1 to reduction
stop_flag_now_int[bid] = 1
else:
stop_flag_now_int[bid] = 0
stop_sum = int(stop_flag_now_int.sum())
not_need_stop[0] = stop_sum < int(stop_nums[0])
class TestSpeculateScheduleCache(unittest.TestCase):
@classmethod
def setUpClass(cls):
if not paddle.is_compiled_with_cuda():
raise unittest.SkipTest("Paddle is not compiled with CUDA; skipping GPU op test.")
paddle.device.set_device("gpu")
def setUp(self):
# --- Construct a deterministic case that exercises all branches ---
# real_bsz < max_bsz to test the padding logic in the CUB reduction
self.real_bsz = 3
self.max_bsz = 5 # only stop_flags has length max_bsz
self.draft_tokens_len = 6
self.accept_tokens_len = 5
self.block_size = 4
self.block_num_per_seq = 3
self.max_draft_tokens = 2 # -> max_next_step_tokens = 6
# Inputs that will trigger for bid 0, not trigger for bid 2, and bid 1 is already stopped
# seq_lens_decoder + 6 // 4 -> indices: [1, 1, 4]. Index 4 is out of range -> no trigger on bid 2
self.draft_tokens = paddle.to_tensor(
np.array(
[
[1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3],
],
dtype=np.int64,
)
)
self.block_tables = paddle.to_tensor(np.full((self.real_bsz, self.block_num_per_seq), -1, dtype=np.int32))
# stop_flags length is max_bsz, others are real_bsz
self.stop_flags = paddle.to_tensor(np.array([False, True, False, False, False], dtype=np.bool_))
self.seq_lens_this_time = paddle.to_tensor(np.array([5, 6, 7], dtype=np.int32))
self.seq_lens_decoder = paddle.to_tensor(np.array([1, 1, 10], dtype=np.int32))
# Will be filled by kernel for the triggering bids only
self.step_seq_lens_decoder = paddle.zeros((self.real_bsz,), dtype="int32")
self.step_draft_tokens = paddle.zeros((self.real_bsz, self.draft_tokens_len), dtype="int64")
self.step_seq_lens_this_time = paddle.zeros((self.real_bsz,), dtype="int32")
# Intentionally non-zero so we can verify in-place zeroing only where triggered
self.accept_num = paddle.to_tensor(np.array([9, 8, 7], dtype=np.int32))
self.accept_tokens = paddle.to_tensor(
np.arange(self.real_bsz * self.accept_tokens_len, dtype=np.int64).reshape(
self.real_bsz, self.accept_tokens_len
)
)
self.is_block_step = paddle.zeros((self.real_bsz,), dtype=paddle.bool)
# not_need_stop lives on CPU in the caller; the kernel copies to device internally
self.not_need_stop = paddle.zeros((1,), dtype=paddle.bool).cpu()
# Choose threshold so with: bid0 triggers, bid1 already stopped, padding (5-3)=2 -> stop_sum = 1+1+2 = 4
# Set stop_nums to 5 so not_need_stop = (4 < 5) = True
self.stop_nums = paddle.to_tensor([5], dtype=paddle.int64)
# Keep NumPy copies for CPU reference
self.np_draft_tokens = self.draft_tokens.numpy().copy()
self.np_block_tables = self.block_tables.numpy().copy()
self.np_stop_flags = self.stop_flags.numpy().copy()
self.np_seq_lens_this_time = self.seq_lens_this_time.numpy().copy()
self.np_seq_lens_decoder = self.seq_lens_decoder.numpy().copy()
self.np_step_seq_lens_decoder = self.step_seq_lens_decoder.numpy().copy()
self.np_step_draft_tokens = self.step_draft_tokens.numpy().copy()
self.np_step_seq_lens_this_time = self.step_seq_lens_this_time.numpy().copy()
self.np_accept_num = self.accept_num.numpy().copy()
self.np_accept_tokens = self.accept_tokens.numpy().copy()
self.np_is_block_step = self.is_block_step.numpy().copy()
self.np_not_need_stop = self.not_need_stop.numpy().copy()
self.np_stop_nums = self.stop_nums.numpy().copy()
def test_correctness_against_cpu_reference(self):
# Run GPU kernel (in-place)
speculate_schedule_cache(
self.draft_tokens,
self.block_tables,
self.stop_flags,
self.seq_lens_this_time,
self.seq_lens_decoder,
self.step_seq_lens_decoder,
self.step_draft_tokens,
self.step_seq_lens_this_time,
self.accept_num,
self.accept_tokens,
self.is_block_step,
self.not_need_stop,
self.stop_nums,
self.block_size,
self.max_draft_tokens,
)
# Compute CPU reference (in-place on NumPy copies)
cpu_reference(
self.np_draft_tokens,
self.np_block_tables,
self.np_stop_flags,
self.np_seq_lens_this_time,
self.np_seq_lens_decoder,
self.np_step_seq_lens_decoder,
self.np_step_draft_tokens,
self.np_step_seq_lens_this_time,
self.np_accept_num,
self.np_accept_tokens,
self.np_is_block_step,
self.np_not_need_stop,
self.np_stop_nums,
self.block_size,
self.max_draft_tokens,
)
# Compare all mutated tensors
np.testing.assert_array_equal(self.step_draft_tokens.numpy(), self.np_step_draft_tokens)
np.testing.assert_array_equal(self.accept_tokens.numpy(), self.np_accept_tokens)
np.testing.assert_array_equal(self.stop_flags.numpy(), self.np_stop_flags)
np.testing.assert_array_equal(self.is_block_step.numpy(), self.np_is_block_step)
np.testing.assert_array_equal(self.seq_lens_this_time.numpy(), self.np_seq_lens_this_time)
np.testing.assert_array_equal(self.seq_lens_decoder.numpy(), self.np_seq_lens_decoder)
np.testing.assert_array_equal(self.step_seq_lens_decoder.numpy(), self.np_step_seq_lens_decoder)
np.testing.assert_array_equal(self.step_seq_lens_this_time.numpy(), self.np_step_seq_lens_this_time)
np.testing.assert_array_equal(self.accept_num.numpy(), self.np_accept_num)
self.assertEqual(bool(self.not_need_stop.numpy()[0]), bool(self.np_not_need_stop[0]))
def test_no_trigger_path(self):
# Make block_tables at candidate index != -1 so nothing triggers
# Candidate index for bid 0/1 is 1, set it to 7
bt = self.block_tables.numpy()
bt[:, 1] = 7
self.block_tables = paddle.to_tensor(bt)
# Reset outputs to distinctive values
self.step_seq_lens_decoder[:] = 0
self.step_draft_tokens[:] = 0
self.step_seq_lens_this_time[:] = 0
self.accept_num[:] = -123
self.accept_tokens[:] = -777
self.is_block_step[:] = False
self.not_need_stop[:] = False
# For not_need_stop: stopped_in_real = (bid1 True) = 1, padding = 2 -> stop_sum=3
# With stop_nums=5 -> True
speculate_schedule_cache(
self.draft_tokens,
self.block_tables,
self.stop_flags,
self.seq_lens_this_time,
self.seq_lens_decoder,
self.step_seq_lens_decoder,
self.step_draft_tokens,
self.step_seq_lens_this_time,
self.accept_num,
self.accept_tokens,
self.is_block_step,
self.not_need_stop,
self.stop_nums,
self.block_size,
self.max_draft_tokens,
)
# Nothing should have changed except not_need_stop
np.testing.assert_array_equal(self.step_draft_tokens.numpy(), np.zeros_like(self.step_draft_tokens.numpy()))
np.testing.assert_array_equal(self.is_block_step.numpy(), np.zeros_like(self.is_block_step.numpy()))
np.testing.assert_array_equal(self.accept_tokens.numpy(), np.full_like(self.accept_tokens.numpy(), -777))
np.testing.assert_array_equal(self.accept_num.numpy(), np.full_like(self.accept_num.numpy(), -123))
self.assertTrue(bool(self.not_need_stop.numpy()[0]))
if __name__ == "__main__":
unittest.main()

View File

@@ -5,12 +5,16 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.incubate.nn.functional import fused_rms_norm
from fastdeploy.model_executor.layers.attention.ops import ( from fastdeploy.model_executor.layers.attention.ops import (
append_attention, append_attention,
get_block_shape_and_split_kv_block, get_block_shape_and_split_kv_block,
) )
np.random.seed(0)
paddle.seed(0)
class TestTreeMask(unittest.TestCase): class TestTreeMask(unittest.TestCase):
def setUp(self): def setUp(self):
@@ -27,6 +31,7 @@ class TestTreeMask(unittest.TestCase):
self.head_dim = 128 self.head_dim = 128
self.num_q_head = 20 self.num_q_head = 20
self.num_kv_head = 4 self.num_kv_head = 4
self.use_qknorm = True
self.dtype = "bfloat16" self.dtype = "bfloat16"
self.rope_3d = False self.rope_3d = False
@@ -91,12 +96,20 @@ class TestTreeMask(unittest.TestCase):
cu_seqlens_k[i + 1] = cum_seq_len_k cu_seqlens_k[i + 1] = cum_seq_len_k
return paddle.to_tensor(batch_id_per_token, dtype="int32"), cu_seqlens_q, cu_seqlens_k return paddle.to_tensor(batch_id_per_token, dtype="int32"), cu_seqlens_q, cu_seqlens_k
def ref_attention(self, q, k, v, mask): def ref_attention(self, q, k, v, mask, use_qknorm=False):
if use_qknorm:
q = q.reshape([-1, self.head_dim])
q = fused_rms_norm(q.astype("float32"), self.q_norm_weight_tensor, None, 1e-6)[0].astype(self.dtype)
q = q.reshape([self.bsz, -1, self.num_q_head, self.head_dim])
q = q.transpose([0, 2, 1, 3]) q = q.transpose([0, 2, 1, 3])
if len(k) > 1: if len(k) > 1:
k = paddle.concat(k, axis=1) k = paddle.concat(k, axis=1)
else: else:
k = k[0] k = k[0]
if use_qknorm:
k = k.reshape([-1, self.head_dim])
k = fused_rms_norm(k.astype("float32"), self.k_norm_weight_tensor, None, 1e-6)[0].astype(self.dtype)
k = k.reshape([self.bsz, -1, self.num_kv_head, self.head_dim])
k = k.transpose([0, 2, 1, 3]) k = k.transpose([0, 2, 1, 3])
if len(v) > 1: if len(v) > 1:
v = paddle.concat(v, axis=1) v = paddle.concat(v, axis=1)
@@ -127,7 +140,7 @@ class TestTreeMask(unittest.TestCase):
.reshape([-1, self.num_q_head, self.head_dim]) .reshape([-1, self.num_q_head, self.head_dim])
) )
def run_append_c16_attention(self, q_len, kv_len, prefill=False, attn_mask=None): def run_append_c16_attention(self, q_len, kv_len, prefill=False, attn_mask=None, use_qknorm=False):
if prefill: if prefill:
seq_lens_enc = [ seq_lens_enc = [
q_len, q_len,
@@ -187,6 +200,10 @@ class TestTreeMask(unittest.TestCase):
decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory() decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory()
max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu() max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
q_norm_weight = np.random.random([self.head_dim]) / 10
k_norm_weight = np.random.random([self.head_dim]) / 10
self.q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype="float32")
self.k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype="float32")
paddle.device.synchronize() paddle.device.synchronize()
( (
encoder_batch_ids, encoder_batch_ids,
@@ -237,20 +254,20 @@ class TestTreeMask(unittest.TestCase):
max_len_kv, max_len_kv,
rotary_embs, rotary_embs,
attn_mask, attn_mask,
None, None, # qkv_bias
None, None, # qkv_out_scales
cache_k_scale, cache_k_scale,
cache_v_scale, cache_v_scale,
cache_k_out_scale, cache_k_out_scale,
cache_v_out_scale, cache_v_out_scale,
None, None, # cache_k_zp
None, None, # cache_v_zp
None, None, # linear_shift
None, None, # linear_smooth
None, None, # mask_offset
None, None, # kv_signal_data
None, self.q_norm_weight_tensor if use_qknorm else None, # q_norm_weight
None, self.k_norm_weight_tensor if use_qknorm else None, # k_norm_weight
1e-6, 1e-6,
"bf16", "bf16",
"none", "none",
@@ -271,7 +288,7 @@ class TestTreeMask(unittest.TestCase):
paddle.device.synchronize() paddle.device.synchronize()
e_time = time.time() e_time = time.time()
print(f"mean infer time: {np.mean((e_time - s_time) * 1000 / self.run_time):.2f}") print(f"mean infer time: {np.mean((e_time - s_time) * 1000 / self.run_time):.2f}")
return out[0].reshape([token_num, self.num_q_head, self.head_dim]) return out.reshape([token_num, self.num_q_head, self.head_dim])
def test_naive_speculative_decoding(self): def test_naive_speculative_decoding(self):
prefill_len = 8192 prefill_len = 8192
@@ -279,10 +296,10 @@ class TestTreeMask(unittest.TestCase):
total_len = prefill_len + dec_len_q total_len = prefill_len + dec_len_q
mask = paddle.tril(paddle.ones((self.bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len) mask = paddle.tril(paddle.ones((self.bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len)
mask = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf"))) mask = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf")))
self.run_append_c16_attention(prefill_len, 0, True) self.run_append_c16_attention(prefill_len, 0, True, use_qknorm=self.use_qknorm)
dec_out = self.run_append_c16_attention(dec_len_q, prefill_len, False) dec_out = self.run_append_c16_attention(dec_len_q, prefill_len, False, use_qknorm=self.use_qknorm)
ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask) ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask, use_qknorm=self.use_qknorm)
np.testing.assert_allclose( np.testing.assert_allclose(
ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03 ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03
) )