【Inference Optimize】DeepSeek-V3-model MLA Optimize (#3886)

* support MLA chunk_size auto search & cuda_graph
This commit is contained in:
AIbin
2025-09-11 10:46:09 +08:00
committed by GitHub
parent 637d96c6ae
commit a7392a0ff9
23 changed files with 375 additions and 310 deletions

View File

@@ -11,10 +11,11 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "cute/tensor.hpp"
#include "helper.h"
#include "paddle/extension.h"
#include "paddle/phi/core/memory/memcpy.h"
#include "utils.cuh"
template <int THREADBLOCK_SIZE>
__global__ void
@@ -116,6 +117,93 @@ void GetMaxLen(const paddle::Tensor &seq_lens_tensor,
max_len_tensor.data<int>(), batch_size);
}
template <uint32_t config_size>
__global__ void search_chunk_size_for_mla(
const int *__restrict__ seq_lens_q,
const int *__restrict__ seq_lens_encoder,
const int *__restrict__ seq_lens_decoder,
int *__restrict__ num_blocks_x,
int *__restrict__ res_chunk_size,
const int bsz,
const int set_chunk_size,
const int block_size,
const int sm_cout) {
const uint32_t conf_id = threadIdx.x;
int gridx = 0;
if (set_chunk_size > 0 && conf_id == 0) {
for (uint32_t bid = 0; bid < bsz; bid++) {
int seq_len = seq_lens_q[bid];
int seq_len_encoder = seq_lens_encoder[bid];
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
if (seq_len == 0 || seq_len_encoder > 0) continue;
int loop_times;
loop_times = cute::ceil_div(seq_len_decoder, set_chunk_size);
gridx += loop_times;
}
*num_blocks_x = gridx;
*res_chunk_size = set_chunk_size;
} else if (conf_id < config_size) {
__shared__ int gridx_shared[config_size];
// chunk_size is a multiple of 64
const int chunk_size = block_size << conf_id;
for (uint32_t bid = 0; bid < bsz; bid++) {
int seq_len = seq_lens_q[bid];
int seq_len_encoder = seq_lens_encoder[bid];
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
if (seq_len == 0 || seq_len_encoder > 0) continue;
int loop_times;
loop_times = cute::ceil_div(seq_len_decoder, chunk_size);
gridx += loop_times;
}
gridx_shared[conf_id] = gridx;
__syncthreads();
if (threadIdx.x == 0) {
uint32_t res_id = 0;
uint32_t max_last_wave_block = 0;
for (uint32_t i = 1; i < config_size; i++) {
uint32_t last_wave_block = gridx_shared[i] % sm_cout;
if (last_wave_block >= max_last_wave_block) {
res_id = i;
max_last_wave_block = last_wave_block;
}
}
*num_blocks_x = gridx_shared[res_id];
*res_chunk_size = block_size << res_id;
}
}
}
__global__ void split_block_for_mla(const int *__restrict__ seq_lens_q,
const int *__restrict__ seq_lens_encoder,
const int *__restrict__ seq_lens_decoder,
int *__restrict__ batch_ids,
int *__restrict__ tile_ids_per_batch,
const int bsz,
const int chunk_size) {
if (threadIdx.x == 0) {
int index = 0;
for (uint32_t bid = 0; bid < bsz; bid++) {
int seq_len = seq_lens_q[bid];
int seq_len_encoder = seq_lens_encoder[bid];
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;
if (seq_len == 0) continue;
int loop_times;
loop_times = cute::ceil_div(seq_len_decoder, chunk_size);
if (seq_len_encoder > 0) {
loop_times = 0;
}
for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
batch_ids[index] = bid;
tile_ids_per_batch[index++] = tile_id;
}
}
}
}
__global__ void split_q_block(const int *__restrict__ seq_lens_q,
const int *__restrict__ seq_lens_encoder,
int *__restrict__ batch_ids,
@@ -197,7 +285,9 @@ void GetBlockShapeAndSplitKVBlock(
const paddle::Tensor &seq_lens_this_time,
paddle::Tensor &decoder_batch_ids, // Inplace
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory
paddle::Tensor &decoder_num_blocks_device, // Inplace
paddle::Tensor &decoder_chunk_size_device, // Inplace
paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU
paddle::Tensor &encoder_batch_ids, // Inplace
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
@@ -230,8 +320,6 @@ void GetBlockShapeAndSplitKVBlock(
int max_system_len = max_len_cpu_ptr[6];
int max_just_dec_len_without_system = max_len_cpu_ptr[7];
auto max_len_kv =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
get_max_len_kv_ernel<128><<<1, 128, 0, stream>>>(
@@ -241,6 +329,106 @@ void GetBlockShapeAndSplitKVBlock(
max_len_kv_cpu.copy_(max_len_kv, max_len_kv_cpu.place(), false);
// decoder
if (max_dec_len_this_time > 0) {
const bool mla_use_tensorcore = GetMlaUseTensorcore();
if (mla_use_tensorcore && group_size <= 64) {
const int set_chunk_size = get_mla_dec_chunk_size(bsz);
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
int device;
cudaGetDevice(&device);
int sm_cout;
cudaDeviceGetAttribute(&sm_cout, cudaDevAttrMultiProcessorCount, device);
constexpr int config_size =
12; // search space for chunk size:[64, 128, 256, ... 131072]
search_chunk_size_for_mla<config_size>
<<<1, 32, 0, stream>>>(seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
decoder_num_blocks_device.data<int>(),
decoder_chunk_size_device.data<int>(),
bsz,
set_chunk_size,
block_size,
sm_cout);
decoder_num_blocks_cpu.copy_(
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
auto decoder_chunk_size_cpu =
decoder_chunk_size_device.copy_to(paddle::CPUPlace(), false);
const int chunk_size = decoder_chunk_size_cpu.data<int>()[0];
// NOTE: (changwenbin) When using auto_chunk,
// decode_max_tile_size must take into account the maximum case, where * 1024 can cover 128K.
// const uint32_t decoder_batch_shape = seq_lens_decoder.dims()[0] * 1024;
const uint32_t decoder_max_tile_size_per_bs_q =
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
const uint32_t decoder_batch_shape =
bsz * 1024 * decoder_max_tile_size_per_bs_q;
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(decoder_batch_ids.data<int>(),
0,
decoder_batch_shape * sizeof(int32_t),
stream));
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(),
0,
decoder_batch_shape * sizeof(int32_t),
stream));
split_block_for_mla<<<1, 32, 0, stream>>>(
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
decoder_batch_ids.data<int>(),
decoder_tile_ids_per_batch.data<int>(),
bsz,
chunk_size);
} else {
// Note:(changwenbin)In order to adapt to cudagraph, the maximum value should be taken here
const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
const uint32_t decoder_batch_shape = bsz * 1024 * decoder_max_tile_size_per_bs_q;
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
split_q_block<<<1, 32, 0, stream>>>(
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
decoder_batch_ids.data<int>(),
decoder_tile_ids_per_batch.data<int>(),
decoder_num_blocks_device.data<int>(),
bsz,
decoder_block_shape_q,
group_size);
decoder_num_blocks_cpu.copy_(
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
}
} else {
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
decoder_num_blocks_cpu.copy_(
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
}
// encoder
if (max_enc_len_this_time > 0) {
const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size);
const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv;
@@ -272,28 +460,6 @@ void GetBlockShapeAndSplitKVBlock(
encoder_num_blocks_x_cpu.copy_(encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false);
}
if (max_just_dec_len_this_time > 0) {
// Clear buffer
const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
const uint32_t decoder_batch_shape = bsz * decoder_max_tile_size_per_bs_q;
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_x_cpu.data<int>(), 0, sizeof(int32_t), stream));
auto decoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block<<<1, 32, 0, stream>>>(
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
decoder_batch_ids.data<int>(),
decoder_tile_ids_per_batch.data<int>(),
decoder_num_blocks_x.data<int>(),
bsz,
decoder_block_shape_q,
group_size);
decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);
}
}
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
@@ -303,7 +469,9 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
"seq_lens_this_time",
"decoder_batch_ids",
"decoder_tile_ids_per_batch",
"decoder_num_blocks_x_cpu",
"decoder_num_blocks_cpu",
"decoder_num_blocks_device",
"decoder_chunk_size_device",
"max_len_tensor_cpu",
"encoder_batch_ids",
"encoder_tile_ids_per_batch",

View File

@@ -63,7 +63,7 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::Tensor &kv_num_blocks,
const paddle::Tensor &decoder_batch_ids,
const paddle::Tensor &decoder_tile_ids_per_batch,
const paddle::Tensor &decoder_num_blocks,
const paddle::Tensor &decoder_num_blocks_cpu,
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
const paddle::optional<paddle::Tensor> &rotary_embs,
const paddle::optional<paddle::Tensor> &attn_mask,
@@ -105,7 +105,7 @@ void AppendAttentionWithOutput(
const paddle::Tensor &kv_num_blocks,
const paddle::Tensor &decoder_batch_ids,
const paddle::Tensor &decoder_tile_ids_per_batch,
const paddle::Tensor &decoder_num_blocks,
const paddle::Tensor &decoder_num_blocks_cpu,
const paddle::Tensor &set_max_lengths, const paddle::Tensor &max_len_kv,
paddle::Tensor &fmha_out,
const paddle::optional<paddle::Tensor> &rotary_embs,
@@ -305,7 +305,9 @@ void GetBlockShapeAndSplitKVBlock(
const paddle::Tensor &seq_lens_this_time,
paddle::Tensor &decoder_batch_ids, // Inplace
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory
paddle::Tensor &decoder_num_blocks_device, // Inplace
paddle::Tensor &decoder_chunk_size_device, // Inplace
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
paddle::Tensor &encoder_batch_ids, // Inplace
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
@@ -473,23 +475,18 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
const paddle::Tensor& query,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
const paddle::Tensor& encoder_num_blocks,
const paddle::Tensor& kv_batch_ids,
const paddle::Tensor& kv_tile_ids_per_batch,
const paddle::Tensor& kv_num_blocks,
const paddle::Tensor& decoder_batch_ids,
const paddle::Tensor& decoder_tile_ids_per_batch,
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& decoder_num_blocks_cpu,
const paddle::Tensor& max_enc_len_this_time,
const paddle::Tensor& decoder_num_blocks_device,
const paddle::Tensor& decoder_chunk_size_device,
const paddle::Tensor& max_dec_len_this_time,
const paddle::Tensor& max_len_kv,
const paddle::optional<paddle::Tensor>& attn_mask,

View File

@@ -59,6 +59,15 @@ inline uint32_t get_cascade_attention_num_threads() {
inline bool get_mla_use_tensorcore() {
static const char* mla_use_tensorcore_env = std::getenv("FLAGS_mla_use_tensorcore");
static const uint32_t mla_use_tensorcore =
mla_use_tensorcore_env == nullptr ? 1 : std::stoul(std::string(mla_use_tensorcore_env));
mla_use_tensorcore_env == nullptr ? 0 : std::stoul(std::string(mla_use_tensorcore_env));
return mla_use_tensorcore != 0 ? true : false;
}
inline int get_mla_dec_chunk_size(int bsz) {
static const char* mla_dec_chunk_size_env =
std::getenv("FLAGS_mla_dec_chunk_size");
static const int mla_dec_chunk_size =
mla_dec_chunk_size_env == nullptr
? -1
: std::stoi(std::string(mla_dec_chunk_size_env));
return bsz > 1 ? mla_dec_chunk_size : 64;
}

View File

@@ -563,3 +563,11 @@ inline int GetSMVersion() {
return sm_version;
}
inline bool GetMlaUseTensorcore() {
static const bool flags_mla_use_tensorcore = get_mla_use_tensorcore();
static const bool enable_mla_tensorcore = GetSMVersion() >= 90 ? true : false;
const bool mla_use_tensorcore =
flags_mla_use_tensorcore && enable_mla_tensorcore;
return mla_use_tensorcore;
}

View File

@@ -70,7 +70,6 @@ void BatchMLAWithPagedKVCacheKernel(
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables,
@@ -78,9 +77,8 @@ void BatchMLAWithPagedKVCacheKernel(
const paddle::Tensor& tile_ids_per_batch,
const paddle::Tensor& num_blocks_x_device,
const std::string& cache_quant_type_str,
const int num_blocks_x,
const paddle::Tensor& decoder_chunk_size_device,
const int max_seq_len,
const int max_dec_len,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
@@ -97,14 +95,12 @@ void BatchMLAWithPagedKVCacheKernel(
const auto q_head_num = meta_data.q_num_heads;
const auto max_block_num_per_seq = meta_data.max_blocks_per_seq;
const auto max_block_num = bsz * max_block_num_per_seq;
const uint32_t chunk_size = get_max_partition_size(bsz);
int q_head_dim = meta_data.head_dims;
int k_head_dim = meta_data.head_dims;
int v_head_dim = meta_data.head_dims_v;
// int num_chunks = max_dec_len / chunk_size;
int num_chunks = div_up(max_dec_len, chunk_size);
int num_chunks = div_up(max_seq_len, 64);
auto *allocator = paddle::GetAllocator(q.place());
phi::Allocator::AllocationPtr O_tmp, m_tmp, d_tmp;
@@ -127,14 +123,14 @@ void BatchMLAWithPagedKVCacheKernel(
params.d = reinterpret_cast<float*>(d_tmp->ptr());
params.block_tables = const_cast<int*>(block_tables.data<int>());
params.seq_lens_this_time = const_cast<int*>(seq_lens_this_time.data<int>());
params.seq_lens_encoder = const_cast<int*>(seq_lens_encoder.data<int>());
params.seq_lens_decoder = const_cast<int*>(seq_lens_decoder.data<int>());
params.cumsum_q_seqlens = const_cast<int*>(cu_seqlens_q.data<int>());
params.batch_id_per_token = const_cast<int*>(batch_id_per_token.data<int>());
params.batch_ids = const_cast<int*>(batch_ids.data<int>());
params.tile_ids_per_batch = const_cast<int*>(tile_ids_per_batch.data<int>());
params.num_blocks_x = const_cast<int*>(num_blocks_x_device.data<int>());
params.num_blocks_x_int = num_blocks_x;
params.chunk_size_device =
const_cast<int*>(decoder_chunk_size_device.data<int>());
params.q_stride_bsz = q_head_num * q_head_dim;
params.q_stride_head_num = q_head_dim;
params.kv_stride_block_num = block_size * k_head_dim;
@@ -151,7 +147,6 @@ void BatchMLAWithPagedKVCacheKernel(
params.block_size = block_size;
params.max_draft_token_num = draft_token_num;
params.sm_scale = softmax_scale;
params.chunk_size = chunk_size;
params.chunk_num = num_chunks;
if (q_head_dim == 576) {
@@ -176,7 +171,6 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::bfloat16>(
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables,
@@ -184,9 +178,8 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::bfloat16>(
const paddle::Tensor& tile_ids_per_batch,
const paddle::Tensor& num_blocks_x_device,
const std::string& cache_quant_type_str,
const int num_blocks_x,
const paddle::Tensor& decoder_chunk_size_device,
const int max_seq_len,
const int max_dec_len,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
@@ -210,7 +203,6 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::float16>(
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables,
@@ -218,9 +210,8 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::float16>(
const paddle::Tensor& tile_ids_per_batch,
const paddle::Tensor& num_blocks_x_device,
const std::string& cache_quant_type_str,
const int num_blocks_x,
const paddle::Tensor& decoder_chunk_size_device,
const int max_seq_len,
const int max_dec_len,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,

View File

@@ -47,7 +47,6 @@ void BatchMLAWithPagedKVCacheKernel(
const paddle::optional<paddle::Tensor>& smooth_weight, // [num_kv_heads, head_dim]
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables,
@@ -55,9 +54,8 @@ void BatchMLAWithPagedKVCacheKernel(
const paddle::Tensor& tile_ids_per_batch,
const paddle::Tensor& num_blocks_x_device,
const std::string& cache_quant_type_str,
const int num_blocks_x,
const paddle::Tensor& decoder_chunk_size_device,
const int max_seq_len,
const int max_dec_len,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,

View File

@@ -128,12 +128,13 @@ struct CollectiveMainloop {
DTypeMD const* d_ptr;
IdType const* kv_block_tables;
IdType const* seq_lens_this_time;
IdType const* seq_lens_encoder;
// IdType const* seq_lens_encoder;
IdType const* seq_lens_decoder;
IdType const* cumsum_q_seqlens;
IdType const* batch_ids;
IdType const* tile_ids_per_batch;
IdType const* num_blocks_x;
IdType const* chunk_size_device;
float sm_scale;
int bsz;
int max_block_num;
@@ -144,7 +145,7 @@ struct CollectiveMainloop {
int kv_stride_block_size;
int o_stride_bsz;
int o_stride_head_num;
int chunk_size;
// int chunk_size;
int chunk_num;
int max_draft_token_num;
};
@@ -160,12 +161,13 @@ struct CollectiveMainloop {
DTypeMD* d_ptr;
IdType* kv_block_tables;
IdType* seq_lens_this_time;
IdType* seq_lens_encoder;
// IdType* seq_lens_encoder;
IdType* seq_lens_decoder;
IdType* cumsum_q_seqlens;
IdType* batch_ids;
IdType* tile_ids_per_batch;
IdType* num_blocks_x;
IdType* chunk_size_device;
float sm_scale;
int bsz;
int max_block_num;
@@ -176,7 +178,7 @@ struct CollectiveMainloop {
int kv_stride_block_size;
int o_stride_bsz;
int o_stride_head_num;
int chunk_size;
// int chunk_size;
int chunk_num;
int max_draft_token_num;
TMA_KV tma_load_KV;
@@ -198,12 +200,13 @@ struct CollectiveMainloop {
const_cast<DTypeMD*>(args.d_ptr),
const_cast<IdType*>(args.kv_block_tables),
const_cast<IdType*>(args.seq_lens_this_time),
const_cast<IdType*>(args.seq_lens_encoder),
// const_cast<IdType*>(args.seq_lens_encoder),
const_cast<IdType*>(args.seq_lens_decoder),
const_cast<IdType*>(args.cumsum_q_seqlens),
const_cast<IdType*>(args.batch_ids),
const_cast<IdType*>(args.tile_ids_per_batch),
const_cast<IdType*>(args.num_blocks_x),
const_cast<IdType*>(args.chunk_size_device),
args.sm_scale,
args.bsz,
args.max_block_num,
@@ -214,7 +217,7 @@ struct CollectiveMainloop {
args.kv_stride_block_size,
args.o_stride_bsz,
args.o_stride_head_num,
args.chunk_size,
// args.chunk_size,
args.chunk_num,
args.max_draft_token_num,
tma_load_KV
@@ -281,9 +284,9 @@ struct CollectiveMainloop {
auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_slice(thread_idx);
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
const int start_len = tile_idx * mainloop_params.chunk_size;
const int start_len = tile_idx * mainloop_params.chunk_size_device[0];
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1;
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));
@@ -322,9 +325,9 @@ struct CollectiveMainloop {
group_modes<0, 2>(sK), group_modes<0, 2>(gKV));
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
const int start_len = tile_idx * mainloop_params.chunk_size;
const int start_len = tile_idx * mainloop_params.chunk_size_device[0];
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1;
auto kv_block_tables = make_tensor(make_gmem_ptr(mainloop_params.kv_block_tables), make_layout(make_shape(mainloop_params.bsz, mainloop_params.max_block_num_per_seq), make_stride(mainloop_params.max_block_num_per_seq, 1)));

View File

@@ -57,7 +57,7 @@ CUTLASS_DEVICE void mma_f16(const Params& mainloop_params,
using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage;
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size);
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size_device[0]);
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
@@ -84,9 +84,9 @@ CUTLASS_DEVICE void mma_f16(const Params& mainloop_params,
Tensor tOrV2 = threadMmaPVSS.partition_fragment_B(sVt_s2);
Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS);
const int start_len = tile_idx * mainloop_params.chunk_size;
const int start_len = tile_idx * mainloop_params.chunk_size_device[0];
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
const int end_tile_idx =cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
const int end_tile_idx =cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1;
int kv_tile_idx = end_tile_idx;
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
@@ -263,7 +263,7 @@ CUTLASS_DEVICE void mma_f16_two_stages(const Params& mainloop_params,
using SmemLayoutVtOneStage = typename Ktraits::SmemLayoutVtOneStage;
static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size);
const int chunk_num_this_seq = cute::ceil_div(kv_len, mainloop_params.chunk_size_device[0]);
static constexpr int BLOCK_SHAPE_Q = get<0>(TileShape_QKD{});
static constexpr int BLOCK_SHAPE_KV = get<1>(TileShape_QKD{});
@@ -295,9 +295,9 @@ CUTLASS_DEVICE void mma_f16_two_stages(const Params& mainloop_params,
Tensor tOrV4 = threadMmaPVSS.partition_fragment_B(sVt_s4);
Tensor tOrP_CS2 = threadMmaPVSS.partition_fragment_A(sPSS);
const int start_len = tile_idx * mainloop_params.chunk_size;
const int start_len = tile_idx * mainloop_params.chunk_size_device[0];
const int start_tile_idx = start_len / BLOCK_SHAPE_KV;
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size, kv_len), BLOCK_SHAPE_KV) - 1;
const int end_tile_idx = cute::ceil_div(min(start_len + mainloop_params.chunk_size_device[0], kv_len), BLOCK_SHAPE_KV) - 1;
int kv_tile_idx = end_tile_idx;
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {

View File

@@ -62,13 +62,12 @@ struct Params {
alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head]
alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head]
alignas(16) DTypeO *O; // [token_num, head_num, dim_head]
alignas(16) DTypeO *O_tmp; // [num_chunks, bsz, head_num, dim_head]
alignas(16) float *m; // [num_chunks, bsz * max_draft_token_num * head_num]
alignas(16) float *d; // [num_chunks, bsz * max_draft_token_num * head_num]
alignas(16) DTypeO *O_tmp; // [max_num_chunks, bsz, head_num, dim_head]
alignas(16) float *m; // [max_num_chunks, bsz * max_draft_token_num * head_num]
alignas(16) float *d; // [max_num_chunks, bsz * max_draft_token_num * head_num]
alignas(16) IdType *block_tables;
alignas(16) IdType *seq_lens_this_time;
alignas(16) IdType *seq_lens_encoder;
alignas(16) IdType *seq_lens_decoder;
alignas(16) IdType *cumsum_q_seqlens;
alignas(16) IdType *batch_id_per_token;
@@ -76,7 +75,7 @@ struct Params {
alignas(16) IdType *batch_ids;
alignas(16) IdType *tile_ids_per_batch;
alignas(16) IdType *num_blocks_x;
alignas(16) IdType *chunk_size_device;
uint32_t q_stride_bsz;
uint32_t q_stride_head_num;
@@ -96,9 +95,7 @@ struct Params {
int vo_head_dim;
int block_size;
int max_draft_token_num;
int chunk_size;
int chunk_num;
int num_blocks_x_int;
float sm_scale;
};
@@ -118,7 +115,7 @@ struct Params {
return cudaErrorNotSupported; \
}
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits, bool CAUSAL, int SM_COUNT = 132, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits, bool CAUSAL, int SM_COUNT = 132, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
__global__ void __launch_bounds__(Ktraits::NUM_WARPS * cutlass::NumThreadsPerWarp, 1)
MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
typename CollectiveMainloop::Params const mainloop_params,
@@ -137,6 +134,7 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q;
static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV;
const int num_blocks_x = mainloop_params.num_blocks_x[0];
const int chunk_size = mainloop_params.chunk_size_device[0];
static constexpr bool use_tma_load_kv = CollectiveMainloop::USE_TMA_LOAD_KV;
@@ -205,58 +203,10 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state<MainloopPipelineQ>();
PipelineState smem_pipe_write_kv = cutlass::make_producer_start_state<MainloopPipeline>();
if constexpr(USE_FIXED_BLOCK) {
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
const int bid = mainloop_params.batch_ids[i];
const int tile_id = mainloop_params.tile_ids_per_batch[i];
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
// load Q
collective_mainloop.load_q(
mainloop_params,
pipeline_q,
smem_pipe_write_q,
shared_storage,
threadIdx.x,
bid);
if constexpr (!use_tma_load_kv) {
// load kv
collective_mainloop.load_kv(
mainloop_params,
pipeline_kv,
smem_pipe_write_kv,
shared_storage,
bid,
seq_len_decoder_now,
tile_id
);
} else {
if (warp_idx_in_warpgroup == 0) {
// load kv tma
collective_mainloop.load_kv_tma(
mainloop_params,
pipeline_kv,
smem_pipe_write_kv,
shared_storage,
bid,
seq_len_decoder_now,
tile_id
);
}
}
}
} else {
const int block_id = blockIdx.x;
const int bid = mainloop_params.batch_ids[block_id];
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
const int bid = mainloop_params.batch_ids[i];
const int tile_id = mainloop_params.tile_ids_per_batch[i];
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
@@ -309,76 +259,12 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{}));
auto attention_updater = OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(mainloop_params.sm_scale);
if constexpr(USE_FIXED_BLOCK) {
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
clear(tOrO);
clear(attention_updater.scores_scale);
const int bid = mainloop_params.batch_ids[i];
const int tile_id = mainloop_params.tile_ids_per_batch[i];
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
/*id=*/static_cast<int>(NamedBarriers::kWG0WG1WG2Sync));
if constexpr (BLOCK_SHAPE_KV == 64) {
mma_f16<Ktraits, CAUSAL>(
mainloop_params,
pipeline_q,
smem_pipe_read_q,
pipeline_kv,
smem_pipe_read_kv,
tOrO,
attention_updater,
threadIdx.x - NUM_COPY_THREADS,
bid,
seq_len_decoder_now,
seq_len_now,
tile_id,
shared_storage);
} else if (BLOCK_SHAPE_KV == 32) {
mma_f16_two_stages<Ktraits, CAUSAL>(
mainloop_params,
pipeline_q,
smem_pipe_read_q,
pipeline_kv,
smem_pipe_read_kv,
tOrO,
attention_updater,
threadIdx.x - NUM_COPY_THREADS,
bid,
seq_len_decoder_now,
seq_len_now,
tile_id,
shared_storage);
}
collective_epilogue.store(
epilogue_params,
tOrO,
attention_updater.get_lse(),
shared_storage,
tiled_mma_pv,
threadIdx.x - NUM_COPY_THREADS,
bid,
mainloop_params.bsz,
seq_len_now,
start_token_idx,
tile_id,
seq_len_decoder_now,
mainloop_params.chunk_size,
mainloop_params.max_draft_token_num,
mainloop_params.o_stride_bsz);
}
} else {
const int block_id = blockIdx.x;
for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) {
clear(tOrO);
clear(attention_updater.scores_scale);
const int bid = mainloop_params.batch_ids[block_id];
const int tile_id = mainloop_params.tile_ids_per_batch[block_id];
const int bid = mainloop_params.batch_ids[i];
const int tile_id = mainloop_params.tile_ids_per_batch[i];
const int seq_len_now = mainloop_params.seq_lens_this_time[bid];
const int seq_len_encoder_now = mainloop_params.seq_lens_encoder[bid];
const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now;
const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid];
cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS,
@@ -429,15 +315,15 @@ MLAWithKVCacheKernel(CUTE_GRID_CONSTANT
start_token_idx,
tile_id,
seq_len_decoder_now,
mainloop_params.chunk_size,
chunk_size,
mainloop_params.max_draft_token_num,
mainloop_params.o_stride_bsz);
}
}
}
}
template <typename KernelTraits, bool CAUSAL, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
template <typename KernelTraits, bool CAUSAL, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
cudaStream_t stream) {
using DTypeQ = typename KernelTraits::DTypeQ;
@@ -460,12 +346,12 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
params.d,
params.block_tables,
params.seq_lens_this_time,
params.seq_lens_encoder,
params.seq_lens_decoder,
params.cumsum_q_seqlens,
params.batch_ids,
params.tile_ids_per_batch,
params.num_blocks_x,
params.chunk_size_device,
params.sm_scale,
params.bsz,
params.max_block_num,
@@ -476,7 +362,6 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
params.kv_stride_block_size,
params.o_stride_bsz,
params.o_stride_head_num,
params.chunk_size,
params.chunk_num,
params.max_draft_token_num
});
@@ -500,13 +385,9 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&act_blocks_per_sm, kernel, KernelTraits::NUM_WARPS * 32, smem_size);
int gridx;
if constexpr(USE_FIXED_BLOCK) {
gridx = multiprocessor_count;
} else {
gridx = params.num_blocks_x_int;
}
dim3 grid_dims = {gridx, 1, 1};
// NOTE: (changwenbin) Here the grid size is fixed so that MLA can be captured
// by the graph.
dim3 grid_dims = {multiprocessor_count, 1, 1};
static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32;
dim3 block_dims(ctaSize, 1, 1);
kernel<<<grid_dims, block_dims, smem_size, stream>>>(
@@ -517,37 +398,38 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
constexpr int merge_block_size = 256;
constexpr int blockx = KernelTraits::HEAD_DIM_VO / vec_size;
constexpr int blocky = (merge_block_size + blockx - 1) / blockx;
dim3 grids_merge(min(multiprocessor_count, params.token_num), params.q_num_head); // 128k is too large
dim3 grids_merge(multiprocessor_count, params.q_num_head); // 128k is too large
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_kernel<NV_TYPE, vec_size, blocky, KernelTraits::HEAD_DIM_VO><<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE*>(params.O_tmp),
params.m,
params.d,
params.seq_lens_this_time,
params.seq_lens_decoder,
params.seq_lens_encoder,
params.cumsum_q_seqlens,
params.batch_id_per_token,
reinterpret_cast<NV_TYPE*>(params.O),
params.chunk_num,
params.q_num_head,
params.chunk_size,
params.vo_head_dim,
params.token_num,
params.bsz,
params.max_draft_token_num
);
merge_multi_chunks_kernel<NV_TYPE,
vec_size,
blocky,
KernelTraits::HEAD_DIM_VO>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(params.O_tmp),
params.m,
params.d,
params.seq_lens_this_time,
params.seq_lens_decoder,
params.cumsum_q_seqlens,
params.batch_id_per_token,
params.chunk_size_device,
reinterpret_cast<NV_TYPE *>(params.O),
params.q_num_head,
params.vo_head_dim,
params.token_num,
params.bsz,
params.max_draft_token_num);
}
return cudaSuccess;
}
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, typename NV_TYPE, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=false>
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, typename NV_TYPE, typename Params, bool USE_REG_EALLOC=false, bool USE_FIXED_BLOCK=true>
cudaError_t BatchMLAWithPagedKVCacheDispatched(Params& params, cudaStream_t stream) {
constexpr bool CAUSAL = true;
if constexpr (HEAD_DIM_QK == 576) {
DISPATCH_GROUP_SIZE(params.q_num_head, GROUP_SIZE,
BatchMLAWithPagedKVCacheKernelTraitsDispatched<
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/false,
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/true,
HEAD_DIM_QK,
HEAD_DIM_VO,
GROUP_SIZE,

View File

@@ -249,18 +249,16 @@ struct prefill_softmax_state_t {
};
template <typename T, int vec_size, uint32_t bdy, uint32_t HEAD_DIM>
__global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [num_chunks, bsz, max_draft_token, num_heads, head_dim]
const float * __restrict__ multi_m, // [num_chunks, bsz, max_draft_token, num_heads]
const float * __restrict__ multi_d, // [num_chunks, bsz, max_draft_token, num_heads]
__global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [max_num_chunks, bsz, max_draft_token, num_heads, head_dim]
const float * __restrict__ multi_m, // [max_num_chunks, bsz, max_draft_token, num_heads]
const float * __restrict__ multi_d, // [max_num_chunks, bsz, max_draft_token, num_heads]
const int * __restrict__ seq_lens_this_time,
const int * __restrict__ seq_lens_decoder,
const int * __restrict__ seq_lens_encoder,
const int *__restrict__ cu_seqlens_q,
const int * __restrict__ batch_id_per_token,
const int * __restrict__ chunk_size_device,
T * __restrict__ out, // [token_num, num_heads, head_dim]
const int num_chunks,
const int num_heads,
const int chunk_size,
const int head_dim,
const int token_num,
const int bsz,
@@ -271,13 +269,15 @@ __global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [
__shared__ float md_smem[bdy * 2];
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
const uint32_t bid = batch_id_per_token[qid];
// NOTE : (changwenbin) Batch_id_per_token is initialized to [:]=-1, Marking meaningless batch IDs.
if (bid == -1) continue;
const int seq_len_q = seq_lens_this_time[bid];
if (seq_len_q == 0) continue;
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
int seq_len_kv = seq_lens_decoder[bid];
if (seq_len_kv == 0) continue;
seq_len_kv += seq_len_q;
const int num_chunks_this_seq = cute::ceil_div(seq_len_kv, chunk_size);
const int num_chunks_this_seq = cute::ceil_div(seq_len_kv, chunk_size_device[0]);
if (num_chunks_this_seq <= 1) {
// not need merge
continue;

View File

@@ -22,23 +22,18 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
const paddle::Tensor& query,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
const paddle::Tensor& encoder_num_blocks,
const paddle::Tensor& kv_batch_ids,
const paddle::Tensor& kv_tile_ids_per_batch,
const paddle::Tensor& kv_num_blocks,
const paddle::Tensor& decoder_batch_ids,
const paddle::Tensor& decoder_tile_ids_per_batch,
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& decoder_num_blocks_cpu,
const paddle::Tensor& max_enc_len_this_time,
const paddle::Tensor& decoder_chunk_size_device,
const paddle::Tensor& max_dec_len_this_time,
const paddle::Tensor& max_len_kv,
const paddle::optional<paddle::Tensor>& attn_mask,
@@ -64,9 +59,12 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
typedef PDTraits<D> traits_;
typedef typename traits_::data_t data_t;
int decoder_num_blocks_data = decoder_num_blocks_cpu.data<int>()[0];
// NOTE: (changwenbin) In cuda graph, it will be fixed in the capture stage
// int decoder_num_blocks_data = decoder_num_blocks_cpu.data<int>()[0];
int max_dec_len_this_time_data = max_dec_len_this_time.data<int>()[0];
int max_len_kv_data = max_len_kv.data<int>()[0];
// int chunk_size = decoder_chunk_size_cpu.data<int>()[0];
//
const bool mla_use_tensorcore = get_mla_use_tensorcore();
auto sm_version = GetSMVersion();
@@ -96,7 +94,6 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
out_linear_smooths,
seq_lens_this_time,
seq_lens_decoder,
seq_lens_encoder,
cu_seqlens_q,
batch_id_per_token,
block_tables,
@@ -104,9 +101,8 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
decoder_tile_ids_per_batch,
decoder_num_blocks,
cache_quant_type_str,
decoder_num_blocks_data,
decoder_chunk_size_device,
max_input_length,
max_len_kv_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
@@ -145,23 +141,18 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
const paddle::Tensor& query,
const paddle::Tensor& key_cache,
const paddle::Tensor& value_cache,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
const paddle::Tensor& encoder_num_blocks,
const paddle::Tensor& kv_batch_ids,
const paddle::Tensor& kv_tile_ids_per_batch,
const paddle::Tensor& kv_num_blocks,
const paddle::Tensor& decoder_batch_ids,
const paddle::Tensor& decoder_tile_ids_per_batch,
const paddle::Tensor& decoder_num_blocks,
const paddle::Tensor& decoder_num_blocks_cpu,
const paddle::Tensor& max_enc_len_this_time,
const paddle::Tensor& decoder_chunk_size_device,
const paddle::Tensor& max_dec_len_this_time,
const paddle::Tensor& max_len_kv,
const paddle::optional<paddle::Tensor>& attn_mask,
@@ -208,23 +199,18 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
query,
key_cache,
value_cache,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
cu_seqlens_q,
batch_id_per_token,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
decoder_num_blocks_cpu,
max_enc_len_this_time,
decoder_chunk_size_device,
max_dec_len_this_time,
max_len_kv,
attn_mask,
@@ -254,23 +240,18 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
query,
key_cache,
value_cache,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
cu_seqlens_q,
batch_id_per_token,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
decoder_num_blocks_cpu,
max_enc_len_this_time,
decoder_chunk_size_device,
max_dec_len_this_time,
max_len_kv,
attn_mask,
@@ -307,23 +288,18 @@ std::vector<std::vector<int64_t>> MultiHeadLatentAttentionInferShape(
const std::vector<int64_t>& query_shape,
const std::vector<int64_t>& key_cache_shape,
const std::vector<int64_t>& value_cache_shape,
const std::vector<int64_t>& seq_lens_encoder_shape,
const std::vector<int64_t>& seq_lens_decoder_shape,
const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& cu_seqlens_q_shape,
const std::vector<int64_t>& batch_id_per_token_shape,
const std::vector<int64_t>& block_tables_shape,
const std::vector<int64_t>& encoder_batch_ids_shape,
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
const std::vector<int64_t>& encoder_num_blocks_shape,
const std::vector<int64_t>& kv_batch_ids_shape,
const std::vector<int64_t>& kv_tile_ids_per_batch_shape,
const std::vector<int64_t>& kv_num_blocks_shape,
const std::vector<int64_t>& decoder_batch_ids_shape,
const std::vector<int64_t>& decoder_tile_ids_per_batch_shape,
const std::vector<int64_t>& decoder_num_blocks_shape,
const std::vector<int64_t>& decoder_num_blocks_cpu_shape,
const std::vector<int64_t>& max_enc_len_this_time_shape,
const std::vector<int64_t>& decoder_chunk_size_device_shape,
const std::vector<int64_t>& max_dec_len_this_time_shape,
const std::vector<int64_t>& max_len_kv_shape,
const paddle::optional<std::vector<int64_t>>& attn_mask_shape,
@@ -361,23 +337,18 @@ std::vector<paddle::DataType> MultiHeadLatentAttentionInferDtype(
const paddle::DataType& query_dtype,
const paddle::DataType& key_cache_dtype,
const paddle::DataType& value_cache_dtype,
const paddle::DataType& seq_lens_encoder_dtype,
const paddle::DataType& seq_lens_decoder_dtype,
const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& cu_seqlens_q_dtype,
const paddle::DataType& batch_id_per_token_dtype,
const paddle::DataType& block_tables_dtype,
const paddle::DataType& encoder_batch_ids_dtype,
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
const paddle::DataType& encoder_num_blocks_dtype,
const paddle::DataType& kv_batch_ids_dtype,
const paddle::DataType& kv_tile_ids_per_batch_dtype,
const paddle::DataType& kv_num_blocks_dtype,
const paddle::DataType& decoder_batch_ids_dtype,
const paddle::DataType& decoder_tile_ids_per_batch_dtype,
const paddle::DataType& decoder_num_blocks_dtype,
const paddle::DataType& decoder_num_blocks_cpu_dtype,
const paddle::DataType& max_enc_len_this_time_dtype,
const paddle::DataType& decoder_chunk_size_device_dtype,
const paddle::DataType& max_dec_len_this_time_dtype,
const paddle::DataType& max_len_kv_dtype,
const paddle::optional<paddle::DataType>& attn_mask_dtype,
@@ -415,23 +386,18 @@ PD_BUILD_STATIC_OP(multi_head_latent_attention)
.Inputs({"query",
"key_cache",
"value_cache",
"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
"cu_seqlens_q",
"batch_id_per_token",
"block_tables",
"encoder_batch_ids",
"encoder_tile_ids_per_batch",
"encoder_num_blocks",
"kv_batch_ids",
"kv_tile_ids_per_batch",
"kv_num_blocks",
"decoder_batch_ids",
"decoder_tile_ids_per_batch",
"decoder_num_blocks",
"decoder_num_blocks_cpu",
"max_enc_len_this_time",
"decoder_chunk_size_device",
"max_dec_len_this_time",
"max_len_kv",
paddle::Optional("attn_mask"),

Binary file not shown.

Before

Width:  |  Height:  |  Size: 81 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 81 KiB

View File

@@ -99,6 +99,8 @@ class ForwardMeta:
decoder_batch_ids: Optional[paddle.Tensor] = None
# Maps the thread block index (blockIdx.x) to the specific data tile being processed within that batch for the decoder stage in multi_query_append_attention_warp1_4_kernel.
decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None
# The number of blocks that attention backend can use in decode stage
decoder_num_blocks_device: Optional[paddle.Tensor] = None
# The number of CUDA blocks to launch in the x-dimension for the multi_query_append_attention_warp1_4_kernel, defining its grids.x.
decoder_num_blocks_cpu: Optional[paddle.Tensor] = None
# A tensor that holds multiple lengths related to prefill or decode stages.
@@ -118,6 +120,8 @@ class ForwardMeta:
# The maximum sequence length of the KV cache, which may represent the current maximum decoder length.
max_len_kv_cpu: Optional[paddle.Tensor] = None
decoder_chunk_size_device: Optional[paddle.Tensor] = None
# Sequence length of encoder for ever batch
seq_lens_encoder: Optional[paddle.Tensor] = None
# Sequence length of Encoder for ever batch

View File

@@ -141,6 +141,8 @@ class AppendAttentionBackend(AttentionBackend):
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.decoder_num_blocks_device,
forward_meta.decoder_chunk_size_device,
forward_meta.max_len_tensor_cpu,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,

View File

@@ -198,6 +198,8 @@ class FlashAttentionBackend(AttentionBackend):
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.decoder_num_blocks_device,
forward_meta.decoder_chunk_size_device,
forward_meta.max_len_tensor_cpu,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,

View File

@@ -187,9 +187,11 @@ class MLAAttentionBackend(AttentionBackend):
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_batch_ids, # decoder_batch_ids_per_ctax
forward_meta.decoder_tile_ids_per_batch, # decoder_chunk_ids_per_ctax_each_batch
forward_meta.decoder_num_blocks_cpu,
forward_meta.decoder_num_blocks_device,
forward_meta.decoder_chunk_size_device,
forward_meta.max_len_tensor_cpu,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
@@ -347,23 +349,18 @@ class MLAAttentionBackend(AttentionBackend):
q,
latent_cache,
latent_cache,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q,
forward_meta.batch_id_per_token,
metadata.block_tables,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.decoder_num_blocks_cpu,
metadata.max_enc_len_this_time,
forward_meta.decoder_num_blocks_device,
forward_meta.decoder_chunk_size_device,
metadata.max_dec_len_this_time,
forward_meta.max_len_kv_cpu,
None, # attn_mask
@@ -468,23 +465,18 @@ class MLAAttentionBackend(AttentionBackend):
q,
latent_cache,
latent_cache,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q,
forward_meta.batch_id_per_token,
metadata.block_tables,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.decoder_num_blocks_cpu,
metadata.max_enc_len_this_time,
forward_meta.decoder_num_blocks_device,
forward_meta.decoder_chunk_size_device,
metadata.max_dec_len_this_time,
forward_meta.max_len_kv_cpu,
None, # attn_mask

View File

@@ -30,7 +30,9 @@ def get_block_shape_and_split_kv_block(
seq_lens_this_time: paddle.Tensor,
decoder_batch_ids: paddle.Tensor,
decoder_tile_ids_per_batch: paddle.Tensor,
decoder_num_blocks_x_cpu: paddle.Tensor,
decoder_num_blocks_cpu: paddle.Tensor,
decoder_num_blocks_device: paddle.Tensor,
decoder_chunk_size_device: paddle.Tensor,
max_len_tensor_cpu: paddle.Tensor,
encoder_batch_ids: paddle.Tensor,
encoder_tile_ids_per_batch: paddle.Tensor,
@@ -55,7 +57,9 @@ def get_block_shape_and_split_kv_block(
seq_lens_this_time,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks_x_cpu,
decoder_num_blocks_cpu,
decoder_num_blocks_device,
decoder_chunk_size_device,
max_len_tensor_cpu,
encoder_batch_ids,
encoder_tile_ids_per_batch,

View File

@@ -210,6 +210,12 @@ class MTPProposer(Proposer):
self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like(
self.target_model_inputs["decoder_num_blocks_cpu"]
).pin_memory()
self.model_inputs["decoder_num_blocks_device"] = paddle.zeros_like(
self.target_model_inputs["decoder_num_blocks_device"]
)
self.model_inputs["decoder_chunk_size_device"] = paddle.zeros_like(
self.target_model_inputs["decoder_chunk_size_device"]
)
self.model_inputs["max_len_tensor_cpu"] = paddle.zeros_like(
self.target_model_inputs["max_len_tensor_cpu"]
).cpu()
@@ -338,6 +344,8 @@ class MTPProposer(Proposer):
self.model_inputs["decoder_batch_ids"] = None
self.model_inputs["decoder_tile_ids_per_batch"] = None
self.model_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
self.model_inputs["decoder_num_blocks_device"] = None
self.model_inputs["decoder_chunk_size_device"] = None
self.model_inputs["max_len_tensor_cpu"] = None # CPU
self.model_inputs["encoder_batch_ids"] = None
self.model_inputs["encoder_tile_ids_per_batch"] = None
@@ -528,6 +536,8 @@ class MTPProposer(Proposer):
decoder_batch_ids=self.model_inputs["decoder_batch_ids"],
decoder_tile_ids_per_batch=self.model_inputs["decoder_tile_ids_per_batch"],
decoder_num_blocks_cpu=self.model_inputs["decoder_num_blocks_cpu"],
decoder_num_blocks_device=self.model_inputs["decoder_num_blocks_device"],
decoder_chunk_size_device=self.model_inputs["decoder_chunk_size_device"],
max_len_tensor_cpu=self.model_inputs["max_len_tensor_cpu"],
seq_lens_encoder=self.model_inputs["seq_lens_encoder"],
seq_lens_decoder=self.model_inputs["seq_lens_decoder"],

View File

@@ -838,6 +838,8 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["decoder_batch_ids"] = None
self.share_inputs["decoder_tile_ids_per_batch"] = None
self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
self.share_inputs["decoder_num_blocks_device"] = None
self.share_inputs["decoder_chunk_size_device"] = None
self.share_inputs["max_len_tensor_cpu"] = None # CPU
self.share_inputs["encoder_batch_ids"] = None
self.share_inputs["encoder_tile_ids_per_batch"] = None
@@ -991,6 +993,8 @@ class GPUModelRunner(ModelRunnerBase):
)
self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
# NOTE: (changwenbin) Initialized to max_num_seq '-1' before copying, marking illegal positions
self.share_inputs["batch_id_per_token"][:] = -1
self.share_inputs["batch_id_per_token"].copy_(batch_id_per_token, False)
self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
@@ -1070,6 +1074,10 @@ class GPUModelRunner(ModelRunnerBase):
decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
decoder_num_blocks_cpu=self.share_inputs["decoder_num_blocks_cpu"],
# NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor,
# adapted to cudagraph.
decoder_num_blocks_device=self.share_inputs["decoder_num_blocks_device"],
decoder_chunk_size_device=self.share_inputs["decoder_chunk_size_device"],
max_len_tensor_cpu=self.share_inputs["max_len_tensor_cpu"],
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
@@ -1196,8 +1204,12 @@ class GPUModelRunner(ModelRunnerBase):
decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1
group_size = np.ceil(num_heads / self.model_config.kv_num_heads)
decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
(decoder_step_token_num * group_size) / decoder_block_shape_q
# NOTE: (changwenbin) When using auto_chunk,
# decode_max_tile_size must take into account the maximum case, where *1024 can cover 128K.
decode_max_tile_size = (
1024
* self.parallel_config.max_num_seqs
* np.ceil((decoder_step_token_num * group_size) / decoder_block_shape_q)
)
encode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
(self.model_config.max_model_len * group_size) / encoder_block_shape_q
@@ -1208,6 +1220,10 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory()
# NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor,
# adapted to cudagraph.
self.share_inputs["decoder_num_blocks_device"] = paddle.full([1], 0, dtype="int32")
self.share_inputs["decoder_chunk_size_device"] = paddle.full([1], 64, dtype="int32")
self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu()
self.share_inputs["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")

View File

@@ -380,10 +380,12 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.max_enc_len_this_time = paddle.to_tensor([self.max_enc_len_this_time], "int32", place=paddle.CPUPlace())
self.max_dec_len_this_time = paddle.to_tensor([self.max_dec_len_this_time], "int32", place=paddle.CPUPlace())
self.seq_lens_this_time = self.seq_lens_encoder
self.decoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32")
self.decoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
decode_max_tile_size = 1024 * self.batch_size * np.ceil((2 * 10) / 12)
self.decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
self.decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
self.decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").pin_memory()
self.decoder_num_blocks_device = paddle.full([1], 0, dtype="int32")
self.decoder_chunk_size_device = paddle.full([1], 64, dtype="int32")
self.max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
self.encoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32")
@@ -484,6 +486,8 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.decoder_batch_ids,
self.decoder_tile_ids_per_batch,
self.decoder_num_blocks_cpu,
self.decoder_num_blocks_device,
self.decoder_chunk_size_device,
self.max_len_tensor_cpu,
self.encoder_batch_ids,
self.encoder_tile_ids_per_batch,

View File

@@ -378,9 +378,12 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.max_dec_len_this_time = paddle.to_tensor([self.max_dec_len_this_time], "int32", place=paddle.CPUPlace())
self.seq_lens_this_time = self.seq_lens_encoder
self.decoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32")
self.decoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
decode_max_tile_size = 1024 * self.batch_size * np.ceil((2 * 10) / 12)
self.decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
self.decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
self.decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").pin_memory()
self.decoder_num_blocks_device = paddle.full([1], 0, dtype="int32")
self.decoder_chunk_size_device = paddle.full([1], 64, dtype="int32")
self.max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
self.encoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32")
self.encoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
@@ -464,6 +467,8 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.decoder_batch_ids,
self.decoder_tile_ids_per_batch,
self.decoder_num_blocks_cpu,
self.decoder_num_blocks_device,
self.decoder_chunk_size_device,
self.max_len_tensor_cpu,
self.encoder_batch_ids,
self.encoder_tile_ids_per_batch,

View File

@@ -192,7 +192,7 @@ class TestTreeMask(unittest.TestCase):
decoder_block_shape_q = 16
group_size = self.num_q_head // self.num_kv_head
decode_max_tile_size = (
self.bsz * (decoder_step_token_num * group_size + decoder_block_shape_q - 1) / decoder_block_shape_q
1024 * self.bsz * (decoder_step_token_num * group_size + decoder_block_shape_q - 1) / decoder_block_shape_q
)
encode_max_tile_size = (
self.bsz * (self.max_seq_len * group_size + encoder_block_shape_q - 1) / encoder_block_shape_q
@@ -202,6 +202,8 @@ class TestTreeMask(unittest.TestCase):
decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory()
decoder_num_blocks_device = paddle.full([1], 0, dtype="int32")
decoder_chunk_size_device = paddle.full([1], 64, dtype="int32")
max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
encoder_batch_ids = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
encoder_tile_ids_per_batch = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
@@ -222,6 +224,8 @@ class TestTreeMask(unittest.TestCase):
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
decoder_num_blocks_device,
decoder_chunk_size_device,
max_len_tensor_cpu,
encoder_batch_ids,
encoder_tile_ids_per_batch,