diff --git a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py index 30a28d293..6e7619faf 100644 --- a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py +++ b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py @@ -20,7 +20,6 @@ from typing import Callable, Dict, Optional import paddle.nn.layer from paddle.device.cuda import graphs -from paddle.jit.dy2static.utils import CUDAGraphState from fastdeploy.config import FDConfig from fastdeploy.distributed.communication import capture_custom_allreduce @@ -52,11 +51,16 @@ class ConcreteSizeEntry: class Dy2StCudaGraphManager: def __init__(self): + # NOTE(gongshaotian): Use local import to avoid RLHF version problems + from paddle.jit.dy2static.utils import CUDAGraphState + self.state = CUDAGraphState.DISABLE self.captured_batch_size = set() self.batch_size = -1 def run_impl(self, original_run_impl, inputs, parameters, attrs): + from paddle.jit.dy2static.utils import CUDAGraphState + run_state = self.state prog_attrs, cuda_graph_attrs = attrs if run_state == CUDAGraphState.REPLAY: @@ -100,6 +104,8 @@ class CudaGraphPiecewiseBackend: self.cuda_graph_manager = Dy2StCudaGraphManager() def run_static_model(self, entry: ConcreteSizeEntry, **kwargs): + from paddle.jit.dy2static.utils import CUDAGraphState + if not entry.captured: # Warmup the model for n in range(entry.num_finished_warmup, self.warm_up_size):