[cp][Loader] 2.2 check paddle version for v1 loader (#4478)

* check

* check

* check import
This commit is contained in:
chen
2025-10-20 15:27:59 +08:00
committed by GitHub
parent c13e6ae481
commit 8d2aaf3ba4
2 changed files with 20 additions and 0 deletions

View File

@@ -199,3 +199,19 @@ def temporary_dtype(dtype: str):
yield
finally:
paddle.set_default_dtype(orig_dtype)
def is_paddle_support_v1_loader():
src_shape = [32, 32]
tgt_shape = [1, 32, 64]
src_tensor = paddle.ones(src_shape, dtype="float32")
tgt_tensor = paddle.zeros(tgt_shape, dtype="float32")
for exp_id in range(tgt_shape[0]):
# gate
gate_tgt = tgt_tensor[exp_id][..., : tgt_shape[2] // 2]
gate_tgt.copy_(src_tensor, False)
# up
up_tgt = tgt_tensor[exp_id][..., tgt_shape[2] // 2 :]
up_tgt.copy_(src_tensor, False)
is_same = bool(paddle.all(tgt_tensor == 1))
return is_same