mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -17,19 +17,19 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import paddle.device.cuda.graphs as graphs
|
||||
import paddle.nn.layer
|
||||
from paddle.device.cuda import graphs
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("cudagrpah_piecewise_backend",
|
||||
"cudagraph_piecewise_backend.log")
|
||||
logger = get_logger("cudagrpah_piecewise_backend", "cudagraph_piecewise_backend.log")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConcreteSizeEntry:
|
||||
""" Record the concrete information corresponding to the current batch size """
|
||||
"""Record the concrete information corresponding to the current batch size"""
|
||||
|
||||
# Concrete batch size
|
||||
runtime_bs: int
|
||||
# The size is in cudagraph_capture_sizes
|
||||
@@ -48,7 +48,7 @@ class ConcreteSizeEntry:
|
||||
|
||||
|
||||
class CudaGraphPiecewiseBackend:
|
||||
""" Manage the capture and replay of CUDA graphs at the subgraph level. """
|
||||
"""Manage the capture and replay of CUDA graphs at the subgraph level."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -65,12 +65,10 @@ class CudaGraphPiecewiseBackend:
|
||||
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
|
||||
|
||||
for shape in self.cudagraph_capture_sizes:
|
||||
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
||||
runtime_bs=shape)
|
||||
self.concrete_size_entries[shape] = ConcreteSizeEntry(runtime_bs=shape)
|
||||
|
||||
logger.info(
|
||||
f"[CUDA GRAPH] CUDAGraph capture list {self.cudagraph_capture_sizes}, "
|
||||
"Created all batch sizes entry."
|
||||
f"[CUDA GRAPH] CUDAGraph capture list {self.cudagraph_capture_sizes}, " "Created all batch sizes entry."
|
||||
)
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
@@ -87,9 +85,7 @@ class CudaGraphPiecewiseBackend:
|
||||
assert entry is not None, f"Batch size:{padding_batch_size} is not in cuda graph capture list."
|
||||
if entry.runnable is None:
|
||||
entry.runnable = self.runnable
|
||||
logger.debug(
|
||||
f"[CUDA GRAPH] New entry lazy initialize with batch size {padding_batch_size}"
|
||||
)
|
||||
logger.debug(f"[CUDA GRAPH] New entry lazy initialize with batch size {padding_batch_size}")
|
||||
|
||||
if not entry.use_cudagraph:
|
||||
return entry.runnable(**kwargs)
|
||||
@@ -106,10 +102,7 @@ class CudaGraphPiecewiseBackend:
|
||||
)
|
||||
|
||||
# Store input addresses for debug
|
||||
input_addresses = [
|
||||
x.data_ptr() for (_, x) in kwargs.items()
|
||||
if isinstance(x, paddle.Tensor)
|
||||
]
|
||||
input_addresses = [x.data_ptr() for (_, x) in kwargs.items() if isinstance(x, paddle.Tensor)]
|
||||
entry.input_addresses = input_addresses
|
||||
|
||||
new_grpah = graphs.CUDAGraph()
|
||||
@@ -127,13 +120,9 @@ class CudaGraphPiecewiseBackend:
|
||||
output._clear
|
||||
|
||||
paddle.device.synchronize()
|
||||
logger.debug(
|
||||
f"[CUDA GRAPH] CUDAGraph captured for batch size {padding_batch_size}"
|
||||
)
|
||||
logger.debug(f"[CUDA GRAPH] CUDAGraph captured for batch size {padding_batch_size}")
|
||||
|
||||
# Replay
|
||||
entry.cuda_graph.replay()
|
||||
logger.debug(
|
||||
f"[CUDA GRAPH] CUDAGraph replayed for batch size {padding_batch_size}"
|
||||
)
|
||||
logger.debug(f"[CUDA GRAPH] CUDAGraph replayed for batch size {padding_batch_size}")
|
||||
return entry.output_buffer
|
||||
|
@@ -20,8 +20,9 @@ from typing import Callable, Optional, TypeVar
|
||||
import paddle.nn.layer
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.graph_optimization.graph_optimization_backend import \
|
||||
GraphOptBackend
|
||||
from fastdeploy.model_executor.graph_optimization.graph_optimization_backend import (
|
||||
GraphOptBackend,
|
||||
)
|
||||
|
||||
_T = TypeVar("_T", bound=type[paddle.nn.Layer])
|
||||
|
||||
@@ -46,23 +47,21 @@ def support_graph_optimization(cls: Optional[_T] = None) -> _T:
|
||||
if GraphOptWrapper in cls.__bases__:
|
||||
return cls
|
||||
else:
|
||||
cls.__bases__ = cls.__bases__ + (GraphOptWrapper, )
|
||||
cls.__bases__ = cls.__bases__ + (GraphOptWrapper,)
|
||||
origin_init = cls.__init__
|
||||
|
||||
def __init__(self, fd_config: FDConfig, **kwargs):
|
||||
""" Decorator model.__init__() func """
|
||||
"""Decorator model.__init__() func"""
|
||||
origin_init(self, fd_config=fd_config, **kwargs)
|
||||
self.use_graph_opt = fd_config.graph_opt_config.graph_opt_level > 0 or fd_config.graph_opt_config.use_cudagraph
|
||||
if self.use_graph_opt:
|
||||
GraphOptWrapper.__init__(self,
|
||||
fd_config=fd_config,
|
||||
graph_opt_backend=None)
|
||||
GraphOptWrapper.__init__(self, fd_config=fd_config, graph_opt_backend=None)
|
||||
else:
|
||||
# Not use graph optimization
|
||||
return
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
""" Decorator model.__call__() func """
|
||||
"""Decorator model.__call__() func"""
|
||||
if not self.use_graph_opt:
|
||||
return self.forward(**kwargs)
|
||||
|
||||
@@ -74,7 +73,7 @@ def support_graph_optimization(cls: Optional[_T] = None) -> _T:
|
||||
|
||||
|
||||
class GraphOptWrapper:
|
||||
""" The wrapper for GraphOptBackend """
|
||||
"""The wrapper for GraphOptBackend"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -87,7 +86,7 @@ class GraphOptWrapper:
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, **kwargs):
|
||||
""" Abstract methods for implementing model.forward() """
|
||||
"""Abstract methods for implementing model.forward()"""
|
||||
pass
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
|
@@ -19,8 +19,9 @@ from typing import Callable, Optional
|
||||
from paddle.jit.dy2static.utils import Backend
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.graph_optimization.cudagraph_piecewise_backend import \
|
||||
CudaGraphPiecewiseBackend
|
||||
from fastdeploy.model_executor.graph_optimization.cudagraph_piecewise_backend import (
|
||||
CudaGraphPiecewiseBackend,
|
||||
)
|
||||
|
||||
|
||||
class GraphOptBackend:
|
||||
@@ -36,32 +37,28 @@ class GraphOptBackend:
|
||||
self.runnable = runnable
|
||||
self.fd_config = fd_config
|
||||
|
||||
self.max_captre_batch = fd_config.graph_opt_config.cudagraph_capture_sizes[
|
||||
0]
|
||||
self.max_captre_batch = fd_config.graph_opt_config.cudagraph_capture_sizes[0]
|
||||
if self.fd_config.graph_opt_config.graph_opt_level > 0:
|
||||
# 1. Prepare cuda grpah input buffers (contain output of subgraphs)
|
||||
|
||||
# 2. Convert dynamic grpah to static graph
|
||||
from paddle.jit import sot
|
||||
backend = (Backend.CINN
|
||||
if self.fd_config.graph_opt_config.graph_opt_level > 1
|
||||
else Backend.PHI)
|
||||
self.runnable = sot.symbolic_translate(self.runnable,
|
||||
training=False,
|
||||
backend=backend)
|
||||
|
||||
backend = Backend.CINN if self.fd_config.graph_opt_config.graph_opt_level > 1 else Backend.PHI
|
||||
self.runnable = sot.symbolic_translate(self.runnable, training=False, backend=backend)
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
if not self.fd_config.graph_opt_config.use_cudagraph:
|
||||
return self.runnable(**kwargs)
|
||||
if self.cudagraph_piecewise_backend is None:
|
||||
self.cudagraph_piecewise_backend = CudaGraphPiecewiseBackend(
|
||||
fd_config=self.fd_config, runnable=self.runnable)
|
||||
fd_config=self.fd_config, runnable=self.runnable
|
||||
)
|
||||
|
||||
assert kwargs["forward_meta"].ids_remove_padding is not None
|
||||
batch_size = kwargs["forward_meta"].ids_remove_padding.shape[0]
|
||||
|
||||
if ((not kwargs["forward_meta"].step_use_cudagraph)
|
||||
or (batch_size > self.max_captre_batch)):
|
||||
if (not kwargs["forward_meta"].step_use_cudagraph) or (batch_size > self.max_captre_batch):
|
||||
return self.runnable(**kwargs)
|
||||
else:
|
||||
return self.cudagraph_piecewise_backend.__call__(**kwargs)
|
||||
|
Reference in New Issue
Block a user