[Excutor] Experiment Feature-Support Prefill in cudagraph (#3459)

* Support prefill in Cudagraph

* Refactor GetBlockShapeAndSplitKVBlock Kernel V2

* Refactor GetBlockShapeAndSplitKVBlock Kernel V2.1

* Refactor GetBlockShapeAndSplitKVBlock Kernel V2.2

* Refactor GetBlockShapeAndSplitKVBlock Kernel V2.3

* Refactor GetBlockShapeAndSplitKVBlock Kernel V2.4

* Refactor GetBlockShapeAndSplitKVBlock Kernel V2.5

* Solve problem about encoder_num_blocks_x_cpu

* Add early-exit mechanism for attention kernel

* fix test case about append-attention

* Update testcode, Add annotations to related tensors

* move get_input_length_list

* solve test_code

* Add annotations about early-exit for attention kernel

* Add annotations about early-exit for attention kernel2

* solve comment

* solve mtp

---------

Co-authored-by: RAM <gstian5555@outlook.com>
This commit is contained in:
Jundong Liu
2025-09-08 13:12:24 +08:00
committed by GitHub
parent 472402bf4e
commit 3d0aaa5923
21 changed files with 528 additions and 260 deletions

View File

@@ -52,6 +52,7 @@ __global__ void multi_query_append_attention_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
@@ -74,6 +75,11 @@ __global__ void multi_query_append_attention_kernel(
block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}
const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
@@ -422,6 +428,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
@@ -445,6 +452,11 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const uint32_t num_rows_per_block = num_frags_x * 16;
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}
const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
@@ -902,6 +914,7 @@ void MultiQueryAppendAttention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
@@ -960,6 +973,7 @@ void MultiQueryAppendAttention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
@@ -1134,6 +1148,7 @@ void MultiQueryAppendAttention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
@@ -1206,6 +1221,7 @@ void MultiQueryAppendAttention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),

View File

@@ -57,6 +57,7 @@ __global__ void multi_query_append_attention_c4_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
@@ -85,6 +86,11 @@ __global__ void multi_query_append_attention_c4_kernel(
block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}
const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
@@ -520,6 +526,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
@@ -549,6 +556,11 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const uint32_t num_rows_per_block = num_frags_x * 16;
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}
const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
@@ -1107,6 +1119,7 @@ void MultiQueryAppendC4Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
@@ -1171,6 +1184,7 @@ void MultiQueryAppendC4Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
@@ -1365,6 +1379,7 @@ void MultiQueryAppendC4Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
@@ -1445,6 +1460,7 @@ void MultiQueryAppendC4Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),

View File

@@ -58,6 +58,7 @@ __global__ void multi_query_append_attention_c8_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
@@ -87,6 +88,11 @@ __global__ void multi_query_append_attention_c8_kernel(
block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}
const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
@@ -527,6 +533,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
@@ -556,6 +563,11 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const uint32_t num_rows_per_block = num_frags_x * 16;
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}
const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
@@ -1159,6 +1171,7 @@ void MultiQueryAppendC8Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
@@ -1217,6 +1230,7 @@ void MultiQueryAppendC8Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
@@ -1443,6 +1457,7 @@ void MultiQueryAppendC8Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
@@ -1517,6 +1532,7 @@ void MultiQueryAppendC8Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),

View File

@@ -191,14 +191,21 @@ get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time,
}
}
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
void GetBlockShapeAndSplitKVBlock(
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
paddle::Tensor &decoder_batch_ids, // Inplace
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU
paddle::Tensor &encoder_batch_ids, // Inplace
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, CPU
paddle::Tensor &kv_batch_ids, // Inplace
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, CPU
paddle::Tensor &max_len_kv_cpu, // Inplace, CPU
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
@@ -223,13 +230,7 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
int max_system_len = max_len_cpu_ptr[6];
int max_just_dec_len_without_system = max_len_cpu_ptr[7];
paddle::Tensor encoder_batch_ids;
paddle::Tensor encoder_tile_ids_per_batch;
paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/
paddle::Tensor kv_batch_ids;
paddle::Tensor kv_tile_ids_per_batch;
paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/
paddle::Tensor max_len_kv_cpu; /*cpu*/
auto max_len_kv =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
@@ -237,17 +238,14 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
max_len_kv.data<int>(), seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(), bsz);
max_len_kv_cpu = max_len_kv.copy_to(paddle::CPUPlace(), false);
max_len_kv_cpu.copy_(max_len_kv, max_len_kv_cpu.place(), false);
if (max_enc_len_this_time > 0) {
const uint32_t max_tile_size_per_bs_kv =
div_up(max_enc_dec_len_this_time, block_size);
kv_batch_ids =
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
seq_lens_encoder.place());
kv_tile_ids_per_batch =
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
seq_lens_encoder.place());
const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size);
const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv;
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_batch_ids.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_tile_ids_per_batch.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
auto kv_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
@@ -258,16 +256,12 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
kv_tile_ids_per_batch.data<int>(), kv_num_blocks_x.data<int>(), bsz,
block_size, block_size);
kv_num_blocks_x_cpu = kv_num_blocks_x.copy_to(paddle::CPUPlace(), false);
const uint32_t encoder_max_tile_size_per_bs_q =
div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
encoder_batch_ids =
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
paddle::DataType::INT32, seq_lens_encoder.place());
encoder_tile_ids_per_batch =
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
paddle::DataType::INT32, seq_lens_encoder.place());
kv_num_blocks_x_cpu.copy_(kv_num_blocks_x, kv_num_blocks_x_cpu.place(), false);
// Clear buffer
const uint32_t encoder_max_tile_size_per_bs_q = div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
const uint32_t encoder_batch_shape = bsz * encoder_max_tile_size_per_bs_q;
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_batch_ids.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_tile_ids_per_batch.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
auto encoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(), nullptr,
@@ -275,21 +269,7 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
encoder_tile_ids_per_batch.data<int>(),
encoder_num_blocks_x.data<int>(), bsz,
encoder_block_shape_q, group_size);
encoder_num_blocks_x_cpu =
encoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
} else {
encoder_batch_ids =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
encoder_tile_ids_per_batch =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
encoder_num_blocks_x_cpu =
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
kv_batch_ids =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
kv_tile_ids_per_batch =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
kv_num_blocks_x_cpu =
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
encoder_num_blocks_x_cpu.copy_(encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false);
}
if (max_just_dec_len_this_time > 0) {
@@ -314,15 +294,6 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);
}
return {
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks_x_cpu, /*cpu*/
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks_x_cpu, /*cpu*/
max_len_kv_cpu, /*cpu*/
};
}
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
@@ -333,16 +304,17 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
"decoder_batch_ids",
"decoder_tile_ids_per_batch",
"decoder_num_blocks_x_cpu",
"max_len_tensor_cpu"
"max_len_tensor_cpu",
"encoder_batch_ids",
"encoder_tile_ids_per_batch",
"encoder_num_blocks_x_cpu",
"kv_batch_ids",
"kv_tile_ids_per_batch",
"kv_num_blocks_x_cpu",
"max_len_kv_cpu"
})
.Outputs({
paddle::Optional("encoder_batch_ids"),
paddle::Optional("encoder_tile_ids_per_batch"),
paddle::Optional("encoder_num_blocks_x_cpu"),
paddle::Optional("kv_batch_ids"),
paddle::Optional("kv_tile_ids_per_batch"),
paddle::Optional("kv_num_blocks_x_cpu"),
"max_len_kv_cpu"
})
.Attrs({
"encoder_block_shape_q: int",

View File

@@ -299,7 +299,7 @@ paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank, const int device_id,
paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor &kv_signal_metadata,
const int layer_id);
std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
void GetBlockShapeAndSplitKVBlock(
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
@@ -307,6 +307,13 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
paddle::Tensor &encoder_batch_ids, // Inplace
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, Pinned Memory
paddle::Tensor &kv_batch_ids, // Inplace
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, Pinned Memory
paddle::Tensor &max_len_kv_cpu, // Inplace, Pinned Memory
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,

View File

@@ -580,6 +580,10 @@ class GraphOptimizationConfig:
""" Whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs.
Thus this flag cannot be used together with splitting_ops."""
self.cudagraph_only_prefill: bool = False
"""When cudagraph_only_prefill is False, only capture decode-only.
When cudagraph_only_prefill is True, only capture prefill-only.
Now don't support capture both decode-only and prefill-only"""
self.full_cuda_graph: bool = True
self.max_capture_size: int = None
@@ -592,13 +596,13 @@ class GraphOptimizationConfig:
self.check_legality_parameters()
def init_with_cudagrpah_size(self, max_num_seqs: int = 0) -> None:
def init_with_cudagrpah_size(self, max_capture_size: int = 0) -> None:
"""
Initialize cuda graph capture sizes and
pre-compute the mapping from batch size to padded graph size
"""
# Regular capture sizes
self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_num_seqs]
self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_capture_size]
dedup_sizes = list(set(self.cudagraph_capture_sizes))
if len(dedup_sizes) < len(self.cudagraph_capture_sizes):
logger.info(
@@ -632,7 +636,7 @@ class GraphOptimizationConfig:
# Shape [128, 144, ... 240, 256]
draft_capture_sizes += [16 * i for i in range(9, 17)]
# Shape [256, 288, ... 992, 1024]
draft_capture_sizes += [32 * i for i in range(17, 33)]
draft_capture_sizes += [32 * i for i in range(9, 33)]
draft_capture_sizes.append(max_num_seqs)
self.cudagraph_capture_sizes = sorted(draft_capture_sizes)
@@ -1140,7 +1144,11 @@ class FDConfig:
# Initialize cuda graph capture list
if self.graph_opt_config.cudagraph_capture_sizes is None:
self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs)
self.graph_opt_config.init_with_cudagrpah_size(max_num_seqs=self.parallel_config.max_num_seqs)
if self.graph_opt_config.cudagraph_only_prefill:
self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=512)
else:
self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=self.parallel_config.max_num_seqs)
# TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn
if self.graph_opt_config.graph_opt_level == 2:

View File

@@ -81,14 +81,42 @@ class ForwardMeta:
attn_mask: Optional[paddle.Tensor] = None
# Attention mask offset
attn_mask_offsets: Optional[paddle.Tensor] = None
# A common pattern for launching CUDA kernels is to set the kernel's grids.x dimension
# using a `num_blocks` variable, and then map each thread block to a specific batch and
# data tile using `batch_ids` and `tile_ids_per_batch`.
#
# The variable names below follow this pattern, using a common prefix (e.g., `encoder_`, `decoder_`, `kv_`)
# for variables that are logically grouped together. The mapping works as follows:
#
# Usage: `my_kernel<<<grids, ...>>>(..., batch_ids, tile_ids, ...)`
# `grids.x` = `num_blocks_cpu`
# `batch_id` = `batch_ids[blockIdx.x]`
# `tile_id` = `tile_ids[blockIdx.x]`
# Maps the thread block index (blockIdx.x) to the corresponding batch for the decoder stage in multi_query_append_attention_warp1_4_kernel.
# Decoder batch id. Used by attention backend.
decoder_batch_ids: Optional[paddle.Tensor] = None
# Tile ID for each batch of the decoder. Used by attention backend.
# Maps the thread block index (blockIdx.x) to the specific data tile being processed within that batch for the decoder stage in multi_query_append_attention_warp1_4_kernel.
decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None
# The number of blocks that attention backend can use in decode stage
# The number of CUDA blocks to launch in the x-dimension for the multi_query_append_attention_warp1_4_kernel, defining its grids.x.
decoder_num_blocks_cpu: Optional[paddle.Tensor] = None
# Recorded multiple lengths related to prefill or decode
# A tensor that holds multiple lengths related to prefill or decode stages.
max_len_tensor_cpu: Optional[paddle.Tensor] = None
# Maps the thread block index (blockIdx.x) to the corresponding batch for the encoder stage in multi_query_append_attention_kernel.
encoder_batch_ids: Optional[paddle.Tensor] = None
# Maps the thread block index (blockIdx.x) to the specific data tile being processed within that batch for the encoder stage in multi_query_append_attention_kernel.
encoder_tile_ids_per_batch: Optional[paddle.Tensor] = None
# The number of CUDA blocks to launch in the x-dimension for the multi_query_append_attention_kernel, defining its grids.x.
encoder_num_blocks_x_cpu: Optional[paddle.Tensor] = None
# Maps the thread block index (blockIdx.x) to the corresponding batch for the append_write_cache_kv kernel.
kv_batch_ids: Optional[paddle.Tensor] = None
# Maps the thread block index (blockIdx.x) to the specific data tile being processed within that batch for the append_write_cache_kv kernel.
kv_tile_ids_per_batch: Optional[paddle.Tensor] = None
# The number of CUDA blocks to launch in the x-dimension for the append_write_cache_kv kernel, defining its grids.x.
kv_num_blocks_x_cpu: Optional[paddle.Tensor] = None
# The maximum sequence length of the KV cache, which may represent the current maximum decoder length.
max_len_kv_cpu: Optional[paddle.Tensor] = None
# Sequence length of encoder for ever batch
seq_lens_encoder: Optional[paddle.Tensor] = None
@@ -133,6 +161,7 @@ class ForwardMeta:
"shape": obj.shape,
"dtype": str(obj.dtype),
"place": str(obj.place),
# "content": obj if obj.numel()<10 else "Too big to show"
}
return tensor_info
elif isinstance(obj, (list, tuple)):

View File

@@ -49,14 +49,6 @@ class AppendAttentionMetadata(AttentionMetadata):
AppendAttentionMetadata
"""
encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
max_len_kv: paddle.Tensor = None
_dtype: paddle.dtype = paddle.bfloat16
encoder_max_partition_size: int = 32768
max_partition_size: int = 32768
@@ -142,15 +134,7 @@ class AppendAttentionBackend(AttentionBackend):
metadata.rotary_embs = forward_meta.rotary_embs
metadata.attn_mask = forward_meta.attn_mask
metadata.pre_caches_length = forward_meta.pre_caches_length
(
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.max_len_kv,
) = get_block_shape_and_split_kv_block(
get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
@@ -158,6 +142,13 @@ class AppendAttentionBackend(AttentionBackend):
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.max_len_kv_cpu,
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.group_size,
@@ -288,17 +279,17 @@ class AppendAttentionBackend(AttentionBackend):
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
metadata.max_len_kv,
forward_meta.max_len_kv_cpu,
res,
metadata.rotary_embs,
metadata.attn_mask,
@@ -344,17 +335,17 @@ class AppendAttentionBackend(AttentionBackend):
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
metadata.max_len_kv,
forward_meta.max_len_kv_cpu,
metadata.rotary_embs,
metadata.attn_mask,
layer.qkv_bias,

View File

@@ -65,13 +65,6 @@ class FlashAttentionMetadata(AttentionMetadata):
rotary_embs: Optional[paddle.Tensor] = None
block_tables: Optional[paddle.Tensor] = None
encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
max_len_kv: paddle.Tensor = None
cu_seqlens_q: paddle.Tensor = None
cu_seqlens_k: paddle.Tensor = None
@@ -198,15 +191,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
metadata.rotary_embs = forward_meta.rotary_embs
metadata.block_tables = forward_meta.block_tables
(
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.max_len_kv,
) = get_block_shape_and_split_kv_block(
get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
@@ -214,6 +199,13 @@ class FlashAttentionBackend(AttentionBackend):
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.max_len_kv_cpu,
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.group_size,
@@ -295,9 +287,9 @@ class FlashAttentionBackend(AttentionBackend):
forward_meta.seq_lens_decoder,
forward_meta.batch_id_per_token,
metadata.block_tables,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
metadata.pre_cache_batch_ids,
metadata.pre_cache_tile_ids_per_batch,
metadata.pre_cache_num_blocks_cpu,
@@ -336,17 +328,17 @@ class FlashAttentionBackend(AttentionBackend):
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.decoder_batch_ids, # from buffer
forward_meta.decoder_tile_ids_per_batch, # from buffer
forward_meta.decoder_num_blocks_cpu,
metadata.max_len_tensor_cpu_decoder,
metadata.max_len_kv,
forward_meta.max_len_kv_cpu,
metadata.rotary_embs,
forward_meta.attn_mask,
layer.qkv_bias,

View File

@@ -69,14 +69,6 @@ class MLAAttentionMetadata(AttentionMetadata):
MLAAttentionMetadata for Multi-Layer Attention
"""
encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
max_len_kv: paddle.Tensor = None
_dtype: paddle.dtype = paddle.bfloat16
encoder_max_partition_size: int = 32768
max_partition_size: int = 32768
@@ -191,15 +183,7 @@ class MLAAttentionBackend(AttentionBackend):
metadata.attn_mask = forward_meta.attn_mask
metadata.pre_caches_length = forward_meta.pre_caches_length
(
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.max_len_kv,
) = get_block_shape_and_split_kv_block(
get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
@@ -207,6 +191,13 @@ class MLAAttentionBackend(AttentionBackend):
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.max_len_kv_cpu,
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.group_size,
@@ -362,19 +353,19 @@ class MLAAttentionBackend(AttentionBackend):
forward_meta.cu_seqlens_q,
forward_meta.batch_id_per_token,
metadata.block_tables,
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.decoder_num_blocks_cpu,
metadata.max_enc_len_this_time,
metadata.max_dec_len_this_time,
metadata.max_len_kv,
forward_meta.max_len_kv_cpu,
None, # attn_mask
None, # qkv_bias
None, # qkv_out_scales
@@ -483,19 +474,19 @@ class MLAAttentionBackend(AttentionBackend):
forward_meta.cu_seqlens_q,
forward_meta.batch_id_per_token,
metadata.block_tables,
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
metadata.encoder_num_blocks,
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.decoder_num_blocks_cpu,
metadata.max_enc_len_this_time,
metadata.max_dec_len_this_time,
metadata.max_len_kv,
forward_meta.max_len_kv_cpu,
None, # attn_mask
None, # qkv_bias
None, # qkv_out_scales

View File

@@ -32,6 +32,13 @@ def get_block_shape_and_split_kv_block(
decoder_tile_ids_per_batch: paddle.Tensor,
decoder_num_blocks_x_cpu: paddle.Tensor,
max_len_tensor_cpu: paddle.Tensor,
encoder_batch_ids: paddle.Tensor,
encoder_tile_ids_per_batch: paddle.Tensor,
encoder_num_blocks_x_cpu: paddle.Tensor,
kv_batch_ids: paddle.Tensor,
kv_tile_ids_per_batch: paddle.Tensor,
kv_num_blocks_x_cpu: paddle.Tensor,
max_len_kv_cpu: paddle.Tensor,
encoder_block_shape_q: int,
decoder_block_shape_q: int,
group_size: int,
@@ -42,15 +49,7 @@ def get_block_shape_and_split_kv_block(
get_block_shape_and_split_kv_block
"""
if current_platform.is_cuda():
(
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
max_len_kv_cpu,
) = get_block_shape_and_split_kv_block_cuda(
get_block_shape_and_split_kv_block_cuda(
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
@@ -58,20 +57,19 @@ def get_block_shape_and_split_kv_block(
decoder_tile_ids_per_batch,
decoder_num_blocks_x_cpu,
max_len_tensor_cpu,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks_x_cpu,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks_x_cpu,
max_len_kv_cpu,
encoder_block_shape_q,
decoder_block_shape_q,
group_size,
block_size,
decoder_step_token_num,
)
return (
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
max_len_kv_cpu,
)
else:
raise NotImplementedError

View File

@@ -212,6 +212,22 @@ class MTPProposer(Proposer):
self.target_model_inputs["max_len_tensor_cpu"]
).cpu()
self.model_inputs["encoder_batch_ids"] = paddle.zeros_like(self.target_model_inputs["encoder_batch_ids"])
self.model_inputs["encoder_tile_ids_per_batch"] = paddle.zeros_like(
self.target_model_inputs["encoder_tile_ids_per_batch"]
)
self.model_inputs["encoder_num_blocks_x_cpu"] = paddle.zeros_like(
self.target_model_inputs["encoder_num_blocks_x_cpu"]
).cpu()
self.model_inputs["kv_batch_ids"] = paddle.zeros_like(self.target_model_inputs["kv_batch_ids"])
self.model_inputs["kv_tile_ids_per_batch"] = paddle.zeros_like(
self.target_model_inputs["kv_tile_ids_per_batch"]
)
self.model_inputs["kv_num_blocks_x_cpu"] = paddle.zeros_like(
self.target_model_inputs["kv_num_blocks_x_cpu"]
).cpu()
self.model_inputs["max_len_kv_cpu"] = paddle.zeros_like(self.target_model_inputs["max_len_kv_cpu"]).cpu()
# Get the attention backend
attn_cls = get_attention_backend()
attn_backend = attn_cls(
@@ -321,6 +337,13 @@ class MTPProposer(Proposer):
self.model_inputs["decoder_tile_ids_per_batch"] = None
self.model_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
self.model_inputs["max_len_tensor_cpu"] = None # CPU
self.model_inputs["encoder_batch_ids"] = None
self.model_inputs["encoder_tile_ids_per_batch"] = None
self.model_inputs["encoder_num_blocks_x_cpu"] = None # CPU
self.model_inputs["kv_batch_ids"] = None
self.model_inputs["kv_tile_ids_per_batch"] = None
self.model_inputs["kv_num_blocks_x_cpu"] = None # CPU
self.model_inputs["max_len_kv_cpu"] = None # CPU
# Input tokens
self.model_inputs["draft_tokens"] = paddle.full(
@@ -512,6 +535,13 @@ class MTPProposer(Proposer):
cu_seqlens_k=self.model_inputs["cu_seqlens_k"],
block_tables=self.model_inputs["block_tables"],
caches=self.model_inputs["caches"],
encoder_batch_ids=self.model_inputs["encoder_batch_ids"],
encoder_tile_ids_per_batch=self.model_inputs["encoder_tile_ids_per_batch"],
encoder_num_blocks_x_cpu=self.model_inputs["encoder_num_blocks_x_cpu"],
kv_batch_ids=self.model_inputs["kv_batch_ids"],
kv_tile_ids_per_batch=self.model_inputs["kv_tile_ids_per_batch"],
kv_num_blocks_x_cpu=self.model_inputs["kv_num_blocks_x_cpu"],
max_len_kv_cpu=self.model_inputs["max_len_kv_cpu"],
)
# Initialzie attention meta data

View File

@@ -430,6 +430,13 @@ class GCUModelRunner(ModelRunnerBase):
self.share_inputs["decoder_tile_ids_per_batch"] = None
self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
self.share_inputs["max_len_tensor_cpu"] = None # CPU
self.share_inputs["encoder_batch_ids"] = None
self.share_inputs["encoder_tile_ids_per_batch"] = None
self.share_inputs["encoder_num_blocks_x_cpu"] = None # CPU
self.share_inputs["kv_batch_ids"] = None
self.share_inputs["kv_tile_ids_per_batch"] = None
self.share_inputs["kv_num_blocks_x_cpu"] = None # CPU
self.share_inputs["max_len_kv_cpu"] = None # CPU
# Initialize rotary position embedding
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
@@ -601,6 +608,13 @@ class GCUModelRunner(ModelRunnerBase):
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
block_tables=self.share_inputs["block_tables"],
caches=self.share_inputs["caches"],
encoder_batch_ids=self.share_inputs["encoder_batch_ids"],
encoder_tile_ids_per_batch=self.share_inputs["encoder_tile_ids_per_batch"],
encoder_num_blocks_x_cpu=self.share_inputs["encoder_num_blocks_x_cpu"],
kv_batch_ids=self.share_inputs["kv_batch_ids"],
kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"],
kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"],
max_len_kv_cpu=self.share_inputs["max_len_kv_cpu"],
)
# Update Batch type for cuda graph
@@ -673,14 +687,31 @@ class GCUModelRunner(ModelRunnerBase):
encoder_block_shape_q = 64
decoder_block_shape_q = 16
decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1
group_size = np.ceil(num_heads / self.model_config.kv_num_heads)
decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
(decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q
(decoder_step_token_num * group_size) / decoder_block_shape_q
)
encode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
(self.model_config.max_model_len * group_size) / encoder_block_shape_q
)
kv_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
self.model_config.max_model_len / self.fd_config.cache_config.block_size
)
self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu()
self.share_inputs["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
self.share_inputs["encoder_tile_ids_per_batch"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
self.share_inputs["encoder_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
self.share_inputs["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
self.share_inputs["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
self.share_inputs["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
self.share_inputs["max_len_kv_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
# Get the attention backend
attn_cls = get_attention_backend()
attn_backend = attn_cls(

View File

@@ -142,6 +142,7 @@ class GPUModelRunner(ModelRunnerBase):
self.use_cudagraph = self.graph_opt_config.use_cudagraph
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
self.cudagraph_only_prefill = self.graph_opt_config.cudagraph_only_prefill
# Initialize share inputs
self._init_share_inputs(self.parallel_config.max_num_seqs)
@@ -177,10 +178,49 @@ class GPUModelRunner(ModelRunnerBase):
"""
check whether prefill stage exist
"""
if int(paddle.max(self.share_inputs["seq_lens_encoder"])) != 0:
return 1
else:
return 0
return int(paddle.max(self.share_inputs["seq_lens_encoder"])) > 0
def exist_decode(self):
"""
check whether decode stage exist
"""
return int(paddle.max(self.share_inputs["seq_lens_decoder"])) > 0
def only_prefill(self):
"""
check whether prefill only
"""
if_only_prefill = True
decode_exists = None
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
only_prefill_batch_list = []
decode_exists = self.exist_decode()
paddle.distributed.all_gather_object(only_prefill_batch_list, not decode_exists)
if_only_prefill = all(only_prefill_batch_list)
if_only_prefill = if_only_prefill and not (decode_exists if decode_exists is not None else self.exist_decode())
return if_only_prefill
def only_decode(self):
"""
check whether decode only
"""
# Update Batch type for cuda graph for if_only_decode
if_only_decode = True
prefill_exists = None
# mix ep in single node
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
only_decode_batch_list = []
prefill_exists = self.exist_prefill()
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
if_only_decode = all(only_decode_batch_list)
if_only_decode = if_only_decode and not (
prefill_exists if prefill_exists is not None else self.exist_prefill()
)
return if_only_decode
def _init_speculative_proposer(self):
"""
@@ -600,27 +640,81 @@ class GPUModelRunner(ModelRunnerBase):
if self.speculative_method in ["mtp"]:
self.proposer.insert_prefill_inputs(req_dicts, num_running_requests)
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int):
"""Set dummy prefill inputs to share_inputs"""
def get_input_length_list(
self, num_tokens: int, batch_size: int, expected_decode_len: int, capture_prefill: bool = False
):
"""
Generates some list for _dummy_prefill_inputs, when capture pure prefill or mtp,
the list should be carefully constructed.
This function addresses a specific problem: in the pure prefill stage, variable
input lengths (e.g., `prompt[160, 0]` vs. `prompt[80, 80]`) can lead to different
CUDA Grid dimensions for kernels like `split_q_block`. This prevents CUDA Graph
reuse.
The `split_q_block` kernel calculates the total number of blocks, which directly
determines the `griddim.x` launch parameter for the `multi_query_append_attention_kernel`.
The blocks for a single sequence are determined by the formula:
`num_blocks = ceil((sequence_length * group_size) / block_shape_q)`
Due to the `ceil` (ceiling) function, distributing a total number of tokens across
a batch of shorter sequences will result in a larger total block count. For example,
with a `group_size` of 5 and `block_shape_q` of 64:
- A single sequence of 160 tokens requires `ceil((160 * 5) / 64) = 13` blocks.
- Two sequences of 80 tokens each require `ceil((80 * 5) / 64) * 2 = 7 * 2 = 14` blocks.
To ensure graph replayability, this function creates a "dummy" list of sequence
lengths that's designed to produce the theoretical maximum `encoder_num_blocks_x_cpu`
for the given `num_tokens` and `batch_size`. This strategy ensures the captured
CUDA Graph has the largest possible grid dimensions. At runtime, if the actual number
of blocks is less than or equal to this maximum, the kernel can safely execute by
using an early-exit mechanism.
Args:
num_tokens (int): The total number of tokens across all sequences.
batch_size (int): The number of sequences (requests) in the batch.
Returns:
List[int]: A list of integers representing the sequence length for each request.
This list is crafted to maximize the total number of blocks.
"""
# NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token
max_dec_len = expected_decode_len + 1
full_length = min(
num_tokens // batch_size,
input_length = min(
num_tokens // (1 if capture_prefill else batch_size),
self.parallel_config.max_model_len - max_dec_len,
)
# NOTE(wanglongzhi): When the full length is too large, DeepEP's buffer size will not be enough to cause the result to appear nan.
# TODO(wanglongzhi): Figure out the accurate buffer size of DeepEP.
if self.fd_config.parallel_config.enable_expert_parallel:
full_length = min(full_length, 32)
input_length = min(input_length, 32)
input_length = int(full_length * self.cache_config.kv_cache_ratio)
block_num = (
input_length + self.cache_config.block_size - 1
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
input_length_list = [input_length] * batch_size
if capture_prefill:
if num_tokens < batch_size:
input_length_list = [1] * num_tokens
else:
input_length_list = [1] * (batch_size - 1)
input_length_list.append(num_tokens - batch_size + 1)
len_of_input_length_list = len(input_length_list)
max_dec_len_list = [max_dec_len] * len_of_input_length_list
return input_length_list, max_dec_len_list, block_num
def _dummy_prefill_inputs(self, input_length_list: List[int], max_dec_len_list: List[int], block_num: int):
"""Set dummy prefill inputs to share_inputs"""
batch_size = len(input_length_list)
for i in range(batch_size):
idx = i
input_length = input_length_list[i]
max_dec_len = max_dec_len_list[i]
self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
self.share_inputs["prompt_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
self.share_inputs["eos_token_id"][:] = np.array(
@@ -745,6 +839,13 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["decoder_tile_ids_per_batch"] = None
self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
self.share_inputs["max_len_tensor_cpu"] = None # CPU
self.share_inputs["encoder_batch_ids"] = None
self.share_inputs["encoder_tile_ids_per_batch"] = None
self.share_inputs["encoder_num_blocks_x_cpu"] = None # CPU
self.share_inputs["kv_batch_ids"] = None
self.share_inputs["kv_tile_ids_per_batch"] = None
self.share_inputs["kv_num_blocks_x_cpu"] = None # CPU
self.share_inputs["max_len_kv_cpu"] = None # CPU
# Initialize rotary position embedding
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
@@ -977,23 +1078,30 @@ class GPUModelRunner(ModelRunnerBase):
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
block_tables=self.share_inputs["block_tables"],
caches=self.share_inputs["caches"],
encoder_batch_ids=self.share_inputs["encoder_batch_ids"],
encoder_tile_ids_per_batch=self.share_inputs["encoder_tile_ids_per_batch"],
encoder_num_blocks_x_cpu=self.share_inputs["encoder_num_blocks_x_cpu"],
kv_batch_ids=self.share_inputs["kv_batch_ids"],
kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"],
kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"],
max_len_kv_cpu=self.share_inputs["max_len_kv_cpu"],
)
# Update Batch type for cuda graph
only_decode_batch = True
prefill_exists = None
# mix ep in single node
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
only_decode_batch_list = []
prefill_exists = self.exist_prefill()
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
only_decode_batch = all(only_decode_batch_list)
self.fd_config.parallel_config.moe_phase.phase = "decode" if only_decode_batch else "prefill"
# Update Batch type for cuda graph for only_decode_batch
if_only_decode = self.only_decode()
only_decode_use_cudagraph = self.use_cudagraph and if_only_decode
# Update config about moe for better performance
# TODO(wanglongzhi):Modifying the config at runtime is not appropriate; it needs to be moved to forward_meta. It will be used in MoEMethodBase.apply()
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
self.fd_config.parallel_config.moe_phase.phase = "decode" if if_only_decode else "prefill"
# Update Batch type for cuda graph for only_prefill_batch
only_prefill_use_cudagraph = self.use_cudagraph and self.cudagraph_only_prefill and self.only_prefill()
# When support capture both prefill-only and decode-only, this will use [only_prefill_use_cudagraph or only_decode_use_cudagraph]
self.forward_meta.step_use_cudagraph = (
self.use_cudagraph
and only_decode_batch
and not (prefill_exists if prefill_exists is not None else self.exist_prefill())
only_prefill_use_cudagraph if self.cudagraph_only_prefill else only_decode_use_cudagraph
)
# Initialzie attention meta data
@@ -1085,14 +1193,31 @@ class GPUModelRunner(ModelRunnerBase):
encoder_block_shape_q = 64
decoder_block_shape_q = 16
decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1
group_size = np.ceil(num_heads / self.model_config.kv_num_heads)
decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
(decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q
(decoder_step_token_num * group_size) / decoder_block_shape_q
)
encode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
(self.model_config.max_model_len * group_size) / encoder_block_shape_q
)
kv_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
self.model_config.max_model_len / self.fd_config.cache_config.block_size
)
self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory()
self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu()
self.share_inputs["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
self.share_inputs["encoder_tile_ids_per_batch"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
self.share_inputs["encoder_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
self.share_inputs["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
self.share_inputs["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
self.share_inputs["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
self.share_inputs["max_len_kv_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
# Get the attention backend
attn_cls = get_attention_backend()
attn_backend = attn_cls(
@@ -1112,6 +1237,7 @@ class GPUModelRunner(ModelRunnerBase):
batch_size: paddle.Tensor,
expected_decode_len: int = 1,
in_capturing: bool = False,
capture_prefill: bool = False,
) -> paddle.Tensor:
"""
Use dummy inputs to run before formal execution.
@@ -1119,11 +1245,19 @@ class GPUModelRunner(ModelRunnerBase):
num_tokens:
expected_decode_len: Expected number of tokens generated
in_capturing: Is cuda graph in capturing state
capture_prefill: Capture pure prefill for cuda graph
"""
self._dummy_prefill_inputs(
input_length_list, max_dec_len_list, block_num = self.get_input_length_list(
num_tokens=num_tokens,
batch_size=batch_size,
expected_decode_len=expected_decode_len,
capture_prefill=capture_prefill,
)
self._dummy_prefill_inputs(
input_length_list=input_length_list,
max_dec_len_list=max_dec_len_list,
block_num=block_num,
)
if self.speculative_method in ["mtp"]:
self.proposer.dummy_prefill_inputs(
@@ -1353,6 +1487,20 @@ class GPUModelRunner(ModelRunnerBase):
time_before_capture = time.perf_counter()
expected_decode_len = 1
capture_sizes = self.cudagraph_capture_sizes.copy()
if self.fd_config.graph_opt_config.cudagraph_only_prefill:
for num_tokens in sorted(capture_sizes, reverse=True):
self._dummy_run(
num_tokens=num_tokens,
batch_size=self.parallel_config.max_num_seqs,
in_capturing=True,
expected_decode_len=expected_decode_len,
capture_prefill=True,
)
logger.info(
f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}"
)
else:
for batch_size in sorted(capture_sizes, reverse=True):
self._dummy_run(
num_tokens=self.parallel_config.max_num_batched_tokens,
@@ -1360,7 +1508,9 @@ class GPUModelRunner(ModelRunnerBase):
in_capturing=True,
expected_decode_len=expected_decode_len,
)
logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}")
logger.info(
f"Warm up the model with the num_tokens:{batch_size}, expected_decode_len:{expected_decode_len}"
)
time_after_capture = time.perf_counter()
logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds")

View File

@@ -397,7 +397,6 @@ class PaddleDisWorkerProc:
self.get_profile_block_num_signal.value[0] = num_blocks_local
else:
num_blocks_local = self.fd_config.parallel_config.total_block_num
logger.info(f"------- num_blocks_global: {num_blocks_local} --------")
# wait engine launch cache_manager
if self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed":

View File

@@ -157,7 +157,7 @@ class TestCUDAGrpahSubgraph(unittest.TestCase):
cache_config = CacheConfig({})
# Initialize cuda graph capture list
graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs)
graph_opt_config.init_with_cudagrpah_size(max_num_seqs=parallel_config.max_num_seqs)
graph_opt_config.init_with_cudagrpah_size(max_capture_size=parallel_config.max_num_seqs)
fd_config = FDConfig(
graph_opt_config=graph_opt_config,
parallel_config=parallel_config,

View File

@@ -104,7 +104,7 @@ class TestCUDAGrpahSpecDecode(unittest.TestCase):
cache_config = CacheConfig({})
# Initialize cuda graph capture list
graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs)
graph_opt_config.init_with_cudagrpah_size(max_num_seqs=parallel_config.max_num_seqs)
graph_opt_config.init_with_cudagrpah_size(max_capture_size=parallel_config.max_num_seqs)
fd_config = FDConfig(
graph_opt_config=graph_opt_config,
parallel_config=parallel_config,

View File

@@ -90,7 +90,7 @@ class TestStaticGraphCUDAGraphSplit(unittest.TestCase):
graph_opt_config = GraphOptimizationConfig({"use_cudagraph": True, "graph_opt_level": 1})
parallel_config = ParallelConfig({"max_num_seqs": 1})
graph_opt_config._set_cudagraph_sizes(max_num_seqs=parallel_config.max_num_seqs)
graph_opt_config.init_with_cudagrpah_size(max_num_seqs=parallel_config.max_num_seqs)
graph_opt_config.init_with_cudagrpah_size(max_capture_size=parallel_config.max_num_seqs)
cache_config = CacheConfig({})
fd_config = FDConfig(

View File

@@ -386,6 +386,14 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").pin_memory()
self.max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
self.encoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32")
self.encoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
self.encoder_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu()
self.kv_batch_ids = paddle.full([self.batch_size], 0, dtype="int32")
self.kv_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
self.kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu()
self.max_len_kv_cpu = paddle.full([1], 0, dtype="int32").cpu()
self.cache_shape = (
self.max_block_num,
self.kv_num_head,
@@ -469,15 +477,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
get_block_shape_and_split_kv_block,
)
(
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
max_len_kv,
) = get_block_shape_and_split_kv_block(
get_block_shape_and_split_kv_block(
self.seq_lens_encoder,
self.seq_lens_decoder,
self.seq_lens_this_time,
@@ -485,6 +485,13 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.decoder_tile_ids_per_batch,
self.decoder_num_blocks_cpu,
self.max_len_tensor_cpu,
self.encoder_batch_ids,
self.encoder_tile_ids_per_batch,
self.encoder_num_blocks_x_cpu,
self.kv_batch_ids,
self.kv_tile_ids_per_batch,
self.kv_num_blocks_x_cpu,
self.max_len_kv_cpu,
64,
12,
(self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head,
@@ -508,17 +515,17 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.padding_offset,
self.cum_offset,
self.block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
self.encoder_batch_ids,
self.encoder_tile_ids_per_batch,
self.encoder_num_blocks_x_cpu,
self.kv_batch_ids,
self.kv_tile_ids_per_batch,
self.kv_num_blocks_x_cpu,
self.decoder_batch_ids,
self.decoder_tile_ids_per_batch,
self.decoder_num_blocks_cpu,
self.max_len_tensor_cpu,
max_len_kv,
self.max_len_kv_cpu,
self.rope_emb, # rope_emb
None, # attn_mask
None, # qkv_bias

View File

@@ -382,6 +382,13 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.decoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
self.decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").pin_memory()
self.max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
self.encoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32")
self.encoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
self.encoder_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu()
self.kv_batch_ids = paddle.full([self.batch_size], 0, dtype="int32")
self.kv_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
self.kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu()
self.max_len_kv_cpu = paddle.full([1], 0, dtype="int32").cpu()
self.cache_shape = (
self.max_block_num,
@@ -450,15 +457,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
get_block_shape_and_split_kv_block,
)
(
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
max_len_kv,
) = get_block_shape_and_split_kv_block(
get_block_shape_and_split_kv_block(
self.seq_lens_encoder,
self.seq_lens_decoder,
self.seq_lens_this_time,
@@ -466,6 +465,13 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.decoder_tile_ids_per_batch,
self.decoder_num_blocks_cpu,
self.max_len_tensor_cpu,
self.encoder_batch_ids,
self.encoder_tile_ids_per_batch,
self.encoder_num_blocks_x_cpu,
self.kv_batch_ids,
self.kv_tile_ids_per_batch,
self.kv_num_blocks_x_cpu,
self.max_len_kv_cpu,
64,
12,
(self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head,
@@ -491,17 +497,17 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.padding_offset,
self.cum_offset,
self.block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
self.encoder_batch_ids,
self.encoder_tile_ids_per_batch,
self.encoder_num_blocks_x_cpu,
self.kv_batch_ids,
self.kv_tile_ids_per_batch,
self.kv_num_blocks_x_cpu,
self.decoder_batch_ids,
self.decoder_tile_ids_per_batch,
self.decoder_num_blocks_cpu,
self.max_len_tensor_cpu,
max_len_kv,
self.max_len_kv_cpu,
out,
self.rope_emb, # rope_emb
None, # attn_mask

View File

@@ -190,30 +190,32 @@ class TestTreeMask(unittest.TestCase):
encoder_block_shape_q = 64
decoder_block_shape_q = 16
group_size = self.num_q_head // self.num_kv_head
decode_max_tile_size = (
self.bsz
* (decoder_step_token_num * (self.num_q_head // self.num_kv_head) + decoder_block_shape_q - 1)
/ decoder_block_shape_q
self.bsz * (decoder_step_token_num * group_size + decoder_block_shape_q - 1) / decoder_block_shape_q
)
encode_max_tile_size = (
self.bsz * (self.max_seq_len * group_size + encoder_block_shape_q - 1) / encoder_block_shape_q
)
kv_max_tile_size = self.bsz * (self.max_seq_len + self.block_size - 1) / self.block_size
decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory()
max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
encoder_batch_ids = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
encoder_tile_ids_per_batch = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
encoder_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu()
kv_batch_ids = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
kv_tile_ids_per_batch = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu()
max_len_kv_cpu = paddle.full([1], 0, dtype="int32").cpu()
q_norm_weight = np.ones([self.head_dim])
k_norm_weight = np.ones([self.head_dim])
self.q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype="float32")
self.k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype="float32")
paddle.device.synchronize()
(
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
max_len_kv,
) = get_block_shape_and_split_kv_block(
get_block_shape_and_split_kv_block(
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
@@ -221,6 +223,13 @@ class TestTreeMask(unittest.TestCase):
decoder_tile_ids_per_batch,
decoder_num_blocks,
max_len_tensor_cpu,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks_x_cpu,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks_x_cpu,
max_len_kv_cpu,
encoder_block_shape_q,
decoder_block_shape_q,
self.num_q_head // self.num_kv_head,
@@ -243,15 +252,15 @@ class TestTreeMask(unittest.TestCase):
self.block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
encoder_num_blocks_x_cpu,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
kv_num_blocks_x_cpu,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
max_len_tensor_cpu,
max_len_kv,
max_len_kv_cpu,
rotary_embs,
attn_mask,
None, # qkv_bias