diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index 75e336b76..f1f5a6177 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -273,11 +273,15 @@ void AppendAttentionKernel( cache_v_zp, cache_quant_type_str, use_neox_rotary_style, + rope_3d, max_input_length, exec_stream, &qkv_out, const_cast(&key_cache), - const_cast(&value_cache)); + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); } else { SpeculateWriteCacheWithRoPEKernel( meta_data, @@ -296,11 +300,15 @@ void AppendAttentionKernel( cache_v_zp, cache_quant_type_str, use_neox_rotary_style, + rope_3d, max_input_length, exec_stream, &qkv_out, const_cast(&key_cache), - const_cast(&value_cache)); + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); } } else { if (qkv_out_scales) { diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh index c8273cd3c..75f9ebd8d 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh @@ -120,7 +120,6 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel( float row_variance = max(warp_m2 / head_size, 0.0f); float row_inv_var = Rsqrt(row_variance + rms_norm_eps); - if (hi < num_heads) { // q Load(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec); #pragma unroll @@ -129,6 +128,7 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel( } } else { // k Load(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec); + #pragma unroll for (int i = 0; i < VecSize; i++) { out_vec[i] = static_cast(tmp_vec[i] * row_inv_var * k_norm_vec[i]); } diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh index 57612c458..4fb5c93d0 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh @@ -18,6 +18,168 @@ #include "mma_tensor_op.cuh" #include "utils.cuh" +template +__global__ void append_speculate_cache_T_rope_qk_norm_kernel( + const InT* __restrict__ qkv, // [token_num, num_heads + 2 * gqa_group_size, + // head_size] + T* __restrict__ key_cache, // [num_blocks, gqa_group_size, block_size, + // head_size // 2] + T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size, + // head_size // 2] + T* __restrict__ q_out, + const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int* __restrict__ batch_id_per_token, // [num_tokens] + const int* __restrict__ cu_seqlens_q, + const int* __restrict__ seq_lens_decoder, // [bsz] + const float* __restrict__ cos_emb, + const float* __restrict__ sin_emb, + const float* + qkv_out_scales, // [(num_heads + 2 * gqa_group_size) * head_size] + const T* qkv_biases, // [num_head + 2 * gqa_group_size, dim_head] + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int output_inner_dim, + const int head_size, + const int block_size, + const int elem_cnt, + const int gqa_group_size, + const float* q_norm_weight, + const float* k_norm_weight, + const float rms_norm_eps, + const bool rope_3d) { + using LoadT = AlignedVector; + using LoadFloat = AlignedVector; + using LoadInT = AlignedVector; + constexpr int HalfVecSize = VecSize / 2; + using LoadEmbT = AlignedVector; + LoadInT src_vec; + LoadFloat scale_vec; + LoadT bias_vec; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + LoadFloat tmp_vec; + LoadFloat q_norm_vec; + LoadFloat k_norm_vec; + + int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y; + int64_t all_warp_num = gridDim.x * blockDim.y; + int64_t all_head_dim = elem_cnt / head_size; + + const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * head_size; + const int half_head_size = head_size / 2; + for (int global_hi = global_warp_idx; global_hi < all_head_dim; global_hi += all_warp_num) { + int64_t linear_index = global_hi * head_size + threadIdx.x * VecSize; + const int token_id = linear_index / hidden_size; + const int ori_bi = batch_id_per_token[token_id]; + if (seq_lens_decoder[ori_bi] == 0) continue; + const int bias = linear_index % hidden_size; + const int hi = bias / head_size; // q + k + v + const int h_bias = bias % head_size; + const int start_token_idx = cu_seqlens_q[ori_bi]; + const int write_seq_id = + seq_lens_decoder[ori_bi] + token_id - start_token_idx; + if (write_seq_id == 0) continue; + + const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq; + const int block_idx = block_table_now[write_seq_id / block_size]; + if (block_idx < 0) { + printf( + "Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var " + "%d %d %d %d\n", + block_idx, + write_seq_id, + ori_bi, + seq_lens_decoder[ori_bi], + token_id, + cu_seqlens_q[ori_bi]); + } + const int block_offset = write_seq_id % block_size; + + const int write_q_idx = + token_id * output_inner_dim * head_size + hi * head_size + h_bias; + + const int bias_idx = hi * head_size + h_bias; + Load(&qkv[linear_index], &src_vec); + if (qkv_biases) { + Load(&qkv_biases[bias_idx], &bias_vec); + } + if (qkv_out_scales) { + Load(&qkv_out_scales[bias_idx], &scale_vec); + } + if (hi < num_heads + gqa_group_size) { + // q k rope + const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2; + uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); + } + float thread_m2 = 0.0f; + float warp_m2 = 0.0f; +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + // add_bias + rope + float input_left = static_cast(src_vec[2 * i]); + float input_right = static_cast(src_vec[2 * i + 1]); + if (qkv_out_scales) { + input_left *= scale_vec[2 * i]; + input_right *= scale_vec[2 * i + 1]; + } + if (qkv_biases) { + input_left = input_left + static_cast(bias_vec[2 * i]); + input_right = input_right + static_cast(bias_vec[2 * i + 1]); + } + if (hi < num_heads + gqa_group_size) { + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + float tmp1 = input_left * cos_tmp - input_right * sin_tmp; + float tmp2 = input_right * cos_tmp + input_left * sin_tmp; + thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; + tmp_vec[2 * i] = tmp1; + tmp_vec[2 * i + 1] = tmp2; + } else { + bias_vec[2 * i] = static_cast(input_left); + bias_vec[2 * i + 1] = static_cast(input_right); + } + } + if (hi < (num_heads + gqa_group_size)) { + WelfordWarpAllReduce(thread_m2, &warp_m2); + float row_variance = + max(warp_m2 / head_size, 0.0f); + float row_inv_var = Rsqrt(row_variance + rms_norm_eps); + if (hi < num_heads) { + Load(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec); + #pragma unroll + for (int i = 0; i < VecSize; i++) { + bias_vec[i] = static_cast(tmp_vec[i] * row_inv_var * q_norm_vec[i]); + } + } else { + Load(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec); + #pragma unroll + for (int i = 0; i < VecSize; i++) { + bias_vec[i] = static_cast(tmp_vec[i] * row_inv_var * k_norm_vec[i]); + } + } + } + if (hi < num_heads) { + // write q + Store(bias_vec, &q_out[write_q_idx]); + } else { + // write k/v + const int kv_head_idx = (hi - num_heads) % gqa_group_size; + const int tgt_idx = (block_idx * gqa_group_size * block_size * head_size + + kv_head_idx * block_size * head_size + + block_offset * head_size + h_bias); + // write + if (hi < num_heads + gqa_group_size) { + Store(bias_vec, &key_cache[tgt_idx]); + } else { + Store(bias_vec, &value_cache[tgt_idx]); + } + } + } +} + template __global__ void append_clear_cache_int8_block( uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size, @@ -193,7 +355,8 @@ __global__ void append_speculate_cache_rope_kernel( const int head_size, const int block_size, const int elem_cnt, - const int gqa_group_size) { + const int gqa_group_size, + const bool rope_3d) { using LoadT = AlignedVector; using LoadFloat = AlignedVector; using LoadInT = AlignedVector; @@ -253,8 +416,9 @@ __global__ void append_speculate_cache_rope_kernel( if (hi < num_heads + gqa_group_size) { // q k rope const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2; - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); } #pragma unroll for (int i = 0; i < HalfVecSize; i++) { @@ -326,7 +490,8 @@ __global__ void append_speculate_cache_neox_rope_kernel( const int head_size, const int block_size, const int elem_cnt, - const int gqa_group_size) { + const int gqa_group_size, + const bool rope_3d) { using LoadT = AlignedVector; using LoadFloat = AlignedVector; using LoadInT = AlignedVector; @@ -390,8 +555,9 @@ __global__ void append_speculate_cache_neox_rope_kernel( if (hi < num_heads + gqa_group_size) { // q k rope const int64_t emb_idx = write_seq_id * head_size + h_bias; - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2: emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); } #pragma unroll for (int i = 0; i < VecSize; i++) { @@ -476,7 +642,8 @@ __global__ void append_speculate_cache_int8_rope_kernel( const int block_size, const float max_bound, const float min_bound, - const int gqa_group_size) { + const int gqa_group_size, + const bool rope_3d) { static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); constexpr int NUM_WARPS = 4; @@ -522,8 +689,9 @@ __global__ void append_speculate_cache_int8_rope_kernel( // q rope const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); if (qkv_out_scales) { Load(&qkv_out_scales[bias_idx], &out_scale_vec); } @@ -583,10 +751,11 @@ __global__ void append_speculate_cache_int8_rope_kernel( T scale; if (head_idx < num_heads + gqa_group_size) { const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - Load(&cos_emb[emb_idx], &cos_emb_vec1); - Load(&cos_emb[emb_idx + 4], &cos_emb_vec2); - Load(&sin_emb[emb_idx], &sin_emb_vec1); - Load(&sin_emb[emb_idx + 4], &sin_emb_vec2); + uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec1); + Load(&cos_emb[new_emb_idx + 4], &cos_emb_vec2); + Load(&sin_emb[new_emb_idx], &sin_emb_vec1); + Load(&sin_emb[new_emb_idx + 4], &sin_emb_vec2); scale = __ldg(&cache_k_scales[kv_head_idx]); } else { scale = __ldg(&cache_v_scales[kv_head_idx]); @@ -708,7 +877,8 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel( const int block_size, const float max_bound, const float min_bound, - const int gqa_group_size) { + const int gqa_group_size, + const bool rope_3d) { static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); constexpr int NUM_WARPS = 4; @@ -757,8 +927,9 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel( // q rope const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); if (qkv_out_scales) { Load(&qkv_out_scales[bias_idx_left], &left_out_scale_vec); @@ -853,10 +1024,11 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel( T scale; const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; - Load(&cos_emb[emb_idx], &cos_emb_vec1); - Load(&cos_emb[emb_idx + 8], &cos_emb_vec2); - Load(&sin_emb[emb_idx], &sin_emb_vec1); - Load(&sin_emb[emb_idx + 8], &sin_emb_vec2); + uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec1); + Load(&cos_emb[new_emb_idx + 8], &cos_emb_vec2); + Load(&sin_emb[new_emb_idx], &sin_emb_vec1); + Load(&sin_emb[new_emb_idx + 8], &sin_emb_vec2); scale = __ldg(&cache_k_scales[kv_head_idx]); #pragma unroll for (int i = 0; i < HALF_K_VEC_SIZE; i++) { @@ -1088,7 +1260,8 @@ __global__ void append_speculate_cache_int4_rope_kernel( const int block_size, const float max_bound, const float min_bound, - const int gqa_group_size) { + const int gqa_group_size, + const bool rope_3d) { static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); constexpr int NUM_WARPS = 4; @@ -1145,8 +1318,9 @@ __global__ void append_speculate_cache_int4_rope_kernel( // Load(&qkv_out_scales[bias_idx], &out_scale_vec); // q rope const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - Load(&cos_emb[emb_idx], &cos_emb_vec); - Load(&sin_emb[emb_idx], &sin_emb_vec); + uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); #pragma unroll for (int i = 0; i < HalfVecSize; i++) { // dequant + add_bias + rope @@ -1235,10 +1409,11 @@ __global__ void append_speculate_cache_int4_rope_kernel( // &out_scale_vec2); if (head_idx < num_heads + gqa_group_size) { const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2; - Load(&cos_emb[emb_idx], &cos_emb_vec1); - Load(&cos_emb[emb_idx + 4], &cos_emb_vec2); - Load(&sin_emb[emb_idx], &sin_emb_vec1); - Load(&sin_emb[emb_idx + 4], &sin_emb_vec2); + uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec1); + Load(&cos_emb[new_emb_idx + 4], &cos_emb_vec2); + Load(&sin_emb[new_emb_idx], &sin_emb_vec1); + Load(&sin_emb[new_emb_idx + 4], &sin_emb_vec2); Load(&cache_k_scales[cache_idx], &scale_vec1); Load(&cache_k_scales[cache_idx + 8], &scale_vec2); Load(&cache_k_zero_points[cache_idx], &zp_vec1); @@ -1431,7 +1606,8 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel( const int block_size, const float max_bound, const float min_bound, - const int gqa_group_size) { + const int gqa_group_size, + const bool rope_3d) { static_assert(HeadDim == 128, "just support HeadDim be 128 now!"); static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!"); constexpr int NUM_WARPS = 4; @@ -1581,10 +1757,11 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel( &right_out_scale_vec2); const uint32_t emb_idx = write_seq_id * HeadDim + head_bias; - Load(&cos_emb[emb_idx], &cos_emb_vec1); - Load(&cos_emb[emb_idx + 8], &cos_emb_vec2); - Load(&sin_emb[emb_idx], &sin_emb_vec1); - Load(&sin_emb[emb_idx + 8], &sin_emb_vec2); + uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx; + Load(&cos_emb[new_emb_idx], &cos_emb_vec1); + Load(&cos_emb[new_emb_idx + 8], &cos_emb_vec2); + Load(&sin_emb[new_emb_idx], &sin_emb_vec1); + Load(&sin_emb[new_emb_idx + 8], &sin_emb_vec2); Load(&cache_k_scales[left_cache_idx], &left_scale_vec1); Load(&cache_k_scales[left_cache_idx + 8], diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu index fb6a24fef..4fd07ae23 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu @@ -15,6 +15,78 @@ #include "speculate_write_cache_with_rope_kernel.h" #include "utils.cuh" +template +void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv, + T* key_cache, + T* value_cache, + T* qkv_out, + const int* block_tables, + const int* batch_id_per_token, + const int* cu_seqlens_q, + const int* seq_lens, + const int* seq_lens_encoder, + const float* cos_emb, + const float* sin_emb, + const float* qkv_out_scales, + const T* qkv_biases, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int kv_num_heads, + const int dim_head, + const int block_size, + const int bsz, + const int token_num, + const cudaStream_t& stream, + const bool use_neox_style, + const float* q_norm_weight, + const float* k_norm_weight, + const float rms_norm_eps, + const bool rope_3d) { + int output_inner_dim = num_heads + 2 * kv_num_heads; + const uint32_t elem_nums = + use_neox_style ? token_num * (num_heads + 2 * kv_num_heads) * dim_head / 2 + : token_num * (num_heads + 2 * kv_num_heads) * dim_head; + constexpr int HEAD_DIM = 128; + + constexpr int PackSize = HEAD_DIM / kWarpSize; + const int pack_num = elem_nums / PackSize; + const int blocksize = 128; + int grid_size = 1; + GetNumBlocks<128>(pack_num, &grid_size); + if (use_neox_style) { + PD_THROW( + "append_speculate_cache_rope_qk_norm not support neox rope yet"); + } else { + dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1); + append_speculate_cache_T_rope_qk_norm_kernel + <<>>(qkv, + key_cache, + value_cache, + qkv_out, + block_tables, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + cos_emb, + sin_emb, + qkv_out_scales, + qkv_biases, + max_seq_len, + max_blocks_per_seq, + num_heads, + output_inner_dim, + dim_head, + block_size, + elem_nums, + kv_num_heads, + q_norm_weight, + k_norm_weight, + rms_norm_eps, + rope_3d); + } +} + // rope + write template void append_speculate_cache_rope(const QKV_TYPE* qkv, @@ -39,7 +111,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv, const int bsz, const int token_num, const cudaStream_t& stream, - const bool use_neox_style) { + const bool use_neox_style, + const bool rope_3d) { int output_inner_dim = num_heads + 2 * kv_num_heads; const uint32_t elem_nums = @@ -73,7 +146,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv, dim_head, block_size, elem_nums, - kv_num_heads); + kv_num_heads, + rope_3d); } else { append_speculate_cache_rope_kernel <<>>( @@ -96,7 +170,8 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv, dim_head, block_size, elem_nums, - kv_num_heads); + kv_num_heads, + rope_3d); } } @@ -125,7 +200,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv, const int bsz, const int token_num, const cudaStream_t& stream, - const bool use_neox_style) { + const bool use_neox_style, + const bool rope_3d) { constexpr int num_warps = 4; const int all_warps = ((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps; @@ -167,7 +243,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv, block_size, 127.0f, -127.0f, - kv_num_heads); + kv_num_heads, + rope_3d); } else { append_speculate_cache_int8_rope_kernel <<>>(qkv, @@ -191,7 +268,8 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv, block_size, 127.0f, -127.0f, - kv_num_heads); + kv_num_heads, + rope_3d); } } @@ -222,7 +300,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv, const int bsz, const int token_num, const cudaStream_t& stream, - const bool use_neox_style) { + const bool use_neox_style, + const bool rope_3d) { constexpr int num_warps = 4; const int all_warps = ((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps; @@ -266,7 +345,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv, block_size, 7.0f, -8.0f, - kv_num_heads); + kv_num_heads, + rope_3d); } else { append_speculate_cache_int4_rope_kernel <<>>(qkv, @@ -292,7 +372,8 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv, block_size, 7.0f, -8.0f, - kv_num_heads); + kv_num_heads, + rope_3d); } } template @@ -313,11 +394,15 @@ void SpeculateWriteCacheWithRoPEKernel( const paddle::optional& cache_v_zp, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, + const bool rope_3d, const int max_seq_len, cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out) { + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps) { typedef cascade_attn_type_traits traits_; typedef cascade_attn_type_traits qkt_nv_type_; typedef typename traits_::type DataType_; @@ -342,142 +427,185 @@ void SpeculateWriteCacheWithRoPEKernel( ? rotary_embs.get().data() + max_seq_len * dim_head : rotary_embs.get().data() + max_seq_len * dim_head / 2; } - if (cache_quant_type_str == "none") { - append_speculate_cache_rope( - reinterpret_cast(qkv_ptr), - reinterpret_cast(key_cache_out->data()), - reinterpret_cast(value_cache_out->data()), - reinterpret_cast(qkv_out->data()), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? reinterpret_cast( - const_cast(qkv_biases.get().data())) - : nullptr, - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads, - dim_head, - block_size, - bsz, - token_nums, - stream, - use_neox_rotary_style); - } else if (cache_quant_type_str == "cache_int8") { - append_speculate_cache_int8_rope( - reinterpret_cast(qkv_ptr), - key_cache_out->data(), - value_cache_out->data(), - reinterpret_cast(qkv_out->data()), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? reinterpret_cast( - const_cast(qkv_biases.get().data())) - : nullptr, - cache_k_scale ? reinterpret_cast( - const_cast(cache_k_scale.get().data())) - : nullptr, - cache_v_scale ? reinterpret_cast( - const_cast(cache_v_scale.get().data())) - : nullptr, - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads, - dim_head, - block_size, - bsz, - token_nums, - stream, - use_neox_rotary_style); - } else if (cache_quant_type_str == "cache_fp8") { - append_speculate_cache_int8_rope( - reinterpret_cast(qkv_ptr), - key_cache_out->data(), - value_cache_out->data(), - reinterpret_cast(qkv_out->data()), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? reinterpret_cast( - const_cast(qkv_biases.get().data())) - : nullptr, - cache_k_scale ? reinterpret_cast( - const_cast(cache_k_scale.get().data())) - : nullptr, - cache_v_scale ? reinterpret_cast( - const_cast(cache_v_scale.get().data())) - : nullptr, - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads, - dim_head, - block_size, - bsz, - token_nums, - stream, - use_neox_rotary_style); - } else if (cache_quant_type_str == "cache_int4_zp") { - append_speculate_cache_int4_rope( - reinterpret_cast(qkv_ptr), - key_cache_out->data(), - value_cache_out->data(), - reinterpret_cast(const_cast(qkv_out->data())), - block_tables.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - qkv_out_scales ? qkv_out_scales.get().data() : nullptr, - qkv_biases ? reinterpret_cast( - const_cast(qkv_biases.get().data())) - : nullptr, - cache_k_scale ? reinterpret_cast( - const_cast(cache_k_scale.get().data())) - : nullptr, - cache_v_scale ? reinterpret_cast( - const_cast(cache_v_scale.get().data())) - : nullptr, - cache_k_zp ? reinterpret_cast( - const_cast(cache_k_zp.get().data())) - : nullptr, - cache_v_zp ? reinterpret_cast( - const_cast(cache_v_zp.get().data())) - : nullptr, - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads, - dim_head, - block_size, - bsz, - token_nums, - stream, - use_neox_rotary_style); + if (q_norm_weight && k_norm_weight) { + if (cache_quant_type_str == "none") { + append_speculate_cache_rope_qk_norm( + reinterpret_cast(qkv_ptr), + reinterpret_cast(key_cache_out->data()), + reinterpret_cast(value_cache_out->data()), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) + : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + token_nums, + stream, + use_neox_rotary_style, + reinterpret_cast(q_norm_weight.get().data()), + reinterpret_cast(k_norm_weight.get().data()), + rms_norm_eps, + rope_3d); + } else { + PD_THROW( + "append_decode_cache_rope_qk_norm not support cachekv quant yet"); + } + } else { - PD_THROW( - "cache_quant_type_str should be one of [none, cache_int8, " - "cache_int4_zp]"); + if (cache_quant_type_str == "none") { + append_speculate_cache_rope( + reinterpret_cast(qkv_ptr), + reinterpret_cast(key_cache_out->data()), + reinterpret_cast(value_cache_out->data()), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) + : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + token_nums, + stream, + use_neox_rotary_style, + rope_3d); + } else if (cache_quant_type_str == "cache_int8") { + append_speculate_cache_int8_rope( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) + : nullptr, + cache_k_scale ? reinterpret_cast( + const_cast(cache_k_scale.get().data())) + : nullptr, + cache_v_scale ? reinterpret_cast( + const_cast(cache_v_scale.get().data())) + : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + token_nums, + stream, + use_neox_rotary_style, + rope_3d); + } else if (cache_quant_type_str == "cache_fp8") { + append_speculate_cache_int8_rope( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) + : nullptr, + cache_k_scale ? reinterpret_cast( + const_cast(cache_k_scale.get().data())) + : nullptr, + cache_v_scale ? reinterpret_cast( + const_cast(cache_v_scale.get().data())) + : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + token_nums, + stream, + use_neox_rotary_style, + rope_3d); + } else if (cache_quant_type_str == "cache_int4_zp") { + append_speculate_cache_int4_rope( + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(const_cast(qkv_out->data())), + block_tables.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + qkv_out_scales ? qkv_out_scales.get().data() : nullptr, + qkv_biases ? reinterpret_cast( + const_cast(qkv_biases.get().data())) + : nullptr, + cache_k_scale ? reinterpret_cast( + const_cast(cache_k_scale.get().data())) + : nullptr, + cache_v_scale ? reinterpret_cast( + const_cast(cache_v_scale.get().data())) + : nullptr, + cache_k_zp ? reinterpret_cast( + const_cast(cache_k_zp.get().data())) + : nullptr, + cache_v_zp ? reinterpret_cast( + const_cast(cache_v_zp.get().data())) + : nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads, + dim_head, + block_size, + bsz, + token_nums, + stream, + use_neox_rotary_style, + rope_3d); + } else { + PD_THROW( + "cache_quant_type_str should be one of [none, cache_int8, " + "cache_int4_zp]"); + } } } @@ -500,11 +628,15 @@ template void SpeculateWriteCacheWithRoPEKernel( const paddle::optional& cache_v_zp, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, + const bool rope_3d, const int max_seq_len, cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); template void SpeculateWriteCacheWithRoPEKernel( @@ -526,11 +658,15 @@ SpeculateWriteCacheWithRoPEKernel( const paddle::optional& cache_v_zp, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, + const bool rope_3d, const int max_seq_len, cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); template void SpeculateWriteCacheWithRoPEKernel( const AppendAttnMetaData& meta_data, @@ -551,11 +687,15 @@ template void SpeculateWriteCacheWithRoPEKernel( const paddle::optional& cache_v_zp, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, + const bool rope_3d, const int max_seq_len, cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); template void @@ -578,8 +718,12 @@ SpeculateWriteCacheWithRoPEKernel( const paddle::optional& cache_v_zp, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, + const bool rope_3d, const int max_seq_len, cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.h b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.h index 40ab34e05..2db42bc26 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.h +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.h @@ -35,8 +35,12 @@ void SpeculateWriteCacheWithRoPEKernel( const paddle::optional& cache_v_zp, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, + const bool rope_3d, const int max_seq_len, cudaStream_t& stream, paddle::Tensor* qkv_out, paddle::Tensor* key_cache_out, - paddle::Tensor* value_cache_out); + paddle::Tensor* value_cache_out, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps); diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 027a33dc0..1ced2ce6f 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -378,9 +378,11 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags, const paddle::Tensor &step_seq_lens_decoder, const paddle::Tensor &block_tables, const paddle::Tensor &is_block_step, - const int block_size); - - + const paddle::optional &draft_tokens, + const paddle::optional &step_draft_tokens, + const paddle::optional &step_seq_lens_this_time, + const int block_size, + const int max_draft_tokens); paddle::Tensor GroupSwigluWithMasked(const paddle::Tensor &fc1_out_tensor, @@ -707,6 +709,22 @@ void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens, void SpeculateClearAcceptNums(const paddle::Tensor& accept_num, const paddle::Tensor& seq_lens_decoder); +void SpeculateScheduleCache(const paddle::Tensor &draft_tokens, + const paddle::Tensor &block_tables, + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &step_seq_lens_decoder, + const paddle::Tensor &step_draft_tokens, + const paddle::Tensor &step_seq_lens_this_time, + const paddle::Tensor &accept_num, + const paddle::Tensor &accept_tokens, + const paddle::Tensor &is_block_step, + const paddle::Tensor ¬_need_stop, + const paddle::Tensor &stop_nums, + const int block_size, + const int max_draft_tokens); + void NgramMatch(const paddle::Tensor &input_ids, const paddle::Tensor &input_ids_len, const paddle::Tensor &pre_ids, @@ -750,6 +768,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& step_idx, const paddle::Tensor& not_need_stop, + const paddle::Tensor& is_block_step, const paddle::Tensor& batch_drop, const paddle::Tensor& pre_ids, const paddle::Tensor& accept_tokens, @@ -763,7 +782,8 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, const paddle::Tensor& base_model_draft_tokens, const int max_draft_token, const bool truncate_first_token, - const bool splitwise_prefill); + const bool splitwise_prefill, + const bool kvcache_scheduler_v1); void DraftModelUpdate(const paddle::Tensor& inter_next_tokens, @@ -1228,6 +1248,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("speculate_clear_accept_nums",&SpeculateClearAcceptNums, "speculate_clear_accept_nums function"); + m.def("speculate_schedule_cache",&SpeculateScheduleCache, "SpeculateScheduleCache function"); + m.def("ngram_match", &NgramMatch, "ngram_match function"); m.def("hybird_mtp_ngram", &HybridMtpNgram, "ngram_match_mixed function"); diff --git a/custom_ops/gpu_ops/recover_decode_task.cu b/custom_ops/gpu_ops/recover_decode_task.cu index 88c7dd51c..ae4e77ad6 100644 --- a/custom_ops/gpu_ops/recover_decode_task.cu +++ b/custom_ops/gpu_ops/recover_decode_task.cu @@ -15,31 +15,72 @@ #include "helper.h" __global__ void recover_decode_task(bool *stop_flags, - int *seq_lens_this_time, - int *seq_lens_encoder, - int *seq_lens_decoder, - int *step_seq_lens_decoder, - int *block_tables, - bool *is_block_step, - const int bsz, - const int block_num_per_seq, - const int block_size) { + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int *step_seq_lens_decoder, + int *block_tables, + bool *is_block_step, + const int bsz, + const int block_num_per_seq, + const int block_size) { int thread_idx = threadIdx.x; if (thread_idx < bsz) { if(is_block_step[thread_idx] == true) { int *block_table_now = block_tables + thread_idx * block_num_per_seq; if (block_table_now[step_seq_lens_decoder[thread_idx] / block_size] != -1) { - // can be recovered for decoding - is_block_step[thread_idx] = false; - seq_lens_this_time[thread_idx]= 1; - stop_flags[thread_idx] = false; - seq_lens_encoder[thread_idx] = 0; - seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx]; - } + // can be recovered for decoding + is_block_step[thread_idx] = false; + seq_lens_this_time[thread_idx]= 1; + stop_flags[thread_idx] = false; + seq_lens_encoder[thread_idx] = 0; + seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx]; + + } } } } +__global__ void recover_spec_decode_task(bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int *step_seq_lens_decoder, + int *block_tables, + bool *is_block_step, + int64_t *draft_tokens, + const int64_t *step_draft_tokens, + const int *step_seq_lens_this_time, + const int bsz, + const int block_num_per_seq, + const int block_size, + const int draft_tokens_len, + const int num_extra_tokens) { + int thread_idx = threadIdx.x; + if (thread_idx < bsz) { + if(is_block_step[thread_idx] == true) { + int *block_table_now = block_tables + thread_idx * block_num_per_seq; + int max_possible_block_idx = (step_seq_lens_decoder[thread_idx] + num_extra_tokens) / block_size; + max_possible_block_idx = min(max_possible_block_idx, block_num_per_seq); + if (block_table_now[max_possible_block_idx] != -1) { + // can be recovered for decoding + int64_t *draft_tokens_now = draft_tokens + thread_idx * draft_tokens_len; + const int64_t *step_draft_tokens_now = step_draft_tokens + thread_idx * draft_tokens_len; + is_block_step[thread_idx] = false; + seq_lens_this_time[thread_idx] = step_seq_lens_this_time[thread_idx]; + stop_flags[thread_idx] = false; + seq_lens_encoder[thread_idx] = 0; + seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx]; + for (int i = 0; i < seq_lens_this_time[thread_idx]; i++) { + draft_tokens_now[i] = step_draft_tokens_now[i]; + } + + } + } + } +} + + void RecoverDecodeTask(const paddle::Tensor &stop_flags, const paddle::Tensor &seq_lens_this_time, const paddle::Tensor &seq_lens_encoder, @@ -47,7 +88,11 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags, const paddle::Tensor &step_seq_lens_decoder, const paddle::Tensor &block_tables, const paddle::Tensor &is_block_step, - const int block_size) { + const paddle::optional &draft_tokens, + const paddle::optional &step_draft_tokens, + const paddle::optional &step_seq_lens_this_time, + const int block_size, + const int max_draft_tokens) { #ifdef PADDLE_WITH_CUSTOM_DEVICE auto dev_ctx = static_cast(paddle::experimental::DeviceContextPool::Instance().Get(seq_lens_this_time.place())); auto cu_stream = dev_ctx->stream(); @@ -56,17 +101,38 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags, #endif const int bsz = seq_lens_this_time.shape()[0]; const int block_num_per_seq = block_tables.shape()[1]; - recover_decode_task<<<1, 1024, 0, cu_stream>>>( - const_cast(stop_flags.data()), - const_cast(seq_lens_this_time.data()), - const_cast(seq_lens_encoder.data()), - const_cast(seq_lens_decoder.data()), - const_cast(step_seq_lens_decoder.data()), - const_cast(block_tables.data()), - const_cast(is_block_step.data()), - bsz, - block_num_per_seq, - block_size); + if (draft_tokens) { + const int draft_tokens_len = draft_tokens.get_ptr()->shape()[1]; + recover_spec_decode_task<<<1, 1024, 0, cu_stream>>>( + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(step_seq_lens_decoder.data()), + const_cast(block_tables.data()), + const_cast(is_block_step.data()), + const_cast(draft_tokens.get_ptr()->data()), + step_draft_tokens.get_ptr()->data(), + step_seq_lens_this_time.get_ptr()->data(), + bsz, + block_num_per_seq, + block_size, + draft_tokens_len, + max_draft_tokens * 2 + 1); + + } else { + recover_decode_task<<<1, 1024, 0, cu_stream>>>( + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(step_seq_lens_decoder.data()), + const_cast(block_tables.data()), + const_cast(is_block_step.data()), + bsz, + block_num_per_seq, + block_size); + } } PD_BUILD_STATIC_OP(recover_decode_task) @@ -76,8 +142,11 @@ PD_BUILD_STATIC_OP(recover_decode_task) "seq_lens_decoder", "step_seq_lens_decoder", "block_tables", - "is_block_step"}) - .Attrs({"block_size: int"}) + "is_block_step", + paddle::Optional("draft_tokens"), + paddle::Optional("step_draft_tokens"), + paddle::Optional("step_seq_lens_this_time")}) + .Attrs({"block_size: int", "max_draft_tokens: int"}) .Outputs({"seq_lens_this_time_out", "seq_lens_encoder_out", "seq_lens_decoder_out", diff --git a/custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_preprocess.cu b/custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_preprocess.cu index 573f6fb68..051d20a03 100644 --- a/custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_preprocess.cu +++ b/custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_preprocess.cu @@ -15,7 +15,48 @@ #include "helper.h" #include "paddle/extension.h" -template + +#define DISPATCH_BLOCKSIZE(BLOCK_SIZE, ...) \ + do { \ + constexpr int BlockSize = BLOCK_SIZE; \ + __VA_ARGS__; \ + } while (0) + +#define DISPATCH_TRUNCATE_FIRST_TOKEN(truncate_first_token, TRUNCATE_FIRST_TOKEN, ...) \ + do { \ + if (truncate_first_token) { \ + constexpr bool TRUNCATE_FIRST_TOKEN = true; \ + __VA_ARGS__; \ + } else { \ + constexpr bool TRUNCATE_FIRST_TOKEN = false; \ + __VA_ARGS__; \ + } \ + } while (0) + +#define DISPATCH_KVCACHE_SCHEDULER(kvcache_scheduler_v1, KVCACHE_SCHEDULER_V1, ...) \ + do { \ + if (kvcache_scheduler_v1) { \ + constexpr bool KVCACHE_SCHEDULER_V1 = true; \ + __VA_ARGS__; \ + } else { \ + constexpr bool KVCACHE_SCHEDULER_V1 = false; \ + __VA_ARGS__; \ + } \ + } while (0) + +#define DISPATCH_SPLITWISE_PREFILL(splitwise_prefill, SPLITWISE_PREFILL, ...) \ + do { \ + if (splitwise_prefill) { \ + constexpr bool SPLITWISE_PREFILL = true; \ + __VA_ARGS__; \ + } else { \ + constexpr bool SPLITWISE_PREFILL = false; \ + __VA_ARGS__; \ + } \ + } while (0) + + +template __global__ void process_splitwise_prefill( int64_t* draft_tokens, int64_t* input_ids, @@ -25,6 +66,7 @@ __global__ void process_splitwise_prefill( int* seq_lens_decoder, int64_t* step_idx, bool* not_need_stop, + bool* is_block_step, bool* batch_drop, int64_t* pre_ids, const int64_t* accept_tokens, @@ -58,7 +100,7 @@ __global__ void process_splitwise_prefill( stop_flags[tid] = false; int64_t base_model_first_token = accept_tokens_now[0]; int position = seq_len_encoder; - if (TRCUNCATE_FIRST_TOKEN) { + if (TRUNCATE_FIRST_TOKEN) { input_ids_now[position - 1] = base_model_first_token; seq_lens_this_time[tid] = seq_len_encoder; } else { @@ -84,7 +126,7 @@ __global__ void process_splitwise_prefill( -template +template __global__ void draft_model_preprocess_kernel( int64_t* draft_tokens, int64_t* input_ids, @@ -94,6 +136,7 @@ __global__ void draft_model_preprocess_kernel( int* seq_lens_decoder, int64_t* step_idx, bool* not_need_stop, + bool* is_block_step, bool* batch_drop, int64_t* pre_ids, const int64_t* accept_tokens, @@ -134,14 +177,26 @@ __global__ void draft_model_preprocess_kernel( base_model_draft_tokens_now[i] = -1; } - if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) { - batch_drop[tid] = true; - stop_flags[tid] = true; + // 1. process block_step situation + // -- In v0 mode, block_step will drop mtp query. + // -- In v1 mode, block_step will continue to infer. + if constexpr(KVCACHE_SCHEDULER_V1) { + if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) { + stop_flags[tid] = true; + is_block_step[tid] = true; + // Need to continue infer + } + } else { + if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) { + batch_drop[tid] = true; + stop_flags[tid] = true; + } } + // 2. process normal query, not in any special case. if (!(base_model_stop_flags[tid] || batch_drop[tid])) { not_stop_flag = 1; - // 1. first token + // prefill generation if (seq_lens_encoder[tid] > 0) { // Can be extended to first few tokens int seq_len_encoder = seq_lens_encoder[tid]; @@ -149,14 +204,20 @@ __global__ void draft_model_preprocess_kernel( int64_t base_model_first_token = accept_tokens_now[0]; pre_ids_now[0] = base_model_first_token; int position = seq_len_encoder; - if (TRCUNCATE_FIRST_TOKEN) { + if (TRUNCATE_FIRST_TOKEN) { input_ids_now[position - 1] = base_model_first_token; seq_lens_this_time[tid] = seq_len_encoder; } else { input_ids_now[position] = base_model_first_token; seq_lens_this_time[tid] = seq_len_encoder + 1; } - } else { + } else { // decode generation + if constexpr (KVCACHE_SCHEDULER_V1) { + // 3. try to recover mtp infer in V1 mode + if (!base_model_is_block_step[tid] && is_block_step[tid]) { + is_block_step[tid] = false; + } + } if (stop_flags[tid]) { stop_flags[tid] = false; // TODO: check @@ -189,99 +250,8 @@ __global__ void draft_model_preprocess_kernel( } } -template -void DispatchRunner( - const cudaStream_t& stream, - int64_t* draft_tokens, - int64_t* input_ids, - bool* stop_flags, - int* seq_lens_this_time, - int* seq_lens_encoder, - int* seq_lens_decoder, - int64_t* step_idx, - bool* not_need_stop, - bool* batch_drop, - int64_t* pre_ids, - const int64_t* accept_tokens, - const int* accept_num, - const int* base_model_seq_lens_this_time, - const int* base_model_seq_lens_encoder, - const int* base_model_seq_lens_decoder, - const int64_t* base_model_step_idx, - const bool* base_model_stop_flags, - const bool* base_model_is_block_step, - int64_t* base_model_draft_tokens, - const int bsz, - const int num_model_step, - const int accept_tokens_len, - const int draft_tokens_len, - const int input_ids_len, - const int base_model_draft_tokens_len, - const int pre_ids_len, - const bool splitwise_prefill) { - constexpr int BlockSize = 512; - if (splitwise_prefill) { - process_splitwise_prefill - <<<1, BlockSize, 0, stream>>>( - draft_tokens, - input_ids, - stop_flags, - seq_lens_this_time, - seq_lens_encoder, - seq_lens_decoder, - step_idx, - not_need_stop, - batch_drop, - pre_ids, - accept_tokens, - accept_num, - base_model_seq_lens_this_time, - base_model_seq_lens_encoder, - base_model_seq_lens_decoder, - base_model_step_idx, - base_model_stop_flags, - base_model_is_block_step, - base_model_draft_tokens, - bsz, - num_model_step, - accept_tokens_len, - draft_tokens_len, - input_ids_len, - base_model_draft_tokens_len, - pre_ids_len); - } else { - draft_model_preprocess_kernel - <<<1, BlockSize, 0, stream>>>( - draft_tokens, - input_ids, - stop_flags, - seq_lens_this_time, - seq_lens_encoder, - seq_lens_decoder, - step_idx, - not_need_stop, - batch_drop, - pre_ids, - accept_tokens, - accept_num, - base_model_seq_lens_this_time, - base_model_seq_lens_encoder, - base_model_seq_lens_decoder, - base_model_step_idx, - base_model_stop_flags, - base_model_is_block_step, - base_model_draft_tokens, - bsz, - num_model_step, - accept_tokens_len, - draft_tokens_len, - input_ids_len, - base_model_draft_tokens_len, - pre_ids_len); - } -} -void DispatchTokenMode( +void DispatchRunner( const cudaStream_t &stream, int64_t* draft_tokens, int64_t* input_ids, @@ -291,6 +261,7 @@ void DispatchTokenMode( int* seq_lens_decoder, int64_t* step_idx, bool* not_need_stop, + bool* is_block_step, bool* batch_drop, int64_t* pre_ids, const int64_t* accept_tokens, @@ -310,75 +281,79 @@ void DispatchTokenMode( const int base_model_draft_tokens_len, const int pre_ids_len, const bool truncate_first_token, - const bool splitwise_prefill) { - if (truncate_first_token) { - DispatchRunner( - stream, - draft_tokens, - input_ids, - stop_flags, - seq_lens_this_time, - seq_lens_encoder, - seq_lens_decoder, - step_idx, - not_need_stop, - batch_drop, - pre_ids, - accept_tokens, - accept_num, - base_model_seq_lens_this_time, - base_model_seq_lens_encoder, - base_model_seq_lens_decoder, - base_model_step_idx, - base_model_stop_flags, - base_model_is_block_step, - base_model_draft_tokens, - bsz, - num_model_step, - accept_tokens_len, - draft_tokens_len, - input_ids_len, - base_model_draft_tokens_len, - pre_ids_len, - splitwise_prefill - ); - } else { - DispatchRunner( - stream, - draft_tokens, - input_ids, - stop_flags, - seq_lens_this_time, - seq_lens_encoder, - seq_lens_decoder, - step_idx, - not_need_stop, - batch_drop, - pre_ids, - accept_tokens, - accept_num, - base_model_seq_lens_this_time, - base_model_seq_lens_encoder, - base_model_seq_lens_decoder, - base_model_step_idx, - base_model_stop_flags, - base_model_is_block_step, - base_model_draft_tokens, - bsz, - num_model_step, - accept_tokens_len, - draft_tokens_len, - input_ids_len, - base_model_draft_tokens_len, - pre_ids_len, - splitwise_prefill - ); - } + const bool splitwise_prefill, + const bool kvcache_scheduler_v1) { + DISPATCH_BLOCKSIZE(512, { + DISPATCH_TRUNCATE_FIRST_TOKEN(truncate_first_token, TRUNCATE_FIRST_TOKEN, { + DISPATCH_KVCACHE_SCHEDULER(kvcache_scheduler_v1, KVCACHE_SCHEDULER_V1, { + DISPATCH_SPLITWISE_PREFILL(splitwise_prefill, SPLITWISE_PREFILL, { + if constexpr (SPLITWISE_PREFILL) { + process_splitwise_prefill + <<<1, BlockSize, 0, stream>>>( + draft_tokens, + input_ids, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + not_need_stop, + is_block_step, + batch_drop, + pre_ids, + accept_tokens, + accept_num, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + base_model_step_idx, + base_model_stop_flags, + base_model_is_block_step, + base_model_draft_tokens, + bsz, + num_model_step, + accept_tokens_len, + draft_tokens_len, + input_ids_len, + base_model_draft_tokens_len, + pre_ids_len); + } else { + draft_model_preprocess_kernel + <<<1, BlockSize, 0, stream>>>( + draft_tokens, + input_ids, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + not_need_stop, + is_block_step, + batch_drop, + pre_ids, + accept_tokens, + accept_num, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + base_model_step_idx, + base_model_stop_flags, + base_model_is_block_step, + base_model_draft_tokens, + bsz, + num_model_step, + accept_tokens_len, + draft_tokens_len, + input_ids_len, + base_model_draft_tokens_len, + pre_ids_len); + } + }); + }); + }); + }); } - - - void DraftModelPreprocess(const paddle::Tensor& draft_tokens, const paddle::Tensor& input_ids, const paddle::Tensor& stop_flags, @@ -387,6 +362,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& step_idx, const paddle::Tensor& not_need_stop, + const paddle::Tensor& is_block_step, const paddle::Tensor& batch_drop, const paddle::Tensor& pre_ids, const paddle::Tensor& accept_tokens, @@ -400,7 +376,8 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, const paddle::Tensor& base_model_draft_tokens, const int num_model_step, const bool truncate_first_token, - const bool splitwise_prefill) { + const bool splitwise_prefill, + const bool kvcache_scheduler_v1) { int real_bsz = seq_lens_this_time.shape()[0]; int accept_tokens_len = accept_tokens.shape()[1]; int input_ids_len = input_ids.shape()[1]; @@ -412,36 +389,38 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, auto not_need_stop_gpu = not_need_stop.copy_to(seq_lens_this_time.place(), false); - DispatchTokenMode( - cu_stream, - const_cast(draft_tokens.data()), - const_cast(input_ids.data()), - const_cast(stop_flags.data()), - const_cast(seq_lens_this_time.data()), - const_cast(seq_lens_encoder.data()), - const_cast(seq_lens_decoder.data()), - const_cast(step_idx.data()), - const_cast(not_need_stop_gpu.data()), - const_cast(batch_drop.data()), - const_cast(pre_ids.data()), - accept_tokens.data(), - accept_num.data(), - base_model_seq_lens_this_time.data(), - base_model_seq_lens_encoder.data(), - base_model_seq_lens_decoder.data(), - base_model_step_idx.data(), - base_model_stop_flags.data(), - base_model_is_block_step.data(), - const_cast(base_model_draft_tokens.data()), - real_bsz, - num_model_step, - accept_tokens_len, - draft_tokens_len, - input_ids_len, - base_model_draft_tokens_len, - pre_ids_len, - truncate_first_token, - splitwise_prefill); + DispatchRunner( + cu_stream, + const_cast(draft_tokens.data()), + const_cast(input_ids.data()), + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(step_idx.data()), + const_cast(not_need_stop_gpu.data()), + const_cast(is_block_step.data()), + const_cast(batch_drop.data()), + const_cast(pre_ids.data()), + accept_tokens.data(), + accept_num.data(), + base_model_seq_lens_this_time.data(), + base_model_seq_lens_encoder.data(), + base_model_seq_lens_decoder.data(), + base_model_step_idx.data(), + base_model_stop_flags.data(), + base_model_is_block_step.data(), + const_cast(base_model_draft_tokens.data()), + real_bsz, + num_model_step, + accept_tokens_len, + draft_tokens_len, + input_ids_len, + base_model_draft_tokens_len, + pre_ids_len, + truncate_first_token, + splitwise_prefill, + kvcache_scheduler_v1); auto not_need_stop_cpu = not_need_stop_gpu.copy_to(not_need_stop.place(), false); @@ -459,6 +438,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess) "seq_lens_decoder", "step_idx", "not_need_stop", + "is_block_step", "batch_drop", "pre_ids", "accept_tokens", @@ -480,7 +460,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess) "not_need_stop_out", "batch_drop_out", "pre_ids_out"}) - .Attrs({"num_model_step: int", "truncate_first_token: bool", "splitwise_prefill: bool"}) + .Attrs({"num_model_step: int", "truncate_first_token: bool", "splitwise_prefill: bool", "kvcache_scheduler_v1: bool"}) .SetInplaceMap({{"draft_tokens", "draft_tokens_out"}, {"input_ids", "input_ids_out"}, {"stop_flags", "stop_flags_out"}, diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_schedule_cache.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_schedule_cache.cu new file mode 100644 index 000000000..633c5bb4d --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_schedule_cache.cu @@ -0,0 +1,176 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" + +template +__global__ void speculate_schedula_cache( + const int64_t *draft_tokens, + int *block_tables, + bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *step_seq_lens_decoder, + int64_t *step_draft_tokens, + int *step_seq_lens_this_time, + int *accept_num, + int64_t *accept_tokens, + bool *is_block_step, + bool *not_need_stop, + const int64_t *stop_nums, + const int real_bsz, + const int max_bsz, + const int max_next_step_tokens, + const int draft_tokens_len, + const int accept_tokens_len, + const int block_size, + const int block_num_per_seq) { + const int bid = threadIdx.x; + int stop_flag_now_int = 0; + if (bid < real_bsz) { + if (!stop_flags[bid]) { + const int64_t *draft_tokens_now = draft_tokens + bid * draft_tokens_len; + int64_t *step_draft_tokens_now = step_draft_tokens + bid * draft_tokens_len; + int *block_table_now = block_tables + bid * block_num_per_seq; + int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len; + const int max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) / block_size; + if (max_possible_block_idx < block_num_per_seq && block_table_now[max_possible_block_idx] == -1) { + is_block_step[bid] = true; + step_seq_lens_this_time[bid] = seq_lens_this_time[bid]; + seq_lens_this_time[bid] = 0; + stop_flags[bid] = true; + stop_flag_now_int = 1; + step_seq_lens_decoder[bid] = seq_lens_decoder[bid]; + seq_lens_decoder[bid] = 0; + accept_num[bid] = 0; + for (int i = 0; i < accept_tokens_len; i++) { + accept_tokens_now[i] = -1; + } + for (int i = 0; i < draft_tokens_len; i++) { + step_draft_tokens_now[i] = draft_tokens_now[i]; + } + } + } else { + stop_flag_now_int = 1; + } + } else if (bid >= real_bsz && bid < max_bsz) { + stop_flag_now_int = 1; + } + __syncthreads(); + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + // printf("stop_flag_now_int %d \n", stop_flag_now_int); + int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int); + + if (threadIdx.x == 0) { + // printf("stop_sum %d \n", stop_sum); + not_need_stop[0] = stop_sum < stop_nums[0]; + } +} + +void SpeculateScheduleCache(const paddle::Tensor &draft_tokens, + const paddle::Tensor &block_tables, + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &step_seq_lens_decoder, + const paddle::Tensor &step_draft_tokens, + const paddle::Tensor &step_seq_lens_this_time, + const paddle::Tensor &accept_num, + const paddle::Tensor &accept_tokens, + const paddle::Tensor &is_block_step, + const paddle::Tensor ¬_need_stop, + const paddle::Tensor &stop_nums, + const int block_size, + const int max_draft_tokens) { + const int real_bsz = seq_lens_this_time.shape()[0]; + const int max_bsz = stop_flags.shape()[0]; + const int accept_tokens_len = accept_tokens.shape()[1]; + const int draft_token_len = draft_tokens.shape()[1]; + const int block_num_per_seq = block_tables.shape()[1]; + + constexpr int BlockSize = 512; + const int max_next_step_tokens = 2 * max_draft_tokens + 2; + + auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false); + speculate_schedula_cache<<<1, BlockSize, 0, seq_lens_this_time.stream()>>>( + draft_tokens.data(), + const_cast(block_tables.data()), + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_decoder.data()), + const_cast(step_seq_lens_decoder.data()), + const_cast(step_draft_tokens.data()), + const_cast(step_seq_lens_this_time.data()), + const_cast(accept_num.data()), + const_cast(accept_tokens.data()), + const_cast(is_block_step.data()), + const_cast(not_need_stop_gpu.data()), + stop_nums.data(), + real_bsz, + max_bsz, + max_next_step_tokens, + draft_token_len, + accept_tokens_len, + block_size, + block_num_per_seq + ); + + auto not_need_stop_cpu = + not_need_stop_gpu.copy_to(not_need_stop.place(), true); + bool *not_need_stop_data = const_cast(not_need_stop.data()); + not_need_stop_data[0] = not_need_stop_cpu.data()[0]; +} + +PD_BUILD_STATIC_OP(speculate_schedule_cache) + .Inputs({"draft_tokens", + "block_tables", + "stop_flags", + "seq_lens_this_time", + "seq_lens_decoder", + "step_seq_lens_decoder", + "step_draft_tokens", + "step_seq_lens_this_time", + "accept_num", + "accept_tokens", + "is_block_step", + "not_need_stop", + "stop_nums"}) + .Attrs({"block_size: int", "max_draft_tokens: int"}) + .Outputs({"draft_tokens_out", + "block_tables_out", + "stop_flags_out", + "seq_lens_this_time_out", + "seq_lens_decoder_out", + "step_seq_lens_decoder_out", + "step_draft_tokens_out", + "step_seq_lens_this_time_out", + "accept_num_out", + "accept_tokens_out", + "is_block_step_out", + "not_need_stop_out"}) + .SetInplaceMap({{"draft_tokens", "draft_tokens_out"}, + {"block_tables", "block_tables_out"}, + {"stop_flags", "stop_flags_out"}, + {"seq_lens_this_time", "seq_lens_this_time_out"}, + {"seq_lens_decoder", "seq_lens_decoder_out"}, + {"step_seq_lens_decoder", "step_seq_lens_decoder_out"}, + {"step_draft_tokens", "step_draft_tokens_out"}, + {"step_seq_lens_this_time", "step_seq_lens_this_time_out"}, + {"accept_num", "accept_num_out"}, + {"accept_tokens", "accept_tokens_out"}, + {"is_block_step", "is_block_step_out"}, + {"not_need_stop", "not_need_stop_out"},}) + .SetKernelFn(PD_KERNEL(SpeculateScheduleCache)); diff --git a/fastdeploy/config.py b/fastdeploy/config.py index f1ee1d457..9b2eba9c6 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -889,7 +889,7 @@ class CacheConfig: else: self.kv_cache_ratio = 0.75 self.enc_dec_block_num = 0 if current_platform.is_iluvatar() else 2 - self.prealloc_dec_block_slot_num_threshold = 5 + self.prealloc_dec_block_slot_num_threshold = 12 self.cache_dtype = "bfloat16" self.model_cfg = None self.enable_chunked_prefill = False diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 8463a7e39..6777b356b 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -165,8 +165,7 @@ class EngineArgs: """ Ratio of tokens to process in a block. """ - - prealloc_dec_block_slot_num_threshold: int = 5 + prealloc_dec_block_slot_num_threshold: int = 12 """ Token slot threshold for preallocating decoder blocks. """ @@ -405,8 +404,6 @@ class EngineArgs: raise NotImplementedError("Logprob does not support enable_expert_parallel.") if not current_platform.is_cuda(): raise NotImplementedError("Only CUDA platform supports logprob.") - if self.speculative_config is not None: - envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 if self.splitwise_role != "mixed": envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 if not current_platform.is_cuda(): @@ -706,7 +703,7 @@ class EngineArgs: cache_group.add_argument( "--prealloc-dec-block-slot-num-threshold", type=int, - default=5, + default=12, help="Number of token slot threadshold to allocate next blocks for decoding.", ) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index dc13c74f0..bce3a4d48 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -62,7 +62,6 @@ class EngineSevice: self.cfg = cfg self.scheduler = cfg.scheduler_config.scheduler() - if envs.ENABLE_V1_KVCACHE_SCHEDULER: self.resource_manager = ResourceManagerV1( cfg.max_num_seqs, diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index bb9ae0bfb..ed6e5fed1 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -84,10 +84,14 @@ class ResourceManagerV1(ResourceManager): return len(request.block_tables) * self.config.cache_config.block_size def get_new_block_nums(self, request: Request, num_new_tokens: int): - return ( + block_num = ( request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1 ) // self.config.cache_config.block_size - len(request.block_tables) + if self.config.speculative_config.method is not None: + block_num = min(block_num + 1, self.config.cache_config.max_block_num_per_seq) + return block_num + def _prepare_prefill_task(self, request, new_token_num): request.prefill_start_index = request.num_computed_tokens request.prefill_end_index = request.num_computed_tokens + new_token_num diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 29d570e23..fb4718fce 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -100,6 +100,8 @@ class AppendAttentionBackend(AttentionBackend): self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr( fd_config.model_config, "use_3d_rope", False ) + if fd_config.speculative_config.model_type != "main": + self.rope_3d = False self.causal: bool = getattr(fd_config.model_config, "causal", True) self.speculative_method: str = fd_config.speculative_config.method self.use_speculate: bool = self.speculative_method is not None @@ -356,7 +358,7 @@ class AppendAttentionBackend(AttentionBackend): getattr(layer, "cache_v_zp", None), layer.linear_shift, layer.linear_smooth, - forward_meta.attn_mask_offsets, + None if self.use_speculate else forward_meta.attn_mask_offsets, metadata.kv_signal_data_list[layer.layer_id], getattr(layer, "q_norm_weight", None), getattr(layer, "k_norm_weight", None), @@ -374,7 +376,7 @@ class AppendAttentionBackend(AttentionBackend): metadata.max_partition_size, metadata.encoder_max_partition_size, self.speculate_max_draft_token_num + 1, - self.causal, + self.causal or self.use_speculate, self.speculative_method is not None, ) return res diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 975174737..796476de7 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -306,7 +306,9 @@ def post_process_normal( ) -def post_process_specualate(model_output, save_each_rank: bool = False, skip_save_output: bool = False): +def post_process_specualate( + model_output: ModelOutputData, save_each_rank: bool = False, skip_save_output: bool = False +): """""" speculate_update( model_output.seq_lens_encoder, diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 1f9ba002c..e3ed56c5e 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -261,7 +261,7 @@ class TokenProcessor: def _compute_speculative_status(self): # TODO(liuzichang): Supplement more statistics - interval = 10 + interval = 1 if self.speculative_stats_step % interval == 0: accept_ratio = 1 - self.total_step * 1.0 / self.number_of_output_tokens spec_logger.info( diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index ab908a584..6ec6ee190 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -19,8 +19,10 @@ from typing import List import numpy as np import paddle +from paddleformers.utils.log import logger -from fastdeploy.engine.request import Request +from fastdeploy import envs +from fastdeploy.engine.request import Request, RequestType from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.layers.attention import get_attention_backend from fastdeploy.model_executor.layers.attention.base_attention_backend import ( @@ -50,14 +52,14 @@ class MTPProposer(Proposer): Proposer for Multi-Token-Prediction(MTP) """ - def __init__(self, cfg, main_model, local_rank, device_id, main_model_inputs): + def __init__(self, cfg, main_model, local_rank, device_id, target_model_inputs): super().__init__(cfg) self.num_main_model_layers = self.model_config.num_hidden_layers self.local_rank = local_rank self.device_id = device_id self._update_cfg(main_model) self._load_model() - self.main_model_inputs = main_model_inputs + self.target_model_inputs = target_model_inputs self.mtp_strategy = self.speculative_config.mtp_strategy self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps @@ -73,7 +75,7 @@ class MTPProposer(Proposer): """ Update config for MTP from global config """ - self.model_config.architectures[0] = "Ernie4_5_MTPForCausalLM" + self.model_config.architectures[0] = self.model_config.architectures[0].replace("Moe", "MTP") self.speculative_config.sharing_model = main_model self.model_config.num_hidden_layers = 1 self.model_config.model = self.speculative_config.model @@ -199,14 +201,16 @@ class MTPProposer(Proposer): encoder_block_shape_q = 64 decoder_block_shape_q = 16 - self.model_inputs["decoder_batch_ids"] = paddle.zeros_like(self.main_model_inputs["decoder_batch_ids"]) + self.model_inputs["decoder_batch_ids"] = paddle.zeros_like(self.target_model_inputs["decoder_batch_ids"]) self.model_inputs["decoder_tile_ids_per_batch"] = paddle.zeros_like( - self.main_model_inputs["decoder_tile_ids_per_batch"] + self.target_model_inputs["decoder_tile_ids_per_batch"] ) self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like( - self.main_model_inputs["decoder_num_blocks_cpu"] + self.target_model_inputs["decoder_num_blocks_cpu"] ).pin_memory() - self.model_inputs["max_len_tensor_cpu"] = paddle.zeros_like(self.main_model_inputs["max_len_tensor_cpu"]).cpu() + self.model_inputs["max_len_tensor_cpu"] = paddle.zeros_like( + self.target_model_inputs["max_len_tensor_cpu"] + ).cpu() # Get the attention backend attn_cls = get_attention_backend() @@ -265,28 +269,29 @@ class MTPProposer(Proposer): """ self.model_inputs = {} # Same shape/dytpe with base model - self.model_inputs["block_tables"] = paddle.clone(self.main_model_inputs["block_tables"]) - self.model_inputs["input_ids"] = paddle.clone(self.main_model_inputs["input_ids"]) - self.seq_lens_this_time_buffer = paddle.clone(self.main_model_inputs["seq_lens_this_time"]) + self.model_inputs["block_tables"] = paddle.clone(self.target_model_inputs["block_tables"]) + self.model_inputs["input_ids"] = paddle.clone(self.target_model_inputs["input_ids"]) self.model_inputs["input_ids_cpu"] = paddle.full( shape=[self.max_num_seqs, self.parallel_config.max_model_len], fill_value=-1, dtype="int64", ).cpu() - self.model_inputs["seq_lens_encoder"] = paddle.clone(self.main_model_inputs["seq_lens_encoder"]) - self.model_inputs["seq_lens_decoder"] = paddle.clone(self.main_model_inputs["seq_lens_decoder"]) - self.model_inputs["step_idx"] = paddle.clone(self.main_model_inputs["step_idx"]) - self.model_inputs["stop_flags"] = paddle.clone(self.main_model_inputs["stop_flags"]) - self.model_inputs["stop_nums"] = paddle.clone(self.main_model_inputs["stop_nums"]) + self.seq_lens_this_time_buffer = paddle.clone(self.target_model_inputs["seq_lens_this_time"]) + + self.model_inputs["seq_lens_encoder"] = paddle.clone(self.target_model_inputs["seq_lens_encoder"]) + self.model_inputs["seq_lens_decoder"] = paddle.clone(self.target_model_inputs["seq_lens_decoder"]) + self.model_inputs["step_idx"] = paddle.clone(self.target_model_inputs["step_idx"]) + self.model_inputs["stop_flags"] = paddle.clone(self.target_model_inputs["stop_flags"]) + self.model_inputs["stop_nums"] = paddle.clone(self.target_model_inputs["stop_nums"]) self.model_inputs["not_need_stop"] = paddle.to_tensor([False], dtype="bool", place="cpu") - self.model_inputs["pre_ids"] = paddle.clone(self.main_model_inputs["pre_ids"]) - self.model_inputs["ids_remove_padding"] = paddle.clone(self.main_model_inputs["ids_remove_padding"]) - self.model_inputs["batch_id_per_token"] = paddle.clone(self.main_model_inputs["batch_id_per_token"]) - self.model_inputs["cu_seqlens_q"] = paddle.clone(self.main_model_inputs["cu_seqlens_q"]) - self.model_inputs["cu_seqlens_k"] = paddle.clone(self.main_model_inputs["cu_seqlens_k"]) - self.model_inputs["decoder_batch_ids"] = paddle.clone(self.main_model_inputs["decoder_batch_ids"]) + self.model_inputs["pre_ids"] = paddle.clone(self.target_model_inputs["pre_ids"]) + self.model_inputs["ids_remove_padding"] = paddle.clone(self.target_model_inputs["ids_remove_padding"]) + self.model_inputs["batch_id_per_token"] = paddle.clone(self.target_model_inputs["batch_id_per_token"]) + self.model_inputs["cu_seqlens_q"] = paddle.clone(self.target_model_inputs["cu_seqlens_q"]) + self.model_inputs["cu_seqlens_k"] = paddle.clone(self.target_model_inputs["cu_seqlens_k"]) + self.model_inputs["decoder_batch_ids"] = paddle.clone(self.target_model_inputs["decoder_batch_ids"]) self.model_inputs["decoder_tile_ids_per_batch"] = paddle.clone( - self.main_model_inputs["decoder_tile_ids_per_batch"] + self.target_model_inputs["decoder_tile_ids_per_batch"] ) tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1)) @@ -298,22 +303,22 @@ class MTPProposer(Proposer): ) # self.model_inputs["caches"] = self.cache_kvs # Inherit generation hyperparameters from the main model for consistency - self.model_inputs["top_p"] = self.main_model_inputs["top_p"] - self.model_inputs["top_k"] = self.main_model_inputs["top_k"] - self.model_inputs["temperature"] = self.main_model_inputs["temperature"] - self.model_inputs["eos_token_id"] = self.main_model_inputs["eos_token_id"] - self.model_inputs["penalty_score"] = self.main_model_inputs["penalty_score"] - self.model_inputs["frequency_score"] = self.main_model_inputs["frequency_score"] - self.model_inputs["presence_score"] = self.main_model_inputs["presence_score"] - self.model_inputs["infer_seed"] = self.main_model_inputs["infer_seed"] + self.model_inputs["top_p"] = self.target_model_inputs["top_p"] + self.model_inputs["top_k"] = self.target_model_inputs["top_k"] + self.model_inputs["temperature"] = self.target_model_inputs["temperature"] + self.model_inputs["eos_token_id"] = self.target_model_inputs["eos_token_id"] + self.model_inputs["penalty_score"] = self.target_model_inputs["penalty_score"] + self.model_inputs["frequency_score"] = self.target_model_inputs["frequency_score"] + self.model_inputs["presence_score"] = self.target_model_inputs["presence_score"] + self.model_inputs["infer_seed"] = self.target_model_inputs["infer_seed"] - self.model_inputs["max_dec_len"] = self.main_model_inputs["max_dec_len"] - self.model_inputs["min_dec_len"] = self.main_model_inputs["min_dec_len"] + self.model_inputs["max_dec_len"] = self.target_model_inputs["max_dec_len"] + self.model_inputs["min_dec_len"] = self.target_model_inputs["min_dec_len"] - self.model_inputs["bad_tokens"] = self.main_model_inputs["bad_tokens"] + self.model_inputs["bad_tokens"] = self.target_model_inputs["bad_tokens"] # Integrate the updated results in model forward - self.model_inputs["base_model_draft_tokens"] = self.main_model_inputs["draft_tokens"] + self.model_inputs["base_model_draft_tokens"] = self.target_model_inputs["draft_tokens"] self.model_inputs["substep"] = 0 # Declare AttentionBackend buffers @@ -327,7 +332,7 @@ class MTPProposer(Proposer): shape=[self.max_num_seqs, self.max_draft_token_num + 1], fill_value=-1, dtype="int64" ) - self.model_inputs["encoder_block_lens"] = paddle.clone(self.main_model_inputs["encoder_block_lens"]) + self.model_inputs["encoder_block_lens"] = paddle.clone(self.target_model_inputs["encoder_block_lens"]) self.free_list = list( range( @@ -341,14 +346,76 @@ class MTPProposer(Proposer): self.model_inputs["free_list"] = paddle.to_tensor(self.free_list, dtype="int32") self.model_inputs["free_list_len"] = paddle.full(shape=[1], fill_value=self.free_list_len, dtype="int32") + self.model_inputs["is_block_step"] = paddle.full(shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool") self.model_inputs["batch_drop"] = paddle.full(shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool") self.model_inputs["used_list_len"] = paddle.full(shape=[self.max_num_seqs], fill_value=0, dtype="int32") if self.num_model_steps > 1: self.last_seq_lens_this_time = paddle.full_like( - self.main_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32" + self.target_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32" ) self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu() + def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int): + + if "caches" not in self.model_inputs: + self.initialize_kv_cache() + req_len = len(req_dicts) + + for i in range(req_len): + request = req_dicts[i] + logger.debug(f"{i}th request-{request.request_id}: {request}") + idx = request.idx + if request.task_type.value == RequestType.PREFILL.value: # prefill task + prefill_start_index = request.prefill_start_index + prefill_end_index = request.prefill_end_index + length = prefill_end_index - prefill_start_index + + input_ids = request.prompt_token_ids + request.output_token_ids + + self.input_ids_len[idx] = length + self.model_inputs["pre_ids"][idx : idx + 1] = -1 + self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs["input_ids"][ + idx : idx + 1, 1:length + ] + encoder_block_num = len(request.block_tables) + self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num + self.model_inputs["block_tables"][idx : idx + 1, :] = -1 + self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( + request.block_tables, dtype="int32" + ) + self.model_inputs["stop_flags"][idx : idx + 1] = False + self.model_inputs["batch_drop"][idx : idx + 1] = False + + self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length + self.model_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index + self.seq_lens_this_time_buffer[idx : idx + 1] = length + self.model_inputs["step_idx"][idx : idx + 1] = ( + len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0 + ) + + # has_prefill_task = True + elif request.task_type.value == RequestType.DECODE.value: # decode task + encoder_block_num = len(request.block_tables) + self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num + self.model_inputs["block_tables"][idx : idx + 1, :] = -1 + self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( + request.block_tables, dtype="int32" + ) + # if self.model_inputs["is_block_step"][idx]: # has tasks to continue to decode + # has_decode_task = True + # continue + else: + self.model_inputs["block_tables"][idx : idx + 1, :] = -1 + self.model_inputs["stop_flags"][idx : idx + 1] = True + self.seq_lens_this_time_buffer[idx : idx + 1] = 0 + self.model_inputs["seq_lens_decoder"][idx : idx + 1] = 0 + self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0 + self.model_inputs["is_block_step"][idx : idx + 1] = False + continue + # if has_prefill_task or has_decode_task: + # self.model_inputs["not_need_stop"][0] = True + self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests] + def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int): """ Process inputs for prefill tasks and insert it to model_inputs buffer @@ -408,9 +475,9 @@ class MTPProposer(Proposer): length = len(request.prompt_token_ids) if length > 1: - self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.main_model_inputs["input_ids"][ - idx : idx + 1, 1:length - ] + self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[ + "input_ids" + ][idx : idx + 1, 1:length] self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array( request.prompt_token_ids )[1:] @@ -470,6 +537,7 @@ class MTPProposer(Proposer): """ Prepare MTP inputs """ + use_v1_cache_scheduler = envs.ENABLE_V1_KVCACHE_SCHEDULER draft_model_preprocess( self.model_inputs["draft_tokens"], self.model_inputs["input_ids"], @@ -480,19 +548,21 @@ class MTPProposer(Proposer): self.model_inputs["step_idx"], self.model_inputs["not_need_stop"], self.model_inputs["batch_drop"], + self.model_inputs["is_block_step"], self.model_inputs["pre_ids"], - self.main_model_inputs["accept_tokens"], - self.main_model_inputs["accept_num"], - self.main_model_inputs["seq_lens_this_time"], - self.main_model_inputs["seq_lens_encoder"], - self.main_model_inputs["seq_lens_decoder"], - self.main_model_inputs["step_idx"], - self.main_model_inputs["stop_flags"], - self.main_model_inputs["is_block_step"], - self.main_model_inputs["draft_tokens"], + self.target_model_inputs["accept_tokens"], + self.target_model_inputs["accept_num"], + self.target_model_inputs["seq_lens_this_time"], + self.target_model_inputs["seq_lens_encoder"], + self.target_model_inputs["seq_lens_decoder"], + self.target_model_inputs["step_idx"], + self.target_model_inputs["stop_flags"], + self.target_model_inputs["is_block_step"], + self.target_model_inputs["draft_tokens"], self.num_model_steps, self.speculative_method in ["eagle", "mtp"], self.role == "prefill", + use_v1_cache_scheduler, ) target_hidden_states = eagle_get_hidden_states( @@ -501,9 +571,9 @@ class MTPProposer(Proposer): self.model_inputs["seq_lens_encoder"], self.model_inputs["seq_lens_decoder"], self.model_inputs["stop_flags"], - self.main_model_inputs["accept_num"], - self.main_model_inputs["seq_lens_this_time"], - self.main_model_inputs["seq_lens_encoder"], + self.target_model_inputs["accept_num"], + self.target_model_inputs["seq_lens_this_time"], + self.target_model_inputs["seq_lens_encoder"], self.num_model_steps, ) if isinstance(target_hidden_states, list): @@ -673,41 +743,41 @@ class MTPProposer(Proposer): Allocate/Free block of MPT. """ draft_model_postprocess( - self.main_model_inputs["draft_tokens"], - self.main_model_inputs["seq_lens_this_time"], - self.main_model_inputs["seq_lens_encoder"], - self.main_model_inputs["stop_flags"], - ) - - mtp_step_paddle( - self.main_model_inputs["stop_flags"], - self.model_inputs["stop_flags"], - self.model_inputs["batch_drop"], - self.model_inputs["seq_lens_this_time"], - self.model_inputs["seq_lens_encoder"], - self.model_inputs["seq_lens_decoder"], - self.model_inputs["block_tables"], - self.model_inputs["encoder_block_lens"], - self.model_inputs["used_list_len"], - self.model_inputs["free_list"], - self.model_inputs["free_list_len"], - self.cache_config.block_size, - self.max_draft_token_num, + self.target_model_inputs["draft_tokens"], + self.target_model_inputs["seq_lens_this_time"], + self.target_model_inputs["seq_lens_encoder"], + self.target_model_inputs["stop_flags"], ) + if not envs.ENABLE_V1_KVCACHE_SCHEDULER: + mtp_step_paddle( + self.target_model_inputs["stop_flags"], + self.model_inputs["stop_flags"], + self.model_inputs["batch_drop"], + self.model_inputs["seq_lens_this_time"], + self.model_inputs["seq_lens_encoder"], + self.model_inputs["seq_lens_decoder"], + self.model_inputs["block_tables"], + self.model_inputs["encoder_block_lens"], + self.model_inputs["used_list_len"], + self.model_inputs["free_list"], + self.model_inputs["free_list_len"], + self.cache_config.block_size, + self.max_draft_token_num, + ) def _extend_draft_token_with_ngram_match(self): # TODO(liuzichang): Optimize this Kernel to CUDA Kernel to reduce lantency device = paddle.CUDAPinnedPlace() - draft_tokens = self.main_model_inputs["draft_tokens"].cpu() - seq_lens_this_time = self.main_model_inputs["seq_lens_this_time"].cpu() + draft_tokens = self.target_model_inputs["draft_tokens"].cpu() + seq_lens_this_time = self.target_model_inputs["seq_lens_this_time"].cpu() seq_lens_decoder = self.model_inputs["seq_lens_decoder"].cpu() hybrid_mtp_ngram( self.model_inputs["input_ids_cpu"], self.input_ids_len, self.model_inputs["pre_ids"]._copy_to(device, True), self.model_inputs["step_idx"].cpu(), - self.main_model_inputs["actual_draft_token_num"].cpu(), + self.target_model_inputs["actual_draft_token_num"].cpu(), draft_tokens, seq_lens_this_time, seq_lens_decoder, @@ -716,8 +786,8 @@ class MTPProposer(Proposer): self.min_ngram_size, self.max_draft_token_num, ) - self.main_model_inputs["draft_tokens"][:] = draft_tokens.cuda() - self.main_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda() + self.target_model_inputs["draft_tokens"][:] = draft_tokens.cuda() + self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda() def _run_impl(self, full_hidden_states): """""" diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 89834acd6..385d0c5ab 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -59,6 +59,7 @@ else: recover_decode_task, set_value_by_flags_and_idx, share_external_data, + speculate_schedule_cache, ) from fastdeploy.model_executor.pre_and_post_process import ( @@ -383,6 +384,8 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["stop_flags"].sum() == self.parallel_config.max_num_seqs ) self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests] + if self.speculative_method in ["mtp"]: + self.proposer.insert_tasks_v1(req_dicts, num_running_requests) def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int = None): """ @@ -803,6 +806,13 @@ class GPUModelRunner(ModelRunnerBase): fill_value=0, dtype="int32", ) + # For V1_KVCACHE_SCHEDULER + self.share_inputs["step_draft_tokens"] = paddle.full( + shape=[max_num_seqs, max_draft_token_num + 1], + fill_value=0, + dtype="int64", + ) + self.share_inputs["step_seq_lens_this_time"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") if self.enable_mm: head_dim = self.model_config.head_dim @@ -841,7 +851,11 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["step_seq_lens_decoder"], self.share_inputs["block_tables"], self.share_inputs["is_block_step"], + self.share_inputs["draft_tokens"] if self.speculative_decoding else None, + self.share_inputs["step_draft_tokens"] if self.speculative_decoding else None, + self.share_inputs["step_seq_lens_this_time"] if self.speculative_decoding else None, self.cache_config.block_size, + self.speculative_config.num_speculative_tokens if self.speculative_decoding else 0, ) # Remove padding @@ -1540,6 +1554,24 @@ class GPUModelRunner(ModelRunnerBase): self._update_chunked_prefill(model_forward_batch) self._add_cache(model_forward_batch) + elif self.speculative_decoding: + speculate_schedule_cache( + self.share_inputs["draft_tokens"], + self.share_inputs["block_tables"], + self.share_inputs["stop_flags"], + self.share_inputs["seq_lens_this_time"], + self.share_inputs["seq_lens_decoder"], + self.share_inputs["step_seq_lens_decoder"], + self.share_inputs["step_draft_tokens"], + self.share_inputs["step_seq_lens_this_time"], + self.share_inputs["accept_num"], + self.share_inputs["accept_tokens"], + self.share_inputs["is_block_step"], + self.share_inputs["not_need_stop"], + self.share_inputs["stop_nums"], + self.cache_config.block_size, + self.speculative_config.num_speculative_tokens, + ) self.seq_lens_this_time_buffer[:num_running_requests].copy_( self.share_inputs["seq_lens_this_time"][:num_running_requests], False diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 8a0ff6f09..ecd0dbf04 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -742,13 +742,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}") logger.info(f"- Load strategy: {load_config.load_strategy}") - if ( - args.speculative_config is not None - and ("method" in args.speculative_config) - and (args.speculative_config["method"] is not None) - ): - logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not support speculative decoding now.") - envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 if args.splitwise_role != "mixed": logger.info(f"Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported {args.splitwise_role} now.") envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 diff --git a/tests/operators/test_speculative_schedule_cache.py b/tests/operators/test_speculative_schedule_cache.py new file mode 100644 index 000000000..d9b9057ac --- /dev/null +++ b/tests/operators/test_speculative_schedule_cache.py @@ -0,0 +1,238 @@ +import unittest + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.gpu import speculate_schedule_cache + + +def cpu_reference( + draft_tokens, + block_tables, + stop_flags, + seq_lens_this_time, + seq_lens_decoder, + step_seq_lens_decoder, + step_draft_tokens, + step_seq_lens_this_time, + accept_num, + accept_tokens, + is_block_step, + not_need_stop, + stop_nums, + block_size, + max_draft_tokens, +): + """Pure-NumPy mirror of the CUDA kernel's logic (single block of 512 threads). + Shapes are the same as inputs to the custom op. This mutates the provided + NumPy arrays in-place, exactly like the kernel does. + """ + real_bsz = seq_lens_this_time.shape[0] + max_bsz = stop_flags.shape[0] + draft_tokens_len = draft_tokens.shape[1] + block_num_per_seq = block_tables.shape[1] + + max_next_step_tokens = 2 * max_draft_tokens + 2 + + # Block-local reduction input per thread (threadIdx.x -> bid) + stop_flag_now_int = np.zeros(512, dtype=np.int64) # THREADBLOCK_SIZE = 512 + + for bid in range(512): + if bid < real_bsz: + if not stop_flags[bid]: + max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) // block_size + if max_possible_block_idx < block_num_per_seq and block_tables[bid, max_possible_block_idx] == -1: + is_block_step[bid] = True + step_seq_lens_this_time[bid] = seq_lens_this_time[bid] + seq_lens_this_time[bid] = 0 + stop_flags[bid] = True + step_seq_lens_decoder[bid] = seq_lens_decoder[bid] + seq_lens_decoder[bid] = 0 + accept_num[bid] = 0 + accept_tokens[bid, :] = -1 + step_draft_tokens[bid, :draft_tokens_len] = draft_tokens[bid, :draft_tokens_len] + stop_flag_now_int[bid] = 1 + else: + stop_flag_now_int[bid] = 0 + else: + stop_flag_now_int[bid] = 1 + elif bid < max_bsz: + # Threads in [real_bsz, max_bsz) contribute 1 to reduction + stop_flag_now_int[bid] = 1 + else: + stop_flag_now_int[bid] = 0 + + stop_sum = int(stop_flag_now_int.sum()) + not_need_stop[0] = stop_sum < int(stop_nums[0]) + + +class TestSpeculateScheduleCache(unittest.TestCase): + @classmethod + def setUpClass(cls): + if not paddle.is_compiled_with_cuda(): + raise unittest.SkipTest("Paddle is not compiled with CUDA; skipping GPU op test.") + paddle.device.set_device("gpu") + + def setUp(self): + # --- Construct a deterministic case that exercises all branches --- + # real_bsz < max_bsz to test the padding logic in the CUB reduction + self.real_bsz = 3 + self.max_bsz = 5 # only stop_flags has length max_bsz + + self.draft_tokens_len = 6 + self.accept_tokens_len = 5 + self.block_size = 4 + self.block_num_per_seq = 3 + self.max_draft_tokens = 2 # -> max_next_step_tokens = 6 + + # Inputs that will trigger for bid 0, not trigger for bid 2, and bid 1 is already stopped + # seq_lens_decoder + 6 // 4 -> indices: [1, 1, 4]. Index 4 is out of range -> no trigger on bid 2 + self.draft_tokens = paddle.to_tensor( + np.array( + [ + [1, 1, 1, 1, 1, 1], + [2, 2, 2, 2, 2, 2], + [3, 3, 3, 3, 3, 3], + ], + dtype=np.int64, + ) + ) + self.block_tables = paddle.to_tensor(np.full((self.real_bsz, self.block_num_per_seq), -1, dtype=np.int32)) + # stop_flags length is max_bsz, others are real_bsz + self.stop_flags = paddle.to_tensor(np.array([False, True, False, False, False], dtype=np.bool_)) + self.seq_lens_this_time = paddle.to_tensor(np.array([5, 6, 7], dtype=np.int32)) + self.seq_lens_decoder = paddle.to_tensor(np.array([1, 1, 10], dtype=np.int32)) + + # Will be filled by kernel for the triggering bids only + self.step_seq_lens_decoder = paddle.zeros((self.real_bsz,), dtype="int32") + self.step_draft_tokens = paddle.zeros((self.real_bsz, self.draft_tokens_len), dtype="int64") + self.step_seq_lens_this_time = paddle.zeros((self.real_bsz,), dtype="int32") + + # Intentionally non-zero so we can verify in-place zeroing only where triggered + self.accept_num = paddle.to_tensor(np.array([9, 8, 7], dtype=np.int32)) + self.accept_tokens = paddle.to_tensor( + np.arange(self.real_bsz * self.accept_tokens_len, dtype=np.int64).reshape( + self.real_bsz, self.accept_tokens_len + ) + ) + self.is_block_step = paddle.zeros((self.real_bsz,), dtype=paddle.bool) + + # not_need_stop lives on CPU in the caller; the kernel copies to device internally + self.not_need_stop = paddle.zeros((1,), dtype=paddle.bool).cpu() + + # Choose threshold so with: bid0 triggers, bid1 already stopped, padding (5-3)=2 -> stop_sum = 1+1+2 = 4 + # Set stop_nums to 5 so not_need_stop = (4 < 5) = True + self.stop_nums = paddle.to_tensor([5], dtype=paddle.int64) + + # Keep NumPy copies for CPU reference + self.np_draft_tokens = self.draft_tokens.numpy().copy() + self.np_block_tables = self.block_tables.numpy().copy() + self.np_stop_flags = self.stop_flags.numpy().copy() + self.np_seq_lens_this_time = self.seq_lens_this_time.numpy().copy() + self.np_seq_lens_decoder = self.seq_lens_decoder.numpy().copy() + self.np_step_seq_lens_decoder = self.step_seq_lens_decoder.numpy().copy() + self.np_step_draft_tokens = self.step_draft_tokens.numpy().copy() + self.np_step_seq_lens_this_time = self.step_seq_lens_this_time.numpy().copy() + self.np_accept_num = self.accept_num.numpy().copy() + self.np_accept_tokens = self.accept_tokens.numpy().copy() + self.np_is_block_step = self.is_block_step.numpy().copy() + self.np_not_need_stop = self.not_need_stop.numpy().copy() + self.np_stop_nums = self.stop_nums.numpy().copy() + + def test_correctness_against_cpu_reference(self): + # Run GPU kernel (in-place) + speculate_schedule_cache( + self.draft_tokens, + self.block_tables, + self.stop_flags, + self.seq_lens_this_time, + self.seq_lens_decoder, + self.step_seq_lens_decoder, + self.step_draft_tokens, + self.step_seq_lens_this_time, + self.accept_num, + self.accept_tokens, + self.is_block_step, + self.not_need_stop, + self.stop_nums, + self.block_size, + self.max_draft_tokens, + ) + + # Compute CPU reference (in-place on NumPy copies) + cpu_reference( + self.np_draft_tokens, + self.np_block_tables, + self.np_stop_flags, + self.np_seq_lens_this_time, + self.np_seq_lens_decoder, + self.np_step_seq_lens_decoder, + self.np_step_draft_tokens, + self.np_step_seq_lens_this_time, + self.np_accept_num, + self.np_accept_tokens, + self.np_is_block_step, + self.np_not_need_stop, + self.np_stop_nums, + self.block_size, + self.max_draft_tokens, + ) + + # Compare all mutated tensors + np.testing.assert_array_equal(self.step_draft_tokens.numpy(), self.np_step_draft_tokens) + np.testing.assert_array_equal(self.accept_tokens.numpy(), self.np_accept_tokens) + np.testing.assert_array_equal(self.stop_flags.numpy(), self.np_stop_flags) + np.testing.assert_array_equal(self.is_block_step.numpy(), self.np_is_block_step) + np.testing.assert_array_equal(self.seq_lens_this_time.numpy(), self.np_seq_lens_this_time) + np.testing.assert_array_equal(self.seq_lens_decoder.numpy(), self.np_seq_lens_decoder) + np.testing.assert_array_equal(self.step_seq_lens_decoder.numpy(), self.np_step_seq_lens_decoder) + np.testing.assert_array_equal(self.step_seq_lens_this_time.numpy(), self.np_step_seq_lens_this_time) + np.testing.assert_array_equal(self.accept_num.numpy(), self.np_accept_num) + self.assertEqual(bool(self.not_need_stop.numpy()[0]), bool(self.np_not_need_stop[0])) + + def test_no_trigger_path(self): + # Make block_tables at candidate index != -1 so nothing triggers + # Candidate index for bid 0/1 is 1, set it to 7 + bt = self.block_tables.numpy() + bt[:, 1] = 7 + self.block_tables = paddle.to_tensor(bt) + + # Reset outputs to distinctive values + self.step_seq_lens_decoder[:] = 0 + self.step_draft_tokens[:] = 0 + self.step_seq_lens_this_time[:] = 0 + self.accept_num[:] = -123 + self.accept_tokens[:] = -777 + self.is_block_step[:] = False + self.not_need_stop[:] = False + + # For not_need_stop: stopped_in_real = (bid1 True) = 1, padding = 2 -> stop_sum=3 + # With stop_nums=5 -> True + speculate_schedule_cache( + self.draft_tokens, + self.block_tables, + self.stop_flags, + self.seq_lens_this_time, + self.seq_lens_decoder, + self.step_seq_lens_decoder, + self.step_draft_tokens, + self.step_seq_lens_this_time, + self.accept_num, + self.accept_tokens, + self.is_block_step, + self.not_need_stop, + self.stop_nums, + self.block_size, + self.max_draft_tokens, + ) + + # Nothing should have changed except not_need_stop + np.testing.assert_array_equal(self.step_draft_tokens.numpy(), np.zeros_like(self.step_draft_tokens.numpy())) + np.testing.assert_array_equal(self.is_block_step.numpy(), np.zeros_like(self.is_block_step.numpy())) + np.testing.assert_array_equal(self.accept_tokens.numpy(), np.full_like(self.accept_tokens.numpy(), -777)) + np.testing.assert_array_equal(self.accept_num.numpy(), np.full_like(self.accept_num.numpy(), -123)) + self.assertTrue(bool(self.not_need_stop.numpy()[0])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/operators/test_tree_mask.py b/tests/operators/test_tree_mask.py index 10e55a4b1..408162809 100644 --- a/tests/operators/test_tree_mask.py +++ b/tests/operators/test_tree_mask.py @@ -5,12 +5,16 @@ import unittest import numpy as np import paddle import paddle.nn.functional as F +from paddle.incubate.nn.functional import fused_rms_norm from fastdeploy.model_executor.layers.attention.ops import ( append_attention, get_block_shape_and_split_kv_block, ) +np.random.seed(0) +paddle.seed(0) + class TestTreeMask(unittest.TestCase): def setUp(self): @@ -27,6 +31,7 @@ class TestTreeMask(unittest.TestCase): self.head_dim = 128 self.num_q_head = 20 self.num_kv_head = 4 + self.use_qknorm = True self.dtype = "bfloat16" self.rope_3d = False @@ -91,12 +96,20 @@ class TestTreeMask(unittest.TestCase): cu_seqlens_k[i + 1] = cum_seq_len_k return paddle.to_tensor(batch_id_per_token, dtype="int32"), cu_seqlens_q, cu_seqlens_k - def ref_attention(self, q, k, v, mask): + def ref_attention(self, q, k, v, mask, use_qknorm=False): + if use_qknorm: + q = q.reshape([-1, self.head_dim]) + q = fused_rms_norm(q.astype("float32"), self.q_norm_weight_tensor, None, 1e-6)[0].astype(self.dtype) + q = q.reshape([self.bsz, -1, self.num_q_head, self.head_dim]) q = q.transpose([0, 2, 1, 3]) if len(k) > 1: k = paddle.concat(k, axis=1) else: k = k[0] + if use_qknorm: + k = k.reshape([-1, self.head_dim]) + k = fused_rms_norm(k.astype("float32"), self.k_norm_weight_tensor, None, 1e-6)[0].astype(self.dtype) + k = k.reshape([self.bsz, -1, self.num_kv_head, self.head_dim]) k = k.transpose([0, 2, 1, 3]) if len(v) > 1: v = paddle.concat(v, axis=1) @@ -127,7 +140,7 @@ class TestTreeMask(unittest.TestCase): .reshape([-1, self.num_q_head, self.head_dim]) ) - def run_append_c16_attention(self, q_len, kv_len, prefill=False, attn_mask=None): + def run_append_c16_attention(self, q_len, kv_len, prefill=False, attn_mask=None, use_qknorm=False): if prefill: seq_lens_enc = [ q_len, @@ -187,6 +200,10 @@ class TestTreeMask(unittest.TestCase): decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory() max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu() + q_norm_weight = np.random.random([self.head_dim]) / 10 + k_norm_weight = np.random.random([self.head_dim]) / 10 + self.q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype="float32") + self.k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype="float32") paddle.device.synchronize() ( encoder_batch_ids, @@ -237,20 +254,20 @@ class TestTreeMask(unittest.TestCase): max_len_kv, rotary_embs, attn_mask, - None, - None, + None, # qkv_bias + None, # qkv_out_scales cache_k_scale, cache_v_scale, cache_k_out_scale, cache_v_out_scale, - None, - None, - None, - None, - None, - None, - None, - None, + None, # cache_k_zp + None, # cache_v_zp + None, # linear_shift + None, # linear_smooth + None, # mask_offset + None, # kv_signal_data + self.q_norm_weight_tensor if use_qknorm else None, # q_norm_weight + self.k_norm_weight_tensor if use_qknorm else None, # k_norm_weight 1e-6, "bf16", "none", @@ -271,7 +288,7 @@ class TestTreeMask(unittest.TestCase): paddle.device.synchronize() e_time = time.time() print(f"mean infer time: {np.mean((e_time - s_time) * 1000 / self.run_time):.2f}") - return out[0].reshape([token_num, self.num_q_head, self.head_dim]) + return out.reshape([token_num, self.num_q_head, self.head_dim]) def test_naive_speculative_decoding(self): prefill_len = 8192 @@ -279,10 +296,10 @@ class TestTreeMask(unittest.TestCase): total_len = prefill_len + dec_len_q mask = paddle.tril(paddle.ones((self.bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len) mask = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf"))) - self.run_append_c16_attention(prefill_len, 0, True) - dec_out = self.run_append_c16_attention(dec_len_q, prefill_len, False) + self.run_append_c16_attention(prefill_len, 0, True, use_qknorm=self.use_qknorm) + dec_out = self.run_append_c16_attention(dec_len_q, prefill_len, False, use_qknorm=self.use_qknorm) - ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask) + ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask, use_qknorm=self.use_qknorm) np.testing.assert_allclose( ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03 )