mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
custom all reduce support cuda graph (#2938)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* Support enabling cuda graph and custom all reduce at the same time, and fix the overwritten custom all reduce flag * rename communication_op to communication
This commit is contained in:
@@ -201,6 +201,8 @@ class ParallelConfig:
|
|||||||
# disable any whitespace for guided decoding
|
# disable any whitespace for guided decoding
|
||||||
self.disable_any_whitespace: bool = True
|
self.disable_any_whitespace: bool = True
|
||||||
self.pod_ip: str = None
|
self.pod_ip: str = None
|
||||||
|
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
|
||||||
|
self.enable_custom_all_reduce: bool = False
|
||||||
for key, value in args.items():
|
for key, value in args.items():
|
||||||
if hasattr(self, key):
|
if hasattr(self, key):
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
@@ -213,8 +215,6 @@ class ParallelConfig:
|
|||||||
self.moe_phase = MoEPhase.DECODER
|
self.moe_phase = MoEPhase.DECODER
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
|
|
||||||
self.enable_custom_all_reduce: bool = False
|
|
||||||
|
|
||||||
# pd_disaggregation
|
# pd_disaggregation
|
||||||
use_pd_disaggregation: int = int(os.getenv("FLAGS_use_pd_disaggregation", 0))
|
use_pd_disaggregation: int = int(os.getenv("FLAGS_use_pd_disaggregation", 0))
|
||||||
|
@@ -17,12 +17,23 @@
|
|||||||
import paddle
|
import paddle
|
||||||
import paddle.distributed as dist
|
import paddle.distributed as dist
|
||||||
from paddle.distributed import fleet
|
from paddle.distributed import fleet
|
||||||
|
from contextlib import contextmanager, nullcontext
|
||||||
|
|
||||||
from fastdeploy.distributed.parallel_state import get_tensor_model_parallel_world_size
|
from fastdeploy.distributed.parallel_state import get_tensor_model_parallel_world_size
|
||||||
|
|
||||||
_TP_AR = None
|
_TP_AR = None
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def capture_custom_allreduce():
|
||||||
|
global _TP_AR
|
||||||
|
ar_context = nullcontext()
|
||||||
|
if _TP_AR is not None:
|
||||||
|
ar_context = _TP_AR.capture()
|
||||||
|
with ar_context:
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
def use_custom_allreduce(custom_all_reduce_max_bytes: int = 8192 * 1024):
|
def use_custom_allreduce(custom_all_reduce_max_bytes: int = 8192 * 1024):
|
||||||
hcg = fleet.get_hybrid_communicate_group()
|
hcg = fleet.get_hybrid_communicate_group()
|
||||||
model_parallel_group = hcg.get_model_parallel_group()
|
model_parallel_group = hcg.get_model_parallel_group()
|
@@ -22,6 +22,8 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
cudaError_t = ctypes.c_int
|
cudaError_t = ctypes.c_int
|
||||||
cudaMemcpyKind = ctypes.c_int
|
cudaMemcpyKind = ctypes.c_int
|
||||||
|
cudaStream_t = ctypes.c_void_p
|
||||||
|
cudaStreamCaptureStatus = ctypes.c_int
|
||||||
|
|
||||||
|
|
||||||
class cudaIpcMemHandle_t(ctypes.Structure):
|
class cudaIpcMemHandle_t(ctypes.Structure):
|
||||||
@@ -108,6 +110,14 @@ class CudaRTLibrary:
|
|||||||
ctypes.c_uint,
|
ctypes.c_uint,
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
Function(
|
||||||
|
"cudaStreamIsCapturing",
|
||||||
|
cudaError_t,
|
||||||
|
[
|
||||||
|
cudaStream_t,
|
||||||
|
ctypes.POINTER(cudaStreamCaptureStatus)
|
||||||
|
]
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
# class attribute to store the mapping from the path to the library
|
# class attribute to store the mapping from the path to the library
|
||||||
@@ -187,3 +197,9 @@ class CudaRTLibrary:
|
|||||||
self.funcs["cudaIpcOpenMemHandle"](ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess)
|
self.funcs["cudaIpcOpenMemHandle"](ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess)
|
||||||
)
|
)
|
||||||
return devPtr
|
return devPtr
|
||||||
|
def cudaStreamIsCapturing(self, stream: cudaStream_t) -> ctypes.c_int:
|
||||||
|
is_capturing = ctypes.c_int()
|
||||||
|
self.CUDART_CHECK(
|
||||||
|
self.funcs["cudaStreamIsCapturing"](stream, is_capturing)
|
||||||
|
)
|
||||||
|
return is_capturing
|
||||||
|
@@ -56,8 +56,7 @@ class CustomAllreduce:
|
|||||||
is bind to a unique device, and all communicators in this group
|
is bind to a unique device, and all communicators in this group
|
||||||
are in the same node.
|
are in the same node.
|
||||||
"""
|
"""
|
||||||
self._IS_CAPTURING = False
|
self.capturing = False
|
||||||
self.disabled = True
|
|
||||||
self.group = group
|
self.group = group
|
||||||
|
|
||||||
if not custom_ar:
|
if not custom_ar:
|
||||||
@@ -78,8 +77,6 @@ class CustomAllreduce:
|
|||||||
if world_size < 2:
|
if world_size < 2:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.disabled = False
|
|
||||||
|
|
||||||
# Buffers memory are owned by this Python class and passed to C++.
|
# Buffers memory are owned by this Python class and passed to C++.
|
||||||
# Meta data composes of two parts: meta data for synchronization and a
|
# Meta data composes of two parts: meta data for synchronization and a
|
||||||
# temporary buffer for storing intermediate allreduce results.
|
# temporary buffer for storing intermediate allreduce results.
|
||||||
@@ -95,13 +92,13 @@ class CustomAllreduce:
|
|||||||
# is enough for 131072 such tuples. The largest model I've seen only
|
# is enough for 131072 such tuples. The largest model I've seen only
|
||||||
# needs less than 10000 of registered tuples.
|
# needs less than 10000 of registered tuples.
|
||||||
self.rank_data = paddle.empty([8 * 1024 * 1024], dtype=paddle.uint8)
|
self.rank_data = paddle.empty([8 * 1024 * 1024], dtype=paddle.uint8)
|
||||||
|
|
||||||
self.max_size = max_size
|
self.max_size = max_size
|
||||||
self.rank = rank
|
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.full_nvlink = True
|
self.full_nvlink = True
|
||||||
self._ptr = init_custom_all_reduce(self.meta_ptrs, self.rank_data, rank, self.full_nvlink)
|
self._ptr = init_custom_all_reduce(self.meta_ptrs, self.rank_data, rank, self.full_nvlink)
|
||||||
register_buffer(self._ptr, self.buffer_ptrs)
|
register_buffer(self._ptr, self.buffer_ptrs)
|
||||||
print("zss init custom allreduce", self._ptr)
|
|
||||||
_instances.append(self)
|
_instances.append(self)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -112,7 +109,6 @@ class CustomAllreduce:
|
|||||||
"""
|
"""
|
||||||
lib = cuda_wrapper.CudaRTLibrary()
|
lib = cuda_wrapper.CudaRTLibrary()
|
||||||
pointer = lib.cudaMalloc(size_in_bytes)
|
pointer = lib.cudaMalloc(size_in_bytes)
|
||||||
# lib.cudaMemset(pointer, 2, size_in_bytes)
|
|
||||||
handle = lib.cudaIpcGetMemHandle(pointer)
|
handle = lib.cudaIpcGetMemHandle(pointer)
|
||||||
rank = dist.get_rank(group=group)
|
rank = dist.get_rank(group=group)
|
||||||
handles = []
|
handles = []
|
||||||
@@ -135,8 +131,8 @@ class CustomAllreduce:
|
|||||||
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
|
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
|
||||||
|
|
||||||
def should_custom_ar(self, inp: paddle.Tensor):
|
def should_custom_ar(self, inp: paddle.Tensor):
|
||||||
if self.disabled:
|
if self.capturing:
|
||||||
return False
|
return True
|
||||||
inp_size = inp.numel() * inp.element_size()
|
inp_size = inp.numel() * inp.element_size()
|
||||||
# custom allreduce requires input byte size to be multiples of 16
|
# custom allreduce requires input byte size to be multiples of 16
|
||||||
if inp_size % 16 != 0:
|
if inp_size % 16 != 0:
|
||||||
@@ -167,6 +163,19 @@ class CustomAllreduce:
|
|||||||
all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size)
|
all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def start_capture(self):
|
||||||
|
"""
|
||||||
|
set CUDA graph flag: True.
|
||||||
|
"""
|
||||||
|
self.capturing = True
|
||||||
|
|
||||||
|
def stop_capture(self):
|
||||||
|
"""
|
||||||
|
set CUDA graph flag: False and register the graph buffers.
|
||||||
|
"""
|
||||||
|
self.capturing = False
|
||||||
|
self.register_graph_buffers()
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def capture(self):
|
def capture(self):
|
||||||
"""
|
"""
|
||||||
@@ -175,44 +184,44 @@ class CustomAllreduce:
|
|||||||
It records all the buffer addresses used in the CUDA graph.
|
It records all the buffer addresses used in the CUDA graph.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self._IS_CAPTURING = True
|
self.capturing = True
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
self._IS_CAPTURING = False
|
self.capturing = False
|
||||||
if not self.disabled:
|
self.register_graph_buffers()
|
||||||
self.register_graph_buffers()
|
|
||||||
|
|
||||||
def register_graph_buffers(self):
|
def register_graph_buffers(self):
|
||||||
|
"""
|
||||||
|
Register the graph buffers collected CUDA graph during capture.
|
||||||
|
"""
|
||||||
handle, offset = get_graph_buffer_ipc_meta(self._ptr)
|
handle, offset = get_graph_buffer_ipc_meta(self._ptr)
|
||||||
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
|
all_datas = []
|
||||||
all_data[self.rank] = [handle, offset]
|
all_data = [handle, offset]
|
||||||
|
|
||||||
ranks = sorted(dist.get_process_group_ranks(group=self.group))
|
dist.all_gather_object(all_datas, all_data, group=self.group)
|
||||||
for i, rank in enumerate(ranks):
|
|
||||||
dist.broadcast_object_list(all_data[i], src=rank, group=self.group, device="cpu")
|
|
||||||
|
|
||||||
# Unpack list of tuples to tuple of lists.
|
handles = [d[0] for d in all_datas]
|
||||||
handles = [d[0] for d in all_data] # type: ignore
|
offsets = [d[1] for d in all_datas]
|
||||||
offsets = [d[1] for d in all_data] # type: ignore
|
|
||||||
register_graph_buffers(self._ptr, handles, offsets)
|
register_graph_buffers(self._ptr, handles, offsets)
|
||||||
|
|
||||||
def custom_all_reduce(self, input: paddle.Tensor) -> Optional[paddle.Tensor]:
|
def custom_all_reduce(self, input: paddle.Tensor) -> Optional[paddle.Tensor]:
|
||||||
"""The main allreduce API that provides support for cuda graph."""
|
"""The main allreduce API that provides support for cuda graph."""
|
||||||
# When custom allreduce is disabled, this will be None.
|
if self.capturing:
|
||||||
if self.disabled or not self.should_custom_ar(input):
|
lib = cuda_wrapper.CudaRTLibrary()
|
||||||
return None
|
stream = paddle.device.current_stream()
|
||||||
if self._IS_CAPTURING:
|
stream_capturing = lib.cudaStreamIsCapturing(stream)
|
||||||
if paddle.cuda.is_current_stream_capturing():
|
if stream_capturing.value == 1:
|
||||||
return self.all_reduce(input, registered=True)
|
# 1 is cudaStreamCaptureStatusActive: The stream is capturing.
|
||||||
|
return self.all_reduce(input, input, registered=True)
|
||||||
else:
|
else:
|
||||||
# If warm up, mimic the allocation pattern since custom
|
# If warm up, mimic the allocation pattern since custom
|
||||||
# allreduce is out-of-place.
|
# allreduce is out-of-place.
|
||||||
return paddle.empty_like(input)
|
return paddle.empty_like(input)
|
||||||
else:
|
else:
|
||||||
return self.all_reduce(input, registered=False)
|
return self.all_reduce(input, input, registered=False)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if not self.disabled and self._ptr:
|
if self._ptr:
|
||||||
dispose(self._ptr)
|
dispose(self._ptr)
|
||||||
self._ptr = 0
|
self._ptr = 0
|
||||||
self.free_shared_buffer(self.group, self.meta_ptrs, rank=self.rank)
|
self.free_shared_buffer(self.group, self.meta_ptrs, rank=self.rank)
|
||||||
|
@@ -898,6 +898,5 @@ class EngineArgs:
|
|||||||
graph_optimization_config=graph_opt_cfg,
|
graph_optimization_config=graph_opt_cfg,
|
||||||
guided_decoding_backend=self.guided_decoding_backend,
|
guided_decoding_backend=self.guided_decoding_backend,
|
||||||
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
|
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
|
||||||
enable_custom_all_reduce=self.enable_custom_all_reduce,
|
|
||||||
enable_logprob=self.enable_logprob,
|
enable_logprob=self.enable_logprob,
|
||||||
)
|
)
|
||||||
|
@@ -656,7 +656,6 @@ class Config:
|
|||||||
reasoning_parser: str = None,
|
reasoning_parser: str = None,
|
||||||
guided_decoding_backend: Optional[str] = None,
|
guided_decoding_backend: Optional[str] = None,
|
||||||
disable_any_whitespace: bool = False,
|
disable_any_whitespace: bool = False,
|
||||||
enable_custom_all_reduce: bool = False,
|
|
||||||
enable_logprob: bool = False,
|
enable_logprob: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@@ -1024,7 +1024,7 @@ class LLMEngine:
|
|||||||
"do_profile": self.do_profile,
|
"do_profile": self.do_profile,
|
||||||
"dynamic_load_weight": self.cfg.model_config.dynamic_load_weight,
|
"dynamic_load_weight": self.cfg.model_config.dynamic_load_weight,
|
||||||
"disable_any_whitespace": self.cfg.disable_any_whitespace,
|
"disable_any_whitespace": self.cfg.disable_any_whitespace,
|
||||||
"enable-custom-all-reduce": self.cfg.parallel_config.enable_custom_all_reduce,
|
"enable_custom_all_reduce": self.cfg.parallel_config.enable_custom_all_reduce,
|
||||||
"enable_logprob": self.cfg.enable_logprob,
|
"enable_logprob": self.cfg.enable_logprob,
|
||||||
"enable_mm": self.cfg.enable_mm,
|
"enable_mm": self.cfg.enable_mm,
|
||||||
}
|
}
|
||||||
|
@@ -22,6 +22,7 @@ from paddle.device.cuda import graphs
|
|||||||
|
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
from fastdeploy.utils import get_logger
|
from fastdeploy.utils import get_logger
|
||||||
|
from fastdeploy.distributed.communication import capture_custom_allreduce
|
||||||
|
|
||||||
logger = get_logger("cudagrpah_piecewise_backend", "cudagraph_piecewise_backend.log")
|
logger = get_logger("cudagrpah_piecewise_backend", "cudagraph_piecewise_backend.log")
|
||||||
|
|
||||||
@@ -109,9 +110,11 @@ class CudaGraphPiecewiseBackend:
|
|||||||
paddle.device.synchronize()
|
paddle.device.synchronize()
|
||||||
|
|
||||||
# Capture
|
# Capture
|
||||||
new_grpah.capture_begin()
|
with capture_custom_allreduce():
|
||||||
output = entry.runnable(**kwargs)
|
new_grpah.capture_begin()
|
||||||
new_grpah.capture_end()
|
output = entry.runnable(**kwargs)
|
||||||
|
new_grpah.capture_end()
|
||||||
|
|
||||||
|
|
||||||
# Store output buffer
|
# Store output buffer
|
||||||
entry.cuda_graph = new_grpah
|
entry.cuda_graph = new_grpah
|
||||||
|
@@ -17,7 +17,7 @@
|
|||||||
import paddle
|
import paddle
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
|
|
||||||
from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce
|
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||||
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
|
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
|
||||||
from fastdeploy.utils import ceil_div
|
from fastdeploy.utils import ceil_div
|
||||||
|
|
||||||
|
@@ -190,7 +190,7 @@ class GCUFusedMoeMethod(MoEMethodBase):
|
|||||||
fused_moe_out = fused_moe_out.reshape_([token_num, hidden_size])
|
fused_moe_out = fused_moe_out.reshape_([token_num, hidden_size])
|
||||||
|
|
||||||
if layer.tp_size > 1:
|
if layer.tp_size > 1:
|
||||||
from fastdeploy.distributed.communication_op import (
|
from fastdeploy.distributed.communication import (
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@@ -18,7 +18,7 @@ import paddle
|
|||||||
from paddle import nn
|
from paddle import nn
|
||||||
|
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce
|
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
from .utils import _set_var_distributed, divide, get_tensor
|
from .utils import _set_var_distributed, divide, get_tensor
|
||||||
|
@@ -20,7 +20,7 @@ from paddle.nn.quant import weight_quantize
|
|||||||
from paddleformers.utils.log import logger
|
from paddleformers.utils.log import logger
|
||||||
|
|
||||||
import fastdeploy
|
import fastdeploy
|
||||||
from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce
|
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
from ..utils import create_and_set_parameter, get_tensor
|
from ..utils import create_and_set_parameter, get_tensor
|
||||||
|
@@ -19,7 +19,7 @@ from paddle import nn
|
|||||||
from paddleformers.utils.log import logger
|
from paddleformers.utils.log import logger
|
||||||
|
|
||||||
import fastdeploy
|
import fastdeploy
|
||||||
from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce
|
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||||
from fastdeploy.model_executor.ops.gpu import count_tokens_per_expert_func, deep_gemm
|
from fastdeploy.model_executor.ops.gpu import count_tokens_per_expert_func, deep_gemm
|
||||||
|
|
||||||
|
@@ -18,7 +18,7 @@ import paddle
|
|||||||
from paddle import nn
|
from paddle import nn
|
||||||
|
|
||||||
import fastdeploy
|
import fastdeploy
|
||||||
from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce
|
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||||
from fastdeploy.model_executor.ops.gpu import (
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
MoeWna16MarlinGemmApi,
|
MoeWna16MarlinGemmApi,
|
||||||
tritonmoe_preprocess_func,
|
tritonmoe_preprocess_func,
|
||||||
|
@@ -18,7 +18,7 @@ import paddle
|
|||||||
from paddle import nn
|
from paddle import nn
|
||||||
|
|
||||||
import fastdeploy
|
import fastdeploy
|
||||||
from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce
|
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||||
from fastdeploy.model_executor.layers.utils import create_and_set_parameter, get_tensor
|
from fastdeploy.model_executor.layers.utils import create_and_set_parameter, get_tensor
|
||||||
from fastdeploy.utils import ceil_div
|
from fastdeploy.utils import ceil_div
|
||||||
|
|
||||||
|
@@ -18,7 +18,7 @@ import paddle
|
|||||||
from paddle import nn
|
from paddle import nn
|
||||||
|
|
||||||
import fastdeploy
|
import fastdeploy
|
||||||
from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce
|
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||||
from fastdeploy.utils import ceil_div
|
from fastdeploy.utils import ceil_div
|
||||||
|
|
||||||
from ..quantization.quant_base import QuantMethodBase
|
from ..quantization.quant_base import QuantMethodBase
|
||||||
|
@@ -82,7 +82,7 @@ class XPUMoEMethod(MoEMethodBase):
|
|||||||
False, # moe group, used in deepseek
|
False, # moe group, used in deepseek
|
||||||
)
|
)
|
||||||
if layer.tp_size > 1:
|
if layer.tp_size > 1:
|
||||||
from fastdeploy.distributed.communication_op import (
|
from fastdeploy.distributed.communication import (
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -210,7 +210,7 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase):
|
|||||||
False, # moe group, used in deepseek
|
False, # moe group, used in deepseek
|
||||||
)
|
)
|
||||||
if layer.tp_size > 1:
|
if layer.tp_size > 1:
|
||||||
from fastdeploy.distributed.communication_op import (
|
from fastdeploy.distributed.communication import (
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@@ -25,7 +25,7 @@ from paddleformers.transformers import PretrainedModel
|
|||||||
from paddleformers.utils.log import logger
|
from paddleformers.utils.log import logger
|
||||||
|
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce
|
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||||
from fastdeploy.model_executor.layers.activation import SiluAndMul
|
from fastdeploy.model_executor.layers.activation import SiluAndMul
|
||||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||||
|
@@ -28,7 +28,7 @@ from paddleformers.transformers.configuration_utils import PretrainedConfig
|
|||||||
from paddleformers.utils.log import logger
|
from paddleformers.utils.log import logger
|
||||||
|
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce
|
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
|
||||||
from fastdeploy.model_executor.graph_optimization.decorator import (
|
from fastdeploy.model_executor.graph_optimization.decorator import (
|
||||||
support_graph_optimization,
|
support_graph_optimization,
|
||||||
)
|
)
|
||||||
|
@@ -62,7 +62,7 @@ class GpuWorker(WorkerBase):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
paddle.device.cuda.empty_cache()
|
paddle.device.cuda.empty_cache()
|
||||||
if self.parallel_config.enable_custom_all_reduce:
|
if self.parallel_config.enable_custom_all_reduce:
|
||||||
from fastdeploy.distributed.communication_op import use_custom_allreduce
|
from fastdeploy.distributed.communication import use_custom_allreduce
|
||||||
|
|
||||||
use_custom_allreduce()
|
use_custom_allreduce()
|
||||||
else:
|
else:
|
||||||
|
@@ -498,7 +498,7 @@ def parse_args():
|
|||||||
help="enable prefix cache",
|
help="enable prefix cache",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-custom-all-reduce",
|
"--enable_custom_all_reduce",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="enable custom all-reduce",
|
help="enable custom all-reduce",
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user