mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Excutor] Experiment Feature-Support Prefill in cudagraph (#3459)
* Support prefill in Cudagraph * Refactor GetBlockShapeAndSplitKVBlock Kernel V2 * Refactor GetBlockShapeAndSplitKVBlock Kernel V2.1 * Refactor GetBlockShapeAndSplitKVBlock Kernel V2.2 * Refactor GetBlockShapeAndSplitKVBlock Kernel V2.3 * Refactor GetBlockShapeAndSplitKVBlock Kernel V2.4 * Refactor GetBlockShapeAndSplitKVBlock Kernel V2.5 * Solve problem about encoder_num_blocks_x_cpu * Add early-exit mechanism for attention kernel * fix test case about append-attention * Update testcode, Add annotations to related tensors * move get_input_length_list * solve test_code * Add annotations about early-exit for attention kernel * Add annotations about early-exit for attention kernel2 * solve comment * solve mtp --------- Co-authored-by: RAM <gstian5555@outlook.com>
This commit is contained in:
@@ -142,6 +142,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.use_cudagraph = self.graph_opt_config.use_cudagraph
|
||||
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
|
||||
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
|
||||
self.cudagraph_only_prefill = self.graph_opt_config.cudagraph_only_prefill
|
||||
|
||||
# Initialize share inputs
|
||||
self._init_share_inputs(self.parallel_config.max_num_seqs)
|
||||
@@ -177,10 +178,49 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
"""
|
||||
check whether prefill stage exist
|
||||
"""
|
||||
if int(paddle.max(self.share_inputs["seq_lens_encoder"])) != 0:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
return int(paddle.max(self.share_inputs["seq_lens_encoder"])) > 0
|
||||
|
||||
def exist_decode(self):
|
||||
"""
|
||||
check whether decode stage exist
|
||||
"""
|
||||
return int(paddle.max(self.share_inputs["seq_lens_decoder"])) > 0
|
||||
|
||||
def only_prefill(self):
|
||||
"""
|
||||
check whether prefill only
|
||||
"""
|
||||
if_only_prefill = True
|
||||
decode_exists = None
|
||||
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
|
||||
only_prefill_batch_list = []
|
||||
decode_exists = self.exist_decode()
|
||||
paddle.distributed.all_gather_object(only_prefill_batch_list, not decode_exists)
|
||||
if_only_prefill = all(only_prefill_batch_list)
|
||||
|
||||
if_only_prefill = if_only_prefill and not (decode_exists if decode_exists is not None else self.exist_decode())
|
||||
|
||||
return if_only_prefill
|
||||
|
||||
def only_decode(self):
|
||||
"""
|
||||
check whether decode only
|
||||
"""
|
||||
# Update Batch type for cuda graph for if_only_decode
|
||||
if_only_decode = True
|
||||
prefill_exists = None
|
||||
# mix ep in single node
|
||||
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
|
||||
only_decode_batch_list = []
|
||||
prefill_exists = self.exist_prefill()
|
||||
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
|
||||
if_only_decode = all(only_decode_batch_list)
|
||||
|
||||
if_only_decode = if_only_decode and not (
|
||||
prefill_exists if prefill_exists is not None else self.exist_prefill()
|
||||
)
|
||||
|
||||
return if_only_decode
|
||||
|
||||
def _init_speculative_proposer(self):
|
||||
"""
|
||||
@@ -600,27 +640,81 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
if self.speculative_method in ["mtp"]:
|
||||
self.proposer.insert_prefill_inputs(req_dicts, num_running_requests)
|
||||
|
||||
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int):
|
||||
"""Set dummy prefill inputs to share_inputs"""
|
||||
def get_input_length_list(
|
||||
self, num_tokens: int, batch_size: int, expected_decode_len: int, capture_prefill: bool = False
|
||||
):
|
||||
"""
|
||||
Generates some list for _dummy_prefill_inputs, when capture pure prefill or mtp,
|
||||
the list should be carefully constructed.
|
||||
|
||||
This function addresses a specific problem: in the pure prefill stage, variable
|
||||
input lengths (e.g., `prompt[160, 0]` vs. `prompt[80, 80]`) can lead to different
|
||||
CUDA Grid dimensions for kernels like `split_q_block`. This prevents CUDA Graph
|
||||
reuse.
|
||||
|
||||
The `split_q_block` kernel calculates the total number of blocks, which directly
|
||||
determines the `griddim.x` launch parameter for the `multi_query_append_attention_kernel`.
|
||||
The blocks for a single sequence are determined by the formula:
|
||||
`num_blocks = ceil((sequence_length * group_size) / block_shape_q)`
|
||||
|
||||
Due to the `ceil` (ceiling) function, distributing a total number of tokens across
|
||||
a batch of shorter sequences will result in a larger total block count. For example,
|
||||
with a `group_size` of 5 and `block_shape_q` of 64:
|
||||
- A single sequence of 160 tokens requires `ceil((160 * 5) / 64) = 13` blocks.
|
||||
- Two sequences of 80 tokens each require `ceil((80 * 5) / 64) * 2 = 7 * 2 = 14` blocks.
|
||||
|
||||
To ensure graph replayability, this function creates a "dummy" list of sequence
|
||||
lengths that's designed to produce the theoretical maximum `encoder_num_blocks_x_cpu`
|
||||
for the given `num_tokens` and `batch_size`. This strategy ensures the captured
|
||||
CUDA Graph has the largest possible grid dimensions. At runtime, if the actual number
|
||||
of blocks is less than or equal to this maximum, the kernel can safely execute by
|
||||
using an early-exit mechanism.
|
||||
|
||||
Args:
|
||||
num_tokens (int): The total number of tokens across all sequences.
|
||||
batch_size (int): The number of sequences (requests) in the batch.
|
||||
|
||||
Returns:
|
||||
List[int]: A list of integers representing the sequence length for each request.
|
||||
This list is crafted to maximize the total number of blocks.
|
||||
"""
|
||||
# NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token
|
||||
max_dec_len = expected_decode_len + 1
|
||||
full_length = min(
|
||||
num_tokens // batch_size,
|
||||
input_length = min(
|
||||
num_tokens // (1 if capture_prefill else batch_size),
|
||||
self.parallel_config.max_model_len - max_dec_len,
|
||||
)
|
||||
|
||||
# NOTE(wanglongzhi): When the full length is too large, DeepEP's buffer size will not be enough to cause the result to appear nan.
|
||||
# TODO(wanglongzhi): Figure out the accurate buffer size of DeepEP.
|
||||
if self.fd_config.parallel_config.enable_expert_parallel:
|
||||
full_length = min(full_length, 32)
|
||||
input_length = min(input_length, 32)
|
||||
|
||||
input_length = int(full_length * self.cache_config.kv_cache_ratio)
|
||||
block_num = (
|
||||
input_length + self.cache_config.block_size - 1
|
||||
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
|
||||
|
||||
input_length_list = [input_length] * batch_size
|
||||
|
||||
if capture_prefill:
|
||||
if num_tokens < batch_size:
|
||||
input_length_list = [1] * num_tokens
|
||||
else:
|
||||
input_length_list = [1] * (batch_size - 1)
|
||||
input_length_list.append(num_tokens - batch_size + 1)
|
||||
|
||||
len_of_input_length_list = len(input_length_list)
|
||||
max_dec_len_list = [max_dec_len] * len_of_input_length_list
|
||||
|
||||
return input_length_list, max_dec_len_list, block_num
|
||||
|
||||
def _dummy_prefill_inputs(self, input_length_list: List[int], max_dec_len_list: List[int], block_num: int):
|
||||
"""Set dummy prefill inputs to share_inputs"""
|
||||
batch_size = len(input_length_list)
|
||||
for i in range(batch_size):
|
||||
idx = i
|
||||
input_length = input_length_list[i]
|
||||
max_dec_len = max_dec_len_list[i]
|
||||
self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
|
||||
self.share_inputs["prompt_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
|
||||
self.share_inputs["eos_token_id"][:] = np.array(
|
||||
@@ -745,6 +839,13 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["decoder_tile_ids_per_batch"] = None
|
||||
self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
|
||||
self.share_inputs["max_len_tensor_cpu"] = None # CPU
|
||||
self.share_inputs["encoder_batch_ids"] = None
|
||||
self.share_inputs["encoder_tile_ids_per_batch"] = None
|
||||
self.share_inputs["encoder_num_blocks_x_cpu"] = None # CPU
|
||||
self.share_inputs["kv_batch_ids"] = None
|
||||
self.share_inputs["kv_tile_ids_per_batch"] = None
|
||||
self.share_inputs["kv_num_blocks_x_cpu"] = None # CPU
|
||||
self.share_inputs["max_len_kv_cpu"] = None # CPU
|
||||
|
||||
# Initialize rotary position embedding
|
||||
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
||||
@@ -977,23 +1078,30 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
|
||||
block_tables=self.share_inputs["block_tables"],
|
||||
caches=self.share_inputs["caches"],
|
||||
encoder_batch_ids=self.share_inputs["encoder_batch_ids"],
|
||||
encoder_tile_ids_per_batch=self.share_inputs["encoder_tile_ids_per_batch"],
|
||||
encoder_num_blocks_x_cpu=self.share_inputs["encoder_num_blocks_x_cpu"],
|
||||
kv_batch_ids=self.share_inputs["kv_batch_ids"],
|
||||
kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"],
|
||||
kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"],
|
||||
max_len_kv_cpu=self.share_inputs["max_len_kv_cpu"],
|
||||
)
|
||||
|
||||
# Update Batch type for cuda graph
|
||||
only_decode_batch = True
|
||||
prefill_exists = None
|
||||
# mix ep in single node
|
||||
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
|
||||
only_decode_batch_list = []
|
||||
prefill_exists = self.exist_prefill()
|
||||
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
|
||||
only_decode_batch = all(only_decode_batch_list)
|
||||
self.fd_config.parallel_config.moe_phase.phase = "decode" if only_decode_batch else "prefill"
|
||||
# Update Batch type for cuda graph for only_decode_batch
|
||||
if_only_decode = self.only_decode()
|
||||
only_decode_use_cudagraph = self.use_cudagraph and if_only_decode
|
||||
|
||||
# Update config about moe for better performance
|
||||
# TODO(wanglongzhi):Modifying the config at runtime is not appropriate; it needs to be moved to forward_meta. It will be used in MoEMethodBase.apply()
|
||||
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
|
||||
self.fd_config.parallel_config.moe_phase.phase = "decode" if if_only_decode else "prefill"
|
||||
|
||||
# Update Batch type for cuda graph for only_prefill_batch
|
||||
only_prefill_use_cudagraph = self.use_cudagraph and self.cudagraph_only_prefill and self.only_prefill()
|
||||
|
||||
# When support capture both prefill-only and decode-only, this will use [only_prefill_use_cudagraph or only_decode_use_cudagraph]
|
||||
self.forward_meta.step_use_cudagraph = (
|
||||
self.use_cudagraph
|
||||
and only_decode_batch
|
||||
and not (prefill_exists if prefill_exists is not None else self.exist_prefill())
|
||||
only_prefill_use_cudagraph if self.cudagraph_only_prefill else only_decode_use_cudagraph
|
||||
)
|
||||
|
||||
# Initialzie attention meta data
|
||||
@@ -1085,14 +1193,31 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
encoder_block_shape_q = 64
|
||||
decoder_block_shape_q = 16
|
||||
decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1
|
||||
group_size = np.ceil(num_heads / self.model_config.kv_num_heads)
|
||||
|
||||
decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
|
||||
(decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q
|
||||
(decoder_step_token_num * group_size) / decoder_block_shape_q
|
||||
)
|
||||
encode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
|
||||
(self.model_config.max_model_len * group_size) / encoder_block_shape_q
|
||||
)
|
||||
kv_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
|
||||
self.model_config.max_model_len / self.fd_config.cache_config.block_size
|
||||
)
|
||||
self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||
self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu()
|
||||
|
||||
self.share_inputs["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["encoder_tile_ids_per_batch"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["encoder_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
|
||||
|
||||
self.share_inputs["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
|
||||
self.share_inputs["max_len_kv_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
|
||||
|
||||
# Get the attention backend
|
||||
attn_cls = get_attention_backend()
|
||||
attn_backend = attn_cls(
|
||||
@@ -1112,6 +1237,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
batch_size: paddle.Tensor,
|
||||
expected_decode_len: int = 1,
|
||||
in_capturing: bool = False,
|
||||
capture_prefill: bool = False,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Use dummy inputs to run before formal execution.
|
||||
@@ -1119,11 +1245,19 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
num_tokens:
|
||||
expected_decode_len: Expected number of tokens generated
|
||||
in_capturing: Is cuda graph in capturing state
|
||||
capture_prefill: Capture pure prefill for cuda graph
|
||||
"""
|
||||
self._dummy_prefill_inputs(
|
||||
|
||||
input_length_list, max_dec_len_list, block_num = self.get_input_length_list(
|
||||
num_tokens=num_tokens,
|
||||
batch_size=batch_size,
|
||||
expected_decode_len=expected_decode_len,
|
||||
capture_prefill=capture_prefill,
|
||||
)
|
||||
self._dummy_prefill_inputs(
|
||||
input_length_list=input_length_list,
|
||||
max_dec_len_list=max_dec_len_list,
|
||||
block_num=block_num,
|
||||
)
|
||||
if self.speculative_method in ["mtp"]:
|
||||
self.proposer.dummy_prefill_inputs(
|
||||
@@ -1353,14 +1487,30 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
time_before_capture = time.perf_counter()
|
||||
expected_decode_len = 1
|
||||
capture_sizes = self.cudagraph_capture_sizes.copy()
|
||||
for batch_size in sorted(capture_sizes, reverse=True):
|
||||
self._dummy_run(
|
||||
num_tokens=self.parallel_config.max_num_batched_tokens,
|
||||
batch_size=batch_size,
|
||||
in_capturing=True,
|
||||
expected_decode_len=expected_decode_len,
|
||||
)
|
||||
logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}")
|
||||
|
||||
if self.fd_config.graph_opt_config.cudagraph_only_prefill:
|
||||
for num_tokens in sorted(capture_sizes, reverse=True):
|
||||
self._dummy_run(
|
||||
num_tokens=num_tokens,
|
||||
batch_size=self.parallel_config.max_num_seqs,
|
||||
in_capturing=True,
|
||||
expected_decode_len=expected_decode_len,
|
||||
capture_prefill=True,
|
||||
)
|
||||
logger.info(
|
||||
f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}"
|
||||
)
|
||||
else:
|
||||
for batch_size in sorted(capture_sizes, reverse=True):
|
||||
self._dummy_run(
|
||||
num_tokens=self.parallel_config.max_num_batched_tokens,
|
||||
batch_size=batch_size,
|
||||
in_capturing=True,
|
||||
expected_decode_len=expected_decode_len,
|
||||
)
|
||||
logger.info(
|
||||
f"Warm up the model with the num_tokens:{batch_size}, expected_decode_len:{expected_decode_len}"
|
||||
)
|
||||
|
||||
time_after_capture = time.perf_counter()
|
||||
logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds")
|
||||
|
Reference in New Issue
Block a user