[Graph Optimization] Add dy_runnable and introduce cudagraph_switch_threshold for cudagraph mode switching (#4578)

* add new branch for sot

* reorder

* fix batch bug
This commit is contained in:
Ryan
2025-10-24 18:36:52 +08:00
committed by GitHub
parent e02a812880
commit f42ed6d5f2
2 changed files with 8 additions and 4 deletions

View File

@@ -102,8 +102,8 @@ class GraphOptBackend:
def __init__(self, runnable: Callable, fd_config: FDConfig):
self.runnable = runnable
self.dy_runnable = self.runnable
self.fd_config = fd_config
self.max_captre_size = fd_config.graph_opt_config.cudagraph_capture_sizes[0]
if self.fd_config.graph_opt_config.graph_opt_level > 0:
# 1. Prepare cuda graph input buffers (contain output of subgraphs)
@@ -118,6 +118,10 @@ class GraphOptBackend:
backend,
).__get__(self.runnable.__self__)
self.cudagraph_switch_threshold = (
1024 if self.fd_config.graph_opt_config.graph_opt_level > 0 else self.max_captre_size
)
def __call__(self, **kwargs):
if not self.fd_config.graph_opt_config.use_cudagraph:
return self.runnable(**kwargs)
@@ -129,8 +133,8 @@ class GraphOptBackend:
assert kwargs["forward_meta"].ids_remove_padding is not None
real_shape = kwargs["forward_meta"].ids_remove_padding.shape[0]
if (not kwargs["forward_meta"].step_use_cudagraph) or (real_shape > self.max_captre_size):
return self.runnable(**kwargs)
if (not kwargs["forward_meta"].step_use_cudagraph) or (real_shape > self.cudagraph_switch_threshold):
return self.dy_runnable(**kwargs)
else:
return self.cudagraph_piecewise_backend.__call__(**kwargs)