mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 09:07:10 +08:00
@@ -54,14 +54,14 @@ class Test(unittest.TestCase):
|
||||
fa = CustomAllreduce(model_parallel_group)
|
||||
|
||||
for m, n in mns:
|
||||
data_cusom_ar = paddle.rand([m, n], dtype="bfloat16")
|
||||
data_paddle = data_cusom_ar.clone()
|
||||
if fa.should_custom_ar(data_cusom_ar):
|
||||
fa.custom_all_reduce(data_cusom_ar)
|
||||
data_custom_ar = paddle.rand([m, n], dtype="bfloat16")
|
||||
data_paddle = data_custom_ar.clone()
|
||||
if fa.should_custom_ar(data_custom_ar):
|
||||
fa.custom_all_reduce(data_custom_ar)
|
||||
dist.all_reduce(data_paddle)
|
||||
if dist.get_rank() == 0:
|
||||
np.testing.assert_allclose(
|
||||
data_cusom_ar.numpy(),
|
||||
data_custom_ar.numpy(),
|
||||
data_paddle.numpy(),
|
||||
rtol=1e-04,
|
||||
atol=1e-04,
|
||||
|
Reference in New Issue
Block a user