[Executor]CUDAGraph support Speculate Decode (#4258)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* [Executor]CUDAGraph support Speculate Decode

* fix problem

* solve problem

* fix

* fast compile

* CUDAGraph + mtp support eb5(only target model)

* Revert "fast compile"

This reverts commit 3cfe8373ed.

* fix precommit

* solve comment

* fix comment about #pragram unroll

---------

Co-authored-by: gongshaotian <gstain5555@outlook.com>
Co-authored-by: gongshaotian <gstian5555@outlook.com>
This commit is contained in:
Jundong Liu
2025-10-13 15:21:41 +08:00
committed by GitHub
parent 07db281647
commit 0b7a5778ab
16 changed files with 265 additions and 134 deletions

View File

@@ -115,7 +115,7 @@ class GraphOptBackend:
self.runnable = runnable
self.fd_config = fd_config
self.max_captre_batch = fd_config.graph_opt_config.cudagraph_capture_sizes[0]
self.max_captre_size = fd_config.graph_opt_config.cudagraph_capture_sizes[0]
if self.fd_config.graph_opt_config.graph_opt_level > 0:
# 1. Prepare cuda grpah input buffers (contain output of subgraphs)
@@ -138,9 +138,9 @@ class GraphOptBackend:
)
assert kwargs["forward_meta"].ids_remove_padding is not None
batch_size = kwargs["forward_meta"].ids_remove_padding.shape[0]
real_shape = kwargs["forward_meta"].ids_remove_padding.shape[0]
if (not kwargs["forward_meta"].step_use_cudagraph) or (batch_size > self.max_captre_batch):
if (not kwargs["forward_meta"].step_use_cudagraph) or (real_shape > self.max_captre_size):
return self.runnable(**kwargs)
else:
return self.cudagraph_piecewise_backend.__call__(**kwargs)