Files
FastDeploy/fastdeploy/model_executor/load_weight_utils.py
xiaoxiaohehe001 9afa236e39 [NewFeatures] support eplb (#3547)
* [NewFeatures] support eplb

* fix eplb
2025-08-26 16:19:30 +08:00

353 lines
13 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import json
import os
import time
import paddle
import paddle.distributed as dist
from fastsafetensors import SafeTensorsFileLoader, SingleGroup
from paddleformers.transformers import PretrainedModel
from paddleformers.transformers.model_utils import load_tp_checkpoint
from paddleformers.utils.log import logger
from paddleformers.utils.safetensors import fast_safe_open
from safetensors import safe_open
from tqdm import tqdm
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.models.tp_utils import (
check_tensor_parallel_prerequisites,
)
from fastdeploy.platforms import current_platform
def measure_time(func):
def wrapper(*args, **kwargs):
time_before_load = time.time()
result = func(*args, **kwargs)
time_after_load = time.time()
logger.info(f"Model loading took {time_after_load - time_before_load} seconds")
return result
return wrapper
def load_reordered_experts(model_path: str, key_name: str):
from safetensors import safe_open
with open(os.path.join(model_path, "model.safetensors.index.json"), "r") as f:
weight_list = json.load(f)["weight_map"]
safetensor_path = os.path.join(model_path, weight_list[key_name])
with safe_open(safetensor_path, framework="np", device="cpu") as f:
if key_name in f.keys():
weight = f.get_tensor(key_name)
weight = paddle.Tensor(weight, zero_copy=True)
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
return weight
def load_ep_checkpoint(model_path: str, fd_config: FDConfig, return_numpy: bool = False):
"""
load ep checkpoint
"""
with open(os.path.join(model_path, "model.safetensors.index.json"), "r") as f:
weight_list = json.load(f)["weight_map"]
filtered_map = {k: v for k, v in weight_list.items() if ".experts." not in k}
num_local_ffn_keys = []
from itertools import chain
def get_expert_ranges(fd_config):
"""
Generate expert index ranges based on configuration parameters
This function is primarily used in Mixture-of-Experts (MoE) models to generate
expert index ranges according to configuration parameters. When moe_num_experts
is a list in the fd_config, it returns a chained combination of two ranges, otherwise
returns a single range.
Args:
fd_config: FastDeploy Configuration object
Returns:
If moe_num_experts is a list:
Returns a chained combination (chain object) of two ranges:
1. Base range: [num_experts_start_offset, num_experts_start_offset + num_experts_per_rank)
2. Offset range: [base_range.start + moe_num_experts[0], base_range.stop + moe_num_experts[0])
Else:
Returns single range: [num_experts_start_offset, num_experts_start_offset + num_experts_per_rank)
"""
base_range = range(
fd_config.parallel_config.num_experts_start_offset,
fd_config.parallel_config.num_experts_start_offset + fd_config.parallel_config.num_experts_per_rank,
)
if isinstance(fd_config.model_config.moe_num_experts, list):
return chain(
base_range,
range(
base_range.start + fd_config.model_config.moe_num_experts[0],
base_range.stop + fd_config.model_config.moe_num_experts[0],
),
)
return base_range
for i in range(fd_config.model_config.moe_layer_start_index, fd_config.model_config.num_hidden_layers):
for j in get_expert_ranges(fd_config):
up_gate_proj_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight"
down_proj_key = f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight"
up_gate_proj_quant_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.quant_weight"
down_proj_quant_key = f"ernie.layers.{i}.mlp.experts.{j}.down_proj.quant_weight"
up_gate_proj_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight_scale"
down_proj_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight_scale"
down_proj_in_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.down_proj.activation_scale"
num_local_ffn_keys.append(up_gate_proj_key)
num_local_ffn_keys.append(down_proj_key)
num_local_ffn_keys.append(up_gate_proj_quant_key)
num_local_ffn_keys.append(down_proj_quant_key)
num_local_ffn_keys.append(up_gate_proj_scale_key)
num_local_ffn_keys.append(down_proj_scale_key)
num_local_ffn_keys.append(down_proj_in_scale_key)
# for EP w4a8, we need all expert's activation_scale for up_gate_proj
num_experts = fd_config.model_config.moe_num_experts
if isinstance(num_experts, list):
num_experts = num_experts[0]
for j in range(num_experts):
up_gate_proj_in_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.activation_scale"
num_local_ffn_keys.append(up_gate_proj_in_scale_key)
for k in num_local_ffn_keys:
if k in weight_list:
filtered_map[k] = weight_list[k]
state_dict = {}
# Get all safetensor file paths that need to be opened
safetensor_paths = set(filtered_map.values())
# Open each safetensor file sequentially with progress bar
for safetensor_path in tqdm(safetensor_paths, desc="Loading safetensor files", unit="file"):
with safe_open(
os.path.join(model_path, safetensor_path),
framework="np",
device="cpu",
) as f:
# Check if this file contains keys from filtered_map
for k in filtered_map:
if filtered_map[k] == safetensor_path and k in f.keys():
weight = f.get_tensor(k)
if not return_numpy:
weight = paddle.Tensor(weight, zero_copy=True)
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
state_dict[k] = weight
return state_dict
def safetensors_weights_iterator(safe_tensor_list: list[str]):
"""
safetensors_weights_iterator
"""
for st_file in tqdm(
safe_tensor_list,
desc="Loading safetensors checkpoint shards",
):
with safe_open(st_file, framework="np") as f:
for name in f.keys():
param = f.get_tensor(name)
yield name, param
def fast_weights_iterator(safe_tensor_list: list[str]):
"""
paddleformers' iterator for safetensors
"""
for st_file in tqdm(
safe_tensor_list,
desc="Loading safetensors checkpoint shards",
):
with fast_safe_open(st_file, framework="np") as f:
for name in f.keys():
param_slice = f.get_slice(name)
paddle_tensor = get_tensor(param_slice)
yield name, paddle_tensor
def fastsafetensors_weights_iterator(
safetensor_list: list[str],
):
"""
Return an iterator over tensors on GPU from a given safetensor_list.
"""
world_size = dist.get_world_size()
if world_size > 1:
pg = dist.get_group()
device = f"gpu:{pg.rank}" if paddle.is_compiled_with_cuda() else "cpu"
else:
pg = SingleGroup()
device = f"gpu:{pg.rank()}" if paddle.is_compiled_with_cuda() else "cpu"
safetensor_files_sub_lists = [
safetensor_list[i : i + world_size] for i in range(0, len(safetensor_list), world_size)
]
for st_file in tqdm(
safetensor_files_sub_lists,
desc="Loading fastsafetensors checkpoint shards",
):
loader = SafeTensorsFileLoader(pg, device, nogds=True, debug_log=False, framework="paddle")
rank_file_map = {i: [f] for i, f in enumerate(st_file)}
loader.add_filenames(rank_file_map)
try:
fb = loader.copy_files_to_device()
try:
keys = list(fb.key_to_rank_lidx.keys())
for k in keys:
t = fb.get_tensor(k)
yield k, t
finally:
fb.close()
finally:
loader.close()
def load_pre_sharded_checkpoint(model_path: str, local_rank: int, use_fastsafetensor: bool = False):
"""
load_pre_sharded_checkpoint
"""
state_dict = {}
_, safetensor_files = get_all_safetensors(os.path.join(model_path, f"rank{local_rank}"))
weights_iterator = safetensors_weights_iterator(safetensor_files)
for name, weight in weights_iterator:
state_dict[name] = weight
return state_dict
def get_all_safetensors(model_path: str):
"""
get_all_safetensors
"""
safe_model_path = os.path.join(model_path, "model.safetensors")
if os.path.exists(safe_model_path):
safetensor_list = [safe_model_path]
with safe_open(safe_model_path, framework="np", device="cpu") as f:
key_name_list = f.keys()
return key_name_list, safetensor_list
else:
with open(os.path.join(model_path, "model.safetensors.index.json"), "r") as f:
weight_map = json.load(f)["weight_map"]
weight_files_in_index = set()
for weight_name in weight_map:
weight_files_in_index.add(os.path.join(model_path, weight_map[weight_name]))
key_name_list = list(set(weight_map.keys()))
safetensor_list = list(weight_files_in_index)
safetensor_list.sort()
return key_name_list, safetensor_list
def load_tp_checkpoint_v1(
model_path: str,
cls: PretrainedModel,
fd_config: FDConfig,
use_fastsafetensor: bool = True,
):
"""
load_tp_checkpoint_v1
"""
safetensor_keys, safetensor_files = get_all_safetensors(model_path)
if use_fastsafetensor:
weights_iterator = fastsafetensors_weights_iterator(safetensor_files)
else:
weights_iterator = safetensors_weights_iterator(safetensor_files)
tensor_parallel_filtered_map = {}
check_tensor_parallel_prerequisites(
fd_config,
cls,
tensor_parallel_filtered_map,
safetensor_keys,
)
need_tp = True if tensor_parallel_filtered_map else False
state_dict = {}
for key, weight in weights_iterator:
paddle.device.synchronize()
if need_tp and key in tensor_parallel_filtered_map:
action = tensor_parallel_filtered_map.pop(key)
tensor = action(weight).clone()
else:
tensor = weight.clone()
state_dict[key] = tensor
weight.value().get_tensor()._clear()
return state_dict
def deal_state_dict(state_dict):
"""deal_state_dict"""
device = paddle.CUDAPinnedPlace()
for name, src in state_dict.items():
if src._is_initialized() and not isinstance(src.place, paddle.CUDAPinnedPlace):
dst = src._copy_to(device, True)
dst_tensor = dst.value().get_tensor()
src_tensor = src.value().get_tensor()
src_tensor._clear()
src_tensor._share_data_with(dst_tensor)
def load_composite_checkpoint(
model_path: str,
cls: PretrainedModel,
fd_config: FDConfig,
return_numpy=True,
):
"""
# This method supports loading model weights under three parallelism strategies:
# 1. Expert Parallel (EP)
# 2. Tensor Parallel (TP)
# 3. Pre-sharded (pre-split)
"""
rank_dirs = [
f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f))
]
if len(rank_dirs) > 1:
if fd_config.parallel_config.tensor_parallel_size != len(rank_dirs):
raise ValueError(f"Your model only supports loading with tp{len(rank_dirs)}")
state_dict = load_pre_sharded_checkpoint(
model_path,
fd_config.parallel_config.tensor_parallel_rank,
use_fastsafetensor=False,
)
else:
if fd_config.load_config.use_fastsafetensor and (current_platform.available() and current_platform.is_cuda()):
state_dict = load_tp_checkpoint_v1(model_path, cls, fd_config, use_fastsafetensor=True)
deal_state_dict(state_dict)
else:
state_dict = load_tp_checkpoint(
model_path,
cls,
fd_config.model_config.pretrained_config,
return_numpy=return_numpy,
)
if not state_dict:
raise ValueError("weight not found in state_dict !")
return state_dict