mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00
[Feature][MTP]Support MTP for rl-model (#4009)
* qk norm for speculate decode C16 * support mtp in v1_scheduler mode * support mtp rope_3d * support mtp features * add unit test && del some log --------- Co-authored-by: yuanxiaolan <yuanxiaolan01@baidu.com> Co-authored-by: xiaoxiaohehe001 <hiteezsf@163.com>
This commit is contained in:
@@ -273,11 +273,15 @@ void AppendAttentionKernel(
|
||||
cache_v_zp,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
} else {
|
||||
SpeculateWriteCacheWithRoPEKernel<data_t, data_t>(
|
||||
meta_data,
|
||||
@@ -296,11 +300,15 @@ void AppendAttentionKernel(
|
||||
cache_v_zp,
|
||||
cache_quant_type_str,
|
||||
use_neox_rotary_style,
|
||||
rope_3d,
|
||||
max_input_length,
|
||||
exec_stream,
|
||||
&qkv_out,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
const_cast<paddle::Tensor*>(&value_cache),
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps);
|
||||
}
|
||||
} else {
|
||||
if (qkv_out_scales) {
|
||||
|
@@ -120,7 +120,6 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
float row_variance =
|
||||
max(warp_m2 / head_size, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
|
||||
if (hi < num_heads) { // q
|
||||
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
||||
#pragma unroll
|
||||
@@ -129,6 +128,7 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
}
|
||||
} else { // k
|
||||
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
out_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
|
||||
}
|
||||
|
@@ -18,6 +18,168 @@
|
||||
#include "mma_tensor_op.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T, int VecSize = 1, typename InT = T>
|
||||
__global__ void append_speculate_cache_T_rope_qk_norm_kernel(
|
||||
const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size,
|
||||
// head_size]
|
||||
T* __restrict__ key_cache, // [num_blocks, gqa_group_size, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ q_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
const float* __restrict__ sin_emb,
|
||||
const float*
|
||||
qkv_out_scales, // [(num_heads + 2 * gqa_group_size) * head_size]
|
||||
const T* qkv_biases, // [num_head + 2 * gqa_group_size, dim_head]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int output_inner_dim,
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int elem_cnt,
|
||||
const int gqa_group_size,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadFloat = AlignedVector<float, VecSize>;
|
||||
using LoadInT = AlignedVector<InT, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
LoadInT src_vec;
|
||||
LoadFloat scale_vec;
|
||||
LoadT bias_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
LoadFloat tmp_vec;
|
||||
LoadFloat q_norm_vec;
|
||||
LoadFloat k_norm_vec;
|
||||
|
||||
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
|
||||
int64_t all_warp_num = gridDim.x * blockDim.y;
|
||||
int64_t all_head_dim = elem_cnt / head_size;
|
||||
|
||||
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * head_size;
|
||||
const int half_head_size = head_size / 2;
|
||||
for (int global_hi = global_warp_idx; global_hi < all_head_dim; global_hi += all_warp_num) {
|
||||
int64_t linear_index = global_hi * head_size + threadIdx.x * VecSize;
|
||||
const int token_id = linear_index / hidden_size;
|
||||
const int ori_bi = batch_id_per_token[token_id];
|
||||
if (seq_lens_decoder[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int hi = bias / head_size; // q + k + v
|
||||
const int h_bias = bias % head_size;
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
const int write_seq_id =
|
||||
seq_lens_decoder[ori_bi] + token_id - start_token_idx;
|
||||
if (write_seq_id == 0) continue;
|
||||
|
||||
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
if (block_idx < 0) {
|
||||
printf(
|
||||
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
|
||||
"%d %d %d %d\n",
|
||||
block_idx,
|
||||
write_seq_id,
|
||||
ori_bi,
|
||||
seq_lens_decoder[ori_bi],
|
||||
token_id,
|
||||
cu_seqlens_q[ori_bi]);
|
||||
}
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
const int write_q_idx =
|
||||
token_id * output_inner_dim * head_size + hi * head_size + h_bias;
|
||||
|
||||
const int bias_idx = hi * head_size + h_bias;
|
||||
Load<InT, VecSize>(&qkv[linear_index], &src_vec);
|
||||
if (qkv_biases) {
|
||||
Load<T, VecSize>(&qkv_biases[bias_idx], &bias_vec);
|
||||
}
|
||||
if (qkv_out_scales) {
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx], &scale_vec);
|
||||
}
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
// q k rope
|
||||
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
float thread_m2 = 0.0f;
|
||||
float warp_m2 = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// add_bias + rope
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
if (qkv_out_scales) {
|
||||
input_left *= scale_vec[2 * i];
|
||||
input_right *= scale_vec[2 * i + 1];
|
||||
}
|
||||
if (qkv_biases) {
|
||||
input_left = input_left + static_cast<float>(bias_vec[2 * i]);
|
||||
input_right = input_right + static_cast<float>(bias_vec[2 * i + 1]);
|
||||
}
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
tmp_vec[2 * i] = tmp1;
|
||||
tmp_vec[2 * i + 1] = tmp2;
|
||||
} else {
|
||||
bias_vec[2 * i] = static_cast<T>(input_left);
|
||||
bias_vec[2 * i + 1] = static_cast<T>(input_right);
|
||||
}
|
||||
}
|
||||
if (hi < (num_heads + gqa_group_size)) {
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||
float row_variance =
|
||||
max(warp_m2 / head_size, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
if (hi < num_heads) {
|
||||
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
|
||||
}
|
||||
} else {
|
||||
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (hi < num_heads) {
|
||||
// write q
|
||||
Store<T, VecSize>(bias_vec, &q_out[write_q_idx]);
|
||||
} else {
|
||||
// write k/v
|
||||
const int kv_head_idx = (hi - num_heads) % gqa_group_size;
|
||||
const int tgt_idx = (block_idx * gqa_group_size * block_size * head_size +
|
||||
kv_head_idx * block_size * head_size +
|
||||
block_offset * head_size + h_bias);
|
||||
// write
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
Store<T, VecSize>(bias_vec, &key_cache[tgt_idx]);
|
||||
} else {
|
||||
Store<T, VecSize>(bias_vec, &value_cache[tgt_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int VecSize = 4, int HeadDim = 128>
|
||||
__global__ void append_clear_cache_int8_block(
|
||||
uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size,
|
||||
@@ -193,7 +355,8 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int elem_cnt,
|
||||
const int gqa_group_size) {
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadFloat = AlignedVector<float, VecSize>;
|
||||
using LoadInT = AlignedVector<InT, VecSize>;
|
||||
@@ -253,8 +416,9 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
// q k rope
|
||||
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
@@ -326,7 +490,8 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int elem_cnt,
|
||||
const int gqa_group_size) {
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadFloat = AlignedVector<float, VecSize>;
|
||||
using LoadInT = AlignedVector<InT, VecSize>;
|
||||
@@ -390,8 +555,9 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
// q k rope
|
||||
const int64_t emb_idx = write_seq_id * head_size + h_bias;
|
||||
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2: emb_idx;
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
@@ -476,7 +642,8 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int gqa_group_size) {
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -522,8 +689,9 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
if (qkv_out_scales) {
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||
}
|
||||
@@ -583,10 +751,11 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
T scale;
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
||||
scale = __ldg(&cache_k_scales[kv_head_idx]);
|
||||
} else {
|
||||
scale = __ldg(&cache_v_scales[kv_head_idx]);
|
||||
@@ -708,7 +877,8 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int gqa_group_size) {
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -757,8 +927,9 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
if (qkv_out_scales) {
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx_left],
|
||||
&left_out_scale_vec);
|
||||
@@ -853,10 +1024,11 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
|
||||
T scale;
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
|
||||
scale = __ldg(&cache_k_scales[kv_head_idx]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
@@ -1088,7 +1260,8 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int gqa_group_size) {
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -1145,8 +1318,9 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
// Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// dequant + add_bias + rope
|
||||
@@ -1235,10 +1409,11 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
// &out_scale_vec2);
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
Load<float, 1>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[emb_idx + 4], &sin_emb_vec2);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx], &scale_vec1);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[cache_idx + 8], &scale_vec2);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_zero_points[cache_idx], &zp_vec1);
|
||||
@@ -1431,7 +1606,8 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int gqa_group_size) {
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
@@ -1581,10 +1757,11 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
&right_out_scale_vec2);
|
||||
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[emb_idx + 8], &sin_emb_vec2);
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8], &sin_emb_vec2);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx],
|
||||
&left_scale_vec1);
|
||||
Load<T, HALF_K_VEC_SIZE>(&cache_k_scales[left_cache_idx + 8],
|
||||
|
@@ -15,6 +15,78 @@
|
||||
#include "speculate_write_cache_with_rope_kernel.h"
|
||||
#include "utils.cuh"
|
||||
|
||||
template <typename T, typename QKV_TYPE>
|
||||
void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
T* key_cache,
|
||||
T* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
const float* sin_emb,
|
||||
const float* qkv_out_scales,
|
||||
const T* qkv_biases,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
const int dim_head,
|
||||
const int block_size,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps,
|
||||
const bool rope_3d) {
|
||||
int output_inner_dim = num_heads + 2 * kv_num_heads;
|
||||
const uint32_t elem_nums =
|
||||
use_neox_style ? token_num * (num_heads + 2 * kv_num_heads) * dim_head / 2
|
||||
: token_num * (num_heads + 2 * kv_num_heads) * dim_head;
|
||||
constexpr int HEAD_DIM = 128;
|
||||
|
||||
constexpr int PackSize = HEAD_DIM / kWarpSize;
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
if (use_neox_style) {
|
||||
PD_THROW(
|
||||
"append_speculate_cache_rope_qk_norm not support neox rope yet");
|
||||
} else {
|
||||
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
|
||||
append_speculate_cache_T_rope_qk_norm_kernel<T, PackSize>
|
||||
<<<grid_size, block_dim, 0, stream>>>(qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales,
|
||||
qkv_biases,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
output_inner_dim,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
|
||||
// rope + write
|
||||
template <typename T, typename QKV_TYPE>
|
||||
void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
@@ -39,7 +111,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style) {
|
||||
const bool use_neox_style,
|
||||
const bool rope_3d) {
|
||||
int output_inner_dim = num_heads + 2 * kv_num_heads;
|
||||
|
||||
const uint32_t elem_nums =
|
||||
@@ -73,7 +146,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_speculate_cache_rope_kernel<T, PackSize>
|
||||
<<<grid_size, threads_per_block, 0, stream>>>(
|
||||
@@ -96,7 +170,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,7 +200,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style) {
|
||||
const bool use_neox_style,
|
||||
const bool rope_3d) {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
@@ -167,7 +243,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_speculate_cache_int8_rope_kernel<T, 4, 0, 128, QKV_TYPE, IsFP8>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(qkv,
|
||||
@@ -191,7 +268,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,7 +300,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style) {
|
||||
const bool use_neox_style,
|
||||
const bool rope_3d) {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
@@ -266,7 +345,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
7.0f,
|
||||
-8.0f,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_speculate_cache_int4_rope_kernel<T, 4>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(qkv,
|
||||
@@ -292,7 +372,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
|
||||
block_size,
|
||||
7.0f,
|
||||
-8.0f,
|
||||
kv_num_heads);
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
template <typename T, typename QKV_TYPE>
|
||||
@@ -313,11 +394,15 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out) {
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps) {
|
||||
typedef cascade_attn_type_traits<T> traits_;
|
||||
typedef cascade_attn_type_traits<QKV_TYPE> qkt_nv_type_;
|
||||
typedef typename traits_::type DataType_;
|
||||
@@ -342,142 +427,185 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
? rotary_embs.get().data<float>() + max_seq_len * dim_head
|
||||
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
|
||||
}
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_speculate_cache_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style);
|
||||
} else if (cache_quant_type_str == "cache_int8") {
|
||||
append_speculate_cache_int8_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style);
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
append_speculate_cache_int4_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style);
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_speculate_cache_rope_qk_norm(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
reinterpret_cast<const float*>(q_norm_weight.get().data<float>()),
|
||||
reinterpret_cast<const float*>(k_norm_weight.get().data<float>()),
|
||||
rms_norm_eps,
|
||||
rope_3d);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
|
||||
}
|
||||
|
||||
} else {
|
||||
PD_THROW(
|
||||
"cache_quant_type_str should be one of [none, cache_int8, "
|
||||
"cache_int4_zp]");
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_speculate_cache_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_int8") {
|
||||
append_speculate_cache_int8_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
append_speculate_cache_int4_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"cache_quant_type_str should be one of [none, cache_int8, "
|
||||
"cache_int4_zp]");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -500,11 +628,15 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
||||
template void
|
||||
SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
@@ -526,11 +658,15 @@ SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
||||
template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
@@ -551,11 +687,15 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
||||
|
||||
template void
|
||||
@@ -578,8 +718,12 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
@@ -35,8 +35,12 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
const paddle::optional<paddle::Tensor>& cache_v_zp,
|
||||
const std::string& cache_quant_type_str,
|
||||
const bool use_neox_rotary_style,
|
||||
const bool rope_3d,
|
||||
const int max_seq_len,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* qkv_out,
|
||||
paddle::Tensor* key_cache_out,
|
||||
paddle::Tensor* value_cache_out);
|
||||
paddle::Tensor* value_cache_out,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
@@ -378,9 +378,11 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const int block_size);
|
||||
|
||||
|
||||
const paddle::optional<paddle::Tensor> &draft_tokens,
|
||||
const paddle::optional<paddle::Tensor> &step_draft_tokens,
|
||||
const paddle::optional<paddle::Tensor> &step_seq_lens_this_time,
|
||||
const int block_size,
|
||||
const int max_draft_tokens);
|
||||
|
||||
paddle::Tensor
|
||||
GroupSwigluWithMasked(const paddle::Tensor &fc1_out_tensor,
|
||||
@@ -707,6 +709,22 @@ void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
|
||||
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
|
||||
const paddle::Tensor& seq_lens_decoder);
|
||||
|
||||
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &step_draft_tokens,
|
||||
const paddle::Tensor &step_seq_lens_this_time,
|
||||
const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &stop_nums,
|
||||
const int block_size,
|
||||
const int max_draft_tokens);
|
||||
|
||||
void NgramMatch(const paddle::Tensor &input_ids,
|
||||
const paddle::Tensor &input_ids_len,
|
||||
const paddle::Tensor &pre_ids,
|
||||
@@ -750,6 +768,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
@@ -763,7 +782,8 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& base_model_draft_tokens,
|
||||
const int max_draft_token,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill);
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1);
|
||||
|
||||
|
||||
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
|
||||
@@ -1228,6 +1248,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("speculate_clear_accept_nums",&SpeculateClearAcceptNums, "speculate_clear_accept_nums function");
|
||||
|
||||
m.def("speculate_schedule_cache",&SpeculateScheduleCache, "SpeculateScheduleCache function");
|
||||
|
||||
m.def("ngram_match", &NgramMatch, "ngram_match function");
|
||||
|
||||
m.def("hybird_mtp_ngram", &HybridMtpNgram, "ngram_match_mixed function");
|
||||
|
@@ -15,31 +15,72 @@
|
||||
#include "helper.h"
|
||||
|
||||
__global__ void recover_decode_task(bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int *block_tables,
|
||||
bool *is_block_step,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int *block_tables,
|
||||
bool *is_block_step,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
int thread_idx = threadIdx.x;
|
||||
if (thread_idx < bsz) {
|
||||
if(is_block_step[thread_idx] == true) {
|
||||
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
|
||||
if (block_table_now[step_seq_lens_decoder[thread_idx] / block_size] != -1) {
|
||||
// can be recovered for decoding
|
||||
is_block_step[thread_idx] = false;
|
||||
seq_lens_this_time[thread_idx]= 1;
|
||||
stop_flags[thread_idx] = false;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
|
||||
}
|
||||
// can be recovered for decoding
|
||||
is_block_step[thread_idx] = false;
|
||||
seq_lens_this_time[thread_idx]= 1;
|
||||
stop_flags[thread_idx] = false;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void recover_spec_decode_task(bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int *block_tables,
|
||||
bool *is_block_step,
|
||||
int64_t *draft_tokens,
|
||||
const int64_t *step_draft_tokens,
|
||||
const int *step_seq_lens_this_time,
|
||||
const int bsz,
|
||||
const int block_num_per_seq,
|
||||
const int block_size,
|
||||
const int draft_tokens_len,
|
||||
const int num_extra_tokens) {
|
||||
int thread_idx = threadIdx.x;
|
||||
if (thread_idx < bsz) {
|
||||
if(is_block_step[thread_idx] == true) {
|
||||
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
|
||||
int max_possible_block_idx = (step_seq_lens_decoder[thread_idx] + num_extra_tokens) / block_size;
|
||||
max_possible_block_idx = min(max_possible_block_idx, block_num_per_seq);
|
||||
if (block_table_now[max_possible_block_idx] != -1) {
|
||||
// can be recovered for decoding
|
||||
int64_t *draft_tokens_now = draft_tokens + thread_idx * draft_tokens_len;
|
||||
const int64_t *step_draft_tokens_now = step_draft_tokens + thread_idx * draft_tokens_len;
|
||||
is_block_step[thread_idx] = false;
|
||||
seq_lens_this_time[thread_idx] = step_seq_lens_this_time[thread_idx];
|
||||
stop_flags[thread_idx] = false;
|
||||
seq_lens_encoder[thread_idx] = 0;
|
||||
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
|
||||
for (int i = 0; i < seq_lens_this_time[thread_idx]; i++) {
|
||||
draft_tokens_now[i] = step_draft_tokens_now[i];
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void RecoverDecodeTask(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
@@ -47,7 +88,11 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const int block_size) {
|
||||
const paddle::optional<paddle::Tensor> &draft_tokens,
|
||||
const paddle::optional<paddle::Tensor> &step_draft_tokens,
|
||||
const paddle::optional<paddle::Tensor> &step_seq_lens_this_time,
|
||||
const int block_size,
|
||||
const int max_draft_tokens) {
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE
|
||||
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(seq_lens_this_time.place()));
|
||||
auto cu_stream = dev_ctx->stream();
|
||||
@@ -56,17 +101,38 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
|
||||
#endif
|
||||
const int bsz = seq_lens_this_time.shape()[0];
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
recover_decode_task<<<1, 1024, 0, cu_stream>>>(
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
if (draft_tokens) {
|
||||
const int draft_tokens_len = draft_tokens.get_ptr()->shape()[1];
|
||||
recover_spec_decode_task<<<1, 1024, 0, cu_stream>>>(
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
const_cast<int64_t *>(draft_tokens.get_ptr()->data<int64_t>()),
|
||||
step_draft_tokens.get_ptr()->data<int64_t>(),
|
||||
step_seq_lens_this_time.get_ptr()->data<int>(),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
block_size,
|
||||
draft_tokens_len,
|
||||
max_draft_tokens * 2 + 1);
|
||||
|
||||
} else {
|
||||
recover_decode_task<<<1, 1024, 0, cu_stream>>>(
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
bsz,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(recover_decode_task)
|
||||
@@ -76,8 +142,11 @@ PD_BUILD_STATIC_OP(recover_decode_task)
|
||||
"seq_lens_decoder",
|
||||
"step_seq_lens_decoder",
|
||||
"block_tables",
|
||||
"is_block_step"})
|
||||
.Attrs({"block_size: int"})
|
||||
"is_block_step",
|
||||
paddle::Optional("draft_tokens"),
|
||||
paddle::Optional("step_draft_tokens"),
|
||||
paddle::Optional("step_seq_lens_this_time")})
|
||||
.Attrs({"block_size: int", "max_draft_tokens: int"})
|
||||
.Outputs({"seq_lens_this_time_out",
|
||||
"seq_lens_encoder_out",
|
||||
"seq_lens_decoder_out",
|
||||
|
@@ -15,7 +15,48 @@
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <int THREADBLOCK_SIZE, bool TRCUNCATE_FIRST_TOKEN>
|
||||
|
||||
#define DISPATCH_BLOCKSIZE(BLOCK_SIZE, ...) \
|
||||
do { \
|
||||
constexpr int BlockSize = BLOCK_SIZE; \
|
||||
__VA_ARGS__; \
|
||||
} while (0)
|
||||
|
||||
#define DISPATCH_TRUNCATE_FIRST_TOKEN(truncate_first_token, TRUNCATE_FIRST_TOKEN, ...) \
|
||||
do { \
|
||||
if (truncate_first_token) { \
|
||||
constexpr bool TRUNCATE_FIRST_TOKEN = true; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr bool TRUNCATE_FIRST_TOKEN = false; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define DISPATCH_KVCACHE_SCHEDULER(kvcache_scheduler_v1, KVCACHE_SCHEDULER_V1, ...) \
|
||||
do { \
|
||||
if (kvcache_scheduler_v1) { \
|
||||
constexpr bool KVCACHE_SCHEDULER_V1 = true; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr bool KVCACHE_SCHEDULER_V1 = false; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define DISPATCH_SPLITWISE_PREFILL(splitwise_prefill, SPLITWISE_PREFILL, ...) \
|
||||
do { \
|
||||
if (splitwise_prefill) { \
|
||||
constexpr bool SPLITWISE_PREFILL = true; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr bool SPLITWISE_PREFILL = false; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
|
||||
template <int THREADBLOCK_SIZE, bool TRUNCATE_FIRST_TOKEN, bool KVCACHE_SCHEDULER_V1>
|
||||
__global__ void process_splitwise_prefill(
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
@@ -25,6 +66,7 @@ __global__ void process_splitwise_prefill(
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* is_block_step,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
@@ -58,7 +100,7 @@ __global__ void process_splitwise_prefill(
|
||||
stop_flags[tid] = false;
|
||||
int64_t base_model_first_token = accept_tokens_now[0];
|
||||
int position = seq_len_encoder;
|
||||
if (TRCUNCATE_FIRST_TOKEN) {
|
||||
if (TRUNCATE_FIRST_TOKEN) {
|
||||
input_ids_now[position - 1] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder;
|
||||
} else {
|
||||
@@ -84,7 +126,7 @@ __global__ void process_splitwise_prefill(
|
||||
|
||||
|
||||
|
||||
template <int THREADBLOCK_SIZE, bool TRCUNCATE_FIRST_TOKEN>
|
||||
template <int THREADBLOCK_SIZE, bool TRUNCATE_FIRST_TOKEN, bool KVCACHE_SCHEDULER_V1>
|
||||
__global__ void draft_model_preprocess_kernel(
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
@@ -94,6 +136,7 @@ __global__ void draft_model_preprocess_kernel(
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* is_block_step,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
@@ -134,14 +177,26 @@ __global__ void draft_model_preprocess_kernel(
|
||||
base_model_draft_tokens_now[i] = -1;
|
||||
}
|
||||
|
||||
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
|
||||
batch_drop[tid] = true;
|
||||
stop_flags[tid] = true;
|
||||
// 1. process block_step situation
|
||||
// -- In v0 mode, block_step will drop mtp query.
|
||||
// -- In v1 mode, block_step will continue to infer.
|
||||
if constexpr(KVCACHE_SCHEDULER_V1) {
|
||||
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
|
||||
stop_flags[tid] = true;
|
||||
is_block_step[tid] = true;
|
||||
// Need to continue infer
|
||||
}
|
||||
} else {
|
||||
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
|
||||
batch_drop[tid] = true;
|
||||
stop_flags[tid] = true;
|
||||
}
|
||||
}
|
||||
|
||||
// 2. process normal query, not in any special case.
|
||||
if (!(base_model_stop_flags[tid] || batch_drop[tid])) {
|
||||
not_stop_flag = 1;
|
||||
// 1. first token
|
||||
// prefill generation
|
||||
if (seq_lens_encoder[tid] > 0) {
|
||||
// Can be extended to first few tokens
|
||||
int seq_len_encoder = seq_lens_encoder[tid];
|
||||
@@ -149,14 +204,20 @@ __global__ void draft_model_preprocess_kernel(
|
||||
int64_t base_model_first_token = accept_tokens_now[0];
|
||||
pre_ids_now[0] = base_model_first_token;
|
||||
int position = seq_len_encoder;
|
||||
if (TRCUNCATE_FIRST_TOKEN) {
|
||||
if (TRUNCATE_FIRST_TOKEN) {
|
||||
input_ids_now[position - 1] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder;
|
||||
} else {
|
||||
input_ids_now[position] = base_model_first_token;
|
||||
seq_lens_this_time[tid] = seq_len_encoder + 1;
|
||||
}
|
||||
} else {
|
||||
} else { // decode generation
|
||||
if constexpr (KVCACHE_SCHEDULER_V1) {
|
||||
// 3. try to recover mtp infer in V1 mode
|
||||
if (!base_model_is_block_step[tid] && is_block_step[tid]) {
|
||||
is_block_step[tid] = false;
|
||||
}
|
||||
}
|
||||
if (stop_flags[tid]) {
|
||||
stop_flags[tid] = false;
|
||||
// TODO: check
|
||||
@@ -189,99 +250,8 @@ __global__ void draft_model_preprocess_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <bool TRCUNCATE_FIRST_TOKEN>
|
||||
void DispatchRunner(
|
||||
const cudaStream_t& stream,
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
bool* stop_flags,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
const int* accept_num,
|
||||
const int* base_model_seq_lens_this_time,
|
||||
const int* base_model_seq_lens_encoder,
|
||||
const int* base_model_seq_lens_decoder,
|
||||
const int64_t* base_model_step_idx,
|
||||
const bool* base_model_stop_flags,
|
||||
const bool* base_model_is_block_step,
|
||||
int64_t* base_model_draft_tokens,
|
||||
const int bsz,
|
||||
const int num_model_step,
|
||||
const int accept_tokens_len,
|
||||
const int draft_tokens_len,
|
||||
const int input_ids_len,
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len,
|
||||
const bool splitwise_prefill) {
|
||||
constexpr int BlockSize = 512;
|
||||
if (splitwise_prefill) {
|
||||
process_splitwise_prefill<BlockSize, TRCUNCATE_FIRST_TOKEN>
|
||||
<<<1, BlockSize, 0, stream>>>(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
} else {
|
||||
draft_model_preprocess_kernel<BlockSize, TRCUNCATE_FIRST_TOKEN>
|
||||
<<<1, BlockSize, 0, stream>>>(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
}
|
||||
}
|
||||
|
||||
void DispatchTokenMode(
|
||||
void DispatchRunner(
|
||||
const cudaStream_t &stream,
|
||||
int64_t* draft_tokens,
|
||||
int64_t* input_ids,
|
||||
@@ -291,6 +261,7 @@ void DispatchTokenMode(
|
||||
int* seq_lens_decoder,
|
||||
int64_t* step_idx,
|
||||
bool* not_need_stop,
|
||||
bool* is_block_step,
|
||||
bool* batch_drop,
|
||||
int64_t* pre_ids,
|
||||
const int64_t* accept_tokens,
|
||||
@@ -310,75 +281,79 @@ void DispatchTokenMode(
|
||||
const int base_model_draft_tokens_len,
|
||||
const int pre_ids_len,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill) {
|
||||
if (truncate_first_token) {
|
||||
DispatchRunner<true>(
|
||||
stream,
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
splitwise_prefill
|
||||
);
|
||||
} else {
|
||||
DispatchRunner<false>(
|
||||
stream,
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
splitwise_prefill
|
||||
);
|
||||
}
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1) {
|
||||
DISPATCH_BLOCKSIZE(512, {
|
||||
DISPATCH_TRUNCATE_FIRST_TOKEN(truncate_first_token, TRUNCATE_FIRST_TOKEN, {
|
||||
DISPATCH_KVCACHE_SCHEDULER(kvcache_scheduler_v1, KVCACHE_SCHEDULER_V1, {
|
||||
DISPATCH_SPLITWISE_PREFILL(splitwise_prefill, SPLITWISE_PREFILL, {
|
||||
if constexpr (SPLITWISE_PREFILL) {
|
||||
process_splitwise_prefill<BlockSize, TRUNCATE_FIRST_TOKEN, KVCACHE_SCHEDULER_V1>
|
||||
<<<1, BlockSize, 0, stream>>>(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
} else {
|
||||
draft_model_preprocess_kernel<BlockSize, TRUNCATE_FIRST_TOKEN, KVCACHE_SCHEDULER_V1>
|
||||
<<<1, BlockSize, 0, stream>>>(
|
||||
draft_tokens,
|
||||
input_ids,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_idx,
|
||||
not_need_stop,
|
||||
is_block_step,
|
||||
batch_drop,
|
||||
pre_ids,
|
||||
accept_tokens,
|
||||
accept_num,
|
||||
base_model_seq_lens_this_time,
|
||||
base_model_seq_lens_encoder,
|
||||
base_model_seq_lens_decoder,
|
||||
base_model_step_idx,
|
||||
base_model_stop_flags,
|
||||
base_model_is_block_step,
|
||||
base_model_draft_tokens,
|
||||
bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& input_ids,
|
||||
const paddle::Tensor& stop_flags,
|
||||
@@ -387,6 +362,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& seq_lens_decoder,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& not_need_stop,
|
||||
const paddle::Tensor& is_block_step,
|
||||
const paddle::Tensor& batch_drop,
|
||||
const paddle::Tensor& pre_ids,
|
||||
const paddle::Tensor& accept_tokens,
|
||||
@@ -400,7 +376,8 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
const paddle::Tensor& base_model_draft_tokens,
|
||||
const int num_model_step,
|
||||
const bool truncate_first_token,
|
||||
const bool splitwise_prefill) {
|
||||
const bool splitwise_prefill,
|
||||
const bool kvcache_scheduler_v1) {
|
||||
int real_bsz = seq_lens_this_time.shape()[0];
|
||||
int accept_tokens_len = accept_tokens.shape()[1];
|
||||
int input_ids_len = input_ids.shape()[1];
|
||||
@@ -412,36 +389,38 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
|
||||
auto not_need_stop_gpu =
|
||||
not_need_stop.copy_to(seq_lens_this_time.place(), false);
|
||||
|
||||
DispatchTokenMode(
|
||||
cu_stream,
|
||||
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
|
||||
const_cast<int64_t*>(input_ids.data<int64_t>()),
|
||||
const_cast<bool*>(stop_flags.data<bool>()),
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<bool*>(batch_drop.data<bool>()),
|
||||
const_cast<int64_t*>(pre_ids.data<int64_t>()),
|
||||
accept_tokens.data<int64_t>(),
|
||||
accept_num.data<int>(),
|
||||
base_model_seq_lens_this_time.data<int>(),
|
||||
base_model_seq_lens_encoder.data<int>(),
|
||||
base_model_seq_lens_decoder.data<int>(),
|
||||
base_model_step_idx.data<int64_t>(),
|
||||
base_model_stop_flags.data<bool>(),
|
||||
base_model_is_block_step.data<bool>(),
|
||||
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
|
||||
real_bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
truncate_first_token,
|
||||
splitwise_prefill);
|
||||
DispatchRunner(
|
||||
cu_stream,
|
||||
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
|
||||
const_cast<int64_t*>(input_ids.data<int64_t>()),
|
||||
const_cast<bool*>(stop_flags.data<bool>()),
|
||||
const_cast<int*>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int*>(seq_lens_encoder.data<int>()),
|
||||
const_cast<int*>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t*>(step_idx.data<int64_t>()),
|
||||
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
|
||||
const_cast<bool*>(is_block_step.data<bool>()),
|
||||
const_cast<bool*>(batch_drop.data<bool>()),
|
||||
const_cast<int64_t*>(pre_ids.data<int64_t>()),
|
||||
accept_tokens.data<int64_t>(),
|
||||
accept_num.data<int>(),
|
||||
base_model_seq_lens_this_time.data<int>(),
|
||||
base_model_seq_lens_encoder.data<int>(),
|
||||
base_model_seq_lens_decoder.data<int>(),
|
||||
base_model_step_idx.data<int64_t>(),
|
||||
base_model_stop_flags.data<bool>(),
|
||||
base_model_is_block_step.data<bool>(),
|
||||
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
|
||||
real_bsz,
|
||||
num_model_step,
|
||||
accept_tokens_len,
|
||||
draft_tokens_len,
|
||||
input_ids_len,
|
||||
base_model_draft_tokens_len,
|
||||
pre_ids_len,
|
||||
truncate_first_token,
|
||||
splitwise_prefill,
|
||||
kvcache_scheduler_v1);
|
||||
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
|
||||
@@ -459,6 +438,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
"seq_lens_decoder",
|
||||
"step_idx",
|
||||
"not_need_stop",
|
||||
"is_block_step",
|
||||
"batch_drop",
|
||||
"pre_ids",
|
||||
"accept_tokens",
|
||||
@@ -480,7 +460,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
|
||||
"not_need_stop_out",
|
||||
"batch_drop_out",
|
||||
"pre_ids_out"})
|
||||
.Attrs({"num_model_step: int", "truncate_first_token: bool", "splitwise_prefill: bool"})
|
||||
.Attrs({"num_model_step: int", "truncate_first_token: bool", "splitwise_prefill: bool", "kvcache_scheduler_v1: bool"})
|
||||
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
|
||||
{"input_ids", "input_ids_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
|
@@ -0,0 +1,176 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__global__ void speculate_schedula_cache(
|
||||
const int64_t *draft_tokens,
|
||||
int *block_tables,
|
||||
bool *stop_flags,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *step_draft_tokens,
|
||||
int *step_seq_lens_this_time,
|
||||
int *accept_num,
|
||||
int64_t *accept_tokens,
|
||||
bool *is_block_step,
|
||||
bool *not_need_stop,
|
||||
const int64_t *stop_nums,
|
||||
const int real_bsz,
|
||||
const int max_bsz,
|
||||
const int max_next_step_tokens,
|
||||
const int draft_tokens_len,
|
||||
const int accept_tokens_len,
|
||||
const int block_size,
|
||||
const int block_num_per_seq) {
|
||||
const int bid = threadIdx.x;
|
||||
int stop_flag_now_int = 0;
|
||||
if (bid < real_bsz) {
|
||||
if (!stop_flags[bid]) {
|
||||
const int64_t *draft_tokens_now = draft_tokens + bid * draft_tokens_len;
|
||||
int64_t *step_draft_tokens_now = step_draft_tokens + bid * draft_tokens_len;
|
||||
int *block_table_now = block_tables + bid * block_num_per_seq;
|
||||
int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len;
|
||||
const int max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) / block_size;
|
||||
if (max_possible_block_idx < block_num_per_seq && block_table_now[max_possible_block_idx] == -1) {
|
||||
is_block_step[bid] = true;
|
||||
step_seq_lens_this_time[bid] = seq_lens_this_time[bid];
|
||||
seq_lens_this_time[bid] = 0;
|
||||
stop_flags[bid] = true;
|
||||
stop_flag_now_int = 1;
|
||||
step_seq_lens_decoder[bid] = seq_lens_decoder[bid];
|
||||
seq_lens_decoder[bid] = 0;
|
||||
accept_num[bid] = 0;
|
||||
for (int i = 0; i < accept_tokens_len; i++) {
|
||||
accept_tokens_now[i] = -1;
|
||||
}
|
||||
for (int i = 0; i < draft_tokens_len; i++) {
|
||||
step_draft_tokens_now[i] = draft_tokens_now[i];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
} else if (bid >= real_bsz && bid < max_bsz) {
|
||||
stop_flag_now_int = 1;
|
||||
}
|
||||
__syncthreads();
|
||||
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
// printf("stop_flag_now_int %d \n", stop_flag_now_int);
|
||||
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
// printf("stop_sum %d \n", stop_sum);
|
||||
not_need_stop[0] = stop_sum < stop_nums[0];
|
||||
}
|
||||
}
|
||||
|
||||
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
|
||||
const paddle::Tensor &block_tables,
|
||||
const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &step_seq_lens_decoder,
|
||||
const paddle::Tensor &step_draft_tokens,
|
||||
const paddle::Tensor &step_seq_lens_this_time,
|
||||
const paddle::Tensor &accept_num,
|
||||
const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &is_block_step,
|
||||
const paddle::Tensor ¬_need_stop,
|
||||
const paddle::Tensor &stop_nums,
|
||||
const int block_size,
|
||||
const int max_draft_tokens) {
|
||||
const int real_bsz = seq_lens_this_time.shape()[0];
|
||||
const int max_bsz = stop_flags.shape()[0];
|
||||
const int accept_tokens_len = accept_tokens.shape()[1];
|
||||
const int draft_token_len = draft_tokens.shape()[1];
|
||||
const int block_num_per_seq = block_tables.shape()[1];
|
||||
|
||||
constexpr int BlockSize = 512;
|
||||
const int max_next_step_tokens = 2 * max_draft_tokens + 2;
|
||||
|
||||
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
|
||||
speculate_schedula_cache<BlockSize><<<1, BlockSize, 0, seq_lens_this_time.stream()>>>(
|
||||
draft_tokens.data<int64_t>(),
|
||||
const_cast<int *>(block_tables.data<int>()),
|
||||
const_cast<bool *>(stop_flags.data<bool>()),
|
||||
const_cast<int *>(seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(seq_lens_decoder.data<int>()),
|
||||
const_cast<int *>(step_seq_lens_decoder.data<int>()),
|
||||
const_cast<int64_t *>(step_draft_tokens.data<int64_t>()),
|
||||
const_cast<int *>(step_seq_lens_this_time.data<int>()),
|
||||
const_cast<int *>(accept_num.data<int>()),
|
||||
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
|
||||
const_cast<bool *>(is_block_step.data<bool>()),
|
||||
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
|
||||
stop_nums.data<int64_t>(),
|
||||
real_bsz,
|
||||
max_bsz,
|
||||
max_next_step_tokens,
|
||||
draft_token_len,
|
||||
accept_tokens_len,
|
||||
block_size,
|
||||
block_num_per_seq
|
||||
);
|
||||
|
||||
auto not_need_stop_cpu =
|
||||
not_need_stop_gpu.copy_to(not_need_stop.place(), true);
|
||||
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
|
||||
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(speculate_schedule_cache)
|
||||
.Inputs({"draft_tokens",
|
||||
"block_tables",
|
||||
"stop_flags",
|
||||
"seq_lens_this_time",
|
||||
"seq_lens_decoder",
|
||||
"step_seq_lens_decoder",
|
||||
"step_draft_tokens",
|
||||
"step_seq_lens_this_time",
|
||||
"accept_num",
|
||||
"accept_tokens",
|
||||
"is_block_step",
|
||||
"not_need_stop",
|
||||
"stop_nums"})
|
||||
.Attrs({"block_size: int", "max_draft_tokens: int"})
|
||||
.Outputs({"draft_tokens_out",
|
||||
"block_tables_out",
|
||||
"stop_flags_out",
|
||||
"seq_lens_this_time_out",
|
||||
"seq_lens_decoder_out",
|
||||
"step_seq_lens_decoder_out",
|
||||
"step_draft_tokens_out",
|
||||
"step_seq_lens_this_time_out",
|
||||
"accept_num_out",
|
||||
"accept_tokens_out",
|
||||
"is_block_step_out",
|
||||
"not_need_stop_out"})
|
||||
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
|
||||
{"block_tables", "block_tables_out"},
|
||||
{"stop_flags", "stop_flags_out"},
|
||||
{"seq_lens_this_time", "seq_lens_this_time_out"},
|
||||
{"seq_lens_decoder", "seq_lens_decoder_out"},
|
||||
{"step_seq_lens_decoder", "step_seq_lens_decoder_out"},
|
||||
{"step_draft_tokens", "step_draft_tokens_out"},
|
||||
{"step_seq_lens_this_time", "step_seq_lens_this_time_out"},
|
||||
{"accept_num", "accept_num_out"},
|
||||
{"accept_tokens", "accept_tokens_out"},
|
||||
{"is_block_step", "is_block_step_out"},
|
||||
{"not_need_stop", "not_need_stop_out"},})
|
||||
.SetKernelFn(PD_KERNEL(SpeculateScheduleCache));
|
@@ -889,7 +889,7 @@ class CacheConfig:
|
||||
else:
|
||||
self.kv_cache_ratio = 0.75
|
||||
self.enc_dec_block_num = 0 if current_platform.is_iluvatar() else 2
|
||||
self.prealloc_dec_block_slot_num_threshold = 5
|
||||
self.prealloc_dec_block_slot_num_threshold = 12
|
||||
self.cache_dtype = "bfloat16"
|
||||
self.model_cfg = None
|
||||
self.enable_chunked_prefill = False
|
||||
|
@@ -165,8 +165,7 @@ class EngineArgs:
|
||||
"""
|
||||
Ratio of tokens to process in a block.
|
||||
"""
|
||||
|
||||
prealloc_dec_block_slot_num_threshold: int = 5
|
||||
prealloc_dec_block_slot_num_threshold: int = 12
|
||||
"""
|
||||
Token slot threshold for preallocating decoder blocks.
|
||||
"""
|
||||
@@ -405,8 +404,6 @@ class EngineArgs:
|
||||
raise NotImplementedError("Logprob does not support enable_expert_parallel.")
|
||||
if not current_platform.is_cuda():
|
||||
raise NotImplementedError("Only CUDA platform supports logprob.")
|
||||
if self.speculative_config is not None:
|
||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
||||
if self.splitwise_role != "mixed":
|
||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
||||
if not current_platform.is_cuda():
|
||||
@@ -706,7 +703,7 @@ class EngineArgs:
|
||||
cache_group.add_argument(
|
||||
"--prealloc-dec-block-slot-num-threshold",
|
||||
type=int,
|
||||
default=5,
|
||||
default=12,
|
||||
help="Number of token slot threadshold to allocate next blocks for decoding.",
|
||||
)
|
||||
|
||||
|
@@ -62,7 +62,6 @@ class EngineSevice:
|
||||
self.cfg = cfg
|
||||
|
||||
self.scheduler = cfg.scheduler_config.scheduler()
|
||||
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
self.resource_manager = ResourceManagerV1(
|
||||
cfg.max_num_seqs,
|
||||
|
@@ -84,10 +84,14 @@ class ResourceManagerV1(ResourceManager):
|
||||
return len(request.block_tables) * self.config.cache_config.block_size
|
||||
|
||||
def get_new_block_nums(self, request: Request, num_new_tokens: int):
|
||||
return (
|
||||
block_num = (
|
||||
request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1
|
||||
) // self.config.cache_config.block_size - len(request.block_tables)
|
||||
|
||||
if self.config.speculative_config.method is not None:
|
||||
block_num = min(block_num + 1, self.config.cache_config.max_block_num_per_seq)
|
||||
return block_num
|
||||
|
||||
def _prepare_prefill_task(self, request, new_token_num):
|
||||
request.prefill_start_index = request.num_computed_tokens
|
||||
request.prefill_end_index = request.num_computed_tokens + new_token_num
|
||||
|
@@ -100,6 +100,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
|
||||
fd_config.model_config, "use_3d_rope", False
|
||||
)
|
||||
if fd_config.speculative_config.model_type != "main":
|
||||
self.rope_3d = False
|
||||
self.causal: bool = getattr(fd_config.model_config, "causal", True)
|
||||
self.speculative_method: str = fd_config.speculative_config.method
|
||||
self.use_speculate: bool = self.speculative_method is not None
|
||||
@@ -356,7 +358,7 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
getattr(layer, "cache_v_zp", None),
|
||||
layer.linear_shift,
|
||||
layer.linear_smooth,
|
||||
forward_meta.attn_mask_offsets,
|
||||
None if self.use_speculate else forward_meta.attn_mask_offsets,
|
||||
metadata.kv_signal_data_list[layer.layer_id],
|
||||
getattr(layer, "q_norm_weight", None),
|
||||
getattr(layer, "k_norm_weight", None),
|
||||
@@ -374,7 +376,7 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
metadata.max_partition_size,
|
||||
metadata.encoder_max_partition_size,
|
||||
self.speculate_max_draft_token_num + 1,
|
||||
self.causal,
|
||||
self.causal or self.use_speculate,
|
||||
self.speculative_method is not None,
|
||||
)
|
||||
return res
|
||||
|
@@ -306,7 +306,9 @@ def post_process_normal(
|
||||
)
|
||||
|
||||
|
||||
def post_process_specualate(model_output, save_each_rank: bool = False, skip_save_output: bool = False):
|
||||
def post_process_specualate(
|
||||
model_output: ModelOutputData, save_each_rank: bool = False, skip_save_output: bool = False
|
||||
):
|
||||
""""""
|
||||
speculate_update(
|
||||
model_output.seq_lens_encoder,
|
||||
|
@@ -261,7 +261,7 @@ class TokenProcessor:
|
||||
|
||||
def _compute_speculative_status(self):
|
||||
# TODO(liuzichang): Supplement more statistics
|
||||
interval = 10
|
||||
interval = 1
|
||||
if self.speculative_stats_step % interval == 0:
|
||||
accept_ratio = 1 - self.total_step * 1.0 / self.number_of_output_tokens
|
||||
spec_logger.info(
|
||||
|
@@ -19,8 +19,10 @@ from typing import List
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.engine.request import Request, RequestType
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.model_executor.layers.attention import get_attention_backend
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
@@ -50,14 +52,14 @@ class MTPProposer(Proposer):
|
||||
Proposer for Multi-Token-Prediction(MTP)
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, main_model, local_rank, device_id, main_model_inputs):
|
||||
def __init__(self, cfg, main_model, local_rank, device_id, target_model_inputs):
|
||||
super().__init__(cfg)
|
||||
self.num_main_model_layers = self.model_config.num_hidden_layers
|
||||
self.local_rank = local_rank
|
||||
self.device_id = device_id
|
||||
self._update_cfg(main_model)
|
||||
self._load_model()
|
||||
self.main_model_inputs = main_model_inputs
|
||||
self.target_model_inputs = target_model_inputs
|
||||
self.mtp_strategy = self.speculative_config.mtp_strategy
|
||||
self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps
|
||||
|
||||
@@ -73,7 +75,7 @@ class MTPProposer(Proposer):
|
||||
"""
|
||||
Update config for MTP from global config
|
||||
"""
|
||||
self.model_config.architectures[0] = "Ernie4_5_MTPForCausalLM"
|
||||
self.model_config.architectures[0] = self.model_config.architectures[0].replace("Moe", "MTP")
|
||||
self.speculative_config.sharing_model = main_model
|
||||
self.model_config.num_hidden_layers = 1
|
||||
self.model_config.model = self.speculative_config.model
|
||||
@@ -199,14 +201,16 @@ class MTPProposer(Proposer):
|
||||
encoder_block_shape_q = 64
|
||||
decoder_block_shape_q = 16
|
||||
|
||||
self.model_inputs["decoder_batch_ids"] = paddle.zeros_like(self.main_model_inputs["decoder_batch_ids"])
|
||||
self.model_inputs["decoder_batch_ids"] = paddle.zeros_like(self.target_model_inputs["decoder_batch_ids"])
|
||||
self.model_inputs["decoder_tile_ids_per_batch"] = paddle.zeros_like(
|
||||
self.main_model_inputs["decoder_tile_ids_per_batch"]
|
||||
self.target_model_inputs["decoder_tile_ids_per_batch"]
|
||||
)
|
||||
self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like(
|
||||
self.main_model_inputs["decoder_num_blocks_cpu"]
|
||||
self.target_model_inputs["decoder_num_blocks_cpu"]
|
||||
).pin_memory()
|
||||
self.model_inputs["max_len_tensor_cpu"] = paddle.zeros_like(self.main_model_inputs["max_len_tensor_cpu"]).cpu()
|
||||
self.model_inputs["max_len_tensor_cpu"] = paddle.zeros_like(
|
||||
self.target_model_inputs["max_len_tensor_cpu"]
|
||||
).cpu()
|
||||
|
||||
# Get the attention backend
|
||||
attn_cls = get_attention_backend()
|
||||
@@ -265,28 +269,29 @@ class MTPProposer(Proposer):
|
||||
"""
|
||||
self.model_inputs = {}
|
||||
# Same shape/dytpe with base model
|
||||
self.model_inputs["block_tables"] = paddle.clone(self.main_model_inputs["block_tables"])
|
||||
self.model_inputs["input_ids"] = paddle.clone(self.main_model_inputs["input_ids"])
|
||||
self.seq_lens_this_time_buffer = paddle.clone(self.main_model_inputs["seq_lens_this_time"])
|
||||
self.model_inputs["block_tables"] = paddle.clone(self.target_model_inputs["block_tables"])
|
||||
self.model_inputs["input_ids"] = paddle.clone(self.target_model_inputs["input_ids"])
|
||||
self.model_inputs["input_ids_cpu"] = paddle.full(
|
||||
shape=[self.max_num_seqs, self.parallel_config.max_model_len],
|
||||
fill_value=-1,
|
||||
dtype="int64",
|
||||
).cpu()
|
||||
self.model_inputs["seq_lens_encoder"] = paddle.clone(self.main_model_inputs["seq_lens_encoder"])
|
||||
self.model_inputs["seq_lens_decoder"] = paddle.clone(self.main_model_inputs["seq_lens_decoder"])
|
||||
self.model_inputs["step_idx"] = paddle.clone(self.main_model_inputs["step_idx"])
|
||||
self.model_inputs["stop_flags"] = paddle.clone(self.main_model_inputs["stop_flags"])
|
||||
self.model_inputs["stop_nums"] = paddle.clone(self.main_model_inputs["stop_nums"])
|
||||
self.seq_lens_this_time_buffer = paddle.clone(self.target_model_inputs["seq_lens_this_time"])
|
||||
|
||||
self.model_inputs["seq_lens_encoder"] = paddle.clone(self.target_model_inputs["seq_lens_encoder"])
|
||||
self.model_inputs["seq_lens_decoder"] = paddle.clone(self.target_model_inputs["seq_lens_decoder"])
|
||||
self.model_inputs["step_idx"] = paddle.clone(self.target_model_inputs["step_idx"])
|
||||
self.model_inputs["stop_flags"] = paddle.clone(self.target_model_inputs["stop_flags"])
|
||||
self.model_inputs["stop_nums"] = paddle.clone(self.target_model_inputs["stop_nums"])
|
||||
self.model_inputs["not_need_stop"] = paddle.to_tensor([False], dtype="bool", place="cpu")
|
||||
self.model_inputs["pre_ids"] = paddle.clone(self.main_model_inputs["pre_ids"])
|
||||
self.model_inputs["ids_remove_padding"] = paddle.clone(self.main_model_inputs["ids_remove_padding"])
|
||||
self.model_inputs["batch_id_per_token"] = paddle.clone(self.main_model_inputs["batch_id_per_token"])
|
||||
self.model_inputs["cu_seqlens_q"] = paddle.clone(self.main_model_inputs["cu_seqlens_q"])
|
||||
self.model_inputs["cu_seqlens_k"] = paddle.clone(self.main_model_inputs["cu_seqlens_k"])
|
||||
self.model_inputs["decoder_batch_ids"] = paddle.clone(self.main_model_inputs["decoder_batch_ids"])
|
||||
self.model_inputs["pre_ids"] = paddle.clone(self.target_model_inputs["pre_ids"])
|
||||
self.model_inputs["ids_remove_padding"] = paddle.clone(self.target_model_inputs["ids_remove_padding"])
|
||||
self.model_inputs["batch_id_per_token"] = paddle.clone(self.target_model_inputs["batch_id_per_token"])
|
||||
self.model_inputs["cu_seqlens_q"] = paddle.clone(self.target_model_inputs["cu_seqlens_q"])
|
||||
self.model_inputs["cu_seqlens_k"] = paddle.clone(self.target_model_inputs["cu_seqlens_k"])
|
||||
self.model_inputs["decoder_batch_ids"] = paddle.clone(self.target_model_inputs["decoder_batch_ids"])
|
||||
self.model_inputs["decoder_tile_ids_per_batch"] = paddle.clone(
|
||||
self.main_model_inputs["decoder_tile_ids_per_batch"]
|
||||
self.target_model_inputs["decoder_tile_ids_per_batch"]
|
||||
)
|
||||
|
||||
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
||||
@@ -298,22 +303,22 @@ class MTPProposer(Proposer):
|
||||
)
|
||||
# self.model_inputs["caches"] = self.cache_kvs
|
||||
# Inherit generation hyperparameters from the main model for consistency
|
||||
self.model_inputs["top_p"] = self.main_model_inputs["top_p"]
|
||||
self.model_inputs["top_k"] = self.main_model_inputs["top_k"]
|
||||
self.model_inputs["temperature"] = self.main_model_inputs["temperature"]
|
||||
self.model_inputs["eos_token_id"] = self.main_model_inputs["eos_token_id"]
|
||||
self.model_inputs["penalty_score"] = self.main_model_inputs["penalty_score"]
|
||||
self.model_inputs["frequency_score"] = self.main_model_inputs["frequency_score"]
|
||||
self.model_inputs["presence_score"] = self.main_model_inputs["presence_score"]
|
||||
self.model_inputs["infer_seed"] = self.main_model_inputs["infer_seed"]
|
||||
self.model_inputs["top_p"] = self.target_model_inputs["top_p"]
|
||||
self.model_inputs["top_k"] = self.target_model_inputs["top_k"]
|
||||
self.model_inputs["temperature"] = self.target_model_inputs["temperature"]
|
||||
self.model_inputs["eos_token_id"] = self.target_model_inputs["eos_token_id"]
|
||||
self.model_inputs["penalty_score"] = self.target_model_inputs["penalty_score"]
|
||||
self.model_inputs["frequency_score"] = self.target_model_inputs["frequency_score"]
|
||||
self.model_inputs["presence_score"] = self.target_model_inputs["presence_score"]
|
||||
self.model_inputs["infer_seed"] = self.target_model_inputs["infer_seed"]
|
||||
|
||||
self.model_inputs["max_dec_len"] = self.main_model_inputs["max_dec_len"]
|
||||
self.model_inputs["min_dec_len"] = self.main_model_inputs["min_dec_len"]
|
||||
self.model_inputs["max_dec_len"] = self.target_model_inputs["max_dec_len"]
|
||||
self.model_inputs["min_dec_len"] = self.target_model_inputs["min_dec_len"]
|
||||
|
||||
self.model_inputs["bad_tokens"] = self.main_model_inputs["bad_tokens"]
|
||||
self.model_inputs["bad_tokens"] = self.target_model_inputs["bad_tokens"]
|
||||
|
||||
# Integrate the updated results in model forward
|
||||
self.model_inputs["base_model_draft_tokens"] = self.main_model_inputs["draft_tokens"]
|
||||
self.model_inputs["base_model_draft_tokens"] = self.target_model_inputs["draft_tokens"]
|
||||
self.model_inputs["substep"] = 0
|
||||
|
||||
# Declare AttentionBackend buffers
|
||||
@@ -327,7 +332,7 @@ class MTPProposer(Proposer):
|
||||
shape=[self.max_num_seqs, self.max_draft_token_num + 1], fill_value=-1, dtype="int64"
|
||||
)
|
||||
|
||||
self.model_inputs["encoder_block_lens"] = paddle.clone(self.main_model_inputs["encoder_block_lens"])
|
||||
self.model_inputs["encoder_block_lens"] = paddle.clone(self.target_model_inputs["encoder_block_lens"])
|
||||
|
||||
self.free_list = list(
|
||||
range(
|
||||
@@ -341,14 +346,76 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["free_list"] = paddle.to_tensor(self.free_list, dtype="int32")
|
||||
self.model_inputs["free_list_len"] = paddle.full(shape=[1], fill_value=self.free_list_len, dtype="int32")
|
||||
|
||||
self.model_inputs["is_block_step"] = paddle.full(shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool")
|
||||
self.model_inputs["batch_drop"] = paddle.full(shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool")
|
||||
self.model_inputs["used_list_len"] = paddle.full(shape=[self.max_num_seqs], fill_value=0, dtype="int32")
|
||||
if self.num_model_steps > 1:
|
||||
self.last_seq_lens_this_time = paddle.full_like(
|
||||
self.main_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32"
|
||||
self.target_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32"
|
||||
)
|
||||
self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu()
|
||||
|
||||
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):
|
||||
|
||||
if "caches" not in self.model_inputs:
|
||||
self.initialize_kv_cache()
|
||||
req_len = len(req_dicts)
|
||||
|
||||
for i in range(req_len):
|
||||
request = req_dicts[i]
|
||||
logger.debug(f"{i}th request-{request.request_id}: {request}")
|
||||
idx = request.idx
|
||||
if request.task_type.value == RequestType.PREFILL.value: # prefill task
|
||||
prefill_start_index = request.prefill_start_index
|
||||
prefill_end_index = request.prefill_end_index
|
||||
length = prefill_end_index - prefill_start_index
|
||||
|
||||
input_ids = request.prompt_token_ids + request.output_token_ids
|
||||
|
||||
self.input_ids_len[idx] = length
|
||||
self.model_inputs["pre_ids"][idx : idx + 1] = -1
|
||||
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs["input_ids"][
|
||||
idx : idx + 1, 1:length
|
||||
]
|
||||
encoder_block_num = len(request.block_tables)
|
||||
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
|
||||
request.block_tables, dtype="int32"
|
||||
)
|
||||
self.model_inputs["stop_flags"][idx : idx + 1] = False
|
||||
self.model_inputs["batch_drop"][idx : idx + 1] = False
|
||||
|
||||
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length
|
||||
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index
|
||||
self.seq_lens_this_time_buffer[idx : idx + 1] = length
|
||||
self.model_inputs["step_idx"][idx : idx + 1] = (
|
||||
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0
|
||||
)
|
||||
|
||||
# has_prefill_task = True
|
||||
elif request.task_type.value == RequestType.DECODE.value: # decode task
|
||||
encoder_block_num = len(request.block_tables)
|
||||
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
|
||||
request.block_tables, dtype="int32"
|
||||
)
|
||||
# if self.model_inputs["is_block_step"][idx]: # has tasks to continue to decode
|
||||
# has_decode_task = True
|
||||
# continue
|
||||
else:
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
|
||||
self.model_inputs["stop_flags"][idx : idx + 1] = True
|
||||
self.seq_lens_this_time_buffer[idx : idx + 1] = 0
|
||||
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = 0
|
||||
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
|
||||
self.model_inputs["is_block_step"][idx : idx + 1] = False
|
||||
continue
|
||||
# if has_prefill_task or has_decode_task:
|
||||
# self.model_inputs["not_need_stop"][0] = True
|
||||
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
|
||||
|
||||
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int):
|
||||
"""
|
||||
Process inputs for prefill tasks and insert it to model_inputs buffer
|
||||
@@ -408,9 +475,9 @@ class MTPProposer(Proposer):
|
||||
length = len(request.prompt_token_ids)
|
||||
|
||||
if length > 1:
|
||||
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.main_model_inputs["input_ids"][
|
||||
idx : idx + 1, 1:length
|
||||
]
|
||||
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[
|
||||
"input_ids"
|
||||
][idx : idx + 1, 1:length]
|
||||
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array(
|
||||
request.prompt_token_ids
|
||||
)[1:]
|
||||
@@ -470,6 +537,7 @@ class MTPProposer(Proposer):
|
||||
"""
|
||||
Prepare MTP inputs
|
||||
"""
|
||||
use_v1_cache_scheduler = envs.ENABLE_V1_KVCACHE_SCHEDULER
|
||||
draft_model_preprocess(
|
||||
self.model_inputs["draft_tokens"],
|
||||
self.model_inputs["input_ids"],
|
||||
@@ -480,19 +548,21 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["step_idx"],
|
||||
self.model_inputs["not_need_stop"],
|
||||
self.model_inputs["batch_drop"],
|
||||
self.model_inputs["is_block_step"],
|
||||
self.model_inputs["pre_ids"],
|
||||
self.main_model_inputs["accept_tokens"],
|
||||
self.main_model_inputs["accept_num"],
|
||||
self.main_model_inputs["seq_lens_this_time"],
|
||||
self.main_model_inputs["seq_lens_encoder"],
|
||||
self.main_model_inputs["seq_lens_decoder"],
|
||||
self.main_model_inputs["step_idx"],
|
||||
self.main_model_inputs["stop_flags"],
|
||||
self.main_model_inputs["is_block_step"],
|
||||
self.main_model_inputs["draft_tokens"],
|
||||
self.target_model_inputs["accept_tokens"],
|
||||
self.target_model_inputs["accept_num"],
|
||||
self.target_model_inputs["seq_lens_this_time"],
|
||||
self.target_model_inputs["seq_lens_encoder"],
|
||||
self.target_model_inputs["seq_lens_decoder"],
|
||||
self.target_model_inputs["step_idx"],
|
||||
self.target_model_inputs["stop_flags"],
|
||||
self.target_model_inputs["is_block_step"],
|
||||
self.target_model_inputs["draft_tokens"],
|
||||
self.num_model_steps,
|
||||
self.speculative_method in ["eagle", "mtp"],
|
||||
self.role == "prefill",
|
||||
use_v1_cache_scheduler,
|
||||
)
|
||||
|
||||
target_hidden_states = eagle_get_hidden_states(
|
||||
@@ -501,9 +571,9 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["seq_lens_encoder"],
|
||||
self.model_inputs["seq_lens_decoder"],
|
||||
self.model_inputs["stop_flags"],
|
||||
self.main_model_inputs["accept_num"],
|
||||
self.main_model_inputs["seq_lens_this_time"],
|
||||
self.main_model_inputs["seq_lens_encoder"],
|
||||
self.target_model_inputs["accept_num"],
|
||||
self.target_model_inputs["seq_lens_this_time"],
|
||||
self.target_model_inputs["seq_lens_encoder"],
|
||||
self.num_model_steps,
|
||||
)
|
||||
if isinstance(target_hidden_states, list):
|
||||
@@ -673,41 +743,41 @@ class MTPProposer(Proposer):
|
||||
Allocate/Free block of MPT.
|
||||
"""
|
||||
draft_model_postprocess(
|
||||
self.main_model_inputs["draft_tokens"],
|
||||
self.main_model_inputs["seq_lens_this_time"],
|
||||
self.main_model_inputs["seq_lens_encoder"],
|
||||
self.main_model_inputs["stop_flags"],
|
||||
)
|
||||
|
||||
mtp_step_paddle(
|
||||
self.main_model_inputs["stop_flags"],
|
||||
self.model_inputs["stop_flags"],
|
||||
self.model_inputs["batch_drop"],
|
||||
self.model_inputs["seq_lens_this_time"],
|
||||
self.model_inputs["seq_lens_encoder"],
|
||||
self.model_inputs["seq_lens_decoder"],
|
||||
self.model_inputs["block_tables"],
|
||||
self.model_inputs["encoder_block_lens"],
|
||||
self.model_inputs["used_list_len"],
|
||||
self.model_inputs["free_list"],
|
||||
self.model_inputs["free_list_len"],
|
||||
self.cache_config.block_size,
|
||||
self.max_draft_token_num,
|
||||
self.target_model_inputs["draft_tokens"],
|
||||
self.target_model_inputs["seq_lens_this_time"],
|
||||
self.target_model_inputs["seq_lens_encoder"],
|
||||
self.target_model_inputs["stop_flags"],
|
||||
)
|
||||
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
mtp_step_paddle(
|
||||
self.target_model_inputs["stop_flags"],
|
||||
self.model_inputs["stop_flags"],
|
||||
self.model_inputs["batch_drop"],
|
||||
self.model_inputs["seq_lens_this_time"],
|
||||
self.model_inputs["seq_lens_encoder"],
|
||||
self.model_inputs["seq_lens_decoder"],
|
||||
self.model_inputs["block_tables"],
|
||||
self.model_inputs["encoder_block_lens"],
|
||||
self.model_inputs["used_list_len"],
|
||||
self.model_inputs["free_list"],
|
||||
self.model_inputs["free_list_len"],
|
||||
self.cache_config.block_size,
|
||||
self.max_draft_token_num,
|
||||
)
|
||||
|
||||
def _extend_draft_token_with_ngram_match(self):
|
||||
# TODO(liuzichang): Optimize this Kernel to CUDA Kernel to reduce lantency
|
||||
device = paddle.CUDAPinnedPlace()
|
||||
|
||||
draft_tokens = self.main_model_inputs["draft_tokens"].cpu()
|
||||
seq_lens_this_time = self.main_model_inputs["seq_lens_this_time"].cpu()
|
||||
draft_tokens = self.target_model_inputs["draft_tokens"].cpu()
|
||||
seq_lens_this_time = self.target_model_inputs["seq_lens_this_time"].cpu()
|
||||
seq_lens_decoder = self.model_inputs["seq_lens_decoder"].cpu()
|
||||
hybrid_mtp_ngram(
|
||||
self.model_inputs["input_ids_cpu"],
|
||||
self.input_ids_len,
|
||||
self.model_inputs["pre_ids"]._copy_to(device, True),
|
||||
self.model_inputs["step_idx"].cpu(),
|
||||
self.main_model_inputs["actual_draft_token_num"].cpu(),
|
||||
self.target_model_inputs["actual_draft_token_num"].cpu(),
|
||||
draft_tokens,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
@@ -716,8 +786,8 @@ class MTPProposer(Proposer):
|
||||
self.min_ngram_size,
|
||||
self.max_draft_token_num,
|
||||
)
|
||||
self.main_model_inputs["draft_tokens"][:] = draft_tokens.cuda()
|
||||
self.main_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()
|
||||
self.target_model_inputs["draft_tokens"][:] = draft_tokens.cuda()
|
||||
self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()
|
||||
|
||||
def _run_impl(self, full_hidden_states):
|
||||
""""""
|
||||
|
@@ -59,6 +59,7 @@ else:
|
||||
recover_decode_task,
|
||||
set_value_by_flags_and_idx,
|
||||
share_external_data,
|
||||
speculate_schedule_cache,
|
||||
)
|
||||
|
||||
from fastdeploy.model_executor.pre_and_post_process import (
|
||||
@@ -383,6 +384,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["stop_flags"].sum() == self.parallel_config.max_num_seqs
|
||||
)
|
||||
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
|
||||
if self.speculative_method in ["mtp"]:
|
||||
self.proposer.insert_tasks_v1(req_dicts, num_running_requests)
|
||||
|
||||
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int = None):
|
||||
"""
|
||||
@@ -803,6 +806,13 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
fill_value=0,
|
||||
dtype="int32",
|
||||
)
|
||||
# For V1_KVCACHE_SCHEDULER
|
||||
self.share_inputs["step_draft_tokens"] = paddle.full(
|
||||
shape=[max_num_seqs, max_draft_token_num + 1],
|
||||
fill_value=0,
|
||||
dtype="int64",
|
||||
)
|
||||
self.share_inputs["step_seq_lens_this_time"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
|
||||
if self.enable_mm:
|
||||
head_dim = self.model_config.head_dim
|
||||
@@ -841,7 +851,11 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["step_seq_lens_decoder"],
|
||||
self.share_inputs["block_tables"],
|
||||
self.share_inputs["is_block_step"],
|
||||
self.share_inputs["draft_tokens"] if self.speculative_decoding else None,
|
||||
self.share_inputs["step_draft_tokens"] if self.speculative_decoding else None,
|
||||
self.share_inputs["step_seq_lens_this_time"] if self.speculative_decoding else None,
|
||||
self.cache_config.block_size,
|
||||
self.speculative_config.num_speculative_tokens if self.speculative_decoding else 0,
|
||||
)
|
||||
|
||||
# Remove padding
|
||||
@@ -1540,6 +1554,24 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
self._update_chunked_prefill(model_forward_batch)
|
||||
self._add_cache(model_forward_batch)
|
||||
elif self.speculative_decoding:
|
||||
speculate_schedule_cache(
|
||||
self.share_inputs["draft_tokens"],
|
||||
self.share_inputs["block_tables"],
|
||||
self.share_inputs["stop_flags"],
|
||||
self.share_inputs["seq_lens_this_time"],
|
||||
self.share_inputs["seq_lens_decoder"],
|
||||
self.share_inputs["step_seq_lens_decoder"],
|
||||
self.share_inputs["step_draft_tokens"],
|
||||
self.share_inputs["step_seq_lens_this_time"],
|
||||
self.share_inputs["accept_num"],
|
||||
self.share_inputs["accept_tokens"],
|
||||
self.share_inputs["is_block_step"],
|
||||
self.share_inputs["not_need_stop"],
|
||||
self.share_inputs["stop_nums"],
|
||||
self.cache_config.block_size,
|
||||
self.speculative_config.num_speculative_tokens,
|
||||
)
|
||||
|
||||
self.seq_lens_this_time_buffer[:num_running_requests].copy_(
|
||||
self.share_inputs["seq_lens_this_time"][:num_running_requests], False
|
||||
|
@@ -742,13 +742,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
||||
logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
|
||||
logger.info(f"- Load strategy: {load_config.load_strategy}")
|
||||
|
||||
if (
|
||||
args.speculative_config is not None
|
||||
and ("method" in args.speculative_config)
|
||||
and (args.speculative_config["method"] is not None)
|
||||
):
|
||||
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not support speculative decoding now.")
|
||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
||||
if args.splitwise_role != "mixed":
|
||||
logger.info(f"Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported {args.splitwise_role} now.")
|
||||
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
|
||||
|
238
tests/operators/test_speculative_schedule_cache.py
Normal file
238
tests/operators/test_speculative_schedule_cache.py
Normal file
@@ -0,0 +1,238 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import speculate_schedule_cache
|
||||
|
||||
|
||||
def cpu_reference(
|
||||
draft_tokens,
|
||||
block_tables,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
step_seq_lens_decoder,
|
||||
step_draft_tokens,
|
||||
step_seq_lens_this_time,
|
||||
accept_num,
|
||||
accept_tokens,
|
||||
is_block_step,
|
||||
not_need_stop,
|
||||
stop_nums,
|
||||
block_size,
|
||||
max_draft_tokens,
|
||||
):
|
||||
"""Pure-NumPy mirror of the CUDA kernel's logic (single block of 512 threads).
|
||||
Shapes are the same as inputs to the custom op. This mutates the provided
|
||||
NumPy arrays in-place, exactly like the kernel does.
|
||||
"""
|
||||
real_bsz = seq_lens_this_time.shape[0]
|
||||
max_bsz = stop_flags.shape[0]
|
||||
draft_tokens_len = draft_tokens.shape[1]
|
||||
block_num_per_seq = block_tables.shape[1]
|
||||
|
||||
max_next_step_tokens = 2 * max_draft_tokens + 2
|
||||
|
||||
# Block-local reduction input per thread (threadIdx.x -> bid)
|
||||
stop_flag_now_int = np.zeros(512, dtype=np.int64) # THREADBLOCK_SIZE = 512
|
||||
|
||||
for bid in range(512):
|
||||
if bid < real_bsz:
|
||||
if not stop_flags[bid]:
|
||||
max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) // block_size
|
||||
if max_possible_block_idx < block_num_per_seq and block_tables[bid, max_possible_block_idx] == -1:
|
||||
is_block_step[bid] = True
|
||||
step_seq_lens_this_time[bid] = seq_lens_this_time[bid]
|
||||
seq_lens_this_time[bid] = 0
|
||||
stop_flags[bid] = True
|
||||
step_seq_lens_decoder[bid] = seq_lens_decoder[bid]
|
||||
seq_lens_decoder[bid] = 0
|
||||
accept_num[bid] = 0
|
||||
accept_tokens[bid, :] = -1
|
||||
step_draft_tokens[bid, :draft_tokens_len] = draft_tokens[bid, :draft_tokens_len]
|
||||
stop_flag_now_int[bid] = 1
|
||||
else:
|
||||
stop_flag_now_int[bid] = 0
|
||||
else:
|
||||
stop_flag_now_int[bid] = 1
|
||||
elif bid < max_bsz:
|
||||
# Threads in [real_bsz, max_bsz) contribute 1 to reduction
|
||||
stop_flag_now_int[bid] = 1
|
||||
else:
|
||||
stop_flag_now_int[bid] = 0
|
||||
|
||||
stop_sum = int(stop_flag_now_int.sum())
|
||||
not_need_stop[0] = stop_sum < int(stop_nums[0])
|
||||
|
||||
|
||||
class TestSpeculateScheduleCache(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if not paddle.is_compiled_with_cuda():
|
||||
raise unittest.SkipTest("Paddle is not compiled with CUDA; skipping GPU op test.")
|
||||
paddle.device.set_device("gpu")
|
||||
|
||||
def setUp(self):
|
||||
# --- Construct a deterministic case that exercises all branches ---
|
||||
# real_bsz < max_bsz to test the padding logic in the CUB reduction
|
||||
self.real_bsz = 3
|
||||
self.max_bsz = 5 # only stop_flags has length max_bsz
|
||||
|
||||
self.draft_tokens_len = 6
|
||||
self.accept_tokens_len = 5
|
||||
self.block_size = 4
|
||||
self.block_num_per_seq = 3
|
||||
self.max_draft_tokens = 2 # -> max_next_step_tokens = 6
|
||||
|
||||
# Inputs that will trigger for bid 0, not trigger for bid 2, and bid 1 is already stopped
|
||||
# seq_lens_decoder + 6 // 4 -> indices: [1, 1, 4]. Index 4 is out of range -> no trigger on bid 2
|
||||
self.draft_tokens = paddle.to_tensor(
|
||||
np.array(
|
||||
[
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[2, 2, 2, 2, 2, 2],
|
||||
[3, 3, 3, 3, 3, 3],
|
||||
],
|
||||
dtype=np.int64,
|
||||
)
|
||||
)
|
||||
self.block_tables = paddle.to_tensor(np.full((self.real_bsz, self.block_num_per_seq), -1, dtype=np.int32))
|
||||
# stop_flags length is max_bsz, others are real_bsz
|
||||
self.stop_flags = paddle.to_tensor(np.array([False, True, False, False, False], dtype=np.bool_))
|
||||
self.seq_lens_this_time = paddle.to_tensor(np.array([5, 6, 7], dtype=np.int32))
|
||||
self.seq_lens_decoder = paddle.to_tensor(np.array([1, 1, 10], dtype=np.int32))
|
||||
|
||||
# Will be filled by kernel for the triggering bids only
|
||||
self.step_seq_lens_decoder = paddle.zeros((self.real_bsz,), dtype="int32")
|
||||
self.step_draft_tokens = paddle.zeros((self.real_bsz, self.draft_tokens_len), dtype="int64")
|
||||
self.step_seq_lens_this_time = paddle.zeros((self.real_bsz,), dtype="int32")
|
||||
|
||||
# Intentionally non-zero so we can verify in-place zeroing only where triggered
|
||||
self.accept_num = paddle.to_tensor(np.array([9, 8, 7], dtype=np.int32))
|
||||
self.accept_tokens = paddle.to_tensor(
|
||||
np.arange(self.real_bsz * self.accept_tokens_len, dtype=np.int64).reshape(
|
||||
self.real_bsz, self.accept_tokens_len
|
||||
)
|
||||
)
|
||||
self.is_block_step = paddle.zeros((self.real_bsz,), dtype=paddle.bool)
|
||||
|
||||
# not_need_stop lives on CPU in the caller; the kernel copies to device internally
|
||||
self.not_need_stop = paddle.zeros((1,), dtype=paddle.bool).cpu()
|
||||
|
||||
# Choose threshold so with: bid0 triggers, bid1 already stopped, padding (5-3)=2 -> stop_sum = 1+1+2 = 4
|
||||
# Set stop_nums to 5 so not_need_stop = (4 < 5) = True
|
||||
self.stop_nums = paddle.to_tensor([5], dtype=paddle.int64)
|
||||
|
||||
# Keep NumPy copies for CPU reference
|
||||
self.np_draft_tokens = self.draft_tokens.numpy().copy()
|
||||
self.np_block_tables = self.block_tables.numpy().copy()
|
||||
self.np_stop_flags = self.stop_flags.numpy().copy()
|
||||
self.np_seq_lens_this_time = self.seq_lens_this_time.numpy().copy()
|
||||
self.np_seq_lens_decoder = self.seq_lens_decoder.numpy().copy()
|
||||
self.np_step_seq_lens_decoder = self.step_seq_lens_decoder.numpy().copy()
|
||||
self.np_step_draft_tokens = self.step_draft_tokens.numpy().copy()
|
||||
self.np_step_seq_lens_this_time = self.step_seq_lens_this_time.numpy().copy()
|
||||
self.np_accept_num = self.accept_num.numpy().copy()
|
||||
self.np_accept_tokens = self.accept_tokens.numpy().copy()
|
||||
self.np_is_block_step = self.is_block_step.numpy().copy()
|
||||
self.np_not_need_stop = self.not_need_stop.numpy().copy()
|
||||
self.np_stop_nums = self.stop_nums.numpy().copy()
|
||||
|
||||
def test_correctness_against_cpu_reference(self):
|
||||
# Run GPU kernel (in-place)
|
||||
speculate_schedule_cache(
|
||||
self.draft_tokens,
|
||||
self.block_tables,
|
||||
self.stop_flags,
|
||||
self.seq_lens_this_time,
|
||||
self.seq_lens_decoder,
|
||||
self.step_seq_lens_decoder,
|
||||
self.step_draft_tokens,
|
||||
self.step_seq_lens_this_time,
|
||||
self.accept_num,
|
||||
self.accept_tokens,
|
||||
self.is_block_step,
|
||||
self.not_need_stop,
|
||||
self.stop_nums,
|
||||
self.block_size,
|
||||
self.max_draft_tokens,
|
||||
)
|
||||
|
||||
# Compute CPU reference (in-place on NumPy copies)
|
||||
cpu_reference(
|
||||
self.np_draft_tokens,
|
||||
self.np_block_tables,
|
||||
self.np_stop_flags,
|
||||
self.np_seq_lens_this_time,
|
||||
self.np_seq_lens_decoder,
|
||||
self.np_step_seq_lens_decoder,
|
||||
self.np_step_draft_tokens,
|
||||
self.np_step_seq_lens_this_time,
|
||||
self.np_accept_num,
|
||||
self.np_accept_tokens,
|
||||
self.np_is_block_step,
|
||||
self.np_not_need_stop,
|
||||
self.np_stop_nums,
|
||||
self.block_size,
|
||||
self.max_draft_tokens,
|
||||
)
|
||||
|
||||
# Compare all mutated tensors
|
||||
np.testing.assert_array_equal(self.step_draft_tokens.numpy(), self.np_step_draft_tokens)
|
||||
np.testing.assert_array_equal(self.accept_tokens.numpy(), self.np_accept_tokens)
|
||||
np.testing.assert_array_equal(self.stop_flags.numpy(), self.np_stop_flags)
|
||||
np.testing.assert_array_equal(self.is_block_step.numpy(), self.np_is_block_step)
|
||||
np.testing.assert_array_equal(self.seq_lens_this_time.numpy(), self.np_seq_lens_this_time)
|
||||
np.testing.assert_array_equal(self.seq_lens_decoder.numpy(), self.np_seq_lens_decoder)
|
||||
np.testing.assert_array_equal(self.step_seq_lens_decoder.numpy(), self.np_step_seq_lens_decoder)
|
||||
np.testing.assert_array_equal(self.step_seq_lens_this_time.numpy(), self.np_step_seq_lens_this_time)
|
||||
np.testing.assert_array_equal(self.accept_num.numpy(), self.np_accept_num)
|
||||
self.assertEqual(bool(self.not_need_stop.numpy()[0]), bool(self.np_not_need_stop[0]))
|
||||
|
||||
def test_no_trigger_path(self):
|
||||
# Make block_tables at candidate index != -1 so nothing triggers
|
||||
# Candidate index for bid 0/1 is 1, set it to 7
|
||||
bt = self.block_tables.numpy()
|
||||
bt[:, 1] = 7
|
||||
self.block_tables = paddle.to_tensor(bt)
|
||||
|
||||
# Reset outputs to distinctive values
|
||||
self.step_seq_lens_decoder[:] = 0
|
||||
self.step_draft_tokens[:] = 0
|
||||
self.step_seq_lens_this_time[:] = 0
|
||||
self.accept_num[:] = -123
|
||||
self.accept_tokens[:] = -777
|
||||
self.is_block_step[:] = False
|
||||
self.not_need_stop[:] = False
|
||||
|
||||
# For not_need_stop: stopped_in_real = (bid1 True) = 1, padding = 2 -> stop_sum=3
|
||||
# With stop_nums=5 -> True
|
||||
speculate_schedule_cache(
|
||||
self.draft_tokens,
|
||||
self.block_tables,
|
||||
self.stop_flags,
|
||||
self.seq_lens_this_time,
|
||||
self.seq_lens_decoder,
|
||||
self.step_seq_lens_decoder,
|
||||
self.step_draft_tokens,
|
||||
self.step_seq_lens_this_time,
|
||||
self.accept_num,
|
||||
self.accept_tokens,
|
||||
self.is_block_step,
|
||||
self.not_need_stop,
|
||||
self.stop_nums,
|
||||
self.block_size,
|
||||
self.max_draft_tokens,
|
||||
)
|
||||
|
||||
# Nothing should have changed except not_need_stop
|
||||
np.testing.assert_array_equal(self.step_draft_tokens.numpy(), np.zeros_like(self.step_draft_tokens.numpy()))
|
||||
np.testing.assert_array_equal(self.is_block_step.numpy(), np.zeros_like(self.is_block_step.numpy()))
|
||||
np.testing.assert_array_equal(self.accept_tokens.numpy(), np.full_like(self.accept_tokens.numpy(), -777))
|
||||
np.testing.assert_array_equal(self.accept_num.numpy(), np.full_like(self.accept_num.numpy(), -123))
|
||||
self.assertTrue(bool(self.not_need_stop.numpy()[0]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@@ -5,12 +5,16 @@ import unittest
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
from paddle.incubate.nn.functional import fused_rms_norm
|
||||
|
||||
from fastdeploy.model_executor.layers.attention.ops import (
|
||||
append_attention,
|
||||
get_block_shape_and_split_kv_block,
|
||||
)
|
||||
|
||||
np.random.seed(0)
|
||||
paddle.seed(0)
|
||||
|
||||
|
||||
class TestTreeMask(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@@ -27,6 +31,7 @@ class TestTreeMask(unittest.TestCase):
|
||||
self.head_dim = 128
|
||||
self.num_q_head = 20
|
||||
self.num_kv_head = 4
|
||||
self.use_qknorm = True
|
||||
self.dtype = "bfloat16"
|
||||
|
||||
self.rope_3d = False
|
||||
@@ -91,12 +96,20 @@ class TestTreeMask(unittest.TestCase):
|
||||
cu_seqlens_k[i + 1] = cum_seq_len_k
|
||||
return paddle.to_tensor(batch_id_per_token, dtype="int32"), cu_seqlens_q, cu_seqlens_k
|
||||
|
||||
def ref_attention(self, q, k, v, mask):
|
||||
def ref_attention(self, q, k, v, mask, use_qknorm=False):
|
||||
if use_qknorm:
|
||||
q = q.reshape([-1, self.head_dim])
|
||||
q = fused_rms_norm(q.astype("float32"), self.q_norm_weight_tensor, None, 1e-6)[0].astype(self.dtype)
|
||||
q = q.reshape([self.bsz, -1, self.num_q_head, self.head_dim])
|
||||
q = q.transpose([0, 2, 1, 3])
|
||||
if len(k) > 1:
|
||||
k = paddle.concat(k, axis=1)
|
||||
else:
|
||||
k = k[0]
|
||||
if use_qknorm:
|
||||
k = k.reshape([-1, self.head_dim])
|
||||
k = fused_rms_norm(k.astype("float32"), self.k_norm_weight_tensor, None, 1e-6)[0].astype(self.dtype)
|
||||
k = k.reshape([self.bsz, -1, self.num_kv_head, self.head_dim])
|
||||
k = k.transpose([0, 2, 1, 3])
|
||||
if len(v) > 1:
|
||||
v = paddle.concat(v, axis=1)
|
||||
@@ -127,7 +140,7 @@ class TestTreeMask(unittest.TestCase):
|
||||
.reshape([-1, self.num_q_head, self.head_dim])
|
||||
)
|
||||
|
||||
def run_append_c16_attention(self, q_len, kv_len, prefill=False, attn_mask=None):
|
||||
def run_append_c16_attention(self, q_len, kv_len, prefill=False, attn_mask=None, use_qknorm=False):
|
||||
if prefill:
|
||||
seq_lens_enc = [
|
||||
q_len,
|
||||
@@ -187,6 +200,10 @@ class TestTreeMask(unittest.TestCase):
|
||||
decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||
max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
|
||||
q_norm_weight = np.random.random([self.head_dim]) / 10
|
||||
k_norm_weight = np.random.random([self.head_dim]) / 10
|
||||
self.q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype="float32")
|
||||
self.k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype="float32")
|
||||
paddle.device.synchronize()
|
||||
(
|
||||
encoder_batch_ids,
|
||||
@@ -237,20 +254,20 @@ class TestTreeMask(unittest.TestCase):
|
||||
max_len_kv,
|
||||
rotary_embs,
|
||||
attn_mask,
|
||||
None,
|
||||
None,
|
||||
None, # qkv_bias
|
||||
None, # qkv_out_scales
|
||||
cache_k_scale,
|
||||
cache_v_scale,
|
||||
cache_k_out_scale,
|
||||
cache_v_out_scale,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None, # cache_k_zp
|
||||
None, # cache_v_zp
|
||||
None, # linear_shift
|
||||
None, # linear_smooth
|
||||
None, # mask_offset
|
||||
None, # kv_signal_data
|
||||
self.q_norm_weight_tensor if use_qknorm else None, # q_norm_weight
|
||||
self.k_norm_weight_tensor if use_qknorm else None, # k_norm_weight
|
||||
1e-6,
|
||||
"bf16",
|
||||
"none",
|
||||
@@ -271,7 +288,7 @@ class TestTreeMask(unittest.TestCase):
|
||||
paddle.device.synchronize()
|
||||
e_time = time.time()
|
||||
print(f"mean infer time: {np.mean((e_time - s_time) * 1000 / self.run_time):.2f}")
|
||||
return out[0].reshape([token_num, self.num_q_head, self.head_dim])
|
||||
return out.reshape([token_num, self.num_q_head, self.head_dim])
|
||||
|
||||
def test_naive_speculative_decoding(self):
|
||||
prefill_len = 8192
|
||||
@@ -279,10 +296,10 @@ class TestTreeMask(unittest.TestCase):
|
||||
total_len = prefill_len + dec_len_q
|
||||
mask = paddle.tril(paddle.ones((self.bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len)
|
||||
mask = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf")))
|
||||
self.run_append_c16_attention(prefill_len, 0, True)
|
||||
dec_out = self.run_append_c16_attention(dec_len_q, prefill_len, False)
|
||||
self.run_append_c16_attention(prefill_len, 0, True, use_qknorm=self.use_qknorm)
|
||||
dec_out = self.run_append_c16_attention(dec_len_q, prefill_len, False, use_qknorm=self.use_qknorm)
|
||||
|
||||
ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask)
|
||||
ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask, use_qknorm=self.use_qknorm)
|
||||
np.testing.assert_allclose(
|
||||
ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03
|
||||
)
|
||||
|
Reference in New Issue
Block a user