mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
Fix performance degradation bug of custom_all_reduce (#2981)
This commit is contained in:
@@ -14,10 +14,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from contextlib import contextmanager, nullcontext
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
import paddle.distributed as dist
|
import paddle.distributed as dist
|
||||||
from paddle.distributed import fleet
|
from paddle.distributed import fleet
|
||||||
from contextlib import contextmanager, nullcontext
|
|
||||||
|
|
||||||
from fastdeploy.distributed.parallel_state import get_tensor_model_parallel_world_size
|
from fastdeploy.distributed.parallel_state import get_tensor_model_parallel_world_size
|
||||||
|
|
||||||
@@ -53,7 +54,7 @@ try:
|
|||||||
"""All-reduce the input tensor across model parallel group."""
|
"""All-reduce the input tensor across model parallel group."""
|
||||||
global _TP_AR
|
global _TP_AR
|
||||||
if _TP_AR is not None and _TP_AR.should_custom_ar(input_):
|
if _TP_AR is not None and _TP_AR.should_custom_ar(input_):
|
||||||
_TP_AR.all_reduce(input_, input_)
|
_TP_AR.custom_all_reduce(input_)
|
||||||
elif paddle.in_dynamic_mode():
|
elif paddle.in_dynamic_mode():
|
||||||
hcg = fleet.get_hybrid_communicate_group()
|
hcg = fleet.get_hybrid_communicate_group()
|
||||||
mp_group = hcg.get_model_parallel_group()
|
mp_group = hcg.get_model_parallel_group()
|
||||||
|
@@ -133,7 +133,7 @@ class CustomAllreduce:
|
|||||||
def should_custom_ar(self, inp: paddle.Tensor):
|
def should_custom_ar(self, inp: paddle.Tensor):
|
||||||
if self.capturing:
|
if self.capturing:
|
||||||
return True
|
return True
|
||||||
inp_size = inp.numel() * inp.element_size()
|
inp_size = inp.shape[0] * inp.shape[1] * inp.element_size()
|
||||||
# custom allreduce requires input byte size to be multiples of 16
|
# custom allreduce requires input byte size to be multiples of 16
|
||||||
if inp_size % 16 != 0:
|
if inp_size % 16 != 0:
|
||||||
return False
|
return False
|
||||||
|
Reference in New Issue
Block a user