mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00
[Excutor] Fixed the issue of CUDA graph execution failure caused by different branches during decoding (#3223) (#3512)
* 彻底解决解码切块问题 * update C8 and C4 kernel * fix problem * fix with pre-commit * retain branch for mtp Co-authored-by: Jundong Liu <61149469+littledgg@users.noreply.github.com>
This commit is contained in:
@@ -1061,12 +1061,11 @@ void MultiQueryAppendAttention(
|
||||
if (!is_decoder) {
|
||||
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
||||
}
|
||||
const int num_chunks = div_up(max_dec_len, chunk_size);
|
||||
|
||||
const int num_chunks = div_up(max_seq_len, chunk_size);
|
||||
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
|
||||
dim3 blocks(32, num_warps);
|
||||
|
||||
if (num_chunks <= 1) {
|
||||
if (num_chunks <= 0) {
|
||||
auto nosplit_kv_kernel =
|
||||
multi_query_append_attention_warp1_4_kernel<NV_TYPE,
|
||||
false,
|
||||
|
@@ -1285,10 +1285,11 @@ void MultiQueryAppendC4Attention(
|
||||
if (!is_decoder) {
|
||||
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
||||
}
|
||||
const int num_chunks = div_up(max_dec_len, chunk_size);
|
||||
|
||||
const int num_chunks = div_up(max_seq_len, chunk_size);
|
||||
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
|
||||
dim3 blocks(32, num_warps);
|
||||
if (num_chunks <= 1) {
|
||||
if (num_chunks <= 0) {
|
||||
auto nosplit_kv_kernel =
|
||||
multi_query_append_attention_c4_warp1_4_kernel<NV_TYPE,
|
||||
uint8_t,
|
||||
|
@@ -1254,10 +1254,10 @@ void MultiQueryAppendC8Attention(
|
||||
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
||||
}
|
||||
|
||||
const int num_chunks = div_up(max_dec_len, chunk_size);
|
||||
const int num_chunks = div_up(max_seq_len, chunk_size);
|
||||
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
|
||||
dim3 blocks(32, num_warps);
|
||||
if (num_chunks <= 1) {
|
||||
if (num_chunks <= 0) {
|
||||
auto nosplit_kv_kernel =
|
||||
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
|
||||
uint8_t,
|
||||
|
Reference in New Issue
Block a user