mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[CUDAGraph]Add support for custom all-reduce operators under SOT mode (#4386)
This commit is contained in:
@@ -133,8 +133,9 @@ class CudaGraphPiecewiseBackend:
|
||||
self.cuda_graph_manager.state = jit_utils.CUDAGraphState.CAPTURE
|
||||
self.cuda_graph_manager.batch_size = entry.real_shape
|
||||
entry.captured = True
|
||||
with self.cuda_graph_manager.run_impl_guard():
|
||||
entry.runnable(**kwargs)
|
||||
with capture_custom_allreduce():
|
||||
with self.cuda_graph_manager.run_impl_guard():
|
||||
entry.runnable(**kwargs)
|
||||
|
||||
# Replay
|
||||
self.cuda_graph_manager.state = jit_utils.CUDAGraphState.REPLAY
|
||||
|
||||
Reference in New Issue
Block a user