From 37569cca86c010fac03d6ee967d56dfcc9aff487 Mon Sep 17 00:00:00 2001 From: bukejiyu <52310069+bukejiyu@users.noreply.github.com> Date: Thu, 7 Aug 2025 22:36:46 +0800 Subject: [PATCH] [feat]add fast_weights_iterator (#3258) * add fast_weights_iterator * update * update --- fastdeploy/model_executor/layers/linear.py | 5 ----- .../model_executor/load_weight_utils.py | 22 ++++++++++++++----- .../model_loader/default_loader_v1.py | 4 ++-- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index fe685a275..c6a62c935 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -162,11 +162,6 @@ class LinearBase(nn.Layer): self.linear_shift = None self.linear_smooth = None - if fd_config.model_config.is_quantized: - self.weight_key = f"{prefix}.quant_weight" - self.weight_scale_key = f"{prefix}.weight_scale" - self.act_scale_key = f"{prefix}.activation_scale" - def load_prequant_weight(self, state_dict: dict): """ Load the prequantized weight from the state dictionary. diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index 01f81ac13..712cff972 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -24,6 +24,7 @@ from fastsafetensors import SafeTensorsFileLoader, SingleGroup from paddleformers.transformers import PretrainedModel from paddleformers.transformers.model_utils import load_tp_checkpoint from paddleformers.utils.log import logger +from paddleformers.utils.safetensors import fast_safe_open from safetensors import safe_open from tqdm import tqdm @@ -155,9 +156,7 @@ def load_ep_checkpoint(model_path: str, fd_config: FDConfig, return_numpy: bool return state_dict -def safetensors_weights_iterator( - safe_tensor_list: list[str], -): +def safetensors_weights_iterator(safe_tensor_list: list[str]): """ safetensors_weights_iterator """ @@ -165,8 +164,20 @@ def safetensors_weights_iterator( safe_tensor_list, desc="Loading safetensors checkpoint shards", ): - from paddleformers.utils.safetensors import fast_safe_open + with safe_open(st_file, framework="np") as f: + for name in f.keys(): + param = f.get_tensor(name) + yield name, param + +def fast_weights_iterator(safe_tensor_list: list[str]): + """ + paddleformers' iterator for safetensors + """ + for st_file in tqdm( + safe_tensor_list, + desc="Loading safetensors checkpoint shards", + ): with fast_safe_open(st_file, framework="np") as f: for name in f.keys(): param = f.get_slice(name) @@ -215,13 +226,12 @@ def load_pre_sharded_checkpoint(model_path: str, local_rank: int, use_fastsafete """ load_pre_sharded_checkpoint """ - from fastdeploy.model_executor.layers.utils import get_tensor state_dict = {} _, safetensor_files = get_all_safetensors(os.path.join(model_path, f"rank{local_rank}")) weights_iterator = safetensors_weights_iterator(safetensor_files) for name, weight in weights_iterator: - state_dict[name] = get_tensor(weight) + state_dict[name] = weight return state_dict diff --git a/fastdeploy/model_executor/model_loader/default_loader_v1.py b/fastdeploy/model_executor/model_loader/default_loader_v1.py index 1ccb7f742..4d79772e5 100644 --- a/fastdeploy/model_executor/model_loader/default_loader_v1.py +++ b/fastdeploy/model_executor/model_loader/default_loader_v1.py @@ -22,9 +22,9 @@ from paddleformers.utils.log import logger from fastdeploy.config import FDConfig, LoadConfig, ModelConfig from fastdeploy.model_executor.load_weight_utils import ( + fast_weights_iterator, get_all_safetensors, measure_time, - safetensors_weights_iterator, ) from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader from fastdeploy.model_executor.models.model_base import ModelRegistry @@ -49,7 +49,7 @@ class DefaultModelLoaderV1(BaseModelLoader): @measure_time def load_weights(self, model, fd_config: FDConfig) -> None: _, safetensor_files = get_all_safetensors(fd_config.model_config.model) - weights_iterator = safetensors_weights_iterator(safetensor_files) + weights_iterator = fast_weights_iterator(safetensor_files) model.load_weights(weights_iterator) self.clean_memory_fragments()