mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
@@ -20,7 +20,6 @@ from typing import Callable, Dict, Optional
|
||||
|
||||
import paddle.nn.layer
|
||||
from paddle.device.cuda import graphs
|
||||
from paddle.jit.dy2static.utils import CUDAGraphState
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.distributed.communication import capture_custom_allreduce
|
||||
@@ -52,11 +51,16 @@ class ConcreteSizeEntry:
|
||||
|
||||
class Dy2StCudaGraphManager:
|
||||
def __init__(self):
|
||||
# NOTE(gongshaotian): Use local import to avoid RLHF version problems
|
||||
from paddle.jit.dy2static.utils import CUDAGraphState
|
||||
|
||||
self.state = CUDAGraphState.DISABLE
|
||||
self.captured_batch_size = set()
|
||||
self.batch_size = -1
|
||||
|
||||
def run_impl(self, original_run_impl, inputs, parameters, attrs):
|
||||
from paddle.jit.dy2static.utils import CUDAGraphState
|
||||
|
||||
run_state = self.state
|
||||
prog_attrs, cuda_graph_attrs = attrs
|
||||
if run_state == CUDAGraphState.REPLAY:
|
||||
@@ -100,6 +104,8 @@ class CudaGraphPiecewiseBackend:
|
||||
self.cuda_graph_manager = Dy2StCudaGraphManager()
|
||||
|
||||
def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
|
||||
from paddle.jit.dy2static.utils import CUDAGraphState
|
||||
|
||||
if not entry.captured:
|
||||
# Warmup the model
|
||||
for n in range(entry.num_finished_warmup, self.warm_up_size):
|
||||
|
Reference in New Issue
Block a user