// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu // Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "helper.h" #include "all_reduce.cuh" // Fake pointer type, must match fptr_t type in ops.h. // We use this type alias to indicate when pointers are passed in as int64_t. using fptr_t = int64_t; static_assert(sizeof(void*) == sizeof(fptr_t)); fptr_t init_custom_all_reduce(const std::vector& fake_ipc_ptrs, paddle::Tensor& rank_data, int64_t rank, bool full_nvlink) { int world_size = fake_ipc_ptrs.size(); if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported"); if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now"); if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in"); paddle::Signal* ipc_ptrs[8]; for (int i = 0; i < world_size; i++) { ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); } return (fptr_t) new paddle::CustomAllreduce(ipc_ptrs, rank_data.data(), rank_data.numel(), rank, world_size, full_nvlink); } /** * Performs an out-of-place allreduce and stores result in out. * * If _reg_buffer is null, assumes inp.data() is already IPC-registered. * Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first * copied into _reg_buffer. */ void all_reduce(paddle::Tensor& inp, paddle::Tensor& out, fptr_t _fa, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) { auto fa = reinterpret_cast(_fa); auto stream = inp.stream(); auto input_size = inp.numel() * 2; auto reg_buffer = reinterpret_cast(_reg_buffer); if (reg_buffer) { cudaMemcpyAsync(reg_buffer, inp.data(), input_size, cudaMemcpyDeviceToDevice, stream); } else { reg_buffer = inp.data(); } switch (out.dtype()) { case phi::DataType::FLOAT32: { fa->allreduce(stream, reinterpret_cast(reg_buffer), reinterpret_cast(out.data()), out.numel()); break; } case phi::DataType::FLOAT16: { fa->allreduce(stream, reinterpret_cast(reg_buffer), reinterpret_cast(out.data()), out.numel()); break; } #if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800) case phi::DataType::BFLOAT16: { fa->allreduce( stream, reinterpret_cast(reg_buffer), reinterpret_cast(out.data()), out.numel()); break; } #endif default: throw std::runtime_error( "custom allreduce only supports float32, float16 and bfloat16"); } } void dispose(fptr_t _fa) { delete reinterpret_cast(_fa); } int64_t meta_size() { return sizeof(paddle::Signal); } void register_buffer(fptr_t _fa, const std::vector& fake_ipc_ptrs) { auto fa = reinterpret_cast(_fa); void* ipc_ptrs[8]; for (int i = 0; i < fake_ipc_ptrs.size(); i++) { ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); } fa->register_buffer(ipc_ptrs); } // Use vector to represent byte data for python binding compatibility. std::tuple, std::vector> get_graph_buffer_ipc_meta(fptr_t _fa) { auto fa = reinterpret_cast(_fa); auto [handle, offsets] = fa->get_graph_buffer_ipc_meta(); std::vector bytes(handle.begin(), handle.end()); return std::make_tuple(bytes, offsets); } // Use vector to represent byte data for python binding compatibility. void register_graph_buffers(fptr_t _fa, const std::vector>& handles, const std::vector>& offsets) { auto fa = reinterpret_cast(_fa); std::vector bytes; bytes.reserve(handles.size()); for (int i = 0; i < handles.size(); i++) { bytes.emplace_back(handles[i].begin(), handles[i].end()); } bytes.reserve(handles.size()); fa->register_graph_buffers(bytes, offsets); } std::tuple allocate_shared_buffer_and_handle( int64_t size) { auto device_index = phi::backends::gpu::GetCurrentDeviceId(); void* buffer; cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed; auto stream = paddle::GetCurrentCUDAStream(phi::GPUPlace(device_index))->raw_stream(); CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode)); // Allocate buffer CUDACHECK(cudaMalloc((void**)&buffer, size)); CUDACHECK(cudaMemsetAsync(buffer, 0, size, stream)); CUDACHECK(cudaStreamSynchronize(stream)); CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode)); // Create IPC memhandle for the allocated buffer. // Will use it in open_mem_handle. auto handle = paddle::empty({static_cast(sizeof(cudaIpcMemHandle_t))}, paddle::DataType::UINT8, paddle::GPUPlace(device_index)); CUDACHECK( cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data(), buffer)); return std::make_tuple(reinterpret_cast(buffer), handle); } fptr_t open_mem_handle(paddle::Tensor& mem_handle) { void* ipc_ptr; CUDACHECK(cudaIpcOpenMemHandle( (void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data()), cudaIpcMemLazyEnablePeerAccess)); return reinterpret_cast(ipc_ptr); } void free_shared_buffer(fptr_t buffer) { CUDACHECK(cudaFree(reinterpret_cast(buffer))); } PD_BUILD_STATIC_OP(all_reduce) .Inputs({"inp", "out"}) .Outputs({"new_out"}) .Attrs({"_fa: int64_t", "_reg_buffer: int64_t", "reg_buffer_sz_bytes: int64_t"}) .SetInplaceMap({{"out", "new_out"}}) .SetKernelFn(PD_KERNEL(all_reduce));