mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
Fix performance degradation bug of custom_all_reduce (#2981)
This commit is contained in:
@@ -14,10 +14,11 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager, nullcontext
|
||||
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
from paddle.distributed import fleet
|
||||
from contextlib import contextmanager, nullcontext
|
||||
|
||||
from fastdeploy.distributed.parallel_state import get_tensor_model_parallel_world_size
|
||||
|
||||
@@ -53,7 +54,7 @@ try:
|
||||
"""All-reduce the input tensor across model parallel group."""
|
||||
global _TP_AR
|
||||
if _TP_AR is not None and _TP_AR.should_custom_ar(input_):
|
||||
_TP_AR.all_reduce(input_, input_)
|
||||
_TP_AR.custom_all_reduce(input_)
|
||||
elif paddle.in_dynamic_mode():
|
||||
hcg = fleet.get_hybrid_communicate_group()
|
||||
mp_group = hcg.get_model_parallel_group()
|
||||
|
Reference in New Issue
Block a user