mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -215,11 +215,13 @@ def load_pre_sharded_checkpoint(model_path: str, local_rank: int, use_fastsafete
|
||||
"""
|
||||
load_pre_sharded_checkpoint
|
||||
"""
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
|
||||
state_dict = {}
|
||||
_, safetensor_files = get_all_safetensors(os.path.join(model_path, f"rank{local_rank}"))
|
||||
weights_iterator = safetensors_weights_iterator(safetensor_files)
|
||||
for name, weight in weights_iterator:
|
||||
state_dict[name] = weight
|
||||
state_dict[name] = get_tensor(weight)
|
||||
return state_dict
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user