mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Intel HPU] Enable dist sampler on intel hpu platform (#4445)
This commit is contained in:
@@ -416,7 +416,7 @@ class Sampler(nn.Layer):
|
||||
|
||||
if next_tokens.shape[0] != max_batch:
|
||||
dim = next_tokens.shape[-1]
|
||||
tmp_tokens = paddle.full((max_batch, dim), -1, dtype=next_tokens.dtype)
|
||||
tmp_tokens = paddle.full((max_batch, dim), -1 if local_rank == 0 else 0, dtype=next_tokens.dtype)
|
||||
tmp_tokens = paddle.scatter(tmp_tokens, batch_ids, next_tokens[: batch_ids.shape[0], :])
|
||||
return tmp_tokens
|
||||
|
||||
|
||||
Reference in New Issue
Block a user