[Graph Optimization] Refactor default capture list (#4617)

* fix bug and refine code

* add debug count

* refine code
This commit is contained in:
RAM
2025-10-28 21:31:02 +08:00
committed by GitHub
parent 0a0c74e717
commit fff5fb5e39
3 changed files with 36 additions and 15 deletions

View File

@@ -34,6 +34,10 @@ from fastdeploy.model_executor.graph_optimization.utils import in_profile_run_mo
from fastdeploy.model_executor.graph_optimization.utils import (
in_sot_warmup_mode as in_warmup_mode,
)
from fastdeploy.utils import get_logger
logger = get_logger("cudagrpah_piecewise_backend", "cudagraph_piecewise_backend.log")
P = ParamSpec("P")
T = TypeVar("T")
@@ -105,6 +109,9 @@ class GraphOptBackend:
self.dy_runnable = self.runnable
self.fd_config = fd_config
self.max_captre_size = fd_config.graph_opt_config.cudagraph_capture_sizes[0]
self._debug_count_cudagraph_replay = 0
self._debug_count_total_step = 0
if self.fd_config.graph_opt_config.graph_opt_level > 0:
# 1. Prepare cuda graph input buffers (contain output of subgraphs)
@@ -123,6 +130,7 @@ class GraphOptBackend:
)
def __call__(self, **kwargs):
self._debug_count_total_step += 1
if not self.fd_config.graph_opt_config.use_cudagraph:
return self.runnable(**kwargs)
if self.cudagraph_piecewise_backend is None:
@@ -136,6 +144,10 @@ class GraphOptBackend:
if (not kwargs["forward_meta"].step_use_cudagraph) or (real_shape > self.cudagraph_switch_threshold):
return self.dy_runnable(**kwargs)
else:
self._debug_count_cudagraph_replay += 1
logger.debug(
f"[CUDA GRAPH][ID:{id(self.cudagraph_piecewise_backend)}] Total step count: {self._debug_count_total_step}, CUDAGraph replay count: {self._debug_count_cudagraph_replay}"
)
return self.cudagraph_piecewise_backend.__call__(**kwargs)
def clear_cudagraph_piecewise_backend(self):