mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00
[Excutor] Fixed the issue of CUDA graph execution failure caused by different branches during decoding (#3223) (#3512)
* 彻底解决解码切块问题 * update C8 and C4 kernel * fix problem * fix with pre-commit * retain branch for mtp Co-authored-by: Jundong Liu <61149469+littledgg@users.noreply.github.com>
This commit is contained in:
@@ -1061,12 +1061,11 @@ void MultiQueryAppendAttention(
|
|||||||
if (!is_decoder) {
|
if (!is_decoder) {
|
||||||
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
||||||
}
|
}
|
||||||
const int num_chunks = div_up(max_dec_len, chunk_size);
|
|
||||||
|
|
||||||
|
const int num_chunks = div_up(max_seq_len, chunk_size);
|
||||||
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
|
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
|
||||||
dim3 blocks(32, num_warps);
|
dim3 blocks(32, num_warps);
|
||||||
|
if (num_chunks <= 0) {
|
||||||
if (num_chunks <= 1) {
|
|
||||||
auto nosplit_kv_kernel =
|
auto nosplit_kv_kernel =
|
||||||
multi_query_append_attention_warp1_4_kernel<NV_TYPE,
|
multi_query_append_attention_warp1_4_kernel<NV_TYPE,
|
||||||
false,
|
false,
|
||||||
@@ -1161,8 +1160,8 @@ void MultiQueryAppendAttention(
|
|||||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k.data<T>())),
|
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k.data<T>())),
|
||||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v.data<T>())),
|
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v.data<T>())),
|
||||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||||
const_cast<T *>(shift_bias.get().data<T>()))
|
const_cast<T *>(shift_bias.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
@@ -1208,8 +1207,8 @@ void MultiQueryAppendAttention(
|
|||||||
seq_lens_encoder.data<int>(),
|
seq_lens_encoder.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||||
const_cast<T *>(shift_bias.get().data<T>()))
|
const_cast<T *>(shift_bias.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||||
smooth_weight.get().data<T>()))
|
smooth_weight.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
@@ -1226,14 +1225,14 @@ void MultiQueryAppendAttention(
|
|||||||
constexpr int blockx = HEAD_DIM / vec_size;
|
constexpr int blockx = HEAD_DIM / vec_size;
|
||||||
constexpr int blocky = (128 + blockx - 1) / blockx;
|
constexpr int blocky = (128 + blockx - 1) / blockx;
|
||||||
dim3 grids_merge(min(sm_count * 4, token_num),
|
dim3 grids_merge(min(sm_count * 4, token_num),
|
||||||
num_heads);
|
num_heads);
|
||||||
dim3 blocks_merge(blockx, blocky);
|
dim3 blocks_merge(blockx, blocky);
|
||||||
merge_multi_chunks_v2_kernel<NV_TYPE,
|
merge_multi_chunks_v2_kernel<NV_TYPE,
|
||||||
vec_size,
|
vec_size,
|
||||||
blocky,
|
blocky,
|
||||||
HEAD_DIM,
|
HEAD_DIM,
|
||||||
OUT_NV_TYPE,
|
OUT_NV_TYPE,
|
||||||
ENABLE_PREFILL>
|
ENABLE_PREFILL>
|
||||||
<<<grids_merge, blocks_merge, 0, stream>>>(
|
<<<grids_merge, blocks_merge, 0, stream>>>(
|
||||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||||
static_cast<float *>(tmp_m->ptr()),
|
static_cast<float *>(tmp_m->ptr()),
|
||||||
@@ -1244,8 +1243,8 @@ void MultiQueryAppendAttention(
|
|||||||
batch_id_per_token.data<int>(),
|
batch_id_per_token.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||||
const_cast<T *>(shift_bias.get().data<T>()))
|
const_cast<T *>(shift_bias.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||||
smooth_weight.get().data<T>()))
|
smooth_weight.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
|
@@ -1285,10 +1285,11 @@ void MultiQueryAppendC4Attention(
|
|||||||
if (!is_decoder) {
|
if (!is_decoder) {
|
||||||
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
||||||
}
|
}
|
||||||
const int num_chunks = div_up(max_dec_len, chunk_size);
|
|
||||||
|
const int num_chunks = div_up(max_seq_len, chunk_size);
|
||||||
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
|
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
|
||||||
dim3 blocks(32, num_warps);
|
dim3 blocks(32, num_warps);
|
||||||
if (num_chunks <= 1) {
|
if (num_chunks <= 0) {
|
||||||
auto nosplit_kv_kernel =
|
auto nosplit_kv_kernel =
|
||||||
multi_query_append_attention_c4_warp1_4_kernel<NV_TYPE,
|
multi_query_append_attention_c4_warp1_4_kernel<NV_TYPE,
|
||||||
uint8_t,
|
uint8_t,
|
||||||
@@ -1392,15 +1393,15 @@ void MultiQueryAppendC4Attention(
|
|||||||
const_cast<uint8_t *>(cache_v.data<uint8_t>()),
|
const_cast<uint8_t *>(cache_v.data<uint8_t>()),
|
||||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_scale.data<T>())),
|
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_scale.data<T>())),
|
||||||
cache_k_zp ? reinterpret_cast<NV_TYPE *>(
|
cache_k_zp ? reinterpret_cast<NV_TYPE *>(
|
||||||
const_cast<T *>(cache_k_zp.get().data<T>()))
|
const_cast<T *>(cache_k_zp.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_scale.data<T>())),
|
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_scale.data<T>())),
|
||||||
cache_v_zp ? reinterpret_cast<NV_TYPE *>(
|
cache_v_zp ? reinterpret_cast<NV_TYPE *>(
|
||||||
const_cast<T *>(cache_v_zp.get().data<T>()))
|
const_cast<T *>(cache_v_zp.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||||
const_cast<T *>(shift_bias.get().data<T>()))
|
const_cast<T *>(shift_bias.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
@@ -1445,8 +1446,8 @@ void MultiQueryAppendC4Attention(
|
|||||||
seq_lens_encoder.data<int>(),
|
seq_lens_encoder.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||||
const_cast<T *>(shift_bias.get().data<T>()))
|
const_cast<T *>(shift_bias.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||||
smooth_weight.get().data<T>()))
|
smooth_weight.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
@@ -1463,14 +1464,14 @@ void MultiQueryAppendC4Attention(
|
|||||||
constexpr int blockx = HEAD_DIM / vec_size;
|
constexpr int blockx = HEAD_DIM / vec_size;
|
||||||
constexpr int blocky = (128 + blockx - 1) / blockx;
|
constexpr int blocky = (128 + blockx - 1) / blockx;
|
||||||
dim3 grids_merge(min(sm_count * 4, token_num),
|
dim3 grids_merge(min(sm_count * 4, token_num),
|
||||||
num_heads);
|
num_heads);
|
||||||
dim3 blocks_merge(blockx, blocky);
|
dim3 blocks_merge(blockx, blocky);
|
||||||
merge_multi_chunks_v2_kernel<NV_TYPE,
|
merge_multi_chunks_v2_kernel<NV_TYPE,
|
||||||
vec_size,
|
vec_size,
|
||||||
blocky,
|
blocky,
|
||||||
HEAD_DIM,
|
HEAD_DIM,
|
||||||
OUT_NV_TYPE,
|
OUT_NV_TYPE,
|
||||||
ENABLE_PREFILL>
|
ENABLE_PREFILL>
|
||||||
<<<grids_merge, blocks_merge, 0, stream>>>(
|
<<<grids_merge, blocks_merge, 0, stream>>>(
|
||||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||||
static_cast<float *>(tmp_m->ptr()),
|
static_cast<float *>(tmp_m->ptr()),
|
||||||
@@ -1481,8 +1482,8 @@ void MultiQueryAppendC4Attention(
|
|||||||
batch_id_per_token.data<int>(),
|
batch_id_per_token.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||||
const_cast<T *>(shift_bias.get().data<T>()))
|
const_cast<T *>(shift_bias.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||||
smooth_weight.get().data<T>()))
|
smooth_weight.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
|
@@ -1254,10 +1254,10 @@ void MultiQueryAppendC8Attention(
|
|||||||
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int num_chunks = div_up(max_dec_len, chunk_size);
|
const int num_chunks = div_up(max_seq_len, chunk_size);
|
||||||
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
|
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
|
||||||
dim3 blocks(32, num_warps);
|
dim3 blocks(32, num_warps);
|
||||||
if (num_chunks <= 1) {
|
if (num_chunks <= 0) {
|
||||||
auto nosplit_kv_kernel =
|
auto nosplit_kv_kernel =
|
||||||
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
|
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
|
||||||
uint8_t,
|
uint8_t,
|
||||||
@@ -1377,8 +1377,8 @@ void MultiQueryAppendC8Attention(
|
|||||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_scale.data<T>())),
|
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_scale.data<T>())),
|
||||||
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_scale.data<T>())),
|
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_scale.data<T>())),
|
||||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||||
const_cast<T *>(shift_bias.get().data<T>()))
|
const_cast<T *>(shift_bias.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
@@ -1418,8 +1418,8 @@ void MultiQueryAppendC8Attention(
|
|||||||
seq_lens_encoder.data<int>(),
|
seq_lens_encoder.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||||
const_cast<T *>(shift_bias.get().data<T>()))
|
const_cast<T *>(shift_bias.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||||
smooth_weight.get().data<T>()))
|
smooth_weight.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
@@ -1436,14 +1436,14 @@ void MultiQueryAppendC8Attention(
|
|||||||
constexpr int blockx = HEAD_DIM / vec_size;
|
constexpr int blockx = HEAD_DIM / vec_size;
|
||||||
constexpr int blocky = (128 + blockx - 1) / blockx;
|
constexpr int blocky = (128 + blockx - 1) / blockx;
|
||||||
dim3 grids_merge(min(sm_count * 4, token_num),
|
dim3 grids_merge(min(sm_count * 4, token_num),
|
||||||
num_heads);
|
num_heads);
|
||||||
dim3 blocks_merge(blockx, blocky);
|
dim3 blocks_merge(blockx, blocky);
|
||||||
merge_multi_chunks_v2_kernel<NV_TYPE,
|
merge_multi_chunks_v2_kernel<NV_TYPE,
|
||||||
vec_size,
|
vec_size,
|
||||||
blocky,
|
blocky,
|
||||||
HEAD_DIM,
|
HEAD_DIM,
|
||||||
OUT_NV_TYPE,
|
OUT_NV_TYPE,
|
||||||
ENABLE_PREFILL>
|
ENABLE_PREFILL>
|
||||||
<<<grids_merge, blocks_merge, 0, stream>>>(
|
<<<grids_merge, blocks_merge, 0, stream>>>(
|
||||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||||
static_cast<float *>(tmp_m->ptr()),
|
static_cast<float *>(tmp_m->ptr()),
|
||||||
@@ -1454,8 +1454,8 @@ void MultiQueryAppendC8Attention(
|
|||||||
batch_id_per_token.data<int>(),
|
batch_id_per_token.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||||
const_cast<T *>(shift_bias.get().data<T>()))
|
const_cast<T *>(shift_bias.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
|
||||||
smooth_weight.get().data<T>()))
|
smooth_weight.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
|
Reference in New Issue
Block a user