mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[feature]2.2 custom_allreduce support cudagraph recapture (#4307)
* custom_allreduce support cudagraph recapture * delete code * add shut_down/restart default group
This commit is contained in:
@@ -616,6 +616,8 @@ int64_t open_mem_handle(paddle::Tensor& mem_handle);
|
||||
|
||||
void free_shared_buffer(int64_t buffer);
|
||||
|
||||
void clear_ipc_handles(int64_t _fa);
|
||||
|
||||
// speculative decoding Kernel
|
||||
std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
|
||||
const paddle::Tensor& input_ids,
|
||||
@@ -1204,6 +1206,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
|
||||
m.def("free_shared_buffer", &free_shared_buffer, "free_shared_buffer");
|
||||
|
||||
m.def("clear_ipc_handles", &clear_ipc_handles, "clear_ipc_handles");
|
||||
|
||||
m.def("open_mem_handle", &open_mem_handle, "open_mem_handle");
|
||||
|
||||
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "get_graph_buffer_ipc_meta");
|
||||
|
@@ -122,10 +122,14 @@ void register_graph_buffers(fptr_t _fa,
|
||||
for (int i = 0; i < handles.size(); i++) {
|
||||
bytes.emplace_back(handles[i].begin(), handles[i].end());
|
||||
}
|
||||
bytes.reserve(handles.size());
|
||||
fa->register_graph_buffers(bytes, offsets);
|
||||
}
|
||||
|
||||
void clear_ipc_handles(fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
|
||||
fa->clear_ipc_handles();
|
||||
}
|
||||
|
||||
std::tuple<fptr_t, paddle::Tensor> allocate_shared_buffer_and_handle(
|
||||
int64_t size) {
|
||||
|
||||
|
@@ -517,10 +517,15 @@ class CustomAllreduce {
|
||||
#undef KL
|
||||
}
|
||||
|
||||
~CustomAllreduce() {
|
||||
void clear_ipc_handles(){
|
||||
for (auto [_, ptr] : ipc_handles_) {
|
||||
CUDACHECK(cudaIpcCloseMemHandle(ptr));
|
||||
}
|
||||
ipc_handles_.clear();
|
||||
}
|
||||
|
||||
~CustomAllreduce() {
|
||||
clear_ipc_handles();
|
||||
}
|
||||
};
|
||||
} // namespace paddle
|
||||
|
@@ -42,6 +42,12 @@ def use_custom_allreduce(custom_all_reduce_max_bytes: int = 8192 * 1024):
|
||||
_TP_AR = CustomAllreduce(model_parallel_group, custom_all_reduce_max_bytes)
|
||||
|
||||
|
||||
def custom_ar_clear_ipc_handles():
|
||||
global _TP_AR
|
||||
if _TP_AR is not None:
|
||||
_TP_AR.clear_ipc_handles()
|
||||
|
||||
|
||||
try:
|
||||
|
||||
@paddle.jit.marker.unified
|
||||
|
@@ -25,6 +25,7 @@ from paddle.distributed.communication.group import Group
|
||||
from fastdeploy.distributed.custom_all_reduce import cuda_wrapper
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
all_reduce,
|
||||
clear_ipc_handles,
|
||||
dispose,
|
||||
get_graph_buffer_ipc_meta,
|
||||
init_custom_all_reduce,
|
||||
@@ -220,6 +221,9 @@ class CustomAllreduce:
|
||||
else:
|
||||
return self.all_reduce(input, input, registered=False)
|
||||
|
||||
def clear_ipc_handles(self):
|
||||
clear_ipc_handles(self._ptr)
|
||||
|
||||
def close(self):
|
||||
if self._ptr:
|
||||
dispose(self._ptr)
|
||||
|
@@ -23,7 +23,10 @@ import paddle.nn.layer
|
||||
from paddle.device.cuda import graphs
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.distributed.communication import capture_custom_allreduce
|
||||
from fastdeploy.distributed.communication import (
|
||||
capture_custom_allreduce,
|
||||
custom_ar_clear_ipc_handles,
|
||||
)
|
||||
from fastdeploy.utils import get_logger
|
||||
|
||||
logger = get_logger("cudagrpah_piecewise_backend", "cudagraph_piecewise_backend.log")
|
||||
@@ -208,6 +211,7 @@ class CudaGraphPiecewiseBackend:
|
||||
def clear_graph(self):
|
||||
""" """
|
||||
# Clear graphs
|
||||
custom_ar_clear_ipc_handles()
|
||||
for id, entry in self.concrete_size_entries.items():
|
||||
if entry.cuda_graph:
|
||||
del entry.cuda_graph
|
||||
|
@@ -66,6 +66,7 @@ class DynamicWeightManager:
|
||||
paddle.device.cuda.empty_cache()
|
||||
|
||||
if not self.first_load:
|
||||
paddle.distributed.restart_process_group()
|
||||
paddle.distributed.restart_process_group(self.parallel_config.tp_group)
|
||||
if self.parallel_config.enable_expert_parallel:
|
||||
paddle.distributed.restart_process_group(self.parallel_config.ep_group)
|
||||
|
Reference in New Issue
Block a user