[feat]add fast_weights_iterator (#3258)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* add fast_weights_iterator

* update

* update
This commit is contained in:
bukejiyu
2025-08-07 22:36:46 +08:00
committed by GitHub
parent 5f0b30f6d0
commit 37569cca86
3 changed files with 18 additions and 13 deletions

View File

@@ -162,11 +162,6 @@ class LinearBase(nn.Layer):
self.linear_shift = None self.linear_shift = None
self.linear_smooth = None self.linear_smooth = None
if fd_config.model_config.is_quantized:
self.weight_key = f"{prefix}.quant_weight"
self.weight_scale_key = f"{prefix}.weight_scale"
self.act_scale_key = f"{prefix}.activation_scale"
def load_prequant_weight(self, state_dict: dict): def load_prequant_weight(self, state_dict: dict):
""" """
Load the prequantized weight from the state dictionary. Load the prequantized weight from the state dictionary.

View File

@@ -24,6 +24,7 @@ from fastsafetensors import SafeTensorsFileLoader, SingleGroup
from paddleformers.transformers import PretrainedModel from paddleformers.transformers import PretrainedModel
from paddleformers.transformers.model_utils import load_tp_checkpoint from paddleformers.transformers.model_utils import load_tp_checkpoint
from paddleformers.utils.log import logger from paddleformers.utils.log import logger
from paddleformers.utils.safetensors import fast_safe_open
from safetensors import safe_open from safetensors import safe_open
from tqdm import tqdm from tqdm import tqdm
@@ -155,9 +156,7 @@ def load_ep_checkpoint(model_path: str, fd_config: FDConfig, return_numpy: bool
return state_dict return state_dict
def safetensors_weights_iterator( def safetensors_weights_iterator(safe_tensor_list: list[str]):
safe_tensor_list: list[str],
):
""" """
safetensors_weights_iterator safetensors_weights_iterator
""" """
@@ -165,8 +164,20 @@ def safetensors_weights_iterator(
safe_tensor_list, safe_tensor_list,
desc="Loading safetensors checkpoint shards", desc="Loading safetensors checkpoint shards",
): ):
from paddleformers.utils.safetensors import fast_safe_open with safe_open(st_file, framework="np") as f:
for name in f.keys():
param = f.get_tensor(name)
yield name, param
def fast_weights_iterator(safe_tensor_list: list[str]):
"""
paddleformers' iterator for safetensors
"""
for st_file in tqdm(
safe_tensor_list,
desc="Loading safetensors checkpoint shards",
):
with fast_safe_open(st_file, framework="np") as f: with fast_safe_open(st_file, framework="np") as f:
for name in f.keys(): for name in f.keys():
param = f.get_slice(name) param = f.get_slice(name)
@@ -215,13 +226,12 @@ def load_pre_sharded_checkpoint(model_path: str, local_rank: int, use_fastsafete
""" """
load_pre_sharded_checkpoint load_pre_sharded_checkpoint
""" """
from fastdeploy.model_executor.layers.utils import get_tensor
state_dict = {} state_dict = {}
_, safetensor_files = get_all_safetensors(os.path.join(model_path, f"rank{local_rank}")) _, safetensor_files = get_all_safetensors(os.path.join(model_path, f"rank{local_rank}"))
weights_iterator = safetensors_weights_iterator(safetensor_files) weights_iterator = safetensors_weights_iterator(safetensor_files)
for name, weight in weights_iterator: for name, weight in weights_iterator:
state_dict[name] = get_tensor(weight) state_dict[name] = weight
return state_dict return state_dict

View File

@@ -22,9 +22,9 @@ from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig, LoadConfig, ModelConfig from fastdeploy.config import FDConfig, LoadConfig, ModelConfig
from fastdeploy.model_executor.load_weight_utils import ( from fastdeploy.model_executor.load_weight_utils import (
fast_weights_iterator,
get_all_safetensors, get_all_safetensors,
measure_time, measure_time,
safetensors_weights_iterator,
) )
from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader
from fastdeploy.model_executor.models.model_base import ModelRegistry from fastdeploy.model_executor.models.model_base import ModelRegistry
@@ -49,7 +49,7 @@ class DefaultModelLoaderV1(BaseModelLoader):
@measure_time @measure_time
def load_weights(self, model, fd_config: FDConfig) -> None: def load_weights(self, model, fd_config: FDConfig) -> None:
_, safetensor_files = get_all_safetensors(fd_config.model_config.model) _, safetensor_files = get_all_safetensors(fd_config.model_config.model)
weights_iterator = safetensors_weights_iterator(safetensor_files) weights_iterator = fast_weights_iterator(safetensor_files)
model.load_weights(weights_iterator) model.load_weights(weights_iterator)
self.clean_memory_fragments() self.clean_memory_fragments()