mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
@@ -20,7 +20,6 @@ from typing import Callable, Dict, Optional
|
|||||||
|
|
||||||
import paddle.nn.layer
|
import paddle.nn.layer
|
||||||
from paddle.device.cuda import graphs
|
from paddle.device.cuda import graphs
|
||||||
from paddle.jit.dy2static.utils import CUDAGraphState
|
|
||||||
|
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
from fastdeploy.distributed.communication import capture_custom_allreduce
|
from fastdeploy.distributed.communication import capture_custom_allreduce
|
||||||
@@ -52,11 +51,16 @@ class ConcreteSizeEntry:
|
|||||||
|
|
||||||
class Dy2StCudaGraphManager:
|
class Dy2StCudaGraphManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
# NOTE(gongshaotian): Use local import to avoid RLHF version problems
|
||||||
|
from paddle.jit.dy2static.utils import CUDAGraphState
|
||||||
|
|
||||||
self.state = CUDAGraphState.DISABLE
|
self.state = CUDAGraphState.DISABLE
|
||||||
self.captured_batch_size = set()
|
self.captured_batch_size = set()
|
||||||
self.batch_size = -1
|
self.batch_size = -1
|
||||||
|
|
||||||
def run_impl(self, original_run_impl, inputs, parameters, attrs):
|
def run_impl(self, original_run_impl, inputs, parameters, attrs):
|
||||||
|
from paddle.jit.dy2static.utils import CUDAGraphState
|
||||||
|
|
||||||
run_state = self.state
|
run_state = self.state
|
||||||
prog_attrs, cuda_graph_attrs = attrs
|
prog_attrs, cuda_graph_attrs = attrs
|
||||||
if run_state == CUDAGraphState.REPLAY:
|
if run_state == CUDAGraphState.REPLAY:
|
||||||
@@ -100,6 +104,8 @@ class CudaGraphPiecewiseBackend:
|
|||||||
self.cuda_graph_manager = Dy2StCudaGraphManager()
|
self.cuda_graph_manager = Dy2StCudaGraphManager()
|
||||||
|
|
||||||
def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
|
def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
|
||||||
|
from paddle.jit.dy2static.utils import CUDAGraphState
|
||||||
|
|
||||||
if not entry.captured:
|
if not entry.captured:
|
||||||
# Warmup the model
|
# Warmup the model
|
||||||
for n in range(entry.num_finished_warmup, self.warm_up_size):
|
for n in range(entry.num_finished_warmup, self.warm_up_size):
|
||||||
|
Reference in New Issue
Block a user