mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[LLM] First commit the llm deployment code
This commit is contained in:
15
fastdeploy/model_executor/graph_optimization/__init__.py
Normal file
15
fastdeploy/model_executor/graph_optimization/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import paddle.device.cuda.graphs as graphs
|
||||
import paddle.nn.layer
|
||||
|
||||
from fastdeploy.config import LLMConfig
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("cudagrpah_piecewise_backend",
|
||||
"cudagraph_piecewise_backend.log")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConcreteSizeEntry:
|
||||
""" Record the concrete information corresponding to the current batch size """
|
||||
# Concrete batch size
|
||||
runtime_bs: int
|
||||
# The size is in cudagraph_capture_sizes
|
||||
use_cuda_graph: bool = True
|
||||
# Has runtime-bs been captured before
|
||||
captured: bool = False
|
||||
|
||||
# Need to be captured callable object(dynamic graph or static grpah backend)
|
||||
runnable: Callable = None # type: ignore
|
||||
# Number of completed warmups
|
||||
num_finished_warmup: int = 0
|
||||
# Captured cuda graph object corresponding to the current batch size
|
||||
cuda_graph: Optional[graphs.CUDAGraph] = None
|
||||
# Output buffer of cudagraph
|
||||
output_buffer: Optional[paddle.Tensor] = None
|
||||
|
||||
# for cudagraph debugging, track the input addresses
|
||||
# during capture, and check if they are the same during replay
|
||||
input_addresses: Optional[list[int]] = None
|
||||
|
||||
|
||||
class CudaGraphPiecewiseBackend:
|
||||
""" """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config: LLMConfig,
|
||||
runnable: Callable,
|
||||
):
|
||||
self.llm_config = llm_config
|
||||
self.runnable = runnable
|
||||
self.cuda_graph_capture_size = llm_config.graph_opt_config.cudagraph_capture_sizes
|
||||
# runtime_bs -> ConcreteSizeEntry
|
||||
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
|
||||
|
||||
for shape in self.cuda_graph_capture_size:
|
||||
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
||||
runtime_bs=shape)
|
||||
|
||||
print("create all batch size entry")
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
# Get batch size
|
||||
input_ids: paddle.Tensor = kwargs['input_ids']
|
||||
batch_size = input_ids.shape[0]
|
||||
entry = self.concrete_size_entries.get(batch_size)
|
||||
if entry.runnable is None:
|
||||
entry.runnable = self.runnable
|
||||
print(
|
||||
f"[CUDA GRAPH] new entry lazy initialize with batch size {batch_size}"
|
||||
)
|
||||
|
||||
if not entry.use_cuda_graph:
|
||||
return entry.runnable(**kwargs)
|
||||
|
||||
# Capture a new cuda graph
|
||||
if entry.cuda_graph is None:
|
||||
# Warmup the model
|
||||
for n in range(entry.num_finished_warmup):
|
||||
entry.num_finished_warmup += 1
|
||||
entry.runnable(**kwargs)
|
||||
print(
|
||||
f"[CUDA GRAPH] warm up for batch size "
|
||||
f"{batch_size}, finished ({n+1}/{entry.num_finished_warmup}) times"
|
||||
)
|
||||
|
||||
# Store input addresses for debug
|
||||
input_addresses = [
|
||||
x.data_ptr() for (_, x) in kwargs.items()
|
||||
if isinstance(x, paddle.Tensor)
|
||||
]
|
||||
entry.input_addresses = input_addresses
|
||||
|
||||
new_grpah = graphs.CUDAGraph()
|
||||
paddle.device.synchronize()
|
||||
|
||||
# Capture
|
||||
new_grpah.capture_begin()
|
||||
output = entry.runnable(**kwargs)
|
||||
new_grpah.capture_end()
|
||||
|
||||
# Store output buffer
|
||||
entry.cuda_graph = new_grpah
|
||||
entry.output_buffer = paddle.zeros_like(output)
|
||||
output._share_buffer_to(entry.output_buffer)
|
||||
output._clear
|
||||
|
||||
paddle.device.synchronize()
|
||||
print(
|
||||
f"[CUDA GRAPH] cuda graph captured for batch size {batch_size}"
|
||||
)
|
||||
|
||||
# Replay
|
||||
entry.cuda_graph.replay()
|
||||
print(f"[CUDA GRAPH] cuda graph replayed for batch size {batch_size}")
|
||||
return entry.output_buffer
|
96
fastdeploy/model_executor/graph_optimization/decorator.py
Normal file
96
fastdeploy/model_executor/graph_optimization/decorator.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, Optional, TypeVar
|
||||
|
||||
import paddle.nn.layer
|
||||
|
||||
from fastdeploy.config import LLMConfig
|
||||
from fastdeploy.model_executor.graph_optimization.graph_optimization_backend import \
|
||||
GraphOptBackend
|
||||
|
||||
_T = TypeVar("_T", bound=type[paddle.nn.Layer])
|
||||
|
||||
|
||||
def support_graph_opt(cls: Optional[_T] = None) -> _T:
|
||||
"""
|
||||
A decorator for wrapping models or layers with CUDA graph support.
|
||||
This enables efficient kernel launch sequencing for improved GPU performance.
|
||||
|
||||
Example usage:
|
||||
|
||||
'''
|
||||
@support_graph_opt
|
||||
class ErnieBot(paddle.nn.Layer):
|
||||
def __init__(**kwargs):
|
||||
...
|
||||
|
||||
def forward(self, x: paddle.Tensor, y: paddle.Tensor):
|
||||
...
|
||||
'''
|
||||
"""
|
||||
if GraphOptWrapper in cls.__bases__:
|
||||
return cls
|
||||
else:
|
||||
cls.__bases__ = cls.__bases__ + (GraphOptWrapper, )
|
||||
origin_init = cls.__init__
|
||||
|
||||
def __init__(self, llm_config: LLMConfig, **kwargs):
|
||||
""" Decorator model.__init__() func """
|
||||
origin_init(self, llm_config=llm_config, **kwargs)
|
||||
self.use_graph_opt = (
|
||||
not (llm_config.graph_opt_config.graph_opt_level == 0
|
||||
and not llm_config.graph_opt_config.use_cudagraph))
|
||||
if self.use_graph_opt:
|
||||
GraphOptWrapper.__init__(self,
|
||||
llm_config=llm_config,
|
||||
graph_opt_backend=None)
|
||||
else:
|
||||
# Not use graph optimization
|
||||
return
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
""" Decorator model.__call__() func """
|
||||
if not self.use_graph_opt:
|
||||
return self.forward(**kwargs)
|
||||
|
||||
return self.graph_opt_backend(**kwargs)
|
||||
|
||||
cls.__init__ = __init__
|
||||
cls.__call__ = __call__
|
||||
return cls
|
||||
|
||||
|
||||
class GraphOptWrapper:
|
||||
""" """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph_opt_backend: Optional[Callable] = None,
|
||||
llm_config: LLMConfig = None,
|
||||
):
|
||||
if graph_opt_backend is None:
|
||||
graph_opt_backend = GraphOptBackend(self.forward, llm_config)
|
||||
self.graph_opt_backend = graph_opt_backend
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, **kwargs):
|
||||
""" """
|
||||
pass
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
return self.graph_opt_backend(**kwargs)
|
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
from fastdeploy.config import LLMConfig
|
||||
from fastdeploy.model_executor.graph_optimization.cudagraph_piecewise_backend import \
|
||||
CudaGraphPiecewiseBackend
|
||||
|
||||
|
||||
class GraphOptBackend:
|
||||
""" """
|
||||
|
||||
llm_config: LLMConfig
|
||||
cudagraph_piecewise_backend: Optional[CudaGraphPiecewiseBackend] = None
|
||||
|
||||
def __init__(self, runnable: Callable, llm_config: LLMConfig):
|
||||
self.runnable = runnable
|
||||
self.llm_config = llm_config
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
# 1. TODO(gongshaotian): Static graph
|
||||
if self.llm_config.graph_opt_config.graph_opt_level > 0:
|
||||
# 1. Prepare cuda grpah input buffers (contain output of subgraphs)
|
||||
|
||||
# 2. Convert dynamic grpah to static graph
|
||||
if self.llm_config.graph_opt_config.graph_opt_level > 1:
|
||||
# with cinn
|
||||
pass
|
||||
else:
|
||||
# not use cinn
|
||||
pass
|
||||
|
||||
# 3. Split the static graph and get a list of callable obj
|
||||
|
||||
# 4. Get piecewise cuda grpah backend list
|
||||
|
||||
return self.runnable # Fake return value
|
||||
|
||||
# 2. Dynamic graph
|
||||
else:
|
||||
print(self.cudagraph_piecewise_backend is None)
|
||||
if self.cudagraph_piecewise_backend is None:
|
||||
self.cudagraph_piecewise_backend = CudaGraphPiecewiseBackend(
|
||||
llm_config=self.llm_config, runnable=self.runnable)
|
||||
# TODO(gongshaotian): handling kwargs
|
||||
assert kwargs["input_ids"] is not None
|
||||
return self.cudagraph_piecewise_backend.__call__(**kwargs)
|
Reference in New Issue
Block a user