Add custom op declaration for all_reduce (#3473)

* add custom op declaration

* roll back try except
This commit is contained in:
Ryan
2025-08-20 20:29:58 +08:00
committed by GitHub
parent 33ff0bfe38
commit bcdfc1d6b9
4 changed files with 16 additions and 4 deletions

View File

@@ -530,7 +530,7 @@ paddle::Tensor FusedHadamardQuantFp8Func(
int64_t init_custom_all_reduce(const std::vector<int64_t>& fake_ipc_ptrs, int64_t init_custom_all_reduce(const std::vector<int64_t>& fake_ipc_ptrs,
paddle::Tensor& rank_data, int64_t rank, bool full_nvlink); paddle::Tensor& rank_data, int64_t rank, bool full_nvlink);
void all_reduce(int64_t _fa, paddle::Tensor& inp, paddle::Tensor& out, void all_reduce(paddle::Tensor& inp, paddle::Tensor& out, int64_t _fa,
int64_t reg_buffer, int64_t reg_buffer_sz_bytes); int64_t reg_buffer, int64_t reg_buffer_sz_bytes);
void dispose(int64_t _fa); void dispose(int64_t _fa);

View File

@@ -49,7 +49,7 @@ fptr_t init_custom_all_reduce(const std::vector<fptr_t>& fake_ipc_ptrs,
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first * Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
* copied into _reg_buffer. * copied into _reg_buffer.
*/ */
void all_reduce(fptr_t _fa, paddle::Tensor& inp, paddle::Tensor& out, void all_reduce(paddle::Tensor& inp, paddle::Tensor& out, fptr_t _fa,
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) { fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa); auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
auto stream = inp.stream(); auto stream = inp.stream();
@@ -163,3 +163,12 @@ fptr_t open_mem_handle(paddle::Tensor& mem_handle) {
void free_shared_buffer(fptr_t buffer) { void free_shared_buffer(fptr_t buffer) {
CUDACHECK(cudaFree(reinterpret_cast<void*>(buffer))); CUDACHECK(cudaFree(reinterpret_cast<void*>(buffer)));
} }
PD_BUILD_STATIC_OP(all_reduce)
.Inputs({"inp",
"out"})
.Outputs({"new_out"})
.Attrs({"_fa: int64_t", "_reg_buffer: int64_t", "reg_buffer_sz_bytes: int64_t"})
.SetInplaceMap({{"out", "new_out"}})
.SetKernelFn(PD_KERNEL(all_reduce));

View File

@@ -158,9 +158,9 @@ class CustomAllreduce:
if out is None: if out is None:
out = paddle.empty_like(inp) out = paddle.empty_like(inp)
if registered: if registered:
all_reduce(self._ptr, inp, out, 0, 0) all_reduce(inp, out, self._ptr, 0, 0)
else: else:
all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size) all_reduce(inp, out, self._ptr, self.buffer_ptrs[self.rank], self.max_size)
return out return out
def start_capture(self): def start_capture(self):

View File

@@ -89,6 +89,9 @@ class MLAAttentionMetadata(AttentionMetadata):
kv_signal_metadata: Optional[paddle.Tensor] = None kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list) kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
max_enc_len_this_time: Optional[paddle.Tensor] = None
max_dec_len_this_time: Optional[paddle.Tensor] = None
class MLAAttentionBackend(AttentionBackend): class MLAAttentionBackend(AttentionBackend):
""" """