Sync v2.0 version of code to github repo

This commit is contained in:
Jiang-Jia-Jun
2025-06-29 23:29:37 +00:00
parent d151496038
commit 92c2cfa2e7
597 changed files with 78776 additions and 22905 deletions

View File

@@ -20,7 +20,7 @@ from typing import Callable, Dict, Optional
import paddle.device.cuda.graphs as graphs
import paddle.nn.layer
from fastdeploy.config import LLMConfig
from fastdeploy.config import FDConfig
from fastdeploy.utils import get_logger
logger = get_logger("cudagrpah_piecewise_backend",
@@ -33,7 +33,7 @@ class ConcreteSizeEntry:
# Concrete batch size
runtime_bs: int
# The size is in cudagraph_capture_sizes
use_cuda_graph: bool = True
use_cudagraph: bool = True
# Has runtime-bs been captured before
captured: bool = False
@@ -56,45 +56,56 @@ class CudaGraphPiecewiseBackend:
def __init__(
self,
llm_config: LLMConfig,
fd_config: FDConfig,
runnable: Callable,
):
self.llm_config = llm_config
self.fd_config = fd_config
self.runnable = runnable
self.cuda_graph_capture_size = llm_config.graph_opt_config.cudagraph_capture_sizes
self.cudagraph_capture_sizes = fd_config.graph_opt_config.cudagraph_capture_sizes
self.warm_up_size = fd_config.graph_opt_config.cudagraph_num_of_warmups
self.batch_size_to_captured_size = fd_config.graph_opt_config.batch_size_to_captured_size
# runtime_bs -> ConcreteSizeEntry
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
for shape in self.cuda_graph_capture_size:
for shape in self.cudagraph_capture_sizes:
self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_bs=shape)
print("create all batch size entry")
print("[CUDA GRAPH] Created all batch size entry ")
def __call__(self, **kwargs):
# Get batch size
input_ids: paddle.Tensor = kwargs['input_ids']
batch_size = input_ids.shape[0]
entry = self.concrete_size_entries.get(batch_size)
ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"]
batch_size = ids_remove_padding.shape[0]
padding_batch_size = self.batch_size_to_captured_size[batch_size]
# print(
# f"[CUDA GRAPH] The actual batch size obtained by CUDAGraph is :{batch_size}, ",
# f"The padded batch size is :{padding_batch_size}"
# )
entry = self.concrete_size_entries.get(padding_batch_size)
assert entry is not None, f"Batch size:{padding_batch_size} is not in cuda graph capture list."
if entry.runnable is None:
entry.runnable = self.runnable
print(
f"[CUDA GRAPH] new entry lazy initialize with batch size {batch_size}"
)
# print(
# f"[CUDA GRAPH] New entry lazy initialize with batch size {padding_batch_size}"
# )
if not entry.use_cuda_graph:
if not entry.use_cudagraph:
return entry.runnable(**kwargs)
# Capture a new cuda graph
if entry.cuda_graph is None:
# Warmup the model
for n in range(entry.num_finished_warmup):
for n in range(entry.num_finished_warmup, self.warm_up_size):
entry.num_finished_warmup += 1
entry.runnable(**kwargs)
print(
f"[CUDA GRAPH] warm up for batch size "
f"{batch_size}, finished ({n+1}/{entry.num_finished_warmup}) times"
)
# print(
# "[CUDA GRAPH] Warm up for batch size ",
# f"{padding_batch_size}, finished ({n+1}/{entry.num_finished_warmup}) times"
# )
# Store input addresses for debug
input_addresses = [
@@ -118,11 +129,11 @@ class CudaGraphPiecewiseBackend:
output._clear
paddle.device.synchronize()
print(
f"[CUDA GRAPH] cuda graph captured for batch size {batch_size}"
)
# print(
# f"[CUDA GRAPH] CUDAGraph captured for batch size {padding_batch_size}"
# )
# Replay
entry.cuda_graph.replay()
print(f"[CUDA GRAPH] cuda graph replayed for batch size {batch_size}")
# print(f"[CUDA GRAPH] CUDAGraph replayed for batch size {padding_batch_size}")
return entry.output_buffer

View File

@@ -19,14 +19,14 @@ from typing import Callable, Optional, TypeVar
import paddle.nn.layer
from fastdeploy.config import LLMConfig
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.graph_optimization.graph_optimization_backend import \
GraphOptBackend
_T = TypeVar("_T", bound=type[paddle.nn.Layer])
def support_graph_opt(cls: Optional[_T] = None) -> _T:
def support_graph_optimization(cls: Optional[_T] = None) -> _T:
"""
A decorator for wrapping models or layers with CUDA graph support.
This enables efficient kernel launch sequencing for improved GPU performance.
@@ -34,7 +34,7 @@ def support_graph_opt(cls: Optional[_T] = None) -> _T:
Example usage:
'''
@support_graph_opt
@support_graph_optimization
class ErnieBot(paddle.nn.Layer):
def __init__(**kwargs):
...
@@ -49,15 +49,13 @@ def support_graph_opt(cls: Optional[_T] = None) -> _T:
cls.__bases__ = cls.__bases__ + (GraphOptWrapper, )
origin_init = cls.__init__
def __init__(self, llm_config: LLMConfig, **kwargs):
def __init__(self, fd_config: FDConfig, **kwargs):
""" Decorator model.__init__() func """
origin_init(self, llm_config=llm_config, **kwargs)
self.use_graph_opt = (
not (llm_config.graph_opt_config.graph_opt_level == 0
and not llm_config.graph_opt_config.use_cudagraph))
origin_init(self, fd_config=fd_config, **kwargs)
self.use_graph_opt = fd_config.graph_opt_config.graph_opt_level > 0 or fd_config.graph_opt_config.use_cudagraph
if self.use_graph_opt:
GraphOptWrapper.__init__(self,
llm_config=llm_config,
fd_config=fd_config,
graph_opt_backend=None)
else:
# Not use graph optimization
@@ -81,10 +79,10 @@ class GraphOptWrapper:
def __init__(
self,
graph_opt_backend: Optional[Callable] = None,
llm_config: LLMConfig = None,
fd_config: FDConfig = None,
):
if graph_opt_backend is None:
graph_opt_backend = GraphOptBackend(self.forward, llm_config)
graph_opt_backend = GraphOptBackend(self.forward, fd_config)
self.graph_opt_backend = graph_opt_backend
@abstractmethod

View File

@@ -16,7 +16,9 @@
from typing import Callable, Optional
from fastdeploy.config import LLMConfig
from paddle.jit.dy2static.utils import Backend
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.graph_optimization.cudagraph_piecewise_backend import \
CudaGraphPiecewiseBackend
@@ -24,38 +26,39 @@ from fastdeploy.model_executor.graph_optimization.cudagraph_piecewise_backend im
class GraphOptBackend:
""" """
llm_config: LLMConfig
fd_config: FDConfig
cudagraph_piecewise_backend: Optional[CudaGraphPiecewiseBackend] = None
def __init__(self, runnable: Callable, llm_config: LLMConfig):
def __init__(self, runnable: Callable, fd_config: FDConfig):
self.runnable = runnable
self.llm_config = llm_config
self.fd_config = fd_config
def __call__(self, **kwargs):
# 1. TODO(gongshaotian): Static graph
if self.llm_config.graph_opt_config.graph_opt_level > 0:
self.max_captre_batch = fd_config.graph_opt_config.cudagraph_capture_sizes[
0]
if self.fd_config.graph_opt_config.graph_opt_level > 0:
# 1. Prepare cuda grpah input buffers (contain output of subgraphs)
# 2. Convert dynamic grpah to static graph
if self.llm_config.graph_opt_config.graph_opt_level > 1:
# with cinn
pass
else:
# not use cinn
pass
from paddle.jit import sot
backend = (Backend.CINN
if self.fd_config.graph_opt_config.graph_opt_level > 1
else Backend.PHI)
self.runnable = sot.symbolic_translate(self.runnable,
training=False,
backend=backend)
# 3. Split the static graph and get a list of callable obj
def __call__(self, **kwargs):
if not self.fd_config.graph_opt_config.use_cudagraph:
return self.runnable(**kwargs)
if self.cudagraph_piecewise_backend is None:
self.cudagraph_piecewise_backend = CudaGraphPiecewiseBackend(
fd_config=self.fd_config, runnable=self.runnable)
# 4. Get piecewise cuda grpah backend list
assert kwargs["forward_meta"].ids_remove_padding is not None
batch_size = kwargs["forward_meta"].ids_remove_padding.shape[0]
return self.runnable # Fake return value
# 2. Dynamic graph
if ((not kwargs["forward_meta"].step_use_cudagraph)
or (batch_size > self.max_captre_batch)):
return self.runnable(**kwargs)
else:
print(self.cudagraph_piecewise_backend is None)
if self.cudagraph_piecewise_backend is None:
self.cudagraph_piecewise_backend = CudaGraphPiecewiseBackend(
llm_config=self.llm_config, runnable=self.runnable)
# TODO(gongshaotian): handling kwargs
assert kwargs["input_ids"] is not None
return self.cudagraph_piecewise_backend.__call__(**kwargs)