mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
[Excutor] Experiment Feature-Support Prefill in cudagraph (#3459)
* Support prefill in Cudagraph * Refactor GetBlockShapeAndSplitKVBlock Kernel V2 * Refactor GetBlockShapeAndSplitKVBlock Kernel V2.1 * Refactor GetBlockShapeAndSplitKVBlock Kernel V2.2 * Refactor GetBlockShapeAndSplitKVBlock Kernel V2.3 * Refactor GetBlockShapeAndSplitKVBlock Kernel V2.4 * Refactor GetBlockShapeAndSplitKVBlock Kernel V2.5 * Solve problem about encoder_num_blocks_x_cpu * Add early-exit mechanism for attention kernel * fix test case about append-attention * Update testcode, Add annotations to related tensors * move get_input_length_list * solve test_code * Add annotations about early-exit for attention kernel * Add annotations about early-exit for attention kernel2 * solve comment * solve mtp --------- Co-authored-by: RAM <gstian5555@outlook.com>
This commit is contained in:
@@ -52,6 +52,7 @@ __global__ void multi_query_append_attention_kernel(
|
|||||||
const float quant_min_bound,
|
const float quant_min_bound,
|
||||||
const float in_scale,
|
const float in_scale,
|
||||||
const uint32_t chunk_size,
|
const uint32_t chunk_size,
|
||||||
|
const int num_blocks_x_cpu,
|
||||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||||
// num_heads, head_dim]
|
// num_heads, head_dim]
|
||||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||||
@@ -74,6 +75,11 @@ __global__ void multi_query_append_attention_kernel(
|
|||||||
|
|
||||||
block_table_now = block_table + batch_id * max_block_num_per_seq;
|
block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||||
|
|
||||||
|
//When cudagraph capture prefill, may launch more gridDim.x
|
||||||
|
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const uint32_t q_len = seq_lens[batch_id];
|
const uint32_t q_len = seq_lens[batch_id];
|
||||||
if (q_len <= 0) {
|
if (q_len <= 0) {
|
||||||
return;
|
return;
|
||||||
@@ -422,6 +428,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
|||||||
const float quant_min_bound,
|
const float quant_min_bound,
|
||||||
const float in_scale,
|
const float in_scale,
|
||||||
const uint32_t chunk_size,
|
const uint32_t chunk_size,
|
||||||
|
const int num_blocks_x_cpu,
|
||||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||||
// num_heads, head_dim]
|
// num_heads, head_dim]
|
||||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||||
@@ -445,6 +452,11 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
|||||||
const uint32_t num_rows_per_block = num_frags_x * 16;
|
const uint32_t num_rows_per_block = num_frags_x * 16;
|
||||||
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
|
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||||
|
|
||||||
|
//When cudagraph capture prefill, may launch more gridDim.x
|
||||||
|
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const uint32_t q_len = seq_lens[batch_id];
|
const uint32_t q_len = seq_lens[batch_id];
|
||||||
if (q_len <= 0) {
|
if (q_len <= 0) {
|
||||||
return;
|
return;
|
||||||
@@ -902,6 +914,7 @@ void MultiQueryAppendAttention(
|
|||||||
quant_min_bound,
|
quant_min_bound,
|
||||||
in_scale,
|
in_scale,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
|
num_blocks_x_cpu,
|
||||||
nullptr,
|
nullptr,
|
||||||
nullptr,
|
nullptr,
|
||||||
nullptr,
|
nullptr,
|
||||||
@@ -960,6 +973,7 @@ void MultiQueryAppendAttention(
|
|||||||
quant_min_bound,
|
quant_min_bound,
|
||||||
in_scale,
|
in_scale,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
|
num_blocks_x_cpu,
|
||||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||||
static_cast<float *>(tmp_m->ptr()),
|
static_cast<float *>(tmp_m->ptr()),
|
||||||
static_cast<float *>(tmp_d->ptr()),
|
static_cast<float *>(tmp_d->ptr()),
|
||||||
@@ -1134,6 +1148,7 @@ void MultiQueryAppendAttention(
|
|||||||
quant_min_bound,
|
quant_min_bound,
|
||||||
in_scale,
|
in_scale,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
|
num_blocks_x_cpu,
|
||||||
nullptr,
|
nullptr,
|
||||||
nullptr,
|
nullptr,
|
||||||
nullptr,
|
nullptr,
|
||||||
@@ -1206,6 +1221,7 @@ void MultiQueryAppendAttention(
|
|||||||
quant_min_bound,
|
quant_min_bound,
|
||||||
in_scale,
|
in_scale,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
|
num_blocks_x_cpu,
|
||||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||||
static_cast<float *>(tmp_m->ptr()),
|
static_cast<float *>(tmp_m->ptr()),
|
||||||
static_cast<float *>(tmp_d->ptr()),
|
static_cast<float *>(tmp_d->ptr()),
|
||||||
|
@@ -57,6 +57,7 @@ __global__ void multi_query_append_attention_c4_kernel(
|
|||||||
const float quant_min_bound,
|
const float quant_min_bound,
|
||||||
const float in_scale,
|
const float in_scale,
|
||||||
const uint32_t chunk_size,
|
const uint32_t chunk_size,
|
||||||
|
const int num_blocks_x_cpu,
|
||||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||||
// num_heads, head_dim]
|
// num_heads, head_dim]
|
||||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||||
@@ -85,6 +86,11 @@ __global__ void multi_query_append_attention_c4_kernel(
|
|||||||
|
|
||||||
block_table_now = block_table + batch_id * max_block_num_per_seq;
|
block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||||
|
|
||||||
|
//When cudagraph capture prefill, may launch more gridDim.x
|
||||||
|
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const uint32_t q_len = seq_lens[batch_id];
|
const uint32_t q_len = seq_lens[batch_id];
|
||||||
if (q_len <= 0) {
|
if (q_len <= 0) {
|
||||||
return;
|
return;
|
||||||
@@ -520,6 +526,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
|||||||
const float quant_min_bound,
|
const float quant_min_bound,
|
||||||
const float in_scale,
|
const float in_scale,
|
||||||
const uint32_t chunk_size,
|
const uint32_t chunk_size,
|
||||||
|
const int num_blocks_x_cpu,
|
||||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||||
// num_heads, head_dim]
|
// num_heads, head_dim]
|
||||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||||
@@ -549,6 +556,11 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
|||||||
const uint32_t num_rows_per_block = num_frags_x * 16;
|
const uint32_t num_rows_per_block = num_frags_x * 16;
|
||||||
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
|
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||||
|
|
||||||
|
//When cudagraph capture prefill, may launch more gridDim.x
|
||||||
|
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const uint32_t q_len = seq_lens[batch_id];
|
const uint32_t q_len = seq_lens[batch_id];
|
||||||
if (q_len <= 0) {
|
if (q_len <= 0) {
|
||||||
return;
|
return;
|
||||||
@@ -1107,6 +1119,7 @@ void MultiQueryAppendC4Attention(
|
|||||||
quant_min_bound,
|
quant_min_bound,
|
||||||
in_scale,
|
in_scale,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
|
num_blocks_x_cpu,
|
||||||
nullptr,
|
nullptr,
|
||||||
nullptr,
|
nullptr,
|
||||||
nullptr,
|
nullptr,
|
||||||
@@ -1171,6 +1184,7 @@ void MultiQueryAppendC4Attention(
|
|||||||
quant_min_bound,
|
quant_min_bound,
|
||||||
in_scale,
|
in_scale,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
|
num_blocks_x_cpu,
|
||||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||||
static_cast<float *>(tmp_m->ptr()),
|
static_cast<float *>(tmp_m->ptr()),
|
||||||
static_cast<float *>(tmp_d->ptr()),
|
static_cast<float *>(tmp_d->ptr()),
|
||||||
@@ -1365,6 +1379,7 @@ void MultiQueryAppendC4Attention(
|
|||||||
quant_min_bound,
|
quant_min_bound,
|
||||||
in_scale,
|
in_scale,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
|
num_blocks_x_cpu,
|
||||||
nullptr,
|
nullptr,
|
||||||
nullptr,
|
nullptr,
|
||||||
nullptr,
|
nullptr,
|
||||||
@@ -1445,6 +1460,7 @@ void MultiQueryAppendC4Attention(
|
|||||||
quant_min_bound,
|
quant_min_bound,
|
||||||
in_scale,
|
in_scale,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
|
num_blocks_x_cpu,
|
||||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||||
static_cast<float *>(tmp_m->ptr()),
|
static_cast<float *>(tmp_m->ptr()),
|
||||||
static_cast<float *>(tmp_d->ptr()),
|
static_cast<float *>(tmp_d->ptr()),
|
||||||
|
@@ -58,6 +58,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
|||||||
const float quant_min_bound,
|
const float quant_min_bound,
|
||||||
const float in_scale,
|
const float in_scale,
|
||||||
const uint32_t chunk_size,
|
const uint32_t chunk_size,
|
||||||
|
const int num_blocks_x_cpu,
|
||||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||||
// num_heads, head_dim]
|
// num_heads, head_dim]
|
||||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||||
@@ -87,6 +88,11 @@ __global__ void multi_query_append_attention_c8_kernel(
|
|||||||
|
|
||||||
block_table_now = block_table + batch_id * max_block_num_per_seq;
|
block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||||
|
|
||||||
|
//When cudagraph capture prefill, may launch more gridDim.x
|
||||||
|
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const uint32_t q_len = seq_lens[batch_id];
|
const uint32_t q_len = seq_lens[batch_id];
|
||||||
if (q_len <= 0) {
|
if (q_len <= 0) {
|
||||||
return;
|
return;
|
||||||
@@ -527,6 +533,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
|||||||
const float quant_min_bound,
|
const float quant_min_bound,
|
||||||
const float in_scale,
|
const float in_scale,
|
||||||
const uint32_t chunk_size,
|
const uint32_t chunk_size,
|
||||||
|
const int num_blocks_x_cpu,
|
||||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||||
// num_heads, head_dim]
|
// num_heads, head_dim]
|
||||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||||
@@ -556,6 +563,11 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
|||||||
const uint32_t num_rows_per_block = num_frags_x * 16;
|
const uint32_t num_rows_per_block = num_frags_x * 16;
|
||||||
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
|
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||||
|
|
||||||
|
//When cudagraph capture prefill, may launch more gridDim.x
|
||||||
|
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const uint32_t q_len = seq_lens[batch_id];
|
const uint32_t q_len = seq_lens[batch_id];
|
||||||
if (q_len <= 0) {
|
if (q_len <= 0) {
|
||||||
return;
|
return;
|
||||||
@@ -1159,6 +1171,7 @@ void MultiQueryAppendC8Attention(
|
|||||||
quant_min_bound,
|
quant_min_bound,
|
||||||
in_scale,
|
in_scale,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
|
num_blocks_x_cpu,
|
||||||
nullptr,
|
nullptr,
|
||||||
nullptr,
|
nullptr,
|
||||||
nullptr,
|
nullptr,
|
||||||
@@ -1217,6 +1230,7 @@ void MultiQueryAppendC8Attention(
|
|||||||
quant_min_bound,
|
quant_min_bound,
|
||||||
in_scale,
|
in_scale,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
|
num_blocks_x_cpu,
|
||||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||||
static_cast<float *>(tmp_m->ptr()),
|
static_cast<float *>(tmp_m->ptr()),
|
||||||
static_cast<float *>(tmp_d->ptr()),
|
static_cast<float *>(tmp_d->ptr()),
|
||||||
@@ -1443,6 +1457,7 @@ void MultiQueryAppendC8Attention(
|
|||||||
quant_min_bound,
|
quant_min_bound,
|
||||||
in_scale,
|
in_scale,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
|
num_blocks_x_cpu,
|
||||||
nullptr,
|
nullptr,
|
||||||
nullptr,
|
nullptr,
|
||||||
nullptr,
|
nullptr,
|
||||||
@@ -1517,6 +1532,7 @@ void MultiQueryAppendC8Attention(
|
|||||||
quant_min_bound,
|
quant_min_bound,
|
||||||
in_scale,
|
in_scale,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
|
num_blocks_x_cpu,
|
||||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||||
static_cast<float *>(tmp_m->ptr()),
|
static_cast<float *>(tmp_m->ptr()),
|
||||||
static_cast<float *>(tmp_d->ptr()),
|
static_cast<float *>(tmp_d->ptr()),
|
||||||
|
@@ -191,14 +191,21 @@ get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
void GetBlockShapeAndSplitKVBlock(
|
||||||
const paddle::Tensor &seq_lens_encoder,
|
const paddle::Tensor &seq_lens_encoder,
|
||||||
const paddle::Tensor &seq_lens_decoder,
|
const paddle::Tensor &seq_lens_decoder,
|
||||||
const paddle::Tensor &seq_lens_this_time,
|
const paddle::Tensor &seq_lens_this_time,
|
||||||
paddle::Tensor &decoder_batch_ids, // Inplace
|
paddle::Tensor &decoder_batch_ids, // Inplace
|
||||||
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||||
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
|
paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU
|
||||||
|
paddle::Tensor &encoder_batch_ids, // Inplace
|
||||||
|
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
|
||||||
|
paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, CPU
|
||||||
|
paddle::Tensor &kv_batch_ids, // Inplace
|
||||||
|
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
|
||||||
|
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, CPU
|
||||||
|
paddle::Tensor &max_len_kv_cpu, // Inplace, CPU
|
||||||
const int encoder_block_shape_q,
|
const int encoder_block_shape_q,
|
||||||
const int decoder_block_shape_q,
|
const int decoder_block_shape_q,
|
||||||
const int group_size,
|
const int group_size,
|
||||||
@@ -223,13 +230,7 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
|||||||
int max_system_len = max_len_cpu_ptr[6];
|
int max_system_len = max_len_cpu_ptr[6];
|
||||||
int max_just_dec_len_without_system = max_len_cpu_ptr[7];
|
int max_just_dec_len_without_system = max_len_cpu_ptr[7];
|
||||||
|
|
||||||
paddle::Tensor encoder_batch_ids;
|
|
||||||
paddle::Tensor encoder_tile_ids_per_batch;
|
|
||||||
paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/
|
|
||||||
paddle::Tensor kv_batch_ids;
|
|
||||||
paddle::Tensor kv_tile_ids_per_batch;
|
|
||||||
paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/
|
|
||||||
paddle::Tensor max_len_kv_cpu; /*cpu*/
|
|
||||||
|
|
||||||
auto max_len_kv =
|
auto max_len_kv =
|
||||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
|
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
|
||||||
@@ -237,17 +238,14 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
|||||||
max_len_kv.data<int>(), seq_lens_this_time.data<int>(),
|
max_len_kv.data<int>(), seq_lens_this_time.data<int>(),
|
||||||
seq_lens_decoder.data<int>(), bsz);
|
seq_lens_decoder.data<int>(), bsz);
|
||||||
|
|
||||||
max_len_kv_cpu = max_len_kv.copy_to(paddle::CPUPlace(), false);
|
|
||||||
|
max_len_kv_cpu.copy_(max_len_kv, max_len_kv_cpu.place(), false);
|
||||||
|
|
||||||
if (max_enc_len_this_time > 0) {
|
if (max_enc_len_this_time > 0) {
|
||||||
const uint32_t max_tile_size_per_bs_kv =
|
const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size);
|
||||||
div_up(max_enc_dec_len_this_time, block_size);
|
const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv;
|
||||||
kv_batch_ids =
|
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_batch_ids.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
|
||||||
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
|
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_tile_ids_per_batch.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
|
||||||
seq_lens_encoder.place());
|
|
||||||
kv_tile_ids_per_batch =
|
|
||||||
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
|
|
||||||
seq_lens_encoder.place());
|
|
||||||
auto kv_num_blocks_x =
|
auto kv_num_blocks_x =
|
||||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||||
|
|
||||||
@@ -258,16 +256,12 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
|||||||
kv_tile_ids_per_batch.data<int>(), kv_num_blocks_x.data<int>(), bsz,
|
kv_tile_ids_per_batch.data<int>(), kv_num_blocks_x.data<int>(), bsz,
|
||||||
block_size, block_size);
|
block_size, block_size);
|
||||||
|
|
||||||
kv_num_blocks_x_cpu = kv_num_blocks_x.copy_to(paddle::CPUPlace(), false);
|
kv_num_blocks_x_cpu.copy_(kv_num_blocks_x, kv_num_blocks_x_cpu.place(), false);
|
||||||
|
// Clear buffer
|
||||||
const uint32_t encoder_max_tile_size_per_bs_q =
|
const uint32_t encoder_max_tile_size_per_bs_q = div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
|
||||||
div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
|
const uint32_t encoder_batch_shape = bsz * encoder_max_tile_size_per_bs_q;
|
||||||
encoder_batch_ids =
|
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_batch_ids.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
|
||||||
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
|
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_tile_ids_per_batch.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
|
||||||
paddle::DataType::INT32, seq_lens_encoder.place());
|
|
||||||
encoder_tile_ids_per_batch =
|
|
||||||
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
|
|
||||||
paddle::DataType::INT32, seq_lens_encoder.place());
|
|
||||||
auto encoder_num_blocks_x =
|
auto encoder_num_blocks_x =
|
||||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||||
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(), nullptr,
|
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(), nullptr,
|
||||||
@@ -275,21 +269,7 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
|||||||
encoder_tile_ids_per_batch.data<int>(),
|
encoder_tile_ids_per_batch.data<int>(),
|
||||||
encoder_num_blocks_x.data<int>(), bsz,
|
encoder_num_blocks_x.data<int>(), bsz,
|
||||||
encoder_block_shape_q, group_size);
|
encoder_block_shape_q, group_size);
|
||||||
encoder_num_blocks_x_cpu =
|
encoder_num_blocks_x_cpu.copy_(encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false);
|
||||||
encoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
|
|
||||||
} else {
|
|
||||||
encoder_batch_ids =
|
|
||||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
|
||||||
encoder_tile_ids_per_batch =
|
|
||||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
|
||||||
encoder_num_blocks_x_cpu =
|
|
||||||
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
|
|
||||||
kv_batch_ids =
|
|
||||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
|
||||||
kv_tile_ids_per_batch =
|
|
||||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
|
||||||
kv_num_blocks_x_cpu =
|
|
||||||
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (max_just_dec_len_this_time > 0) {
|
if (max_just_dec_len_this_time > 0) {
|
||||||
@@ -314,15 +294,6 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
|||||||
decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);
|
decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
|
||||||
encoder_batch_ids,
|
|
||||||
encoder_tile_ids_per_batch,
|
|
||||||
encoder_num_blocks_x_cpu, /*cpu*/
|
|
||||||
kv_batch_ids,
|
|
||||||
kv_tile_ids_per_batch,
|
|
||||||
kv_num_blocks_x_cpu, /*cpu*/
|
|
||||||
max_len_kv_cpu, /*cpu*/
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||||
@@ -333,16 +304,17 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
|||||||
"decoder_batch_ids",
|
"decoder_batch_ids",
|
||||||
"decoder_tile_ids_per_batch",
|
"decoder_tile_ids_per_batch",
|
||||||
"decoder_num_blocks_x_cpu",
|
"decoder_num_blocks_x_cpu",
|
||||||
"max_len_tensor_cpu"
|
"max_len_tensor_cpu",
|
||||||
|
"encoder_batch_ids",
|
||||||
|
"encoder_tile_ids_per_batch",
|
||||||
|
"encoder_num_blocks_x_cpu",
|
||||||
|
"kv_batch_ids",
|
||||||
|
"kv_tile_ids_per_batch",
|
||||||
|
"kv_num_blocks_x_cpu",
|
||||||
|
"max_len_kv_cpu"
|
||||||
})
|
})
|
||||||
.Outputs({
|
.Outputs({
|
||||||
paddle::Optional("encoder_batch_ids"),
|
|
||||||
paddle::Optional("encoder_tile_ids_per_batch"),
|
|
||||||
paddle::Optional("encoder_num_blocks_x_cpu"),
|
|
||||||
paddle::Optional("kv_batch_ids"),
|
|
||||||
paddle::Optional("kv_tile_ids_per_batch"),
|
|
||||||
paddle::Optional("kv_num_blocks_x_cpu"),
|
|
||||||
"max_len_kv_cpu"
|
|
||||||
})
|
})
|
||||||
.Attrs({
|
.Attrs({
|
||||||
"encoder_block_shape_q: int",
|
"encoder_block_shape_q: int",
|
||||||
|
@@ -299,7 +299,7 @@ paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank, const int device_id,
|
|||||||
paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor &kv_signal_metadata,
|
paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor &kv_signal_metadata,
|
||||||
const int layer_id);
|
const int layer_id);
|
||||||
|
|
||||||
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
void GetBlockShapeAndSplitKVBlock(
|
||||||
const paddle::Tensor &seq_lens_encoder,
|
const paddle::Tensor &seq_lens_encoder,
|
||||||
const paddle::Tensor &seq_lens_decoder,
|
const paddle::Tensor &seq_lens_decoder,
|
||||||
const paddle::Tensor &seq_lens_this_time,
|
const paddle::Tensor &seq_lens_this_time,
|
||||||
@@ -307,6 +307,13 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
|||||||
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||||
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
|
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
|
||||||
|
paddle::Tensor &encoder_batch_ids, // Inplace
|
||||||
|
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
|
||||||
|
paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||||
|
paddle::Tensor &kv_batch_ids, // Inplace
|
||||||
|
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
|
||||||
|
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||||
|
paddle::Tensor &max_len_kv_cpu, // Inplace, Pinned Memory
|
||||||
const int encoder_block_shape_q,
|
const int encoder_block_shape_q,
|
||||||
const int decoder_block_shape_q,
|
const int decoder_block_shape_q,
|
||||||
const int group_size,
|
const int group_size,
|
||||||
|
@@ -580,6 +580,10 @@ class GraphOptimizationConfig:
|
|||||||
""" Whether to use a full cuda graph for the entire forward pass rather than
|
""" Whether to use a full cuda graph for the entire forward pass rather than
|
||||||
splitting certain operations such as attention into subgraphs.
|
splitting certain operations such as attention into subgraphs.
|
||||||
Thus this flag cannot be used together with splitting_ops."""
|
Thus this flag cannot be used together with splitting_ops."""
|
||||||
|
self.cudagraph_only_prefill: bool = False
|
||||||
|
"""When cudagraph_only_prefill is False, only capture decode-only.
|
||||||
|
When cudagraph_only_prefill is True, only capture prefill-only.
|
||||||
|
Now don't support capture both decode-only and prefill-only"""
|
||||||
self.full_cuda_graph: bool = True
|
self.full_cuda_graph: bool = True
|
||||||
|
|
||||||
self.max_capture_size: int = None
|
self.max_capture_size: int = None
|
||||||
@@ -592,13 +596,13 @@ class GraphOptimizationConfig:
|
|||||||
|
|
||||||
self.check_legality_parameters()
|
self.check_legality_parameters()
|
||||||
|
|
||||||
def init_with_cudagrpah_size(self, max_num_seqs: int = 0) -> None:
|
def init_with_cudagrpah_size(self, max_capture_size: int = 0) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize cuda graph capture sizes and
|
Initialize cuda graph capture sizes and
|
||||||
pre-compute the mapping from batch size to padded graph size
|
pre-compute the mapping from batch size to padded graph size
|
||||||
"""
|
"""
|
||||||
# Regular capture sizes
|
# Regular capture sizes
|
||||||
self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_num_seqs]
|
self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_capture_size]
|
||||||
dedup_sizes = list(set(self.cudagraph_capture_sizes))
|
dedup_sizes = list(set(self.cudagraph_capture_sizes))
|
||||||
if len(dedup_sizes) < len(self.cudagraph_capture_sizes):
|
if len(dedup_sizes) < len(self.cudagraph_capture_sizes):
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -632,7 +636,7 @@ class GraphOptimizationConfig:
|
|||||||
# Shape [128, 144, ... 240, 256]
|
# Shape [128, 144, ... 240, 256]
|
||||||
draft_capture_sizes += [16 * i for i in range(9, 17)]
|
draft_capture_sizes += [16 * i for i in range(9, 17)]
|
||||||
# Shape [256, 288, ... 992, 1024]
|
# Shape [256, 288, ... 992, 1024]
|
||||||
draft_capture_sizes += [32 * i for i in range(17, 33)]
|
draft_capture_sizes += [32 * i for i in range(9, 33)]
|
||||||
|
|
||||||
draft_capture_sizes.append(max_num_seqs)
|
draft_capture_sizes.append(max_num_seqs)
|
||||||
self.cudagraph_capture_sizes = sorted(draft_capture_sizes)
|
self.cudagraph_capture_sizes = sorted(draft_capture_sizes)
|
||||||
@@ -1140,7 +1144,11 @@ class FDConfig:
|
|||||||
# Initialize cuda graph capture list
|
# Initialize cuda graph capture list
|
||||||
if self.graph_opt_config.cudagraph_capture_sizes is None:
|
if self.graph_opt_config.cudagraph_capture_sizes is None:
|
||||||
self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs)
|
self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs)
|
||||||
self.graph_opt_config.init_with_cudagrpah_size(max_num_seqs=self.parallel_config.max_num_seqs)
|
|
||||||
|
if self.graph_opt_config.cudagraph_only_prefill:
|
||||||
|
self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=512)
|
||||||
|
else:
|
||||||
|
self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=self.parallel_config.max_num_seqs)
|
||||||
|
|
||||||
# TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn
|
# TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn
|
||||||
if self.graph_opt_config.graph_opt_level == 2:
|
if self.graph_opt_config.graph_opt_level == 2:
|
||||||
|
@@ -81,14 +81,42 @@ class ForwardMeta:
|
|||||||
attn_mask: Optional[paddle.Tensor] = None
|
attn_mask: Optional[paddle.Tensor] = None
|
||||||
# Attention mask offset
|
# Attention mask offset
|
||||||
attn_mask_offsets: Optional[paddle.Tensor] = None
|
attn_mask_offsets: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
|
# A common pattern for launching CUDA kernels is to set the kernel's grids.x dimension
|
||||||
|
# using a `num_blocks` variable, and then map each thread block to a specific batch and
|
||||||
|
# data tile using `batch_ids` and `tile_ids_per_batch`.
|
||||||
|
#
|
||||||
|
# The variable names below follow this pattern, using a common prefix (e.g., `encoder_`, `decoder_`, `kv_`)
|
||||||
|
# for variables that are logically grouped together. The mapping works as follows:
|
||||||
|
#
|
||||||
|
# Usage: `my_kernel<<<grids, ...>>>(..., batch_ids, tile_ids, ...)`
|
||||||
|
# `grids.x` = `num_blocks_cpu`
|
||||||
|
# `batch_id` = `batch_ids[blockIdx.x]`
|
||||||
|
# `tile_id` = `tile_ids[blockIdx.x]`
|
||||||
|
|
||||||
|
# Maps the thread block index (blockIdx.x) to the corresponding batch for the decoder stage in multi_query_append_attention_warp1_4_kernel.
|
||||||
# Decoder batch id. Used by attention backend.
|
# Decoder batch id. Used by attention backend.
|
||||||
decoder_batch_ids: Optional[paddle.Tensor] = None
|
decoder_batch_ids: Optional[paddle.Tensor] = None
|
||||||
# Tile ID for each batch of the decoder. Used by attention backend.
|
# Maps the thread block index (blockIdx.x) to the specific data tile being processed within that batch for the decoder stage in multi_query_append_attention_warp1_4_kernel.
|
||||||
decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None
|
decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None
|
||||||
# The number of blocks that attention backend can use in decode stage
|
# The number of CUDA blocks to launch in the x-dimension for the multi_query_append_attention_warp1_4_kernel, defining its grids.x.
|
||||||
decoder_num_blocks_cpu: Optional[paddle.Tensor] = None
|
decoder_num_blocks_cpu: Optional[paddle.Tensor] = None
|
||||||
# Recorded multiple lengths related to prefill or decode
|
# A tensor that holds multiple lengths related to prefill or decode stages.
|
||||||
max_len_tensor_cpu: Optional[paddle.Tensor] = None
|
max_len_tensor_cpu: Optional[paddle.Tensor] = None
|
||||||
|
# Maps the thread block index (blockIdx.x) to the corresponding batch for the encoder stage in multi_query_append_attention_kernel.
|
||||||
|
encoder_batch_ids: Optional[paddle.Tensor] = None
|
||||||
|
# Maps the thread block index (blockIdx.x) to the specific data tile being processed within that batch for the encoder stage in multi_query_append_attention_kernel.
|
||||||
|
encoder_tile_ids_per_batch: Optional[paddle.Tensor] = None
|
||||||
|
# The number of CUDA blocks to launch in the x-dimension for the multi_query_append_attention_kernel, defining its grids.x.
|
||||||
|
encoder_num_blocks_x_cpu: Optional[paddle.Tensor] = None
|
||||||
|
# Maps the thread block index (blockIdx.x) to the corresponding batch for the append_write_cache_kv kernel.
|
||||||
|
kv_batch_ids: Optional[paddle.Tensor] = None
|
||||||
|
# Maps the thread block index (blockIdx.x) to the specific data tile being processed within that batch for the append_write_cache_kv kernel.
|
||||||
|
kv_tile_ids_per_batch: Optional[paddle.Tensor] = None
|
||||||
|
# The number of CUDA blocks to launch in the x-dimension for the append_write_cache_kv kernel, defining its grids.x.
|
||||||
|
kv_num_blocks_x_cpu: Optional[paddle.Tensor] = None
|
||||||
|
# The maximum sequence length of the KV cache, which may represent the current maximum decoder length.
|
||||||
|
max_len_kv_cpu: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
# Sequence length of encoder for ever batch
|
# Sequence length of encoder for ever batch
|
||||||
seq_lens_encoder: Optional[paddle.Tensor] = None
|
seq_lens_encoder: Optional[paddle.Tensor] = None
|
||||||
@@ -133,6 +161,7 @@ class ForwardMeta:
|
|||||||
"shape": obj.shape,
|
"shape": obj.shape,
|
||||||
"dtype": str(obj.dtype),
|
"dtype": str(obj.dtype),
|
||||||
"place": str(obj.place),
|
"place": str(obj.place),
|
||||||
|
# "content": obj if obj.numel()<10 else "Too big to show"
|
||||||
}
|
}
|
||||||
return tensor_info
|
return tensor_info
|
||||||
elif isinstance(obj, (list, tuple)):
|
elif isinstance(obj, (list, tuple)):
|
||||||
|
@@ -49,14 +49,6 @@ class AppendAttentionMetadata(AttentionMetadata):
|
|||||||
AppendAttentionMetadata
|
AppendAttentionMetadata
|
||||||
"""
|
"""
|
||||||
|
|
||||||
encoder_batch_ids: paddle.Tensor = None
|
|
||||||
encoder_tile_ids_per_batch: paddle.Tensor = None
|
|
||||||
encoder_num_blocks: paddle.Tensor = None
|
|
||||||
kv_batch_ids: paddle.Tensor = None
|
|
||||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
|
||||||
kv_num_blocks: paddle.Tensor = None
|
|
||||||
max_len_kv: paddle.Tensor = None
|
|
||||||
|
|
||||||
_dtype: paddle.dtype = paddle.bfloat16
|
_dtype: paddle.dtype = paddle.bfloat16
|
||||||
encoder_max_partition_size: int = 32768
|
encoder_max_partition_size: int = 32768
|
||||||
max_partition_size: int = 32768
|
max_partition_size: int = 32768
|
||||||
@@ -142,15 +134,7 @@ class AppendAttentionBackend(AttentionBackend):
|
|||||||
metadata.rotary_embs = forward_meta.rotary_embs
|
metadata.rotary_embs = forward_meta.rotary_embs
|
||||||
metadata.attn_mask = forward_meta.attn_mask
|
metadata.attn_mask = forward_meta.attn_mask
|
||||||
metadata.pre_caches_length = forward_meta.pre_caches_length
|
metadata.pre_caches_length = forward_meta.pre_caches_length
|
||||||
(
|
get_block_shape_and_split_kv_block(
|
||||||
metadata.encoder_batch_ids,
|
|
||||||
metadata.encoder_tile_ids_per_batch,
|
|
||||||
metadata.encoder_num_blocks,
|
|
||||||
metadata.kv_batch_ids,
|
|
||||||
metadata.kv_tile_ids_per_batch,
|
|
||||||
metadata.kv_num_blocks,
|
|
||||||
metadata.max_len_kv,
|
|
||||||
) = get_block_shape_and_split_kv_block(
|
|
||||||
forward_meta.seq_lens_encoder,
|
forward_meta.seq_lens_encoder,
|
||||||
forward_meta.seq_lens_decoder,
|
forward_meta.seq_lens_decoder,
|
||||||
forward_meta.seq_lens_this_time,
|
forward_meta.seq_lens_this_time,
|
||||||
@@ -158,6 +142,13 @@ class AppendAttentionBackend(AttentionBackend):
|
|||||||
forward_meta.decoder_tile_ids_per_batch,
|
forward_meta.decoder_tile_ids_per_batch,
|
||||||
forward_meta.decoder_num_blocks_cpu,
|
forward_meta.decoder_num_blocks_cpu,
|
||||||
forward_meta.max_len_tensor_cpu,
|
forward_meta.max_len_tensor_cpu,
|
||||||
|
forward_meta.encoder_batch_ids,
|
||||||
|
forward_meta.encoder_tile_ids_per_batch,
|
||||||
|
forward_meta.encoder_num_blocks_x_cpu,
|
||||||
|
forward_meta.kv_batch_ids,
|
||||||
|
forward_meta.kv_tile_ids_per_batch,
|
||||||
|
forward_meta.kv_num_blocks_x_cpu,
|
||||||
|
forward_meta.max_len_kv_cpu,
|
||||||
self.encoder_block_shape_q,
|
self.encoder_block_shape_q,
|
||||||
self.decoder_block_shape_q,
|
self.decoder_block_shape_q,
|
||||||
self.group_size,
|
self.group_size,
|
||||||
@@ -288,17 +279,17 @@ class AppendAttentionBackend(AttentionBackend):
|
|||||||
forward_meta.batch_id_per_token,
|
forward_meta.batch_id_per_token,
|
||||||
forward_meta.cu_seqlens_q,
|
forward_meta.cu_seqlens_q,
|
||||||
metadata.block_tables,
|
metadata.block_tables,
|
||||||
metadata.encoder_batch_ids,
|
forward_meta.encoder_batch_ids,
|
||||||
metadata.encoder_tile_ids_per_batch,
|
forward_meta.encoder_tile_ids_per_batch,
|
||||||
metadata.encoder_num_blocks,
|
forward_meta.encoder_num_blocks_x_cpu,
|
||||||
metadata.kv_batch_ids,
|
forward_meta.kv_batch_ids,
|
||||||
metadata.kv_tile_ids_per_batch,
|
forward_meta.kv_tile_ids_per_batch,
|
||||||
metadata.kv_num_blocks,
|
forward_meta.kv_num_blocks_x_cpu,
|
||||||
forward_meta.decoder_batch_ids,
|
forward_meta.decoder_batch_ids,
|
||||||
forward_meta.decoder_tile_ids_per_batch,
|
forward_meta.decoder_tile_ids_per_batch,
|
||||||
forward_meta.decoder_num_blocks_cpu,
|
forward_meta.decoder_num_blocks_cpu,
|
||||||
forward_meta.max_len_tensor_cpu,
|
forward_meta.max_len_tensor_cpu,
|
||||||
metadata.max_len_kv,
|
forward_meta.max_len_kv_cpu,
|
||||||
res,
|
res,
|
||||||
metadata.rotary_embs,
|
metadata.rotary_embs,
|
||||||
metadata.attn_mask,
|
metadata.attn_mask,
|
||||||
@@ -344,17 +335,17 @@ class AppendAttentionBackend(AttentionBackend):
|
|||||||
forward_meta.batch_id_per_token,
|
forward_meta.batch_id_per_token,
|
||||||
forward_meta.cu_seqlens_q,
|
forward_meta.cu_seqlens_q,
|
||||||
metadata.block_tables,
|
metadata.block_tables,
|
||||||
metadata.encoder_batch_ids,
|
forward_meta.encoder_batch_ids,
|
||||||
metadata.encoder_tile_ids_per_batch,
|
forward_meta.encoder_tile_ids_per_batch,
|
||||||
metadata.encoder_num_blocks,
|
forward_meta.encoder_num_blocks_x_cpu,
|
||||||
metadata.kv_batch_ids,
|
forward_meta.kv_batch_ids,
|
||||||
metadata.kv_tile_ids_per_batch,
|
forward_meta.kv_tile_ids_per_batch,
|
||||||
metadata.kv_num_blocks,
|
forward_meta.kv_num_blocks_x_cpu,
|
||||||
forward_meta.decoder_batch_ids,
|
forward_meta.decoder_batch_ids,
|
||||||
forward_meta.decoder_tile_ids_per_batch,
|
forward_meta.decoder_tile_ids_per_batch,
|
||||||
forward_meta.decoder_num_blocks_cpu,
|
forward_meta.decoder_num_blocks_cpu,
|
||||||
forward_meta.max_len_tensor_cpu,
|
forward_meta.max_len_tensor_cpu,
|
||||||
metadata.max_len_kv,
|
forward_meta.max_len_kv_cpu,
|
||||||
metadata.rotary_embs,
|
metadata.rotary_embs,
|
||||||
metadata.attn_mask,
|
metadata.attn_mask,
|
||||||
layer.qkv_bias,
|
layer.qkv_bias,
|
||||||
|
@@ -65,13 +65,6 @@ class FlashAttentionMetadata(AttentionMetadata):
|
|||||||
|
|
||||||
rotary_embs: Optional[paddle.Tensor] = None
|
rotary_embs: Optional[paddle.Tensor] = None
|
||||||
block_tables: Optional[paddle.Tensor] = None
|
block_tables: Optional[paddle.Tensor] = None
|
||||||
encoder_batch_ids: paddle.Tensor = None
|
|
||||||
encoder_tile_ids_per_batch: paddle.Tensor = None
|
|
||||||
encoder_num_blocks: paddle.Tensor = None
|
|
||||||
kv_batch_ids: paddle.Tensor = None
|
|
||||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
|
||||||
kv_num_blocks: paddle.Tensor = None
|
|
||||||
max_len_kv: paddle.Tensor = None
|
|
||||||
|
|
||||||
cu_seqlens_q: paddle.Tensor = None
|
cu_seqlens_q: paddle.Tensor = None
|
||||||
cu_seqlens_k: paddle.Tensor = None
|
cu_seqlens_k: paddle.Tensor = None
|
||||||
@@ -198,15 +191,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
|
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
|
||||||
metadata.rotary_embs = forward_meta.rotary_embs
|
metadata.rotary_embs = forward_meta.rotary_embs
|
||||||
metadata.block_tables = forward_meta.block_tables
|
metadata.block_tables = forward_meta.block_tables
|
||||||
(
|
get_block_shape_and_split_kv_block(
|
||||||
metadata.encoder_batch_ids,
|
|
||||||
metadata.encoder_tile_ids_per_batch,
|
|
||||||
metadata.encoder_num_blocks,
|
|
||||||
metadata.kv_batch_ids,
|
|
||||||
metadata.kv_tile_ids_per_batch,
|
|
||||||
metadata.kv_num_blocks,
|
|
||||||
metadata.max_len_kv,
|
|
||||||
) = get_block_shape_and_split_kv_block(
|
|
||||||
forward_meta.seq_lens_encoder,
|
forward_meta.seq_lens_encoder,
|
||||||
forward_meta.seq_lens_decoder,
|
forward_meta.seq_lens_decoder,
|
||||||
forward_meta.seq_lens_this_time,
|
forward_meta.seq_lens_this_time,
|
||||||
@@ -214,6 +199,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
forward_meta.decoder_tile_ids_per_batch,
|
forward_meta.decoder_tile_ids_per_batch,
|
||||||
forward_meta.decoder_num_blocks_cpu,
|
forward_meta.decoder_num_blocks_cpu,
|
||||||
forward_meta.max_len_tensor_cpu,
|
forward_meta.max_len_tensor_cpu,
|
||||||
|
forward_meta.encoder_batch_ids,
|
||||||
|
forward_meta.encoder_tile_ids_per_batch,
|
||||||
|
forward_meta.encoder_num_blocks_x_cpu,
|
||||||
|
forward_meta.kv_batch_ids,
|
||||||
|
forward_meta.kv_tile_ids_per_batch,
|
||||||
|
forward_meta.kv_num_blocks_x_cpu,
|
||||||
|
forward_meta.max_len_kv_cpu,
|
||||||
self.encoder_block_shape_q,
|
self.encoder_block_shape_q,
|
||||||
self.decoder_block_shape_q,
|
self.decoder_block_shape_q,
|
||||||
self.group_size,
|
self.group_size,
|
||||||
@@ -295,9 +287,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
forward_meta.seq_lens_decoder,
|
forward_meta.seq_lens_decoder,
|
||||||
forward_meta.batch_id_per_token,
|
forward_meta.batch_id_per_token,
|
||||||
metadata.block_tables,
|
metadata.block_tables,
|
||||||
metadata.kv_batch_ids,
|
forward_meta.kv_batch_ids,
|
||||||
metadata.kv_tile_ids_per_batch,
|
forward_meta.kv_tile_ids_per_batch,
|
||||||
metadata.kv_num_blocks,
|
forward_meta.kv_num_blocks_x_cpu,
|
||||||
metadata.pre_cache_batch_ids,
|
metadata.pre_cache_batch_ids,
|
||||||
metadata.pre_cache_tile_ids_per_batch,
|
metadata.pre_cache_tile_ids_per_batch,
|
||||||
metadata.pre_cache_num_blocks_cpu,
|
metadata.pre_cache_num_blocks_cpu,
|
||||||
@@ -336,17 +328,17 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
forward_meta.batch_id_per_token,
|
forward_meta.batch_id_per_token,
|
||||||
forward_meta.cu_seqlens_q,
|
forward_meta.cu_seqlens_q,
|
||||||
metadata.block_tables,
|
metadata.block_tables,
|
||||||
metadata.encoder_batch_ids,
|
forward_meta.encoder_batch_ids,
|
||||||
metadata.encoder_tile_ids_per_batch,
|
forward_meta.encoder_tile_ids_per_batch,
|
||||||
metadata.encoder_num_blocks,
|
forward_meta.encoder_num_blocks_x_cpu,
|
||||||
metadata.kv_batch_ids,
|
forward_meta.kv_batch_ids,
|
||||||
metadata.kv_tile_ids_per_batch,
|
forward_meta.kv_tile_ids_per_batch,
|
||||||
metadata.kv_num_blocks,
|
forward_meta.kv_num_blocks_x_cpu,
|
||||||
forward_meta.decoder_batch_ids, # from buffer
|
forward_meta.decoder_batch_ids, # from buffer
|
||||||
forward_meta.decoder_tile_ids_per_batch, # from buffer
|
forward_meta.decoder_tile_ids_per_batch, # from buffer
|
||||||
forward_meta.decoder_num_blocks_cpu,
|
forward_meta.decoder_num_blocks_cpu,
|
||||||
metadata.max_len_tensor_cpu_decoder,
|
metadata.max_len_tensor_cpu_decoder,
|
||||||
metadata.max_len_kv,
|
forward_meta.max_len_kv_cpu,
|
||||||
metadata.rotary_embs,
|
metadata.rotary_embs,
|
||||||
forward_meta.attn_mask,
|
forward_meta.attn_mask,
|
||||||
layer.qkv_bias,
|
layer.qkv_bias,
|
||||||
|
@@ -69,14 +69,6 @@ class MLAAttentionMetadata(AttentionMetadata):
|
|||||||
MLAAttentionMetadata for Multi-Layer Attention
|
MLAAttentionMetadata for Multi-Layer Attention
|
||||||
"""
|
"""
|
||||||
|
|
||||||
encoder_batch_ids: paddle.Tensor = None
|
|
||||||
encoder_tile_ids_per_batch: paddle.Tensor = None
|
|
||||||
encoder_num_blocks: paddle.Tensor = None
|
|
||||||
kv_batch_ids: paddle.Tensor = None
|
|
||||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
|
||||||
kv_num_blocks: paddle.Tensor = None
|
|
||||||
max_len_kv: paddle.Tensor = None
|
|
||||||
|
|
||||||
_dtype: paddle.dtype = paddle.bfloat16
|
_dtype: paddle.dtype = paddle.bfloat16
|
||||||
encoder_max_partition_size: int = 32768
|
encoder_max_partition_size: int = 32768
|
||||||
max_partition_size: int = 32768
|
max_partition_size: int = 32768
|
||||||
@@ -191,15 +183,7 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
metadata.attn_mask = forward_meta.attn_mask
|
metadata.attn_mask = forward_meta.attn_mask
|
||||||
metadata.pre_caches_length = forward_meta.pre_caches_length
|
metadata.pre_caches_length = forward_meta.pre_caches_length
|
||||||
|
|
||||||
(
|
get_block_shape_and_split_kv_block(
|
||||||
metadata.encoder_batch_ids,
|
|
||||||
metadata.encoder_tile_ids_per_batch,
|
|
||||||
metadata.encoder_num_blocks,
|
|
||||||
metadata.kv_batch_ids,
|
|
||||||
metadata.kv_tile_ids_per_batch,
|
|
||||||
metadata.kv_num_blocks,
|
|
||||||
metadata.max_len_kv,
|
|
||||||
) = get_block_shape_and_split_kv_block(
|
|
||||||
forward_meta.seq_lens_encoder,
|
forward_meta.seq_lens_encoder,
|
||||||
forward_meta.seq_lens_decoder,
|
forward_meta.seq_lens_decoder,
|
||||||
forward_meta.seq_lens_this_time,
|
forward_meta.seq_lens_this_time,
|
||||||
@@ -207,6 +191,13 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
forward_meta.decoder_tile_ids_per_batch,
|
forward_meta.decoder_tile_ids_per_batch,
|
||||||
forward_meta.decoder_num_blocks_cpu,
|
forward_meta.decoder_num_blocks_cpu,
|
||||||
forward_meta.max_len_tensor_cpu,
|
forward_meta.max_len_tensor_cpu,
|
||||||
|
forward_meta.encoder_batch_ids,
|
||||||
|
forward_meta.encoder_tile_ids_per_batch,
|
||||||
|
forward_meta.encoder_num_blocks_x_cpu,
|
||||||
|
forward_meta.kv_batch_ids,
|
||||||
|
forward_meta.kv_tile_ids_per_batch,
|
||||||
|
forward_meta.kv_num_blocks_x_cpu,
|
||||||
|
forward_meta.max_len_kv_cpu,
|
||||||
self.encoder_block_shape_q,
|
self.encoder_block_shape_q,
|
||||||
self.decoder_block_shape_q,
|
self.decoder_block_shape_q,
|
||||||
self.group_size,
|
self.group_size,
|
||||||
@@ -362,19 +353,19 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
forward_meta.cu_seqlens_q,
|
forward_meta.cu_seqlens_q,
|
||||||
forward_meta.batch_id_per_token,
|
forward_meta.batch_id_per_token,
|
||||||
metadata.block_tables,
|
metadata.block_tables,
|
||||||
metadata.encoder_batch_ids,
|
forward_meta.encoder_batch_ids,
|
||||||
metadata.encoder_tile_ids_per_batch,
|
forward_meta.encoder_tile_ids_per_batch,
|
||||||
metadata.encoder_num_blocks,
|
forward_meta.encoder_num_blocks_x_cpu,
|
||||||
metadata.kv_batch_ids,
|
forward_meta.kv_batch_ids,
|
||||||
metadata.kv_tile_ids_per_batch,
|
forward_meta.kv_tile_ids_per_batch,
|
||||||
metadata.kv_num_blocks,
|
forward_meta.kv_num_blocks_x_cpu,
|
||||||
forward_meta.decoder_batch_ids,
|
forward_meta.decoder_batch_ids,
|
||||||
forward_meta.decoder_tile_ids_per_batch,
|
forward_meta.decoder_tile_ids_per_batch,
|
||||||
forward_meta.decoder_num_blocks_cpu,
|
forward_meta.decoder_num_blocks_cpu,
|
||||||
forward_meta.decoder_num_blocks_cpu,
|
forward_meta.decoder_num_blocks_cpu,
|
||||||
metadata.max_enc_len_this_time,
|
metadata.max_enc_len_this_time,
|
||||||
metadata.max_dec_len_this_time,
|
metadata.max_dec_len_this_time,
|
||||||
metadata.max_len_kv,
|
forward_meta.max_len_kv_cpu,
|
||||||
None, # attn_mask
|
None, # attn_mask
|
||||||
None, # qkv_bias
|
None, # qkv_bias
|
||||||
None, # qkv_out_scales
|
None, # qkv_out_scales
|
||||||
@@ -483,19 +474,19 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
forward_meta.cu_seqlens_q,
|
forward_meta.cu_seqlens_q,
|
||||||
forward_meta.batch_id_per_token,
|
forward_meta.batch_id_per_token,
|
||||||
metadata.block_tables,
|
metadata.block_tables,
|
||||||
metadata.encoder_batch_ids,
|
forward_meta.encoder_batch_ids,
|
||||||
metadata.encoder_tile_ids_per_batch,
|
forward_meta.encoder_tile_ids_per_batch,
|
||||||
metadata.encoder_num_blocks,
|
forward_meta.encoder_num_blocks_x_cpu,
|
||||||
metadata.kv_batch_ids,
|
forward_meta.kv_batch_ids,
|
||||||
metadata.kv_tile_ids_per_batch,
|
forward_meta.kv_tile_ids_per_batch,
|
||||||
metadata.kv_num_blocks,
|
forward_meta.kv_num_blocks_x_cpu,
|
||||||
forward_meta.decoder_batch_ids,
|
forward_meta.decoder_batch_ids,
|
||||||
forward_meta.decoder_tile_ids_per_batch,
|
forward_meta.decoder_tile_ids_per_batch,
|
||||||
forward_meta.decoder_num_blocks_cpu,
|
forward_meta.decoder_num_blocks_cpu,
|
||||||
forward_meta.decoder_num_blocks_cpu,
|
forward_meta.decoder_num_blocks_cpu,
|
||||||
metadata.max_enc_len_this_time,
|
metadata.max_enc_len_this_time,
|
||||||
metadata.max_dec_len_this_time,
|
metadata.max_dec_len_this_time,
|
||||||
metadata.max_len_kv,
|
forward_meta.max_len_kv_cpu,
|
||||||
None, # attn_mask
|
None, # attn_mask
|
||||||
None, # qkv_bias
|
None, # qkv_bias
|
||||||
None, # qkv_out_scales
|
None, # qkv_out_scales
|
||||||
|
@@ -32,6 +32,13 @@ def get_block_shape_and_split_kv_block(
|
|||||||
decoder_tile_ids_per_batch: paddle.Tensor,
|
decoder_tile_ids_per_batch: paddle.Tensor,
|
||||||
decoder_num_blocks_x_cpu: paddle.Tensor,
|
decoder_num_blocks_x_cpu: paddle.Tensor,
|
||||||
max_len_tensor_cpu: paddle.Tensor,
|
max_len_tensor_cpu: paddle.Tensor,
|
||||||
|
encoder_batch_ids: paddle.Tensor,
|
||||||
|
encoder_tile_ids_per_batch: paddle.Tensor,
|
||||||
|
encoder_num_blocks_x_cpu: paddle.Tensor,
|
||||||
|
kv_batch_ids: paddle.Tensor,
|
||||||
|
kv_tile_ids_per_batch: paddle.Tensor,
|
||||||
|
kv_num_blocks_x_cpu: paddle.Tensor,
|
||||||
|
max_len_kv_cpu: paddle.Tensor,
|
||||||
encoder_block_shape_q: int,
|
encoder_block_shape_q: int,
|
||||||
decoder_block_shape_q: int,
|
decoder_block_shape_q: int,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
@@ -42,15 +49,7 @@ def get_block_shape_and_split_kv_block(
|
|||||||
get_block_shape_and_split_kv_block
|
get_block_shape_and_split_kv_block
|
||||||
"""
|
"""
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
(
|
get_block_shape_and_split_kv_block_cuda(
|
||||||
encoder_batch_ids,
|
|
||||||
encoder_tile_ids_per_batch,
|
|
||||||
encoder_num_blocks,
|
|
||||||
kv_batch_ids,
|
|
||||||
kv_tile_ids_per_batch,
|
|
||||||
kv_num_blocks,
|
|
||||||
max_len_kv_cpu,
|
|
||||||
) = get_block_shape_and_split_kv_block_cuda(
|
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
seq_lens_decoder,
|
seq_lens_decoder,
|
||||||
seq_lens_this_time,
|
seq_lens_this_time,
|
||||||
@@ -58,20 +57,19 @@ def get_block_shape_and_split_kv_block(
|
|||||||
decoder_tile_ids_per_batch,
|
decoder_tile_ids_per_batch,
|
||||||
decoder_num_blocks_x_cpu,
|
decoder_num_blocks_x_cpu,
|
||||||
max_len_tensor_cpu,
|
max_len_tensor_cpu,
|
||||||
|
encoder_batch_ids,
|
||||||
|
encoder_tile_ids_per_batch,
|
||||||
|
encoder_num_blocks_x_cpu,
|
||||||
|
kv_batch_ids,
|
||||||
|
kv_tile_ids_per_batch,
|
||||||
|
kv_num_blocks_x_cpu,
|
||||||
|
max_len_kv_cpu,
|
||||||
encoder_block_shape_q,
|
encoder_block_shape_q,
|
||||||
decoder_block_shape_q,
|
decoder_block_shape_q,
|
||||||
group_size,
|
group_size,
|
||||||
block_size,
|
block_size,
|
||||||
decoder_step_token_num,
|
decoder_step_token_num,
|
||||||
)
|
)
|
||||||
return (
|
|
||||||
encoder_batch_ids,
|
|
||||||
encoder_tile_ids_per_batch,
|
|
||||||
encoder_num_blocks,
|
|
||||||
kv_batch_ids,
|
|
||||||
kv_tile_ids_per_batch,
|
|
||||||
kv_num_blocks,
|
|
||||||
max_len_kv_cpu,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@@ -212,6 +212,22 @@ class MTPProposer(Proposer):
|
|||||||
self.target_model_inputs["max_len_tensor_cpu"]
|
self.target_model_inputs["max_len_tensor_cpu"]
|
||||||
).cpu()
|
).cpu()
|
||||||
|
|
||||||
|
self.model_inputs["encoder_batch_ids"] = paddle.zeros_like(self.target_model_inputs["encoder_batch_ids"])
|
||||||
|
self.model_inputs["encoder_tile_ids_per_batch"] = paddle.zeros_like(
|
||||||
|
self.target_model_inputs["encoder_tile_ids_per_batch"]
|
||||||
|
)
|
||||||
|
self.model_inputs["encoder_num_blocks_x_cpu"] = paddle.zeros_like(
|
||||||
|
self.target_model_inputs["encoder_num_blocks_x_cpu"]
|
||||||
|
).cpu()
|
||||||
|
self.model_inputs["kv_batch_ids"] = paddle.zeros_like(self.target_model_inputs["kv_batch_ids"])
|
||||||
|
self.model_inputs["kv_tile_ids_per_batch"] = paddle.zeros_like(
|
||||||
|
self.target_model_inputs["kv_tile_ids_per_batch"]
|
||||||
|
)
|
||||||
|
self.model_inputs["kv_num_blocks_x_cpu"] = paddle.zeros_like(
|
||||||
|
self.target_model_inputs["kv_num_blocks_x_cpu"]
|
||||||
|
).cpu()
|
||||||
|
self.model_inputs["max_len_kv_cpu"] = paddle.zeros_like(self.target_model_inputs["max_len_kv_cpu"]).cpu()
|
||||||
|
|
||||||
# Get the attention backend
|
# Get the attention backend
|
||||||
attn_cls = get_attention_backend()
|
attn_cls = get_attention_backend()
|
||||||
attn_backend = attn_cls(
|
attn_backend = attn_cls(
|
||||||
@@ -321,6 +337,13 @@ class MTPProposer(Proposer):
|
|||||||
self.model_inputs["decoder_tile_ids_per_batch"] = None
|
self.model_inputs["decoder_tile_ids_per_batch"] = None
|
||||||
self.model_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
|
self.model_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
|
||||||
self.model_inputs["max_len_tensor_cpu"] = None # CPU
|
self.model_inputs["max_len_tensor_cpu"] = None # CPU
|
||||||
|
self.model_inputs["encoder_batch_ids"] = None
|
||||||
|
self.model_inputs["encoder_tile_ids_per_batch"] = None
|
||||||
|
self.model_inputs["encoder_num_blocks_x_cpu"] = None # CPU
|
||||||
|
self.model_inputs["kv_batch_ids"] = None
|
||||||
|
self.model_inputs["kv_tile_ids_per_batch"] = None
|
||||||
|
self.model_inputs["kv_num_blocks_x_cpu"] = None # CPU
|
||||||
|
self.model_inputs["max_len_kv_cpu"] = None # CPU
|
||||||
|
|
||||||
# Input tokens
|
# Input tokens
|
||||||
self.model_inputs["draft_tokens"] = paddle.full(
|
self.model_inputs["draft_tokens"] = paddle.full(
|
||||||
@@ -512,6 +535,13 @@ class MTPProposer(Proposer):
|
|||||||
cu_seqlens_k=self.model_inputs["cu_seqlens_k"],
|
cu_seqlens_k=self.model_inputs["cu_seqlens_k"],
|
||||||
block_tables=self.model_inputs["block_tables"],
|
block_tables=self.model_inputs["block_tables"],
|
||||||
caches=self.model_inputs["caches"],
|
caches=self.model_inputs["caches"],
|
||||||
|
encoder_batch_ids=self.model_inputs["encoder_batch_ids"],
|
||||||
|
encoder_tile_ids_per_batch=self.model_inputs["encoder_tile_ids_per_batch"],
|
||||||
|
encoder_num_blocks_x_cpu=self.model_inputs["encoder_num_blocks_x_cpu"],
|
||||||
|
kv_batch_ids=self.model_inputs["kv_batch_ids"],
|
||||||
|
kv_tile_ids_per_batch=self.model_inputs["kv_tile_ids_per_batch"],
|
||||||
|
kv_num_blocks_x_cpu=self.model_inputs["kv_num_blocks_x_cpu"],
|
||||||
|
max_len_kv_cpu=self.model_inputs["max_len_kv_cpu"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialzie attention meta data
|
# Initialzie attention meta data
|
||||||
|
@@ -430,6 +430,13 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["decoder_tile_ids_per_batch"] = None
|
self.share_inputs["decoder_tile_ids_per_batch"] = None
|
||||||
self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
|
self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
|
||||||
self.share_inputs["max_len_tensor_cpu"] = None # CPU
|
self.share_inputs["max_len_tensor_cpu"] = None # CPU
|
||||||
|
self.share_inputs["encoder_batch_ids"] = None
|
||||||
|
self.share_inputs["encoder_tile_ids_per_batch"] = None
|
||||||
|
self.share_inputs["encoder_num_blocks_x_cpu"] = None # CPU
|
||||||
|
self.share_inputs["kv_batch_ids"] = None
|
||||||
|
self.share_inputs["kv_tile_ids_per_batch"] = None
|
||||||
|
self.share_inputs["kv_num_blocks_x_cpu"] = None # CPU
|
||||||
|
self.share_inputs["max_len_kv_cpu"] = None # CPU
|
||||||
|
|
||||||
# Initialize rotary position embedding
|
# Initialize rotary position embedding
|
||||||
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
||||||
@@ -601,6 +608,13 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
|
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
|
||||||
block_tables=self.share_inputs["block_tables"],
|
block_tables=self.share_inputs["block_tables"],
|
||||||
caches=self.share_inputs["caches"],
|
caches=self.share_inputs["caches"],
|
||||||
|
encoder_batch_ids=self.share_inputs["encoder_batch_ids"],
|
||||||
|
encoder_tile_ids_per_batch=self.share_inputs["encoder_tile_ids_per_batch"],
|
||||||
|
encoder_num_blocks_x_cpu=self.share_inputs["encoder_num_blocks_x_cpu"],
|
||||||
|
kv_batch_ids=self.share_inputs["kv_batch_ids"],
|
||||||
|
kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"],
|
||||||
|
kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"],
|
||||||
|
max_len_kv_cpu=self.share_inputs["max_len_kv_cpu"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update Batch type for cuda graph
|
# Update Batch type for cuda graph
|
||||||
@@ -673,14 +687,31 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
encoder_block_shape_q = 64
|
encoder_block_shape_q = 64
|
||||||
decoder_block_shape_q = 16
|
decoder_block_shape_q = 16
|
||||||
decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1
|
decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1
|
||||||
|
group_size = np.ceil(num_heads / self.model_config.kv_num_heads)
|
||||||
|
|
||||||
decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
|
decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
|
||||||
(decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q
|
(decoder_step_token_num * group_size) / decoder_block_shape_q
|
||||||
|
)
|
||||||
|
encode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
|
||||||
|
(self.model_config.max_model_len * group_size) / encoder_block_shape_q
|
||||||
|
)
|
||||||
|
kv_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
|
||||||
|
self.model_config.max_model_len / self.fd_config.cache_config.block_size
|
||||||
)
|
)
|
||||||
self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||||
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||||
self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
|
self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
|
||||||
self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu()
|
self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu()
|
||||||
|
|
||||||
|
self.share_inputs["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
|
||||||
|
self.share_inputs["encoder_tile_ids_per_batch"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
|
||||||
|
self.share_inputs["encoder_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
|
||||||
|
|
||||||
|
self.share_inputs["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
|
||||||
|
self.share_inputs["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
|
||||||
|
self.share_inputs["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
|
||||||
|
self.share_inputs["max_len_kv_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
|
||||||
|
|
||||||
# Get the attention backend
|
# Get the attention backend
|
||||||
attn_cls = get_attention_backend()
|
attn_cls = get_attention_backend()
|
||||||
attn_backend = attn_cls(
|
attn_backend = attn_cls(
|
||||||
|
@@ -142,6 +142,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.use_cudagraph = self.graph_opt_config.use_cudagraph
|
self.use_cudagraph = self.graph_opt_config.use_cudagraph
|
||||||
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
|
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
|
||||||
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
|
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
|
||||||
|
self.cudagraph_only_prefill = self.graph_opt_config.cudagraph_only_prefill
|
||||||
|
|
||||||
# Initialize share inputs
|
# Initialize share inputs
|
||||||
self._init_share_inputs(self.parallel_config.max_num_seqs)
|
self._init_share_inputs(self.parallel_config.max_num_seqs)
|
||||||
@@ -177,10 +178,49 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
"""
|
"""
|
||||||
check whether prefill stage exist
|
check whether prefill stage exist
|
||||||
"""
|
"""
|
||||||
if int(paddle.max(self.share_inputs["seq_lens_encoder"])) != 0:
|
return int(paddle.max(self.share_inputs["seq_lens_encoder"])) > 0
|
||||||
return 1
|
|
||||||
else:
|
def exist_decode(self):
|
||||||
return 0
|
"""
|
||||||
|
check whether decode stage exist
|
||||||
|
"""
|
||||||
|
return int(paddle.max(self.share_inputs["seq_lens_decoder"])) > 0
|
||||||
|
|
||||||
|
def only_prefill(self):
|
||||||
|
"""
|
||||||
|
check whether prefill only
|
||||||
|
"""
|
||||||
|
if_only_prefill = True
|
||||||
|
decode_exists = None
|
||||||
|
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
|
||||||
|
only_prefill_batch_list = []
|
||||||
|
decode_exists = self.exist_decode()
|
||||||
|
paddle.distributed.all_gather_object(only_prefill_batch_list, not decode_exists)
|
||||||
|
if_only_prefill = all(only_prefill_batch_list)
|
||||||
|
|
||||||
|
if_only_prefill = if_only_prefill and not (decode_exists if decode_exists is not None else self.exist_decode())
|
||||||
|
|
||||||
|
return if_only_prefill
|
||||||
|
|
||||||
|
def only_decode(self):
|
||||||
|
"""
|
||||||
|
check whether decode only
|
||||||
|
"""
|
||||||
|
# Update Batch type for cuda graph for if_only_decode
|
||||||
|
if_only_decode = True
|
||||||
|
prefill_exists = None
|
||||||
|
# mix ep in single node
|
||||||
|
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
|
||||||
|
only_decode_batch_list = []
|
||||||
|
prefill_exists = self.exist_prefill()
|
||||||
|
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
|
||||||
|
if_only_decode = all(only_decode_batch_list)
|
||||||
|
|
||||||
|
if_only_decode = if_only_decode and not (
|
||||||
|
prefill_exists if prefill_exists is not None else self.exist_prefill()
|
||||||
|
)
|
||||||
|
|
||||||
|
return if_only_decode
|
||||||
|
|
||||||
def _init_speculative_proposer(self):
|
def _init_speculative_proposer(self):
|
||||||
"""
|
"""
|
||||||
@@ -600,27 +640,81 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
if self.speculative_method in ["mtp"]:
|
if self.speculative_method in ["mtp"]:
|
||||||
self.proposer.insert_prefill_inputs(req_dicts, num_running_requests)
|
self.proposer.insert_prefill_inputs(req_dicts, num_running_requests)
|
||||||
|
|
||||||
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int):
|
def get_input_length_list(
|
||||||
"""Set dummy prefill inputs to share_inputs"""
|
self, num_tokens: int, batch_size: int, expected_decode_len: int, capture_prefill: bool = False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generates some list for _dummy_prefill_inputs, when capture pure prefill or mtp,
|
||||||
|
the list should be carefully constructed.
|
||||||
|
|
||||||
|
This function addresses a specific problem: in the pure prefill stage, variable
|
||||||
|
input lengths (e.g., `prompt[160, 0]` vs. `prompt[80, 80]`) can lead to different
|
||||||
|
CUDA Grid dimensions for kernels like `split_q_block`. This prevents CUDA Graph
|
||||||
|
reuse.
|
||||||
|
|
||||||
|
The `split_q_block` kernel calculates the total number of blocks, which directly
|
||||||
|
determines the `griddim.x` launch parameter for the `multi_query_append_attention_kernel`.
|
||||||
|
The blocks for a single sequence are determined by the formula:
|
||||||
|
`num_blocks = ceil((sequence_length * group_size) / block_shape_q)`
|
||||||
|
|
||||||
|
Due to the `ceil` (ceiling) function, distributing a total number of tokens across
|
||||||
|
a batch of shorter sequences will result in a larger total block count. For example,
|
||||||
|
with a `group_size` of 5 and `block_shape_q` of 64:
|
||||||
|
- A single sequence of 160 tokens requires `ceil((160 * 5) / 64) = 13` blocks.
|
||||||
|
- Two sequences of 80 tokens each require `ceil((80 * 5) / 64) * 2 = 7 * 2 = 14` blocks.
|
||||||
|
|
||||||
|
To ensure graph replayability, this function creates a "dummy" list of sequence
|
||||||
|
lengths that's designed to produce the theoretical maximum `encoder_num_blocks_x_cpu`
|
||||||
|
for the given `num_tokens` and `batch_size`. This strategy ensures the captured
|
||||||
|
CUDA Graph has the largest possible grid dimensions. At runtime, if the actual number
|
||||||
|
of blocks is less than or equal to this maximum, the kernel can safely execute by
|
||||||
|
using an early-exit mechanism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_tokens (int): The total number of tokens across all sequences.
|
||||||
|
batch_size (int): The number of sequences (requests) in the batch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[int]: A list of integers representing the sequence length for each request.
|
||||||
|
This list is crafted to maximize the total number of blocks.
|
||||||
|
"""
|
||||||
# NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token
|
# NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token
|
||||||
max_dec_len = expected_decode_len + 1
|
max_dec_len = expected_decode_len + 1
|
||||||
full_length = min(
|
input_length = min(
|
||||||
num_tokens // batch_size,
|
num_tokens // (1 if capture_prefill else batch_size),
|
||||||
self.parallel_config.max_model_len - max_dec_len,
|
self.parallel_config.max_model_len - max_dec_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE(wanglongzhi): When the full length is too large, DeepEP's buffer size will not be enough to cause the result to appear nan.
|
# NOTE(wanglongzhi): When the full length is too large, DeepEP's buffer size will not be enough to cause the result to appear nan.
|
||||||
# TODO(wanglongzhi): Figure out the accurate buffer size of DeepEP.
|
# TODO(wanglongzhi): Figure out the accurate buffer size of DeepEP.
|
||||||
if self.fd_config.parallel_config.enable_expert_parallel:
|
if self.fd_config.parallel_config.enable_expert_parallel:
|
||||||
full_length = min(full_length, 32)
|
input_length = min(input_length, 32)
|
||||||
|
|
||||||
input_length = int(full_length * self.cache_config.kv_cache_ratio)
|
|
||||||
block_num = (
|
block_num = (
|
||||||
input_length + self.cache_config.block_size - 1
|
input_length + self.cache_config.block_size - 1
|
||||||
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
|
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
|
||||||
|
|
||||||
|
input_length_list = [input_length] * batch_size
|
||||||
|
|
||||||
|
if capture_prefill:
|
||||||
|
if num_tokens < batch_size:
|
||||||
|
input_length_list = [1] * num_tokens
|
||||||
|
else:
|
||||||
|
input_length_list = [1] * (batch_size - 1)
|
||||||
|
input_length_list.append(num_tokens - batch_size + 1)
|
||||||
|
|
||||||
|
len_of_input_length_list = len(input_length_list)
|
||||||
|
max_dec_len_list = [max_dec_len] * len_of_input_length_list
|
||||||
|
|
||||||
|
return input_length_list, max_dec_len_list, block_num
|
||||||
|
|
||||||
|
def _dummy_prefill_inputs(self, input_length_list: List[int], max_dec_len_list: List[int], block_num: int):
|
||||||
|
"""Set dummy prefill inputs to share_inputs"""
|
||||||
|
batch_size = len(input_length_list)
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
idx = i
|
idx = i
|
||||||
|
input_length = input_length_list[i]
|
||||||
|
max_dec_len = max_dec_len_list[i]
|
||||||
self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
|
self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
|
||||||
self.share_inputs["prompt_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
|
self.share_inputs["prompt_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
|
||||||
self.share_inputs["eos_token_id"][:] = np.array(
|
self.share_inputs["eos_token_id"][:] = np.array(
|
||||||
@@ -745,6 +839,13 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["decoder_tile_ids_per_batch"] = None
|
self.share_inputs["decoder_tile_ids_per_batch"] = None
|
||||||
self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
|
self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
|
||||||
self.share_inputs["max_len_tensor_cpu"] = None # CPU
|
self.share_inputs["max_len_tensor_cpu"] = None # CPU
|
||||||
|
self.share_inputs["encoder_batch_ids"] = None
|
||||||
|
self.share_inputs["encoder_tile_ids_per_batch"] = None
|
||||||
|
self.share_inputs["encoder_num_blocks_x_cpu"] = None # CPU
|
||||||
|
self.share_inputs["kv_batch_ids"] = None
|
||||||
|
self.share_inputs["kv_tile_ids_per_batch"] = None
|
||||||
|
self.share_inputs["kv_num_blocks_x_cpu"] = None # CPU
|
||||||
|
self.share_inputs["max_len_kv_cpu"] = None # CPU
|
||||||
|
|
||||||
# Initialize rotary position embedding
|
# Initialize rotary position embedding
|
||||||
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
||||||
@@ -977,23 +1078,30 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
|
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
|
||||||
block_tables=self.share_inputs["block_tables"],
|
block_tables=self.share_inputs["block_tables"],
|
||||||
caches=self.share_inputs["caches"],
|
caches=self.share_inputs["caches"],
|
||||||
|
encoder_batch_ids=self.share_inputs["encoder_batch_ids"],
|
||||||
|
encoder_tile_ids_per_batch=self.share_inputs["encoder_tile_ids_per_batch"],
|
||||||
|
encoder_num_blocks_x_cpu=self.share_inputs["encoder_num_blocks_x_cpu"],
|
||||||
|
kv_batch_ids=self.share_inputs["kv_batch_ids"],
|
||||||
|
kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"],
|
||||||
|
kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"],
|
||||||
|
max_len_kv_cpu=self.share_inputs["max_len_kv_cpu"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update Batch type for cuda graph
|
# Update Batch type for cuda graph for only_decode_batch
|
||||||
only_decode_batch = True
|
if_only_decode = self.only_decode()
|
||||||
prefill_exists = None
|
only_decode_use_cudagraph = self.use_cudagraph and if_only_decode
|
||||||
# mix ep in single node
|
|
||||||
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
|
|
||||||
only_decode_batch_list = []
|
|
||||||
prefill_exists = self.exist_prefill()
|
|
||||||
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
|
|
||||||
only_decode_batch = all(only_decode_batch_list)
|
|
||||||
self.fd_config.parallel_config.moe_phase.phase = "decode" if only_decode_batch else "prefill"
|
|
||||||
|
|
||||||
|
# Update config about moe for better performance
|
||||||
|
# TODO(wanglongzhi):Modifying the config at runtime is not appropriate; it needs to be moved to forward_meta. It will be used in MoEMethodBase.apply()
|
||||||
|
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
|
||||||
|
self.fd_config.parallel_config.moe_phase.phase = "decode" if if_only_decode else "prefill"
|
||||||
|
|
||||||
|
# Update Batch type for cuda graph for only_prefill_batch
|
||||||
|
only_prefill_use_cudagraph = self.use_cudagraph and self.cudagraph_only_prefill and self.only_prefill()
|
||||||
|
|
||||||
|
# When support capture both prefill-only and decode-only, this will use [only_prefill_use_cudagraph or only_decode_use_cudagraph]
|
||||||
self.forward_meta.step_use_cudagraph = (
|
self.forward_meta.step_use_cudagraph = (
|
||||||
self.use_cudagraph
|
only_prefill_use_cudagraph if self.cudagraph_only_prefill else only_decode_use_cudagraph
|
||||||
and only_decode_batch
|
|
||||||
and not (prefill_exists if prefill_exists is not None else self.exist_prefill())
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialzie attention meta data
|
# Initialzie attention meta data
|
||||||
@@ -1085,14 +1193,31 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
encoder_block_shape_q = 64
|
encoder_block_shape_q = 64
|
||||||
decoder_block_shape_q = 16
|
decoder_block_shape_q = 16
|
||||||
decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1
|
decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1
|
||||||
|
group_size = np.ceil(num_heads / self.model_config.kv_num_heads)
|
||||||
|
|
||||||
decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
|
decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
|
||||||
(decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q
|
(decoder_step_token_num * group_size) / decoder_block_shape_q
|
||||||
|
)
|
||||||
|
encode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
|
||||||
|
(self.model_config.max_model_len * group_size) / encoder_block_shape_q
|
||||||
|
)
|
||||||
|
kv_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
|
||||||
|
self.model_config.max_model_len / self.fd_config.cache_config.block_size
|
||||||
)
|
)
|
||||||
self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||||
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||||
self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory()
|
self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||||
self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu()
|
self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu()
|
||||||
|
|
||||||
|
self.share_inputs["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
|
||||||
|
self.share_inputs["encoder_tile_ids_per_batch"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
|
||||||
|
self.share_inputs["encoder_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
|
||||||
|
|
||||||
|
self.share_inputs["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
|
||||||
|
self.share_inputs["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
|
||||||
|
self.share_inputs["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
|
||||||
|
self.share_inputs["max_len_kv_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
|
||||||
|
|
||||||
# Get the attention backend
|
# Get the attention backend
|
||||||
attn_cls = get_attention_backend()
|
attn_cls = get_attention_backend()
|
||||||
attn_backend = attn_cls(
|
attn_backend = attn_cls(
|
||||||
@@ -1112,6 +1237,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
batch_size: paddle.Tensor,
|
batch_size: paddle.Tensor,
|
||||||
expected_decode_len: int = 1,
|
expected_decode_len: int = 1,
|
||||||
in_capturing: bool = False,
|
in_capturing: bool = False,
|
||||||
|
capture_prefill: bool = False,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Use dummy inputs to run before formal execution.
|
Use dummy inputs to run before formal execution.
|
||||||
@@ -1119,11 +1245,19 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
num_tokens:
|
num_tokens:
|
||||||
expected_decode_len: Expected number of tokens generated
|
expected_decode_len: Expected number of tokens generated
|
||||||
in_capturing: Is cuda graph in capturing state
|
in_capturing: Is cuda graph in capturing state
|
||||||
|
capture_prefill: Capture pure prefill for cuda graph
|
||||||
"""
|
"""
|
||||||
self._dummy_prefill_inputs(
|
|
||||||
|
input_length_list, max_dec_len_list, block_num = self.get_input_length_list(
|
||||||
num_tokens=num_tokens,
|
num_tokens=num_tokens,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
expected_decode_len=expected_decode_len,
|
expected_decode_len=expected_decode_len,
|
||||||
|
capture_prefill=capture_prefill,
|
||||||
|
)
|
||||||
|
self._dummy_prefill_inputs(
|
||||||
|
input_length_list=input_length_list,
|
||||||
|
max_dec_len_list=max_dec_len_list,
|
||||||
|
block_num=block_num,
|
||||||
)
|
)
|
||||||
if self.speculative_method in ["mtp"]:
|
if self.speculative_method in ["mtp"]:
|
||||||
self.proposer.dummy_prefill_inputs(
|
self.proposer.dummy_prefill_inputs(
|
||||||
@@ -1353,6 +1487,20 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
time_before_capture = time.perf_counter()
|
time_before_capture = time.perf_counter()
|
||||||
expected_decode_len = 1
|
expected_decode_len = 1
|
||||||
capture_sizes = self.cudagraph_capture_sizes.copy()
|
capture_sizes = self.cudagraph_capture_sizes.copy()
|
||||||
|
|
||||||
|
if self.fd_config.graph_opt_config.cudagraph_only_prefill:
|
||||||
|
for num_tokens in sorted(capture_sizes, reverse=True):
|
||||||
|
self._dummy_run(
|
||||||
|
num_tokens=num_tokens,
|
||||||
|
batch_size=self.parallel_config.max_num_seqs,
|
||||||
|
in_capturing=True,
|
||||||
|
expected_decode_len=expected_decode_len,
|
||||||
|
capture_prefill=True,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
for batch_size in sorted(capture_sizes, reverse=True):
|
for batch_size in sorted(capture_sizes, reverse=True):
|
||||||
self._dummy_run(
|
self._dummy_run(
|
||||||
num_tokens=self.parallel_config.max_num_batched_tokens,
|
num_tokens=self.parallel_config.max_num_batched_tokens,
|
||||||
@@ -1360,7 +1508,9 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
in_capturing=True,
|
in_capturing=True,
|
||||||
expected_decode_len=expected_decode_len,
|
expected_decode_len=expected_decode_len,
|
||||||
)
|
)
|
||||||
logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}")
|
logger.info(
|
||||||
|
f"Warm up the model with the num_tokens:{batch_size}, expected_decode_len:{expected_decode_len}"
|
||||||
|
)
|
||||||
|
|
||||||
time_after_capture = time.perf_counter()
|
time_after_capture = time.perf_counter()
|
||||||
logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds")
|
logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds")
|
||||||
|
@@ -397,7 +397,6 @@ class PaddleDisWorkerProc:
|
|||||||
self.get_profile_block_num_signal.value[0] = num_blocks_local
|
self.get_profile_block_num_signal.value[0] = num_blocks_local
|
||||||
else:
|
else:
|
||||||
num_blocks_local = self.fd_config.parallel_config.total_block_num
|
num_blocks_local = self.fd_config.parallel_config.total_block_num
|
||||||
|
|
||||||
logger.info(f"------- num_blocks_global: {num_blocks_local} --------")
|
logger.info(f"------- num_blocks_global: {num_blocks_local} --------")
|
||||||
# wait engine launch cache_manager
|
# wait engine launch cache_manager
|
||||||
if self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed":
|
if self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed":
|
||||||
|
@@ -157,7 +157,7 @@ class TestCUDAGrpahSubgraph(unittest.TestCase):
|
|||||||
cache_config = CacheConfig({})
|
cache_config = CacheConfig({})
|
||||||
# Initialize cuda graph capture list
|
# Initialize cuda graph capture list
|
||||||
graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs)
|
graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs)
|
||||||
graph_opt_config.init_with_cudagrpah_size(max_num_seqs=parallel_config.max_num_seqs)
|
graph_opt_config.init_with_cudagrpah_size(max_capture_size=parallel_config.max_num_seqs)
|
||||||
fd_config = FDConfig(
|
fd_config = FDConfig(
|
||||||
graph_opt_config=graph_opt_config,
|
graph_opt_config=graph_opt_config,
|
||||||
parallel_config=parallel_config,
|
parallel_config=parallel_config,
|
||||||
|
@@ -104,7 +104,7 @@ class TestCUDAGrpahSpecDecode(unittest.TestCase):
|
|||||||
cache_config = CacheConfig({})
|
cache_config = CacheConfig({})
|
||||||
# Initialize cuda graph capture list
|
# Initialize cuda graph capture list
|
||||||
graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs)
|
graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs)
|
||||||
graph_opt_config.init_with_cudagrpah_size(max_num_seqs=parallel_config.max_num_seqs)
|
graph_opt_config.init_with_cudagrpah_size(max_capture_size=parallel_config.max_num_seqs)
|
||||||
fd_config = FDConfig(
|
fd_config = FDConfig(
|
||||||
graph_opt_config=graph_opt_config,
|
graph_opt_config=graph_opt_config,
|
||||||
parallel_config=parallel_config,
|
parallel_config=parallel_config,
|
||||||
|
@@ -90,7 +90,7 @@ class TestStaticGraphCUDAGraphSplit(unittest.TestCase):
|
|||||||
graph_opt_config = GraphOptimizationConfig({"use_cudagraph": True, "graph_opt_level": 1})
|
graph_opt_config = GraphOptimizationConfig({"use_cudagraph": True, "graph_opt_level": 1})
|
||||||
parallel_config = ParallelConfig({"max_num_seqs": 1})
|
parallel_config = ParallelConfig({"max_num_seqs": 1})
|
||||||
graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs)
|
graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs)
|
||||||
graph_opt_config.init_with_cudagrpah_size(max_num_seqs=parallel_config.max_num_seqs)
|
graph_opt_config.init_with_cudagrpah_size(max_capture_size=parallel_config.max_num_seqs)
|
||||||
cache_config = CacheConfig({})
|
cache_config = CacheConfig({})
|
||||||
|
|
||||||
fd_config = FDConfig(
|
fd_config = FDConfig(
|
||||||
|
@@ -386,6 +386,14 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
|||||||
self.decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").pin_memory()
|
self.decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||||
self.max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
|
self.max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
|
||||||
|
|
||||||
|
self.encoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32")
|
||||||
|
self.encoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
|
||||||
|
self.encoder_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu()
|
||||||
|
self.kv_batch_ids = paddle.full([self.batch_size], 0, dtype="int32")
|
||||||
|
self.kv_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
|
||||||
|
self.kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu()
|
||||||
|
self.max_len_kv_cpu = paddle.full([1], 0, dtype="int32").cpu()
|
||||||
|
|
||||||
self.cache_shape = (
|
self.cache_shape = (
|
||||||
self.max_block_num,
|
self.max_block_num,
|
||||||
self.kv_num_head,
|
self.kv_num_head,
|
||||||
@@ -469,15 +477,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
|||||||
get_block_shape_and_split_kv_block,
|
get_block_shape_and_split_kv_block,
|
||||||
)
|
)
|
||||||
|
|
||||||
(
|
get_block_shape_and_split_kv_block(
|
||||||
encoder_batch_ids,
|
|
||||||
encoder_tile_ids_per_batch,
|
|
||||||
encoder_num_blocks,
|
|
||||||
kv_batch_ids,
|
|
||||||
kv_tile_ids_per_batch,
|
|
||||||
kv_num_blocks,
|
|
||||||
max_len_kv,
|
|
||||||
) = get_block_shape_and_split_kv_block(
|
|
||||||
self.seq_lens_encoder,
|
self.seq_lens_encoder,
|
||||||
self.seq_lens_decoder,
|
self.seq_lens_decoder,
|
||||||
self.seq_lens_this_time,
|
self.seq_lens_this_time,
|
||||||
@@ -485,6 +485,13 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
|||||||
self.decoder_tile_ids_per_batch,
|
self.decoder_tile_ids_per_batch,
|
||||||
self.decoder_num_blocks_cpu,
|
self.decoder_num_blocks_cpu,
|
||||||
self.max_len_tensor_cpu,
|
self.max_len_tensor_cpu,
|
||||||
|
self.encoder_batch_ids,
|
||||||
|
self.encoder_tile_ids_per_batch,
|
||||||
|
self.encoder_num_blocks_x_cpu,
|
||||||
|
self.kv_batch_ids,
|
||||||
|
self.kv_tile_ids_per_batch,
|
||||||
|
self.kv_num_blocks_x_cpu,
|
||||||
|
self.max_len_kv_cpu,
|
||||||
64,
|
64,
|
||||||
12,
|
12,
|
||||||
(self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head,
|
(self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head,
|
||||||
@@ -508,17 +515,17 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
|||||||
self.padding_offset,
|
self.padding_offset,
|
||||||
self.cum_offset,
|
self.cum_offset,
|
||||||
self.block_tables,
|
self.block_tables,
|
||||||
encoder_batch_ids,
|
self.encoder_batch_ids,
|
||||||
encoder_tile_ids_per_batch,
|
self.encoder_tile_ids_per_batch,
|
||||||
encoder_num_blocks,
|
self.encoder_num_blocks_x_cpu,
|
||||||
kv_batch_ids,
|
self.kv_batch_ids,
|
||||||
kv_tile_ids_per_batch,
|
self.kv_tile_ids_per_batch,
|
||||||
kv_num_blocks,
|
self.kv_num_blocks_x_cpu,
|
||||||
self.decoder_batch_ids,
|
self.decoder_batch_ids,
|
||||||
self.decoder_tile_ids_per_batch,
|
self.decoder_tile_ids_per_batch,
|
||||||
self.decoder_num_blocks_cpu,
|
self.decoder_num_blocks_cpu,
|
||||||
self.max_len_tensor_cpu,
|
self.max_len_tensor_cpu,
|
||||||
max_len_kv,
|
self.max_len_kv_cpu,
|
||||||
self.rope_emb, # rope_emb
|
self.rope_emb, # rope_emb
|
||||||
None, # attn_mask
|
None, # attn_mask
|
||||||
None, # qkv_bias
|
None, # qkv_bias
|
||||||
|
@@ -382,6 +382,13 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
|||||||
self.decoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
|
self.decoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
|
||||||
self.decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").pin_memory()
|
self.decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||||
self.max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
|
self.max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
|
||||||
|
self.encoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32")
|
||||||
|
self.encoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
|
||||||
|
self.encoder_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu()
|
||||||
|
self.kv_batch_ids = paddle.full([self.batch_size], 0, dtype="int32")
|
||||||
|
self.kv_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
|
||||||
|
self.kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu()
|
||||||
|
self.max_len_kv_cpu = paddle.full([1], 0, dtype="int32").cpu()
|
||||||
|
|
||||||
self.cache_shape = (
|
self.cache_shape = (
|
||||||
self.max_block_num,
|
self.max_block_num,
|
||||||
@@ -450,15 +457,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
|||||||
get_block_shape_and_split_kv_block,
|
get_block_shape_and_split_kv_block,
|
||||||
)
|
)
|
||||||
|
|
||||||
(
|
get_block_shape_and_split_kv_block(
|
||||||
encoder_batch_ids,
|
|
||||||
encoder_tile_ids_per_batch,
|
|
||||||
encoder_num_blocks,
|
|
||||||
kv_batch_ids,
|
|
||||||
kv_tile_ids_per_batch,
|
|
||||||
kv_num_blocks,
|
|
||||||
max_len_kv,
|
|
||||||
) = get_block_shape_and_split_kv_block(
|
|
||||||
self.seq_lens_encoder,
|
self.seq_lens_encoder,
|
||||||
self.seq_lens_decoder,
|
self.seq_lens_decoder,
|
||||||
self.seq_lens_this_time,
|
self.seq_lens_this_time,
|
||||||
@@ -466,6 +465,13 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
|||||||
self.decoder_tile_ids_per_batch,
|
self.decoder_tile_ids_per_batch,
|
||||||
self.decoder_num_blocks_cpu,
|
self.decoder_num_blocks_cpu,
|
||||||
self.max_len_tensor_cpu,
|
self.max_len_tensor_cpu,
|
||||||
|
self.encoder_batch_ids,
|
||||||
|
self.encoder_tile_ids_per_batch,
|
||||||
|
self.encoder_num_blocks_x_cpu,
|
||||||
|
self.kv_batch_ids,
|
||||||
|
self.kv_tile_ids_per_batch,
|
||||||
|
self.kv_num_blocks_x_cpu,
|
||||||
|
self.max_len_kv_cpu,
|
||||||
64,
|
64,
|
||||||
12,
|
12,
|
||||||
(self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head,
|
(self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head,
|
||||||
@@ -491,17 +497,17 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
|||||||
self.padding_offset,
|
self.padding_offset,
|
||||||
self.cum_offset,
|
self.cum_offset,
|
||||||
self.block_tables,
|
self.block_tables,
|
||||||
encoder_batch_ids,
|
self.encoder_batch_ids,
|
||||||
encoder_tile_ids_per_batch,
|
self.encoder_tile_ids_per_batch,
|
||||||
encoder_num_blocks,
|
self.encoder_num_blocks_x_cpu,
|
||||||
kv_batch_ids,
|
self.kv_batch_ids,
|
||||||
kv_tile_ids_per_batch,
|
self.kv_tile_ids_per_batch,
|
||||||
kv_num_blocks,
|
self.kv_num_blocks_x_cpu,
|
||||||
self.decoder_batch_ids,
|
self.decoder_batch_ids,
|
||||||
self.decoder_tile_ids_per_batch,
|
self.decoder_tile_ids_per_batch,
|
||||||
self.decoder_num_blocks_cpu,
|
self.decoder_num_blocks_cpu,
|
||||||
self.max_len_tensor_cpu,
|
self.max_len_tensor_cpu,
|
||||||
max_len_kv,
|
self.max_len_kv_cpu,
|
||||||
out,
|
out,
|
||||||
self.rope_emb, # rope_emb
|
self.rope_emb, # rope_emb
|
||||||
None, # attn_mask
|
None, # attn_mask
|
||||||
|
@@ -190,30 +190,32 @@ class TestTreeMask(unittest.TestCase):
|
|||||||
|
|
||||||
encoder_block_shape_q = 64
|
encoder_block_shape_q = 64
|
||||||
decoder_block_shape_q = 16
|
decoder_block_shape_q = 16
|
||||||
|
group_size = self.num_q_head // self.num_kv_head
|
||||||
decode_max_tile_size = (
|
decode_max_tile_size = (
|
||||||
self.bsz
|
self.bsz * (decoder_step_token_num * group_size + decoder_block_shape_q - 1) / decoder_block_shape_q
|
||||||
* (decoder_step_token_num * (self.num_q_head // self.num_kv_head) + decoder_block_shape_q - 1)
|
|
||||||
/ decoder_block_shape_q
|
|
||||||
)
|
)
|
||||||
|
encode_max_tile_size = (
|
||||||
|
self.bsz * (self.max_seq_len * group_size + encoder_block_shape_q - 1) / encoder_block_shape_q
|
||||||
|
)
|
||||||
|
kv_max_tile_size = self.bsz * (self.max_seq_len + self.block_size - 1) / self.block_size
|
||||||
|
|
||||||
decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||||
decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||||
decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory()
|
decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||||
max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
|
max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
|
||||||
|
encoder_batch_ids = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
|
||||||
|
encoder_tile_ids_per_batch = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
|
||||||
|
encoder_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu()
|
||||||
|
kv_batch_ids = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
|
||||||
|
kv_tile_ids_per_batch = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
|
||||||
|
kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu()
|
||||||
|
max_len_kv_cpu = paddle.full([1], 0, dtype="int32").cpu()
|
||||||
q_norm_weight = np.ones([self.head_dim])
|
q_norm_weight = np.ones([self.head_dim])
|
||||||
k_norm_weight = np.ones([self.head_dim])
|
k_norm_weight = np.ones([self.head_dim])
|
||||||
self.q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype="float32")
|
self.q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype="float32")
|
||||||
self.k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype="float32")
|
self.k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype="float32")
|
||||||
paddle.device.synchronize()
|
paddle.device.synchronize()
|
||||||
(
|
get_block_shape_and_split_kv_block(
|
||||||
encoder_batch_ids,
|
|
||||||
encoder_tile_ids_per_batch,
|
|
||||||
encoder_num_blocks,
|
|
||||||
kv_batch_ids,
|
|
||||||
kv_tile_ids_per_batch,
|
|
||||||
kv_num_blocks,
|
|
||||||
max_len_kv,
|
|
||||||
) = get_block_shape_and_split_kv_block(
|
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
seq_lens_decoder,
|
seq_lens_decoder,
|
||||||
seq_lens_this_time,
|
seq_lens_this_time,
|
||||||
@@ -221,6 +223,13 @@ class TestTreeMask(unittest.TestCase):
|
|||||||
decoder_tile_ids_per_batch,
|
decoder_tile_ids_per_batch,
|
||||||
decoder_num_blocks,
|
decoder_num_blocks,
|
||||||
max_len_tensor_cpu,
|
max_len_tensor_cpu,
|
||||||
|
encoder_batch_ids,
|
||||||
|
encoder_tile_ids_per_batch,
|
||||||
|
encoder_num_blocks_x_cpu,
|
||||||
|
kv_batch_ids,
|
||||||
|
kv_tile_ids_per_batch,
|
||||||
|
kv_num_blocks_x_cpu,
|
||||||
|
max_len_kv_cpu,
|
||||||
encoder_block_shape_q,
|
encoder_block_shape_q,
|
||||||
decoder_block_shape_q,
|
decoder_block_shape_q,
|
||||||
self.num_q_head // self.num_kv_head,
|
self.num_q_head // self.num_kv_head,
|
||||||
@@ -243,15 +252,15 @@ class TestTreeMask(unittest.TestCase):
|
|||||||
self.block_tables,
|
self.block_tables,
|
||||||
encoder_batch_ids,
|
encoder_batch_ids,
|
||||||
encoder_tile_ids_per_batch,
|
encoder_tile_ids_per_batch,
|
||||||
encoder_num_blocks,
|
encoder_num_blocks_x_cpu,
|
||||||
kv_batch_ids,
|
kv_batch_ids,
|
||||||
kv_tile_ids_per_batch,
|
kv_tile_ids_per_batch,
|
||||||
kv_num_blocks,
|
kv_num_blocks_x_cpu,
|
||||||
decoder_batch_ids,
|
decoder_batch_ids,
|
||||||
decoder_tile_ids_per_batch,
|
decoder_tile_ids_per_batch,
|
||||||
decoder_num_blocks,
|
decoder_num_blocks,
|
||||||
max_len_tensor_cpu,
|
max_len_tensor_cpu,
|
||||||
max_len_kv,
|
max_len_kv_cpu,
|
||||||
rotary_embs,
|
rotary_embs,
|
||||||
attn_mask,
|
attn_mask,
|
||||||
None, # qkv_bias
|
None, # qkv_bias
|
||||||
|
Reference in New Issue
Block a user