diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh index b7d8441c6..823eeea3b 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh @@ -1061,12 +1061,11 @@ void MultiQueryAppendAttention( if (!is_decoder) { chunk_size = static_cast(encoder_max_partition_size); } - const int num_chunks = div_up(max_dec_len, chunk_size); + const int num_chunks = div_up(max_seq_len, chunk_size); dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); dim3 blocks(32, num_warps); - - if (num_chunks <= 1) { + if (num_chunks <= 0) { auto nosplit_kv_kernel = multi_query_append_attention_warp1_4_kernel(const_cast(cache_k.data())), reinterpret_cast(const_cast(cache_v.data())), shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast( const_cast(smooth_weight.get().data())) : nullptr, @@ -1208,8 +1207,8 @@ void MultiQueryAppendAttention( seq_lens_encoder.data(), cu_seqlens_q.data(), shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast(const_cast( smooth_weight.get().data())) : nullptr, @@ -1226,14 +1225,14 @@ void MultiQueryAppendAttention( constexpr int blockx = HEAD_DIM / vec_size; constexpr int blocky = (128 + blockx - 1) / blockx; dim3 grids_merge(min(sm_count * 4, token_num), - num_heads); + num_heads); dim3 blocks_merge(blockx, blocky); merge_multi_chunks_v2_kernel + vec_size, + blocky, + HEAD_DIM, + OUT_NV_TYPE, + ENABLE_PREFILL> <<>>( reinterpret_cast(tmp_workspace->ptr()), static_cast(tmp_m->ptr()), @@ -1244,8 +1243,8 @@ void MultiQueryAppendAttention( batch_id_per_token.data(), cu_seqlens_q.data(), shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast(const_cast( smooth_weight.get().data())) : nullptr, diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh index 9f003af88..b7cf89682 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh @@ -1285,10 +1285,11 @@ void MultiQueryAppendC4Attention( if (!is_decoder) { chunk_size = static_cast(encoder_max_partition_size); } - const int num_chunks = div_up(max_dec_len, chunk_size); + + const int num_chunks = div_up(max_seq_len, chunk_size); dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); dim3 blocks(32, num_warps); - if (num_chunks <= 1) { + if (num_chunks <= 0) { auto nosplit_kv_kernel = multi_query_append_attention_c4_warp1_4_kernel(cache_v.data()), reinterpret_cast(const_cast(cache_k_scale.data())), cache_k_zp ? reinterpret_cast( - const_cast(cache_k_zp.get().data())) - : nullptr, + const_cast(cache_k_zp.get().data())) + : nullptr, reinterpret_cast(const_cast(cache_v_scale.data())), cache_v_zp ? reinterpret_cast( - const_cast(cache_v_zp.get().data())) - : nullptr, + const_cast(cache_v_zp.get().data())) + : nullptr, shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast( const_cast(smooth_weight.get().data())) : nullptr, @@ -1445,8 +1446,8 @@ void MultiQueryAppendC4Attention( seq_lens_encoder.data(), cu_seqlens_q.data(), shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast(const_cast( smooth_weight.get().data())) : nullptr, @@ -1463,14 +1464,14 @@ void MultiQueryAppendC4Attention( constexpr int blockx = HEAD_DIM / vec_size; constexpr int blocky = (128 + blockx - 1) / blockx; dim3 grids_merge(min(sm_count * 4, token_num), - num_heads); + num_heads); dim3 blocks_merge(blockx, blocky); merge_multi_chunks_v2_kernel + vec_size, + blocky, + HEAD_DIM, + OUT_NV_TYPE, + ENABLE_PREFILL> <<>>( reinterpret_cast(tmp_workspace->ptr()), static_cast(tmp_m->ptr()), @@ -1481,8 +1482,8 @@ void MultiQueryAppendC4Attention( batch_id_per_token.data(), cu_seqlens_q.data(), shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast(const_cast( smooth_weight.get().data())) : nullptr, diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh index 3b72597e0..9078c70af 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh @@ -1254,10 +1254,10 @@ void MultiQueryAppendC8Attention( chunk_size = static_cast(encoder_max_partition_size); } - const int num_chunks = div_up(max_dec_len, chunk_size); + const int num_chunks = div_up(max_seq_len, chunk_size); dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); dim3 blocks(32, num_warps); - if (num_chunks <= 1) { + if (num_chunks <= 0) { auto nosplit_kv_kernel = multi_query_append_attention_c8_warp1_4_kernel(const_cast(cache_k_scale.data())), reinterpret_cast(const_cast(cache_v_scale.data())), shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast( const_cast(smooth_weight.get().data())) : nullptr, @@ -1418,8 +1418,8 @@ void MultiQueryAppendC8Attention( seq_lens_encoder.data(), cu_seqlens_q.data(), shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast(const_cast( smooth_weight.get().data())) : nullptr, @@ -1436,14 +1436,14 @@ void MultiQueryAppendC8Attention( constexpr int blockx = HEAD_DIM / vec_size; constexpr int blocky = (128 + blockx - 1) / blockx; dim3 grids_merge(min(sm_count * 4, token_num), - num_heads); + num_heads); dim3 blocks_merge(blockx, blocky); merge_multi_chunks_v2_kernel + vec_size, + blocky, + HEAD_DIM, + OUT_NV_TYPE, + ENABLE_PREFILL> <<>>( reinterpret_cast(tmp_workspace->ptr()), static_cast(tmp_m->ptr()), @@ -1454,8 +1454,8 @@ void MultiQueryAppendC8Attention( batch_id_per_token.data(), cu_seqlens_q.data(), shift_bias ? reinterpret_cast( - const_cast(shift_bias.get().data())) - : nullptr, + const_cast(shift_bias.get().data())) + : nullptr, smooth_weight ? reinterpret_cast(const_cast( smooth_weight.get().data())) : nullptr,