polish code with new pre-commit rule (#2923)

This commit is contained in:
Zero Rains
2025-07-19 23:19:27 +08:00
committed by GitHub
parent b8676d71a8
commit 25698d56d1
424 changed files with 14307 additions and 13518 deletions

View File

@@ -26,36 +26,35 @@ from safetensors import safe_open
from tqdm import tqdm
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.models.tp_utils import \
check_tensor_parallel_prerequisites
from fastdeploy.model_executor.models.tp_utils import (
check_tensor_parallel_prerequisites,
)
from fastdeploy.platforms import current_platform
def load_ep_checkpoint(model_path: str,
fd_config: FDConfig,
return_numpy: bool = False):
def load_ep_checkpoint(model_path: str, fd_config: FDConfig, return_numpy: bool = False):
"""
load ep checkpoint
"""
with open(os.path.join(model_path, "model.safetensors.index.json"),
"r") as f:
with open(os.path.join(model_path, "model.safetensors.index.json"), "r") as f:
weight_list = json.load(f)["weight_map"]
filtered_map = {k: v for k, v in weight_list.items() if "experts" not in k}
num_local_ffn_keys = []
from itertools import chain
def get_expert_ranges(fd_config):
"""
Generate expert index ranges based on configuration parameters
This function is primarily used in Mixture-of-Experts (MoE) models to generate
expert index ranges according to configuration parameters. When moe_num_experts
is a list in the fd_config, it returns a chained combination of two ranges, otherwise
returns a single range.
Args:
fd_config: FastDeploy Configuration object
Returns:
If moe_num_experts is a list:
Returns a chained combination (chain object) of two ranges:
@@ -66,25 +65,28 @@ def load_ep_checkpoint(model_path: str,
"""
base_range = range(
fd_config.parallel_config.num_experts_start_offset,
fd_config.parallel_config.num_experts_start_offset + fd_config.parallel_config.num_experts_per_rank
fd_config.parallel_config.num_experts_start_offset + fd_config.parallel_config.num_experts_per_rank,
)
if isinstance(fd_config.model_config.moe_num_experts, list):
return chain(base_range,
range(base_range.start + fd_config.model_config.moe_num_experts[0], base_range.stop + fd_config.model_config.moe_num_experts[0]))
return chain(
base_range,
range(
base_range.start + fd_config.model_config.moe_num_experts[0],
base_range.stop + fd_config.model_config.moe_num_experts[0],
),
)
return base_range
for i in range(fd_config.model_config.moe_layer_start_index, fd_config.model_config.num_hidden_layers):
for j in get_expert_ranges(fd_config):
up_gate_proj_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight"
down_proj_key = (f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight")
down_proj_key = f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight"
up_gate_proj_quant_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.quant_weight"
down_proj_quant_key = (
f"ernie.layers.{i}.mlp.experts.{j}.down_proj.quant_weight")
down_proj_quant_key = f"ernie.layers.{i}.mlp.experts.{j}.down_proj.quant_weight"
up_gate_proj_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight_scale"
down_proj_scale_key = (
f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight_scale")
down_proj_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight_scale"
num_local_ffn_keys.append(up_gate_proj_key)
num_local_ffn_keys.append(down_proj_key)
num_local_ffn_keys.append(up_gate_proj_quant_key)
@@ -101,31 +103,32 @@ def load_ep_checkpoint(model_path: str,
safetensor_paths = set(filtered_map.values())
# Open each safetensor file sequentially with progress bar
for safetensor_path in tqdm(safetensor_paths,
desc="Loading safetensor files",
unit="file"):
with safe_open(os.path.join(model_path, safetensor_path),
framework="np",
device="cpu") as f:
for safetensor_path in tqdm(safetensor_paths, desc="Loading safetensor files", unit="file"):
with safe_open(
os.path.join(model_path, safetensor_path),
framework="np",
device="cpu",
) as f:
# Check if this file contains keys from filtered_map
for k in filtered_map:
if filtered_map[k] == safetensor_path and k in f.keys():
weight = f.get_tensor(k)
if not return_numpy:
weight = paddle.Tensor(weight, zero_copy=True)
weight = weight._copy_to(
paddle.framework._current_expected_place(), False)
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
state_dict[k] = weight
return state_dict
def safetensors_weights_iterator(safe_tensor_list: list[str], ):
def safetensors_weights_iterator(
safe_tensor_list: list[str],
):
"""
safetensors_weights_iterator
"""
for st_file in tqdm(
safe_tensor_list,
desc="Loading safetensors checkpoint shards",
safe_tensor_list,
desc="Loading safetensors checkpoint shards",
):
with safe_open(st_file, framework="np") as f:
for name in f.keys():
@@ -133,7 +136,9 @@ def safetensors_weights_iterator(safe_tensor_list: list[str], ):
yield name, param
def fastsafetensors_weights_iterator(safetensor_list: list[str], ):
def fastsafetensors_weights_iterator(
safetensor_list: list[str],
):
"""
Return an iterator over tensors on GPU from a given safetensor_list.
"""
@@ -143,23 +148,17 @@ def fastsafetensors_weights_iterator(safetensor_list: list[str], ):
device = f"gpu:{pg.rank}" if paddle.is_compiled_with_cuda() else "cpu"
else:
pg = SingleGroup()
device = f"gpu:{pg.rank()}" if paddle.is_compiled_with_cuda(
) else "cpu"
device = f"gpu:{pg.rank()}" if paddle.is_compiled_with_cuda() else "cpu"
safetensor_files_sub_lists = [
safetensor_list[i:i + world_size]
for i in range(0, len(safetensor_list), world_size)
safetensor_list[i : i + world_size] for i in range(0, len(safetensor_list), world_size)
]
for st_file in tqdm(
safetensor_files_sub_lists,
desc="Loading fastsafetensors checkpoint shards",
safetensor_files_sub_lists,
desc="Loading fastsafetensors checkpoint shards",
):
loader = SafeTensorsFileLoader(pg,
device,
nogds=True,
debug_log=False,
framework="paddle")
loader = SafeTensorsFileLoader(pg, device, nogds=True, debug_log=False, framework="paddle")
rank_file_map = {i: [f] for i, f in enumerate(st_file)}
loader.add_filenames(rank_file_map)
try:
@@ -175,15 +174,12 @@ def fastsafetensors_weights_iterator(safetensor_list: list[str], ):
loader.close()
def load_pre_sharded_checkpoint(model_path: str,
local_rank: int,
use_fastsafetensor: bool = False):
def load_pre_sharded_checkpoint(model_path: str, local_rank: int, use_fastsafetensor: bool = False):
"""
load_pre_sharded_checkpoint
"""
state_dict = {}
_, safetensor_files = get_all_safetensors(
os.path.join(model_path, f"rank{local_rank}"))
_, safetensor_files = get_all_safetensors(os.path.join(model_path, f"rank{local_rank}"))
weights_iterator = safetensors_weights_iterator(safetensor_files)
for name, weight in weights_iterator:
state_dict[name] = weight
@@ -201,13 +197,11 @@ def get_all_safetensors(model_path: str):
key_name_list = f.keys()
return key_name_list, safetensor_list
else:
with open(os.path.join(model_path, "model.safetensors.index.json"),
"r") as f:
with open(os.path.join(model_path, "model.safetensors.index.json"), "r") as f:
weight_map = json.load(f)["weight_map"]
weight_files_in_index = set()
for weight_name in weight_map:
weight_files_in_index.add(
os.path.join(model_path, weight_map[weight_name]))
weight_files_in_index.add(os.path.join(model_path, weight_map[weight_name]))
key_name_list = list(set(weight_map.keys()))
safetensor_list = list(weight_files_in_index)
safetensor_list.sort()
@@ -256,8 +250,7 @@ def deal_state_dict(state_dict):
"""deal_state_dict"""
device = paddle.CUDAPinnedPlace()
for name, src in state_dict.items():
if src._is_initialized() and not isinstance(src.place,
paddle.CUDAPinnedPlace):
if src._is_initialized() and not isinstance(src.place, paddle.CUDAPinnedPlace):
dst = src._copy_to(device, True)
dst_tensor = dst.value().get_tensor()
src_tensor = src.value().get_tensor()
@@ -277,22 +270,15 @@ def load_composite_checkpoint(
# 2. Tensor Parallel (TP)
# 3. Pre-sharded (pre-split)
"""
if fd_config.parallel_config.use_ep and \
fd_config.speculative_config.model_type != "mtp":
state_dict = load_ep_checkpoint(model_path,
fd_config,
return_numpy=True)
if fd_config.parallel_config.use_ep and fd_config.speculative_config.model_type != "mtp":
state_dict = load_ep_checkpoint(model_path, fd_config, return_numpy=True)
else:
rank_dirs = [
f for f in os.listdir(model_path) if f.startswith("rank")
and os.path.isdir(os.path.join(model_path, f))
f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f))
]
if len(rank_dirs) > 1:
if fd_config.parallel_config.tensor_parallel_size != len(
rank_dirs):
raise ValueError(
f"Your model only supports loading with tp{len(rank_dirs)}"
)
if fd_config.parallel_config.tensor_parallel_size != len(rank_dirs):
raise ValueError(f"Your model only supports loading with tp{len(rank_dirs)}")
state_dict = load_pre_sharded_checkpoint(
model_path,
fd_config.parallel_config.tensor_parallel_rank,
@@ -300,18 +286,17 @@ def load_composite_checkpoint(
)
else:
if fd_config.load_config.use_fastsafetensor and (
current_platform.available()
and current_platform.is_cuda()):
state_dict = load_tp_checkpoint_v1(model_path,
cls,
fd_config,
use_fastsafetensor=True)
current_platform.available() and current_platform.is_cuda()
):
state_dict = load_tp_checkpoint_v1(model_path, cls, fd_config, use_fastsafetensor=True)
deal_state_dict(state_dict)
else:
state_dict = load_tp_checkpoint(model_path,
cls,
fd_config.model_config.pretrained_config,
return_numpy=return_numpy)
state_dict = load_tp_checkpoint(
model_path,
cls,
fd_config.model_config.pretrained_config,
return_numpy=return_numpy,
)
if not state_dict:
raise ValueError("weight not found in state_dict !")
return state_dict