[Excutor] Fixed the issue of CUDA graph execution failure caused by different branches during decoding (#3223) (#3512)

* 彻底解决解码切块问题

* update C8 and C4 kernel

* fix problem

* fix with pre-commit

* retain branch for mtp

Co-authored-by: Jundong Liu <61149469+littledgg@users.noreply.github.com>
This commit is contained in:
RAM
2025-08-21 20:58:47 +08:00
committed by GitHub
parent 1b399b91c0
commit d97aab25bc
3 changed files with 47 additions and 47 deletions

View File

@@ -1061,12 +1061,11 @@ void MultiQueryAppendAttention(
if (!is_decoder) { if (!is_decoder) {
chunk_size = static_cast<uint32_t>(encoder_max_partition_size); chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
} }
const int num_chunks = div_up(max_dec_len, chunk_size);
const int num_chunks = div_up(max_seq_len, chunk_size);
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
dim3 blocks(32, num_warps); dim3 blocks(32, num_warps);
if (num_chunks <= 0) {
if (num_chunks <= 1) {
auto nosplit_kv_kernel = auto nosplit_kv_kernel =
multi_query_append_attention_warp1_4_kernel<NV_TYPE, multi_query_append_attention_warp1_4_kernel<NV_TYPE,
false, false,
@@ -1161,8 +1160,8 @@ void MultiQueryAppendAttention(
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k.data<T>())), reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k.data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v.data<T>())), reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v.data<T>())),
shift_bias ? reinterpret_cast<NV_TYPE *>( shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>())) const_cast<T *>(shift_bias.get().data<T>()))
: nullptr, : nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>( smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>())) const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr, : nullptr,
@@ -1208,8 +1207,8 @@ void MultiQueryAppendAttention(
seq_lens_encoder.data<int>(), seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(), cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>( shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>())) const_cast<T *>(shift_bias.get().data<T>()))
: nullptr, : nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>( smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>())) smooth_weight.get().data<T>()))
: nullptr, : nullptr,
@@ -1226,14 +1225,14 @@ void MultiQueryAppendAttention(
constexpr int blockx = HEAD_DIM / vec_size; constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx; constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num), dim3 grids_merge(min(sm_count * 4, token_num),
num_heads); num_heads);
dim3 blocks_merge(blockx, blocky); dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE, merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size, vec_size,
blocky, blocky,
HEAD_DIM, HEAD_DIM,
OUT_NV_TYPE, OUT_NV_TYPE,
ENABLE_PREFILL> ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>( <<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()), reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()), static_cast<float *>(tmp_m->ptr()),
@@ -1244,8 +1243,8 @@ void MultiQueryAppendAttention(
batch_id_per_token.data<int>(), batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(), cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>( shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>())) const_cast<T *>(shift_bias.get().data<T>()))
: nullptr, : nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>( smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>())) smooth_weight.get().data<T>()))
: nullptr, : nullptr,

View File

@@ -1285,10 +1285,11 @@ void MultiQueryAppendC4Attention(
if (!is_decoder) { if (!is_decoder) {
chunk_size = static_cast<uint32_t>(encoder_max_partition_size); chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
} }
const int num_chunks = div_up(max_dec_len, chunk_size);
const int num_chunks = div_up(max_seq_len, chunk_size);
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
dim3 blocks(32, num_warps); dim3 blocks(32, num_warps);
if (num_chunks <= 1) { if (num_chunks <= 0) {
auto nosplit_kv_kernel = auto nosplit_kv_kernel =
multi_query_append_attention_c4_warp1_4_kernel<NV_TYPE, multi_query_append_attention_c4_warp1_4_kernel<NV_TYPE,
uint8_t, uint8_t,
@@ -1392,15 +1393,15 @@ void MultiQueryAppendC4Attention(
const_cast<uint8_t *>(cache_v.data<uint8_t>()), const_cast<uint8_t *>(cache_v.data<uint8_t>()),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_scale.data<T>())), reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_scale.data<T>())),
cache_k_zp ? reinterpret_cast<NV_TYPE *>( cache_k_zp ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(cache_k_zp.get().data<T>())) const_cast<T *>(cache_k_zp.get().data<T>()))
: nullptr, : nullptr,
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_scale.data<T>())), reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_scale.data<T>())),
cache_v_zp ? reinterpret_cast<NV_TYPE *>( cache_v_zp ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(cache_v_zp.get().data<T>())) const_cast<T *>(cache_v_zp.get().data<T>()))
: nullptr, : nullptr,
shift_bias ? reinterpret_cast<NV_TYPE *>( shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>())) const_cast<T *>(shift_bias.get().data<T>()))
: nullptr, : nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>( smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>())) const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr, : nullptr,
@@ -1445,8 +1446,8 @@ void MultiQueryAppendC4Attention(
seq_lens_encoder.data<int>(), seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(), cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>( shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>())) const_cast<T *>(shift_bias.get().data<T>()))
: nullptr, : nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>( smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>())) smooth_weight.get().data<T>()))
: nullptr, : nullptr,
@@ -1463,14 +1464,14 @@ void MultiQueryAppendC4Attention(
constexpr int blockx = HEAD_DIM / vec_size; constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx; constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num), dim3 grids_merge(min(sm_count * 4, token_num),
num_heads); num_heads);
dim3 blocks_merge(blockx, blocky); dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE, merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size, vec_size,
blocky, blocky,
HEAD_DIM, HEAD_DIM,
OUT_NV_TYPE, OUT_NV_TYPE,
ENABLE_PREFILL> ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>( <<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()), reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()), static_cast<float *>(tmp_m->ptr()),
@@ -1481,8 +1482,8 @@ void MultiQueryAppendC4Attention(
batch_id_per_token.data<int>(), batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(), cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>( shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>())) const_cast<T *>(shift_bias.get().data<T>()))
: nullptr, : nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>( smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>())) smooth_weight.get().data<T>()))
: nullptr, : nullptr,

View File

@@ -1254,10 +1254,10 @@ void MultiQueryAppendC8Attention(
chunk_size = static_cast<uint32_t>(encoder_max_partition_size); chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
} }
const int num_chunks = div_up(max_dec_len, chunk_size); const int num_chunks = div_up(max_seq_len, chunk_size);
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
dim3 blocks(32, num_warps); dim3 blocks(32, num_warps);
if (num_chunks <= 1) { if (num_chunks <= 0) {
auto nosplit_kv_kernel = auto nosplit_kv_kernel =
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE, multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
uint8_t, uint8_t,
@@ -1377,8 +1377,8 @@ void MultiQueryAppendC8Attention(
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_scale.data<T>())), reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_scale.data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_scale.data<T>())), reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_scale.data<T>())),
shift_bias ? reinterpret_cast<NV_TYPE *>( shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>())) const_cast<T *>(shift_bias.get().data<T>()))
: nullptr, : nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>( smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>())) const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr, : nullptr,
@@ -1418,8 +1418,8 @@ void MultiQueryAppendC8Attention(
seq_lens_encoder.data<int>(), seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(), cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>( shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>())) const_cast<T *>(shift_bias.get().data<T>()))
: nullptr, : nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>( smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>())) smooth_weight.get().data<T>()))
: nullptr, : nullptr,
@@ -1436,14 +1436,14 @@ void MultiQueryAppendC8Attention(
constexpr int blockx = HEAD_DIM / vec_size; constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx; constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num), dim3 grids_merge(min(sm_count * 4, token_num),
num_heads); num_heads);
dim3 blocks_merge(blockx, blocky); dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE, merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size, vec_size,
blocky, blocky,
HEAD_DIM, HEAD_DIM,
OUT_NV_TYPE, OUT_NV_TYPE,
ENABLE_PREFILL> ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>( <<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()), reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()), static_cast<float *>(tmp_m->ptr()),
@@ -1454,8 +1454,8 @@ void MultiQueryAppendC8Attention(
batch_id_per_token.data<int>(), batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(), cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>( shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>())) const_cast<T *>(shift_bias.get().data<T>()))
: nullptr, : nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>( smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>())) smooth_weight.get().data<T>()))
: nullptr, : nullptr,