fix load_pre_sharded_checkpoint (#3152) (#3169)

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
bukejiyu
2025-08-04 15:44:10 +08:00
committed by GitHub
parent 5f6fc7f7b9
commit 8e789dcb67

View File

@@ -215,11 +215,13 @@ def load_pre_sharded_checkpoint(model_path: str, local_rank: int, use_fastsafete
"""
load_pre_sharded_checkpoint
"""
from fastdeploy.model_executor.layers.utils import get_tensor
state_dict = {}
_, safetensor_files = get_all_safetensors(os.path.join(model_path, f"rank{local_rank}"))
weights_iterator = safetensors_weights_iterator(safetensor_files)
for name, weight in weights_iterator:
state_dict[name] = weight
state_dict[name] = get_tensor(weight)
return state_dict