mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
Add custom op declaration for all_reduce
(#3473)
* add custom op declaration * roll back try except
This commit is contained in:
@@ -530,7 +530,7 @@ paddle::Tensor FusedHadamardQuantFp8Func(
|
||||
int64_t init_custom_all_reduce(const std::vector<int64_t>& fake_ipc_ptrs,
|
||||
paddle::Tensor& rank_data, int64_t rank, bool full_nvlink);
|
||||
|
||||
void all_reduce(int64_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
|
||||
void all_reduce(paddle::Tensor& inp, paddle::Tensor& out, int64_t _fa,
|
||||
int64_t reg_buffer, int64_t reg_buffer_sz_bytes);
|
||||
|
||||
void dispose(int64_t _fa);
|
||||
|
@@ -49,7 +49,7 @@ fptr_t init_custom_all_reduce(const std::vector<fptr_t>& fake_ipc_ptrs,
|
||||
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
|
||||
* copied into _reg_buffer.
|
||||
*/
|
||||
void all_reduce(fptr_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
|
||||
void all_reduce(paddle::Tensor& inp, paddle::Tensor& out, fptr_t _fa,
|
||||
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
|
||||
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
|
||||
auto stream = inp.stream();
|
||||
@@ -163,3 +163,12 @@ fptr_t open_mem_handle(paddle::Tensor& mem_handle) {
|
||||
void free_shared_buffer(fptr_t buffer) {
|
||||
CUDACHECK(cudaFree(reinterpret_cast<void*>(buffer)));
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(all_reduce)
|
||||
.Inputs({"inp",
|
||||
"out"})
|
||||
.Outputs({"new_out"})
|
||||
.Attrs({"_fa: int64_t", "_reg_buffer: int64_t", "reg_buffer_sz_bytes: int64_t"})
|
||||
.SetInplaceMap({{"out", "new_out"}})
|
||||
.SetKernelFn(PD_KERNEL(all_reduce));
|
||||
|
@@ -158,9 +158,9 @@ class CustomAllreduce:
|
||||
if out is None:
|
||||
out = paddle.empty_like(inp)
|
||||
if registered:
|
||||
all_reduce(self._ptr, inp, out, 0, 0)
|
||||
all_reduce(inp, out, self._ptr, 0, 0)
|
||||
else:
|
||||
all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size)
|
||||
all_reduce(inp, out, self._ptr, self.buffer_ptrs[self.rank], self.max_size)
|
||||
return out
|
||||
|
||||
def start_capture(self):
|
||||
|
@@ -89,6 +89,9 @@ class MLAAttentionMetadata(AttentionMetadata):
|
||||
kv_signal_metadata: Optional[paddle.Tensor] = None
|
||||
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
|
||||
|
||||
max_enc_len_this_time: Optional[paddle.Tensor] = None
|
||||
max_dec_len_this_time: Optional[paddle.Tensor] = None
|
||||
|
||||
|
||||
class MLAAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user