mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature] support custom all-reduce (#2758)
* [Feature] support custom all-reduce * add vllm adapted
This commit is contained in:
167
fastdeploy/distributed/custom_all_reduce/cuda_wrapper.py
Normal file
167
fastdeploy/distributed/custom_all_reduce/cuda_wrapper.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""This file is a pure Python wrapper for the cudart library.
|
||||
It avoids the need to compile a separate shared library, and is
|
||||
convenient for use when we just need to call a few functions.
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
cudaError_t = ctypes.c_int
|
||||
cudaMemcpyKind = ctypes.c_int
|
||||
|
||||
|
||||
class cudaIpcMemHandle_t(ctypes.Structure):
|
||||
_fields_ = [("internal", ctypes.c_byte * 128)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: List[Any]
|
||||
|
||||
|
||||
def find_loaded_library(lib_name) -> Optional[str]:
|
||||
"""
|
||||
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
|
||||
the file `/proc/self/maps` contains the memory maps of the process, which includes the
|
||||
shared libraries loaded by the process. We can use this file to find the path of the
|
||||
a loaded library.
|
||||
""" # noqa
|
||||
found = False
|
||||
with open("/proc/self/maps") as f:
|
||||
for line in f:
|
||||
if lib_name in line:
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
# the library is not loaded in the current process
|
||||
return None
|
||||
# if lib_name is libcudart, we need to match a line with:
|
||||
# address /path/to/libcudart-hash.so.11.0
|
||||
start = line.index("/")
|
||||
path = line[start:].strip()
|
||||
filename = path.split("/")[-1]
|
||||
assert filename.rpartition(".so")[0].startswith(
|
||||
lib_name
|
||||
), f"Unexpected filename: {filename} for library {lib_name}"
|
||||
return path
|
||||
|
||||
|
||||
class CudaRTLibrary:
|
||||
exported_functions = [
|
||||
# cudaError_t cudaSetDevice ( int device )
|
||||
Function("cudaSetDevice", cudaError_t, [ctypes.c_int]),
|
||||
# cudaError_t cudaDeviceSynchronize ( void )
|
||||
Function("cudaDeviceSynchronize", cudaError_t, []),
|
||||
# cudaError_t cudaDeviceReset ( void )
|
||||
Function("cudaDeviceReset", cudaError_t, []),
|
||||
# const char* cudaGetErrorString ( cudaError_t error )
|
||||
Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
|
||||
# cudaError_t cudaMalloc ( void** devPtr, size_t size )
|
||||
Function("cudaMalloc", cudaError_t, [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]),
|
||||
# cudaError_t cudaFree ( void* devPtr )
|
||||
Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
|
||||
# cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
|
||||
Function("cudaMemset", cudaError_t, [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]),
|
||||
# cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
|
||||
Function("cudaMemcpy", cudaError_t, [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind]),
|
||||
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
|
||||
Function("cudaIpcGetMemHandle", cudaError_t, [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]),
|
||||
# cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
|
||||
Function(
|
||||
"cudaIpcOpenMemHandle", cudaError_t, [ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint]
|
||||
),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: Dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the corresponding dictionary
|
||||
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
if so_file is None:
|
||||
so_file = find_loaded_library("libcudart")
|
||||
if so_file is None:
|
||||
pass
|
||||
# so_file = envs.VLLM_CUDART_SO_PATH # fallback to env var
|
||||
assert so_file is not None, (
|
||||
"libcudart is not loaded in the current process, " "try setting VLLM_CUDART_SO_PATH"
|
||||
)
|
||||
if so_file not in CudaRTLibrary.path_to_library_cache:
|
||||
lib = ctypes.CDLL(so_file)
|
||||
CudaRTLibrary.path_to_library_cache[so_file] = lib
|
||||
self.lib = CudaRTLibrary.path_to_library_cache[so_file]
|
||||
|
||||
if so_file not in CudaRTLibrary.path_to_dict_mapping:
|
||||
_funcs = {}
|
||||
for func in CudaRTLibrary.exported_functions:
|
||||
f = getattr(self.lib, func.name)
|
||||
f.restype = func.restype
|
||||
f.argtypes = func.argtypes
|
||||
_funcs[func.name] = f
|
||||
CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
|
||||
self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]
|
||||
|
||||
def CUDART_CHECK(self, result: cudaError_t) -> None:
|
||||
if result != 0:
|
||||
error_str = self.cudaGetErrorString(result)
|
||||
raise RuntimeError(f"CUDART error: {error_str}")
|
||||
|
||||
def cudaGetErrorString(self, error: cudaError_t) -> str:
|
||||
return self.funcs["cudaGetErrorString"](error).decode("utf-8")
|
||||
|
||||
def cudaSetDevice(self, device: int) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaSetDevice"](device))
|
||||
|
||||
def cudaDeviceSynchronize(self) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]())
|
||||
|
||||
def cudaDeviceReset(self) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaDeviceReset"]())
|
||||
|
||||
def cudaMalloc(self, size: int) -> ctypes.c_void_p:
|
||||
devPtr = ctypes.c_void_p()
|
||||
self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size))
|
||||
return devPtr
|
||||
|
||||
def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))
|
||||
|
||||
def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, count: int) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))
|
||||
|
||||
def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p, count: int) -> None:
|
||||
cudaMemcpyDefault = 4
|
||||
kind = cudaMemcpyDefault
|
||||
self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))
|
||||
|
||||
def cudaIpcGetMemHandle(self, devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
|
||||
handle = cudaIpcMemHandle_t()
|
||||
self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"](ctypes.byref(handle), devPtr))
|
||||
return handle
|
||||
|
||||
def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
|
||||
cudaIpcMemLazyEnablePeerAccess = 1
|
||||
devPtr = ctypes.c_void_p()
|
||||
self.CUDART_CHECK(
|
||||
self.funcs["cudaIpcOpenMemHandle"](ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess)
|
||||
)
|
||||
return devPtr
|
Reference in New Issue
Block a user