add draft model using cudagraph switch

This commit is contained in:
gongshaotian
2025-10-27 14:56:34 +08:00
committed by lizhenyun01
parent 006c7e5a0d
commit fa85956c6f
2 changed files with 5 additions and 1 deletions

View File

@@ -591,6 +591,9 @@ class GraphOptimizationConfig:
""" Whether to use shared memory pool for multi capture_size """
self.use_unique_memory_pool: bool = False
""" Whether to use cudagraph for draft model."""
self.draft_model_use_cudagraph: bool = True
self.max_capture_size: int = None
self.real_shape_to_captured_size: dict[int, int] = None
# CINN Config ...

View File

@@ -83,6 +83,7 @@ class MTPProposer(Proposer):
self._init_model_inputs()
# CUDA Graph
self.draft_model_use_cudagraph = self.graph_opt_config.draft_model_use_cudagraph
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
@@ -618,7 +619,7 @@ class MTPProposer(Proposer):
attn_backend.init_attention_metadata(self.forward_meta)
# TODO(gongshaotian): Use CUDAGraph with Draft Model
self.forward_meta.step_use_cudagraph = step_use_cudagraph
self.forward_meta.step_use_cudagraph = step_use_cudagraph and self.draft_model_use_cudagraph
def exist_prefill(self):
"""