diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index dbe072250..2e2e8c7ba 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -11,10 +11,11 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - +#include "cute/tensor.hpp" #include "helper.h" #include "paddle/extension.h" #include "paddle/phi/core/memory/memcpy.h" +#include "utils.cuh" template __global__ void @@ -116,6 +117,93 @@ void GetMaxLen(const paddle::Tensor &seq_lens_tensor, max_len_tensor.data(), batch_size); } +template +__global__ void search_chunk_size_for_mla( + const int *__restrict__ seq_lens_q, + const int *__restrict__ seq_lens_encoder, + const int *__restrict__ seq_lens_decoder, + int *__restrict__ num_blocks_x, + int *__restrict__ res_chunk_size, + const int bsz, + const int set_chunk_size, + const int block_size, + const int sm_cout) { + const uint32_t conf_id = threadIdx.x; + int gridx = 0; + if (set_chunk_size > 0 && conf_id == 0) { + for (uint32_t bid = 0; bid < bsz; bid++) { + int seq_len = seq_lens_q[bid]; + int seq_len_encoder = seq_lens_encoder[bid]; + int seq_len_decoder = seq_lens_decoder[bid] + seq_len; + if (seq_len == 0 || seq_len_encoder > 0) continue; + + int loop_times; + loop_times = cute::ceil_div(seq_len_decoder, set_chunk_size); + gridx += loop_times; + } + *num_blocks_x = gridx; + *res_chunk_size = set_chunk_size; + } else if (conf_id < config_size) { + __shared__ int gridx_shared[config_size]; + // chunk_size is a multiple of 64 + const int chunk_size = block_size << conf_id; + for (uint32_t bid = 0; bid < bsz; bid++) { + int seq_len = seq_lens_q[bid]; + int seq_len_encoder = seq_lens_encoder[bid]; + int seq_len_decoder = seq_lens_decoder[bid] + seq_len; + if (seq_len == 0 || seq_len_encoder > 0) continue; + + int loop_times; + loop_times = cute::ceil_div(seq_len_decoder, chunk_size); + gridx += loop_times; + } + gridx_shared[conf_id] = gridx; + __syncthreads(); + if (threadIdx.x == 0) { + uint32_t res_id = 0; + uint32_t max_last_wave_block = 0; + for (uint32_t i = 1; i < config_size; i++) { + uint32_t last_wave_block = gridx_shared[i] % sm_cout; + if (last_wave_block >= max_last_wave_block) { + res_id = i; + max_last_wave_block = last_wave_block; + } + } + *num_blocks_x = gridx_shared[res_id]; + *res_chunk_size = block_size << res_id; + } + } +} + +__global__ void split_block_for_mla(const int *__restrict__ seq_lens_q, + const int *__restrict__ seq_lens_encoder, + const int *__restrict__ seq_lens_decoder, + int *__restrict__ batch_ids, + int *__restrict__ tile_ids_per_batch, + const int bsz, + const int chunk_size) { + if (threadIdx.x == 0) { + int index = 0; + for (uint32_t bid = 0; bid < bsz; bid++) { + int seq_len = seq_lens_q[bid]; + int seq_len_encoder = seq_lens_encoder[bid]; + int seq_len_decoder = seq_lens_decoder[bid] + seq_len; + + if (seq_len == 0) continue; + + int loop_times; + loop_times = cute::ceil_div(seq_len_decoder, chunk_size); + if (seq_len_encoder > 0) { + loop_times = 0; + } + for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) { + batch_ids[index] = bid; + tile_ids_per_batch[index++] = tile_id; + } + } + } +} + __global__ void split_q_block(const int *__restrict__ seq_lens_q, const int *__restrict__ seq_lens_encoder, int *__restrict__ batch_ids, @@ -197,7 +285,9 @@ void GetBlockShapeAndSplitKVBlock( const paddle::Tensor &seq_lens_this_time, paddle::Tensor &decoder_batch_ids, // Inplace paddle::Tensor &decoder_tile_ids_per_batch, // Inplace - paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory + paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory + paddle::Tensor &decoder_num_blocks_device, // Inplace + paddle::Tensor &decoder_chunk_size_device, // Inplace paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU paddle::Tensor &encoder_batch_ids, // Inplace paddle::Tensor &encoder_tile_ids_per_batch, // Inplace @@ -230,8 +320,6 @@ void GetBlockShapeAndSplitKVBlock( int max_system_len = max_len_cpu_ptr[6]; int max_just_dec_len_without_system = max_len_cpu_ptr[7]; - - auto max_len_kv = GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place()); get_max_len_kv_ernel<128><<<1, 128, 0, stream>>>( @@ -241,6 +329,106 @@ void GetBlockShapeAndSplitKVBlock( max_len_kv_cpu.copy_(max_len_kv, max_len_kv_cpu.place(), false); + // decoder + if (max_dec_len_this_time > 0) { + const bool mla_use_tensorcore = GetMlaUseTensorcore(); + if (mla_use_tensorcore && group_size <= 64) { + const int set_chunk_size = get_mla_dec_chunk_size(bsz); + + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( + decoder_chunk_size_device.data(), 64, sizeof(int32_t), stream)); + + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( + decoder_num_blocks_device.data(), 0, sizeof(int32_t), stream)); + + int device; + cudaGetDevice(&device); + int sm_cout; + cudaDeviceGetAttribute(&sm_cout, cudaDevAttrMultiProcessorCount, device); + constexpr int config_size = + 12; // search space for chunk size:[64, 128, 256, ... 131072] + + search_chunk_size_for_mla + <<<1, 32, 0, stream>>>(seq_lens_this_time.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + decoder_num_blocks_device.data(), + decoder_chunk_size_device.data(), + bsz, + set_chunk_size, + block_size, + sm_cout); + + decoder_num_blocks_cpu.copy_( + decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false); + auto decoder_chunk_size_cpu = + decoder_chunk_size_device.copy_to(paddle::CPUPlace(), false); + const int chunk_size = decoder_chunk_size_cpu.data()[0]; + + // NOTE: (changwenbin) When using auto_chunk, + // decode_max_tile_size must take into account the maximum case, where * 1024 can cover 128K. + // const uint32_t decoder_batch_shape = seq_lens_decoder.dims()[0] * 1024; + + const uint32_t decoder_max_tile_size_per_bs_q = + div_up((decoder_step_token_num * group_size), decoder_block_shape_q); + const uint32_t decoder_batch_shape = + bsz * 1024 * decoder_max_tile_size_per_bs_q; + + PADDLE_ENFORCE_GPU_SUCCESS( + cudaMemsetAsync(decoder_batch_ids.data(), + 0, + decoder_batch_shape * sizeof(int32_t), + stream)); + PADDLE_ENFORCE_GPU_SUCCESS( + cudaMemsetAsync(decoder_tile_ids_per_batch.data(), + 0, + decoder_batch_shape * sizeof(int32_t), + stream)); + + + split_block_for_mla<<<1, 32, 0, stream>>>( + seq_lens_this_time.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + decoder_batch_ids.data(), + decoder_tile_ids_per_batch.data(), + bsz, + chunk_size); + + } else { + // Note:(changwenbin)In order to adapt to cudagraph, the maximum value should be taken here + const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q); + const uint32_t decoder_batch_shape = bsz * 1024 * decoder_max_tile_size_per_bs_q; + + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data(), 0, decoder_batch_shape * sizeof(int32_t), stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data(), 0, decoder_batch_shape * sizeof(int32_t), stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_device.data(), 0, sizeof(int32_t), stream)); + + split_q_block<<<1, 32, 0, stream>>>( + seq_lens_this_time.data(), + seq_lens_encoder.data(), + decoder_batch_ids.data(), + decoder_tile_ids_per_batch.data(), + decoder_num_blocks_device.data(), + bsz, + decoder_block_shape_q, + group_size); + + decoder_num_blocks_cpu.copy_( + decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( + decoder_chunk_size_device.data(), 64, sizeof(int32_t), stream)); + } + } else { + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( + decoder_chunk_size_device.data(), 64, sizeof(int32_t), stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( + decoder_num_blocks_device.data(), 0, sizeof(int32_t), stream)); + decoder_num_blocks_cpu.copy_( + decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false); + } + + // encoder if (max_enc_len_this_time > 0) { const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size); const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv; @@ -272,28 +460,6 @@ void GetBlockShapeAndSplitKVBlock( encoder_num_blocks_x_cpu.copy_(encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false); } - if (max_just_dec_len_this_time > 0) { - // Clear buffer - const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q); - const uint32_t decoder_batch_shape = bsz * decoder_max_tile_size_per_bs_q; - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data(), 0, decoder_batch_shape * sizeof(int32_t), stream)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data(), 0, decoder_batch_shape * sizeof(int32_t), stream)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_x_cpu.data(), 0, sizeof(int32_t), stream)); - - auto decoder_num_blocks_x = - GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place()); - split_q_block<<<1, 32, 0, stream>>>( - seq_lens_this_time.data(), - seq_lens_encoder.data(), - decoder_batch_ids.data(), - decoder_tile_ids_per_batch.data(), - decoder_num_blocks_x.data(), - bsz, - decoder_block_shape_q, - group_size); - decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false); - } - } PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block) @@ -303,7 +469,9 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block) "seq_lens_this_time", "decoder_batch_ids", "decoder_tile_ids_per_batch", - "decoder_num_blocks_x_cpu", + "decoder_num_blocks_cpu", + "decoder_num_blocks_device", + "decoder_chunk_size_device", "max_len_tensor_cpu", "encoder_batch_ids", "encoder_tile_ids_per_batch", diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index a8348fed1..1d977f50a 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -63,7 +63,7 @@ std::vector AppendAttention( const paddle::Tensor &kv_num_blocks, const paddle::Tensor &decoder_batch_ids, const paddle::Tensor &decoder_tile_ids_per_batch, - const paddle::Tensor &decoder_num_blocks, + const paddle::Tensor &decoder_num_blocks_cpu, const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv, const paddle::optional &rotary_embs, const paddle::optional &attn_mask, @@ -105,7 +105,7 @@ void AppendAttentionWithOutput( const paddle::Tensor &kv_num_blocks, const paddle::Tensor &decoder_batch_ids, const paddle::Tensor &decoder_tile_ids_per_batch, - const paddle::Tensor &decoder_num_blocks, + const paddle::Tensor &decoder_num_blocks_cpu, const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv, paddle::Tensor &fmha_out, const paddle::optional &rotary_embs, @@ -305,7 +305,9 @@ void GetBlockShapeAndSplitKVBlock( const paddle::Tensor &seq_lens_this_time, paddle::Tensor &decoder_batch_ids, // Inplace paddle::Tensor &decoder_tile_ids_per_batch, // Inplace - paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory + paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory + paddle::Tensor &decoder_num_blocks_device, // Inplace + paddle::Tensor &decoder_chunk_size_device, // Inplace paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory paddle::Tensor &encoder_batch_ids, // Inplace paddle::Tensor &encoder_tile_ids_per_batch, // Inplace @@ -473,23 +475,18 @@ std::vector MultiHeadLatentAttention( const paddle::Tensor& query, const paddle::Tensor& key_cache, const paddle::Tensor& value_cache, - const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& batch_id_per_token, const paddle::Tensor& block_tables, - const paddle::Tensor& encoder_batch_ids, - const paddle::Tensor& encoder_tile_ids_per_batch, - const paddle::Tensor& encoder_num_blocks, const paddle::Tensor& kv_batch_ids, const paddle::Tensor& kv_tile_ids_per_batch, const paddle::Tensor& kv_num_blocks, const paddle::Tensor& decoder_batch_ids, const paddle::Tensor& decoder_tile_ids_per_batch, - const paddle::Tensor& decoder_num_blocks, - const paddle::Tensor& decoder_num_blocks_cpu, - const paddle::Tensor& max_enc_len_this_time, + const paddle::Tensor& decoder_num_blocks_device, + const paddle::Tensor& decoder_chunk_size_device, const paddle::Tensor& max_dec_len_this_time, const paddle::Tensor& max_len_kv, const paddle::optional& attn_mask, diff --git a/custom_ops/gpu_ops/env.h b/custom_ops/gpu_ops/env.h index c7db21ba8..5e8eee339 100644 --- a/custom_ops/gpu_ops/env.h +++ b/custom_ops/gpu_ops/env.h @@ -59,6 +59,15 @@ inline uint32_t get_cascade_attention_num_threads() { inline bool get_mla_use_tensorcore() { static const char* mla_use_tensorcore_env = std::getenv("FLAGS_mla_use_tensorcore"); static const uint32_t mla_use_tensorcore = - mla_use_tensorcore_env == nullptr ? 1 : std::stoul(std::string(mla_use_tensorcore_env)); + mla_use_tensorcore_env == nullptr ? 0 : std::stoul(std::string(mla_use_tensorcore_env)); return mla_use_tensorcore != 0 ? true : false; } +inline int get_mla_dec_chunk_size(int bsz) { + static const char* mla_dec_chunk_size_env = + std::getenv("FLAGS_mla_dec_chunk_size"); + static const int mla_dec_chunk_size = + mla_dec_chunk_size_env == nullptr + ? -1 + : std::stoi(std::string(mla_dec_chunk_size_env)); + return bsz > 1 ? mla_dec_chunk_size : 64; +} diff --git a/custom_ops/gpu_ops/helper.h b/custom_ops/gpu_ops/helper.h index 8256d43cd..a99dac1e2 100644 --- a/custom_ops/gpu_ops/helper.h +++ b/custom_ops/gpu_ops/helper.h @@ -563,3 +563,11 @@ inline int GetSMVersion() { return sm_version; } + +inline bool GetMlaUseTensorcore() { + static const bool flags_mla_use_tensorcore = get_mla_use_tensorcore(); + static const bool enable_mla_tensorcore = GetSMVersion() >= 90 ? true : false; + const bool mla_use_tensorcore = + flags_mla_use_tensorcore && enable_mla_tensorcore; + return mla_use_tensorcore; +} diff --git a/custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.cu b/custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.cu index f7d4b8ae2..3b4c0b1e1 100644 --- a/custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.cu +++ b/custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.cu @@ -70,7 +70,6 @@ void BatchMLAWithPagedKVCacheKernel( const paddle::optional& smooth_weight, // [num_kv_heads, head_dim] const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& batch_id_per_token, const paddle::Tensor& block_tables, @@ -78,9 +77,8 @@ void BatchMLAWithPagedKVCacheKernel( const paddle::Tensor& tile_ids_per_batch, const paddle::Tensor& num_blocks_x_device, const std::string& cache_quant_type_str, - const int num_blocks_x, + const paddle::Tensor& decoder_chunk_size_device, const int max_seq_len, - const int max_dec_len, const float softmax_scale, const float quant_max_bound, const float quant_min_bound, @@ -97,14 +95,12 @@ void BatchMLAWithPagedKVCacheKernel( const auto q_head_num = meta_data.q_num_heads; const auto max_block_num_per_seq = meta_data.max_blocks_per_seq; const auto max_block_num = bsz * max_block_num_per_seq; - const uint32_t chunk_size = get_max_partition_size(bsz); - int q_head_dim = meta_data.head_dims; int k_head_dim = meta_data.head_dims; int v_head_dim = meta_data.head_dims_v; // int num_chunks = max_dec_len / chunk_size; - int num_chunks = div_up(max_dec_len, chunk_size); + int num_chunks = div_up(max_seq_len, 64); auto *allocator = paddle::GetAllocator(q.place()); phi::Allocator::AllocationPtr O_tmp, m_tmp, d_tmp; @@ -127,14 +123,14 @@ void BatchMLAWithPagedKVCacheKernel( params.d = reinterpret_cast(d_tmp->ptr()); params.block_tables = const_cast(block_tables.data()); params.seq_lens_this_time = const_cast(seq_lens_this_time.data()); - params.seq_lens_encoder = const_cast(seq_lens_encoder.data()); params.seq_lens_decoder = const_cast(seq_lens_decoder.data()); params.cumsum_q_seqlens = const_cast(cu_seqlens_q.data()); params.batch_id_per_token = const_cast(batch_id_per_token.data()); params.batch_ids = const_cast(batch_ids.data()); params.tile_ids_per_batch = const_cast(tile_ids_per_batch.data()); params.num_blocks_x = const_cast(num_blocks_x_device.data()); - params.num_blocks_x_int = num_blocks_x; + params.chunk_size_device = + const_cast(decoder_chunk_size_device.data()); params.q_stride_bsz = q_head_num * q_head_dim; params.q_stride_head_num = q_head_dim; params.kv_stride_block_num = block_size * k_head_dim; @@ -151,7 +147,6 @@ void BatchMLAWithPagedKVCacheKernel( params.block_size = block_size; params.max_draft_token_num = draft_token_num; params.sm_scale = softmax_scale; - params.chunk_size = chunk_size; params.chunk_num = num_chunks; if (q_head_dim == 576) { @@ -176,7 +171,6 @@ template void BatchMLAWithPagedKVCacheKernel( const paddle::optional& smooth_weight, // [num_kv_heads, head_dim] const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& batch_id_per_token, const paddle::Tensor& block_tables, @@ -184,9 +178,8 @@ template void BatchMLAWithPagedKVCacheKernel( const paddle::Tensor& tile_ids_per_batch, const paddle::Tensor& num_blocks_x_device, const std::string& cache_quant_type_str, - const int num_blocks_x, + const paddle::Tensor& decoder_chunk_size_device, const int max_seq_len, - const int max_dec_len, const float softmax_scale, const float quant_max_bound, const float quant_min_bound, @@ -210,7 +203,6 @@ template void BatchMLAWithPagedKVCacheKernel( const paddle::optional& smooth_weight, // [num_kv_heads, head_dim] const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& batch_id_per_token, const paddle::Tensor& block_tables, @@ -218,9 +210,8 @@ template void BatchMLAWithPagedKVCacheKernel( const paddle::Tensor& tile_ids_per_batch, const paddle::Tensor& num_blocks_x_device, const std::string& cache_quant_type_str, - const int num_blocks_x, + const paddle::Tensor& decoder_chunk_size_device, const int max_seq_len, - const int max_dec_len, const float softmax_scale, const float quant_max_bound, const float quant_min_bound, diff --git a/custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.h b/custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.h index 97fffe39d..62f5d5e9f 100644 --- a/custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.h +++ b/custom_ops/gpu_ops/mla_attn/batch_mla_with_paged_kv_cache.h @@ -47,7 +47,6 @@ void BatchMLAWithPagedKVCacheKernel( const paddle::optional& smooth_weight, // [num_kv_heads, head_dim] const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& batch_id_per_token, const paddle::Tensor& block_tables, @@ -55,9 +54,8 @@ void BatchMLAWithPagedKVCacheKernel( const paddle::Tensor& tile_ids_per_batch, const paddle::Tensor& num_blocks_x_device, const std::string& cache_quant_type_str, - const int num_blocks_x, + const paddle::Tensor& decoder_chunk_size_device, const int max_seq_len, - const int max_dec_len, const float softmax_scale, const float quant_max_bound, const float quant_min_bound, diff --git a/custom_ops/gpu_ops/mla_attn/mainloop_load.cuh b/custom_ops/gpu_ops/mla_attn/mainloop_load.cuh index 9c67f601f..4ee350f76 100644 --- a/custom_ops/gpu_ops/mla_attn/mainloop_load.cuh +++ b/custom_ops/gpu_ops/mla_attn/mainloop_load.cuh @@ -128,12 +128,13 @@ struct CollectiveMainloop { DTypeMD const* d_ptr; IdType const* kv_block_tables; IdType const* seq_lens_this_time; - IdType const* seq_lens_encoder; + // IdType const* seq_lens_encoder; IdType const* seq_lens_decoder; IdType const* cumsum_q_seqlens; IdType const* batch_ids; IdType const* tile_ids_per_batch; IdType const* num_blocks_x; + IdType const* chunk_size_device; float sm_scale; int bsz; int max_block_num; @@ -144,7 +145,7 @@ struct CollectiveMainloop { int kv_stride_block_size; int o_stride_bsz; int o_stride_head_num; - int chunk_size; + // int chunk_size; int chunk_num; int max_draft_token_num; }; @@ -160,12 +161,13 @@ struct CollectiveMainloop { DTypeMD* d_ptr; IdType* kv_block_tables; IdType* seq_lens_this_time; - IdType* seq_lens_encoder; + // IdType* seq_lens_encoder; IdType* seq_lens_decoder; IdType* cumsum_q_seqlens; IdType* batch_ids; IdType* tile_ids_per_batch; IdType* num_blocks_x; + IdType* chunk_size_device; float sm_scale; int bsz; int max_block_num; @@ -176,7 +178,7 @@ struct CollectiveMainloop { int kv_stride_block_size; int o_stride_bsz; int o_stride_head_num; - int chunk_size; + // int chunk_size; int chunk_num; int max_draft_token_num; TMA_KV tma_load_KV; @@ -198,12 +200,13 @@ struct CollectiveMainloop { const_cast(args.d_ptr), const_cast(args.kv_block_tables), const_cast(args.seq_lens_this_time), - const_cast(args.seq_lens_encoder), + // const_cast(args.seq_lens_encoder), const_cast(args.seq_lens_decoder), const_cast(args.cumsum_q_seqlens), const_cast(args.batch_ids), const_cast(args.tile_ids_per_batch), const_cast(args.num_blocks_x), + const_cast(args.chunk_size_device), args.sm_scale, args.bsz, args.max_block_num, @@ -214,7 +217,7 @@ struct CollectiveMainloop { args.kv_stride_block_size, args.o_stride_bsz, args.o_stride_head_num, - args.chunk_size, + // args.chunk_size, args.chunk_num, args.max_draft_token_num, tma_load_KV @@ -281,9 +284,9 @@ struct CollectiveMainloop { auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_slice(thread_idx); static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{}); - const int start_len = tile_idx * mainloop_params.chunk_size; + const int start_len = tile_idx * mainloop_params.chunk_size_device[0]; const int start_tile_idx = start_len / BLOCK_SHAPE_KV; - const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1; + const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1; auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1))); @@ -322,9 +325,9 @@ struct CollectiveMainloop { group_modes<0, 2>(sK), group_modes<0, 2>(gKV)); static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{}); - const int start_len = tile_idx * mainloop_params.chunk_size; + const int start_len = tile_idx * mainloop_params.chunk_size_device[0]; const int start_tile_idx = start_len / BLOCK_SHAPE_KV; - const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1; + const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1; auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1))); diff --git a/custom_ops/gpu_ops/mla_attn/mainloop_mma.cuh b/custom_ops/gpu_ops/mla_attn/mainloop_mma.cuh index 77d059583..b6cdabb8f 100644 --- a/custom_ops/gpu_ops/mla_attn/mainloop_mma.cuh +++ b/custom_ops/gpu_ops/mla_attn/mainloop_mma.cuh @@ -57,7 +57,7 @@ CUTLASS_DEVICE void mma_f16(const Params& mainloop_params, using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage; static_assert(is_rmem::value, "O tensor must be rmem resident."); - const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size); + const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size_device[0]); static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{}); static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{}); @@ -84,9 +84,9 @@ CUTLASS_DEVICE void mma_f16(const Params& mainloop_params, Tensor tOrV2 = threadMmaPVSS.partition_fragment_B(sVt_s2); Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS); - const int start_len = tile_idx * mainloop_params.chunk_size; + const int start_len = tile_idx * mainloop_params.chunk_size_device[0]; const int start_tile_idx = start_len / BLOCK_SHAPE_KV; - const int end_tile_idx =cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1; + const int end_tile_idx =cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1; int kv_tile_idx = end_tile_idx; auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { @@ -263,7 +263,7 @@ CUTLASS_DEVICE void mma_f16_two_stages(const Params& mainloop_params, using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage; static_assert(is_rmem::value, "O tensor must be rmem resident."); - const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size); + const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size_device[0]); static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{}); static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{}); @@ -295,9 +295,9 @@ CUTLASS_DEVICE void mma_f16_two_stages(const Params& mainloop_params, Tensor tOrV4 = threadMmaPVSS.partition_fragment_B(sVt_s4); Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS); - const int start_len = tile_idx * mainloop_params.chunk_size; + const int start_len = tile_idx * mainloop_params.chunk_size_device[0]; const int start_tile_idx = start_len / BLOCK_SHAPE_KV; - const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1; + const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1; int kv_tile_idx = end_tile_idx; auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { diff --git a/custom_ops/gpu_ops/mla_attn/mla_hopper.cuh b/custom_ops/gpu_ops/mla_attn/mla_hopper.cuh index ba1f4b447..2d55d91e5 100644 --- a/custom_ops/gpu_ops/mla_attn/mla_hopper.cuh +++ b/custom_ops/gpu_ops/mla_attn/mla_hopper.cuh @@ -62,13 +62,12 @@ struct Params { alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head] alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head] alignas(16) DTypeO *O; // [token_num, head_num, dim_head] - alignas(16) DTypeO *O_tmp; // [num_chunks, bsz, head_num, dim_head] - alignas(16) float *m; // [num_chunks, bsz * max_draft_token_num * head_num] - alignas(16) float *d; // [num_chunks, bsz * max_draft_token_num * head_num] + alignas(16) DTypeO *O_tmp; // [max_num_chunks, bsz, head_num, dim_head] + alignas(16) float *m; // [max_num_chunks, bsz * max_draft_token_num * head_num] + alignas(16) float *d; // [max_num_chunks, bsz * max_draft_token_num * head_num] alignas(16) IdType *block_tables; alignas(16) IdType *seq_lens_this_time; - alignas(16) IdType *seq_lens_encoder; alignas(16) IdType *seq_lens_decoder; alignas(16) IdType *cumsum_q_seqlens; alignas(16) IdType *batch_id_per_token; @@ -76,7 +75,7 @@ struct Params { alignas(16) IdType *batch_ids; alignas(16) IdType *tile_ids_per_batch; alignas(16) IdType *num_blocks_x; - + alignas(16) IdType *chunk_size_device; uint32_t q_stride_bsz; uint32_t q_stride_head_num; @@ -96,9 +95,7 @@ struct Params { int vo_head_dim; int block_size; int max_draft_token_num; - int chunk_size; int chunk_num; - int num_blocks_x_int; float sm_scale; }; @@ -118,7 +115,7 @@ struct Params { return cudaErrorNotSupported; \ } -template +template __global__ void __launch_bounds__(Ktraits::NUM_WARPS * cutlass::NumThreadsPerWarp, 1) MLAWithKVCacheKernel(CUTE_GRID_CONSTANT typename CollectiveMainloop::Params const mainloop_params, @@ -137,6 +134,7 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q; static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV; const int num_blocks_x = mainloop_params.num_blocks_x[0]; + const int chunk_size = mainloop_params.chunk_size_device[0]; static constexpr bool use_tma_load_kv = CollectiveMainloop::USE_TMA_LOAD_KV; @@ -205,58 +203,10 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state(); PipelineState smem_pipe_write_kv = cutlass::make_producer_start_state(); - if constexpr(USE_FIXED_BLOCK) { - for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) { - const int bid = mainloop_params.batch_ids[i]; - const int tile_id = mainloop_params.tile_ids_per_batch[i]; - const int seq_len_now = mainloop_params.seq_lens_this_time[bid]; - const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid]; - const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now; - const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid]; - cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS, - /*id=*/static_cast(NamedBarriers::kWG0WG1WG2Sync)); - - // load Q - collective_mainloop.load_q( - mainloop_params, - pipeline_q, - smem_pipe_write_q, - shared_storage, - threadIdx.x, - bid); - - if constexpr (!use_tma_load_kv) { - // load kv - collective_mainloop.load_kv( - mainloop_params, - pipeline_kv, - smem_pipe_write_kv, - shared_storage, - bid, - seq_len_decoder_now, - tile_id - ); - } else { - if (warp_idx_in_warpgroup == 0) { - // load kv tma - collective_mainloop.load_kv_tma( - mainloop_params, - pipeline_kv, - smem_pipe_write_kv, - shared_storage, - bid, - seq_len_decoder_now, - tile_id - ); - } - } - } - } else { - const int block_id = blockIdx.x; - const int bid = mainloop_params.batch_ids[block_id]; - const int tile_id = mainloop_params.tile_ids_per_batch[block_id]; + for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) { + const int bid = mainloop_params.batch_ids[i]; + const int tile_id = mainloop_params.tile_ids_per_batch[i]; const int seq_len_now = mainloop_params.seq_lens_this_time[bid]; - const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid]; const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now; const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid]; cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS, @@ -309,76 +259,12 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{})); auto attention_updater = OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(mainloop_params.sm_scale); - if constexpr(USE_FIXED_BLOCK) { - for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) { - clear(tOrO); - clear(attention_updater.scores_scale); - const int bid = mainloop_params.batch_ids[i]; - const int tile_id = mainloop_params.tile_ids_per_batch[i]; - const int seq_len_now = mainloop_params.seq_lens_this_time[bid]; - const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid]; - const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now; - const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid]; - cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS, - /*id=*/static_cast(NamedBarriers::kWG0WG1WG2Sync)); - - if constexpr (BLOCK_SHAPE_KV == 64) { - mma_f16( - mainloop_params, - pipeline_q, - smem_pipe_read_q, - pipeline_kv, - smem_pipe_read_kv, - tOrO, - attention_updater, - threadIdx.x - NUM_COPY_THREADS, - bid, - seq_len_decoder_now, - seq_len_now, - tile_id, - shared_storage); - } else if (BLOCK_SHAPE_KV == 32) { - mma_f16_two_stages( - mainloop_params, - pipeline_q, - smem_pipe_read_q, - pipeline_kv, - smem_pipe_read_kv, - tOrO, - attention_updater, - threadIdx.x - NUM_COPY_THREADS, - bid, - seq_len_decoder_now, - seq_len_now, - tile_id, - shared_storage); - } - - collective_epilogue.store( - epilogue_params, - tOrO, - attention_updater.get_lse(), - shared_storage, - tiled_mma_pv, - threadIdx.x - NUM_COPY_THREADS, - bid, - mainloop_params.bsz, - seq_len_now, - start_token_idx, - tile_id, - seq_len_decoder_now, - mainloop_params.chunk_size, - mainloop_params.max_draft_token_num, - mainloop_params.o_stride_bsz); - } - } else { - const int block_id = blockIdx.x; + for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) { clear(tOrO); clear(attention_updater.scores_scale); - const int bid = mainloop_params.batch_ids[block_id]; - const int tile_id = mainloop_params.tile_ids_per_batch[block_id]; + const int bid = mainloop_params.batch_ids[i]; + const int tile_id = mainloop_params.tile_ids_per_batch[i]; const int seq_len_now = mainloop_params.seq_lens_this_time[bid]; - const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid]; const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now; const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid]; cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS, @@ -429,15 +315,15 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT start_token_idx, tile_id, seq_len_decoder_now, - mainloop_params.chunk_size, + chunk_size, mainloop_params.max_draft_token_num, mainloop_params.o_stride_bsz); - } + } } } -template +template cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params, cudaStream_t stream) { using DTypeQ = typename KernelTraits::DTypeQ; @@ -460,12 +346,12 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params, params.d, params.block_tables, params.seq_lens_this_time, - params.seq_lens_encoder, params.seq_lens_decoder, params.cumsum_q_seqlens, params.batch_ids, params.tile_ids_per_batch, params.num_blocks_x, + params.chunk_size_device, params.sm_scale, params.bsz, params.max_block_num, @@ -476,7 +362,6 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params, params.kv_stride_block_size, params.o_stride_bsz, params.o_stride_head_num, - params.chunk_size, params.chunk_num, params.max_draft_token_num }); @@ -500,13 +385,9 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params, cudaOccupancyMaxActiveBlocksPerMultiprocessor( &act_blocks_per_sm, kernel, KernelTraits::NUM_WARPS * 32, smem_size); - int gridx; - if constexpr(USE_FIXED_BLOCK) { - gridx = multiprocessor_count; - } else { - gridx = params.num_blocks_x_int; - } - dim3 grid_dims = {gridx, 1, 1}; + // NOTE: (changwenbin) Here the grid size is fixed so that MLA can be captured + // by the graph. + dim3 grid_dims = {multiprocessor_count, 1, 1}; static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32; dim3 block_dims(ctaSize, 1, 1); kernel<<>>( @@ -517,37 +398,38 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params, constexpr int merge_block_size = 256; constexpr int blockx = KernelTraits::HEAD_DIM_VO / vec_size; constexpr int blocky = (merge_block_size + blockx - 1) / blockx; - dim3 grids_merge(min(multiprocessor_count, params.token_num), params.q_num_head); // 128k is too large + dim3 grids_merge(multiprocessor_count, params.q_num_head); // 128k is too large dim3 blocks_merge(blockx, blocky); - merge_multi_chunks_kernel<<>>( - reinterpret_cast(params.O_tmp), - params.m, - params.d, - params.seq_lens_this_time, - params.seq_lens_decoder, - params.seq_lens_encoder, - params.cumsum_q_seqlens, - params.batch_id_per_token, - reinterpret_cast(params.O), - params.chunk_num, - params.q_num_head, - params.chunk_size, - params.vo_head_dim, - params.token_num, - params.bsz, - params.max_draft_token_num - ); + merge_multi_chunks_kernel + <<>>( + reinterpret_cast(params.O_tmp), + params.m, + params.d, + params.seq_lens_this_time, + params.seq_lens_decoder, + params.cumsum_q_seqlens, + params.batch_id_per_token, + params.chunk_size_device, + reinterpret_cast(params.O), + params.q_num_head, + params.vo_head_dim, + params.token_num, + params.bsz, + params.max_draft_token_num); } return cudaSuccess; } -template +template cudaError_t BatchMLAWithPagedKVCacheDispatched(Params& params, cudaStream_t stream) { constexpr bool CAUSAL = true; if constexpr (HEAD_DIM_QK == 576) { DISPATCH_GROUP_SIZE(params.q_num_head, GROUP_SIZE, BatchMLAWithPagedKVCacheKernelTraitsDispatched< - AttentionKernelTraits -__global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [num_chunks, bsz, max_draft_token, num_heads, head_dim] - const float * __restrict__ multi_m, // [num_chunks, bsz, max_draft_token, num_heads] - const float * __restrict__ multi_d, // [num_chunks, bsz, max_draft_token, num_heads] +__global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [max_num_chunks, bsz, max_draft_token, num_heads, head_dim] + const float * __restrict__ multi_m, // [max_num_chunks, bsz, max_draft_token, num_heads] + const float * __restrict__ multi_d, // [max_num_chunks, bsz, max_draft_token, num_heads] const int * __restrict__ seq_lens_this_time, const int * __restrict__ seq_lens_decoder, - const int * __restrict__ seq_lens_encoder, const int *__restrict__ cu_seqlens_q, const int * __restrict__ batch_id_per_token, + const int * __restrict__ chunk_size_device, T * __restrict__ out, // [token_num, num_heads, head_dim] - const int num_chunks, const int num_heads, - const int chunk_size, const int head_dim, const int token_num, const int bsz, @@ -271,13 +269,15 @@ __global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [ __shared__ float md_smem[bdy * 2]; for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) { const uint32_t bid = batch_id_per_token[qid]; + // NOTE : (changwenbin) Batch_id_per_token is initialized to [:]=-1, Marking meaningless batch IDs. + if (bid == -1) continue; const int seq_len_q = seq_lens_this_time[bid]; if (seq_len_q == 0) continue; const uint32_t local_seq_id = qid - cu_seqlens_q[bid]; int seq_len_kv = seq_lens_decoder[bid]; if (seq_len_kv == 0) continue; seq_len_kv += seq_len_q; - const int num_chunks_this_seq = cute::ceil_div(seq_len_kv, chunk_size); + const int num_chunks_this_seq = cute::ceil_div(seq_len_kv, chunk_size_device[0]); if (num_chunks_this_seq <= 1) { // not need merge continue; diff --git a/custom_ops/gpu_ops/multi_head_latent_attention.cu b/custom_ops/gpu_ops/multi_head_latent_attention.cu index 98a61e838..126b014b8 100644 --- a/custom_ops/gpu_ops/multi_head_latent_attention.cu +++ b/custom_ops/gpu_ops/multi_head_latent_attention.cu @@ -22,23 +22,18 @@ std::vector MultiHeadLatentAttentionKernel( const paddle::Tensor& query, const paddle::Tensor& key_cache, const paddle::Tensor& value_cache, - const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& batch_id_per_token, const paddle::Tensor& block_tables, - const paddle::Tensor& encoder_batch_ids, - const paddle::Tensor& encoder_tile_ids_per_batch, - const paddle::Tensor& encoder_num_blocks, const paddle::Tensor& kv_batch_ids, const paddle::Tensor& kv_tile_ids_per_batch, const paddle::Tensor& kv_num_blocks, const paddle::Tensor& decoder_batch_ids, const paddle::Tensor& decoder_tile_ids_per_batch, const paddle::Tensor& decoder_num_blocks, - const paddle::Tensor& decoder_num_blocks_cpu, - const paddle::Tensor& max_enc_len_this_time, + const paddle::Tensor& decoder_chunk_size_device, const paddle::Tensor& max_dec_len_this_time, const paddle::Tensor& max_len_kv, const paddle::optional& attn_mask, @@ -64,9 +59,12 @@ std::vector MultiHeadLatentAttentionKernel( typedef PDTraits traits_; typedef typename traits_::data_t data_t; - int decoder_num_blocks_data = decoder_num_blocks_cpu.data()[0]; + // NOTE: (changwenbin) In cuda graph, it will be fixed in the capture stage + // int decoder_num_blocks_data = decoder_num_blocks_cpu.data()[0]; int max_dec_len_this_time_data = max_dec_len_this_time.data()[0]; int max_len_kv_data = max_len_kv.data()[0]; + // int chunk_size = decoder_chunk_size_cpu.data()[0]; + // const bool mla_use_tensorcore = get_mla_use_tensorcore(); auto sm_version = GetSMVersion(); @@ -96,7 +94,6 @@ std::vector MultiHeadLatentAttentionKernel( out_linear_smooths, seq_lens_this_time, seq_lens_decoder, - seq_lens_encoder, cu_seqlens_q, batch_id_per_token, block_tables, @@ -104,9 +101,8 @@ std::vector MultiHeadLatentAttentionKernel( decoder_tile_ids_per_batch, decoder_num_blocks, cache_quant_type_str, - decoder_num_blocks_data, + decoder_chunk_size_device, max_input_length, - max_len_kv_data, softmax_scale, quant_max_bound, quant_min_bound, @@ -145,23 +141,18 @@ std::vector MultiHeadLatentAttention( const paddle::Tensor& query, const paddle::Tensor& key_cache, const paddle::Tensor& value_cache, - const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& batch_id_per_token, const paddle::Tensor& block_tables, - const paddle::Tensor& encoder_batch_ids, - const paddle::Tensor& encoder_tile_ids_per_batch, - const paddle::Tensor& encoder_num_blocks, const paddle::Tensor& kv_batch_ids, const paddle::Tensor& kv_tile_ids_per_batch, const paddle::Tensor& kv_num_blocks, const paddle::Tensor& decoder_batch_ids, const paddle::Tensor& decoder_tile_ids_per_batch, const paddle::Tensor& decoder_num_blocks, - const paddle::Tensor& decoder_num_blocks_cpu, - const paddle::Tensor& max_enc_len_this_time, + const paddle::Tensor& decoder_chunk_size_device, const paddle::Tensor& max_dec_len_this_time, const paddle::Tensor& max_len_kv, const paddle::optional& attn_mask, @@ -208,23 +199,18 @@ std::vector MultiHeadLatentAttention( query, key_cache, value_cache, - seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, cu_seqlens_q, batch_id_per_token, block_tables, - encoder_batch_ids, - encoder_tile_ids_per_batch, - encoder_num_blocks, kv_batch_ids, kv_tile_ids_per_batch, kv_num_blocks, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks, - decoder_num_blocks_cpu, - max_enc_len_this_time, + decoder_chunk_size_device, max_dec_len_this_time, max_len_kv, attn_mask, @@ -254,23 +240,18 @@ std::vector MultiHeadLatentAttention( query, key_cache, value_cache, - seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, cu_seqlens_q, batch_id_per_token, block_tables, - encoder_batch_ids, - encoder_tile_ids_per_batch, - encoder_num_blocks, kv_batch_ids, kv_tile_ids_per_batch, kv_num_blocks, decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks, - decoder_num_blocks_cpu, - max_enc_len_this_time, + decoder_chunk_size_device, max_dec_len_this_time, max_len_kv, attn_mask, @@ -307,23 +288,18 @@ std::vector> MultiHeadLatentAttentionInferShape( const std::vector& query_shape, const std::vector& key_cache_shape, const std::vector& value_cache_shape, - const std::vector& seq_lens_encoder_shape, const std::vector& seq_lens_decoder_shape, const std::vector& seq_lens_this_time_shape, const std::vector& cu_seqlens_q_shape, const std::vector& batch_id_per_token_shape, const std::vector& block_tables_shape, - const std::vector& encoder_batch_ids_shape, - const std::vector& encoder_tile_ids_per_batch_shape, - const std::vector& encoder_num_blocks_shape, const std::vector& kv_batch_ids_shape, const std::vector& kv_tile_ids_per_batch_shape, const std::vector& kv_num_blocks_shape, const std::vector& decoder_batch_ids_shape, const std::vector& decoder_tile_ids_per_batch_shape, const std::vector& decoder_num_blocks_shape, - const std::vector& decoder_num_blocks_cpu_shape, - const std::vector& max_enc_len_this_time_shape, + const std::vector& decoder_chunk_size_device_shape, const std::vector& max_dec_len_this_time_shape, const std::vector& max_len_kv_shape, const paddle::optional>& attn_mask_shape, @@ -361,23 +337,18 @@ std::vector MultiHeadLatentAttentionInferDtype( const paddle::DataType& query_dtype, const paddle::DataType& key_cache_dtype, const paddle::DataType& value_cache_dtype, - const paddle::DataType& seq_lens_encoder_dtype, const paddle::DataType& seq_lens_decoder_dtype, const paddle::DataType& seq_lens_this_time_dtype, const paddle::DataType& cu_seqlens_q_dtype, const paddle::DataType& batch_id_per_token_dtype, const paddle::DataType& block_tables_dtype, - const paddle::DataType& encoder_batch_ids_dtype, - const paddle::DataType& encoder_tile_ids_per_batch_dtype, - const paddle::DataType& encoder_num_blocks_dtype, const paddle::DataType& kv_batch_ids_dtype, const paddle::DataType& kv_tile_ids_per_batch_dtype, const paddle::DataType& kv_num_blocks_dtype, const paddle::DataType& decoder_batch_ids_dtype, const paddle::DataType& decoder_tile_ids_per_batch_dtype, const paddle::DataType& decoder_num_blocks_dtype, - const paddle::DataType& decoder_num_blocks_cpu_dtype, - const paddle::DataType& max_enc_len_this_time_dtype, + const paddle::DataType& decoder_chunk_size_device_dtype, const paddle::DataType& max_dec_len_this_time_dtype, const paddle::DataType& max_len_kv_dtype, const paddle::optional& attn_mask_dtype, @@ -415,23 +386,18 @@ PD_BUILD_STATIC_OP(multi_head_latent_attention) .Inputs({"query", "key_cache", "value_cache", - "seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time", "cu_seqlens_q", "batch_id_per_token", "block_tables", - "encoder_batch_ids", - "encoder_tile_ids_per_batch", - "encoder_num_blocks", "kv_batch_ids", "kv_tile_ids_per_batch", "kv_num_blocks", "decoder_batch_ids", "decoder_tile_ids_per_batch", "decoder_num_blocks", - "decoder_num_blocks_cpu", - "max_enc_len_this_time", + "decoder_chunk_size_device", "max_dec_len_this_time", "max_len_kv", paddle::Optional("attn_mask"), diff --git a/docs/quantization/wint2.png b/docs/quantization/wint2.png deleted file mode 100644 index a117ea8af..000000000 Binary files a/docs/quantization/wint2.png and /dev/null differ diff --git a/docs/zh/quantization/wint2.png b/docs/zh/quantization/wint2.png deleted file mode 100644 index a117ea8af..000000000 Binary files a/docs/zh/quantization/wint2.png and /dev/null differ diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index c775beaaf..f0888302d 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -99,6 +99,8 @@ class ForwardMeta: decoder_batch_ids: Optional[paddle.Tensor] = None # Maps the thread block index (blockIdx.x) to the specific data tile being processed within that batch for the decoder stage in multi_query_append_attention_warp1_4_kernel. decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None + # The number of blocks that attention backend can use in decode stage + decoder_num_blocks_device: Optional[paddle.Tensor] = None # The number of CUDA blocks to launch in the x-dimension for the multi_query_append_attention_warp1_4_kernel, defining its grids.x. decoder_num_blocks_cpu: Optional[paddle.Tensor] = None # A tensor that holds multiple lengths related to prefill or decode stages. @@ -118,6 +120,8 @@ class ForwardMeta: # The maximum sequence length of the KV cache, which may represent the current maximum decoder length. max_len_kv_cpu: Optional[paddle.Tensor] = None + decoder_chunk_size_device: Optional[paddle.Tensor] = None + # Sequence length of encoder for ever batch seq_lens_encoder: Optional[paddle.Tensor] = None # Sequence length of Encoder for ever batch diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 64023e7e2..d42c4b80c 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -141,6 +141,8 @@ class AppendAttentionBackend(AttentionBackend): forward_meta.decoder_batch_ids, forward_meta.decoder_tile_ids_per_batch, forward_meta.decoder_num_blocks_cpu, + forward_meta.decoder_num_blocks_device, + forward_meta.decoder_chunk_size_device, forward_meta.max_len_tensor_cpu, forward_meta.encoder_batch_ids, forward_meta.encoder_tile_ids_per_batch, diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index c4c504368..6038fe4ca 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -198,6 +198,8 @@ class FlashAttentionBackend(AttentionBackend): forward_meta.decoder_batch_ids, forward_meta.decoder_tile_ids_per_batch, forward_meta.decoder_num_blocks_cpu, + forward_meta.decoder_num_blocks_device, + forward_meta.decoder_chunk_size_device, forward_meta.max_len_tensor_cpu, forward_meta.encoder_batch_ids, forward_meta.encoder_tile_ids_per_batch, diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 724f6eae5..896742962 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -187,9 +187,11 @@ class MLAAttentionBackend(AttentionBackend): forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, - forward_meta.decoder_batch_ids, - forward_meta.decoder_tile_ids_per_batch, + forward_meta.decoder_batch_ids, # decoder_batch_ids_per_ctax + forward_meta.decoder_tile_ids_per_batch, # decoder_chunk_ids_per_ctax_each_batch forward_meta.decoder_num_blocks_cpu, + forward_meta.decoder_num_blocks_device, + forward_meta.decoder_chunk_size_device, forward_meta.max_len_tensor_cpu, forward_meta.encoder_batch_ids, forward_meta.encoder_tile_ids_per_batch, @@ -347,23 +349,18 @@ class MLAAttentionBackend(AttentionBackend): q, latent_cache, latent_cache, - forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, forward_meta.cu_seqlens_q, forward_meta.batch_id_per_token, metadata.block_tables, - forward_meta.encoder_batch_ids, - forward_meta.encoder_tile_ids_per_batch, - forward_meta.encoder_num_blocks_x_cpu, forward_meta.kv_batch_ids, forward_meta.kv_tile_ids_per_batch, forward_meta.kv_num_blocks_x_cpu, forward_meta.decoder_batch_ids, forward_meta.decoder_tile_ids_per_batch, - forward_meta.decoder_num_blocks_cpu, - forward_meta.decoder_num_blocks_cpu, - metadata.max_enc_len_this_time, + forward_meta.decoder_num_blocks_device, + forward_meta.decoder_chunk_size_device, metadata.max_dec_len_this_time, forward_meta.max_len_kv_cpu, None, # attn_mask @@ -468,23 +465,18 @@ class MLAAttentionBackend(AttentionBackend): q, latent_cache, latent_cache, - forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, forward_meta.cu_seqlens_q, forward_meta.batch_id_per_token, metadata.block_tables, - forward_meta.encoder_batch_ids, - forward_meta.encoder_tile_ids_per_batch, - forward_meta.encoder_num_blocks_x_cpu, forward_meta.kv_batch_ids, forward_meta.kv_tile_ids_per_batch, forward_meta.kv_num_blocks_x_cpu, forward_meta.decoder_batch_ids, forward_meta.decoder_tile_ids_per_batch, - forward_meta.decoder_num_blocks_cpu, - forward_meta.decoder_num_blocks_cpu, - metadata.max_enc_len_this_time, + forward_meta.decoder_num_blocks_device, + forward_meta.decoder_chunk_size_device, metadata.max_dec_len_this_time, forward_meta.max_len_kv_cpu, None, # attn_mask diff --git a/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py b/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py index 68a7402b8..edcf8a692 100644 --- a/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py +++ b/fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py @@ -30,7 +30,9 @@ def get_block_shape_and_split_kv_block( seq_lens_this_time: paddle.Tensor, decoder_batch_ids: paddle.Tensor, decoder_tile_ids_per_batch: paddle.Tensor, - decoder_num_blocks_x_cpu: paddle.Tensor, + decoder_num_blocks_cpu: paddle.Tensor, + decoder_num_blocks_device: paddle.Tensor, + decoder_chunk_size_device: paddle.Tensor, max_len_tensor_cpu: paddle.Tensor, encoder_batch_ids: paddle.Tensor, encoder_tile_ids_per_batch: paddle.Tensor, @@ -55,7 +57,9 @@ def get_block_shape_and_split_kv_block( seq_lens_this_time, decoder_batch_ids, decoder_tile_ids_per_batch, - decoder_num_blocks_x_cpu, + decoder_num_blocks_cpu, + decoder_num_blocks_device, + decoder_chunk_size_device, max_len_tensor_cpu, encoder_batch_ids, encoder_tile_ids_per_batch, diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 41c18eb54..05408dfcc 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -210,6 +210,12 @@ class MTPProposer(Proposer): self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like( self.target_model_inputs["decoder_num_blocks_cpu"] ).pin_memory() + self.model_inputs["decoder_num_blocks_device"] = paddle.zeros_like( + self.target_model_inputs["decoder_num_blocks_device"] + ) + self.model_inputs["decoder_chunk_size_device"] = paddle.zeros_like( + self.target_model_inputs["decoder_chunk_size_device"] + ) self.model_inputs["max_len_tensor_cpu"] = paddle.zeros_like( self.target_model_inputs["max_len_tensor_cpu"] ).cpu() @@ -338,6 +344,8 @@ class MTPProposer(Proposer): self.model_inputs["decoder_batch_ids"] = None self.model_inputs["decoder_tile_ids_per_batch"] = None self.model_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory + self.model_inputs["decoder_num_blocks_device"] = None + self.model_inputs["decoder_chunk_size_device"] = None self.model_inputs["max_len_tensor_cpu"] = None # CPU self.model_inputs["encoder_batch_ids"] = None self.model_inputs["encoder_tile_ids_per_batch"] = None @@ -528,6 +536,8 @@ class MTPProposer(Proposer): decoder_batch_ids=self.model_inputs["decoder_batch_ids"], decoder_tile_ids_per_batch=self.model_inputs["decoder_tile_ids_per_batch"], decoder_num_blocks_cpu=self.model_inputs["decoder_num_blocks_cpu"], + decoder_num_blocks_device=self.model_inputs["decoder_num_blocks_device"], + decoder_chunk_size_device=self.model_inputs["decoder_chunk_size_device"], max_len_tensor_cpu=self.model_inputs["max_len_tensor_cpu"], seq_lens_encoder=self.model_inputs["seq_lens_encoder"], seq_lens_decoder=self.model_inputs["seq_lens_decoder"], diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 8bb0239d5..9a7e2bdf5 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -838,6 +838,8 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["decoder_batch_ids"] = None self.share_inputs["decoder_tile_ids_per_batch"] = None self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory + self.share_inputs["decoder_num_blocks_device"] = None + self.share_inputs["decoder_chunk_size_device"] = None self.share_inputs["max_len_tensor_cpu"] = None # CPU self.share_inputs["encoder_batch_ids"] = None self.share_inputs["encoder_tile_ids_per_batch"] = None @@ -991,6 +993,8 @@ class GPUModelRunner(ModelRunnerBase): ) self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False) + # NOTE: (changwenbin) Initialized to max_num_seq '-1' before copying, marking illegal positions + self.share_inputs["batch_id_per_token"][:] = -1 self.share_inputs["batch_id_per_token"].copy_(batch_id_per_token, False) self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) @@ -1070,6 +1074,10 @@ class GPUModelRunner(ModelRunnerBase): decoder_batch_ids=self.share_inputs["decoder_batch_ids"], decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"], decoder_num_blocks_cpu=self.share_inputs["decoder_num_blocks_cpu"], + # NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor, + # adapted to cudagraph. + decoder_num_blocks_device=self.share_inputs["decoder_num_blocks_device"], + decoder_chunk_size_device=self.share_inputs["decoder_chunk_size_device"], max_len_tensor_cpu=self.share_inputs["max_len_tensor_cpu"], seq_lens_encoder=self.share_inputs["seq_lens_encoder"], seq_lens_decoder=self.share_inputs["seq_lens_decoder"], @@ -1196,8 +1204,12 @@ class GPUModelRunner(ModelRunnerBase): decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1 group_size = np.ceil(num_heads / self.model_config.kv_num_heads) - decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil( - (decoder_step_token_num * group_size) / decoder_block_shape_q + # NOTE: (changwenbin) When using auto_chunk, + # decode_max_tile_size must take into account the maximum case, where *1024 can cover 128K. + decode_max_tile_size = ( + 1024 + * self.parallel_config.max_num_seqs + * np.ceil((decoder_step_token_num * group_size) / decoder_block_shape_q) ) encode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil( (self.model_config.max_model_len * group_size) / encoder_block_shape_q @@ -1208,6 +1220,10 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory() + # NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor, + # adapted to cudagraph. + self.share_inputs["decoder_num_blocks_device"] = paddle.full([1], 0, dtype="int32") + self.share_inputs["decoder_chunk_size_device"] = paddle.full([1], 64, dtype="int32") self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu() self.share_inputs["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") diff --git a/tests/layers/test_append_attention.py b/tests/layers/test_append_attention.py index 5b85404c7..31b12e539 100644 --- a/tests/layers/test_append_attention.py +++ b/tests/layers/test_append_attention.py @@ -380,10 +380,12 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): self.max_enc_len_this_time = paddle.to_tensor([self.max_enc_len_this_time], "int32", place=paddle.CPUPlace()) self.max_dec_len_this_time = paddle.to_tensor([self.max_dec_len_this_time], "int32", place=paddle.CPUPlace()) self.seq_lens_this_time = self.seq_lens_encoder - - self.decoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32") - self.decoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32") + decode_max_tile_size = 1024 * self.batch_size * np.ceil((2 * 10) / 12) + self.decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") + self.decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") self.decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").pin_memory() + self.decoder_num_blocks_device = paddle.full([1], 0, dtype="int32") + self.decoder_chunk_size_device = paddle.full([1], 64, dtype="int32") self.max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu() self.encoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32") @@ -484,6 +486,8 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): self.decoder_batch_ids, self.decoder_tile_ids_per_batch, self.decoder_num_blocks_cpu, + self.decoder_num_blocks_device, + self.decoder_chunk_size_device, self.max_len_tensor_cpu, self.encoder_batch_ids, self.encoder_tile_ids_per_batch, diff --git a/tests/layers/test_append_attention_with_output.py b/tests/layers/test_append_attention_with_output.py index 47cc1f384..c198d1291 100644 --- a/tests/layers/test_append_attention_with_output.py +++ b/tests/layers/test_append_attention_with_output.py @@ -378,9 +378,12 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): self.max_dec_len_this_time = paddle.to_tensor([self.max_dec_len_this_time], "int32", place=paddle.CPUPlace()) self.seq_lens_this_time = self.seq_lens_encoder - self.decoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32") - self.decoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32") + decode_max_tile_size = 1024 * self.batch_size * np.ceil((2 * 10) / 12) + self.decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") + self.decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") self.decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").pin_memory() + self.decoder_num_blocks_device = paddle.full([1], 0, dtype="int32") + self.decoder_chunk_size_device = paddle.full([1], 64, dtype="int32") self.max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu() self.encoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32") self.encoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32") @@ -464,6 +467,8 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase): self.decoder_batch_ids, self.decoder_tile_ids_per_batch, self.decoder_num_blocks_cpu, + self.decoder_num_blocks_device, + self.decoder_chunk_size_device, self.max_len_tensor_cpu, self.encoder_batch_ids, self.encoder_tile_ids_per_batch, diff --git a/tests/operators/test_tree_mask.py b/tests/operators/test_tree_mask.py index 1d8c81b12..a6bb8bd46 100644 --- a/tests/operators/test_tree_mask.py +++ b/tests/operators/test_tree_mask.py @@ -192,7 +192,7 @@ class TestTreeMask(unittest.TestCase): decoder_block_shape_q = 16 group_size = self.num_q_head // self.num_kv_head decode_max_tile_size = ( - self.bsz * (decoder_step_token_num * group_size + decoder_block_shape_q - 1) / decoder_block_shape_q + 1024 * self.bsz * (decoder_step_token_num * group_size + decoder_block_shape_q - 1) / decoder_block_shape_q ) encode_max_tile_size = ( self.bsz * (self.max_seq_len * group_size + encoder_block_shape_q - 1) / encoder_block_shape_q @@ -202,6 +202,8 @@ class TestTreeMask(unittest.TestCase): decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory() + decoder_num_blocks_device = paddle.full([1], 0, dtype="int32") + decoder_chunk_size_device = paddle.full([1], 64, dtype="int32") max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu() encoder_batch_ids = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") encoder_tile_ids_per_batch = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") @@ -222,6 +224,8 @@ class TestTreeMask(unittest.TestCase): decoder_batch_ids, decoder_tile_ids_per_batch, decoder_num_blocks, + decoder_num_blocks_device, + decoder_chunk_size_device, max_len_tensor_cpu, encoder_batch_ids, encoder_tile_ids_per_batch,