mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[feat]add fast_weights_iterator (#3258)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* add fast_weights_iterator * update * update
This commit is contained in:
@@ -24,6 +24,7 @@ from fastsafetensors import SafeTensorsFileLoader, SingleGroup
|
||||
from paddleformers.transformers import PretrainedModel
|
||||
from paddleformers.transformers.model_utils import load_tp_checkpoint
|
||||
from paddleformers.utils.log import logger
|
||||
from paddleformers.utils.safetensors import fast_safe_open
|
||||
from safetensors import safe_open
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -155,9 +156,7 @@ def load_ep_checkpoint(model_path: str, fd_config: FDConfig, return_numpy: bool
|
||||
return state_dict
|
||||
|
||||
|
||||
def safetensors_weights_iterator(
|
||||
safe_tensor_list: list[str],
|
||||
):
|
||||
def safetensors_weights_iterator(safe_tensor_list: list[str]):
|
||||
"""
|
||||
safetensors_weights_iterator
|
||||
"""
|
||||
@@ -165,8 +164,20 @@ def safetensors_weights_iterator(
|
||||
safe_tensor_list,
|
||||
desc="Loading safetensors checkpoint shards",
|
||||
):
|
||||
from paddleformers.utils.safetensors import fast_safe_open
|
||||
with safe_open(st_file, framework="np") as f:
|
||||
for name in f.keys():
|
||||
param = f.get_tensor(name)
|
||||
yield name, param
|
||||
|
||||
|
||||
def fast_weights_iterator(safe_tensor_list: list[str]):
|
||||
"""
|
||||
paddleformers' iterator for safetensors
|
||||
"""
|
||||
for st_file in tqdm(
|
||||
safe_tensor_list,
|
||||
desc="Loading safetensors checkpoint shards",
|
||||
):
|
||||
with fast_safe_open(st_file, framework="np") as f:
|
||||
for name in f.keys():
|
||||
param = f.get_slice(name)
|
||||
@@ -215,13 +226,12 @@ def load_pre_sharded_checkpoint(model_path: str, local_rank: int, use_fastsafete
|
||||
"""
|
||||
load_pre_sharded_checkpoint
|
||||
"""
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
|
||||
state_dict = {}
|
||||
_, safetensor_files = get_all_safetensors(os.path.join(model_path, f"rank{local_rank}"))
|
||||
weights_iterator = safetensors_weights_iterator(safetensor_files)
|
||||
for name, weight in weights_iterator:
|
||||
state_dict[name] = get_tensor(weight)
|
||||
state_dict[name] = weight
|
||||
return state_dict
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user