mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Graph Optimization] Refactor default capture list (#4617)
* fix bug and refine code * add debug count * refine code
This commit is contained in:
@@ -862,7 +862,7 @@ class GraphOptimizationConfig:
|
||||
self.real_shape_to_captured_size[bs] = end
|
||||
self.real_shape_to_captured_size[self.max_capture_size] = self.max_capture_size
|
||||
|
||||
def _set_cudagraph_sizes(self, max_num_seqs: int = 0):
|
||||
def _set_cudagraph_sizes(self, max_capture_size: int = 0):
|
||||
"""
|
||||
Calculate a series of candidate capture sizes,
|
||||
and then extract a portion of them as the capture list for the CUDA graph based on user input.
|
||||
@@ -874,7 +874,7 @@ class GraphOptimizationConfig:
|
||||
# Shape [256, 288, ... 992, 1024]
|
||||
draft_capture_sizes += [32 * i for i in range(9, 33)]
|
||||
|
||||
draft_capture_sizes.append(max_num_seqs)
|
||||
draft_capture_sizes.append(max_capture_size)
|
||||
self.cudagraph_capture_sizes = sorted(draft_capture_sizes)
|
||||
|
||||
def to_json_string(self):
|
||||
@@ -1391,19 +1391,22 @@ class FDConfig:
|
||||
self.cache_config: CacheConfig = cache_config # type: ignore
|
||||
self.plas_attention_config: Optional[PlasAttentionConfig] = plas_attention_config
|
||||
self.structured_outputs_config: StructuredOutputsConfig = structured_outputs_config
|
||||
# Initialize cuda graph capture list
|
||||
if self.graph_opt_config.cudagraph_capture_sizes is None:
|
||||
self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.scheduler_config.max_num_seqs)
|
||||
|
||||
# Initialize cuda graph capture list
|
||||
max_capture_shape = self.scheduler_config.max_num_seqs
|
||||
if self.speculative_config is not None and self.speculative_config.method == "mtp":
|
||||
max_capture_shape = self.scheduler_config.max_num_seqs * (
|
||||
self.speculative_config.num_speculative_tokens + 1
|
||||
)
|
||||
assert max_capture_shape % 2 == 0, "CUDAGraph only supports capturing even token nums in MTP scenarios."
|
||||
if self.graph_opt_config.cudagraph_only_prefill:
|
||||
self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=512)
|
||||
elif self.speculative_config is not None and self.speculative_config.method == "mtp":
|
||||
max_shape = self.scheduler_config.max_num_seqs * (self.speculative_config.num_speculative_tokens + 1)
|
||||
if max_shape % 2 == 1:
|
||||
max_shape = max_shape + 1
|
||||
self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=min(512, max_shape))
|
||||
max_capture_shape = 512
|
||||
else:
|
||||
self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=self.scheduler_config.max_num_seqs)
|
||||
max_capture_shape = min(512, max_capture_shape)
|
||||
|
||||
if self.graph_opt_config.cudagraph_capture_sizes is None:
|
||||
self.graph_opt_config._set_cudagraph_sizes(max_capture_size=max_capture_shape)
|
||||
self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=max_capture_shape)
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.ips = ips
|
||||
|
||||
Reference in New Issue
Block a user