mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -215,11 +215,13 @@ def load_pre_sharded_checkpoint(model_path: str, local_rank: int, use_fastsafete
|
|||||||
"""
|
"""
|
||||||
load_pre_sharded_checkpoint
|
load_pre_sharded_checkpoint
|
||||||
"""
|
"""
|
||||||
|
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||||
|
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
_, safetensor_files = get_all_safetensors(os.path.join(model_path, f"rank{local_rank}"))
|
_, safetensor_files = get_all_safetensors(os.path.join(model_path, f"rank{local_rank}"))
|
||||||
weights_iterator = safetensors_weights_iterator(safetensor_files)
|
weights_iterator = safetensors_weights_iterator(safetensor_files)
|
||||||
for name, weight in weights_iterator:
|
for name, weight in weights_iterator:
|
||||||
state_dict[name] = weight
|
state_dict[name] = get_tensor(weight)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user