【Fearture】support qwen2 some func (#2740)

* add rl qwen model support

* fix

* fix
This commit is contained in:
gaoziyuan
2025-07-08 12:03:04 +08:00
committed by GitHub
parent fefbd65cf8
commit 26d5d737dd
13 changed files with 438 additions and 174 deletions

View File

@@ -17,13 +17,15 @@
import paddle
import paddle.distributed as dist
@paddle.jit.marker.unified
def tensor_model_parallel_all_reduce(input_: paddle.Tensor) -> paddle.Tensor:
"""All-reduce the input tensor across model parallel group."""
if paddle.in_dynamic_mode():
hcg = dist.fleet.get_hybrid_communicate_group()
mp_group = hcg.get_model_parallel_group()
dist.all_reduce(input_, group=mp_group)
else:
dist.all_reduce(input_)
try:
@paddle.jit.marker.unified
def tensor_model_parallel_all_reduce(input_: paddle.Tensor) -> paddle.Tensor:
"""All-reduce the input tensor across model parallel group."""
if paddle.in_dynamic_mode():
hcg = dist.fleet.get_hybrid_communicate_group()
mp_group = hcg.get_model_parallel_group()
dist.all_reduce(input_, group=mp_group)
else:
dist.all_reduce(input_)
except:
tensor_model_parallel_all_reduce=None