[Excutor] Experiment Feature-Support Prefill in cudagraph (#3459)

* Support prefill in Cudagraph

* Refactor GetBlockShapeAndSplitKVBlock Kernel V2

* Refactor GetBlockShapeAndSplitKVBlock Kernel V2.1

* Refactor GetBlockShapeAndSplitKVBlock Kernel V2.2

* Refactor GetBlockShapeAndSplitKVBlock Kernel V2.3

* Refactor GetBlockShapeAndSplitKVBlock Kernel V2.4

* Refactor GetBlockShapeAndSplitKVBlock Kernel V2.5

* Solve problem about encoder_num_blocks_x_cpu

* Add early-exit mechanism for attention kernel

* fix test case about append-attention

* Update testcode, Add annotations to related tensors

* move get_input_length_list

* solve test_code

* Add annotations about early-exit for attention kernel

* Add annotations about early-exit for attention kernel2

* solve comment

* solve mtp

---------

Co-authored-by: RAM <gstian5555@outlook.com>
This commit is contained in:
Jundong Liu
2025-09-08 13:12:24 +08:00
committed by GitHub
parent 472402bf4e
commit 3d0aaa5923
21 changed files with 528 additions and 260 deletions

View File

@@ -142,6 +142,7 @@ class GPUModelRunner(ModelRunnerBase):
self.use_cudagraph = self.graph_opt_config.use_cudagraph
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
self.cudagraph_only_prefill = self.graph_opt_config.cudagraph_only_prefill
# Initialize share inputs
self._init_share_inputs(self.parallel_config.max_num_seqs)
@@ -177,10 +178,49 @@ class GPUModelRunner(ModelRunnerBase):
"""
check whether prefill stage exist
"""
if int(paddle.max(self.share_inputs["seq_lens_encoder"])) != 0:
return 1
else:
return 0
return int(paddle.max(self.share_inputs["seq_lens_encoder"])) > 0
def exist_decode(self):
"""
check whether decode stage exist
"""
return int(paddle.max(self.share_inputs["seq_lens_decoder"])) > 0
def only_prefill(self):
"""
check whether prefill only
"""
if_only_prefill = True
decode_exists = None
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
only_prefill_batch_list = []
decode_exists = self.exist_decode()
paddle.distributed.all_gather_object(only_prefill_batch_list, not decode_exists)
if_only_prefill = all(only_prefill_batch_list)
if_only_prefill = if_only_prefill and not (decode_exists if decode_exists is not None else self.exist_decode())
return if_only_prefill
def only_decode(self):
"""
check whether decode only
"""
# Update Batch type for cuda graph for if_only_decode
if_only_decode = True
prefill_exists = None
# mix ep in single node
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
only_decode_batch_list = []
prefill_exists = self.exist_prefill()
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
if_only_decode = all(only_decode_batch_list)
if_only_decode = if_only_decode and not (
prefill_exists if prefill_exists is not None else self.exist_prefill()
)
return if_only_decode
def _init_speculative_proposer(self):
"""
@@ -600,27 +640,81 @@ class GPUModelRunner(ModelRunnerBase):
if self.speculative_method in ["mtp"]:
self.proposer.insert_prefill_inputs(req_dicts, num_running_requests)
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int):
"""Set dummy prefill inputs to share_inputs"""
def get_input_length_list(
self, num_tokens: int, batch_size: int, expected_decode_len: int, capture_prefill: bool = False
):
"""
Generates some list for _dummy_prefill_inputs, when capture pure prefill or mtp,
the list should be carefully constructed.
This function addresses a specific problem: in the pure prefill stage, variable
input lengths (e.g., `prompt[160, 0]` vs. `prompt[80, 80]`) can lead to different
CUDA Grid dimensions for kernels like `split_q_block`. This prevents CUDA Graph
reuse.
The `split_q_block` kernel calculates the total number of blocks, which directly
determines the `griddim.x` launch parameter for the `multi_query_append_attention_kernel`.
The blocks for a single sequence are determined by the formula:
`num_blocks = ceil((sequence_length * group_size) / block_shape_q)`
Due to the `ceil` (ceiling) function, distributing a total number of tokens across
a batch of shorter sequences will result in a larger total block count. For example,
with a `group_size` of 5 and `block_shape_q` of 64:
- A single sequence of 160 tokens requires `ceil((160 * 5) / 64) = 13` blocks.
- Two sequences of 80 tokens each require `ceil((80 * 5) / 64) * 2 = 7 * 2 = 14` blocks.
To ensure graph replayability, this function creates a "dummy" list of sequence
lengths that's designed to produce the theoretical maximum `encoder_num_blocks_x_cpu`
for the given `num_tokens` and `batch_size`. This strategy ensures the captured
CUDA Graph has the largest possible grid dimensions. At runtime, if the actual number
of blocks is less than or equal to this maximum, the kernel can safely execute by
using an early-exit mechanism.
Args:
num_tokens (int): The total number of tokens across all sequences.
batch_size (int): The number of sequences (requests) in the batch.
Returns:
List[int]: A list of integers representing the sequence length for each request.
This list is crafted to maximize the total number of blocks.
"""
# NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token
max_dec_len = expected_decode_len + 1
full_length = min(
num_tokens // batch_size,
input_length = min(
num_tokens // (1 if capture_prefill else batch_size),
self.parallel_config.max_model_len - max_dec_len,
)
# NOTE(wanglongzhi): When the full length is too large, DeepEP's buffer size will not be enough to cause the result to appear nan.
# TODO(wanglongzhi): Figure out the accurate buffer size of DeepEP.
if self.fd_config.parallel_config.enable_expert_parallel:
full_length = min(full_length, 32)
input_length = min(input_length, 32)
input_length = int(full_length * self.cache_config.kv_cache_ratio)
block_num = (
input_length + self.cache_config.block_size - 1
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
input_length_list = [input_length] * batch_size
if capture_prefill:
if num_tokens < batch_size:
input_length_list = [1] * num_tokens
else:
input_length_list = [1] * (batch_size - 1)
input_length_list.append(num_tokens - batch_size + 1)
len_of_input_length_list = len(input_length_list)
max_dec_len_list = [max_dec_len] * len_of_input_length_list
return input_length_list, max_dec_len_list, block_num
def _dummy_prefill_inputs(self, input_length_list: List[int], max_dec_len_list: List[int], block_num: int):
"""Set dummy prefill inputs to share_inputs"""
batch_size = len(input_length_list)
for i in range(batch_size):
idx = i
input_length = input_length_list[i]
max_dec_len = max_dec_len_list[i]
self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
self.share_inputs["prompt_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
self.share_inputs["eos_token_id"][:] = np.array(
@@ -745,6 +839,13 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["decoder_tile_ids_per_batch"] = None
self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
self.share_inputs["max_len_tensor_cpu"] = None # CPU
self.share_inputs["encoder_batch_ids"] = None
self.share_inputs["encoder_tile_ids_per_batch"] = None
self.share_inputs["encoder_num_blocks_x_cpu"] = None # CPU
self.share_inputs["kv_batch_ids"] = None
self.share_inputs["kv_tile_ids_per_batch"] = None
self.share_inputs["kv_num_blocks_x_cpu"] = None # CPU
self.share_inputs["max_len_kv_cpu"] = None # CPU
# Initialize rotary position embedding
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
@@ -977,23 +1078,30 @@ class GPUModelRunner(ModelRunnerBase):
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
block_tables=self.share_inputs["block_tables"],
caches=self.share_inputs["caches"],
encoder_batch_ids=self.share_inputs["encoder_batch_ids"],
encoder_tile_ids_per_batch=self.share_inputs["encoder_tile_ids_per_batch"],
encoder_num_blocks_x_cpu=self.share_inputs["encoder_num_blocks_x_cpu"],
kv_batch_ids=self.share_inputs["kv_batch_ids"],
kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"],
kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"],
max_len_kv_cpu=self.share_inputs["max_len_kv_cpu"],
)
# Update Batch type for cuda graph
only_decode_batch = True
prefill_exists = None
# mix ep in single node
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
only_decode_batch_list = []
prefill_exists = self.exist_prefill()
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
only_decode_batch = all(only_decode_batch_list)
self.fd_config.parallel_config.moe_phase.phase = "decode" if only_decode_batch else "prefill"
# Update Batch type for cuda graph for only_decode_batch
if_only_decode = self.only_decode()
only_decode_use_cudagraph = self.use_cudagraph and if_only_decode
# Update config about moe for better performance
# TODO(wanglongzhi):Modifying the config at runtime is not appropriate; it needs to be moved to forward_meta. It will be used in MoEMethodBase.apply()
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
self.fd_config.parallel_config.moe_phase.phase = "decode" if if_only_decode else "prefill"
# Update Batch type for cuda graph for only_prefill_batch
only_prefill_use_cudagraph = self.use_cudagraph and self.cudagraph_only_prefill and self.only_prefill()
# When support capture both prefill-only and decode-only, this will use [only_prefill_use_cudagraph or only_decode_use_cudagraph]
self.forward_meta.step_use_cudagraph = (
self.use_cudagraph
and only_decode_batch
and not (prefill_exists if prefill_exists is not None else self.exist_prefill())
only_prefill_use_cudagraph if self.cudagraph_only_prefill else only_decode_use_cudagraph
)
# Initialzie attention meta data
@@ -1085,14 +1193,31 @@ class GPUModelRunner(ModelRunnerBase):
encoder_block_shape_q = 64
decoder_block_shape_q = 16
decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1
group_size = np.ceil(num_heads / self.model_config.kv_num_heads)
decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
(decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q
(decoder_step_token_num * group_size) / decoder_block_shape_q
)
encode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
(self.model_config.max_model_len * group_size) / encoder_block_shape_q
)
kv_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
self.model_config.max_model_len / self.fd_config.cache_config.block_size
)
self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory()
self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu()
self.share_inputs["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
self.share_inputs["encoder_tile_ids_per_batch"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
self.share_inputs["encoder_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
self.share_inputs["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
self.share_inputs["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
self.share_inputs["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
self.share_inputs["max_len_kv_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
# Get the attention backend
attn_cls = get_attention_backend()
attn_backend = attn_cls(
@@ -1112,6 +1237,7 @@ class GPUModelRunner(ModelRunnerBase):
batch_size: paddle.Tensor,
expected_decode_len: int = 1,
in_capturing: bool = False,
capture_prefill: bool = False,
) -> paddle.Tensor:
"""
Use dummy inputs to run before formal execution.
@@ -1119,11 +1245,19 @@ class GPUModelRunner(ModelRunnerBase):
num_tokens:
expected_decode_len: Expected number of tokens generated
in_capturing: Is cuda graph in capturing state
capture_prefill: Capture pure prefill for cuda graph
"""
self._dummy_prefill_inputs(
input_length_list, max_dec_len_list, block_num = self.get_input_length_list(
num_tokens=num_tokens,
batch_size=batch_size,
expected_decode_len=expected_decode_len,
capture_prefill=capture_prefill,
)
self._dummy_prefill_inputs(
input_length_list=input_length_list,
max_dec_len_list=max_dec_len_list,
block_num=block_num,
)
if self.speculative_method in ["mtp"]:
self.proposer.dummy_prefill_inputs(
@@ -1353,14 +1487,30 @@ class GPUModelRunner(ModelRunnerBase):
time_before_capture = time.perf_counter()
expected_decode_len = 1
capture_sizes = self.cudagraph_capture_sizes.copy()
for batch_size in sorted(capture_sizes, reverse=True):
self._dummy_run(
num_tokens=self.parallel_config.max_num_batched_tokens,
batch_size=batch_size,
in_capturing=True,
expected_decode_len=expected_decode_len,
)
logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}")
if self.fd_config.graph_opt_config.cudagraph_only_prefill:
for num_tokens in sorted(capture_sizes, reverse=True):
self._dummy_run(
num_tokens=num_tokens,
batch_size=self.parallel_config.max_num_seqs,
in_capturing=True,
expected_decode_len=expected_decode_len,
capture_prefill=True,
)
logger.info(
f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}"
)
else:
for batch_size in sorted(capture_sizes, reverse=True):
self._dummy_run(
num_tokens=self.parallel_config.max_num_batched_tokens,
batch_size=batch_size,
in_capturing=True,
expected_decode_len=expected_decode_len,
)
logger.info(
f"Warm up the model with the num_tokens:{batch_size}, expected_decode_len:{expected_decode_len}"
)
time_after_capture = time.perf_counter()
logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds")