mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00
[CUDAGraph] Support multi output buffers and merge some fixes from feature/exp_0908 (#4062)
* refine cudagraph * refine cudagraph * typo * fix * fix plugins * fix * update * update * update
This commit is contained in:
@@ -71,15 +71,9 @@ class InputPreprocessor:
|
||||
"""
|
||||
reasoning_parser_obj = None
|
||||
tool_parser_obj = None
|
||||
try:
|
||||
from fastdeploy.plugins.reasoning_parser import (
|
||||
load_reasoning_parser_plugins,
|
||||
)
|
||||
|
||||
reasoning_parser_obj = load_reasoning_parser_plugins()
|
||||
except:
|
||||
if self.reasoning_parser:
|
||||
reasoning_parser_obj = ReasoningParserManager.get_reasoning_parser(self.reasoning_parser)
|
||||
if self.reasoning_parser:
|
||||
reasoning_parser_obj = ReasoningParserManager.get_reasoning_parser(self.reasoning_parser)
|
||||
if self.tool_parser:
|
||||
tool_parser_obj = ToolParserManager.get_tool_parser(self.tool_parser)
|
||||
|
||||
|
@@ -14,14 +14,16 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import paddle.jit.dy2static.utils as jit_utils
|
||||
import paddle.nn.layer
|
||||
from paddle.device.cuda import graphs
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.distributed.communication import capture_custom_allreduce
|
||||
from fastdeploy.utils import get_logger
|
||||
@@ -46,8 +48,8 @@ class ConcreteSizeEntry:
|
||||
num_finished_warmup: int = 0
|
||||
# Captured cuda graph object corresponding to the current real shape
|
||||
cuda_graph: Optional[graphs.CUDAGraph] = None
|
||||
# Output buffer of cudagraph
|
||||
output_buffer: Optional[paddle.Tensor] = None
|
||||
# Output buffers of cudagraph
|
||||
output_buffers: List[Optional[paddle.Tensor]] = field(default_factory=list)
|
||||
|
||||
|
||||
class Dy2StCudaGraphManager:
|
||||
@@ -130,9 +132,9 @@ class CudaGraphPiecewiseBackend:
|
||||
with self.cuda_graph_manager.run_impl_guard():
|
||||
return entry.runnable(**kwargs)
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
def __call__(self, **kwargs) -> List[paddle.Tensor] | paddle.Tensor:
|
||||
# Get real shape(all num tokens)
|
||||
ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"]
|
||||
ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding
|
||||
real_shape = ids_remove_padding.shape[0]
|
||||
padding_real_shape = self.real_shape_to_captured_size[real_shape]
|
||||
logger.debug(
|
||||
@@ -173,14 +175,22 @@ class CudaGraphPiecewiseBackend:
|
||||
# Capture
|
||||
with capture_custom_allreduce():
|
||||
new_grpah.capture_begin()
|
||||
output = entry.runnable(**kwargs)
|
||||
outputs = entry.runnable(**kwargs)
|
||||
if isinstance(outputs, paddle.Tensor):
|
||||
assert outputs is not None
|
||||
outputs = [outputs]
|
||||
new_grpah.capture_end()
|
||||
|
||||
# Store output buffer
|
||||
entry.cuda_graph = new_grpah
|
||||
entry.output_buffer = paddle.zeros_like(output)
|
||||
output._share_buffer_to(entry.output_buffer)
|
||||
output._clear
|
||||
for output in outputs:
|
||||
if output is not None:
|
||||
output_buffer = paddle.zeros_like(output)
|
||||
output._share_buffer_to(output_buffer)
|
||||
output._clear
|
||||
entry.output_buffers.append(output_buffer)
|
||||
else:
|
||||
entry.output_buffers.append(None)
|
||||
|
||||
paddle.device.synchronize()
|
||||
|
||||
@@ -191,7 +201,9 @@ class CudaGraphPiecewiseBackend:
|
||||
# Replay
|
||||
entry.cuda_graph.replay()
|
||||
logger.debug(f"[CUDA GRAPH] CUDAGraph replayed for real shape {padding_real_shape}")
|
||||
return entry.output_buffer
|
||||
if len(entry.output_buffers) == 1:
|
||||
return entry.output_buffers[0]
|
||||
return entry.output_buffers
|
||||
|
||||
def _create_entry_dict(self):
|
||||
""" """
|
||||
@@ -221,8 +233,11 @@ class CudaGraphPiecewiseBackend:
|
||||
|
||||
def _save_cudagrpah_dot_files(self, entry):
|
||||
"""Print CUDAGrpah to dot files"""
|
||||
log_dir = envs.FD_LOG_DIR
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
if entry.cuda_graph:
|
||||
entry.cuda_graph.print_to_dot_files(
|
||||
f"./log/GraphDotFiles/backend{id(self)}_shape{entry.real_shape}",
|
||||
f"{log_dir}/GraphDotFiles/backend{id(self)}_shape{entry.real_shape}",
|
||||
1 << 0,
|
||||
)
|
||||
|
@@ -23,5 +23,5 @@ PLUGINS_GROUP = "fastdeploy.input_processor_plugins"
|
||||
def load_input_processor_plugins():
|
||||
"""load_input_processor_plugins"""
|
||||
plugins = load_plugins_by_group(group=PLUGINS_GROUP)
|
||||
assert len(plugins) <= 1, "Most one plugin is allowed to be loaded."
|
||||
assert len(plugins) == 1, "Only one plugin is allowed to be loaded."
|
||||
return next(iter(plugins.values()))()
|
||||
|
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from fastdeploy.plugins.utils import load_plugins_by_group, plugins_loaded
|
||||
from fastdeploy.plugins.utils import load_plugins_by_group
|
||||
|
||||
# use for modle runner
|
||||
PLUGINS_GROUP = "fastdeploy.model_runner_plugins"
|
||||
@@ -22,11 +22,6 @@ PLUGINS_GROUP = "fastdeploy.model_runner_plugins"
|
||||
|
||||
def load_model_runner_plugins():
|
||||
"""load_model_runner_plugins"""
|
||||
global plugins_loaded
|
||||
if plugins_loaded:
|
||||
return
|
||||
plugins_loaded = True
|
||||
|
||||
plugins = load_plugins_by_group(group=PLUGINS_GROUP)
|
||||
assert len(plugins) <= 1, "Most one plugin is allowed to be loaded."
|
||||
assert len(plugins) == 1, "Only one plugin is allowed to be loaded."
|
||||
return next(iter(plugins.values()))()
|
||||
|
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from fastdeploy.plugins.utils import load_plugins_by_group
|
||||
from fastdeploy.plugins.utils import load_plugins_by_group, plugins_loaded
|
||||
|
||||
# make sure one process only loads plugins once
|
||||
PLUGINS_GROUP = "fastdeploy.reasoning_parser_plugins"
|
||||
@@ -22,6 +22,12 @@ PLUGINS_GROUP = "fastdeploy.reasoning_parser_plugins"
|
||||
|
||||
def load_reasoning_parser_plugins():
|
||||
"""load_reasoning_parser_plugins"""
|
||||
global plugins_loaded
|
||||
if plugins_loaded:
|
||||
return
|
||||
plugins_loaded = True
|
||||
|
||||
plugins = load_plugins_by_group(group=PLUGINS_GROUP)
|
||||
assert len(plugins) <= 1, "Most one plugin is allowed to be loaded."
|
||||
return next(iter(plugins.values()))()
|
||||
# general plugins, we only need to execute the loaded functions
|
||||
for func in plugins.values():
|
||||
func()
|
||||
|
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from fastdeploy.plugins import load_reasoning_parser_plugins
|
||||
|
||||
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
|
||||
from .ernie_vl_reasoning_parsers import ErnieVLReasoningParser
|
||||
from .ernie_x1_reasoning_parsers import ErnieX1ReasoningParser
|
||||
@@ -26,3 +28,5 @@ __all__ = [
|
||||
"Qwen3ReasoningParser",
|
||||
"ErnieX1ReasoningParser",
|
||||
]
|
||||
|
||||
load_reasoning_parser_plugins()
|
||||
|
@@ -1,14 +1,6 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
#
|
||||
from collections.abc import Sequence
|
||||
from typing import Tuple, Union
|
||||
|
||||
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
|
||||
from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
@@ -20,6 +12,13 @@ from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Tuple, Union
|
||||
|
||||
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
|
||||
from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
|
||||
@ReasoningParserManager.register_module("ernie_x1")
|
||||
|
@@ -248,6 +248,11 @@ class PaddleDisWorkerProc:
|
||||
create=False,
|
||||
)
|
||||
|
||||
def _broadcast_model_weights_signal(self, src: int, group) -> int:
|
||||
model_weights_signal_tensor = paddle.full(shape=[1], fill_value=self.model_weights_signal[0], dtype="int32")
|
||||
paddle.distributed.broadcast(model_weights_signal_tensor, src=src, group=group)
|
||||
return model_weights_signal_tensor.item()
|
||||
|
||||
def event_loop_normal(self) -> None:
|
||||
"""Main event loop for Paddle Distributed Workers.
|
||||
TODO(gongshaotian): support remote calling of functions that control worker.
|
||||
@@ -257,15 +262,19 @@ class PaddleDisWorkerProc:
|
||||
req_ids = []
|
||||
num_running_requests = 0
|
||||
|
||||
self.model_weights_signal = paddle.zeros([1], dtype=paddle.int32)
|
||||
self.model_weights_signal = np.zeros([1], dtype=np.int32)
|
||||
while True:
|
||||
if self.local_rank % self.parallel_config.tensor_parallel_size == 0:
|
||||
if self.model_weights_status.value[0] != 0:
|
||||
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
|
||||
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel:
|
||||
paddle.distributed.broadcast(self.model_weights_signal, src=0, group=self.parallel_config.ep_group)
|
||||
if self.fd_config.load_config.dynamic_load_weight:
|
||||
paddle.distributed.broadcast(self.model_weights_signal, src=0, group=self.parallel_config.tp_group)
|
||||
self.model_weights_signal[0] = self._broadcast_model_weights_signal(
|
||||
src=0, group=self.parallel_config.ep_group
|
||||
)
|
||||
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.tensor_parallel_size > 1:
|
||||
self.model_weights_signal[0] = self._broadcast_model_weights_signal(
|
||||
src=0, group=self.parallel_config.tp_group
|
||||
)
|
||||
|
||||
self.insert_step = False
|
||||
req_dicts = None
|
||||
@@ -294,7 +303,9 @@ class PaddleDisWorkerProc:
|
||||
else:
|
||||
paddle.distributed.barrier(self.parallel_config.tp_group)
|
||||
if self.model_weights_signal[0] != 0:
|
||||
logger.info(f"Rank: {self.local_rank} has updated parameters.")
|
||||
logger.info(
|
||||
f"Rank: {self.local_rank} to update or clear parameters, signal is {self.model_weights_signal[0]}, [-1:clear, 1:update]"
|
||||
)
|
||||
from fastdeploy.rl.dynamic_weight_manager import (
|
||||
DynamicWeightManager,
|
||||
)
|
||||
@@ -307,6 +318,7 @@ class PaddleDisWorkerProc:
|
||||
self.parallel_config.engine_worker_queue_port,
|
||||
)
|
||||
self.model_weights_signal[0] = 0
|
||||
logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.")
|
||||
|
||||
if self.exist_task_signal.value[0] == 1 or self.task_queue.read_finish_flag.get() == 1:
|
||||
logger.info(f"Rank: {self.local_rank} Detected new requests.")
|
||||
|
Reference in New Issue
Block a user