From d58c1db8a078bde555753997dd2ecded269ef33e Mon Sep 17 00:00:00 2001 From: chen <103103266+ckl117@users.noreply.github.com> Date: Mon, 17 Nov 2025 20:47:33 +0800 Subject: [PATCH] [Feature][OP] Append Attn Support CUDA-PDL (#5072) --- .../append_attn/append_attention_func.cuh | 16 + .../decoder_write_cache_with_rope_impl.cuh | 114 +- .../decoder_write_cache_with_rope_kernel.cu | 919 +++++++------ .../encoder_write_cache_with_rope_impl.cuh | 1176 ++++++++++------- .../append_attn/gqa_rope_write_cache.cu | 1051 ++++++++------- .../multiquery_attention_c16_impl.cuh | 332 +++-- .../multiquery_attention_c4_impl.cuh | 466 ++++--- .../multiquery_attention_c8_impl.cuh | 310 +++-- custom_ops/gpu_ops/helper.cu | 214 +-- custom_ops/gpu_ops/helper.h | 296 +++-- fastdeploy/engine/engine.py | 1 + fastdeploy/envs.py | 1 + 12 files changed, 2828 insertions(+), 2068 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index 4c1c51dd4..4936e0172 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -2296,6 +2296,9 @@ __global__ void merge_multi_chunks_decoder_kernel( const int bid = blockIdx.x, hid = blockIdx.y; __shared__ T smem[bdy * HEAD_DIM]; __shared__ float md_smem[bdy * 2]; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif const int start_token_idx = cu_seqlens_q[bid]; const int seq_len_q = seq_lens_q[bid]; if (seq_len_q == 0) return; @@ -2332,6 +2335,10 @@ __global__ void merge_multi_chunks_decoder_kernel( } else if constexpr (std::is_same::value) { m = -3.0e+30f; } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + #pragma unroll 2 for (int i = ty; i < num_chunks_this_seq; i += bdy) { uint32_t offset = (bid * num_chunks + i) * num_heads + hid; @@ -2397,6 +2404,9 @@ __global__ void merge_multi_chunks_decoder_kernel( out_vec, &out[(start_token_idx * num_heads + hid) * head_dim + vid * vec_size]); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } template = 900)) + cudaGridDependencySynchronize(); +#endif for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) { const uint32_t bid = batch_id_per_token[qid]; if (bid == -1) { @@ -2569,4 +2582,7 @@ __global__ void merge_multi_chunks_v2_kernel( } __syncthreads(); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh index 999399f4c..5c141d7e3 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh @@ -109,6 +109,9 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel( const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size; const int half_head_size = head_size / 2; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_dim; gloabl_hi += all_warp_num) { int64_t linear_index = gloabl_hi * head_size + threadIdx.x * VecSize; @@ -198,6 +201,9 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel( } } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } template @@ -239,6 +245,9 @@ __global__ void append_decode_cache_T_rope_kernel( const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size; // const int64_t offset = 2 * hidden_size; const int half_head_size = head_size / 2; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int32_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; @@ -305,10 +314,13 @@ __global__ void append_decode_cache_T_rope_kernel( } } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } template -__global__ void append_decode_cache_T_rope_kernel( +__global__ void append_decode_cache_T_quant_rope_kernel( const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size, @@ -352,6 +364,9 @@ __global__ void append_decode_cache_T_rope_kernel( const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size; // const int64_t offset = 2 * hidden_size; const int half_head_size = head_size / 2; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int32_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; @@ -427,6 +442,9 @@ __global__ void append_decode_cache_T_rope_kernel( } } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } template @@ -473,7 +491,9 @@ __global__ void append_decode_cache_T_neox_partial_rope_kernel( const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size; const int64_t half_hidden_size = hidden_size / 2; // const int64_t offset = 2 * hidden_size; - +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int32_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; @@ -566,6 +586,9 @@ __global__ void append_decode_cache_T_neox_partial_rope_kernel( } } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } template @@ -608,7 +631,9 @@ __global__ void append_decode_cache_T_neox_rope_kernel( const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size; const int64_t half_hidden_size = hidden_size / 2; // const int64_t offset = 2 * hidden_size; - +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int32_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; @@ -680,10 +705,13 @@ __global__ void append_decode_cache_T_neox_rope_kernel( } } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } template -__global__ void append_decode_cache_T_neox_rope_kernel( +__global__ void append_decode_cache_T_quant_neox_rope_kernel( const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size, @@ -726,7 +754,9 @@ __global__ void append_decode_cache_T_neox_rope_kernel( const int half_head_size = head_size / 2; const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size; const int64_t half_hidden_size = hidden_size / 2; - +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int32_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; @@ -814,6 +844,9 @@ __global__ void append_decode_cache_T_neox_rope_kernel( } } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } template = 900)) + cudaGridDependencySynchronize(); +#endif if (head_idx < num_heads) { // q using LoadT = AlignedVector; @@ -1118,6 +1153,9 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel( value_cache[tgt_cache_idx4] = cache_vec[3]; } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } template = 900)) + cudaGridDependencySynchronize(); +#endif if (head_idx < num_heads) { // q const T* qkv_now = @@ -1356,6 +1396,9 @@ __global__ void append_decode_cache_int8_rope_kernel( value_cache[tgt_cache_idx4] = cache_vec[3]; } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } template -__global__ void append_decode_cache_int8_rope_kernel( +__global__ void int_append_decode_cache_int8_rope_kernel( const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] uint8_t* __restrict__ key_cache, // [num_blocks, kv_num_heads, @@ -1412,7 +1455,9 @@ __global__ void append_decode_cache_int8_rope_kernel( block_table_now = block_tables + bid * max_blocks_per_seq; const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); const int block_offset = write_seq_id % block_size; - +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif if (head_idx < num_heads) { // q using LoadT = AlignedVector; @@ -1674,6 +1719,9 @@ __global__ void append_decode_cache_int8_rope_kernel( value_cache[tgt_cache_idx4] = cache_vec[3]; } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } template @@ -1721,7 +1769,9 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( block_table_now = block_tables + bid * max_blocks_per_seq; const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); const int block_offset = write_seq_id % block_size; - +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif if (head_idx < num_heads) { // q using LoadT = AlignedVector; @@ -1977,10 +2027,13 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( value_cache[tgt_cache_idx4] = cache_vec[3]; } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } template -__global__ void append_decode_cache_int8_neox_rope_kernel( +__global__ void int_append_decode_cache_int8_neox_rope_kernel( const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] uint8_t* __restrict__ key_cache, // [num_blocks, kv_num_heads, @@ -2030,7 +2083,9 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( block_table_now = block_tables + bid * max_blocks_per_seq; const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); const int block_offset = write_seq_id % block_size; - +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif if (head_idx < num_heads) { // q using LoadT = AlignedVector; @@ -2374,6 +2429,9 @@ __global__ void append_decode_cache_int8_neox_rope_kernel( value_cache[tgt_cache_idx4] = cache_vec[3]; } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } template @@ -2424,7 +2482,9 @@ __global__ void append_decode_cache_int4_rope_kernel( const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); const int block_offset = write_seq_id % block_size; - +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif if (head_idx < num_heads) { // q const T* qkv_now = @@ -2648,10 +2708,13 @@ __global__ void append_decode_cache_int4_rope_kernel( (uint_quant_value2 << 4) | (uint_quant_value1 & 0x0F); } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } template -__global__ void append_decode_cache_int4_rope_kernel( +__global__ void int_append_decode_cache_int4_rope_kernel( const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] uint8_t* __restrict__ key_cache, // [num_blocks, kv_num_heads, @@ -2703,7 +2766,9 @@ __global__ void append_decode_cache_int4_rope_kernel( const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); const int block_offset = write_seq_id % block_size; - +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif if (head_idx < num_heads) { // q using LoadT = AlignedVector; @@ -2981,6 +3046,9 @@ __global__ void append_decode_cache_int4_rope_kernel( (uint_quant_value2 << 4) | (uint_quant_value1 & 0x0F); } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } template @@ -3031,7 +3099,9 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); const int block_offset = write_seq_id % block_size; - +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif if (head_idx < num_heads) { // q using LoadT = AlignedVector; @@ -3355,10 +3425,13 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( (uint_quant_value2 << 4) | (uint_quant_value1 & 0x0F); } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } template -__global__ void append_decode_cache_int4_neox_rope_kernel( +__global__ void int_append_decode_cache_int4_neox_rope_kernel( const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads, // head_size] uint8_t* __restrict__ key_cache, // [num_blocks, kv_num_heads, @@ -3410,7 +3483,9 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]); const int block_offset = write_seq_id % block_size; - +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif if (head_idx < num_heads) { // q using LoadT = AlignedVector; @@ -3808,4 +3883,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel( (uint_quant_value2 << 4) | (uint_quant_value1 & 0x0F); } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu index f8114188e..d19c2ef38 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu @@ -52,28 +52,33 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv, int grid_size = 1; GetNumBlocks<128>(pack_num, &grid_size); dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1); - append_decode_cache_T_rope_qk_norm_kernel - <<>>(reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - max_seq_len, - max_blocks_per_seq, - num_heads, - dim_head, - block_size, - elem_nums, - kv_num_heads, - rope_3d, - q_norm_weight, - k_norm_weight, - rms_norm_eps); + launchWithPdlWhenEnabled( + append_decode_cache_T_rope_qk_norm_kernel, + grid_size, + block_dim, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + max_seq_len, + max_blocks_per_seq, + num_heads, + dim_head, + block_size, + elem_nums, + kv_num_heads, + rope_3d, + q_norm_weight, + k_norm_weight, + rms_norm_eps); } template @@ -111,118 +116,140 @@ void append_decode_cache_rope(const QKV_TYPE* qkv, GetNumBlocks<128>(pack_num, &grid_size); if (use_neox_style) { if (qkv_out_scales) { - append_decode_cache_T_neox_rope_kernel - <<>>( - reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - qkv_out_scales, - qkv_biases, - max_seq_len, - max_blocks_per_seq, - num_heads, - dim_head, - block_size, - elem_nums, - kv_num_heads, - rope_3d); + launchWithPdlWhenEnabled( + append_decode_cache_T_quant_neox_rope_kernel, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + qkv_out_scales, + qkv_biases, + max_seq_len, + max_blocks_per_seq, + num_heads, + dim_head, + block_size, + elem_nums, + kv_num_heads, + rope_3d); } else { if (rotary_dim < dim_head) { - append_decode_cache_T_neox_partial_rope_kernel - <<>>( - reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - max_seq_len, - max_blocks_per_seq, - num_heads, - dim_head, - rotary_dim, - block_size, - elem_nums, - kv_num_heads, - rope_3d); + auto* kernelFn = + append_decode_cache_T_neox_partial_rope_kernel; + launchWithPdlWhenEnabled(kernelFn, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + max_seq_len, + max_blocks_per_seq, + num_heads, + dim_head, + rotary_dim, + block_size, + elem_nums, + kv_num_heads, + rope_3d); } else { - append_decode_cache_T_neox_rope_kernel - <<>>( - reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - max_seq_len, - max_blocks_per_seq, - num_heads, - dim_head, - block_size, - elem_nums, - kv_num_heads, - rope_3d); + auto* kernelFn = append_decode_cache_T_neox_rope_kernel; + launchWithPdlWhenEnabled(kernelFn, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + max_seq_len, + max_blocks_per_seq, + num_heads, + dim_head, + block_size, + elem_nums, + kv_num_heads, + rope_3d); } } } else { if (qkv_out_scales) { - append_decode_cache_T_rope_kernel - <<>>( - reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - qkv_out_scales, - qkv_biases, - max_seq_len, - max_blocks_per_seq, - num_heads, - dim_head, - block_size, - elem_nums, - kv_num_heads, - rope_3d); + launchWithPdlWhenEnabled( + append_decode_cache_T_quant_rope_kernel, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + qkv_out_scales, + qkv_biases, + max_seq_len, + max_blocks_per_seq, + num_heads, + dim_head, + block_size, + elem_nums, + kv_num_heads, + rope_3d); } else { - append_decode_cache_T_rope_kernel - <<>>(reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - max_seq_len, - max_blocks_per_seq, - num_heads, - dim_head, - block_size, - elem_nums, - kv_num_heads, - rope_3d); + auto* kernelFn = append_decode_cache_T_rope_kernel; + launchWithPdlWhenEnabled(kernelFn, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + max_seq_len, + max_blocks_per_seq, + num_heads, + dim_head, + block_size, + elem_nums, + kv_num_heads, + rope_3d); } } } @@ -261,113 +288,128 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv, dim3 grids(bsz, all_warps / num_warps); if (use_neox_style) { if (qkv_out_scales) { - append_decode_cache_int8_neox_rope_kernel - <<>>( - reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - qkv_out_scales, - qkv_biases, - cache_k_scale, - cache_v_scale, - max_seq_len, - max_blocks_per_seq, - num_heads, - block_size, - 127.0f, - -127.0f, - kv_num_heads, - rope_3d); + launchWithPdlWhenEnabled( + int_append_decode_cache_int8_neox_rope_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + qkv_out_scales, + qkv_biases, + cache_k_scale, + cache_v_scale, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 127.0f, + -127.0f, + kv_num_heads, + rope_3d); } else { - append_decode_cache_int8_neox_rope_kernel - <<>>( - reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - cache_k_scale, - cache_v_scale, - max_seq_len, - max_blocks_per_seq, - num_heads, - block_size, - 127.0f, - -127.0f, - kv_num_heads, - rope_3d); + launchWithPdlWhenEnabled(append_decode_cache_int8_neox_rope_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + cache_k_scale, + cache_v_scale, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 127.0f, + -127.0f, + kv_num_heads, + rope_3d); } } else { if (qkv_out_scales) { - append_decode_cache_int8_rope_kernel - <<>>( - reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - qkv_out_scales, - qkv_biases, - cache_k_scale, - cache_v_scale, - max_seq_len, - max_blocks_per_seq, - num_heads, - block_size, - 127.0f, - -127.0f, - kv_num_heads, - rope_3d); + launchWithPdlWhenEnabled( + int_append_decode_cache_int8_rope_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + qkv_out_scales, + qkv_biases, + cache_k_scale, + cache_v_scale, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 127.0f, + -127.0f, + kv_num_heads, + rope_3d); } else { - append_decode_cache_int8_rope_kernel - <<>>( - reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - cache_k_scale, - cache_v_scale, - max_seq_len, - max_blocks_per_seq, - num_heads, - block_size, - 127.0f, - -127.0f, - kv_num_heads, - rope_3d); + launchWithPdlWhenEnabled( + append_decode_cache_int8_rope_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + cache_k_scale, + cache_v_scale, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 127.0f, + -127.0f, + kv_num_heads, + rope_3d); } } } @@ -405,111 +447,124 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv, dim3 grids(bsz, all_warps / num_warps); if (use_neox_style) { if (qkv_out_scales) { - append_decode_cache_int4_neox_rope_kernel - <<>>( - reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - qkv_out_scales, - qkv_biases, - cache_k_scale, - cache_v_scale, - cache_k_zp, - cache_v_zp, - max_seq_len, - max_blocks_per_seq, - num_heads, - block_size, - 7.0f, - -8.0f, - kv_num_heads, - rope_3d); + launchWithPdlWhenEnabled( + int_append_decode_cache_int4_neox_rope_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + qkv_out_scales, + qkv_biases, + cache_k_scale, + cache_v_scale, + cache_k_zp, + cache_v_zp, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 7.0f, + -8.0f, + kv_num_heads, + rope_3d); } else { - append_decode_cache_int4_neox_rope_kernel - <<>>( - reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - cache_k_scale, - cache_v_scale, - cache_k_zp, - cache_v_zp, - max_seq_len, - max_blocks_per_seq, - num_heads, - block_size, - 7.0f, - -8.0f, - kv_num_heads, - rope_3d); + launchWithPdlWhenEnabled(append_decode_cache_int4_neox_rope_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + cache_k_scale, + cache_v_scale, + cache_k_zp, + cache_v_zp, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 7.0f, + -8.0f, + kv_num_heads, + rope_3d); } } else { if (qkv_out_scales) { - append_decode_cache_int4_rope_kernel - <<>>( - reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - qkv_out_scales, - qkv_biases, - cache_k_scale, - cache_v_scale, - cache_k_zp, - cache_v_zp, - max_seq_len, - max_blocks_per_seq, - num_heads, - block_size, - 7.0f, - -8.0f, - kv_num_heads, - rope_3d); + launchWithPdlWhenEnabled(int_append_decode_cache_int4_rope_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + qkv_out_scales, + qkv_biases, + cache_k_scale, + cache_v_scale, + cache_k_zp, + cache_v_zp, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 7.0f, + -8.0f, + kv_num_heads, + rope_3d); } else { - append_decode_cache_int4_rope_kernel - <<>>( - reinterpret_cast(qkv), - key_cache, - value_cache, - qkv_out, - block_tables, - cu_seqlens_q, - seq_lens, - seq_lens_encoder, - cos_emb, - sin_emb, - cache_k_scale, - cache_v_scale, - cache_k_zp, - cache_v_zp, - max_seq_len, - max_blocks_per_seq, - num_heads, - block_size, - 7.0f, - -8.0f, - kv_num_heads, - rope_3d); + launchWithPdlWhenEnabled(append_decode_cache_int4_rope_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv), + key_cache, + value_cache, + qkv_out, + block_tables, + cu_seqlens_q, + seq_lens, + seq_lens_encoder, + cos_emb, + sin_emb, + cache_k_scale, + cache_v_scale, + cache_k_zp, + cache_v_zp, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 7.0f, + -8.0f, + kv_num_heads, + rope_3d); } } } @@ -610,77 +665,85 @@ void DecoderWriteCacheWithRoPEKernel( const int all_warps = ((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps; dim3 grids(bsz, all_warps / num_warps); - append_decode_cache_int8_rope_qk_norm_kernel - <<>>( - reinterpret_cast(qkv_ptr), - key_cache_out->data(), - value_cache_out->data(), - reinterpret_cast(qkv_out->data()), - block_tables.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - const_cast(reinterpret_cast( - cache_k_scale.get().data())), - const_cast(reinterpret_cast( - (cache_v_scale.get().data()))), - q_norm_weight.get().data(), - k_norm_weight.get().data(), - max_seq_len, - max_blocks_per_seq, - num_heads, - block_size, - 127.0f, - -127.0f, - kv_num_heads, - rope_3d, - rms_norm_eps); + launchWithPdlWhenEnabled( + append_decode_cache_int8_rope_qk_norm_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + const_cast(reinterpret_cast( + cache_k_scale.get().data())), + const_cast(reinterpret_cast( + (cache_v_scale.get().data()))), + q_norm_weight.get().data(), + k_norm_weight.get().data(), + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 127.0f, + -127.0f, + kv_num_heads, + rope_3d, + rms_norm_eps); } else if ((cache_quant_type_str == "cache_fp8")) { constexpr int num_warps = 4; const int all_warps = ((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps; dim3 grids(bsz, all_warps / num_warps); - append_decode_cache_int8_rope_qk_norm_kernel - <<>>( - reinterpret_cast(qkv_ptr), - key_cache_out->data(), - value_cache_out->data(), - reinterpret_cast(qkv_out->data()), - block_tables.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - const_cast(reinterpret_cast( - cache_k_scale.get().data())), - const_cast(reinterpret_cast( - (cache_v_scale.get().data()))), - q_norm_weight.get().data(), - k_norm_weight.get().data(), - max_seq_len, - max_blocks_per_seq, - num_heads, - block_size, - 127.0f, - -127.0f, - kv_num_heads, - rope_3d, - rms_norm_eps); + launchWithPdlWhenEnabled( + append_decode_cache_int8_rope_qk_norm_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + const_cast(reinterpret_cast( + cache_k_scale.get().data())), + const_cast(reinterpret_cast( + (cache_v_scale.get().data()))), + q_norm_weight.get().data(), + k_norm_weight.get().data(), + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 127.0f, + -127.0f, + kv_num_heads, + rope_3d, + rms_norm_eps); } else { PD_THROW( "append_decode_cache_rope_qk_norm just supports cache_quant_type " @@ -822,38 +885,42 @@ void DecoderWriteCacheWithRoPEKernel( const int all_warps = ((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps; dim3 grids(bsz, all_warps / num_warps); - append_decode_cache_int8_rope_qk_norm_kernel - <<>>( - reinterpret_cast(qkv_ptr), - key_cache_out->data(), - value_cache_out->data(), - reinterpret_cast(qkv_out->data()), - block_tables.data(), - cu_seqlens_q.data(), - seq_lens.data(), - seq_lens_encoder.data(), - cos_emb, - sin_emb, - const_cast(reinterpret_cast( - cache_k_scale.get().data())), - const_cast(reinterpret_cast( - (cache_v_scale.get().data()))), - nullptr, - nullptr, - max_seq_len, - max_blocks_per_seq, - num_heads, - block_size, - 127.0f, - -127.0f, - kv_num_heads, - rope_3d, - rms_norm_eps); + launchWithPdlWhenEnabled( + append_decode_cache_int8_rope_qk_norm_kernel, + grids, + num_warps * 32, + 0, + stream, + reinterpret_cast(qkv_ptr), + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast(qkv_out->data()), + block_tables.data(), + cu_seqlens_q.data(), + seq_lens.data(), + seq_lens_encoder.data(), + cos_emb, + sin_emb, + const_cast(reinterpret_cast( + cache_k_scale.get().data())), + const_cast(reinterpret_cast( + (cache_v_scale.get().data()))), + nullptr, + nullptr, + max_seq_len, + max_blocks_per_seq, + num_heads, + block_size, + 127.0f, + -127.0f, + kv_num_heads, + rope_3d, + rms_norm_eps); } else if (cache_quant_type_str == "cache_int4_zp") { append_decode_cache_int4_rope( reinterpret_cast(qkv_ptr), diff --git a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh index 688838cca..c1d1fd27c 100644 --- a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh @@ -19,7 +19,7 @@ #include "utils.cuh" template -__global__ void VariableLengthRotaryKernel( +__global__ void IntVariableLengthRotaryKernel( const int *qkv, const float *cos_emb, // [1, 1, seq_len, dim_head / 2] const float *sin_emb, @@ -49,6 +49,9 @@ __global__ void VariableLengthRotaryKernel( const int half_lastdim = last_dim / 2; const int hidden_size = num_head * last_dim; const int offset = 3 * hidden_size; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int64_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; @@ -62,7 +65,8 @@ __global__ void VariableLengthRotaryKernel( const int hi = qkv_bias / last_dim; const int h_bias = qkv_bias % last_dim; - const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + const int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int emb_idx = ori_seq_id * half_lastdim + h_bias / 2; int new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx; @@ -102,6 +106,9 @@ __global__ void VariableLengthRotaryKernel( } Store(bias_vec, &qkv_out[base_idx]); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } template @@ -129,6 +136,9 @@ __global__ void VariableLengthRotaryKernel( const int half_lastdim = last_dim / 2; const int hidden_size = num_head * last_dim; const int offset = 2 * hidden_size; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int64_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; @@ -142,7 +152,8 @@ __global__ void VariableLengthRotaryKernel( const int hi = qkv_bias / last_dim; const int h_bias = qkv_bias % last_dim; - const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + const int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int emb_idx = ori_seq_id * half_lastdim + h_bias / 2; int new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx; @@ -164,10 +175,13 @@ __global__ void VariableLengthRotaryKernel( } Store(src_vec, &qkv_out[base_idx]); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } template -__global__ void NeoxVariableLengthRotaryKernel( +__global__ void IntNeoxVariableLengthRotaryKernel( const int *qkv, const float *cos_emb, // [1, 1, seq_len, dim_head / 2] const float *sin_emb, @@ -200,6 +214,9 @@ __global__ void NeoxVariableLengthRotaryKernel( const int hidden_size = num_head * half_lastdim; const int full_hidden_size = num_head * last_dim; const int offset = 3 * hidden_size; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int64_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; @@ -213,10 +230,12 @@ __global__ void NeoxVariableLengthRotaryKernel( const int hi = qkv_bias / half_lastdim; const int h_bias = qkv_bias % half_lastdim; - const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + const int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int emb_idx = ori_seq_id * last_dim + h_bias; - int new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len * 2 : emb_idx; + int new_emb_idx = + rope_3d ? emb_idx + ori_bi * last_dim * seq_len * 2 : emb_idx; const int bias_idx_left = qkv_id * full_hidden_size + hi * last_dim + h_bias; const int bias_idx_right = bias_idx_left + half_lastdim; @@ -260,6 +279,9 @@ __global__ void NeoxVariableLengthRotaryKernel( Store(left_bias_vec, &qkv_out[base_idx_left]); Store(right_bias_vec, &qkv_out[base_idx_right]); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } template @@ -288,6 +310,9 @@ __global__ void NeoxVariableLengthRotaryKernel( const int hidden_size = num_head * half_lastdim; const int full_hidden_size = num_head * last_dim; const int offset = 2 * hidden_size; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int64_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; @@ -301,10 +326,12 @@ __global__ void NeoxVariableLengthRotaryKernel( const int hi = qkv_bias / half_lastdim; const int h_bias = qkv_bias % half_lastdim; - const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + const int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int emb_idx = ori_seq_id * last_dim + h_bias; - int new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len * 2 : emb_idx; + int new_emb_idx = + rope_3d ? emb_idx + ori_bi * last_dim * seq_len * 2 : emb_idx; const int base_idx_left = token_idx * 3 * full_hidden_size + qkv_id * full_hidden_size + hi * last_dim + h_bias; @@ -328,10 +355,13 @@ __global__ void NeoxVariableLengthRotaryKernel( Store(left_vec, &qkv_out[base_idx_left]); Store(right_vec, &qkv_out[base_idx_right]); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } template -__global__ void GQAVariableLengthRotaryKernel( +__global__ void IntGQAVariableLengthRotaryKernel( const int *qkv, const float *cos_emb, // [1, 1, seq_len, dim_head / 2] const float *sin_emb, @@ -361,6 +391,9 @@ __global__ void GQAVariableLengthRotaryKernel( int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; const int half_lastdim = last_dim / 2; const int offset = (q_num_head + 2 * kv_num_head) * last_dim; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int64_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; @@ -372,10 +405,12 @@ __global__ void GQAVariableLengthRotaryKernel( const int hi = bias / last_dim; const int h_bias = bias % last_dim; - const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + const int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2; - int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx; + int64_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx; const int64_t bias_idx = hi * last_dim + h_bias; const int64_t base_idx = token_idx * offset + bias_idx; Load(&qkv[base_idx], &src_vec); @@ -412,9 +447,11 @@ __global__ void GQAVariableLengthRotaryKernel( } Store(bias_vec, &qkv_out[base_idx]); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } - template __global__ void GQAVariableLengthRotaryQKNormKernel( const T *qkv, @@ -431,10 +468,9 @@ __global__ void GQAVariableLengthRotaryQKNormKernel( const int seq_len, const int last_dim, const bool rope_3d, - const float* q_norm_weight, - const float* k_norm_weight, - const float rms_norm_eps -) { + const float *q_norm_weight, + const float *k_norm_weight, + const float rms_norm_eps) { using LoadT = AlignedVector; constexpr int HalfVecSize = VecSize / 2; using LoadEmbT = AlignedVector; @@ -449,7 +485,11 @@ __global__ void GQAVariableLengthRotaryQKNormKernel( const int half_lastdim = last_dim / 2; const int offset = (q_num_head + kv_num_head) * last_dim; const int all_head_num = elem_cnt / last_dim; - for (int global_hi = global_warp_idx; global_hi < all_head_num; global_hi += all_warp_num) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + for (int global_hi = global_warp_idx; global_hi < all_head_num; + global_hi += all_warp_num) { int64_t linear_index = global_hi * last_dim + threadIdx.x * VecSize; const int token_idx = linear_index / offset; const int ori_bi = batch_id_per_token[token_idx]; @@ -458,14 +498,16 @@ __global__ void GQAVariableLengthRotaryQKNormKernel( const int hi = bias / last_dim; const int h_bias = bias % last_dim; - const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + const int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2; const int64_t base_idx = token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim + h_bias; Load(&qkv[base_idx], &src_vec); - int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx; + int64_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx; Load(&cos_emb[new_emb_idx], &cos_emb_vec); Load(&sin_emb[new_emb_idx], &sin_emb_vec); @@ -485,13 +527,12 @@ __global__ void GQAVariableLengthRotaryQKNormKernel( thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; } WelfordWarpAllReduce(thread_m2, &warp_m2); - float row_variance = - max(warp_m2 / last_dim, 0.0f); + float row_variance = max(warp_m2 / last_dim, 0.0f); float row_inv_var = Rsqrt(row_variance + rms_norm_eps); if (hi < q_num_head) { Load(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec); - #pragma unroll +#pragma unroll for (int i = 0; i < VecSize; i++) { src_vec[i] = static_cast(tmp_vec[i] * row_inv_var * q_norm_vec[i]); } @@ -503,24 +544,26 @@ __global__ void GQAVariableLengthRotaryQKNormKernel( } Store(src_vec, &qkv_out[base_idx]); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } template -__global__ void GQAVariableLengthRotaryKernel( - const T *qkv, - const float *cos_emb, - const float *sin_emb, - const int *batch_id_per_token, - const int *cu_seqlens_q, - const int *seq_lens, - const int *seq_lens_decoder, - T *qkv_out, - const int64_t elem_cnt, - const int q_num_head, - const int kv_num_head, - const int seq_len, - const int last_dim, - const bool rope_3d) { +__global__ void GQAVariableLengthRotaryKernel(const T *qkv, + const float *cos_emb, + const float *sin_emb, + const int *batch_id_per_token, + const int *cu_seqlens_q, + const int *seq_lens, + const int *seq_lens_decoder, + T *qkv_out, + const int64_t elem_cnt, + const int q_num_head, + const int kv_num_head, + const int seq_len, + const int last_dim, + const bool rope_3d) { using LoadT = AlignedVector; constexpr int HalfVecSize = VecSize / 2; using LoadEmbT = AlignedVector; @@ -530,18 +573,23 @@ __global__ void GQAVariableLengthRotaryKernel( int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; const int half_lastdim = last_dim / 2; const int offset = (q_num_head + kv_num_head) * last_dim; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int64_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; linear_index += step) { const int token_idx = linear_index / offset; - const int ori_bi = batch_id_per_token[token_idx];; + const int ori_bi = batch_id_per_token[token_idx]; + ; if (seq_lens[ori_bi] == 0) continue; const int bias = linear_index % offset; const int hi = bias / last_dim; const int h_bias = bias % last_dim; - const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + const int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2; const int64_t base_idx = @@ -549,7 +597,8 @@ __global__ void GQAVariableLengthRotaryKernel( h_bias; Load(&qkv[base_idx], &src_vec); - int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx; + int64_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx; Load(&cos_emb[new_emb_idx], &cos_emb_vec); Load(&sin_emb[new_emb_idx], &sin_emb_vec); #pragma unroll @@ -565,27 +614,31 @@ __global__ void GQAVariableLengthRotaryKernel( } Store(src_vec, &qkv_out[base_idx]); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } template -__global__ void GQAVariableLengthRotaryQuantKVKernel(const int *qkv, - const float *cos_emb, // [1, 1, seq_len, dim_head / 2] - const float *sin_emb, - const float *qkv_out_scales, - const int *batch_id_per_token, - const int *cu_seqlens_q, - const int *seq_lens, - const int *seq_lens_decoder, - const T *qkv_biases, - const T *cache_k_scales, - const T *cache_v_scales, - T *qkv_out, - const int64_t elem_cnt, - const int q_num_head, - const int kv_num_head, - const int seq_len, - const int last_dim, - const bool rope_3d) { +__global__ void IntGQAVariableLengthRotaryQuantKVKernel( + const int *qkv, + const float *cos_emb, // [1, 1, seq_len, dim_head / 2] + const float *sin_emb, + const float *qkv_out_scales, + const int *batch_id_per_token, + const int *cu_seqlens_q, + const int *seq_lens, + const int *seq_lens_decoder, + const T *qkv_biases, + const T *cache_k_scales, + const T *cache_v_scales, + T *qkv_out, + const int64_t elem_cnt, + const int q_num_head, + const int kv_num_head, + const int seq_len, + const int last_dim, + const bool rope_3d) { using LoadIn = AlignedVector; using LoadBiasT = AlignedVector; constexpr int HalfVecSize = VecSize / 2; @@ -600,6 +653,9 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const int *qkv, const int half_lastdim = last_dim / 2; // const int hidden_size = num_head * last_dim; const int offset = (q_num_head + 2 * kv_num_head) * last_dim; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int64_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; @@ -611,10 +667,12 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const int *qkv, const int hi = bias / last_dim; const int h_bias = bias % last_dim; - int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2; - int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx; + int64_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx; const int64_t bias_idx = hi * last_dim + h_bias; const int64_t base_idx = token_idx * offset + bias_idx; Load(&qkv[base_idx], &src_vec); @@ -634,47 +692,59 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const int *qkv, input_right = qkv_biases ? input_right * out_scale_vec[2 * i + 1] + static_cast(bias_vec[2 * i + 1]) : input_right * out_scale_vec[2 * i + 1]; - if (hi < q_num_head) { // qk rope + if (hi < q_num_head) { // qk rope const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; - bias_vec[2 * i] = static_cast(input_left * cos_tmp - input_right * sin_tmp); - bias_vec[2 * i + 1] = static_cast(input_right * cos_tmp + input_left * sin_tmp); + bias_vec[2 * i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + bias_vec[2 * i + 1] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); } else if (hi < q_num_head + kv_num_head) { int k_hi = hi - q_num_head; const int scale_idx = k_hi * last_dim + h_bias; const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; - bias_vec[2 * i] = static_cast((input_left * cos_tmp - input_right * sin_tmp) * float(cache_k_scales[scale_idx + 2 * i])); - bias_vec[2 * i + 1] = static_cast((input_right * cos_tmp + input_left * sin_tmp) * float(cache_k_scales[scale_idx + 2 * i + 1])); + bias_vec[2 * i] = + static_cast((input_left * cos_tmp - input_right * sin_tmp) * + float(cache_k_scales[scale_idx + 2 * i])); + bias_vec[2 * i + 1] = + static_cast((input_right * cos_tmp + input_left * sin_tmp) * + float(cache_k_scales[scale_idx + 2 * i + 1])); } else { int v_hi = hi - q_num_head - kv_num_head; const int scale_idx = v_hi * last_dim + h_bias; - bias_vec[2 * i] = static_cast(input_left * float(cache_v_scales[scale_idx + 2 * i])); - bias_vec[2 * i + 1] = static_cast(input_right * float(cache_v_scales[scale_idx + 2 * i + 1])); + bias_vec[2 * i] = static_cast( + input_left * float(cache_v_scales[scale_idx + 2 * i])); + bias_vec[2 * i + 1] = static_cast( + input_right * float(cache_v_scales[scale_idx + 2 * i + 1])); } } Store(bias_vec, &qkv_out[base_idx]); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } template -__global__ void GQAVariableLengthRotaryQuantKVKernel(const T *qkv, - const float *cos_emb, // [1, 1, seq_len, dim_head / 2] - const float *sin_emb, - const int *batch_id_per_token, - const int *cu_seqlens_q, - const int *seq_lens, - const int *seq_lens_decoder, - const T *qkv_biases, - const T *cache_k_scales, - const T *cache_v_scales, - T *qkv_out, - const int64_t elem_cnt, - const int q_num_head, - const int kv_num_head, - const int seq_len, - const int last_dim, - const bool rope_3d) { +__global__ void GQAVariableLengthRotaryQuantKVKernel( + const T *qkv, + const float *cos_emb, // [1, 1, seq_len, dim_head / 2] + const float *sin_emb, + const int *batch_id_per_token, + const int *cu_seqlens_q, + const int *seq_lens, + const int *seq_lens_decoder, + const T *qkv_biases, + const T *cache_k_scales, + const T *cache_v_scales, + T *qkv_out, + const int64_t elem_cnt, + const int q_num_head, + const int kv_num_head, + const int seq_len, + const int last_dim, + const bool rope_3d) { using LoadT = AlignedVector; constexpr int HalfVecSize = VecSize / 2; using LoadEmbT = AlignedVector; @@ -686,6 +756,9 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const T *qkv, const int half_lastdim = last_dim / 2; // const int hidden_size = num_head * last_dim; const int offset = (q_num_head + 2 * kv_num_head) * last_dim; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int64_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; @@ -697,10 +770,12 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const T *qkv, const int hi = bias / last_dim; const int h_bias = bias % last_dim; - int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2; - int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx; + int64_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx; const int64_t bias_idx = hi * last_dim + h_bias; const int64_t base_idx = token_idx * offset + bias_idx; Load(&qkv[base_idx], &src_vec); @@ -711,37 +786,54 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const T *qkv, Load(&sin_emb[new_emb_idx], &sin_emb_vec); #pragma unroll for (int i = 0; i < HalfVecSize; i++) { - const float input_left = qkv_biases ? static_cast(src_vec[2 * i]+ bias_vec[2 * i]) : static_cast(src_vec[2 * i]); - const float input_right = qkv_biases ? static_cast(src_vec[2 * i + 1] + bias_vec[2 * i + 1]) : static_cast(src_vec[2 * i + 1]); + const float input_left = + qkv_biases ? static_cast(src_vec[2 * i] + bias_vec[2 * i]) + : static_cast(src_vec[2 * i]); + const float input_right = + qkv_biases + ? static_cast(src_vec[2 * i + 1] + bias_vec[2 * i + 1]) + : static_cast(src_vec[2 * i + 1]); // const float cos_tmp = cos_emb_vec[i]; // const float sin_tmp = sin_emb_vec[i]; - // src_vec[2 * i] = static_cast(input_left * cos_tmp - input_right * sin_tmp); - // src_vec[2 * i + 1] = static_cast(input_right * cos_tmp + input_left * sin_tmp); - if (hi < q_num_head) { // qk rope + // src_vec[2 * i] = static_cast(input_left * cos_tmp - input_right * + // sin_tmp); src_vec[2 * i + 1] = static_cast(input_right * cos_tmp + + // input_left * sin_tmp); + if (hi < q_num_head) { // qk rope const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; - src_vec[2 * i] = static_cast(input_left * cos_tmp - input_right * sin_tmp); - src_vec[2 * i + 1] = static_cast(input_right * cos_tmp + input_left * sin_tmp); + src_vec[2 * i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + src_vec[2 * i + 1] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); } else if (hi < q_num_head + kv_num_head) { int k_hi = hi - q_num_head; const int scale_idx = k_hi * last_dim + h_bias; const float cos_tmp = cos_emb_vec[i]; const float sin_tmp = sin_emb_vec[i]; - src_vec[2 * i] = static_cast((input_left * cos_tmp - input_right * sin_tmp) * float(cache_k_scales[scale_idx + 2 * i])); - src_vec[2 * i + 1] = static_cast((input_right * cos_tmp + input_left * sin_tmp) * float(cache_k_scales[scale_idx + 2 * i + 1])); + src_vec[2 * i] = + static_cast((input_left * cos_tmp - input_right * sin_tmp) * + float(cache_k_scales[scale_idx + 2 * i])); + src_vec[2 * i + 1] = + static_cast((input_right * cos_tmp + input_left * sin_tmp) * + float(cache_k_scales[scale_idx + 2 * i + 1])); } else { int v_hi = hi - q_num_head - kv_num_head; const int scale_idx = v_hi * last_dim + h_bias; - src_vec[2 * i] = static_cast(input_left * float(cache_v_scales[scale_idx + 2 * i])); - src_vec[2 * i + 1] = static_cast(input_right * float(cache_v_scales[scale_idx + 2 * i + 1])); + src_vec[2 * i] = static_cast( + input_left * float(cache_v_scales[scale_idx + 2 * i])); + src_vec[2 * i + 1] = static_cast( + input_right * float(cache_v_scales[scale_idx + 2 * i + 1])); } } Store(src_vec, &qkv_out[base_idx]); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } template -__global__ void GQANeoxVariableLengthRotaryKernel( +__global__ void IntGQANeoxVariableLengthRotaryKernel( const int *qkv, const float *cos_emb, // [1, 1, seq_len, dim_head / 2] const float *sin_emb, @@ -773,6 +865,9 @@ __global__ void GQANeoxVariableLengthRotaryKernel( int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; const int half_lastdim = last_dim / 2; const int offset = (q_num_head + 2 * kv_num_head) * half_lastdim; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int64_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; @@ -784,10 +879,12 @@ __global__ void GQANeoxVariableLengthRotaryKernel( const int hi = bias / half_lastdim; const int h_bias = bias % half_lastdim; - const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + const int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int emb_idx = ori_seq_id * last_dim + h_bias; - int new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len * 2 : emb_idx; + int new_emb_idx = + rope_3d ? emb_idx + ori_bi * last_dim * seq_len * 2 : emb_idx; const int bias_idx_left = hi * last_dim + h_bias; const int bias_idx_right = bias_idx_left + half_lastdim; const int base_idx_left = @@ -831,26 +928,28 @@ __global__ void GQANeoxVariableLengthRotaryKernel( Store(left_bias_vec, &qkv_out[base_idx_left]); Store(right_bias_vec, &qkv_out[base_idx_right]); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } template -__global__ void GQANeoxVariableLengthRotaryKernel( - const T *qkv, - const float *cos_emb, - const float *sin_emb, - const int *batch_id_per_token, - const int *cu_seqlens_q, - const int *seq_lens, - const int *seq_lens_decoder, - const float *qkv_out_scales, - const T *qkv_biases, - T *qkv_out, - const int64_t elem_cnt, - const int q_num_head, - const int kv_num_head, - const int seq_len, - const int last_dim, - const bool rope_3d) { +__global__ void GQANeoxVariableLengthRotaryKernel(const T *qkv, + const float *cos_emb, + const float *sin_emb, + const int *batch_id_per_token, + const int *cu_seqlens_q, + const int *seq_lens, + const int *seq_lens_decoder, + const float *qkv_out_scales, + const T *qkv_biases, + T *qkv_out, + const int64_t elem_cnt, + const int q_num_head, + const int kv_num_head, + const int seq_len, + const int last_dim, + const bool rope_3d) { using LoadT = AlignedVector; using LoadEmbT = AlignedVector; LoadT left_vec; @@ -860,6 +959,9 @@ __global__ void GQANeoxVariableLengthRotaryKernel( int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; const int half_lastdim = last_dim / 2; const int offset = (q_num_head + kv_num_head) * half_lastdim; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int64_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; @@ -871,10 +973,12 @@ __global__ void GQANeoxVariableLengthRotaryKernel( const int hi = bias / half_lastdim; const int h_bias = bias % half_lastdim; - const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + const int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int emb_idx = ori_seq_id * last_dim + h_bias; - int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len * 2 : emb_idx; + int64_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * last_dim * seq_len * 2 : emb_idx; const int base_idx_left = token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim + h_bias; @@ -898,6 +1002,9 @@ __global__ void GQANeoxVariableLengthRotaryKernel( Store(left_vec, &qkv_out[base_idx_left]); Store(right_vec, &qkv_out[base_idx_right]); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } template @@ -928,6 +1035,9 @@ __global__ void GQANeoxVariableLengthPartialRotaryKernel( int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; const int rotary_dim_half = rotary_dim / 2; const int offset = (q_num_head + kv_num_head) * rotary_dim_half; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (int64_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; @@ -939,10 +1049,12 @@ __global__ void GQANeoxVariableLengthPartialRotaryKernel( const int hi = bias / rotary_dim_half; const int h_bias = bias % rotary_dim_half; - const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + const int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int emb_idx = ori_seq_id * rotary_dim_half + h_bias; - int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * head_dim * seq_len * 2 : emb_idx; + int64_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * head_dim * seq_len * 2 : emb_idx; const int base_idx_left = token_idx * (q_num_head + 2 * kv_num_head) * head_dim + hi * head_dim + h_bias; @@ -966,6 +1078,9 @@ __global__ void GQANeoxVariableLengthPartialRotaryKernel( Store(left_vec, &qkv_out[base_idx_left]); Store(right_vec, &qkv_out[base_idx_right]); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } template @@ -976,11 +1091,11 @@ __global__ void cache_kernel( // head_size] T *__restrict__ value_cache, // [num_blocks, kv_num_heads, block_size, // head_size] - const int *__restrict__ block_tables, // [bsz, max_blocks_per_seq] - const int *__restrict__ batch_id_per_token, // [num_tokens] - const int *__restrict__ cu_seqlens_q, // [bsz] - const int *__restrict__ seq_lens, // [bsz] - const int *__restrict__ seq_lens_decoder, // [bsz] + const int *__restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int *__restrict__ batch_id_per_token, // [num_tokens] + const int *__restrict__ cu_seqlens_q, // [bsz] + const int *__restrict__ seq_lens, // [bsz] + const int *__restrict__ seq_lens_decoder, // [bsz] const int max_seq_len, const int max_blocks_per_seq, const int num_heads, @@ -994,6 +1109,9 @@ __global__ void cache_kernel( uint32_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; const uint32_t hidden_size = kv_num_heads * head_size; const uint32_t offset = 2 * hidden_size; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif for (uint32_t linear_index = global_thread_idx * VecSize, step = gridDim.x * blockDim.x * VecSize; linear_index < elem_cnt; @@ -1007,7 +1125,8 @@ __global__ void cache_kernel( const int32_t ori_bi = batch_id_per_token[token_idx]; if (ori_bi == -1) continue; // skip batch_id_per_token[token_idx]=-1 if (seq_lens[ori_bi] == 0) continue; - const uint32_t ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + const uint32_t ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int32_t *block_table_now = nullptr; @@ -1029,9 +1148,11 @@ __global__ void cache_kernel( Store(src_vec, &value_cache[tgt_idx]); } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } - template + bool IsFP8 = false> __global__ void append_write_cache_kv_c8_qkv( uint8_t *__restrict__ cache_k, uint8_t *__restrict__ cache_v, @@ -1063,6 +1184,9 @@ __global__ void append_write_cache_kv_c8_qkv( const T cache_k_scale = cache_k_scales[kv_head_idx]; const T cache_v_scale = cache_v_scales[kv_head_idx]; const uint32_t tid = threadIdx.x, wid = threadIdx.y; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif const uint32_t batch_id = batch_ids[btid]; const uint32_t tile_id = tile_ids[btid]; const uint32_t seq_len_this_time = seq_lens_this_time[batch_id]; @@ -1095,16 +1219,15 @@ __global__ void append_write_cache_kv_c8_qkv( // int lane_id = wid * 32 + tid; // pad zero for this kv_head_idx for this block LoadPadKVT pad_cache_vec; - *(reinterpret_cast(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0); + *(reinterpret_cast(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0); // reset k constexpr int num_vecs_per_head_k = HEAD_DIM / KV_VEC_SIZE; constexpr int num_token_each_time_k = 32 / num_vecs_per_head_k; uint32_t tgt_idx = (block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE * HEAD_DIM + tid % num_vecs_per_head_k * KV_VEC_SIZE; - for (int block_i = tid / num_vecs_per_head_k; - block_i < BLOCK_SIZE; - block_i += num_token_each_time_k) { + for (int block_i = tid / num_vecs_per_head_k; block_i < BLOCK_SIZE; + block_i += num_token_each_time_k) { Store(pad_cache_vec, &cache_k[tgt_idx + block_i * HEAD_DIM]); } @@ -1112,13 +1235,12 @@ __global__ void append_write_cache_kv_c8_qkv( // reset v const int num_vecs_per_head_v = BLOCK_SIZE / KV_VEC_SIZE; const int num_token_each_time_v = 32 / num_vecs_per_head_v; - tgt_idx = - (block_id * kv_num_heads + kv_head_idx) * HEAD_DIM * BLOCK_SIZE + - tid % num_vecs_per_head_v * KV_VEC_SIZE; + tgt_idx = (block_id * kv_num_heads + kv_head_idx) * HEAD_DIM * BLOCK_SIZE + + tid % num_vecs_per_head_v * KV_VEC_SIZE; for (int block_i = tid / num_vecs_per_head_v; block_i < HEAD_DIM; - block_i += num_token_each_time_v) { - Store( - pad_cache_vec, &cache_v[tgt_idx + block_i * BLOCK_SIZE]); + block_i += num_token_each_time_v) { + Store(pad_cache_vec, + &cache_v[tgt_idx + block_i * BLOCK_SIZE]); } } smem_t k_smem(k_smem_ori); @@ -1212,7 +1334,8 @@ __global__ void append_write_cache_kv_c8_qkv( uint8_t uint_quant_value; if (chunk_start_k + (v_id / 4) * 8 >= start_len && chunk_start_k + (v_id / 4) * 8 < end_len) { - uint_quant_value = QuantToC8(cache_k_scale, k_frag_T[v_id], 127.0f, -127.0f); + uint_quant_value = QuantToC8( + cache_k_scale, k_frag_T[v_id], 127.0f, -127.0f); } else { uint_quant_value = 0; } @@ -1268,7 +1391,8 @@ __global__ void append_write_cache_kv_c8_qkv( uint8_t uint_quant_value; if (chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 >= start_len && chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 < end_len) { - uint_quant_value = QuantToC8(cache_v_scale, v_frag_T[v_id], 127.0f, -127.0f); + uint_quant_value = QuantToC8( + cache_v_scale, v_frag_T[v_id], 127.0f, -127.0f); // store now } else { uint_quant_value = 0; @@ -1299,6 +1423,9 @@ __global__ void append_write_cache_kv_c8_qkv( 16 * num_frags_z_v * num_vecs_per_head; chunk_start_v -= 16 * num_frags_z_v; } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } template = 900)) + cudaGridDependencySynchronize(); +#endif const uint32_t batch_id = batch_ids[btid]; const uint32_t tile_id = tile_ids[btid]; const uint32_t seq_len_this_time = seq_lens_this_time[batch_id]; @@ -1364,16 +1494,15 @@ __global__ void append_write_cache_kv_c8_qkv_dynamic( using LoadPadKVT = AlignedVector; // pad zero for this kv_head_idx for this block LoadPadKVT pad_cache_vec; - *(reinterpret_cast(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0); + *(reinterpret_cast(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0); // reset k constexpr int num_vecs_per_head_k = HEAD_DIM / KV_VEC_SIZE; constexpr int num_token_each_time_k = 32 / num_vecs_per_head_k; uint32_t tgt_idx = (block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE * HEAD_DIM + tid % num_vecs_per_head_k * KV_VEC_SIZE; - for (int block_i = tid / num_vecs_per_head_k; - block_i < BLOCK_SIZE; - block_i += num_token_each_time_k) { + for (int block_i = tid / num_vecs_per_head_k; block_i < BLOCK_SIZE; + block_i += num_token_each_time_k) { Store(pad_cache_vec, &cache_k[tgt_idx + block_i * HEAD_DIM]); } @@ -1381,13 +1510,12 @@ __global__ void append_write_cache_kv_c8_qkv_dynamic( // reset v const int num_vecs_per_head_v = BLOCK_SIZE / KV_VEC_SIZE; const int num_token_each_time_v = 32 / num_vecs_per_head_v; - tgt_idx = - (block_id * kv_num_heads + kv_head_idx) * HEAD_DIM * BLOCK_SIZE + - tid % num_vecs_per_head_v * KV_VEC_SIZE; + tgt_idx = (block_id * kv_num_heads + kv_head_idx) * HEAD_DIM * BLOCK_SIZE + + tid % num_vecs_per_head_v * KV_VEC_SIZE; for (int block_i = tid / num_vecs_per_head_v; block_i < HEAD_DIM; - block_i += num_token_each_time_v) { - Store( - pad_cache_vec, &cache_v[tgt_idx + block_i * BLOCK_SIZE]); + block_i += num_token_each_time_v) { + Store(pad_cache_vec, + &cache_v[tgt_idx + block_i * BLOCK_SIZE]); } } smem_t k_smem(k_smem_ori); @@ -1456,10 +1584,10 @@ __global__ void append_write_cache_kv_c8_qkv_dynamic( // reduce scale // 16 rows per warp uint32_t kv_reduce_frag[4]; - T *kv_reduce_frag_T = reinterpret_cast(kv_reduce_frag); + T *kv_reduce_frag_T = reinterpret_cast(kv_reduce_frag); - T k_local_max_value[num_frags_z * 2]; - T v_local_max_value[num_frags_z * 2]; + T k_local_max_value[num_frags_z * 2]; + T v_local_max_value[num_frags_z * 2]; #pragma unroll for (int i = 0; i < num_frags_z * 2; i++) { k_local_max_value[i] = -INFINITY; @@ -1469,7 +1597,8 @@ __global__ void append_write_cache_kv_c8_qkv_dynamic( v_local_max_value[i] = -INFINITY; } const int num_kv_heads = gridDim.z; - const int scale_offset = block_id * num_kv_heads * BLOCK_SIZE + kv_head_idx * BLOCK_SIZE; + const int scale_offset = + block_id * num_kv_heads * BLOCK_SIZE + kv_head_idx * BLOCK_SIZE; T *cache_k_scale_now = cache_k_scales + scale_offset; T *cache_v_scale_now = cache_v_scales + scale_offset; // k scale @@ -1481,19 +1610,23 @@ __global__ void append_write_cache_kv_c8_qkv_dynamic( k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_reduce_frag); #pragma unroll for (int i = 0; i < 4; i++) { - k_local_max_value[fz * 2] = __hmax(__habs(kv_reduce_frag_T[i]), k_local_max_value[fz * 2]); + k_local_max_value[fz * 2] = + __hmax(__habs(kv_reduce_frag_T[i]), k_local_max_value[fz * 2]); } #pragma unroll for (int i = 0; i < 4; i++) { - k_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), k_local_max_value[fz * 2 + 1]); + k_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), + k_local_max_value[fz * 2 + 1]); } k_smem_offset_r = k_smem.advance_offset_by_column<2>(k_smem_offset_r, fy); } // reduce per row for (int i = 0; i < 2; i++) { T local_max_value = __habs(k_local_max_value[fz * 2 + i]); - local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 2)); - local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 1)); + local_max_value = __hmax(local_max_value, + __shfl_xor_sync(0xffffffff, local_max_value, 2)); + local_max_value = __hmax(local_max_value, + __shfl_xor_sync(0xffffffff, local_max_value, 1)); // used for quant k_local_max_value[fz * 2 + i] = __hdiv(448, local_max_value); } @@ -1510,17 +1643,18 @@ __global__ void append_write_cache_kv_c8_qkv_dynamic( } if (tile_start + offset_now + 8 >= start_len) { if (tile_start + offset_now + 8 < end_len) { - cache_k_scale_now[offset_now + 8] = __hdiv(1, k_local_max_value[fz * 2 + 1]); + cache_k_scale_now[offset_now + 8] = + __hdiv(1, k_local_max_value[fz * 2 + 1]); } else { cache_k_scale_now[offset_now + 8] = 0; } } } __syncthreads(); - k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1 + k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1 } - // v scale - #pragma unroll +// v scale +#pragma unroll for (uint32_t fz = 0; fz < num_frags_z; ++fz) { #pragma unroll for (uint32_t fy = 0; fy < num_frags_y; ++fy) { @@ -1528,19 +1662,23 @@ __global__ void append_write_cache_kv_c8_qkv_dynamic( v_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_reduce_frag); #pragma unroll for (int i = 0; i < 4; i++) { - v_local_max_value[fz * 2] = __hmax(__habs(kv_reduce_frag_T[i]), v_local_max_value[fz * 2]); + v_local_max_value[fz * 2] = + __hmax(__habs(kv_reduce_frag_T[i]), v_local_max_value[fz * 2]); } #pragma unroll for (int i = 0; i < 4; i++) { - v_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), v_local_max_value[fz * 2 + 1]); + v_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), + v_local_max_value[fz * 2 + 1]); } k_smem_offset_r = v_smem.advance_offset_by_column<2>(k_smem_offset_r, fy); } // reduce per row for (int i = 0; i < 2; i++) { T local_max_value = __habs(v_local_max_value[fz * 2 + i]); - local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 2)); - local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 1)); + local_max_value = __hmax(local_max_value, + __shfl_xor_sync(0xffffffff, local_max_value, 2)); + local_max_value = __hmax(local_max_value, + __shfl_xor_sync(0xffffffff, local_max_value, 1)); v_local_max_value[fz * 2 + i] = __hdiv(448, local_max_value); } // store @@ -1558,7 +1696,8 @@ __global__ void append_write_cache_kv_c8_qkv_dynamic( } if (tile_start + offset_now + 8 >= start_len) { if (tile_start + offset_now + 8 < end_len) { - cache_v_scale_now[offset_now + 8] = __hdiv(1, v_local_max_value[fz * 2 + 1]); + cache_v_scale_now[offset_now + 8] = + __hdiv(1, v_local_max_value[fz * 2 + 1]); v_scale_smem[offset_now + 8] = v_local_max_value[fz * 2 + 1]; } else { cache_v_scale_now[offset_now + 8] = 0; @@ -1567,7 +1706,7 @@ __global__ void append_write_cache_kv_c8_qkv_dynamic( } } __syncthreads(); - k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1 + k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1 } __syncthreads(); @@ -1607,7 +1746,11 @@ __global__ void append_write_cache_kv_c8_qkv_dynamic( uint8_t uint_quant_value; if (chunk_start_k + (v_id / 4) * 8 >= start_len && chunk_start_k + (v_id / 4) * 8 < end_len) { - uint_quant_value = QuantToC8(k_local_max_value[fz * 2 + v_id / 4], k_frag_T[v_id], 127.0f, -127.0f); + uint_quant_value = QuantToC8( + k_local_max_value[fz * 2 + v_id / 4], + k_frag_T[v_id], + 127.0f, + -127.0f); } else { uint_quant_value = 0; } @@ -1673,7 +1816,8 @@ __global__ void append_write_cache_kv_c8_qkv_dynamic( uint8_t uint_quant_value; if (chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 >= start_len && chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 < end_len) { - uint_quant_value = QuantToC8(v_scales[fz * 4 + v_id % 4], v_frag_T[v_id], 127.0f, -127.0f); + uint_quant_value = QuantToC8( + v_scales[fz * 4 + v_id % 4], v_frag_T[v_id], 127.0f, -127.0f); // store now } else { uint_quant_value = 0; @@ -1704,6 +1848,9 @@ __global__ void append_write_cache_kv_c8_qkv_dynamic( 16 * num_frags_z_v * num_vecs_per_head; chunk_start_v -= 16 * num_frags_z_v; } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } // Write Cache KV in Append @@ -1736,6 +1883,9 @@ __global__ void append_write_cache_kv_c4_qkv( constexpr uint32_t pad_len = BLOCK_SIZE; const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; const uint32_t tid = threadIdx.x, wid = threadIdx.y; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif const uint32_t batch_id = batch_ids[btid]; const uint32_t tile_id = tile_ids[btid]; const uint32_t seq_len_this_time = seq_lens_this_time[batch_id]; @@ -1768,28 +1918,27 @@ __global__ void append_write_cache_kv_c4_qkv( using LoadPadKVT = AlignedVector; // pad zero for this kv_head_idx for this block LoadPadKVT pad_cache_vec; - *(reinterpret_cast(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0); + *(reinterpret_cast(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0); // reset k - constexpr int num_vecs_per_head_k = HEAD_DIM_HALF / KV_VEC_SIZE; // 4 - constexpr int num_token_each_time_k = 32 / num_vecs_per_head_k; // 8 + constexpr int num_vecs_per_head_k = HEAD_DIM_HALF / KV_VEC_SIZE; // 4 + constexpr int num_token_each_time_k = 32 / num_vecs_per_head_k; // 8 uint32_t tgt_idx = (block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE * HEAD_DIM_HALF + tid % num_vecs_per_head_k * KV_VEC_SIZE; - for (int block_i = tid / num_vecs_per_head_k; - block_i < BLOCK_SIZE; - block_i += num_token_each_time_k) { + for (int block_i = tid / num_vecs_per_head_k; block_i < BLOCK_SIZE; + block_i += num_token_each_time_k) { Store(pad_cache_vec, &cache_k[tgt_idx + block_i * HEAD_DIM_HALF]); } // reset v - const int num_vecs_per_head_v = BLOCK_SIZE_HALF / KV_VEC_SIZE; // 2 - const int num_token_each_time_v = 32 / num_vecs_per_head_v; // 16 + const int num_vecs_per_head_v = BLOCK_SIZE_HALF / KV_VEC_SIZE; // 2 + const int num_token_each_time_v = 32 / num_vecs_per_head_v; // 16 tgt_idx = (block_id * kv_num_heads + kv_head_idx) * HEAD_DIM * BLOCK_SIZE_HALF + tid % num_vecs_per_head_v * KV_VEC_SIZE; for (int block_i = tid / num_vecs_per_head_v; block_i < HEAD_DIM; - block_i += num_token_each_time_v) { + block_i += num_token_each_time_v) { Store( pad_cache_vec, &cache_v[tgt_idx + block_i * BLOCK_SIZE_HALF]); } @@ -1845,16 +1994,10 @@ __global__ void append_write_cache_kv_c4_qkv( for (uint32_t fy = 0; fy < num_frags_y / 4; ++fy) { // (num_frags_y * 16) / (8 * num_elems_per_128b()) if (chunk_start >= start_len && chunk_start < end_len) { - k_smem - .load_128b_async( - kv_smem_offset_w, - qkv_input + k_read_idx, - chunk_start < end_len); - v_smem - .load_128b_async( - kv_smem_offset_w, - qkv_input + v_read_idx, - chunk_start < end_len); + k_smem.load_128b_async( + kv_smem_offset_w, qkv_input + k_read_idx, chunk_start < end_len); + v_smem.load_128b_async( + kv_smem_offset_w, qkv_input + v_read_idx, chunk_start < end_len); } kv_smem_offset_w = k_smem.advance_offset_by_column<8>(kv_smem_offset_w, fy); @@ -2040,6 +2183,9 @@ __global__ void append_write_cache_kv_c4_qkv( 16 * num_frags_z_v * num_vecs_per_head; chunk_start_v -= 16 * num_frags_z_v; } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif } template @@ -2061,9 +2207,8 @@ void rotary_qk_variable( const cudaStream_t &stream, bool use_neox_style = false, bool rope_3d = false) { - int64_t elem_nums = - qkv_out_scales ? token_num * 3 * head_num * dim_head - : token_num * 2 * head_num * dim_head; + int64_t elem_nums = qkv_out_scales ? token_num * 3 * head_num * dim_head + : token_num * 2 * head_num * dim_head; if (use_neox_style) { elem_nums /= 2; } @@ -2077,77 +2222,89 @@ void rotary_qk_variable( const float *cos_emb = rotary_emb; const float *sin_emb = rotary_emb + input_output_len * dim_head / 2; if (qkv_out_scales) { - VariableLengthRotaryKernel - <<>>( - reinterpret_cast(qkv_input), - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_out_scales, - qkv_bias, - qkv_out, - elem_nums, - head_num, - seq_len, - dim_head, - rope_3d); + launchWithPdlWhenEnabled(IntVariableLengthRotaryKernel, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_out_scales, + qkv_bias, + qkv_out, + elem_nums, + head_num, + seq_len, + dim_head, + rope_3d); } else { - VariableLengthRotaryKernel - <<>>( - reinterpret_cast(qkv_input), - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_out, - elem_nums, - head_num, - seq_len, - dim_head, - rope_3d); + launchWithPdlWhenEnabled(VariableLengthRotaryKernel, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_out, + elem_nums, + head_num, + seq_len, + dim_head, + rope_3d); } } else { const float *cos_emb = rotary_emb; const float *sin_emb = rotary_emb + input_output_len * dim_head; if (qkv_out_scales) { - NeoxVariableLengthRotaryKernel - <<>>( - reinterpret_cast(qkv_input), - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_out_scales, - qkv_bias, - qkv_out, - elem_nums, - head_num, - seq_len, - dim_head, - rope_3d); + launchWithPdlWhenEnabled(IntNeoxVariableLengthRotaryKernel, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_out_scales, + qkv_bias, + qkv_out, + elem_nums, + head_num, + seq_len, + dim_head, + rope_3d); } else { - NeoxVariableLengthRotaryKernel - <<>>( - reinterpret_cast(qkv_input), - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_out, - elem_nums, - head_num, - seq_len, - dim_head, - rope_3d); + launchWithPdlWhenEnabled(NeoxVariableLengthRotaryKernel, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_out, + elem_nums, + head_num, + seq_len, + dim_head, + rope_3d); } } } @@ -2180,7 +2337,9 @@ void gqa_rotary_qk_norm_variable( ? token_num * (num_heads + 2 * kv_num_heads) * dim_head : token_num * (num_heads + kv_num_heads) * dim_head; // for all q k v if (dim_head != 128) { - PADDLE_THROW("gqa rotary with qk norm only support head_dim=128, but got %d.", dim_head); + PADDLE_THROW( + "gqa rotary with qk norm only support head_dim=128, but got %d.", + dim_head); } constexpr int HEAD_DIM = 128; constexpr int PackSize = HEAD_DIM / kWarpSize; @@ -2188,30 +2347,32 @@ void gqa_rotary_qk_norm_variable( const int blocksize = 128; int grid_size = 1; GetNumBlocks<128>(pack_num, &grid_size); - dim3 Block_Size(kWarpSize, blocksize/kWarpSize, 1); + dim3 Block_Size(kWarpSize, blocksize / kWarpSize, 1); const float *cos_emb = rotary_emb; const float *sin_emb = rotary_emb + input_output_len * dim_head / 2; - - GQAVariableLengthRotaryQKNormKernel - <<>>( - reinterpret_cast(qkv_input), - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_out, - elem_nums, - num_heads, - kv_num_heads, - seq_len, - dim_head, - rope_3d, - q_norm_weight, - k_norm_weight, - rms_norm_eps); + launchWithPdlWhenEnabled(GQAVariableLengthRotaryQKNormKernel, + grid_size, + Block_Size, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_out, + elem_nums, + num_heads, + kv_num_heads, + seq_len, + dim_head, + rope_3d, + q_norm_weight, + k_norm_weight, + rms_norm_eps); } template @@ -2253,114 +2414,134 @@ void gqa_rotary_qk_variable( const float *cos_emb = rotary_emb; const float *sin_emb = rotary_emb + input_output_len * dim_head / 2; if (qkv_out_scales) { - GQAVariableLengthRotaryKernel - <<>>( - reinterpret_cast(qkv_input), - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_out_scales, - qkv_bias, - qkv_out, - elem_nums, - num_heads, - kv_num_heads, - seq_len, - dim_head, - rope_3d); + launchWithPdlWhenEnabled(IntGQAVariableLengthRotaryKernel, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_out_scales, + qkv_bias, + qkv_out, + elem_nums, + num_heads, + kv_num_heads, + seq_len, + dim_head, + rope_3d); } else { - GQAVariableLengthRotaryKernel - <<>>( - reinterpret_cast(qkv_input), - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_out, - elem_nums, - num_heads, - kv_num_heads, - seq_len, - dim_head, - rope_3d); + auto *kernelFn = GQAVariableLengthRotaryKernel; + launchWithPdlWhenEnabled(kernelFn, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_out, + elem_nums, + num_heads, + kv_num_heads, + seq_len, + dim_head, + rope_3d); } } else { const float *cos_emb = rotary_emb; const float *sin_emb = rotary_emb + input_output_len * dim_head; if (qkv_out_scales) { - GQANeoxVariableLengthRotaryKernel - <<>>( - reinterpret_cast(qkv_input), - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_out_scales, - qkv_bias, - qkv_out, - elem_nums, - num_heads, - kv_num_heads, - seq_len, - dim_head, - rope_3d); + launchWithPdlWhenEnabled( + IntGQANeoxVariableLengthRotaryKernel, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_out_scales, + qkv_bias, + qkv_out, + elem_nums, + num_heads, + kv_num_heads, + seq_len, + dim_head, + rope_3d); } else { - if (rotary_dim < dim_head){ + if (rotary_dim < dim_head) { PD_CHECK((rotary_dim / 2) % PackSize == 0); elem_nums = qkv_out_scales ? token_num * (num_heads + 2 * kv_num_heads) * rotary_dim - : token_num * (num_heads + kv_num_heads) * rotary_dim; // for all q k v + : token_num * (num_heads + kv_num_heads) * + rotary_dim; // for all q k v if (use_neox_style) { elem_nums /= 2; } const int pack_num_new = elem_nums / PackSize; GetNumBlocks<128>(pack_num_new, &grid_size); - GQANeoxVariableLengthPartialRotaryKernel - <<>>( - reinterpret_cast(qkv_input), - cos_emb, - rotary_emb + input_output_len * rotary_dim / 2, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_out_scales, - qkv_bias, - qkv_out, - elem_nums, - num_heads, - kv_num_heads, - seq_len, - dim_head, - rotary_dim, - rope_3d); - }else{ - GQANeoxVariableLengthRotaryKernel - <<>>( - reinterpret_cast(qkv_input), - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_out_scales, - qkv_bias, - qkv_out, - elem_nums, - num_heads, - kv_num_heads, - seq_len, - dim_head, - rope_3d); + auto *kernelFn = GQANeoxVariableLengthPartialRotaryKernel; + launchWithPdlWhenEnabled(kernelFn, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + rotary_emb + input_output_len * rotary_dim / 2, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_out_scales, + qkv_bias, + qkv_out, + elem_nums, + num_heads, + kv_num_heads, + seq_len, + dim_head, + rotary_dim, + rope_3d); + } else { + auto *kernelFn = GQANeoxVariableLengthRotaryKernel; + launchWithPdlWhenEnabled(kernelFn, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_out_scales, + qkv_bias, + qkv_out, + elem_nums, + num_heads, + kv_num_heads, + seq_len, + dim_head, + rope_3d); } } } @@ -2399,49 +2580,57 @@ void gqa_rotary_qk_quant_variable( int grid_size = 1; GetNumBlocks<128>(pack_num, &grid_size); const float *cos_emb = rotary_emb; - const float *sin_emb = rotary_emb + input_output_len * dim_head / 2; + const float *sin_emb = rotary_emb + input_output_len * dim_head / 2; if (!use_neox_style) { if (qkv_out_scales) { - GQAVariableLengthRotaryQuantKVKernel - <<>>( - reinterpret_cast(qkv_input), - cos_emb, - sin_emb, - qkv_out_scales, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_bias, - cache_k_scales, - cache_v_scales, - qkv_out, - elem_nums, - num_heads, - kv_num_heads, - seq_len, - dim_head, - rope_3d); + launchWithPdlWhenEnabled( + IntGQAVariableLengthRotaryQuantKVKernel, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + qkv_out_scales, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_bias, + cache_k_scales, + cache_v_scales, + qkv_out, + elem_nums, + num_heads, + kv_num_heads, + seq_len, + dim_head, + rope_3d); } else { - GQAVariableLengthRotaryQuantKVKernel - <<>>( - reinterpret_cast(qkv_input), - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens, - seq_lens_decoder, - qkv_bias, - cache_k_scales, - cache_v_scales, - qkv_out, - elem_nums, - num_heads, - kv_num_heads, - seq_len, - dim_head, - rope_3d); + launchWithPdlWhenEnabled( + GQAVariableLengthRotaryQuantKVKernel, + grid_size, + blocksize, + 0, + stream, + reinterpret_cast(qkv_input), + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens, + seq_lens_decoder, + qkv_bias, + cache_k_scales, + cache_v_scales, + qkv_out, + elem_nums, + num_heads, + kv_num_heads, + seq_len, + dim_head, + rope_3d); } } else { PADDLE_THROW("Use_neox_style mode isn't implemented yet"); @@ -2470,14 +2659,18 @@ void CascadeAppendWriteCacheKVQKV( auto head_dim = meta_data.head_dims; auto block_size = meta_data.block_size; - const uint32_t elem_nums = - num_tokens * 2 * kv_num_heads * head_dim; + const uint32_t elem_nums = num_tokens * 2 * kv_num_heads * head_dim; constexpr int PackSize = 16 / sizeof(T); const int pack_num = elem_nums / PackSize; const int blocksize = 128; int grid_size = 1; GetNumBlocks<128>(pack_num, &grid_size); - cache_kernel<<>>( + launchWithPdlWhenEnabled( + cache_kernel, + grid_size, + blocksize, + 0, + stream, reinterpret_cast(const_cast(qkv.data())), reinterpret_cast(key_cache_out->data()), reinterpret_cast(value_cache_out->data()), @@ -2515,7 +2708,7 @@ void CascadeAppendWriteCacheKVC8QKV( int num_blocks_x_cpu, int max_seq_len, bool is_scale_channel_wise, - const std::string& cache_quant_type, + const std::string &cache_quant_type, cudaStream_t &stream, paddle::Tensor *cache_k_out, paddle::Tensor *cache_v_out) { @@ -2544,43 +2737,50 @@ void CascadeAppendWriteCacheKVC8QKV( HEAD_DIM, BLOCK_SIZE, num_warps, - true, false>; + true, + false>; if (cache_quant_type == "cache_fp8") { kernel_fn = append_write_cache_kv_c8_qkv; + num_frags_y, + num_frags_z, + HEAD_DIM, + BLOCK_SIZE, + num_warps, + true, + true>; } if (is_scale_channel_wise) { kernel_fn = append_write_cache_kv_c8_qkv; + num_frags_y, + num_frags_z, + HEAD_DIM, + BLOCK_SIZE, + num_warps, + false>; } cudaFuncSetAttribute( kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - kernel_fn<<>>(cache_k_out->data(), - cache_v_out->data(), - qkv.data(), - cache_k_scale.data(), - cache_v_scale.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - seq_lens_this_time.data(), - seq_lens_decoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - block_table.data(), - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads); + launchWithPdlWhenEnabled(kernel_fn, + grids, + blocks, + 0, + stream, + cache_k_out->data(), + cache_v_out->data(), + qkv.data(), + cache_k_scale.data(), + cache_v_scale.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + seq_lens_this_time.data(), + seq_lens_decoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + block_table.data(), + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads); } else { auto kernel_fn = append_write_cache_kv_c8_qkv_dynamic; + true, + true>; cudaFuncSetAttribute( kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - kernel_fn<<>>(cache_k_out->data(), - cache_v_out->data(), - reinterpret_cast(qkv.data()), - const_cast(reinterpret_cast(cache_k_scale.data())), - const_cast(reinterpret_cast(cache_v_scale.data())), - batch_ids.data(), - tile_ids_per_batch.data(), - seq_lens_this_time.data(), - seq_lens_decoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - block_table.data(), - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads); + launchWithPdlWhenEnabled( + kernel_fn, + grids, + blocks, + 0, + stream, + cache_k_out->data(), + cache_v_out->data(), + reinterpret_cast(qkv.data()), + const_cast( + reinterpret_cast(cache_k_scale.data())), + const_cast( + reinterpret_cast(cache_v_scale.data())), + batch_ids.data(), + tile_ids_per_batch.data(), + seq_lens_this_time.data(), + seq_lens_decoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + block_table.data(), + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads); } } @@ -2660,22 +2869,27 @@ void CascadeAppendWriteCacheKVC4QKV( num_warps>; cudaFuncSetAttribute( kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - kernel_fn<<>>(cache_k_out->data(), - cache_v_out->data(), - qkv.data(), - cache_k_scale.data(), - cache_v_scale.data(), - cache_k_zp.data(), - cache_v_zp.data(), - batch_ids.data(), - tile_ids_per_batch.data(), - seq_lens_this_time.data(), - seq_lens_decoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - block_table.data(), - max_seq_len, - max_blocks_per_seq, - num_heads, - kv_num_heads); + launchWithPdlWhenEnabled(kernel_fn, + grids, + blocks, + 0, + stream, + cache_k_out->data(), + cache_v_out->data(), + qkv.data(), + cache_k_scale.data(), + cache_v_scale.data(), + cache_k_zp.data(), + cache_v_zp.data(), + batch_ids.data(), + tile_ids_per_batch.data(), + seq_lens_this_time.data(), + seq_lens_decoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + block_table.data(), + max_seq_len, + max_blocks_per_seq, + num_heads, + kv_num_heads); } diff --git a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu index 0388b9fb6..4f847e8de 100644 --- a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu +++ b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "encoder_write_cache_with_rope_impl.cuh" #include "helper.h" #include "paddle/extension.h" -#include "paddle/phi/core/memory/memcpy.h" -#include "encoder_write_cache_with_rope_impl.cuh" #include "paddle/phi/backends/context_pool.h" +#include "paddle/phi/core/memory/memcpy.h" #include "remote_cache_kv_ipc.h" template @@ -59,25 +59,30 @@ __global__ void GQAVariableLengthRotarySplitKernel( const int hi = bias / last_dim; const int h_bias = bias % last_dim; - const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; + const int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id; const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2; - int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx; + int64_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx; const int64_t base_idx = token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim + h_bias; int64_t base_split_idx; T *out_p = nullptr; if (hi < q_num_head) { - base_split_idx = token_idx * q_num_head * last_dim + hi * last_dim + h_bias; + base_split_idx = + token_idx * q_num_head * last_dim + hi * last_dim + h_bias; out_p = q; } else if (hi < q_num_head + kv_num_head) { - base_split_idx = kv_write_idx * kv_num_head * last_dim + (hi - q_num_head) * last_dim + h_bias; + base_split_idx = kv_write_idx * kv_num_head * last_dim + + (hi - q_num_head) * last_dim + h_bias; out_p = k; } else { out_p = v; - base_split_idx = kv_write_idx * kv_num_head * last_dim + (hi - q_num_head - kv_num_head) * last_dim + h_bias; + base_split_idx = kv_write_idx * kv_num_head * last_dim + + (hi - q_num_head - kv_num_head) * last_dim + h_bias; } Load(&qkv[base_idx], &src_vec); // do rope @@ -103,7 +108,7 @@ __global__ void GQAVariableLengthRotarySplitKernel( template void gqa_rotary_qk_split_variable( - T *qkv_out, // [token_num, 3, num_head, dim_head] + T *qkv_out, // [token_num, 3, num_head, dim_head] T *q, T *k, T *v, @@ -131,49 +136,52 @@ void gqa_rotary_qk_split_variable( const float *cos_emb = rotary_emb; const float *sin_emb = rotary_emb + input_output_len * dim_head / 2; - GQAVariableLengthRotarySplitKernel - <<>>( - qkv_input, - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens_encoder, - seq_lens_decoder, - cu_seqlens_k, - qkv_out, - q, - k, - v, - elem_nums, - num_heads, - kv_num_heads, - seq_len, - dim_head, - rope_3d); + launchWithPdlWhenEnabled(GQAVariableLengthRotarySplitKernel, + grid_size, + blocksize, + 0, + stream, + qkv_input, + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens_encoder, + seq_lens_decoder, + cu_seqlens_k, + qkv_out, + q, + k, + v, + elem_nums, + num_heads, + kv_num_heads, + seq_len, + dim_head, + rope_3d); } template -__global__ void append_cache_kv_c16( - const T *__restrict__ cache_k, - const T *__restrict__ cache_v, - T *__restrict__ k_out, - T *__restrict__ v_out, - const int *__restrict__ seq_lens_this_time, - const int *__restrict__ seq_lens_decoder, - const int *__restrict__ cu_seqlens_k, - const int *__restrict__ block_tables, - const int *batch_ids, - const int *tile_ids_per_batch, - const int max_blocks_per_seq, - const int kv_num_heads) { + uint32_t NUM_WARPS = 4> +__global__ void append_cache_kv_c16(const T *__restrict__ cache_k, + const T *__restrict__ cache_v, + T *__restrict__ k_out, + T *__restrict__ v_out, + const int *__restrict__ seq_lens_this_time, + const int *__restrict__ seq_lens_decoder, + const int *__restrict__ cu_seqlens_k, + const int *__restrict__ block_tables, + const int *batch_ids, + const int *tile_ids_per_batch, + const int max_blocks_per_seq, + const int kv_num_heads) { // start_kv_idx: start kv_idx current block // batch_id:block's batch_id - // TODO: 1.scale preload 2.frag_dq_T reuse 3.pipeline 4.store aligned 5.cacheT with template(int8/fp8) + // TODO: 1.scale preload 2.frag_dq_T reuse 3.pipeline 4.store aligned 5.cacheT + // with template(int8/fp8) const uint32_t tile_idx = blockIdx.x, kv_head_idx = blockIdx.z; const uint32_t tid = threadIdx.x, wid = threadIdx.y; @@ -189,13 +197,17 @@ __global__ void append_cache_kv_c16( // cache_kv idx uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; uint32_t block_stride = kv_num_heads * kv_h_stride; - const CacheT *cur_cache_k = cache_k + block_id * block_stride + kv_head_idx * kv_h_stride; - const CacheT *cur_cache_v = cache_v + block_id * block_stride + kv_head_idx * kv_h_stride; + const CacheT *cur_cache_k = + cache_k + block_id * block_stride + kv_head_idx * kv_h_stride; + const CacheT *cur_cache_v = + cache_v + block_id * block_stride + kv_head_idx * kv_h_stride; // k_out v_out idx uint32_t kv_t_stride = kv_num_heads * HEAD_DIM; - T *k_write_ptr = k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; - T *v_write_ptr = v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; + T *k_write_ptr = + k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; + T *v_write_ptr = + v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; uint32_t kv_frag[4]; T *frag_dq_T = reinterpret_cast(kv_frag); @@ -206,26 +218,31 @@ __global__ void append_cache_kv_c16( extern __shared__ uint8_t smem[]; smem_t k_smem(smem); - uint32_t k_smem_offset_w = smem_t::get_permuted_offset( - wid * 4 + tid / 8, tid % 8); // 4 * 4 per warp + uint32_t k_smem_offset_w = + smem_t::get_permuted_offset( + wid * 4 + tid / 8, tid % 8); // 4 * 4 per warp - uint32_t k_smem_offset_r = smem_t::get_permuted_offset( - wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + uint32_t k_smem_offset_r = + smem_t::get_permuted_offset( + wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); - uint32_t k_read_idx = (wid * 4 + tid / 8) * HEAD_DIM + - tid % 8 * num_elems_per_128b(); + uint32_t k_read_idx = + (wid * 4 + tid / 8) * HEAD_DIM + tid % 8 * num_elems_per_128b(); // load k_smem 64 rows 128 cols - for (int fz = 0; fz < 4; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter - for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 once, need 2 iter + for (int fz = 0; fz < 4; + fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter + for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 once, need 2 iter k_smem.load_128b_async( - k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0); - k_smem_offset_w = - k_smem.advance_offset_by_column<8, num_vecs_per_head>(k_smem_offset_w, fy); + k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0); + k_smem_offset_w = k_smem.advance_offset_by_column<8, num_vecs_per_head>( + k_smem_offset_w, fy); k_read_idx += 8 * num_elems_per_128b(); } k_smem_offset_w = - k_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head>(k_smem_offset_w) - 16; + k_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head>( + k_smem_offset_w) - + 16; k_read_idx += 4 * NUM_WARPS * HEAD_DIM - 16 * num_elems_per_128b(); } commit_group(); @@ -233,9 +250,10 @@ __global__ void append_cache_kv_c16( __syncthreads(); // deal k_smem 64 rows 128 cols - for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter + for (int fz = 0; fz < 1; + fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter uint32_t row_idx = wid * 16 + tid / 4; - for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 once, need 8 iter + for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 once, need 8 iter uint32_t col_idx = fy * 16 + tid % 4 * 2; k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag); // layout @@ -243,7 +261,8 @@ __global__ void append_cache_kv_c16( r0c0,r0c1, r0c8,r0c9 r8c0,r8c1, r8c8,r8c9 ***/ - T *k_tile_ptr0 = k_write_ptr + row_idx * kv_t_stride + kv_head_idx * HEAD_DIM + col_idx; + T *k_tile_ptr0 = k_write_ptr + row_idx * kv_t_stride + + kv_head_idx * HEAD_DIM + col_idx; T *k_tile_ptr1 = k_tile_ptr0 + 8 * kv_t_stride; if (row_idx < end_idx) { @@ -260,33 +279,40 @@ __global__ void append_cache_kv_c16( k_tile_ptr1[9] = frag_dq_T[7]; } k_smem_offset_r = k_smem.advance_offset_by_column<2, num_vecs_per_head>( - k_smem_offset_r, fy); + k_smem_offset_r, fy); } k_smem_offset_r = - k_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head>(k_smem_offset_r) - 16; + k_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head>( + k_smem_offset_r) - + 16; } // ================v================ smem_t v_smem(smem + BLOCK_SIZE * HEAD_DIM * sizeof(CacheT)); - uint32_t v_smem_offset_w = smem_t::get_permuted_offset( - wid * 4 + tid / 8, tid % 8); // 4 * 4 per warp - uint32_t v_smem_offset_r = smem_t::get_permuted_offset( - wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + uint32_t v_smem_offset_w = + smem_t::get_permuted_offset( + wid * 4 + tid / 8, tid % 8); // 4 * 4 per warp + uint32_t v_smem_offset_r = + smem_t::get_permuted_offset( + wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); - uint32_t v_read_idx = (wid * 4 + tid / 8) * HEAD_DIM + - tid % 8 * num_elems_per_128b(); + uint32_t v_read_idx = + (wid * 4 + tid / 8) * HEAD_DIM + tid % 8 * num_elems_per_128b(); // load v_smem 64 rows 128 cols - for (int fz = 0; fz < 4; fz++) { // // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter - for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 once, need 2 iter + for (int fz = 0; fz < 4; fz++) { // // 4 rows pre warp once, 16 rows all 4 + // warps once, need 4 iter + for (int fy = 0; fy < 2; fy++) { // 8 * 128b = 64 * bf16 once, need 2 iter v_smem.load_128b_async( - v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0); - v_smem_offset_w = - v_smem.advance_offset_by_column<8, num_vecs_per_head>(v_smem_offset_w, fy); + v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0); + v_smem_offset_w = v_smem.advance_offset_by_column<8, num_vecs_per_head>( + v_smem_offset_w, fy); v_read_idx += 8 * num_elems_per_128b(); } v_smem_offset_w = - v_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head>(v_smem_offset_w) - 16; + v_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head>( + v_smem_offset_w) - + 16; v_read_idx += 4 * NUM_WARPS * HEAD_DIM - 16 * num_elems_per_128b(); } commit_group(); @@ -294,9 +320,10 @@ __global__ void append_cache_kv_c16( __syncthreads(); // deal v_smem 64 rows 128 cols - for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter + for (int fz = 0; fz < 1; + fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter uint32_t row_idx = wid * 16 + tid / 4; - for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 once, need 8 iter + for (int fy = 0; fy < 8; fy++) { // 2 * 128b = 16 * bf16 once, need 8 iter uint32_t col_idx = fy * 16 + tid % 4 * 2; v_smem.ldmatrix_m8n8x4(v_smem_offset_r, kv_frag); // layout @@ -304,7 +331,8 @@ __global__ void append_cache_kv_c16( r0c0,r0c1, r0c8,r0c9 r8c0,r8c1, r8c8,r8c9 ***/ - T *v_tile_ptr0 = v_write_ptr + row_idx * kv_t_stride + kv_head_idx * HEAD_DIM + col_idx; + T *v_tile_ptr0 = v_write_ptr + row_idx * kv_t_stride + + kv_head_idx * HEAD_DIM + col_idx; T *v_tile_ptr1 = v_tile_ptr0 + 8 * kv_t_stride; if (row_idx < end_idx) { @@ -321,10 +349,12 @@ __global__ void append_cache_kv_c16( v_tile_ptr1[9] = frag_dq_T[7]; } v_smem_offset_r = v_smem.advance_offset_by_column<2, num_vecs_per_head>( - v_smem_offset_r, fy); + v_smem_offset_r, fy); } v_smem_offset_r = - v_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head>(v_smem_offset_r) - 16; + v_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head>( + v_smem_offset_r) - + 16; } } @@ -332,26 +362,26 @@ template -__global__ void append_cache_kv_c8( - const CacheT *__restrict__ cache_k, - const CacheT *__restrict__ cache_v, - T *__restrict__ k_out, - T *__restrict__ v_out, - const T *__restrict__ cache_k_dequant_scales, - const T *__restrict__ cache_v_dequant_scales, - const int *__restrict__ seq_lens_this_time, - const int *__restrict__ seq_lens_decoder, - const int *__restrict__ cu_seqlens_k, - const int *__restrict__ block_tables, - const int *batch_ids, - const int *tile_ids_per_batch, - const int max_blocks_per_seq, - const int kv_num_heads) { + uint32_t NUM_WARPS = 4, + bool IS_FP8 = false> +__global__ void append_cache_kv_c8(const CacheT *__restrict__ cache_k, + const CacheT *__restrict__ cache_v, + T *__restrict__ k_out, + T *__restrict__ v_out, + const T *__restrict__ cache_k_dequant_scales, + const T *__restrict__ cache_v_dequant_scales, + const int *__restrict__ seq_lens_this_time, + const int *__restrict__ seq_lens_decoder, + const int *__restrict__ cu_seqlens_k, + const int *__restrict__ block_tables, + const int *batch_ids, + const int *tile_ids_per_batch, + const int max_blocks_per_seq, + const int kv_num_heads) { // start_kv_idx: start kv_idx current block // batch_id:block's batch_id - // TODO: 1.scale preload 2.frag_dq_T reuse 3.pipeline 4.store aligned 5.cacheT with template(int8/fp8) + // TODO: 1.scale preload 2.frag_dq_T reuse 3.pipeline 4.store aligned 5.cacheT + // with template(int8/fp8) const uint32_t tile_idx = blockIdx.x, kv_head_idx = blockIdx.z; const uint32_t tid = threadIdx.x, wid = threadIdx.y; @@ -367,13 +397,17 @@ __global__ void append_cache_kv_c8( // cache_kv idx uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; uint32_t block_stride = kv_num_heads * kv_h_stride; - const CacheT *cur_cache_k = cache_k + block_id * block_stride + kv_head_idx * kv_h_stride; - const CacheT *cur_cache_v = cache_v + block_id * block_stride + kv_head_idx * kv_h_stride; + const CacheT *cur_cache_k = + cache_k + block_id * block_stride + kv_head_idx * kv_h_stride; + const CacheT *cur_cache_v = + cache_v + block_id * block_stride + kv_head_idx * kv_h_stride; // k_out v_out idx uint32_t kv_t_stride = kv_num_heads * HEAD_DIM; - T *k_write_ptr = k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; - T *v_write_ptr = v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; + T *k_write_ptr = + k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; + T *v_write_ptr = + v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; uint32_t k_frag[4], v_frag[4], frag_dq[4]; T *frag_dq_T = reinterpret_cast(frag_dq); @@ -389,26 +423,32 @@ __global__ void append_cache_kv_c8( extern __shared__ uint8_t smem[]; smem_t k_smem(smem); - uint32_t k_smem_offset_w = smem_t::get_permuted_offset( - wid * 4 + tid / 8, tid % 8); // 4 * 4 per warp + uint32_t k_smem_offset_w = + smem_t::get_permuted_offset( + wid * 4 + tid / 8, tid % 8); // 4 * 4 per warp - uint32_t k_smem_offset_r = smem_t::get_permuted_offset( - wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + uint32_t k_smem_offset_r = + smem_t::get_permuted_offset( + wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); - uint32_t k_read_idx = (wid * 4 + tid / 8) * HEAD_DIM + - tid % 8 * num_elems_per_128b(); + uint32_t k_read_idx = + (wid * 4 + tid / 8) * HEAD_DIM + tid % 8 * num_elems_per_128b(); // load v_smem 64 rows, 128 cols - for (int fz = 0; fz < 4; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter - for (int fy = 0; fy < 1; fy++) { // 8 * 128b = 128 * uint8 once, need 1 iter + for (int fz = 0; fz < 4; + fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter + for (int fy = 0; fy < 1; + fy++) { // 8 * 128b = 128 * uint8 once, need 1 iter k_smem.load_128b_async( - k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0); - k_smem_offset_w = - k_smem.advance_offset_by_column<8, num_vecs_per_head_k>(k_smem_offset_w, fy); + k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0); + k_smem_offset_w = k_smem.advance_offset_by_column<8, num_vecs_per_head_k>( + k_smem_offset_w, fy); k_read_idx += 8 * num_elems_per_128b(); } k_smem_offset_w = - k_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head_k>(k_smem_offset_w) - 8; + k_smem.advance_offset_by_row<4 * NUM_WARPS, num_vecs_per_head_k>( + k_smem_offset_w) - + 8; k_read_idx += 4 * NUM_WARPS * HEAD_DIM - 8 * num_elems_per_128b(); } commit_group(); @@ -416,9 +456,10 @@ __global__ void append_cache_kv_c8( __syncthreads(); // deal k_smem 64 rows, 128 cols - for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter + for (int fz = 0; fz < 1; + fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter uint32_t row_idx = wid * 16 + tid / 4; - for (int fy = 0; fy < 4; fy++) { // 2 * 128b = 32 * uint8 once, need 4 iter + for (int fy = 0; fy < 4; fy++) { // 2 * 128b = 32 * uint8 once, need 4 iter uint32_t col_idx = fy * 32 + tid % 4 * 2; k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag); // layout @@ -427,11 +468,13 @@ __global__ void append_cache_kv_c8( r0c16,r0c17,r0c24,r0c25, r8c16,r8c17,r8c24,r8c25 ***/ for (int i = 0; i < 4 / 2; i++) { - T *k_tile_ptr0 = k_write_ptr + row_idx * kv_t_stride + kv_head_idx * HEAD_DIM + col_idx; + T *k_tile_ptr0 = k_write_ptr + row_idx * kv_t_stride + + kv_head_idx * HEAD_DIM + col_idx; T *k_tile_ptr1 = k_tile_ptr0 + 8 * kv_t_stride; if (row_idx < end_idx) { - convert_c8(frag_dq_T,k_frag[2 * i]); // 4 * uint8/fp8 -> 4 * T + convert_c8(frag_dq_T, + k_frag[2 * i]); // 4 * uint8/fp8 -> 4 * T k_tile_ptr0[0] = frag_dq_T[0] * cache_k_scale; k_tile_ptr0[1] = frag_dq_T[1] * cache_k_scale; k_tile_ptr0[8] = frag_dq_T[2] * cache_k_scale; @@ -439,7 +482,8 @@ __global__ void append_cache_kv_c8( } if (row_idx + 8 < end_idx) { - convert_c8(frag_dq_T + 4,k_frag[2 * i + 1]); // 4 * uint8/fp8 -> 4 * T + convert_c8(frag_dq_T + 4, + k_frag[2 * i + 1]); // 4 * uint8/fp8 -> 4 * T k_tile_ptr1[0] = frag_dq_T[4] * cache_k_scale; k_tile_ptr1[1] = frag_dq_T[5] * cache_k_scale; k_tile_ptr1[8] = frag_dq_T[6] * cache_k_scale; @@ -448,33 +492,41 @@ __global__ void append_cache_kv_c8( col_idx += 16; } k_smem_offset_r = k_smem.advance_offset_by_column<2, num_vecs_per_head_k>( - k_smem_offset_r, fy); + k_smem_offset_r, fy); } k_smem_offset_r = - k_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head_k>(k_smem_offset_r) - 8; + k_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head_k>( + k_smem_offset_r) - + 8; } // ================v================ smem_t v_smem(smem + BLOCK_SIZE * HEAD_DIM * sizeof(CacheT)); - uint32_t v_smem_offset_w = smem_t::get_permuted_offset( - wid * 8 + tid / 4, tid % 4); // 4 * 8 per warp + uint32_t v_smem_offset_w = + smem_t::get_permuted_offset( + wid * 8 + tid / 4, tid % 4); // 4 * 8 per warp - uint32_t v_smem_offset_r = smem_t::get_permuted_offset( - wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + uint32_t v_smem_offset_r = + smem_t::get_permuted_offset( + wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); - uint32_t v_read_idx = (wid * 8 + tid / 4) * BLOCK_SIZE + - tid % 4 * num_elems_per_128b(); + uint32_t v_read_idx = + (wid * 8 + tid / 4) * BLOCK_SIZE + tid % 4 * num_elems_per_128b(); // load v_smem 128 rows 64 cols - for (int fy = 0; fy < 4; fy++) { // 8 rows pre warp once, 32 rows all 4 warps once, need 4 iter - for (int fz = 0; fz < 1; fz++) { // 4 * 128b = 64 * uint8 once, need 1 iter + for (int fy = 0; fy < 4; + fy++) { // 8 rows pre warp once, 32 rows all 4 warps once, need 4 iter + for (int fz = 0; fz < 1; fz++) { // 4 * 128b = 64 * uint8 once, need 1 iter v_smem.load_128b_async( - v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0); + v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0); v_smem_offset_w = - v_smem.advance_offset_by_column<4, num_vecs_per_blocksize>(v_smem_offset_w, fz); + v_smem.advance_offset_by_column<4, num_vecs_per_blocksize>( + v_smem_offset_w, fz); v_read_idx += 4 * num_elems_per_128b(); } v_smem_offset_w = - v_smem.advance_offset_by_row<8 * NUM_WARPS, num_vecs_per_blocksize>(v_smem_offset_w) - 4; + v_smem.advance_offset_by_row<8 * NUM_WARPS, num_vecs_per_blocksize>( + v_smem_offset_w) - + 4; v_read_idx += 8 * NUM_WARPS * BLOCK_SIZE - 4 * num_elems_per_128b(); } @@ -483,17 +535,21 @@ __global__ void append_cache_kv_c8( __syncthreads(); // deal v_smem 128 rows 64 cols - for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter + for (int fy = 0; fy < 2; + fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4; - for (int fz = 0; fz < 2; fz++) { // 2 * 128b = 32 * uint8 once, need 2 iter + for (int fz = 0; fz < 2; fz++) { // 2 * 128b = 32 * uint8 once, need 2 iter uint32_t kv_idx = fz * 32 + tid % 4 * 2; v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag); // layout for (int i = 0; i < 4 / 2; i++) { - T *v_tile_ptr0 = v_write_ptr + kv_idx * kv_t_stride + kv_head_idx * HEAD_DIM + dim_idx; + T *v_tile_ptr0 = v_write_ptr + kv_idx * kv_t_stride + + kv_head_idx * HEAD_DIM + dim_idx; T *v_tile_ptr1 = v_tile_ptr0 + 8; - convert_c8(frag_dq_T, v_frag[2 * i]); // 4 * uint8/fp8 -> 4 * T - convert_c8(frag_dq_T + 4, v_frag[2 * i + 1]); // 4 * uint8/fp8 -> 4 * T + convert_c8(frag_dq_T, + v_frag[2 * i]); // 4 * uint8/fp8 -> 4 * T + convert_c8(frag_dq_T + 4, + v_frag[2 * i + 1]); // 4 * uint8/fp8 -> 4 * T if (kv_idx < end_idx) { v_tile_ptr0[0] = frag_dq_T[0] * cache_v_scale; v_tile_ptr1[0] = frag_dq_T[4] * cache_v_scale; @@ -512,11 +568,14 @@ __global__ void append_cache_kv_c8( } kv_idx += 16; } - v_smem_offset_r = v_smem.advance_offset_by_column<2, num_vecs_per_blocksize>( - v_smem_offset_r, fz); + v_smem_offset_r = + v_smem.advance_offset_by_column<2, num_vecs_per_blocksize>( + v_smem_offset_r, fz); } v_smem_offset_r = - v_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_blocksize>(v_smem_offset_r) - 4; + v_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_blocksize>( + v_smem_offset_r) - + 4; } } @@ -524,27 +583,27 @@ template -__global__ void append_cache_kv_c4( - const CacheT *__restrict__ cache_k, - const CacheT *__restrict__ cache_v, - T *__restrict__ k_out, - T *__restrict__ v_out, - const T *__restrict__ cache_k_dequant_scales, - const T *__restrict__ cache_v_dequant_scales, - const T *__restrict__ cache_k_zero_point, - const T *__restrict__ cache_v_zero_point, - const int *__restrict__ seq_lens_this_time, - const int *__restrict__ seq_lens_decoder, - const int *__restrict__ cu_seqlens_k, - const int *__restrict__ block_tables, - const int *batch_ids, - const int *tile_ids_per_batch, - const int max_blocks_per_seq, - const int kv_num_heads) { + uint32_t NUM_WARPS = 4> +__global__ void append_cache_kv_c4(const CacheT *__restrict__ cache_k, + const CacheT *__restrict__ cache_v, + T *__restrict__ k_out, + T *__restrict__ v_out, + const T *__restrict__ cache_k_dequant_scales, + const T *__restrict__ cache_v_dequant_scales, + const T *__restrict__ cache_k_zero_point, + const T *__restrict__ cache_v_zero_point, + const int *__restrict__ seq_lens_this_time, + const int *__restrict__ seq_lens_decoder, + const int *__restrict__ cu_seqlens_k, + const int *__restrict__ block_tables, + const int *batch_ids, + const int *tile_ids_per_batch, + const int max_blocks_per_seq, + const int kv_num_heads) { // start_kv_idx: start kv_idx current block // batch_id:block's batch_id - // TODO: 1.scale preload 2.frag_dq_T reuse 3.pipeline 4.store aligned 5.cacheT with template(int8/fp8) + // TODO: 1.scale preload 2.frag_dq_T reuse 3.pipeline 4.store aligned 5.cacheT + // with template(int8/fp8) const uint32_t tile_idx = blockIdx.x, kv_head_idx = blockIdx.z; const uint32_t tid = threadIdx.x, wid = threadIdx.y; @@ -564,13 +623,17 @@ __global__ void append_cache_kv_c4( // cache_kv idx uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM_HALF; uint32_t block_stride = kv_num_heads * kv_h_stride; - const CacheT *cur_cache_k = cache_k + block_id * block_stride + kv_head_idx * kv_h_stride; - const CacheT *cur_cache_v = cache_v + block_id * block_stride + kv_head_idx * kv_h_stride; + const CacheT *cur_cache_k = + cache_k + block_id * block_stride + kv_head_idx * kv_h_stride; + const CacheT *cur_cache_v = + cache_v + block_id * block_stride + kv_head_idx * kv_h_stride; // k_out v_out idx uint32_t kv_t_stride = kv_num_heads * HEAD_DIM; - T *k_write_ptr = k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; - T *v_write_ptr = v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; + T *k_write_ptr = + k_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; + T *v_write_ptr = + v_out + (cu_seqlens_k[batch_id] + start_kv_idx) * kv_t_stride; extern __shared__ uint8_t smem[]; @@ -582,8 +645,8 @@ __global__ void append_cache_kv_c4( const T *cache_k_zp_now = cache_k_zero_point + kv_head_idx * HEAD_DIM; const T *cache_v_scale_now = cache_v_dequant_scales + kv_head_idx * HEAD_DIM; const T *cache_v_zp_now = cache_v_zero_point + kv_head_idx * HEAD_DIM; - T *cache_k_scale_smem = reinterpret_cast( - smem + BLOCK_SIZE * HEAD_DIM * sizeof(CacheT)); + T *cache_k_scale_smem = + reinterpret_cast(smem + BLOCK_SIZE * HEAD_DIM * sizeof(CacheT)); T *cache_k_zero_point_smem = cache_k_scale_smem + HEAD_DIM; T *cache_v_scale_smem = cache_k_zero_point_smem + HEAD_DIM; T *cache_v_zero_point_smem = cache_v_scale_smem + HEAD_DIM; @@ -597,104 +660,150 @@ __global__ void append_cache_kv_c4( smem_t k_smem(smem); constexpr uint32_t num_vecs_per_head_k = - HEAD_DIM_HALF / num_elems_per_128b(); // 2 + HEAD_DIM_HALF / num_elems_per_128b(); // 2 constexpr uint32_t num_vecs_per_blocksize = BLOCK_SIZE_HALF / num_elems_per_128b(); - constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; // 4 + constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; // 4 constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; - uint32_t k_smem_offset_w = smem_t::get_permuted_offset( - wid * 8 + tid / 4, tid % 4); // 2(iter) * 4(warp) * 8 row per warp + uint32_t k_smem_offset_w = + smem_t::get_permuted_offset( + wid * 8 + tid / 4, tid % 4); // 2(iter) * 4(warp) * 8 row per warp - uint32_t k_smem_offset_r = smem_t::get_permuted_offset( - wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); // + uint32_t k_smem_offset_r = + smem_t::get_permuted_offset( + wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); // uint32_t k_read_idx = (wid * 8 + tid / 4) * HEAD_DIM / 2 + - tid % 4 * num_elems_per_128b(); + tid % 4 * num_elems_per_128b(); // load k_smem 64 rows 128 cols - for (int fz = 0; fz < 2; fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter - for (int fy = 0; fy < 1; fy++) { // 4 * 128b = 128 * int4 once, need 1 iter + for (int fz = 0; fz < 2; + fz++) { // 4 rows pre warp once, 16 rows all 4 warps once, need 4 iter + for (int fy = 0; fy < 1; fy++) { // 4 * 128b = 128 * int4 once, need 1 iter k_smem.load_128b_async( - k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0); - k_smem_offset_w = - k_smem.advance_offset_by_column<4, num_vecs_per_head_k>(k_smem_offset_w, fy); + k_smem_offset_w, cur_cache_k + k_read_idx, end_idx > 0); + k_smem_offset_w = k_smem.advance_offset_by_column<4, num_vecs_per_head_k>( + k_smem_offset_w, fy); k_read_idx += 4 * num_elems_per_128b(); } k_smem_offset_w = - k_smem.advance_offset_by_row<8 * NUM_WARPS, num_vecs_per_head_k>(k_smem_offset_w) - 4; - k_read_idx += 8 * NUM_WARPS * HEAD_DIM / 2 - 4 * num_elems_per_128b(); + k_smem.advance_offset_by_row<8 * NUM_WARPS, num_vecs_per_head_k>( + k_smem_offset_w) - + 4; + k_read_idx += + 8 * NUM_WARPS * HEAD_DIM / 2 - 4 * num_elems_per_128b(); } commit_group(); wait_group<0>(); __syncthreads(); // deal k_smem 64 rows 128 cols - for (int fz = 0; fz < 1; fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter + for (int fz = 0; fz < 1; + fz++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 1 iter uint32_t row_idx = wid * 16 + tid / 4; - for (int fy = 0; fy < 2; fy++) { // 2 * 128b = 64 * int4 once, need 2 iter + for (int fy = 0; fy < 2; fy++) { // 2 * 128b = 64 * int4 once, need 2 iter uint32_t col_idx = fy * 64 + tid % 4 * 2; k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag); - for (int i = 0; i < 2; i++) { - T *k_tile_ptr0 = k_write_ptr + row_idx * kv_t_stride + kv_head_idx * HEAD_DIM + col_idx; + T *k_tile_ptr0 = k_write_ptr + row_idx * kv_t_stride + + kv_head_idx * HEAD_DIM + col_idx; T *k_tile_ptr1 = k_tile_ptr0 + 8 * kv_t_stride; convert_int4(frag_dq_T, k_frag[2 * i]); convert_int4(frag_dq_T + 8, k_frag[2 * i + 1]); if (row_idx < end_idx) { - k_tile_ptr0[0] = (frag_dq_T[0] - cache_k_zero_point_smem[col_idx]) * cache_k_scale_smem[col_idx]; - k_tile_ptr0[1] = (frag_dq_T[1] - cache_k_zero_point_smem[col_idx + 1]) * cache_k_scale_smem[col_idx + 1]; - k_tile_ptr0[8] = (frag_dq_T[2] - cache_k_zero_point_smem[col_idx + 8]) * cache_k_scale_smem[col_idx + 8]; - k_tile_ptr0[9] = (frag_dq_T[3] - cache_k_zero_point_smem[col_idx + 9]) * cache_k_scale_smem[col_idx + 9]; - k_tile_ptr0[16] = (frag_dq_T[8] - cache_k_zero_point_smem[col_idx + 16]) * cache_k_scale_smem[col_idx + 16]; - k_tile_ptr0[17] = (frag_dq_T[9] - cache_k_zero_point_smem[col_idx + 17]) * cache_k_scale_smem[col_idx + 17]; - k_tile_ptr0[24] = (frag_dq_T[10] - cache_k_zero_point_smem[col_idx + 24]) * cache_k_scale_smem[col_idx + 24]; - k_tile_ptr0[25] = (frag_dq_T[11] - cache_k_zero_point_smem[col_idx + 25]) * cache_k_scale_smem[col_idx + 25]; + k_tile_ptr0[0] = (frag_dq_T[0] - cache_k_zero_point_smem[col_idx]) * + cache_k_scale_smem[col_idx]; + k_tile_ptr0[1] = + (frag_dq_T[1] - cache_k_zero_point_smem[col_idx + 1]) * + cache_k_scale_smem[col_idx + 1]; + k_tile_ptr0[8] = + (frag_dq_T[2] - cache_k_zero_point_smem[col_idx + 8]) * + cache_k_scale_smem[col_idx + 8]; + k_tile_ptr0[9] = + (frag_dq_T[3] - cache_k_zero_point_smem[col_idx + 9]) * + cache_k_scale_smem[col_idx + 9]; + k_tile_ptr0[16] = + (frag_dq_T[8] - cache_k_zero_point_smem[col_idx + 16]) * + cache_k_scale_smem[col_idx + 16]; + k_tile_ptr0[17] = + (frag_dq_T[9] - cache_k_zero_point_smem[col_idx + 17]) * + cache_k_scale_smem[col_idx + 17]; + k_tile_ptr0[24] = + (frag_dq_T[10] - cache_k_zero_point_smem[col_idx + 24]) * + cache_k_scale_smem[col_idx + 24]; + k_tile_ptr0[25] = + (frag_dq_T[11] - cache_k_zero_point_smem[col_idx + 25]) * + cache_k_scale_smem[col_idx + 25]; } if (row_idx + 8 < end_idx) { - k_tile_ptr1[0] = (frag_dq_T[4] - cache_k_zero_point_smem[col_idx]) * cache_k_scale_smem[col_idx]; - k_tile_ptr1[1] = (frag_dq_T[5] - cache_k_zero_point_smem[col_idx + 1]) * cache_k_scale_smem[col_idx + 1]; - k_tile_ptr1[8] = (frag_dq_T[6] - cache_k_zero_point_smem[col_idx + 8]) * cache_k_scale_smem[col_idx + 8]; - k_tile_ptr1[9] = (frag_dq_T[7] - cache_k_zero_point_smem[col_idx + 9]) * cache_k_scale_smem[col_idx + 9]; - k_tile_ptr1[16] = (frag_dq_T[12] - cache_k_zero_point_smem[col_idx + 16]) * cache_k_scale_smem[col_idx + 16]; - k_tile_ptr1[17] = (frag_dq_T[13] - cache_k_zero_point_smem[col_idx + 17]) * cache_k_scale_smem[col_idx + 17]; - k_tile_ptr1[24] = (frag_dq_T[14] - cache_k_zero_point_smem[col_idx + 24]) * cache_k_scale_smem[col_idx + 24]; - k_tile_ptr1[25] = (frag_dq_T[15] - cache_k_zero_point_smem[col_idx + 25]) * cache_k_scale_smem[col_idx + 25]; + k_tile_ptr1[0] = (frag_dq_T[4] - cache_k_zero_point_smem[col_idx]) * + cache_k_scale_smem[col_idx]; + k_tile_ptr1[1] = + (frag_dq_T[5] - cache_k_zero_point_smem[col_idx + 1]) * + cache_k_scale_smem[col_idx + 1]; + k_tile_ptr1[8] = + (frag_dq_T[6] - cache_k_zero_point_smem[col_idx + 8]) * + cache_k_scale_smem[col_idx + 8]; + k_tile_ptr1[9] = + (frag_dq_T[7] - cache_k_zero_point_smem[col_idx + 9]) * + cache_k_scale_smem[col_idx + 9]; + k_tile_ptr1[16] = + (frag_dq_T[12] - cache_k_zero_point_smem[col_idx + 16]) * + cache_k_scale_smem[col_idx + 16]; + k_tile_ptr1[17] = + (frag_dq_T[13] - cache_k_zero_point_smem[col_idx + 17]) * + cache_k_scale_smem[col_idx + 17]; + k_tile_ptr1[24] = + (frag_dq_T[14] - cache_k_zero_point_smem[col_idx + 24]) * + cache_k_scale_smem[col_idx + 24]; + k_tile_ptr1[25] = + (frag_dq_T[15] - cache_k_zero_point_smem[col_idx + 25]) * + cache_k_scale_smem[col_idx + 25]; } col_idx += 32; } k_smem_offset_r = k_smem.advance_offset_by_column<2, num_vecs_per_head_k>( - k_smem_offset_r, fy); + k_smem_offset_r, fy); } k_smem_offset_r = - k_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head_k>(k_smem_offset_r) - 4; + k_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_head_k>( + k_smem_offset_r) - + 4; } // ================v================ smem_t v_smem(smem + BLOCK_SIZE * HEAD_DIM * sizeof(CacheT) / 2); - uint32_t v_smem_offset_w = smem_t::get_permuted_offset( - wid * 16 + tid / 2, tid % 2); // 4 * 8 per warp + uint32_t v_smem_offset_w = + smem_t::get_permuted_offset( + wid * 16 + tid / 2, tid % 2); // 4 * 8 per warp - uint32_t v_smem_offset_r = smem_t::get_permuted_offset( - wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + uint32_t v_smem_offset_r = + smem_t::get_permuted_offset( + wid * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); uint32_t v_read_idx = (wid * 16 + tid / 2) * BLOCK_SIZE_HALF + - tid % 2 * num_elems_per_128b(); + tid % 2 * num_elems_per_128b(); // load v_smem 128 rows 64 rows - for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter - for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 once, need 1 iter + for (int fy = 0; fy < 2; + fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter + for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 once, need 1 iter v_smem.load_128b_async( - v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0); + v_smem_offset_w, cur_cache_v + v_read_idx, end_idx > 0); v_smem_offset_w = - v_smem.advance_offset_by_column<2, num_vecs_per_blocksize>(v_smem_offset_w, fz); + v_smem.advance_offset_by_column<2, num_vecs_per_blocksize>( + v_smem_offset_w, fz); v_read_idx += 2 * num_elems_per_128b(); } v_smem_offset_w = - v_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_blocksize>(v_smem_offset_w) - 2; - v_read_idx += 16 * NUM_WARPS * BLOCK_SIZE_HALF - 2 * num_elems_per_128b(); + v_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_blocksize>( + v_smem_offset_w) - + 2; + v_read_idx += + 16 * NUM_WARPS * BLOCK_SIZE_HALF - 2 * num_elems_per_128b(); } commit_group(); @@ -702,82 +811,116 @@ __global__ void append_cache_kv_c4( __syncthreads(); // deal v_smem 128 rows 64 cols - for (int fy = 0; fy < 2; fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter + for (int fy = 0; fy < 2; + fy++) { // 16 rows pre warp once, 64 rows all 4 warps once, need 2 iter uint32_t dim_idx = fy * NUM_WARPS * 16 + wid * 16 + tid / 4; - for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 once, need 1 iter + for (int fz = 0; fz < 1; fz++) { // 2 * 128b = 64 * int4 once, need 1 iter uint32_t kv_idx = fz * 64 + tid % 4 * 2; v_smem.ldmatrix_m8n8x4(v_smem_offset_r, v_frag); // layout for (int i = 0; i < 2; i++) { - T *v_tile_ptr0 = v_write_ptr + kv_idx * kv_t_stride + kv_head_idx * HEAD_DIM + dim_idx; + T *v_tile_ptr0 = v_write_ptr + kv_idx * kv_t_stride + + kv_head_idx * HEAD_DIM + dim_idx; T *v_tile_ptr1 = v_tile_ptr0 + 8; convert_int4(frag_dq_T, v_frag[2 * i]); convert_int4(frag_dq_T + 8, v_frag[2 * i + 1]); if (kv_idx < end_idx) { - v_tile_ptr0[0] = (frag_dq_T[0] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx]; - v_tile_ptr1[0] = (frag_dq_T[4] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8]; + v_tile_ptr0[0] = (frag_dq_T[0] - cache_v_zero_point_smem[dim_idx]) * + cache_v_scale_smem[dim_idx]; + v_tile_ptr1[0] = + (frag_dq_T[4] - cache_v_zero_point_smem[dim_idx + 8]) * + cache_v_scale_smem[dim_idx + 8]; } if (kv_idx + 1 < end_idx) { - v_tile_ptr0[kv_t_stride] = (frag_dq_T[1] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx]; - v_tile_ptr1[kv_t_stride] = (frag_dq_T[5] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8]; + v_tile_ptr0[kv_t_stride] = + (frag_dq_T[1] - cache_v_zero_point_smem[dim_idx]) * + cache_v_scale_smem[dim_idx]; + v_tile_ptr1[kv_t_stride] = + (frag_dq_T[5] - cache_v_zero_point_smem[dim_idx + 8]) * + cache_v_scale_smem[dim_idx + 8]; } if (kv_idx + 8 < end_idx) { - v_tile_ptr0[8 * kv_t_stride] = (frag_dq_T[2] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx]; - v_tile_ptr1[8 * kv_t_stride] = (frag_dq_T[6] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8]; + v_tile_ptr0[8 * kv_t_stride] = + (frag_dq_T[2] - cache_v_zero_point_smem[dim_idx]) * + cache_v_scale_smem[dim_idx]; + v_tile_ptr1[8 * kv_t_stride] = + (frag_dq_T[6] - cache_v_zero_point_smem[dim_idx + 8]) * + cache_v_scale_smem[dim_idx + 8]; } if (kv_idx + 9 < end_idx) { - v_tile_ptr0[9 * kv_t_stride] = (frag_dq_T[3] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx]; - v_tile_ptr1[9 * kv_t_stride] = (frag_dq_T[7] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8]; + v_tile_ptr0[9 * kv_t_stride] = + (frag_dq_T[3] - cache_v_zero_point_smem[dim_idx]) * + cache_v_scale_smem[dim_idx]; + v_tile_ptr1[9 * kv_t_stride] = + (frag_dq_T[7] - cache_v_zero_point_smem[dim_idx + 8]) * + cache_v_scale_smem[dim_idx + 8]; } if (kv_idx + 16 < end_idx) { - v_tile_ptr0[16 * kv_t_stride] = (frag_dq_T[8] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx]; - v_tile_ptr1[16 * kv_t_stride] = (frag_dq_T[12] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8]; + v_tile_ptr0[16 * kv_t_stride] = + (frag_dq_T[8] - cache_v_zero_point_smem[dim_idx]) * + cache_v_scale_smem[dim_idx]; + v_tile_ptr1[16 * kv_t_stride] = + (frag_dq_T[12] - cache_v_zero_point_smem[dim_idx + 8]) * + cache_v_scale_smem[dim_idx + 8]; } if (kv_idx + 17 < end_idx) { - v_tile_ptr0[17 * kv_t_stride] = (frag_dq_T[9] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx]; - v_tile_ptr1[17 * kv_t_stride] = (frag_dq_T[13] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8]; + v_tile_ptr0[17 * kv_t_stride] = + (frag_dq_T[9] - cache_v_zero_point_smem[dim_idx]) * + cache_v_scale_smem[dim_idx]; + v_tile_ptr1[17 * kv_t_stride] = + (frag_dq_T[13] - cache_v_zero_point_smem[dim_idx + 8]) * + cache_v_scale_smem[dim_idx + 8]; } if (kv_idx + 24 < end_idx) { - v_tile_ptr0[24 * kv_t_stride] = (frag_dq_T[10] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx]; - v_tile_ptr1[24 * kv_t_stride] = (frag_dq_T[14] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8]; + v_tile_ptr0[24 * kv_t_stride] = + (frag_dq_T[10] - cache_v_zero_point_smem[dim_idx]) * + cache_v_scale_smem[dim_idx]; + v_tile_ptr1[24 * kv_t_stride] = + (frag_dq_T[14] - cache_v_zero_point_smem[dim_idx + 8]) * + cache_v_scale_smem[dim_idx + 8]; } if (kv_idx + 25 < end_idx) { - v_tile_ptr0[25 * kv_t_stride] = (frag_dq_T[11] - cache_v_zero_point_smem[dim_idx]) * cache_v_scale_smem[dim_idx]; - v_tile_ptr1[25 * kv_t_stride] = (frag_dq_T[15] - cache_v_zero_point_smem[dim_idx + 8]) * cache_v_scale_smem[dim_idx + 8]; + v_tile_ptr0[25 * kv_t_stride] = + (frag_dq_T[11] - cache_v_zero_point_smem[dim_idx]) * + cache_v_scale_smem[dim_idx]; + v_tile_ptr1[25 * kv_t_stride] = + (frag_dq_T[15] - cache_v_zero_point_smem[dim_idx + 8]) * + cache_v_scale_smem[dim_idx + 8]; } kv_idx += 32; } - v_smem_offset_r = v_smem.advance_offset_by_column<2, num_vecs_per_blocksize>( - v_smem_offset_r, fz); + v_smem_offset_r = + v_smem.advance_offset_by_column<2, num_vecs_per_blocksize>( + v_smem_offset_r, fz); } v_smem_offset_r = - v_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_blocksize>(v_smem_offset_r) - 2; + v_smem.advance_offset_by_row<16 * NUM_WARPS, num_vecs_per_blocksize>( + v_smem_offset_r) - + 2; } } template -void AppendCacheKV( - const paddle::Tensor &cache_k, - const paddle::Tensor &cache_v, - const paddle::Tensor &cache_k_dequant_scales, - const paddle::Tensor &cache_v_dequant_scales, - const paddle::Tensor &cache_k_zp, - const paddle::Tensor &cache_v_zp, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &cu_seqlens_k, - const paddle::Tensor &block_tables, - const paddle::Tensor &cache_batch_ids, - const paddle::Tensor &cache_tile_ids_per_batch, - const paddle::Tensor &cache_num_blocks_x, - const int max_blocks_per_seq, - const int kv_num_heads, - const std::string &cache_quant_type, - paddle::Tensor *k_out, - paddle::Tensor *v_out, - const cudaStream_t& stream -) { +void AppendCacheKV(const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::Tensor &cache_k_dequant_scales, + const paddle::Tensor &cache_v_dequant_scales, + const paddle::Tensor &cache_k_zp, + const paddle::Tensor &cache_v_zp, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &cu_seqlens_k, + const paddle::Tensor &block_tables, + const paddle::Tensor &cache_batch_ids, + const paddle::Tensor &cache_tile_ids_per_batch, + const paddle::Tensor &cache_num_blocks_x, + const int max_blocks_per_seq, + const int kv_num_heads, + const std::string &cache_quant_type, + paddle::Tensor *k_out, + paddle::Tensor *v_out, + const cudaStream_t &stream) { using NV_TYPE = typename cascade_attn_type_traits::type; constexpr int NUM_WARPS = 4; int block_num = cache_num_blocks_x.data()[0]; @@ -785,131 +928,160 @@ void AppendCacheKV( dim3 blocks(32, NUM_WARPS); if (cache_quant_type == "none") { const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(T) * 2; - auto kernel_func = append_cache_kv_c16; + auto kernel_func = + append_cache_kv_c16; if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(kernel_func, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); + cudaFuncSetAttribute( + kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); } - kernel_func<<>>( - reinterpret_cast(const_cast(cache_k.data())), - reinterpret_cast(const_cast(cache_v.data())), - reinterpret_cast(k_out->data()), - reinterpret_cast(v_out->data()), - seq_lens_this_time.data(), - seq_lens_decoder.data(), - cu_seqlens_k.data(), - block_tables.data(), - cache_batch_ids.data(), - cache_tile_ids_per_batch.data(), - max_blocks_per_seq, - kv_num_heads - ); - } else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") { + launchWithPdlWhenEnabled( + kernel_func, + grids, + blocks, + smem_size, + stream, + reinterpret_cast(const_cast(cache_k.data())), + reinterpret_cast(const_cast(cache_v.data())), + reinterpret_cast(k_out->data()), + reinterpret_cast(v_out->data()), + seq_lens_this_time.data(), + seq_lens_decoder.data(), + cu_seqlens_k.data(), + block_tables.data(), + cache_batch_ids.data(), + cache_tile_ids_per_batch.data(), + max_blocks_per_seq, + kv_num_heads); + } else if (cache_quant_type == "cache_int8" || + cache_quant_type == "cache_fp8") { const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) * 2; - auto kernel_func = append_cache_kv_c8; + auto kernel_func = append_cache_kv_c8; if (cache_quant_type == "cache_fp8") { - kernel_func = append_cache_kv_c8; + kernel_func = append_cache_kv_c8; } if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(kernel_func, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); + cudaFuncSetAttribute( + kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); } - kernel_func<<>>( - cache_k.data(), - cache_v.data(), - reinterpret_cast(k_out->data()), - reinterpret_cast(v_out->data()), - reinterpret_cast(const_cast(cache_k_dequant_scales.data())), - reinterpret_cast(const_cast(cache_v_dequant_scales.data())), - seq_lens_this_time.data(), - seq_lens_decoder.data(), - cu_seqlens_k.data(), - block_tables.data(), - cache_batch_ids.data(), - cache_tile_ids_per_batch.data(), - max_blocks_per_seq, - kv_num_heads - ); + launchWithPdlWhenEnabled(kernel_func, + grids, + blocks, + smem_size, + stream, + cache_k.data(), + cache_v.data(), + reinterpret_cast(k_out->data()), + reinterpret_cast(v_out->data()), + reinterpret_cast(const_cast( + cache_k_dequant_scales.data())), + reinterpret_cast(const_cast( + cache_v_dequant_scales.data())), + seq_lens_this_time.data(), + seq_lens_decoder.data(), + cu_seqlens_k.data(), + block_tables.data(), + cache_batch_ids.data(), + cache_tile_ids_per_batch.data(), + max_blocks_per_seq, + kv_num_heads); } else if (cache_quant_type == "cache_int4_zp") { - const uint32_t smem_size = BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) + 4 * HEAD_DIM * sizeof(T); + const uint32_t smem_size = + BLOCK_SIZE * HEAD_DIM * sizeof(uint8_t) + 4 * HEAD_DIM * sizeof(T); - auto kernel_func = append_cache_kv_c4; + auto kernel_func = + append_cache_kv_c4; if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(kernel_func, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); + cudaFuncSetAttribute( + kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); } - kernel_func<<>>( - cache_k.data(), - cache_v.data(), - reinterpret_cast(k_out->data()), - reinterpret_cast(v_out->data()), - reinterpret_cast(const_cast(cache_k_dequant_scales.data())), - reinterpret_cast(const_cast(cache_v_dequant_scales.data())), - reinterpret_cast(const_cast(cache_k_zp.data())), - reinterpret_cast(const_cast(cache_v_zp.data())), - seq_lens_this_time.data(), - seq_lens_decoder.data(), - cu_seqlens_k.data(), - block_tables.data(), - cache_batch_ids.data(), - cache_tile_ids_per_batch.data(), - max_blocks_per_seq, - kv_num_heads - ); + launchWithPdlWhenEnabled( + kernel_func, + grids, + blocks, + smem_size, + stream, + cache_k.data(), + cache_v.data(), + reinterpret_cast(k_out->data()), + reinterpret_cast(v_out->data()), + reinterpret_cast( + const_cast(cache_k_dequant_scales.data())), + reinterpret_cast( + const_cast(cache_v_dequant_scales.data())), + reinterpret_cast(const_cast(cache_k_zp.data())), + reinterpret_cast(const_cast(cache_v_zp.data())), + seq_lens_this_time.data(), + seq_lens_decoder.data(), + cu_seqlens_k.data(), + block_tables.data(), + cache_batch_ids.data(), + cache_tile_ids_per_batch.data(), + max_blocks_per_seq, + kv_num_heads); } else { PADDLE_THROW("%s mode isn't implemented yet", cache_quant_type.c_str()); } } std::vector GQARopeWriteCacheKernel( - const paddle::Tensor& qkv, - const paddle::Tensor& key_cache, - const paddle::Tensor& value_cache, - const paddle::Tensor& cu_seqlens_q, - const paddle::Tensor& cu_seqlens_k, - const paddle::Tensor& rotary_embs, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& batch_id_per_token, - const paddle::Tensor& block_tables, - const paddle::Tensor& kv_batch_ids, - const paddle::Tensor& kv_tile_ids, - const paddle::Tensor& kv_num_blocks, - const paddle::Tensor& cache_batch_ids, - const paddle::Tensor& cache_tile_ids, - const paddle::Tensor& cache_num_blocks, - const paddle::optional& cache_k_quant_scales, - const paddle::optional& cache_v_quant_scales, - const paddle::optional& cache_k_dequant_scales, - const paddle::optional& cache_v_dequant_scales, - const paddle::optional& cache_k_zp, - const paddle::optional& cache_v_zp, - const paddle::optional& kv_signal_data, + const paddle::Tensor &qkv, + const paddle::Tensor &key_cache, + const paddle::Tensor &value_cache, + const paddle::Tensor &cu_seqlens_q, + const paddle::Tensor &cu_seqlens_k, + const paddle::Tensor &rotary_embs, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &batch_id_per_token, + const paddle::Tensor &block_tables, + const paddle::Tensor &kv_batch_ids, + const paddle::Tensor &kv_tile_ids, + const paddle::Tensor &kv_num_blocks, + const paddle::Tensor &cache_batch_ids, + const paddle::Tensor &cache_tile_ids, + const paddle::Tensor &cache_num_blocks, + const paddle::optional &cache_k_quant_scales, + const paddle::optional &cache_v_quant_scales, + const paddle::optional &cache_k_dequant_scales, + const paddle::optional &cache_v_dequant_scales, + const paddle::optional &cache_k_zp, + const paddle::optional &cache_v_zp, + const paddle::optional &kv_signal_data, const int kv_token_num, const int max_seq_len, - const std::string& cache_quant_type, + const std::string &cache_quant_type, const bool rope_3d) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; const int kv_num_blocks_data = kv_num_blocks.data()[0]; - const auto& qkv_dims = qkv.dims(); - const auto& key_cache_dims = key_cache.dims(); + const auto &qkv_dims = qkv.dims(); + const auto &key_cache_dims = key_cache.dims(); const int token_num = qkv_dims[0]; const int max_blocks_per_seq = block_tables.dims()[1]; const int block_size = key_cache.dims()[2]; const int batch_size = seq_lens_this_time.dims()[0]; const int kv_num_heads = key_cache_dims[1]; - const int head_dim = cache_quant_type == "cache_int4_zp" ? key_cache_dims[3] * 2 : key_cache_dims[3]; - const int num_heads = qkv_dims[qkv_dims.size() - 1] / head_dim - 2 * kv_num_heads; + const int head_dim = cache_quant_type == "cache_int4_zp" + ? key_cache_dims[3] * 2 + : key_cache_dims[3]; + const int num_heads = + qkv_dims[qkv_dims.size() - 1] / head_dim - 2 * kv_num_heads; const float softmax_scale = 1.f / sqrt(head_dim); AppendAttnMetaData meta_data; @@ -921,90 +1093,82 @@ std::vector GQARopeWriteCacheKernel( meta_data.block_size = block_size; meta_data.batch_size = seq_lens_this_time.dims()[0]; - phi::GPUContext* dev_ctx = static_cast(phi::DeviceContextPool::Instance().Get(qkv.place())); + phi::GPUContext *dev_ctx = static_cast( + phi::DeviceContextPool::Instance().Get(qkv.place())); auto stream = qkv.stream(); - paddle::Tensor qkv_out = GetEmptyTensor( - qkv.dims(), - qkv.dtype(), - qkv.place()); + paddle::Tensor qkv_out = GetEmptyTensor(qkv.dims(), qkv.dtype(), qkv.place()); paddle::Tensor q = GetEmptyTensor( - {token_num, num_heads, head_dim}, - qkv.dtype(), - qkv.place()); + {token_num, num_heads, head_dim}, qkv.dtype(), qkv.place()); paddle::Tensor k = GetEmptyTensor( - {kv_token_num, kv_num_heads, head_dim}, - qkv.dtype(), - qkv.place()); + {kv_token_num, kv_num_heads, head_dim}, qkv.dtype(), qkv.place()); paddle::Tensor v = GetEmptyTensor( - {kv_token_num, kv_num_heads, head_dim}, - qkv.dtype(), - qkv.place()); + {kv_token_num, kv_num_heads, head_dim}, qkv.dtype(), qkv.place()); // rope gqa_rotary_qk_split_variable( - qkv_out.data(), - q.data(), - k.data(), - v.data(), - qkv.data(), - rotary_embs.data(), - batch_id_per_token.data(), - seq_lens_encoder.data(), - seq_lens_decoder.data(), - cu_seqlens_q.data(), - cu_seqlens_k.data(), - token_num, - num_heads, - kv_num_heads, - max_seq_len, - rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2], - head_dim, - rope_3d, - stream); + qkv_out.data(), + q.data(), + k.data(), + v.data(), + qkv.data(), + rotary_embs.data(), + batch_id_per_token.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + token_num, + num_heads, + kv_num_heads, + max_seq_len, + rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2], + head_dim, + rope_3d, + stream); if (token_num < kv_token_num) { - AppendCacheKV( - key_cache, - value_cache, - cache_k_dequant_scales.get(), - cache_v_dequant_scales.get(), - cache_k_zp.get(), - cache_v_zp.get(), - seq_lens_this_time, - seq_lens_decoder, - cu_seqlens_k, - block_tables, - cache_batch_ids, - cache_tile_ids, - cache_num_blocks, - max_blocks_per_seq, - kv_num_heads, - cache_quant_type, - &k, - &v, - stream - ); + AppendCacheKV(key_cache, + value_cache, + cache_k_dequant_scales.get(), + cache_v_dequant_scales.get(), + cache_k_zp.get(), + cache_v_zp.get(), + seq_lens_this_time, + seq_lens_decoder, + cu_seqlens_k, + block_tables, + cache_batch_ids, + cache_tile_ids, + cache_num_blocks, + max_blocks_per_seq, + kv_num_heads, + cache_quant_type, + &k, + &v, + stream); } // write cache if (cache_quant_type == "none") { CascadeAppendWriteCacheKVQKV( - meta_data, - qkv_out, - block_tables, - batch_id_per_token, - cu_seqlens_q, - seq_lens_encoder, - seq_lens_decoder, - max_seq_len, - stream, - const_cast(&key_cache), - const_cast(&value_cache)); - } else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8" || cache_quant_type == "block_wise_fp8") { + meta_data, + qkv_out, + block_tables, + batch_id_per_token, + cu_seqlens_q, + seq_lens_encoder, + seq_lens_decoder, + max_seq_len, + stream, + const_cast(&key_cache), + const_cast(&value_cache)); + } else if (cache_quant_type == "cache_int8" || + cache_quant_type == "cache_fp8" || + cache_quant_type == "block_wise_fp8") { CascadeAppendWriteCacheKVC8QKV( meta_data, - *const_cast(&key_cache), - *const_cast(&value_cache), + *const_cast(&key_cache), + *const_cast(&value_cache), qkv_out, cache_k_quant_scales.get(), cache_v_quant_scales.get(), @@ -1017,16 +1181,16 @@ std::vector GQARopeWriteCacheKernel( kv_tile_ids, kv_num_blocks_data, max_seq_len, - false, // is_scale_channel_wise + false, // is_scale_channel_wise cache_quant_type, stream, - const_cast(&key_cache), - const_cast(&value_cache)); + const_cast(&key_cache), + const_cast(&value_cache)); } else if (cache_quant_type == "cache_int4_zp") { CascadeAppendWriteCacheKVC4QKV( meta_data, - *const_cast(&key_cache), - *const_cast(&value_cache), + *const_cast(&key_cache), + *const_cast(&value_cache), qkv_out, cache_k_quant_scales.get(), cache_v_quant_scales.get(), @@ -1042,31 +1206,37 @@ std::vector GQARopeWriteCacheKernel( kv_num_blocks_data, max_seq_len, stream, - const_cast(&key_cache), - const_cast(&value_cache)); + const_cast(&key_cache), + const_cast(&value_cache)); } else { PD_THROW( "cache_quant_type_str should be one of [none, cache_int8, cache_fp8, " "cache_int4_zp]"); } - const char* fmt_write_cache_completed_signal_str = std::getenv("FLAGS_fmt_write_cache_completed_signal"); - const char* FLAGS_use_pd_disaggregation_per_chunk = std::getenv("FLAGS_use_pd_disaggregation_per_chunk"); + const char *fmt_write_cache_completed_signal_str = + std::getenv("FLAGS_fmt_write_cache_completed_signal"); + const char *FLAGS_use_pd_disaggregation_per_chunk = + std::getenv("FLAGS_use_pd_disaggregation_per_chunk"); if (fmt_write_cache_completed_signal_str && (std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 || std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) { - if (FLAGS_use_pd_disaggregation_per_chunk && - (std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 || - std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) { - cudaLaunchHostFunc(qkv.stream(), - &(RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_per_query), - (void*)nullptr); - } else { - if (kv_signal_data) { - cudaLaunchHostFunc(qkv.stream(), - &RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise, - (void*)(const_cast(kv_signal_data.get().data()))); - } + if (FLAGS_use_pd_disaggregation_per_chunk && + (std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 || + std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) { + cudaLaunchHostFunc( + qkv.stream(), + &(RemoteCacheKvIpc:: + save_cache_kv_complete_signal_layerwise_per_query), + (void *)nullptr); + } else { + if (kv_signal_data) { + cudaLaunchHostFunc( + qkv.stream(), + &RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise, + (void *)(const_cast( + kv_signal_data.get().data()))); } + } } return {q, k, v, qkv_out}; } @@ -1096,12 +1266,7 @@ PD_BUILD_STATIC_OP(gqa_rope_write_cache) paddle::Optional("cache_k_zp"), paddle::Optional("cache_v_zp"), paddle::Optional("kv_signal_data")}) - .Outputs({"q", - "k", - "v", - "qkv_out", - "key_cache_out", - "value_cache_out"}) + .Outputs({"q", "k", "v", "qkv_out", "key_cache_out", "value_cache_out"}) .SetInplaceMap({{"key_cache", "key_cache_out"}, {"value_cache", "value_cache_out"}}) .Attrs({"kv_token_num: int", diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index 90fd7079c..54e0715d7 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -157,7 +157,9 @@ __global__ void multi_query_append_attention_kernel( const uint32_t q_end = min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE)); - +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif load_q_global_smem( q_base_ptr, &qo_smem, @@ -410,6 +412,9 @@ __global__ void multi_query_append_attention_kernel( } } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } template = 900)) + cudaGridDependencySynchronize(); +#endif + load_q_global_smem_multi_warps= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } template >>( + launchWithPdlWhenEnabled( + nosplit_kv_kernel, + grids, + blocks, + smem_size, + stream, reinterpret_cast(const_cast(qkv.data())), reinterpret_cast(const_cast(cache_k.data())), reinterpret_cast(const_cast(cache_v.data())), @@ -996,7 +1012,12 @@ void MultiQueryAppendAttention( num_chunks * num_heads)); } - split_kv_kernel<<>>( + launchWithPdlWhenEnabled( + split_kv_kernel, + grids, + blocks, + smem_size, + stream, reinterpret_cast(const_cast(qkv.data())), reinterpret_cast(const_cast(cache_k.data())), reinterpret_cast(const_cast(cache_v.data())), @@ -1037,79 +1058,89 @@ void MultiQueryAppendAttention( constexpr int blocky = (128 + blockx - 1) / blockx; dim3 grids_merge(bsz, num_heads); dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_decoder_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM); + auto *kernelFn = merge_multi_chunks_decoder_kernel; + launchWithPdlWhenEnabled( + kernelFn, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM); } else { constexpr int blockx = HEAD_DIM / vec_size; constexpr int blocky = (128 + blockx - 1) / blockx; dim3 grids_merge(min(sm_count * 4, token_num), num_heads); // 128k is too large dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_v2_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM, - token_num, - speculate_max_draft_token_num); + auto *kernelFn = merge_multi_chunks_v2_kernel; + launchWithPdlWhenEnabled( + kernelFn, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); } } } else { @@ -1177,8 +1208,12 @@ void MultiQueryAppendAttention( cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); } - - nosplit_kv_kernel<<>>( + launchWithPdlWhenEnabled( + nosplit_kv_kernel, + grids, + blocks, + smem_size, + stream, reinterpret_cast(const_cast(qkv.data())), reinterpret_cast(const_cast(cache_k.data())), reinterpret_cast(const_cast(cache_v.data())), @@ -1254,7 +1289,12 @@ void MultiQueryAppendAttention( num_chunks * num_heads)); } } - split_kv_kernel<<>>( + launchWithPdlWhenEnabled( + split_kv_kernel, + grids, + blocks, + smem_size, + stream, reinterpret_cast(const_cast(qkv.data())), reinterpret_cast(const_cast(cache_k.data())), reinterpret_cast(const_cast(cache_v.data())), @@ -1299,78 +1339,88 @@ void MultiQueryAppendAttention( constexpr int blocky = (128 + blockx - 1) / blockx; dim3 grids_merge(bsz, num_heads); dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_decoder_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM); + auto *kernelFn = merge_multi_chunks_decoder_kernel; + launchWithPdlWhenEnabled( + kernelFn, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM); } else { constexpr int blockx = HEAD_DIM / vec_size; constexpr int blocky = (128 + blockx - 1) / blockx; dim3 grids_merge(min(sm_count * 4, token_num), num_heads); dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_v2_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM, - token_num, - speculate_max_draft_token_num); + auto *kernelFn = merge_multi_chunks_v2_kernel; + launchWithPdlWhenEnabled( + kernelFn, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); } } } diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh index c2fb7299b..dab3dcee4 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh @@ -31,7 +31,7 @@ template __global__ void multi_query_append_attention_c4_kernel( - T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] + T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, // head_dim] CacheT *__restrict__ cache_v, @@ -41,7 +41,7 @@ __global__ void multi_query_append_attention_c4_kernel( const T *__restrict__ cache_v_zero_point, // [num_kv_heads, head_dim] const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM] const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] - const T *__restrict__ sinks, // [q_num_heads] + const T *__restrict__ sinks, // [q_num_heads] const int *__restrict__ seq_lens, const int *__restrict__ seq_lens_kv, const int *__restrict__ batch_ids, @@ -87,8 +87,8 @@ __global__ void multi_query_append_attention_c4_kernel( block_table_now = block_table + batch_id * max_block_num_per_seq; - //When cudagraph capture prefill, may launch more gridDim.x - if(btid >= static_cast(num_blocks_x_cpu)){ + // When cudagraph capture prefill, may launch more gridDim.x + if (btid >= static_cast(num_blocks_x_cpu)) { return; } @@ -125,6 +125,9 @@ __global__ void multi_query_append_attention_c4_kernel( float o_frag[num_frags_x][num_frags_y][8]; float m_frag[num_frags_x][2]; float d_frag[num_frags_x][2]; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif const T *cache_k_scale_now = cache_k_scale + kv_head_idx * HEAD_DIM; const T *cache_k_zp_now = cache_k_zero_point + kv_head_idx * HEAD_DIM; @@ -180,7 +183,8 @@ __global__ void multi_query_append_attention_c4_kernel( } else { o_base_ptr_int8 = out + o_offset; } - const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + const int *mask_offset_this_seq = + mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; smem_t qo_smem(smem); uint32_t q_smem_offset_r = smem_t::get_permuted_offset( @@ -241,7 +245,6 @@ __global__ void multi_query_append_attention_c4_kernel( v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + num_frags_z * 16 * HEAD_DIM / 2 * sizeof(CacheT)); - const uint32_t num_iterations = div_up( CAUSAL ? (min(chunk_len, @@ -252,12 +255,13 @@ __global__ void multi_query_append_attention_c4_kernel( : chunk_len, num_frags_z * 16); const uint32_t mask_check_iteration = - (CAUSAL ? (min(chunk_len, + (CAUSAL ? (min(chunk_len, sub_if_greater_or_zero( kv_len - q_len + tile_id * num_rows_per_block / GROUP_SIZE, chunk_start))) - : mask_offset ? 0 : chunk_len) / + : mask_offset ? 0 + : chunk_len) / (num_frags_z * 16); uint32_t k_smem_offset_r = @@ -270,9 +274,7 @@ __global__ void multi_query_append_attention_c4_kernel( uint32_t k_smem_offset_w = smem_t::get_permuted_offset( - wid * 8 + tid / 4, - tid % - 4); + wid * 8 + tid / 4, tid % 4); uint32_t v_smem_offset_w = smem_t::get_permuted_offset( wid * 16 + tid / 2, tid % 2); // 2 * 128 / 8 = 32B, 64 nums @@ -417,15 +419,19 @@ __global__ void multi_query_append_attention_c4_kernel( if constexpr (!partition_kv) { if (sinks) { float current_sinks[num_frags_x][2]; - #pragma unroll +#pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { - #pragma unroll +#pragma unroll for (uint32_t j = 0; j < 2; ++j) { - const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE; - current_sinks[fx][j] = static_cast(sinks[q_head_idx + h_offset]); + const uint32_t h_offset = + (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % + GROUP_SIZE; + current_sinks[fx][j] = + static_cast(sinks[q_head_idx + h_offset]); } } - normalize_d(o_frag, d_frag, m_frag, current_sinks); + normalize_d( + o_frag, d_frag, m_frag, current_sinks); } else { normalize_d(o_frag, d_frag); } @@ -497,6 +503,9 @@ __global__ void multi_query_append_attention_c4_kernel( } } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } template __global__ void multi_query_append_attention_c4_warp1_4_kernel( - T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] + T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size, // head_dim] CacheT *__restrict__ cache_v, @@ -533,7 +542,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel( const int *__restrict__ cu_seqlens_q, const int *__restrict__ block_table, // [bsz, block_num_per_seq] const int *__restrict__ mask_offset, - const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask + const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask const int max_seq_len, const int max_dec_len, const int max_block_num_per_seq, @@ -573,8 +582,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel( const uint32_t num_rows_per_block = num_frags_x * 16; const int *block_table_now = block_table + batch_id * max_block_num_per_seq; - //When cudagraph capture prefill, may launch more gridDim.x - if(btid >= static_cast(num_blocks_x_cpu)){ + // When cudagraph capture prefill, may launch more gridDim.x + if (btid >= static_cast(num_blocks_x_cpu)) { return; } @@ -612,6 +621,9 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel( float m_frag[num_frags_x][2]; float d_frag[num_frags_x][2]; init_states(o_frag, m_frag, d_frag); +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif const T *cache_k_scale_now = cache_k_scale + kv_head_idx * HEAD_DIM; const T *cache_k_zp_now = cache_k_zero_point + kv_head_idx * HEAD_DIM; @@ -664,11 +676,12 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel( tid % 8 * num_elems_per_128b(); } } - const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + const int *mask_offset_this_seq = + mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; smem_t qo_smem(smem); - uint32_t q_smem_offset_r = smem_t::get_permuted_offset( - tid % 16, tid / 16); + uint32_t q_smem_offset_r = + smem_t::get_permuted_offset(tid % 16, tid / 16); load_q_global_smem_multi_warps( - wid * 8 + tid / 4, - tid % - 4); + wid * 8 + tid / 4, tid % 4); uint32_t v_smem_offset_w = smem_t::get_permuted_offset( wid * 16 + tid / 2, tid % 2); @@ -824,16 +834,18 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel( NUM_WARPS, num_frags_x, num_frags_y, - num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr, - q_base_seq_id_this_block, - kv_idx_base + wid * num_frags_z * 16, - q_len, - kv_len, - chunk_end, - attn_mask_len, - s_frag, - mask_offset_this_seq, - sliding_window); + num_frags_z>( + attn_mask ? attn_mask + batch_id * attn_mask_len * attn_mask_len + : nullptr, + q_base_seq_id_this_block, + kv_idx_base + wid * num_frags_z * 16, + q_len, + kv_len, + chunk_end, + attn_mask_len, + s_frag, + mask_offset_this_seq, + sliding_window); } update_mdo_states( @@ -903,15 +915,19 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel( if (num_chunks_this_seq <= 1) { if (sinks) { float current_sinks[num_frags_x][2]; - #pragma unroll +#pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { - #pragma unroll +#pragma unroll for (uint32_t j = 0; j < 2; ++j) { - const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE; - current_sinks[fx][j] = static_cast(sinks[q_head_idx + h_offset]); + const uint32_t h_offset = + (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % + GROUP_SIZE; + current_sinks[fx][j] = + static_cast(sinks[q_head_idx + h_offset]); } } - normalize_d(o_frag, d_frag, m_frag, current_sinks); + normalize_d( + o_frag, d_frag, m_frag, current_sinks); } else { normalize_d(o_frag, d_frag); } @@ -987,6 +1003,9 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel( } } } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif } template >>( + launchWithPdlWhenEnabled( + nosplit_kv_kernel, + grids, + blocks, + smem_size, + stream, reinterpret_cast(const_cast(qkv.data())), const_cast(cache_k.data()), const_cast(cache_v.data()), @@ -1138,8 +1162,8 @@ void MultiQueryAppendC4Attention( const_cast(smooth_weight.get().data())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, seq_lens_q.data(), seq_lens_kv.data(), batch_ids.data(), @@ -1188,7 +1212,12 @@ void MultiQueryAppendC4Attention( static_cast(speculate_max_draft_token_num * bsz * num_chunks * num_heads)); } - split_kv_kernel<<>>( + launchWithPdlWhenEnabled( + split_kv_kernel, + grids, + blocks, + smem_size, + stream, reinterpret_cast(const_cast(qkv.data())), const_cast(cache_k.data()), const_cast(cache_v.data()), @@ -1207,8 +1236,8 @@ void MultiQueryAppendC4Attention( const_cast(smooth_weight.get().data())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, seq_lens_q.data(), seq_lens_kv.data(), batch_ids.data(), @@ -1238,79 +1267,86 @@ void MultiQueryAppendC4Attention( constexpr int blocky = (128 + blockx - 1) / blockx; dim3 grids_merge(bsz, num_heads); dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_decoder_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM); + launchWithPdlWhenEnabled( + merge_multi_chunks_decoder_kernel, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM); } else { constexpr int blockx = HEAD_DIM / vec_size; constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(min(sm_count * 4, token_num), - num_heads); + dim3 grids_merge(min(sm_count * 4, token_num), num_heads); dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_v2_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM, - token_num, - speculate_max_draft_token_num); + launchWithPdlWhenEnabled( + merge_multi_chunks_v2_kernel, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); } } } else { @@ -1353,7 +1389,6 @@ void MultiQueryAppendC4Attention( const float ratio = static_cast(num_blocks_need) / static_cast(num_blocks_per_wave); - uint32_t chunk_size = static_cast(max_partition_size); if (!is_decoder) { chunk_size = static_cast(encoder_max_partition_size); @@ -1362,9 +1397,9 @@ void MultiQueryAppendC4Attention( const int num_chunks = div_up(max_seq_len, chunk_size); uint32_t attn_mask_len; if (attn_mask) { - attn_mask_len = attn_mask.get().shape()[1]; + attn_mask_len = attn_mask.get().shape()[1]; } else { - attn_mask_len = -1; + attn_mask_len = -1; } dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); @@ -1391,7 +1426,12 @@ void MultiQueryAppendC4Attention( cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); } - nosplit_kv_kernel<<>>( + launchWithPdlWhenEnabled( + nosplit_kv_kernel, + grids, + blocks, + smem_size, + stream, reinterpret_cast(const_cast(qkv.data())), const_cast(cache_k.data()), const_cast(cache_v.data()), @@ -1410,8 +1450,8 @@ void MultiQueryAppendC4Attention( const_cast(smooth_weight.get().data())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, seq_lens_q.data(), seq_lens_kv.data(), batch_ids.data(), @@ -1420,7 +1460,7 @@ void MultiQueryAppendC4Attention( block_table.data(), meta_data.mask_offset, attn_mask ? const_cast(attn_mask.get().data()) - : nullptr, + : nullptr, max_seq_len, max_dec_len, max_block_num_per_seq, @@ -1476,27 +1516,32 @@ void MultiQueryAppendC4Attention( num_chunks * num_heads)); } } - split_kv_kernel<<>>( + launchWithPdlWhenEnabled( + split_kv_kernel, + grids, + blocks, + smem_size, + stream, reinterpret_cast(const_cast(qkv.data())), const_cast(cache_k.data()), const_cast(cache_v.data()), reinterpret_cast(const_cast(cache_k_scale.data())), cache_k_zp ? reinterpret_cast( - const_cast(cache_k_zp.get().data())) - : nullptr, + const_cast(cache_k_zp.get().data())) + : nullptr, reinterpret_cast(const_cast(cache_v_scale.data())), cache_v_zp ? reinterpret_cast( - const_cast(cache_v_zp.get().data())) - : nullptr, + const_cast(cache_v_zp.get().data())) + : nullptr, shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast( const_cast(smooth_weight.get().data())) : nullptr, sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, + const_cast(sinks.get().data())) + : nullptr, seq_lens_q.data(), seq_lens_kv.data(), batch_ids.data(), @@ -1505,7 +1550,7 @@ void MultiQueryAppendC4Attention( block_table.data(), meta_data.mask_offset, attn_mask ? const_cast(attn_mask.get().data()) - : nullptr, + : nullptr, max_seq_len, max_dec_len, max_block_num_per_seq, @@ -1529,79 +1574,86 @@ void MultiQueryAppendC4Attention( constexpr int blocky = (128 + blockx - 1) / blockx; dim3 grids_merge(bsz, num_heads); dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_decoder_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM); + launchWithPdlWhenEnabled( + merge_multi_chunks_decoder_kernel, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM); } else { constexpr int blockx = HEAD_DIM / vec_size; constexpr int blocky = (128 + blockx - 1) / blockx; - dim3 grids_merge(min(sm_count * 4, token_num), - num_heads); + dim3 grids_merge(min(sm_count * 4, token_num), num_heads); dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_v2_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM, - token_num, - speculate_max_draft_token_num); + launchWithPdlWhenEnabled( + merge_multi_chunks_v2_kernel, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); } } } diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh index 01adb2130..39b371c78 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh @@ -1271,8 +1271,12 @@ void MultiQueryAppendC8Attention( cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); } - - nosplit_kv_kernel<<>>( + launchWithPdlWhenEnabled( + nosplit_kv_kernel, + grids, + blocks, + smem_size, + stream, reinterpret_cast(const_cast(qkv.data())), const_cast(cache_k.data()), const_cast(cache_v.data()), @@ -1335,7 +1339,12 @@ void MultiQueryAppendC8Attention( static_cast(speculate_max_draft_token_num * bsz * num_chunks * num_heads)); } - split_kv_kernel<<>>( + launchWithPdlWhenEnabled( + split_kv_kernel, + grids, + blocks, + smem_size, + stream, reinterpret_cast(const_cast(qkv.data())), const_cast(cache_k.data()), const_cast(cache_v.data()), @@ -1379,78 +1388,86 @@ void MultiQueryAppendC8Attention( constexpr int blocky = (128 + blockx - 1) / blockx; dim3 grids_merge(bsz, num_heads); dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_decoder_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM); + launchWithPdlWhenEnabled( + merge_multi_chunks_decoder_kernel, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM); } else { constexpr int blockx = HEAD_DIM / vec_size; constexpr int blocky = (128 + blockx - 1) / blockx; dim3 grids_merge(min(sm_count * 4, token_num), num_heads); dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_v2_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM, - token_num, - speculate_max_draft_token_num); + launchWithPdlWhenEnabled( + merge_multi_chunks_v2_kernel, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); } } } else { @@ -1568,8 +1585,12 @@ void MultiQueryAppendC8Attention( cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); } - - nosplit_kv_kernel<<>>( + launchWithPdlWhenEnabled( + nosplit_kv_kernel, + grids, + blocks, + smem_size, + stream, reinterpret_cast(const_cast(qkv.data())), const_cast(cache_k.data()), const_cast(cache_v.data()), @@ -1648,7 +1669,12 @@ void MultiQueryAppendC8Attention( num_chunks * num_heads)); } } - split_kv_kernel<<>>( + launchWithPdlWhenEnabled( + split_kv_kernel, + grids, + blocks, + smem_size, + stream, reinterpret_cast(const_cast(qkv.data())), const_cast(cache_k.data()), const_cast(cache_v.data()), @@ -1695,73 +1721,87 @@ void MultiQueryAppendC8Attention( constexpr int blocky = (128 + blockx - 1) / blockx; dim3 grids_merge(bsz, num_heads); dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_decoder_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM); + auto *kernelFn = merge_multi_chunks_decoder_kernel; + launchWithPdlWhenEnabled( + kernelFn, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM); } else { constexpr int blockx = HEAD_DIM / vec_size; constexpr int blocky = (128 + blockx - 1) / blockx; dim3 grids_merge(min(sm_count * 4, token_num), num_heads); dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_v2_kernel - <<>>( - reinterpret_cast(tmp_workspace->ptr()), - static_cast(tmp_m->ptr()), - static_cast(tmp_d->ptr()), - seq_lens_q.data(), - seq_lens_kv.data(), - seq_lens_encoder.data(), - batch_id_per_token.data(), - cu_seqlens_q.data(), - shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, - smooth_weight ? reinterpret_cast(const_cast( - smooth_weight.get().data())) - : nullptr, - sinks ? reinterpret_cast( - const_cast(sinks.get().data())) - : nullptr, - reinterpret_cast(out->data()), - quant_max_bound, - quant_min_bound, - in_scale, - max_seq_len, - num_chunks, - num_heads, - chunk_size, - HEAD_DIM, - token_num, - speculate_max_draft_token_num); + launchWithPdlWhenEnabled( + merge_multi_chunks_v2_kernel, + grids_merge, + blocks_merge, + 0, + stream, + reinterpret_cast(tmp_workspace->ptr()), + static_cast(tmp_m->ptr()), + static_cast(tmp_d->ptr()), + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + shift_bias ? reinterpret_cast( + const_cast(shift_bias.get().data())) + : nullptr, + smooth_weight ? reinterpret_cast( + const_cast(smooth_weight.get().data())) + : nullptr, + sinks ? reinterpret_cast( + const_cast(sinks.get().data())) + : nullptr, + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + in_scale, + max_seq_len, + num_chunks, + num_heads, + chunk_size, + HEAD_DIM, + token_num, + speculate_max_draft_token_num); } } } diff --git a/custom_ops/gpu_ops/helper.cu b/custom_ops/gpu_ops/helper.cu index bb9dd88da..49c78914f 100644 --- a/custom_ops/gpu_ops/helper.cu +++ b/custom_ops/gpu_ops/helper.cu @@ -16,10 +16,10 @@ #include float bfloat16_to_float(__nv_bfloat16 x) { - uint32_t tmp_x = *(reinterpret_cast(&x)); - tmp_x = tmp_x << 16; - float float_x = *(reinterpret_cast(&tmp_x)); - return float_x; + uint32_t tmp_x = *(reinterpret_cast(&x)); + tmp_x = tmp_x << 16; + float float_x = *(reinterpret_cast(&tmp_x)); + return float_x; } template @@ -27,120 +27,132 @@ static void PrintMatrix(const T* mat_d, int num, std::string name, int numOfCols) { - std::vector tmp(num); - cudaMemcpy(tmp.data(), mat_d, sizeof(T) * num, cudaMemcpyDeviceToHost); + std::vector tmp(num); + cudaMemcpy(tmp.data(), mat_d, sizeof(T) * num, cudaMemcpyDeviceToHost); - std::ofstream outfile; - outfile.open(name + ".dtxt", std::ios::out | std::ios::app); - std::stringstream ss; + std::ofstream outfile; + outfile.open(name + ".dtxt", std::ios::out | std::ios::app); + std::stringstream ss; - for (int i = 0; i < num; ++i) { - if (std::is_same::value || std::is_same::value || - std::is_same::value) { - ss << static_cast(tmp[i]) << " "; - } else { - ss << std::setprecision(8) << static_cast(tmp[i]) << " "; - } - if (i % numOfCols == numOfCols - 1) { - ss << std::endl; - } + for (int i = 0; i < num; ++i) { + if (std::is_same::value || std::is_same::value || + std::is_same::value) { + ss << static_cast(tmp[i]) << " "; + } else { + ss << std::setprecision(8) << static_cast(tmp[i]) << " "; } - outfile << ss.str(); - outfile.close(); + if (i % numOfCols == numOfCols - 1) { + ss << std::endl; + } + } + outfile << ss.str(); + outfile.close(); } GPUMemoryChecker::GPUMemoryChecker() { - nvmlReturn_t result = nvmlInit_v2(); - if (NVML_SUCCESS != result) { - throw std::runtime_error("Failed to initialize NVML: " + - std::string(nvmlErrorString(result))); - } + nvmlReturn_t result = nvmlInit_v2(); + if (NVML_SUCCESS != result) { + throw std::runtime_error("Failed to initialize NVML: " + + std::string(nvmlErrorString(result))); + } - result = nvmlDeviceGetCount_v2(&deviceCount_); - if (NVML_SUCCESS != result) { - nvmlShutdown(); - throw std::runtime_error("Failed to get GPU count: " + - std::string(nvmlErrorString(result))); - } - - getCUDAVisibleDevice(); -} - -GPUMemoryChecker::~GPUMemoryChecker() { + result = nvmlDeviceGetCount_v2(&deviceCount_); + if (NVML_SUCCESS != result) { nvmlShutdown(); + throw std::runtime_error("Failed to get GPU count: " + + std::string(nvmlErrorString(result))); + } + + getCUDAVisibleDevice(); } -void GPUMemoryChecker::getCUDAVisibleDevice(){ - std::vector devices; - const char* env_p = std::getenv("CUDA_VISIBLE_DEVICES"); - if(!env_p){ - for(int i = 0; i < deviceCount_; i++){ - visible_device_.push_back(i); - return ; - } - } +GPUMemoryChecker::~GPUMemoryChecker() { nvmlShutdown(); } - std::string env_str(env_p); - std::istringstream stream(env_str); - std::string device_id; - - while(std::getline(stream, device_id, ',')){ - visible_device_.push_back(std::stoi(device_id)); - visible_device_mem_usage_.push_back(-1); +void GPUMemoryChecker::getCUDAVisibleDevice() { + std::vector devices; + const char* env_p = std::getenv("CUDA_VISIBLE_DEVICES"); + if (!env_p) { + for (int i = 0; i < deviceCount_; i++) { + visible_device_.push_back(i); + return; } - std::cout << "\nVisible NVIDIA GPU devices" << env_str << std::endl; - return ; + } + + std::string env_str(env_p); + std::istringstream stream(env_str); + std::string device_id; + + while (std::getline(stream, device_id, ',')) { + visible_device_.push_back(std::stoi(device_id)); + visible_device_mem_usage_.push_back(-1); + } + std::cout << "\nVisible NVIDIA GPU devices" << env_str << std::endl; + return; } void GPUMemoryChecker::addCheckPoint(const char* call_file, int call_line) { - try { + try { + for (int i = 0; i < visible_device_.size(); i++) { + unsigned int device_id = visible_device_.at(i); + nvmlDevice_t device; + nvmlReturn_t result = nvmlDeviceGetHandleByIndex_v2(device_id, &device); + if (NVML_SUCCESS != result) { + std::cerr << "Failed to get handle for GPU " << device_id << ": " + << nvmlErrorString(result) << std::endl; + continue; + } + char name[NVML_DEVICE_NAME_BUFFER_SIZE]; + result = nvmlDeviceGetName(device, name, NVML_DEVICE_NAME_BUFFER_SIZE); + if (NVML_SUCCESS != result) { + std::cerr << "Failed to get name for GPU " << device_id << ": " + << nvmlErrorString(result) << std::endl; + continue; + } - for (int i = 0; i < visible_device_.size(); i++) { - unsigned int device_id = visible_device_.at(i); - nvmlDevice_t device; - nvmlReturn_t result = nvmlDeviceGetHandleByIndex_v2(device_id, &device); - if (NVML_SUCCESS != result) { - std::cerr << "Failed to get handle for GPU " << device_id << ": " - << nvmlErrorString(result) << std::endl; - continue; - } + nvmlMemory_t memoryInfo; + result = nvmlDeviceGetMemoryInfo(device, &memoryInfo); + if (NVML_SUCCESS != result) { + std::cerr << "Failed to get memory info for GPU " << device_id << ": " + << nvmlErrorString(result) << std::endl; + continue; + } - char name[NVML_DEVICE_NAME_BUFFER_SIZE]; - result = nvmlDeviceGetName(device, name, NVML_DEVICE_NAME_BUFFER_SIZE); - if (NVML_SUCCESS != result) { - std::cerr << "Failed to get name for GPU " << device_id << ": " - << nvmlErrorString(result) << std::endl; - continue; - } - - nvmlMemory_t memoryInfo; - result = nvmlDeviceGetMemoryInfo(device, &memoryInfo); - if (NVML_SUCCESS != result) { - std::cerr << "Failed to get memory info for GPU " << device_id << ": " - << nvmlErrorString(result) << std::endl; - continue; - } - - // Check GPU memory - const char* env_c = std::getenv("MEMCHECKER_CHECK_MEMORY"); - if (env_c){ - assert(memoryInfo.used <= visible_device_mem_usage_.at(i) && "GPU Memory does not allow growth!"); - } - visible_device_mem_usage_[i] = memoryInfo.used; - } - - // Check GPU memory - const char* env_p = std::getenv("MEMCHECKER_PRINT_MEMORY"); - if (env_p){ - std::cout << "\nCall Line: "<< call_line << "\t"; - for (int i = 0; i < visible_device_.size(); i++) { - unsigned int device_id = visible_device_.at(i); - std::cout << "GPU " << device_id << ": " - << " Used memory: " << visible_device_mem_usage_.at(device_id) / (1024 * 1024) << " MB\t"; - } - } - } catch (const std::exception& e) { - std::cerr << "Error: " << e.what() << std::endl; + // Check GPU memory + const char* env_c = std::getenv("MEMCHECKER_CHECK_MEMORY"); + if (env_c) { + assert(memoryInfo.used <= visible_device_mem_usage_.at(i) && + "GPU Memory does not allow growth!"); + } + visible_device_mem_usage_[i] = memoryInfo.used; } + + // Check GPU memory + const char* env_p = std::getenv("MEMCHECKER_PRINT_MEMORY"); + if (env_p) { + std::cout << "\nCall Line: " << call_line << "\t"; + for (int i = 0; i < visible_device_.size(); i++) { + unsigned int device_id = visible_device_.at(i); + std::cout << "GPU " << device_id << ": " + << " Used memory: " + << visible_device_mem_usage_.at(device_id) / (1024 * 1024) + << " MB\t"; + } + } + } catch (const std::exception& e) { + std::cerr << "Error: " << e.what() << std::endl; + } +} + +bool getEnvEnablePDL() { + static std::once_flag flag; + static bool enablePDL = false; + + std::call_once(flag, [&]() { + int sm_version = GetSMVersion(); + if (sm_version >= 90) { + enablePDL = getBoolEnv("FD_ENABLE_PDL"); + } + }); + return enablePDL; } diff --git a/custom_ops/gpu_ops/helper.h b/custom_ops/gpu_ops/helper.h index ed740a882..b06ec0211 100644 --- a/custom_ops/gpu_ops/helper.h +++ b/custom_ops/gpu_ops/helper.h @@ -20,6 +20,7 @@ #include "glog/logging.h" #endif #include +#include #include #include #include @@ -27,19 +28,17 @@ #include #include #include -#include #include #include -#include #include #ifdef PADDLE_WITH_HIP #include #include #include -#include #include #include +#include namespace cub = hipcub; #else #include @@ -58,8 +57,8 @@ namespace cub = hipcub; #else #include "paddle/phi/core/cuda_stream.h" #endif -#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/core/dense_tensor.h" #ifdef PADDLE_WITH_COREX #define WARP_SIZE 64 @@ -74,14 +73,16 @@ namespace cub = hipcub; using json = nlohmann::json; #endif -#define CUDA_CHECK(call) \ - do { \ - const cudaError_t error_code = call; \ - if (error_code != cudaSuccess) { \ - std::printf("at %s:%d - %s.\n", __FILE__, __LINE__, \ - cudaGetErrorString(error_code)); \ - exit(1); \ - } \ +#define CUDA_CHECK(call) \ + do { \ + const cudaError_t error_code = call; \ + if (error_code != cudaSuccess) { \ + std::printf("at %s:%d - %s.\n", \ + __FILE__, \ + __LINE__, \ + cudaGetErrorString(error_code)); \ + exit(1); \ + } \ } while (0) #ifdef PADDLE_WITH_HIP @@ -110,9 +111,10 @@ inline hipError_t GetNumBlocks(int64_t n, int *num_blocks) { return err; } } - *num_blocks = std::max( - 1, std::min((n + kBlockSize - 1) / kBlockSize, - sm_count * tpm / kBlockSize * kNumWaves)); + *num_blocks = + std::max(1, + std::min((n + kBlockSize - 1) / kBlockSize, + sm_count * tpm / kBlockSize * kNumWaves)); return hipSuccess; } #else @@ -141,9 +143,10 @@ inline cudaError_t GetNumBlocks(int64_t n, int *num_blocks) { return err; } } - *num_blocks = std::max( - 1, std::min((n + kBlockSize - 1) / kBlockSize, - sm_count * tpm / kBlockSize * kNumWaves)); + *num_blocks = + std::max(1, + std::min((n + kBlockSize - 1) / kBlockSize, + sm_count * tpm / kBlockSize * kNumWaves)); return cudaSuccess; } @@ -163,51 +166,54 @@ inline int GetGPUComputeCapability(int id) { #endif #ifndef DISPATCH_FLOAT_FP6_DTYPE -#define DISPATCH_FLOAT_FP6_DTYPE(pd_dtype, c_type, ...) \ - switch (pd_dtype) { \ - case phi::DataType::FLOAT32: { \ - using c_type = float; \ - __VA_ARGS__ \ - break; \ - } \ - case phi::DataType::BFLOAT16: { \ - using c_type = phi::dtype::bfloat16; \ - __VA_ARGS__ \ - break; \ - } \ - case phi::DataType::FLOAT16: { \ - using c_type = phi::dtype::float16; \ - __VA_ARGS__ \ - break; \ - } \ - default: { \ - PD_THROW("Only supported attr of input type in [fp32, fp16, bf16]."); \ - } \ - } +#define DISPATCH_FLOAT_FP6_DTYPE(pd_dtype, c_type, ...) \ + switch (pd_dtype) { \ + case phi::DataType::FLOAT32: { \ + using c_type = float; \ + __VA_ARGS__ \ + break; \ + } \ + case phi::DataType::BFLOAT16: { \ + using c_type = phi::dtype::bfloat16; \ + __VA_ARGS__ \ + break; \ + } \ + case phi::DataType::FLOAT16: { \ + using c_type = phi::dtype::float16; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + PD_THROW("Only supported attr of input type in [fp32, fp16, bf16]."); \ + } \ + } #endif inline constexpr uint32_t next_pow_2(uint32_t const num) { - if (num <= 1) - return num; + if (num <= 1) return num; return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); } -template class PDTraits; +template +class PDTraits; -template <> class PDTraits { -public: +template <> +class PDTraits { + public: typedef float DataType; typedef float data_t; }; -template <> class PDTraits { -public: +template <> +class PDTraits { + public: typedef half DataType; typedef paddle::float16 data_t; }; -template <> class PDTraits { -public: +template <> +class PDTraits { + public: #ifdef PADDLE_WITH_HIP typedef hip_bfloat16 DataType; #else @@ -216,27 +222,31 @@ public: typedef paddle::bfloat16 data_t; }; -template <> class PDTraits { -public: +template <> +class PDTraits { + public: typedef int8_t DataType; typedef int8_t data_t; }; -template <> class PDTraits { -public: +template <> +class PDTraits { + public: typedef uint8_t DataType; typedef uint8_t data_t; }; #ifndef PADDLE_WITH_COREX -template <> class PDTraits { -public: +template <> +class PDTraits { + public: typedef __nv_fp8_e4m3 DataType; typedef paddle::float8_e4m3fn data_t; }; #endif -template struct alignas(sizeof(T) * Size) AlignedVector { +template +struct alignas(sizeof(T) * Size) AlignedVector { T val[Size]; HOSTDEVICE inline const T &operator[](int i) const { return val[i]; } @@ -261,7 +271,7 @@ HOSTDEVICE inline void Store(const AlignedVector &vec, T *addr) { template HOSTDEVICE inline void Store(const AlignedVector &vec, int8_t *addr) { - printf("Error: Store hip_bfloat16 to int8_t is not supported!"); + printf("Error: Store hip_bfloat16 to int8_t is not supported!"); } #else template @@ -279,11 +289,13 @@ HOSTDEVICE inline void Store(const AlignedVector &vec, constexpr int VEC_16B = 16; -template __device__ T max_func(const T a, const T b) { +template +__device__ T max_func(const T a, const T b) { return a > b ? a : b; } -template struct MaxOp { +template +struct MaxOp { __device__ __forceinline__ T operator()(const T &a, const T &b) const { return max_func(a, b); } @@ -316,14 +328,14 @@ inline json readJsonFromFile(const std::string &filePath) { } #endif -#define cudaCheckError() \ - { \ - cudaError_t e = cudaGetLastError(); \ - if (e != cudaSuccess) { \ - std::cerr << "CUDA Error " << __FILE__ << ":" << __LINE__ << ": " \ - << cudaGetErrorString(e) << std::endl; \ - exit(EXIT_FAILURE); \ - } \ +#define cudaCheckError() \ + { \ + cudaError_t e = cudaGetLastError(); \ + if (e != cudaSuccess) { \ + std::cerr << "CUDA Error " << __FILE__ << ":" << __LINE__ << ": " \ + << cudaGetErrorString(e) << std::endl; \ + exit(EXIT_FAILURE); \ + } \ } // place must be an existing place object and cannot use paddle::CPUPlace() or @@ -336,8 +348,8 @@ inline paddle::Tensor GetEmptyTensor(const common::DDim &dims, auto *allocator = paddle::GetAllocator(place); phi::DenseTensor dense_tensor; dense_tensor.Resize(dims); - dense_tensor.AllocateFrom(allocator, dtype, - dense_tensor.numel() * phi::SizeOf(dtype)); + dense_tensor.AllocateFrom( + allocator, dtype, dense_tensor.numel() * phi::SizeOf(dtype)); return paddle::Tensor(std::make_shared(dense_tensor)); } @@ -348,39 +360,63 @@ inline paddle::Tensor GetEmptyTensor(const common::DDim &dims, auto *allocator = paddle::GetAllocator(place); phi::DenseTensor dense_tensor; dense_tensor.Resize(dims); - dense_tensor.AllocateFrom(allocator, dtype, - dense_tensor.numel() * phi::SizeOf(dtype)); + dense_tensor.AllocateFrom( + allocator, dtype, dense_tensor.numel() * phi::SizeOf(dtype)); dense_tensor.set_strides(strides); return paddle::Tensor(std::make_shared(dense_tensor)); } #endif -__global__ void free_and_dispatch_block( - bool *stop_flags, int *seq_lens_this_time, int *seq_lens_decoder, - int *block_tables, int *encoder_block_lens, bool *is_block_step, - int *step_block_list, // [bsz] - int *step_len, int *recover_block_list, int *recover_len, - int *need_block_list, int *need_block_len, int *used_list_len, - int *free_list, int *free_list_len, int64_t *first_token_ids, const int bsz, - const int block_size, const int block_num_per_seq, - const int max_decoder_block_num); +__global__ void free_and_dispatch_block(bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + bool *is_block_step, + int *step_block_list, // [bsz] + int *step_len, + int *recover_block_list, + int *recover_len, + int *need_block_list, + int *need_block_len, + int *used_list_len, + int *free_list, + int *free_list_len, + int64_t *first_token_ids, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_decoder_block_num); __global__ void speculate_free_and_dispatch_block( - bool *stop_flags, int *seq_lens_this_time, int *seq_lens_decoder, - int *block_tables, int *encoder_block_lens, bool *is_block_step, - int *step_block_list, // [bsz] - int *step_len, int *recover_block_list, int *recover_len, - int *need_block_list, int *need_block_len, int *used_list_len, - int *free_list, int *free_list_len, int64_t *first_token_ids, - int *accept_num, const int bsz, const int block_size, - const int block_num_per_seq, const int max_decoder_block_num, + bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + bool *is_block_step, + int *step_block_list, // [bsz] + int *step_len, + int *recover_block_list, + int *recover_len, + int *need_block_list, + int *need_block_len, + int *used_list_len, + int *free_list, + int *free_list_len, + int64_t *first_token_ids, + int *accept_num, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_decoder_block_num, const int max_draft_tokens); __device__ bool speculate_free_and_dispatch_block(const int &qid, int *need_block_list, const int &need_block_len); -static std::string global_base64_chars = // NOLINT +static std::string global_base64_chars = // NOLINT "Tokp9lA/BjimRVKx32edMPFftOzsbNQ8C15Xn+YUEGc4WD0uLIq7hyJ6vZaHSwrg"; // Base64 编码函数 @@ -501,7 +537,8 @@ inline T get_relative_best(nlohmann::json *json_data, } #endif -__device__ inline bool is_in_end(const int64_t id, const int64_t *end_ids, +__device__ inline bool is_in_end(const int64_t id, + const int64_t *end_ids, int length) { bool flag = false; for (int i = 0; i < length; i++) { @@ -512,22 +549,20 @@ __device__ inline bool is_in_end(const int64_t id, const int64_t *end_ids, return flag; } -template inline __device__ __host__ T div_up(T m, T n) { +template +inline __device__ __host__ T div_up(T m, T n) { return (m + n - 1) / n; } template __device__ __inline__ T ClipFunc(const T v, const T min, const T max) { - if (v > max) - return max; - if (v < min) - return min; + if (v > max) return max; + if (v < min) return min; return v; } template static void PrintMatrix3(const T *mat_d, int num, std::string name) { - std::vector tmp(num); #ifdef PADDLE_WITH_HIP hipMemcpy(tmp.data(), mat_d, sizeof(T) * num, hipMemcpyDeviceToHost); @@ -535,7 +570,6 @@ static void PrintMatrix3(const T *mat_d, int num, std::string name) { cudaMemcpy(tmp.data(), mat_d, sizeof(T) * num, cudaMemcpyDeviceToHost); #endif - std::ofstream outfile; outfile.open(name + ".txt", std::ios::out); std::stringstream ss; @@ -544,7 +578,7 @@ static void PrintMatrix3(const T *mat_d, int num, std::string name) { if (std::is_same::value || std::is_same::value) { ss << static_cast(tmp[i]) << std::endl; } else { - ss << std::setprecision(8) << (float)(tmp[i]) << std::endl; // NOLINT + ss << std::setprecision(8) << (float)(tmp[i]) << std::endl; // NOLINT } } outfile << ss.str(); @@ -573,7 +607,8 @@ __forceinline__ __device__ uint32_t ld_flag_acquire(uint32_t *flag_addr, } __forceinline__ __device__ void st_flag_release(uint32_t *flag_addr, - uint32_t flag, int mode = 0) { + uint32_t flag, + int mode = 0) { if (mode == 0) { asm volatile("st.release.sys.global.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); @@ -589,7 +624,8 @@ __forceinline__ __device__ void st_flag_release(uint32_t *flag_addr, inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { int max_shared_mem_per_block_opt_in = 0; cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, - cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + cudaDevAttrMaxSharedMemoryPerBlockOptin, + device); return max_shared_mem_per_block_opt_in; } #endif @@ -627,29 +663,29 @@ inline bool checkAttentionBackend() { #ifndef GPU_MEMORY_CHECKER_H #define GPU_MEMORY_CHECKER_H class GPUMemoryChecker { -public: - static GPUMemoryChecker* getInstance() { - static GPUMemoryChecker instance; - return &instance; - } + public: + static GPUMemoryChecker *getInstance() { + static GPUMemoryChecker instance; + return &instance; + } - void addCheckPoint(const char* call_file, int call_line); - unsigned int getGPUCount() const { return deviceCount_; } - void getCUDAVisibleDevice(); + void addCheckPoint(const char *call_file, int call_line); + unsigned int getGPUCount() const { return deviceCount_; } + void getCUDAVisibleDevice(); - GPUMemoryChecker(const GPUMemoryChecker&) = delete; - void operator=(const GPUMemoryChecker&) = delete; + GPUMemoryChecker(const GPUMemoryChecker &) = delete; + void operator=(const GPUMemoryChecker &) = delete; -private: - GPUMemoryChecker(); - ~GPUMemoryChecker(); + private: + GPUMemoryChecker(); + ~GPUMemoryChecker(); - unsigned int deviceCount_; - std::vector visible_device_; - std::vector visible_device_mem_usage_; + unsigned int deviceCount_; + std::vector visible_device_; + std::vector visible_device_mem_usage_; }; -#endif // GPU_MEMORY_CHECKER_H +#endif // GPU_MEMORY_CHECKER_H __device__ __forceinline__ float warpReduceMax(float value) { value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 16)); value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 8)); @@ -674,3 +710,31 @@ __device__ __forceinline__ float blockReduceMax(float value) { return value; } +inline bool getBoolEnv(char const *name) { + char const *env = std::getenv(name); + return env && env[0] == '1' && env[1] == '\0'; +} + +bool getEnvEnablePDL(); + +template +inline void launchWithPdlWhenEnabled(KernelFn kernelFn, + dim3 grid, + dim3 block, + size_t dynamicShmSize, + cudaStream_t stream, + Args &&...args) { + cudaLaunchConfig_t kernelConfig; + kernelConfig.gridDim = grid; + kernelConfig.blockDim = block; + kernelConfig.dynamicSmemBytes = dynamicShmSize; + kernelConfig.stream = stream; + + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = getEnvEnablePDL(); + kernelConfig.attrs = attrs; + kernelConfig.numAttrs = 1; + + cudaLaunchKernelEx(&kernelConfig, kernelFn, std::forward(args)...); +} diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index ee261fa0f..40e02445b 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -448,6 +448,7 @@ class LLMEngine: "NCCL_ALGO": "Ring", "FLAGS_max_partition_size": int(os.getenv("FLAGS_max_partition_size", 1024)), "OMP_NUM_THREADS": int(os.getenv("OMP_NUM_THREADS", 3)), + "FD_ENABLE_PDL": envs.FD_ENABLE_PDL, } # environment variables needed by Dy2St variables.update( diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index c1d112bc1..9d6b597b1 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -159,6 +159,7 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_OFFLINE_PERF_TEST_FOR_PD": lambda: int(os.getenv("FD_OFFLINE_PERF_TEST_FOR_PD", "0")), "FD_ENABLE_E2W_TENSOR_CONVERT": lambda: int(os.getenv("FD_ENABLE_E2W_TENSOR_CONVERT", "0")), "FD_ENGINE_TASK_QUEUE_WITH_SHM": lambda: int(os.getenv("FD_ENGINE_TASK_QUEUE_WITH_SHM", "0")), + "FD_ENABLE_PDL": lambda: int(os.getenv("FD_ENABLE_PDL", "1")), }