mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-03 15:56:49 +08:00
[Excutor] Experiment Feature-Support Prefill in cudagraph (#3459)
* Support prefill in Cudagraph * Refactor GetBlockShapeAndSplitKVBlock Kernel V2 * Refactor GetBlockShapeAndSplitKVBlock Kernel V2.1 * Refactor GetBlockShapeAndSplitKVBlock Kernel V2.2 * Refactor GetBlockShapeAndSplitKVBlock Kernel V2.3 * Refactor GetBlockShapeAndSplitKVBlock Kernel V2.4 * Refactor GetBlockShapeAndSplitKVBlock Kernel V2.5 * Solve problem about encoder_num_blocks_x_cpu * Add early-exit mechanism for attention kernel * fix test case about append-attention * Update testcode, Add annotations to related tensors * move get_input_length_list * solve test_code * Add annotations about early-exit for attention kernel * Add annotations about early-exit for attention kernel2 * solve comment * solve mtp --------- Co-authored-by: RAM <gstian5555@outlook.com>
This commit is contained in:
@@ -52,6 +52,7 @@ __global__ void multi_query_append_attention_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
@@ -74,6 +75,11 @@ __global__ void multi_query_append_attention_kernel(
|
||||
|
||||
block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
@@ -422,6 +428,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
@@ -445,6 +452,11 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
||||
const uint32_t num_rows_per_block = num_frags_x * 16;
|
||||
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
@@ -902,6 +914,7 @@ void MultiQueryAppendAttention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -960,6 +973,7 @@ void MultiQueryAppendAttention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
@@ -1134,6 +1148,7 @@ void MultiQueryAppendAttention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -1206,6 +1221,7 @@ void MultiQueryAppendAttention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
|
@@ -57,6 +57,7 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
@@ -85,6 +86,11 @@ __global__ void multi_query_append_attention_c4_kernel(
|
||||
|
||||
block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
@@ -520,6 +526,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
@@ -549,6 +556,11 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
||||
const uint32_t num_rows_per_block = num_frags_x * 16;
|
||||
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
@@ -1107,6 +1119,7 @@ void MultiQueryAppendC4Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -1171,6 +1184,7 @@ void MultiQueryAppendC4Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
@@ -1365,6 +1379,7 @@ void MultiQueryAppendC4Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -1445,6 +1460,7 @@ void MultiQueryAppendC4Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
|
@@ -58,6 +58,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
@@ -87,6 +88,11 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
|
||||
block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
@@ -527,6 +533,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
const float quant_min_bound,
|
||||
const float in_scale,
|
||||
const uint32_t chunk_size,
|
||||
const int num_blocks_x_cpu,
|
||||
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
|
||||
// num_heads, head_dim]
|
||||
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
|
||||
@@ -556,6 +563,11 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
const uint32_t num_rows_per_block = num_frags_x * 16;
|
||||
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
|
||||
|
||||
//When cudagraph capture prefill, may launch more gridDim.x
|
||||
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t q_len = seq_lens[batch_id];
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
@@ -1159,6 +1171,7 @@ void MultiQueryAppendC8Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -1217,6 +1230,7 @@ void MultiQueryAppendC8Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
@@ -1443,6 +1457,7 @@ void MultiQueryAppendC8Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -1517,6 +1532,7 @@ void MultiQueryAppendC8Attention(
|
||||
quant_min_bound,
|
||||
in_scale,
|
||||
chunk_size,
|
||||
num_blocks_x_cpu,
|
||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||
static_cast<float *>(tmp_m->ptr()),
|
||||
static_cast<float *>(tmp_d->ptr()),
|
||||
|
@@ -191,14 +191,21 @@ get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time,
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
void GetBlockShapeAndSplitKVBlock(
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
paddle::Tensor &decoder_batch_ids, // Inplace
|
||||
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU
|
||||
paddle::Tensor &encoder_batch_ids, // Inplace
|
||||
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, CPU
|
||||
paddle::Tensor &kv_batch_ids, // Inplace
|
||||
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, CPU
|
||||
paddle::Tensor &max_len_kv_cpu, // Inplace, CPU
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int group_size,
|
||||
@@ -223,13 +230,7 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
int max_system_len = max_len_cpu_ptr[6];
|
||||
int max_just_dec_len_without_system = max_len_cpu_ptr[7];
|
||||
|
||||
paddle::Tensor encoder_batch_ids;
|
||||
paddle::Tensor encoder_tile_ids_per_batch;
|
||||
paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/
|
||||
paddle::Tensor kv_batch_ids;
|
||||
paddle::Tensor kv_tile_ids_per_batch;
|
||||
paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/
|
||||
paddle::Tensor max_len_kv_cpu; /*cpu*/
|
||||
|
||||
|
||||
auto max_len_kv =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
|
||||
@@ -237,17 +238,14 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
max_len_kv.data<int>(), seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(), bsz);
|
||||
|
||||
max_len_kv_cpu = max_len_kv.copy_to(paddle::CPUPlace(), false);
|
||||
|
||||
max_len_kv_cpu.copy_(max_len_kv, max_len_kv_cpu.place(), false);
|
||||
|
||||
if (max_enc_len_this_time > 0) {
|
||||
const uint32_t max_tile_size_per_bs_kv =
|
||||
div_up(max_enc_dec_len_this_time, block_size);
|
||||
kv_batch_ids =
|
||||
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
|
||||
seq_lens_encoder.place());
|
||||
kv_tile_ids_per_batch =
|
||||
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
|
||||
seq_lens_encoder.place());
|
||||
const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size);
|
||||
const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv;
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_batch_ids.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_tile_ids_per_batch.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
|
||||
auto kv_num_blocks_x =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
|
||||
@@ -258,16 +256,12 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
kv_tile_ids_per_batch.data<int>(), kv_num_blocks_x.data<int>(), bsz,
|
||||
block_size, block_size);
|
||||
|
||||
kv_num_blocks_x_cpu = kv_num_blocks_x.copy_to(paddle::CPUPlace(), false);
|
||||
|
||||
const uint32_t encoder_max_tile_size_per_bs_q =
|
||||
div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
|
||||
encoder_batch_ids =
|
||||
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
|
||||
paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
encoder_tile_ids_per_batch =
|
||||
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
|
||||
paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
kv_num_blocks_x_cpu.copy_(kv_num_blocks_x, kv_num_blocks_x_cpu.place(), false);
|
||||
// Clear buffer
|
||||
const uint32_t encoder_max_tile_size_per_bs_q = div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
|
||||
const uint32_t encoder_batch_shape = bsz * encoder_max_tile_size_per_bs_q;
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_batch_ids.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_tile_ids_per_batch.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
|
||||
auto encoder_num_blocks_x =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(), nullptr,
|
||||
@@ -275,21 +269,7 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
encoder_tile_ids_per_batch.data<int>(),
|
||||
encoder_num_blocks_x.data<int>(), bsz,
|
||||
encoder_block_shape_q, group_size);
|
||||
encoder_num_blocks_x_cpu =
|
||||
encoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
|
||||
} else {
|
||||
encoder_batch_ids =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
encoder_tile_ids_per_batch =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
encoder_num_blocks_x_cpu =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
|
||||
kv_batch_ids =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
kv_tile_ids_per_batch =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
kv_num_blocks_x_cpu =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
|
||||
encoder_num_blocks_x_cpu.copy_(encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false);
|
||||
}
|
||||
|
||||
if (max_just_dec_len_this_time > 0) {
|
||||
@@ -314,15 +294,6 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);
|
||||
}
|
||||
|
||||
return {
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks_x_cpu, /*cpu*/
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks_x_cpu, /*cpu*/
|
||||
max_len_kv_cpu, /*cpu*/
|
||||
};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||
@@ -333,16 +304,17 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||
"decoder_batch_ids",
|
||||
"decoder_tile_ids_per_batch",
|
||||
"decoder_num_blocks_x_cpu",
|
||||
"max_len_tensor_cpu"
|
||||
"max_len_tensor_cpu",
|
||||
"encoder_batch_ids",
|
||||
"encoder_tile_ids_per_batch",
|
||||
"encoder_num_blocks_x_cpu",
|
||||
"kv_batch_ids",
|
||||
"kv_tile_ids_per_batch",
|
||||
"kv_num_blocks_x_cpu",
|
||||
"max_len_kv_cpu"
|
||||
})
|
||||
.Outputs({
|
||||
paddle::Optional("encoder_batch_ids"),
|
||||
paddle::Optional("encoder_tile_ids_per_batch"),
|
||||
paddle::Optional("encoder_num_blocks_x_cpu"),
|
||||
paddle::Optional("kv_batch_ids"),
|
||||
paddle::Optional("kv_tile_ids_per_batch"),
|
||||
paddle::Optional("kv_num_blocks_x_cpu"),
|
||||
"max_len_kv_cpu"
|
||||
|
||||
})
|
||||
.Attrs({
|
||||
"encoder_block_shape_q: int",
|
||||
|
@@ -299,7 +299,7 @@ paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank, const int device_id,
|
||||
paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor &kv_signal_metadata,
|
||||
const int layer_id);
|
||||
|
||||
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
void GetBlockShapeAndSplitKVBlock(
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
@@ -307,6 +307,13 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &encoder_batch_ids, // Inplace
|
||||
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &kv_batch_ids, // Inplace
|
||||
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &max_len_kv_cpu, // Inplace, Pinned Memory
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int group_size,
|
||||
|
@@ -580,6 +580,10 @@ class GraphOptimizationConfig:
|
||||
""" Whether to use a full cuda graph for the entire forward pass rather than
|
||||
splitting certain operations such as attention into subgraphs.
|
||||
Thus this flag cannot be used together with splitting_ops."""
|
||||
self.cudagraph_only_prefill: bool = False
|
||||
"""When cudagraph_only_prefill is False, only capture decode-only.
|
||||
When cudagraph_only_prefill is True, only capture prefill-only.
|
||||
Now don't support capture both decode-only and prefill-only"""
|
||||
self.full_cuda_graph: bool = True
|
||||
|
||||
self.max_capture_size: int = None
|
||||
@@ -592,13 +596,13 @@ class GraphOptimizationConfig:
|
||||
|
||||
self.check_legality_parameters()
|
||||
|
||||
def init_with_cudagrpah_size(self, max_num_seqs: int = 0) -> None:
|
||||
def init_with_cudagrpah_size(self, max_capture_size: int = 0) -> None:
|
||||
"""
|
||||
Initialize cuda graph capture sizes and
|
||||
pre-compute the mapping from batch size to padded graph size
|
||||
"""
|
||||
# Regular capture sizes
|
||||
self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_num_seqs]
|
||||
self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_capture_size]
|
||||
dedup_sizes = list(set(self.cudagraph_capture_sizes))
|
||||
if len(dedup_sizes) < len(self.cudagraph_capture_sizes):
|
||||
logger.info(
|
||||
@@ -632,7 +636,7 @@ class GraphOptimizationConfig:
|
||||
# Shape [128, 144, ... 240, 256]
|
||||
draft_capture_sizes += [16 * i for i in range(9, 17)]
|
||||
# Shape [256, 288, ... 992, 1024]
|
||||
draft_capture_sizes += [32 * i for i in range(17, 33)]
|
||||
draft_capture_sizes += [32 * i for i in range(9, 33)]
|
||||
|
||||
draft_capture_sizes.append(max_num_seqs)
|
||||
self.cudagraph_capture_sizes = sorted(draft_capture_sizes)
|
||||
@@ -1140,7 +1144,11 @@ class FDConfig:
|
||||
# Initialize cuda graph capture list
|
||||
if self.graph_opt_config.cudagraph_capture_sizes is None:
|
||||
self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs)
|
||||
self.graph_opt_config.init_with_cudagrpah_size(max_num_seqs=self.parallel_config.max_num_seqs)
|
||||
|
||||
if self.graph_opt_config.cudagraph_only_prefill:
|
||||
self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=512)
|
||||
else:
|
||||
self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=self.parallel_config.max_num_seqs)
|
||||
|
||||
# TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn
|
||||
if self.graph_opt_config.graph_opt_level == 2:
|
||||
|
@@ -81,14 +81,42 @@ class ForwardMeta:
|
||||
attn_mask: Optional[paddle.Tensor] = None
|
||||
# Attention mask offset
|
||||
attn_mask_offsets: Optional[paddle.Tensor] = None
|
||||
|
||||
# A common pattern for launching CUDA kernels is to set the kernel's grids.x dimension
|
||||
# using a `num_blocks` variable, and then map each thread block to a specific batch and
|
||||
# data tile using `batch_ids` and `tile_ids_per_batch`.
|
||||
#
|
||||
# The variable names below follow this pattern, using a common prefix (e.g., `encoder_`, `decoder_`, `kv_`)
|
||||
# for variables that are logically grouped together. The mapping works as follows:
|
||||
#
|
||||
# Usage: `my_kernel<<<grids, ...>>>(..., batch_ids, tile_ids, ...)`
|
||||
# `grids.x` = `num_blocks_cpu`
|
||||
# `batch_id` = `batch_ids[blockIdx.x]`
|
||||
# `tile_id` = `tile_ids[blockIdx.x]`
|
||||
|
||||
# Maps the thread block index (blockIdx.x) to the corresponding batch for the decoder stage in multi_query_append_attention_warp1_4_kernel.
|
||||
# Decoder batch id. Used by attention backend.
|
||||
decoder_batch_ids: Optional[paddle.Tensor] = None
|
||||
# Tile ID for each batch of the decoder. Used by attention backend.
|
||||
# Maps the thread block index (blockIdx.x) to the specific data tile being processed within that batch for the decoder stage in multi_query_append_attention_warp1_4_kernel.
|
||||
decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None
|
||||
# The number of blocks that attention backend can use in decode stage
|
||||
# The number of CUDA blocks to launch in the x-dimension for the multi_query_append_attention_warp1_4_kernel, defining its grids.x.
|
||||
decoder_num_blocks_cpu: Optional[paddle.Tensor] = None
|
||||
# Recorded multiple lengths related to prefill or decode
|
||||
# A tensor that holds multiple lengths related to prefill or decode stages.
|
||||
max_len_tensor_cpu: Optional[paddle.Tensor] = None
|
||||
# Maps the thread block index (blockIdx.x) to the corresponding batch for the encoder stage in multi_query_append_attention_kernel.
|
||||
encoder_batch_ids: Optional[paddle.Tensor] = None
|
||||
# Maps the thread block index (blockIdx.x) to the specific data tile being processed within that batch for the encoder stage in multi_query_append_attention_kernel.
|
||||
encoder_tile_ids_per_batch: Optional[paddle.Tensor] = None
|
||||
# The number of CUDA blocks to launch in the x-dimension for the multi_query_append_attention_kernel, defining its grids.x.
|
||||
encoder_num_blocks_x_cpu: Optional[paddle.Tensor] = None
|
||||
# Maps the thread block index (blockIdx.x) to the corresponding batch for the append_write_cache_kv kernel.
|
||||
kv_batch_ids: Optional[paddle.Tensor] = None
|
||||
# Maps the thread block index (blockIdx.x) to the specific data tile being processed within that batch for the append_write_cache_kv kernel.
|
||||
kv_tile_ids_per_batch: Optional[paddle.Tensor] = None
|
||||
# The number of CUDA blocks to launch in the x-dimension for the append_write_cache_kv kernel, defining its grids.x.
|
||||
kv_num_blocks_x_cpu: Optional[paddle.Tensor] = None
|
||||
# The maximum sequence length of the KV cache, which may represent the current maximum decoder length.
|
||||
max_len_kv_cpu: Optional[paddle.Tensor] = None
|
||||
|
||||
# Sequence length of encoder for ever batch
|
||||
seq_lens_encoder: Optional[paddle.Tensor] = None
|
||||
@@ -133,6 +161,7 @@ class ForwardMeta:
|
||||
"shape": obj.shape,
|
||||
"dtype": str(obj.dtype),
|
||||
"place": str(obj.place),
|
||||
# "content": obj if obj.numel()<10 else "Too big to show"
|
||||
}
|
||||
return tensor_info
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
|
@@ -49,14 +49,6 @@ class AppendAttentionMetadata(AttentionMetadata):
|
||||
AppendAttentionMetadata
|
||||
"""
|
||||
|
||||
encoder_batch_ids: paddle.Tensor = None
|
||||
encoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
encoder_num_blocks: paddle.Tensor = None
|
||||
kv_batch_ids: paddle.Tensor = None
|
||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||
kv_num_blocks: paddle.Tensor = None
|
||||
max_len_kv: paddle.Tensor = None
|
||||
|
||||
_dtype: paddle.dtype = paddle.bfloat16
|
||||
encoder_max_partition_size: int = 32768
|
||||
max_partition_size: int = 32768
|
||||
@@ -142,15 +134,7 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
metadata.rotary_embs = forward_meta.rotary_embs
|
||||
metadata.attn_mask = forward_meta.attn_mask
|
||||
metadata.pre_caches_length = forward_meta.pre_caches_length
|
||||
(
|
||||
metadata.encoder_batch_ids,
|
||||
metadata.encoder_tile_ids_per_batch,
|
||||
metadata.encoder_num_blocks,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.max_len_kv,
|
||||
) = get_block_shape_and_split_kv_block(
|
||||
get_block_shape_and_split_kv_block(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
@@ -158,6 +142,13 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
forward_meta.encoder_num_blocks_x_cpu,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
forward_meta.max_len_kv_cpu,
|
||||
self.encoder_block_shape_q,
|
||||
self.decoder_block_shape_q,
|
||||
self.group_size,
|
||||
@@ -288,17 +279,17 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
forward_meta.batch_id_per_token,
|
||||
forward_meta.cu_seqlens_q,
|
||||
metadata.block_tables,
|
||||
metadata.encoder_batch_ids,
|
||||
metadata.encoder_tile_ids_per_batch,
|
||||
metadata.encoder_num_blocks,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
forward_meta.encoder_num_blocks_x_cpu,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
metadata.max_len_kv,
|
||||
forward_meta.max_len_kv_cpu,
|
||||
res,
|
||||
metadata.rotary_embs,
|
||||
metadata.attn_mask,
|
||||
@@ -344,17 +335,17 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
forward_meta.batch_id_per_token,
|
||||
forward_meta.cu_seqlens_q,
|
||||
metadata.block_tables,
|
||||
metadata.encoder_batch_ids,
|
||||
metadata.encoder_tile_ids_per_batch,
|
||||
metadata.encoder_num_blocks,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
forward_meta.encoder_num_blocks_x_cpu,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
metadata.max_len_kv,
|
||||
forward_meta.max_len_kv_cpu,
|
||||
metadata.rotary_embs,
|
||||
metadata.attn_mask,
|
||||
layer.qkv_bias,
|
||||
|
@@ -65,13 +65,6 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
encoder_batch_ids: paddle.Tensor = None
|
||||
encoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
encoder_num_blocks: paddle.Tensor = None
|
||||
kv_batch_ids: paddle.Tensor = None
|
||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||
kv_num_blocks: paddle.Tensor = None
|
||||
max_len_kv: paddle.Tensor = None
|
||||
|
||||
cu_seqlens_q: paddle.Tensor = None
|
||||
cu_seqlens_k: paddle.Tensor = None
|
||||
@@ -198,15 +191,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
|
||||
metadata.rotary_embs = forward_meta.rotary_embs
|
||||
metadata.block_tables = forward_meta.block_tables
|
||||
(
|
||||
metadata.encoder_batch_ids,
|
||||
metadata.encoder_tile_ids_per_batch,
|
||||
metadata.encoder_num_blocks,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.max_len_kv,
|
||||
) = get_block_shape_and_split_kv_block(
|
||||
get_block_shape_and_split_kv_block(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
@@ -214,6 +199,13 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
forward_meta.encoder_num_blocks_x_cpu,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
forward_meta.max_len_kv_cpu,
|
||||
self.encoder_block_shape_q,
|
||||
self.decoder_block_shape_q,
|
||||
self.group_size,
|
||||
@@ -295,9 +287,9 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.batch_id_per_token,
|
||||
metadata.block_tables,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
metadata.pre_cache_batch_ids,
|
||||
metadata.pre_cache_tile_ids_per_batch,
|
||||
metadata.pre_cache_num_blocks_cpu,
|
||||
@@ -336,17 +328,17 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
forward_meta.batch_id_per_token,
|
||||
forward_meta.cu_seqlens_q,
|
||||
metadata.block_tables,
|
||||
metadata.encoder_batch_ids,
|
||||
metadata.encoder_tile_ids_per_batch,
|
||||
metadata.encoder_num_blocks,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
forward_meta.encoder_num_blocks_x_cpu,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
forward_meta.decoder_batch_ids, # from buffer
|
||||
forward_meta.decoder_tile_ids_per_batch, # from buffer
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
metadata.max_len_tensor_cpu_decoder,
|
||||
metadata.max_len_kv,
|
||||
forward_meta.max_len_kv_cpu,
|
||||
metadata.rotary_embs,
|
||||
forward_meta.attn_mask,
|
||||
layer.qkv_bias,
|
||||
|
@@ -69,14 +69,6 @@ class MLAAttentionMetadata(AttentionMetadata):
|
||||
MLAAttentionMetadata for Multi-Layer Attention
|
||||
"""
|
||||
|
||||
encoder_batch_ids: paddle.Tensor = None
|
||||
encoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
encoder_num_blocks: paddle.Tensor = None
|
||||
kv_batch_ids: paddle.Tensor = None
|
||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||
kv_num_blocks: paddle.Tensor = None
|
||||
max_len_kv: paddle.Tensor = None
|
||||
|
||||
_dtype: paddle.dtype = paddle.bfloat16
|
||||
encoder_max_partition_size: int = 32768
|
||||
max_partition_size: int = 32768
|
||||
@@ -191,15 +183,7 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
metadata.attn_mask = forward_meta.attn_mask
|
||||
metadata.pre_caches_length = forward_meta.pre_caches_length
|
||||
|
||||
(
|
||||
metadata.encoder_batch_ids,
|
||||
metadata.encoder_tile_ids_per_batch,
|
||||
metadata.encoder_num_blocks,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.max_len_kv,
|
||||
) = get_block_shape_and_split_kv_block(
|
||||
get_block_shape_and_split_kv_block(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
@@ -207,6 +191,13 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
forward_meta.encoder_num_blocks_x_cpu,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
forward_meta.max_len_kv_cpu,
|
||||
self.encoder_block_shape_q,
|
||||
self.decoder_block_shape_q,
|
||||
self.group_size,
|
||||
@@ -362,19 +353,19 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
forward_meta.cu_seqlens_q,
|
||||
forward_meta.batch_id_per_token,
|
||||
metadata.block_tables,
|
||||
metadata.encoder_batch_ids,
|
||||
metadata.encoder_tile_ids_per_batch,
|
||||
metadata.encoder_num_blocks,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
forward_meta.encoder_num_blocks_x_cpu,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
metadata.max_enc_len_this_time,
|
||||
metadata.max_dec_len_this_time,
|
||||
metadata.max_len_kv,
|
||||
forward_meta.max_len_kv_cpu,
|
||||
None, # attn_mask
|
||||
None, # qkv_bias
|
||||
None, # qkv_out_scales
|
||||
@@ -483,19 +474,19 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
forward_meta.cu_seqlens_q,
|
||||
forward_meta.batch_id_per_token,
|
||||
metadata.block_tables,
|
||||
metadata.encoder_batch_ids,
|
||||
metadata.encoder_tile_ids_per_batch,
|
||||
metadata.encoder_num_blocks,
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
forward_meta.encoder_num_blocks_x_cpu,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
metadata.max_enc_len_this_time,
|
||||
metadata.max_dec_len_this_time,
|
||||
metadata.max_len_kv,
|
||||
forward_meta.max_len_kv_cpu,
|
||||
None, # attn_mask
|
||||
None, # qkv_bias
|
||||
None, # qkv_out_scales
|
||||
|
@@ -32,6 +32,13 @@ def get_block_shape_and_split_kv_block(
|
||||
decoder_tile_ids_per_batch: paddle.Tensor,
|
||||
decoder_num_blocks_x_cpu: paddle.Tensor,
|
||||
max_len_tensor_cpu: paddle.Tensor,
|
||||
encoder_batch_ids: paddle.Tensor,
|
||||
encoder_tile_ids_per_batch: paddle.Tensor,
|
||||
encoder_num_blocks_x_cpu: paddle.Tensor,
|
||||
kv_batch_ids: paddle.Tensor,
|
||||
kv_tile_ids_per_batch: paddle.Tensor,
|
||||
kv_num_blocks_x_cpu: paddle.Tensor,
|
||||
max_len_kv_cpu: paddle.Tensor,
|
||||
encoder_block_shape_q: int,
|
||||
decoder_block_shape_q: int,
|
||||
group_size: int,
|
||||
@@ -42,15 +49,7 @@ def get_block_shape_and_split_kv_block(
|
||||
get_block_shape_and_split_kv_block
|
||||
"""
|
||||
if current_platform.is_cuda():
|
||||
(
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
max_len_kv_cpu,
|
||||
) = get_block_shape_and_split_kv_block_cuda(
|
||||
get_block_shape_and_split_kv_block_cuda(
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
@@ -58,20 +57,19 @@ def get_block_shape_and_split_kv_block(
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks_x_cpu,
|
||||
max_len_tensor_cpu,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks_x_cpu,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks_x_cpu,
|
||||
max_len_kv_cpu,
|
||||
encoder_block_shape_q,
|
||||
decoder_block_shape_q,
|
||||
group_size,
|
||||
block_size,
|
||||
decoder_step_token_num,
|
||||
)
|
||||
return (
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
max_len_kv_cpu,
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
@@ -212,6 +212,22 @@ class MTPProposer(Proposer):
|
||||
self.target_model_inputs["max_len_tensor_cpu"]
|
||||
).cpu()
|
||||
|
||||
self.model_inputs["encoder_batch_ids"] = paddle.zeros_like(self.target_model_inputs["encoder_batch_ids"])
|
||||
self.model_inputs["encoder_tile_ids_per_batch"] = paddle.zeros_like(
|
||||
self.target_model_inputs["encoder_tile_ids_per_batch"]
|
||||
)
|
||||
self.model_inputs["encoder_num_blocks_x_cpu"] = paddle.zeros_like(
|
||||
self.target_model_inputs["encoder_num_blocks_x_cpu"]
|
||||
).cpu()
|
||||
self.model_inputs["kv_batch_ids"] = paddle.zeros_like(self.target_model_inputs["kv_batch_ids"])
|
||||
self.model_inputs["kv_tile_ids_per_batch"] = paddle.zeros_like(
|
||||
self.target_model_inputs["kv_tile_ids_per_batch"]
|
||||
)
|
||||
self.model_inputs["kv_num_blocks_x_cpu"] = paddle.zeros_like(
|
||||
self.target_model_inputs["kv_num_blocks_x_cpu"]
|
||||
).cpu()
|
||||
self.model_inputs["max_len_kv_cpu"] = paddle.zeros_like(self.target_model_inputs["max_len_kv_cpu"]).cpu()
|
||||
|
||||
# Get the attention backend
|
||||
attn_cls = get_attention_backend()
|
||||
attn_backend = attn_cls(
|
||||
@@ -321,6 +337,13 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["decoder_tile_ids_per_batch"] = None
|
||||
self.model_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
|
||||
self.model_inputs["max_len_tensor_cpu"] = None # CPU
|
||||
self.model_inputs["encoder_batch_ids"] = None
|
||||
self.model_inputs["encoder_tile_ids_per_batch"] = None
|
||||
self.model_inputs["encoder_num_blocks_x_cpu"] = None # CPU
|
||||
self.model_inputs["kv_batch_ids"] = None
|
||||
self.model_inputs["kv_tile_ids_per_batch"] = None
|
||||
self.model_inputs["kv_num_blocks_x_cpu"] = None # CPU
|
||||
self.model_inputs["max_len_kv_cpu"] = None # CPU
|
||||
|
||||
# Input tokens
|
||||
self.model_inputs["draft_tokens"] = paddle.full(
|
||||
@@ -512,6 +535,13 @@ class MTPProposer(Proposer):
|
||||
cu_seqlens_k=self.model_inputs["cu_seqlens_k"],
|
||||
block_tables=self.model_inputs["block_tables"],
|
||||
caches=self.model_inputs["caches"],
|
||||
encoder_batch_ids=self.model_inputs["encoder_batch_ids"],
|
||||
encoder_tile_ids_per_batch=self.model_inputs["encoder_tile_ids_per_batch"],
|
||||
encoder_num_blocks_x_cpu=self.model_inputs["encoder_num_blocks_x_cpu"],
|
||||
kv_batch_ids=self.model_inputs["kv_batch_ids"],
|
||||
kv_tile_ids_per_batch=self.model_inputs["kv_tile_ids_per_batch"],
|
||||
kv_num_blocks_x_cpu=self.model_inputs["kv_num_blocks_x_cpu"],
|
||||
max_len_kv_cpu=self.model_inputs["max_len_kv_cpu"],
|
||||
)
|
||||
|
||||
# Initialzie attention meta data
|
||||
|
@@ -430,6 +430,13 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["decoder_tile_ids_per_batch"] = None
|
||||
self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
|
||||
self.share_inputs["max_len_tensor_cpu"] = None # CPU
|
||||
self.share_inputs["encoder_batch_ids"] = None
|
||||
self.share_inputs["encoder_tile_ids_per_batch"] = None
|
||||
self.share_inputs["encoder_num_blocks_x_cpu"] = None # CPU
|
||||
self.share_inputs["kv_batch_ids"] = None
|
||||
self.share_inputs["kv_tile_ids_per_batch"] = None
|
||||
self.share_inputs["kv_num_blocks_x_cpu"] = None # CPU
|
||||
self.share_inputs["max_len_kv_cpu"] = None # CPU
|
||||
|
||||
# Initialize rotary position embedding
|
||||
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
||||
@@ -601,6 +608,13 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
|
||||
block_tables=self.share_inputs["block_tables"],
|
||||
caches=self.share_inputs["caches"],
|
||||
encoder_batch_ids=self.share_inputs["encoder_batch_ids"],
|
||||
encoder_tile_ids_per_batch=self.share_inputs["encoder_tile_ids_per_batch"],
|
||||
encoder_num_blocks_x_cpu=self.share_inputs["encoder_num_blocks_x_cpu"],
|
||||
kv_batch_ids=self.share_inputs["kv_batch_ids"],
|
||||
kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"],
|
||||
kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"],
|
||||
max_len_kv_cpu=self.share_inputs["max_len_kv_cpu"],
|
||||
)
|
||||
|
||||
# Update Batch type for cuda graph
|
||||
@@ -673,14 +687,31 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
encoder_block_shape_q = 64
|
||||
decoder_block_shape_q = 16
|
||||
decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1
|
||||
group_size = np.ceil(num_heads / self.model_config.kv_num_heads)
|
||||
|
||||
decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
|
||||
(decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q
|
||||
(decoder_step_token_num * group_size) / decoder_block_shape_q
|
||||
)
|
||||
encode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
|
||||
(self.model_config.max_model_len * group_size) / encoder_block_shape_q
|
||||
)
|
||||
kv_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
|
||||
self.model_config.max_model_len / self.fd_config.cache_config.block_size
|
||||
)
|
||||
self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
|
||||
self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu()
|
||||
|
||||
self.share_inputs["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["encoder_tile_ids_per_batch"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["encoder_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
|
||||
|
||||
self.share_inputs["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
|
||||
self.share_inputs["max_len_kv_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
|
||||
|
||||
# Get the attention backend
|
||||
attn_cls = get_attention_backend()
|
||||
attn_backend = attn_cls(
|
||||
|
@@ -142,6 +142,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.use_cudagraph = self.graph_opt_config.use_cudagraph
|
||||
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
|
||||
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
|
||||
self.cudagraph_only_prefill = self.graph_opt_config.cudagraph_only_prefill
|
||||
|
||||
# Initialize share inputs
|
||||
self._init_share_inputs(self.parallel_config.max_num_seqs)
|
||||
@@ -177,10 +178,49 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
"""
|
||||
check whether prefill stage exist
|
||||
"""
|
||||
if int(paddle.max(self.share_inputs["seq_lens_encoder"])) != 0:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
return int(paddle.max(self.share_inputs["seq_lens_encoder"])) > 0
|
||||
|
||||
def exist_decode(self):
|
||||
"""
|
||||
check whether decode stage exist
|
||||
"""
|
||||
return int(paddle.max(self.share_inputs["seq_lens_decoder"])) > 0
|
||||
|
||||
def only_prefill(self):
|
||||
"""
|
||||
check whether prefill only
|
||||
"""
|
||||
if_only_prefill = True
|
||||
decode_exists = None
|
||||
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
|
||||
only_prefill_batch_list = []
|
||||
decode_exists = self.exist_decode()
|
||||
paddle.distributed.all_gather_object(only_prefill_batch_list, not decode_exists)
|
||||
if_only_prefill = all(only_prefill_batch_list)
|
||||
|
||||
if_only_prefill = if_only_prefill and not (decode_exists if decode_exists is not None else self.exist_decode())
|
||||
|
||||
return if_only_prefill
|
||||
|
||||
def only_decode(self):
|
||||
"""
|
||||
check whether decode only
|
||||
"""
|
||||
# Update Batch type for cuda graph for if_only_decode
|
||||
if_only_decode = True
|
||||
prefill_exists = None
|
||||
# mix ep in single node
|
||||
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
|
||||
only_decode_batch_list = []
|
||||
prefill_exists = self.exist_prefill()
|
||||
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
|
||||
if_only_decode = all(only_decode_batch_list)
|
||||
|
||||
if_only_decode = if_only_decode and not (
|
||||
prefill_exists if prefill_exists is not None else self.exist_prefill()
|
||||
)
|
||||
|
||||
return if_only_decode
|
||||
|
||||
def _init_speculative_proposer(self):
|
||||
"""
|
||||
@@ -600,27 +640,81 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
if self.speculative_method in ["mtp"]:
|
||||
self.proposer.insert_prefill_inputs(req_dicts, num_running_requests)
|
||||
|
||||
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int):
|
||||
"""Set dummy prefill inputs to share_inputs"""
|
||||
def get_input_length_list(
|
||||
self, num_tokens: int, batch_size: int, expected_decode_len: int, capture_prefill: bool = False
|
||||
):
|
||||
"""
|
||||
Generates some list for _dummy_prefill_inputs, when capture pure prefill or mtp,
|
||||
the list should be carefully constructed.
|
||||
|
||||
This function addresses a specific problem: in the pure prefill stage, variable
|
||||
input lengths (e.g., `prompt[160, 0]` vs. `prompt[80, 80]`) can lead to different
|
||||
CUDA Grid dimensions for kernels like `split_q_block`. This prevents CUDA Graph
|
||||
reuse.
|
||||
|
||||
The `split_q_block` kernel calculates the total number of blocks, which directly
|
||||
determines the `griddim.x` launch parameter for the `multi_query_append_attention_kernel`.
|
||||
The blocks for a single sequence are determined by the formula:
|
||||
`num_blocks = ceil((sequence_length * group_size) / block_shape_q)`
|
||||
|
||||
Due to the `ceil` (ceiling) function, distributing a total number of tokens across
|
||||
a batch of shorter sequences will result in a larger total block count. For example,
|
||||
with a `group_size` of 5 and `block_shape_q` of 64:
|
||||
- A single sequence of 160 tokens requires `ceil((160 * 5) / 64) = 13` blocks.
|
||||
- Two sequences of 80 tokens each require `ceil((80 * 5) / 64) * 2 = 7 * 2 = 14` blocks.
|
||||
|
||||
To ensure graph replayability, this function creates a "dummy" list of sequence
|
||||
lengths that's designed to produce the theoretical maximum `encoder_num_blocks_x_cpu`
|
||||
for the given `num_tokens` and `batch_size`. This strategy ensures the captured
|
||||
CUDA Graph has the largest possible grid dimensions. At runtime, if the actual number
|
||||
of blocks is less than or equal to this maximum, the kernel can safely execute by
|
||||
using an early-exit mechanism.
|
||||
|
||||
Args:
|
||||
num_tokens (int): The total number of tokens across all sequences.
|
||||
batch_size (int): The number of sequences (requests) in the batch.
|
||||
|
||||
Returns:
|
||||
List[int]: A list of integers representing the sequence length for each request.
|
||||
This list is crafted to maximize the total number of blocks.
|
||||
"""
|
||||
# NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token
|
||||
max_dec_len = expected_decode_len + 1
|
||||
full_length = min(
|
||||
num_tokens // batch_size,
|
||||
input_length = min(
|
||||
num_tokens // (1 if capture_prefill else batch_size),
|
||||
self.parallel_config.max_model_len - max_dec_len,
|
||||
)
|
||||
|
||||
# NOTE(wanglongzhi): When the full length is too large, DeepEP's buffer size will not be enough to cause the result to appear nan.
|
||||
# TODO(wanglongzhi): Figure out the accurate buffer size of DeepEP.
|
||||
if self.fd_config.parallel_config.enable_expert_parallel:
|
||||
full_length = min(full_length, 32)
|
||||
input_length = min(input_length, 32)
|
||||
|
||||
input_length = int(full_length * self.cache_config.kv_cache_ratio)
|
||||
block_num = (
|
||||
input_length + self.cache_config.block_size - 1
|
||||
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
|
||||
|
||||
input_length_list = [input_length] * batch_size
|
||||
|
||||
if capture_prefill:
|
||||
if num_tokens < batch_size:
|
||||
input_length_list = [1] * num_tokens
|
||||
else:
|
||||
input_length_list = [1] * (batch_size - 1)
|
||||
input_length_list.append(num_tokens - batch_size + 1)
|
||||
|
||||
len_of_input_length_list = len(input_length_list)
|
||||
max_dec_len_list = [max_dec_len] * len_of_input_length_list
|
||||
|
||||
return input_length_list, max_dec_len_list, block_num
|
||||
|
||||
def _dummy_prefill_inputs(self, input_length_list: List[int], max_dec_len_list: List[int], block_num: int):
|
||||
"""Set dummy prefill inputs to share_inputs"""
|
||||
batch_size = len(input_length_list)
|
||||
for i in range(batch_size):
|
||||
idx = i
|
||||
input_length = input_length_list[i]
|
||||
max_dec_len = max_dec_len_list[i]
|
||||
self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
|
||||
self.share_inputs["prompt_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
|
||||
self.share_inputs["eos_token_id"][:] = np.array(
|
||||
@@ -745,6 +839,13 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["decoder_tile_ids_per_batch"] = None
|
||||
self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
|
||||
self.share_inputs["max_len_tensor_cpu"] = None # CPU
|
||||
self.share_inputs["encoder_batch_ids"] = None
|
||||
self.share_inputs["encoder_tile_ids_per_batch"] = None
|
||||
self.share_inputs["encoder_num_blocks_x_cpu"] = None # CPU
|
||||
self.share_inputs["kv_batch_ids"] = None
|
||||
self.share_inputs["kv_tile_ids_per_batch"] = None
|
||||
self.share_inputs["kv_num_blocks_x_cpu"] = None # CPU
|
||||
self.share_inputs["max_len_kv_cpu"] = None # CPU
|
||||
|
||||
# Initialize rotary position embedding
|
||||
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
||||
@@ -977,23 +1078,30 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
|
||||
block_tables=self.share_inputs["block_tables"],
|
||||
caches=self.share_inputs["caches"],
|
||||
encoder_batch_ids=self.share_inputs["encoder_batch_ids"],
|
||||
encoder_tile_ids_per_batch=self.share_inputs["encoder_tile_ids_per_batch"],
|
||||
encoder_num_blocks_x_cpu=self.share_inputs["encoder_num_blocks_x_cpu"],
|
||||
kv_batch_ids=self.share_inputs["kv_batch_ids"],
|
||||
kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"],
|
||||
kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"],
|
||||
max_len_kv_cpu=self.share_inputs["max_len_kv_cpu"],
|
||||
)
|
||||
|
||||
# Update Batch type for cuda graph
|
||||
only_decode_batch = True
|
||||
prefill_exists = None
|
||||
# mix ep in single node
|
||||
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
|
||||
only_decode_batch_list = []
|
||||
prefill_exists = self.exist_prefill()
|
||||
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
|
||||
only_decode_batch = all(only_decode_batch_list)
|
||||
self.fd_config.parallel_config.moe_phase.phase = "decode" if only_decode_batch else "prefill"
|
||||
# Update Batch type for cuda graph for only_decode_batch
|
||||
if_only_decode = self.only_decode()
|
||||
only_decode_use_cudagraph = self.use_cudagraph and if_only_decode
|
||||
|
||||
# Update config about moe for better performance
|
||||
# TODO(wanglongzhi):Modifying the config at runtime is not appropriate; it needs to be moved to forward_meta. It will be used in MoEMethodBase.apply()
|
||||
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
|
||||
self.fd_config.parallel_config.moe_phase.phase = "decode" if if_only_decode else "prefill"
|
||||
|
||||
# Update Batch type for cuda graph for only_prefill_batch
|
||||
only_prefill_use_cudagraph = self.use_cudagraph and self.cudagraph_only_prefill and self.only_prefill()
|
||||
|
||||
# When support capture both prefill-only and decode-only, this will use [only_prefill_use_cudagraph or only_decode_use_cudagraph]
|
||||
self.forward_meta.step_use_cudagraph = (
|
||||
self.use_cudagraph
|
||||
and only_decode_batch
|
||||
and not (prefill_exists if prefill_exists is not None else self.exist_prefill())
|
||||
only_prefill_use_cudagraph if self.cudagraph_only_prefill else only_decode_use_cudagraph
|
||||
)
|
||||
|
||||
# Initialzie attention meta data
|
||||
@@ -1085,14 +1193,31 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
encoder_block_shape_q = 64
|
||||
decoder_block_shape_q = 16
|
||||
decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1
|
||||
group_size = np.ceil(num_heads / self.model_config.kv_num_heads)
|
||||
|
||||
decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
|
||||
(decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q
|
||||
(decoder_step_token_num * group_size) / decoder_block_shape_q
|
||||
)
|
||||
encode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
|
||||
(self.model_config.max_model_len * group_size) / encoder_block_shape_q
|
||||
)
|
||||
kv_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
|
||||
self.model_config.max_model_len / self.fd_config.cache_config.block_size
|
||||
)
|
||||
self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||
self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu()
|
||||
|
||||
self.share_inputs["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["encoder_tile_ids_per_batch"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["encoder_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
|
||||
|
||||
self.share_inputs["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
|
||||
self.share_inputs["max_len_kv_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
|
||||
|
||||
# Get the attention backend
|
||||
attn_cls = get_attention_backend()
|
||||
attn_backend = attn_cls(
|
||||
@@ -1112,6 +1237,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
batch_size: paddle.Tensor,
|
||||
expected_decode_len: int = 1,
|
||||
in_capturing: bool = False,
|
||||
capture_prefill: bool = False,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Use dummy inputs to run before formal execution.
|
||||
@@ -1119,11 +1245,19 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
num_tokens:
|
||||
expected_decode_len: Expected number of tokens generated
|
||||
in_capturing: Is cuda graph in capturing state
|
||||
capture_prefill: Capture pure prefill for cuda graph
|
||||
"""
|
||||
self._dummy_prefill_inputs(
|
||||
|
||||
input_length_list, max_dec_len_list, block_num = self.get_input_length_list(
|
||||
num_tokens=num_tokens,
|
||||
batch_size=batch_size,
|
||||
expected_decode_len=expected_decode_len,
|
||||
capture_prefill=capture_prefill,
|
||||
)
|
||||
self._dummy_prefill_inputs(
|
||||
input_length_list=input_length_list,
|
||||
max_dec_len_list=max_dec_len_list,
|
||||
block_num=block_num,
|
||||
)
|
||||
if self.speculative_method in ["mtp"]:
|
||||
self.proposer.dummy_prefill_inputs(
|
||||
@@ -1353,6 +1487,20 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
time_before_capture = time.perf_counter()
|
||||
expected_decode_len = 1
|
||||
capture_sizes = self.cudagraph_capture_sizes.copy()
|
||||
|
||||
if self.fd_config.graph_opt_config.cudagraph_only_prefill:
|
||||
for num_tokens in sorted(capture_sizes, reverse=True):
|
||||
self._dummy_run(
|
||||
num_tokens=num_tokens,
|
||||
batch_size=self.parallel_config.max_num_seqs,
|
||||
in_capturing=True,
|
||||
expected_decode_len=expected_decode_len,
|
||||
capture_prefill=True,
|
||||
)
|
||||
logger.info(
|
||||
f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}"
|
||||
)
|
||||
else:
|
||||
for batch_size in sorted(capture_sizes, reverse=True):
|
||||
self._dummy_run(
|
||||
num_tokens=self.parallel_config.max_num_batched_tokens,
|
||||
@@ -1360,7 +1508,9 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
in_capturing=True,
|
||||
expected_decode_len=expected_decode_len,
|
||||
)
|
||||
logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}")
|
||||
logger.info(
|
||||
f"Warm up the model with the num_tokens:{batch_size}, expected_decode_len:{expected_decode_len}"
|
||||
)
|
||||
|
||||
time_after_capture = time.perf_counter()
|
||||
logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds")
|
||||
|
@@ -397,7 +397,6 @@ class PaddleDisWorkerProc:
|
||||
self.get_profile_block_num_signal.value[0] = num_blocks_local
|
||||
else:
|
||||
num_blocks_local = self.fd_config.parallel_config.total_block_num
|
||||
|
||||
logger.info(f"------- num_blocks_global: {num_blocks_local} --------")
|
||||
# wait engine launch cache_manager
|
||||
if self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed":
|
||||
|
@@ -157,7 +157,7 @@ class TestCUDAGrpahSubgraph(unittest.TestCase):
|
||||
cache_config = CacheConfig({})
|
||||
# Initialize cuda graph capture list
|
||||
graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs)
|
||||
graph_opt_config.init_with_cudagrpah_size(max_num_seqs=parallel_config.max_num_seqs)
|
||||
graph_opt_config.init_with_cudagrpah_size(max_capture_size=parallel_config.max_num_seqs)
|
||||
fd_config = FDConfig(
|
||||
graph_opt_config=graph_opt_config,
|
||||
parallel_config=parallel_config,
|
||||
|
@@ -104,7 +104,7 @@ class TestCUDAGrpahSpecDecode(unittest.TestCase):
|
||||
cache_config = CacheConfig({})
|
||||
# Initialize cuda graph capture list
|
||||
graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs)
|
||||
graph_opt_config.init_with_cudagrpah_size(max_num_seqs=parallel_config.max_num_seqs)
|
||||
graph_opt_config.init_with_cudagrpah_size(max_capture_size=parallel_config.max_num_seqs)
|
||||
fd_config = FDConfig(
|
||||
graph_opt_config=graph_opt_config,
|
||||
parallel_config=parallel_config,
|
||||
|
@@ -90,7 +90,7 @@ class TestStaticGraphCUDAGraphSplit(unittest.TestCase):
|
||||
graph_opt_config = GraphOptimizationConfig({"use_cudagraph": True, "graph_opt_level": 1})
|
||||
parallel_config = ParallelConfig({"max_num_seqs": 1})
|
||||
graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs)
|
||||
graph_opt_config.init_with_cudagrpah_size(max_num_seqs=parallel_config.max_num_seqs)
|
||||
graph_opt_config.init_with_cudagrpah_size(max_capture_size=parallel_config.max_num_seqs)
|
||||
cache_config = CacheConfig({})
|
||||
|
||||
fd_config = FDConfig(
|
||||
|
@@ -386,6 +386,14 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||||
self.decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||
self.max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
|
||||
|
||||
self.encoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32")
|
||||
self.encoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
|
||||
self.encoder_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu()
|
||||
self.kv_batch_ids = paddle.full([self.batch_size], 0, dtype="int32")
|
||||
self.kv_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
|
||||
self.kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu()
|
||||
self.max_len_kv_cpu = paddle.full([1], 0, dtype="int32").cpu()
|
||||
|
||||
self.cache_shape = (
|
||||
self.max_block_num,
|
||||
self.kv_num_head,
|
||||
@@ -469,15 +477,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||||
get_block_shape_and_split_kv_block,
|
||||
)
|
||||
|
||||
(
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
max_len_kv,
|
||||
) = get_block_shape_and_split_kv_block(
|
||||
get_block_shape_and_split_kv_block(
|
||||
self.seq_lens_encoder,
|
||||
self.seq_lens_decoder,
|
||||
self.seq_lens_this_time,
|
||||
@@ -485,6 +485,13 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||||
self.decoder_tile_ids_per_batch,
|
||||
self.decoder_num_blocks_cpu,
|
||||
self.max_len_tensor_cpu,
|
||||
self.encoder_batch_ids,
|
||||
self.encoder_tile_ids_per_batch,
|
||||
self.encoder_num_blocks_x_cpu,
|
||||
self.kv_batch_ids,
|
||||
self.kv_tile_ids_per_batch,
|
||||
self.kv_num_blocks_x_cpu,
|
||||
self.max_len_kv_cpu,
|
||||
64,
|
||||
12,
|
||||
(self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head,
|
||||
@@ -508,17 +515,17 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||||
self.padding_offset,
|
||||
self.cum_offset,
|
||||
self.block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
self.encoder_batch_ids,
|
||||
self.encoder_tile_ids_per_batch,
|
||||
self.encoder_num_blocks_x_cpu,
|
||||
self.kv_batch_ids,
|
||||
self.kv_tile_ids_per_batch,
|
||||
self.kv_num_blocks_x_cpu,
|
||||
self.decoder_batch_ids,
|
||||
self.decoder_tile_ids_per_batch,
|
||||
self.decoder_num_blocks_cpu,
|
||||
self.max_len_tensor_cpu,
|
||||
max_len_kv,
|
||||
self.max_len_kv_cpu,
|
||||
self.rope_emb, # rope_emb
|
||||
None, # attn_mask
|
||||
None, # qkv_bias
|
||||
|
@@ -382,6 +382,13 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||||
self.decoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
|
||||
self.decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||
self.max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
|
||||
self.encoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32")
|
||||
self.encoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
|
||||
self.encoder_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu()
|
||||
self.kv_batch_ids = paddle.full([self.batch_size], 0, dtype="int32")
|
||||
self.kv_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
|
||||
self.kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu()
|
||||
self.max_len_kv_cpu = paddle.full([1], 0, dtype="int32").cpu()
|
||||
|
||||
self.cache_shape = (
|
||||
self.max_block_num,
|
||||
@@ -450,15 +457,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||||
get_block_shape_and_split_kv_block,
|
||||
)
|
||||
|
||||
(
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
max_len_kv,
|
||||
) = get_block_shape_and_split_kv_block(
|
||||
get_block_shape_and_split_kv_block(
|
||||
self.seq_lens_encoder,
|
||||
self.seq_lens_decoder,
|
||||
self.seq_lens_this_time,
|
||||
@@ -466,6 +465,13 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||||
self.decoder_tile_ids_per_batch,
|
||||
self.decoder_num_blocks_cpu,
|
||||
self.max_len_tensor_cpu,
|
||||
self.encoder_batch_ids,
|
||||
self.encoder_tile_ids_per_batch,
|
||||
self.encoder_num_blocks_x_cpu,
|
||||
self.kv_batch_ids,
|
||||
self.kv_tile_ids_per_batch,
|
||||
self.kv_num_blocks_x_cpu,
|
||||
self.max_len_kv_cpu,
|
||||
64,
|
||||
12,
|
||||
(self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head,
|
||||
@@ -491,17 +497,17 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||||
self.padding_offset,
|
||||
self.cum_offset,
|
||||
self.block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
self.encoder_batch_ids,
|
||||
self.encoder_tile_ids_per_batch,
|
||||
self.encoder_num_blocks_x_cpu,
|
||||
self.kv_batch_ids,
|
||||
self.kv_tile_ids_per_batch,
|
||||
self.kv_num_blocks_x_cpu,
|
||||
self.decoder_batch_ids,
|
||||
self.decoder_tile_ids_per_batch,
|
||||
self.decoder_num_blocks_cpu,
|
||||
self.max_len_tensor_cpu,
|
||||
max_len_kv,
|
||||
self.max_len_kv_cpu,
|
||||
out,
|
||||
self.rope_emb, # rope_emb
|
||||
None, # attn_mask
|
||||
|
@@ -190,30 +190,32 @@ class TestTreeMask(unittest.TestCase):
|
||||
|
||||
encoder_block_shape_q = 64
|
||||
decoder_block_shape_q = 16
|
||||
|
||||
group_size = self.num_q_head // self.num_kv_head
|
||||
decode_max_tile_size = (
|
||||
self.bsz
|
||||
* (decoder_step_token_num * (self.num_q_head // self.num_kv_head) + decoder_block_shape_q - 1)
|
||||
/ decoder_block_shape_q
|
||||
self.bsz * (decoder_step_token_num * group_size + decoder_block_shape_q - 1) / decoder_block_shape_q
|
||||
)
|
||||
encode_max_tile_size = (
|
||||
self.bsz * (self.max_seq_len * group_size + encoder_block_shape_q - 1) / encoder_block_shape_q
|
||||
)
|
||||
kv_max_tile_size = self.bsz * (self.max_seq_len + self.block_size - 1) / self.block_size
|
||||
|
||||
decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||
max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
|
||||
encoder_batch_ids = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
|
||||
encoder_tile_ids_per_batch = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
|
||||
encoder_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu()
|
||||
kv_batch_ids = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
|
||||
kv_tile_ids_per_batch = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
|
||||
kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu()
|
||||
max_len_kv_cpu = paddle.full([1], 0, dtype="int32").cpu()
|
||||
q_norm_weight = np.ones([self.head_dim])
|
||||
k_norm_weight = np.ones([self.head_dim])
|
||||
self.q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype="float32")
|
||||
self.k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype="float32")
|
||||
paddle.device.synchronize()
|
||||
(
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
max_len_kv,
|
||||
) = get_block_shape_and_split_kv_block(
|
||||
get_block_shape_and_split_kv_block(
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
@@ -221,6 +223,13 @@ class TestTreeMask(unittest.TestCase):
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
max_len_tensor_cpu,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks_x_cpu,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks_x_cpu,
|
||||
max_len_kv_cpu,
|
||||
encoder_block_shape_q,
|
||||
decoder_block_shape_q,
|
||||
self.num_q_head // self.num_kv_head,
|
||||
@@ -243,15 +252,15 @@ class TestTreeMask(unittest.TestCase):
|
||||
self.block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
encoder_num_blocks_x_cpu,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
kv_num_blocks_x_cpu,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
max_len_tensor_cpu,
|
||||
max_len_kv,
|
||||
max_len_kv_cpu,
|
||||
rotary_embs,
|
||||
attn_mask,
|
||||
None, # qkv_bias
|
||||
|
Reference in New Issue
Block a user