diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 85f88cf12..3379e0cb7 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -623,6 +623,8 @@ int64_t open_mem_handle(paddle::Tensor& mem_handle); void free_shared_buffer(int64_t buffer); +void clear_ipc_handles(int64_t _fa); + // speculative decoding Kernel std::vector SpeculateGetPaddingOffset( const paddle::Tensor& input_ids, @@ -1229,6 +1231,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("free_shared_buffer", &free_shared_buffer, "free_shared_buffer"); + m.def("clear_ipc_handles", &clear_ipc_handles, "clear_ipc_handles"); + m.def("open_mem_handle", &open_mem_handle, "open_mem_handle"); m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "get_graph_buffer_ipc_meta"); diff --git a/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu b/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu index 0de212773..cb89cf79a 100644 --- a/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu +++ b/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu @@ -122,10 +122,14 @@ void register_graph_buffers(fptr_t _fa, for (int i = 0; i < handles.size(); i++) { bytes.emplace_back(handles[i].begin(), handles[i].end()); } - bytes.reserve(handles.size()); fa->register_graph_buffers(bytes, offsets); } +void clear_ipc_handles(fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + fa->clear_ipc_handles(); +} + std::tuple allocate_shared_buffer_and_handle( int64_t size) { diff --git a/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cuh b/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cuh index 341dbf5b5..fea3d63fe 100644 --- a/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cuh +++ b/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cuh @@ -517,10 +517,15 @@ class CustomAllreduce { #undef KL } - ~CustomAllreduce() { + void clear_ipc_handles(){ for (auto [_, ptr] : ipc_handles_) { CUDACHECK(cudaIpcCloseMemHandle(ptr)); } + ipc_handles_.clear(); + } + + ~CustomAllreduce() { + clear_ipc_handles(); } }; } // namespace paddle diff --git a/fastdeploy/distributed/communication.py b/fastdeploy/distributed/communication.py index 52f2fbddb..5c78d125c 100644 --- a/fastdeploy/distributed/communication.py +++ b/fastdeploy/distributed/communication.py @@ -42,6 +42,12 @@ def use_custom_allreduce(custom_all_reduce_max_bytes: int = 8192 * 1024): _TP_AR = CustomAllreduce(model_parallel_group, custom_all_reduce_max_bytes) +def custom_ar_clear_ipc_handles(): + global _TP_AR + if _TP_AR is not None: + _TP_AR.clear_ipc_handles() + + try: @paddle.jit.marker.unified diff --git a/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py b/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py index 9a38b728e..b2e61c71d 100644 --- a/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py +++ b/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py @@ -25,6 +25,7 @@ from paddle.distributed.communication.group import Group from fastdeploy.distributed.custom_all_reduce import cuda_wrapper from fastdeploy.model_executor.ops.gpu import ( all_reduce, + clear_ipc_handles, dispose, get_graph_buffer_ipc_meta, init_custom_all_reduce, @@ -220,6 +221,9 @@ class CustomAllreduce: else: return self.all_reduce(input, input, registered=False) + def clear_ipc_handles(self): + clear_ipc_handles(self._ptr) + def close(self): if self._ptr: dispose(self._ptr) diff --git a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py index 8c64fe3cd..d31cf7464 100644 --- a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py +++ b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py @@ -25,7 +25,10 @@ from paddle.device.cuda import graphs from fastdeploy import envs from fastdeploy.config import FDConfig -from fastdeploy.distributed.communication import capture_custom_allreduce +from fastdeploy.distributed.communication import ( + capture_custom_allreduce, + custom_ar_clear_ipc_handles, +) from fastdeploy.utils import get_logger logger = get_logger("cudagrpah_piecewise_backend", "cudagraph_piecewise_backend.log") @@ -227,6 +230,7 @@ class CudaGraphPiecewiseBackend: def clear_graph(self): """ """ # Clear graphs + custom_ar_clear_ipc_handles() for id, entry in self.concrete_size_entries.items(): if entry.cuda_graph: del entry.cuda_graph diff --git a/fastdeploy/rl/dynamic_weight_manager.py b/fastdeploy/rl/dynamic_weight_manager.py index c5422843b..eec76dda4 100644 --- a/fastdeploy/rl/dynamic_weight_manager.py +++ b/fastdeploy/rl/dynamic_weight_manager.py @@ -66,6 +66,7 @@ class DynamicWeightManager: # step1 : restart paddle process group if not self.first_load: + paddle.distributed.restart_process_group() paddle.distributed.restart_process_group(self.parallel_config.tp_group) if self.parallel_config.enable_expert_parallel: paddle.distributed.restart_process_group(self.parallel_config.ep_group) @@ -148,6 +149,7 @@ class DynamicWeightManager: if self.parallel_config.enable_expert_parallel: paddle.distributed.barrier(self.parallel_config.ep_group) paddle.distributed.shutdown_process_group(self.parallel_config.ep_group) + paddle.distributed.shutdown_process_group() self._update_shared_status(pid, ModelWeightsStatus.CLEARED) def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor], src_type: str):