mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[cp][Loader] 2.2 check paddle version for v1 loader (#4478)
* check * check * check import
This commit is contained in:
@@ -199,3 +199,19 @@ def temporary_dtype(dtype: str):
|
||||
yield
|
||||
finally:
|
||||
paddle.set_default_dtype(orig_dtype)
|
||||
|
||||
|
||||
def is_paddle_support_v1_loader():
|
||||
src_shape = [32, 32]
|
||||
tgt_shape = [1, 32, 64]
|
||||
src_tensor = paddle.ones(src_shape, dtype="float32")
|
||||
tgt_tensor = paddle.zeros(tgt_shape, dtype="float32")
|
||||
for exp_id in range(tgt_shape[0]):
|
||||
# gate
|
||||
gate_tgt = tgt_tensor[exp_id][..., : tgt_shape[2] // 2]
|
||||
gate_tgt.copy_(src_tensor, False)
|
||||
# up
|
||||
up_tgt = tgt_tensor[exp_id][..., tgt_shape[2] // 2 :]
|
||||
up_tgt.copy_(src_tensor, False)
|
||||
is_same = bool(paddle.all(tgt_tensor == 1))
|
||||
return is_same
|
||||
|
||||
Reference in New Issue
Block a user