[Intel HPU] Enable dist sampler on intel hpu platform (#4445)

This commit is contained in:
Jianyu Li
2025-10-16 19:02:27 +08:00
committed by GitHub
parent 4251ac5e95
commit 3bbe99eae7
2 changed files with 4 additions and 3 deletions

View File

@@ -416,7 +416,7 @@ class Sampler(nn.Layer):
if next_tokens.shape[0] != max_batch:
dim = next_tokens.shape[-1]
tmp_tokens = paddle.full((max_batch, dim), -1, dtype=next_tokens.dtype)
tmp_tokens = paddle.full((max_batch, dim), -1 if local_rank == 0 else 0, dtype=next_tokens.dtype)
tmp_tokens = paddle.scatter(tmp_tokens, batch_ids, next_tokens[: batch_ids.shape[0], :])
return tmp_tokens