From d97aab25bc9ea13ace4c142ba44eecda98d03d3a Mon Sep 17 00:00:00 2001 From: RAM Date: Thu, 21 Aug 2025 20:58:47 +0800 Subject: [PATCH] [Excutor] Fixed the issue of CUDA graph execution failure caused by different branches during decoding (#3223) (#3512) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 彻底解决解码切块问题 * update C8 and C4 kernel * fix problem * fix with pre-commit * retain branch for mtp Co-authored-by: Jundong Liu <61149469+littledgg@users.noreply.github.com> --- .../append_attn/append_attention_c16_impl.cuh | 29 +++++++-------- .../append_attn/append_attention_c4_impl.cuh | 37 ++++++++++--------- .../append_attn/append_attention_c8_impl.cuh | 28 +++++++------- 3 files changed, 47 insertions(+), 47 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh index b7d8441c6..823eeea3b 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh @@ -1061,12 +1061,11 @@ void MultiQueryAppendAttention( if (!is_decoder) { chunk_size = static_cast(encoder_max_partition_size); } - const int num_chunks = div_up(max_dec_len, chunk_size); + const int num_chunks = div_up(max_seq_len, chunk_size); dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); dim3 blocks(32, num_warps); - - if (num_chunks <= 1) { + if (num_chunks <= 0) { auto nosplit_kv_kernel = multi_query_append_attention_warp1_4_kernel(const_cast(cache_k.data())), reinterpret_cast(const_cast(cache_v.data())), shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast( const_cast(smooth_weight.get().data())) : nullptr, @@ -1208,8 +1207,8 @@ void MultiQueryAppendAttention( seq_lens_encoder.data(), cu_seqlens_q.data(), shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast(const_cast( smooth_weight.get().data())) : nullptr, @@ -1226,14 +1225,14 @@ void MultiQueryAppendAttention( constexpr int blockx = HEAD_DIM / vec_size; constexpr int blocky = (128 + blockx - 1) / blockx; dim3 grids_merge(min(sm_count * 4, token_num), - num_heads); + num_heads); dim3 blocks_merge(blockx, blocky); merge_multi_chunks_v2_kernel + vec_size, + blocky, + HEAD_DIM, + OUT_NV_TYPE, + ENABLE_PREFILL> <<>>( reinterpret_cast(tmp_workspace->ptr()), static_cast(tmp_m->ptr()), @@ -1244,8 +1243,8 @@ void MultiQueryAppendAttention( batch_id_per_token.data(), cu_seqlens_q.data(), shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast(const_cast( smooth_weight.get().data())) : nullptr, diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh index 9f003af88..b7cf89682 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh @@ -1285,10 +1285,11 @@ void MultiQueryAppendC4Attention( if (!is_decoder) { chunk_size = static_cast(encoder_max_partition_size); } - const int num_chunks = div_up(max_dec_len, chunk_size); + + const int num_chunks = div_up(max_seq_len, chunk_size); dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); dim3 blocks(32, num_warps); - if (num_chunks <= 1) { + if (num_chunks <= 0) { auto nosplit_kv_kernel = multi_query_append_attention_c4_warp1_4_kernel(cache_v.data()), reinterpret_cast(const_cast(cache_k_scale.data())), cache_k_zp ? reinterpret_cast( - const_cast(cache_k_zp.get().data())) - : nullptr, + const_cast(cache_k_zp.get().data())) + : nullptr, reinterpret_cast(const_cast(cache_v_scale.data())), cache_v_zp ? reinterpret_cast( - const_cast(cache_v_zp.get().data())) - : nullptr, + const_cast(cache_v_zp.get().data())) + : nullptr, shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast( const_cast(smooth_weight.get().data())) : nullptr, @@ -1445,8 +1446,8 @@ void MultiQueryAppendC4Attention( seq_lens_encoder.data(), cu_seqlens_q.data(), shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast(const_cast( smooth_weight.get().data())) : nullptr, @@ -1463,14 +1464,14 @@ void MultiQueryAppendC4Attention( constexpr int blockx = HEAD_DIM / vec_size; constexpr int blocky = (128 + blockx - 1) / blockx; dim3 grids_merge(min(sm_count * 4, token_num), - num_heads); + num_heads); dim3 blocks_merge(blockx, blocky); merge_multi_chunks_v2_kernel + vec_size, + blocky, + HEAD_DIM, + OUT_NV_TYPE, + ENABLE_PREFILL> <<>>( reinterpret_cast(tmp_workspace->ptr()), static_cast(tmp_m->ptr()), @@ -1481,8 +1482,8 @@ void MultiQueryAppendC4Attention( batch_id_per_token.data(), cu_seqlens_q.data(), shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast(const_cast( smooth_weight.get().data())) : nullptr, diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh index 3b72597e0..9078c70af 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh @@ -1254,10 +1254,10 @@ void MultiQueryAppendC8Attention( chunk_size = static_cast(encoder_max_partition_size); } - const int num_chunks = div_up(max_dec_len, chunk_size); + const int num_chunks = div_up(max_seq_len, chunk_size); dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); dim3 blocks(32, num_warps); - if (num_chunks <= 1) { + if (num_chunks <= 0) { auto nosplit_kv_kernel = multi_query_append_attention_c8_warp1_4_kernel(const_cast(cache_k_scale.data())), reinterpret_cast(const_cast(cache_v_scale.data())), shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast( const_cast(smooth_weight.get().data())) : nullptr, @@ -1418,8 +1418,8 @@ void MultiQueryAppendC8Attention( seq_lens_encoder.data(), cu_seqlens_q.data(), shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast(const_cast( smooth_weight.get().data())) : nullptr, @@ -1436,14 +1436,14 @@ void MultiQueryAppendC8Attention( constexpr int blockx = HEAD_DIM / vec_size; constexpr int blocky = (128 + blockx - 1) / blockx; dim3 grids_merge(min(sm_count * 4, token_num), - num_heads); + num_heads); dim3 blocks_merge(blockx, blocky); merge_multi_chunks_v2_kernel + vec_size, + blocky, + HEAD_DIM, + OUT_NV_TYPE, + ENABLE_PREFILL> <<>>( reinterpret_cast(tmp_workspace->ptr()), static_cast(tmp_m->ptr()), @@ -1454,8 +1454,8 @@ void MultiQueryAppendC8Attention( batch_id_per_token.data(), cu_seqlens_q.data(), shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast(const_cast( smooth_weight.get().data())) : nullptr,