mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[feat]add fast_weights_iterator (#3258)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* add fast_weights_iterator * update * update
This commit is contained in:
@@ -162,11 +162,6 @@ class LinearBase(nn.Layer):
|
|||||||
self.linear_shift = None
|
self.linear_shift = None
|
||||||
self.linear_smooth = None
|
self.linear_smooth = None
|
||||||
|
|
||||||
if fd_config.model_config.is_quantized:
|
|
||||||
self.weight_key = f"{prefix}.quant_weight"
|
|
||||||
self.weight_scale_key = f"{prefix}.weight_scale"
|
|
||||||
self.act_scale_key = f"{prefix}.activation_scale"
|
|
||||||
|
|
||||||
def load_prequant_weight(self, state_dict: dict):
|
def load_prequant_weight(self, state_dict: dict):
|
||||||
"""
|
"""
|
||||||
Load the prequantized weight from the state dictionary.
|
Load the prequantized weight from the state dictionary.
|
||||||
|
@@ -24,6 +24,7 @@ from fastsafetensors import SafeTensorsFileLoader, SingleGroup
|
|||||||
from paddleformers.transformers import PretrainedModel
|
from paddleformers.transformers import PretrainedModel
|
||||||
from paddleformers.transformers.model_utils import load_tp_checkpoint
|
from paddleformers.transformers.model_utils import load_tp_checkpoint
|
||||||
from paddleformers.utils.log import logger
|
from paddleformers.utils.log import logger
|
||||||
|
from paddleformers.utils.safetensors import fast_safe_open
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
@@ -155,9 +156,7 @@ def load_ep_checkpoint(model_path: str, fd_config: FDConfig, return_numpy: bool
|
|||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def safetensors_weights_iterator(
|
def safetensors_weights_iterator(safe_tensor_list: list[str]):
|
||||||
safe_tensor_list: list[str],
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
safetensors_weights_iterator
|
safetensors_weights_iterator
|
||||||
"""
|
"""
|
||||||
@@ -165,8 +164,20 @@ def safetensors_weights_iterator(
|
|||||||
safe_tensor_list,
|
safe_tensor_list,
|
||||||
desc="Loading safetensors checkpoint shards",
|
desc="Loading safetensors checkpoint shards",
|
||||||
):
|
):
|
||||||
from paddleformers.utils.safetensors import fast_safe_open
|
with safe_open(st_file, framework="np") as f:
|
||||||
|
for name in f.keys():
|
||||||
|
param = f.get_tensor(name)
|
||||||
|
yield name, param
|
||||||
|
|
||||||
|
|
||||||
|
def fast_weights_iterator(safe_tensor_list: list[str]):
|
||||||
|
"""
|
||||||
|
paddleformers' iterator for safetensors
|
||||||
|
"""
|
||||||
|
for st_file in tqdm(
|
||||||
|
safe_tensor_list,
|
||||||
|
desc="Loading safetensors checkpoint shards",
|
||||||
|
):
|
||||||
with fast_safe_open(st_file, framework="np") as f:
|
with fast_safe_open(st_file, framework="np") as f:
|
||||||
for name in f.keys():
|
for name in f.keys():
|
||||||
param = f.get_slice(name)
|
param = f.get_slice(name)
|
||||||
@@ -215,13 +226,12 @@ def load_pre_sharded_checkpoint(model_path: str, local_rank: int, use_fastsafete
|
|||||||
"""
|
"""
|
||||||
load_pre_sharded_checkpoint
|
load_pre_sharded_checkpoint
|
||||||
"""
|
"""
|
||||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
|
||||||
|
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
_, safetensor_files = get_all_safetensors(os.path.join(model_path, f"rank{local_rank}"))
|
_, safetensor_files = get_all_safetensors(os.path.join(model_path, f"rank{local_rank}"))
|
||||||
weights_iterator = safetensors_weights_iterator(safetensor_files)
|
weights_iterator = safetensors_weights_iterator(safetensor_files)
|
||||||
for name, weight in weights_iterator:
|
for name, weight in weights_iterator:
|
||||||
state_dict[name] = get_tensor(weight)
|
state_dict[name] = weight
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
@@ -22,9 +22,9 @@ from paddleformers.utils.log import logger
|
|||||||
|
|
||||||
from fastdeploy.config import FDConfig, LoadConfig, ModelConfig
|
from fastdeploy.config import FDConfig, LoadConfig, ModelConfig
|
||||||
from fastdeploy.model_executor.load_weight_utils import (
|
from fastdeploy.model_executor.load_weight_utils import (
|
||||||
|
fast_weights_iterator,
|
||||||
get_all_safetensors,
|
get_all_safetensors,
|
||||||
measure_time,
|
measure_time,
|
||||||
safetensors_weights_iterator,
|
|
||||||
)
|
)
|
||||||
from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader
|
from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader
|
||||||
from fastdeploy.model_executor.models.model_base import ModelRegistry
|
from fastdeploy.model_executor.models.model_base import ModelRegistry
|
||||||
@@ -49,7 +49,7 @@ class DefaultModelLoaderV1(BaseModelLoader):
|
|||||||
@measure_time
|
@measure_time
|
||||||
def load_weights(self, model, fd_config: FDConfig) -> None:
|
def load_weights(self, model, fd_config: FDConfig) -> None:
|
||||||
_, safetensor_files = get_all_safetensors(fd_config.model_config.model)
|
_, safetensor_files = get_all_safetensors(fd_config.model_config.model)
|
||||||
weights_iterator = safetensors_weights_iterator(safetensor_files)
|
weights_iterator = fast_weights_iterator(safetensor_files)
|
||||||
model.load_weights(weights_iterator)
|
model.load_weights(weights_iterator)
|
||||||
self.clean_memory_fragments()
|
self.clean_memory_fragments()
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user