mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
@@ -20,7 +20,7 @@ from typing import Callable, Dict, Optional
|
||||
import paddle.device.cuda.graphs as graphs
|
||||
import paddle.nn.layer
|
||||
|
||||
from fastdeploy.config import LLMConfig
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("cudagrpah_piecewise_backend",
|
||||
@@ -33,7 +33,7 @@ class ConcreteSizeEntry:
|
||||
# Concrete batch size
|
||||
runtime_bs: int
|
||||
# The size is in cudagraph_capture_sizes
|
||||
use_cuda_graph: bool = True
|
||||
use_cudagraph: bool = True
|
||||
# Has runtime-bs been captured before
|
||||
captured: bool = False
|
||||
|
||||
@@ -56,45 +56,56 @@ class CudaGraphPiecewiseBackend:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config: LLMConfig,
|
||||
fd_config: FDConfig,
|
||||
runnable: Callable,
|
||||
):
|
||||
self.llm_config = llm_config
|
||||
self.fd_config = fd_config
|
||||
self.runnable = runnable
|
||||
self.cuda_graph_capture_size = llm_config.graph_opt_config.cudagraph_capture_sizes
|
||||
self.cudagraph_capture_sizes = fd_config.graph_opt_config.cudagraph_capture_sizes
|
||||
self.warm_up_size = fd_config.graph_opt_config.cudagraph_num_of_warmups
|
||||
self.batch_size_to_captured_size = fd_config.graph_opt_config.batch_size_to_captured_size
|
||||
|
||||
# runtime_bs -> ConcreteSizeEntry
|
||||
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
|
||||
|
||||
for shape in self.cuda_graph_capture_size:
|
||||
for shape in self.cudagraph_capture_sizes:
|
||||
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
||||
runtime_bs=shape)
|
||||
|
||||
print("create all batch size entry")
|
||||
print("[CUDA GRAPH] Created all batch size entry ")
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
# Get batch size
|
||||
input_ids: paddle.Tensor = kwargs['input_ids']
|
||||
batch_size = input_ids.shape[0]
|
||||
entry = self.concrete_size_entries.get(batch_size)
|
||||
ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"]
|
||||
batch_size = ids_remove_padding.shape[0]
|
||||
|
||||
padding_batch_size = self.batch_size_to_captured_size[batch_size]
|
||||
# print(
|
||||
# f"[CUDA GRAPH] The actual batch size obtained by CUDAGraph is :{batch_size}, ",
|
||||
# f"The padded batch size is :{padding_batch_size}"
|
||||
# )
|
||||
|
||||
entry = self.concrete_size_entries.get(padding_batch_size)
|
||||
assert entry is not None, f"Batch size:{padding_batch_size} is not in cuda graph capture list."
|
||||
if entry.runnable is None:
|
||||
entry.runnable = self.runnable
|
||||
print(
|
||||
f"[CUDA GRAPH] new entry lazy initialize with batch size {batch_size}"
|
||||
)
|
||||
# print(
|
||||
# f"[CUDA GRAPH] New entry lazy initialize with batch size {padding_batch_size}"
|
||||
# )
|
||||
|
||||
if not entry.use_cuda_graph:
|
||||
if not entry.use_cudagraph:
|
||||
return entry.runnable(**kwargs)
|
||||
|
||||
# Capture a new cuda graph
|
||||
if entry.cuda_graph is None:
|
||||
# Warmup the model
|
||||
for n in range(entry.num_finished_warmup):
|
||||
for n in range(entry.num_finished_warmup, self.warm_up_size):
|
||||
entry.num_finished_warmup += 1
|
||||
entry.runnable(**kwargs)
|
||||
print(
|
||||
f"[CUDA GRAPH] warm up for batch size "
|
||||
f"{batch_size}, finished ({n+1}/{entry.num_finished_warmup}) times"
|
||||
)
|
||||
# print(
|
||||
# "[CUDA GRAPH] Warm up for batch size ",
|
||||
# f"{padding_batch_size}, finished ({n+1}/{entry.num_finished_warmup}) times"
|
||||
# )
|
||||
|
||||
# Store input addresses for debug
|
||||
input_addresses = [
|
||||
@@ -118,11 +129,11 @@ class CudaGraphPiecewiseBackend:
|
||||
output._clear
|
||||
|
||||
paddle.device.synchronize()
|
||||
print(
|
||||
f"[CUDA GRAPH] cuda graph captured for batch size {batch_size}"
|
||||
)
|
||||
# print(
|
||||
# f"[CUDA GRAPH] CUDAGraph captured for batch size {padding_batch_size}"
|
||||
# )
|
||||
|
||||
# Replay
|
||||
entry.cuda_graph.replay()
|
||||
print(f"[CUDA GRAPH] cuda graph replayed for batch size {batch_size}")
|
||||
# print(f"[CUDA GRAPH] CUDAGraph replayed for batch size {padding_batch_size}")
|
||||
return entry.output_buffer
|
||||
|
@@ -19,14 +19,14 @@ from typing import Callable, Optional, TypeVar
|
||||
|
||||
import paddle.nn.layer
|
||||
|
||||
from fastdeploy.config import LLMConfig
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.graph_optimization.graph_optimization_backend import \
|
||||
GraphOptBackend
|
||||
|
||||
_T = TypeVar("_T", bound=type[paddle.nn.Layer])
|
||||
|
||||
|
||||
def support_graph_opt(cls: Optional[_T] = None) -> _T:
|
||||
def support_graph_optimization(cls: Optional[_T] = None) -> _T:
|
||||
"""
|
||||
A decorator for wrapping models or layers with CUDA graph support.
|
||||
This enables efficient kernel launch sequencing for improved GPU performance.
|
||||
@@ -34,7 +34,7 @@ def support_graph_opt(cls: Optional[_T] = None) -> _T:
|
||||
Example usage:
|
||||
|
||||
'''
|
||||
@support_graph_opt
|
||||
@support_graph_optimization
|
||||
class ErnieBot(paddle.nn.Layer):
|
||||
def __init__(**kwargs):
|
||||
...
|
||||
@@ -49,15 +49,13 @@ def support_graph_opt(cls: Optional[_T] = None) -> _T:
|
||||
cls.__bases__ = cls.__bases__ + (GraphOptWrapper, )
|
||||
origin_init = cls.__init__
|
||||
|
||||
def __init__(self, llm_config: LLMConfig, **kwargs):
|
||||
def __init__(self, fd_config: FDConfig, **kwargs):
|
||||
""" Decorator model.__init__() func """
|
||||
origin_init(self, llm_config=llm_config, **kwargs)
|
||||
self.use_graph_opt = (
|
||||
not (llm_config.graph_opt_config.graph_opt_level == 0
|
||||
and not llm_config.graph_opt_config.use_cudagraph))
|
||||
origin_init(self, fd_config=fd_config, **kwargs)
|
||||
self.use_graph_opt = fd_config.graph_opt_config.graph_opt_level > 0 or fd_config.graph_opt_config.use_cudagraph
|
||||
if self.use_graph_opt:
|
||||
GraphOptWrapper.__init__(self,
|
||||
llm_config=llm_config,
|
||||
fd_config=fd_config,
|
||||
graph_opt_backend=None)
|
||||
else:
|
||||
# Not use graph optimization
|
||||
@@ -81,10 +79,10 @@ class GraphOptWrapper:
|
||||
def __init__(
|
||||
self,
|
||||
graph_opt_backend: Optional[Callable] = None,
|
||||
llm_config: LLMConfig = None,
|
||||
fd_config: FDConfig = None,
|
||||
):
|
||||
if graph_opt_backend is None:
|
||||
graph_opt_backend = GraphOptBackend(self.forward, llm_config)
|
||||
graph_opt_backend = GraphOptBackend(self.forward, fd_config)
|
||||
self.graph_opt_backend = graph_opt_backend
|
||||
|
||||
@abstractmethod
|
||||
|
@@ -16,7 +16,9 @@
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
from fastdeploy.config import LLMConfig
|
||||
from paddle.jit.dy2static.utils import Backend
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.graph_optimization.cudagraph_piecewise_backend import \
|
||||
CudaGraphPiecewiseBackend
|
||||
|
||||
@@ -24,38 +26,39 @@ from fastdeploy.model_executor.graph_optimization.cudagraph_piecewise_backend im
|
||||
class GraphOptBackend:
|
||||
""" """
|
||||
|
||||
llm_config: LLMConfig
|
||||
fd_config: FDConfig
|
||||
cudagraph_piecewise_backend: Optional[CudaGraphPiecewiseBackend] = None
|
||||
|
||||
def __init__(self, runnable: Callable, llm_config: LLMConfig):
|
||||
def __init__(self, runnable: Callable, fd_config: FDConfig):
|
||||
self.runnable = runnable
|
||||
self.llm_config = llm_config
|
||||
self.fd_config = fd_config
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
# 1. TODO(gongshaotian): Static graph
|
||||
if self.llm_config.graph_opt_config.graph_opt_level > 0:
|
||||
self.max_captre_batch = fd_config.graph_opt_config.cudagraph_capture_sizes[
|
||||
0]
|
||||
if self.fd_config.graph_opt_config.graph_opt_level > 0:
|
||||
# 1. Prepare cuda grpah input buffers (contain output of subgraphs)
|
||||
|
||||
# 2. Convert dynamic grpah to static graph
|
||||
if self.llm_config.graph_opt_config.graph_opt_level > 1:
|
||||
# with cinn
|
||||
pass
|
||||
else:
|
||||
# not use cinn
|
||||
pass
|
||||
from paddle.jit import sot
|
||||
backend = (Backend.CINN
|
||||
if self.fd_config.graph_opt_config.graph_opt_level > 1
|
||||
else Backend.PHI)
|
||||
self.runnable = sot.symbolic_translate(self.runnable,
|
||||
training=False,
|
||||
backend=backend)
|
||||
|
||||
# 3. Split the static graph and get a list of callable obj
|
||||
def __call__(self, **kwargs):
|
||||
if not self.fd_config.graph_opt_config.use_cudagraph:
|
||||
return self.runnable(**kwargs)
|
||||
if self.cudagraph_piecewise_backend is None:
|
||||
self.cudagraph_piecewise_backend = CudaGraphPiecewiseBackend(
|
||||
fd_config=self.fd_config, runnable=self.runnable)
|
||||
|
||||
# 4. Get piecewise cuda grpah backend list
|
||||
assert kwargs["forward_meta"].ids_remove_padding is not None
|
||||
batch_size = kwargs["forward_meta"].ids_remove_padding.shape[0]
|
||||
|
||||
return self.runnable # Fake return value
|
||||
|
||||
# 2. Dynamic graph
|
||||
if ((not kwargs["forward_meta"].step_use_cudagraph)
|
||||
or (batch_size > self.max_captre_batch)):
|
||||
return self.runnable(**kwargs)
|
||||
else:
|
||||
print(self.cudagraph_piecewise_backend is None)
|
||||
if self.cudagraph_piecewise_backend is None:
|
||||
self.cudagraph_piecewise_backend = CudaGraphPiecewiseBackend(
|
||||
llm_config=self.llm_config, runnable=self.runnable)
|
||||
# TODO(gongshaotian): handling kwargs
|
||||
assert kwargs["input_ids"] is not None
|
||||
return self.cudagraph_piecewise_backend.__call__(**kwargs)
|
||||
|
Reference in New Issue
Block a user