* fix typos

* fix
This commit is contained in:
co63oc
2025-09-22 14:27:17 +08:00
committed by GitHub
parent 0b62648924
commit c4830ef24c
10 changed files with 34 additions and 50 deletions

View File

@@ -54,14 +54,14 @@ class Test(unittest.TestCase):
fa = CustomAllreduce(model_parallel_group)
for m, n in mns:
data_cusom_ar = paddle.rand([m, n], dtype="bfloat16")
data_paddle = data_cusom_ar.clone()
if fa.should_custom_ar(data_cusom_ar):
fa.custom_all_reduce(data_cusom_ar)
data_custom_ar = paddle.rand([m, n], dtype="bfloat16")
data_paddle = data_custom_ar.clone()
if fa.should_custom_ar(data_custom_ar):
fa.custom_all_reduce(data_custom_ar)
dist.all_reduce(data_paddle)
if dist.get_rank() == 0:
np.testing.assert_allclose(
data_cusom_ar.numpy(),
data_custom_ar.numpy(),
data_paddle.numpy(),
rtol=1e-04,
atol=1e-04,