Files
FastDeploy/fastdeploy/model_executor/load_weight_utils.py
bukejiyu 29ed617f0f [v1 loader]qwen Offline fp8 (#4036)
* support offline fp8

* update ut

* update ut

* update ut

* fix

* update

* update
2025-09-15 13:44:11 +08:00

495 lines
19 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import contextlib
import hashlib
import inspect
import json
import os
import pickle
import time
from functools import wraps
from pathlib import Path
import paddle
import paddle.distributed as dist
from fastsafetensors import SafeTensorsFileLoader, SingleGroup
from paddleformers.transformers import PretrainedModel
from paddleformers.transformers.model_utils import load_tp_checkpoint
from paddleformers.utils.log import logger
from paddleformers.utils.safetensors import fast_safe_open
from safetensors import safe_open
from tqdm import tqdm
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.linear import KVBatchLinear
from fastdeploy.model_executor.models.tp_utils import (
check_tensor_parallel_prerequisites,
)
from fastdeploy.model_executor.utils import switch_config_context
from fastdeploy.platforms import current_platform
def pdparams_weight_iterator(paddle_file_list: list[str]):
for pdparams_file in tqdm(
paddle_file_list,
desc="Loading pdparams checkpoint shards",
):
state_dict = paddle.load(pdparams_file)
yield from state_dict.items()
del state_dict
def load_weights_form_cache(model, weights_iterator):
params_dict = dict(model.named_parameters())
for loaded_weight_name, loaded_weight in weights_iterator:
param = params_dict[loaded_weight_name]
param.copy_(loaded_weight, False)
if "embeddings" in loaded_weight_name and getattr(model, "tie_word_embeddings", False):
model.lm_head.load_state_dict({model.lm_head.weight_key: loaded_weight})
for _, model_sublayer in model.named_sublayers():
if isinstance(model_sublayer, KVBatchLinear):
model_sublayer.process_weights_after_loading()
def get_weight_iterator(model_path: str):
_, files_list, use_safetensors = get_all_weights_file(model_path)
if use_safetensors:
weights_iterator = fast_weights_iterator(files_list)
else:
weights_iterator = pdparams_weight_iterator(files_list)
return weights_iterator
def is_weight_cache_enabled(fd_config, weight_cache_path=".cache"):
weight_cache_context = contextlib.nullcontext()
weight_cache_dir = None
enable_cache = False
if envs.FD_ENABLE_MODEL_LOAD_CACHE:
model_weight_cache_path = os.path.join(fd_config.model_config.model, weight_cache_path)
# model_type + quantization + tp_size + ep_size
weight_cache_key = "_".join(
[
fd_config.model_config.model_type,
fd_config.quant_config.name(),
str(fd_config.parallel_config.tensor_parallel_size),
str(fd_config.parallel_config.expert_parallel_size),
]
)
# only support tp now
hash_key = hashlib.md5(pickle.dumps(weight_cache_key)).hexdigest()
weight_cache_dir = os.path.join(model_weight_cache_path, hash_key)
if os.path.exists(weight_cache_dir):
logger.info(
f"Loading will prioritize cached models. Users are responsible for ensuring the saved model is correct. If any error occurs, deleting the cache at {weight_cache_dir} may resolve it."
)
enable_cache = True
weight_cache_context = switch_config_context(fd_config.quant_config, "is_quantized", True)
return enable_cache, weight_cache_dir, weight_cache_context
def save_model(model_arg_name="model", config_arg_name="fd_config"):
@measure_time("Model saving")
def _save_model(model_dict, weight_cache_dir):
# Note: ProcessGroupNCCL do not support deepcopy protocol, we made modifications here.
paddle.distributed.communication.group.Group.__deepcopy__ = lambda self, _: self
paddle.distributed.communication.group.Group.to_json = lambda self: repr(self)
paddle.save(model_dict, weight_cache_dir)
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
sig = inspect.signature(func)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
fd_config = bound_args.arguments.get(config_arg_name, None)
model = bound_args.arguments.get(model_arg_name, None)
enable_cache, weight_cache_dir, _ = is_weight_cache_enabled(fd_config)
assert fd_config is not None, "fd_config cannot be None"
assert model is not None, "model cannot be None"
if enable_cache:
tp_weight_cache_dir = os.path.join(
weight_cache_dir, f"rank{str(fd_config.parallel_config.tensor_parallel_rank)}"
)
context = switch_config_context(fd_config.model_config, "model", tp_weight_cache_dir)
else:
context = contextlib.nullcontext()
with context:
result = func(*args, **kwargs)
if (
envs.FD_ENABLE_MODEL_LOAD_CACHE
and weight_cache_dir is not None
and not os.path.exists(weight_cache_dir)
):
assert fd_config.quant_config is not None and getattr(
fd_config.quant_config, "is_checkpoint_bf16", False
), "Save cache only for dynamic quantization"
tp_weight_cache_dir = os.path.join(
weight_cache_dir, f"rank{str(fd_config.parallel_config.tensor_parallel_rank)}"
)
logger.info(f"Saving model to {tp_weight_cache_dir}")
os.makedirs(
tp_weight_cache_dir,
exist_ok=True,
)
_save_model(model.state_dict(), os.path.join(tp_weight_cache_dir, "cache.pdparams"))
else:
reason = "weights already cached" if envs.FD_ENABLE_MODEL_LOAD_CACHE else "cache disabled"
logger.info(f"Skip saving ,{reason}")
return result
return wrapper
return decorator
def measure_time(prefix: str = "Model loading"):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
time_before = time.time()
result = func(*args, **kwargs)
time_after = time.time()
logger.info(f"{prefix} took {time_after - time_before:.3f} seconds")
return result
return wrapper
return decorator
def load_reordered_experts(model_path: str, key_name: str):
from safetensors import safe_open
with open(os.path.join(model_path, "model.safetensors.index.json"), "r") as f:
weight_list = json.load(f)["weight_map"]
safetensor_path = os.path.join(model_path, weight_list[key_name])
with safe_open(safetensor_path, framework="np", device="cpu") as f:
if key_name in f.keys():
weight = f.get_tensor(key_name)
weight = paddle.Tensor(weight, zero_copy=True)
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
return weight
def load_ep_checkpoint(model_path: str, fd_config: FDConfig, return_numpy: bool = False):
"""
load ep checkpoint
"""
with open(os.path.join(model_path, "model.safetensors.index.json"), "r") as f:
weight_list = json.load(f)["weight_map"]
filtered_map = {k: v for k, v in weight_list.items() if ".experts." not in k}
num_local_ffn_keys = []
from itertools import chain
def get_expert_ranges(fd_config):
"""
Generate expert index ranges based on configuration parameters
This function is primarily used in Mixture-of-Experts (MoE) models to generate
expert index ranges according to configuration parameters. When moe_num_experts
is a list in the fd_config, it returns a chained combination of two ranges, otherwise
returns a single range.
Args:
fd_config: FastDeploy Configuration object
Returns:
If moe_num_experts is a list:
Returns a chained combination (chain object) of two ranges:
1. Base range: [num_experts_start_offset, num_experts_start_offset + num_experts_per_rank)
2. Offset range: [base_range.start + moe_num_experts[0], base_range.stop + moe_num_experts[0])
Else:
Returns single range: [num_experts_start_offset, num_experts_start_offset + num_experts_per_rank)
"""
base_range = range(
fd_config.parallel_config.num_experts_start_offset,
fd_config.parallel_config.num_experts_start_offset + fd_config.parallel_config.num_experts_per_rank,
)
if isinstance(fd_config.model_config.moe_num_experts, list):
return chain(
base_range,
range(
base_range.start + fd_config.model_config.moe_num_experts[0],
base_range.stop + fd_config.model_config.moe_num_experts[0],
),
)
return base_range
for i in range(fd_config.model_config.moe_layer_start_index, fd_config.model_config.num_hidden_layers):
for j in get_expert_ranges(fd_config):
up_gate_proj_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight"
down_proj_key = f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight"
up_gate_proj_quant_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.quant_weight"
down_proj_quant_key = f"ernie.layers.{i}.mlp.experts.{j}.down_proj.quant_weight"
up_gate_proj_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight_scale"
down_proj_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight_scale"
down_proj_in_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.down_proj.activation_scale"
num_local_ffn_keys.append(up_gate_proj_key)
num_local_ffn_keys.append(down_proj_key)
num_local_ffn_keys.append(up_gate_proj_quant_key)
num_local_ffn_keys.append(down_proj_quant_key)
num_local_ffn_keys.append(up_gate_proj_scale_key)
num_local_ffn_keys.append(down_proj_scale_key)
num_local_ffn_keys.append(down_proj_in_scale_key)
# for EP w4a8, we need all expert's activation_scale for up_gate_proj
num_experts = fd_config.model_config.moe_num_experts
if isinstance(num_experts, list):
num_experts = num_experts[0]
for j in range(num_experts):
up_gate_proj_in_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.activation_scale"
num_local_ffn_keys.append(up_gate_proj_in_scale_key)
for k in num_local_ffn_keys:
if k in weight_list:
filtered_map[k] = weight_list[k]
state_dict = {}
# Get all safetensor file paths that need to be opened
safetensor_paths = set(filtered_map.values())
# Open each safetensor file sequentially with progress bar
for safetensor_path in tqdm(safetensor_paths, desc="Loading safetensor files", unit="file"):
with safe_open(
os.path.join(model_path, safetensor_path),
framework="np",
device="cpu",
) as f:
# Check if this file contains keys from filtered_map
for k in filtered_map:
if filtered_map[k] == safetensor_path and k in f.keys():
weight = f.get_tensor(k)
if not return_numpy:
weight = paddle.Tensor(weight, zero_copy=True)
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
state_dict[k] = weight
return state_dict
def safetensors_weights_iterator(safe_tensor_list: list[str]):
"""
safetensors_weights_iterator
"""
for st_file in tqdm(
safe_tensor_list,
desc="Loading safetensors checkpoint shards",
):
with safe_open(st_file, framework="np") as f:
for name in f.keys():
param = f.get_tensor(name)
yield name, param
def fast_weights_iterator(safe_tensor_list: list[str]):
"""
paddleformers' iterator for safetensors
"""
for st_file in tqdm(
safe_tensor_list,
desc="Loading safetensors checkpoint shards",
):
with fast_safe_open(st_file, framework="np") as f:
for name in f.keys():
param_slice = f.get_slice(name)
yield name, param_slice
def fastsafetensors_weights_iterator(
safetensor_list: list[str],
):
"""
Return an iterator over tensors on GPU from a given safetensor_list.
"""
world_size = dist.get_world_size()
if world_size > 1:
pg = dist.get_group()
device = f"gpu:{pg.rank}" if paddle.is_compiled_with_cuda() else "cpu"
else:
pg = SingleGroup()
device = f"gpu:{pg.rank()}" if paddle.is_compiled_with_cuda() else "cpu"
safetensor_files_sub_lists = [
safetensor_list[i : i + world_size] for i in range(0, len(safetensor_list), world_size)
]
for st_file in tqdm(
safetensor_files_sub_lists,
desc="Loading fastsafetensors checkpoint shards",
):
loader = SafeTensorsFileLoader(pg, device, nogds=True, debug_log=False, framework="paddle")
rank_file_map = {i: [f] for i, f in enumerate(st_file)}
loader.add_filenames(rank_file_map)
try:
fb = loader.copy_files_to_device()
try:
keys = list(fb.key_to_rank_lidx.keys())
for k in keys:
t = fb.get_tensor(k)
yield k, t
finally:
fb.close()
finally:
loader.close()
def load_pre_sharded_checkpoint(model_path: str, local_rank: int, use_fastsafetensor: bool = False):
"""
load_pre_sharded_checkpoint
"""
state_dict = {}
_, safetensor_files, _ = get_all_weights_file(os.path.join(model_path, f"rank{local_rank}"))
weights_iterator = safetensors_weights_iterator(safetensor_files)
for name, weight in weights_iterator:
state_dict[name] = weight
return state_dict
def get_all_weights_file(model_path: str):
"""
get_all_safetensors
"""
model_path = Path(model_path)
use_safetensors = True
if any(model_path.glob("*.pdparams")):
key_name_list = []
files_list = [str(file) for file in model_path.glob("*.pdparams")]
use_safetensors = False
else:
safe_model_path = model_path / "model.safetensors"
if safe_model_path.exists():
files_list = [str(safe_model_path)]
with safe_open(safe_model_path, framework="np", device="cpu") as f:
key_name_list = f.keys()
return key_name_list, files_list, use_safetensors
else:
index_file = model_path / "model.safetensors.index.json"
with index_file.open("r") as f:
weight_map = json.load(f)["weight_map"]
weight_files_in_index = {str(model_path / weight_map[name]) for name in weight_map}
key_name_list = list(weight_map.keys())
files_list = sorted(weight_files_in_index)
return key_name_list, files_list, use_safetensors
def load_tp_checkpoint_v1(
model_path: str,
cls: PretrainedModel,
fd_config: FDConfig,
use_fastsafetensor: bool = True,
):
"""
load_tp_checkpoint_v1
"""
safetensor_keys, safetensor_files, _ = get_all_weights_file(model_path)
if use_fastsafetensor:
weights_iterator = fastsafetensors_weights_iterator(safetensor_files)
else:
weights_iterator = safetensors_weights_iterator(safetensor_files)
tensor_parallel_filtered_map = {}
check_tensor_parallel_prerequisites(
fd_config,
cls,
tensor_parallel_filtered_map,
safetensor_keys,
)
need_tp = True if tensor_parallel_filtered_map else False
state_dict = {}
for key, weight in weights_iterator:
paddle.device.synchronize()
if need_tp and key in tensor_parallel_filtered_map:
action = tensor_parallel_filtered_map.pop(key)
tensor = action(weight).clone()
else:
tensor = weight.clone()
state_dict[key] = tensor
weight.value().get_tensor()._clear()
return state_dict
def deal_state_dict(state_dict):
"""deal_state_dict"""
device = paddle.CUDAPinnedPlace()
for name, src in state_dict.items():
if src._is_initialized() and not isinstance(src.place, paddle.CUDAPinnedPlace):
dst = src._copy_to(device, True)
dst_tensor = dst.value().get_tensor()
src_tensor = src.value().get_tensor()
src_tensor._clear()
src_tensor._share_data_with(dst_tensor)
def load_composite_checkpoint(
model_path: str,
cls: PretrainedModel,
fd_config: FDConfig,
return_numpy=True,
):
"""
# This method supports loading model weights under three parallelism strategies:
# 1. Expert Parallel (EP)
# 2. Tensor Parallel (TP)
# 3. Pre-sharded (pre-split)
"""
# (TODO: remove in the future)
if (
fd_config.parallel_config.use_ep
and fd_config.speculative_config.model_type != "mtp"
and fd_config.parallel_config.tensor_parallel_size == 1
):
state_dict = load_ep_checkpoint(model_path, fd_config, return_numpy=True)
else:
rank_dirs = [
f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f))
]
if len(rank_dirs) > 1:
if fd_config.parallel_config.tensor_parallel_size != len(rank_dirs):
raise ValueError(f"Your model only supports loading with tp{len(rank_dirs)}")
state_dict = load_pre_sharded_checkpoint(
model_path,
fd_config.parallel_config.tensor_parallel_rank,
use_fastsafetensor=False,
)
else:
if fd_config.load_config.use_fastsafetensor and (
current_platform.available() and current_platform.is_cuda()
):
state_dict = load_tp_checkpoint_v1(model_path, cls, fd_config, use_fastsafetensor=True)
deal_state_dict(state_dict)
else:
# NOTE: for very big model, cpu will be out of memory
state_dict = load_tp_checkpoint(
model_path,
cls,
fd_config.model_config.pretrained_config,
return_numpy=return_numpy,
)
if not state_dict:
raise ValueError("weight not found in state_dict !")
return state_dict