mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-11-01 04:12:58 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			365 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			365 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| """
 | |
| 
 | |
| import json
 | |
| import os
 | |
| import time
 | |
| 
 | |
| import paddle
 | |
| import paddle.distributed as dist
 | |
| from fastsafetensors import SafeTensorsFileLoader, SingleGroup
 | |
| from paddleformers.transformers import PretrainedModel
 | |
| from paddleformers.transformers.model_utils import load_tp_checkpoint
 | |
| from paddleformers.utils.log import logger
 | |
| from paddleformers.utils.safetensors import fast_safe_open
 | |
| from safetensors import safe_open
 | |
| from tqdm import tqdm
 | |
| 
 | |
| from fastdeploy.config import FDConfig
 | |
| from fastdeploy.model_executor.models.tp_utils import (
 | |
|     check_tensor_parallel_prerequisites,
 | |
| )
 | |
| from fastdeploy.platforms import current_platform
 | |
| 
 | |
| 
 | |
| def measure_time(func):
 | |
|     def wrapper(*args, **kwargs):
 | |
|         time_before_load = time.time()
 | |
|         result = func(*args, **kwargs)
 | |
|         time_after_load = time.time()
 | |
|         logger.info(f"Model loading took {time_after_load - time_before_load} seconds")
 | |
|         return result
 | |
| 
 | |
|     return wrapper
 | |
| 
 | |
| 
 | |
| def load_reordered_experts(model_path: str, key_name: str):
 | |
|     from safetensors import safe_open
 | |
| 
 | |
|     with open(os.path.join(model_path, "model.safetensors.index.json"), "r") as f:
 | |
|         weight_list = json.load(f)["weight_map"]
 | |
|     safetensor_path = os.path.join(model_path, weight_list[key_name])
 | |
|     with safe_open(safetensor_path, framework="np", device="cpu") as f:
 | |
|         if key_name in f.keys():
 | |
|             weight = f.get_tensor(key_name)
 | |
|             weight = paddle.Tensor(weight, zero_copy=True)
 | |
|             weight = weight._copy_to(paddle.framework._current_expected_place(), False)
 | |
|             return weight
 | |
| 
 | |
| 
 | |
| def load_ep_checkpoint(cls: PretrainedModel, model_path: str, fd_config: FDConfig, return_numpy: bool = False):
 | |
|     """
 | |
|     load ep checkpoint
 | |
|     """
 | |
|     with open(os.path.join(model_path, "model.safetensors.index.json"), "r") as f:
 | |
|         weight_list = json.load(f)["weight_map"]
 | |
|     filtered_map = {k: v for k, v in weight_list.items() if ".experts." not in k}
 | |
|     num_local_ffn_keys = []
 | |
| 
 | |
|     from itertools import chain
 | |
| 
 | |
|     def get_expert_ranges(fd_config):
 | |
|         """
 | |
|         Generate expert index ranges based on configuration parameters
 | |
| 
 | |
|         This function is primarily used in Mixture-of-Experts (MoE) models to generate
 | |
|         expert index ranges according to configuration parameters. When moe_num_experts
 | |
|         is a list in the fd_config, it returns a chained combination of two ranges, otherwise
 | |
|         returns a single range.
 | |
| 
 | |
|         Args:
 | |
|             fd_config: FastDeploy Configuration object
 | |
| 
 | |
|         Returns:
 | |
|             If moe_num_experts is a list:
 | |
|                 Returns a chained combination (chain object) of two ranges:
 | |
|                     1. Base range: [num_experts_start_offset, num_experts_start_offset + num_experts_per_rank)
 | |
|                     2. Offset range: [base_range.start + moe_num_experts[0], base_range.stop + moe_num_experts[0])
 | |
|             Else:
 | |
|                 Returns single range: [num_experts_start_offset, num_experts_start_offset + num_experts_per_rank)
 | |
|         """
 | |
|         base_range = range(
 | |
|             fd_config.parallel_config.num_experts_start_offset,
 | |
|             fd_config.parallel_config.num_experts_start_offset + fd_config.parallel_config.num_experts_per_rank,
 | |
|         )
 | |
|         if isinstance(fd_config.model_config.moe_num_experts, list):
 | |
|             return chain(
 | |
|                 base_range,
 | |
|                 range(
 | |
|                     base_range.start + fd_config.model_config.moe_num_experts[0],
 | |
|                     base_range.stop + fd_config.model_config.moe_num_experts[0],
 | |
|                 ),
 | |
|             )
 | |
|         return base_range
 | |
| 
 | |
|     for i in range(fd_config.model_config.moe_layer_start_index, fd_config.model_config.num_hidden_layers):
 | |
|         for j in get_expert_ranges(fd_config):
 | |
|             up_gate_proj_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight"
 | |
|             down_proj_key = f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight"
 | |
| 
 | |
|             up_gate_proj_quant_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.quant_weight"
 | |
|             down_proj_quant_key = f"ernie.layers.{i}.mlp.experts.{j}.down_proj.quant_weight"
 | |
| 
 | |
|             up_gate_proj_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.weight_scale"
 | |
|             down_proj_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.down_proj.weight_scale"
 | |
| 
 | |
|             down_proj_in_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.down_proj.activation_scale"
 | |
|             num_local_ffn_keys.append(up_gate_proj_key)
 | |
|             num_local_ffn_keys.append(down_proj_key)
 | |
|             num_local_ffn_keys.append(up_gate_proj_quant_key)
 | |
|             num_local_ffn_keys.append(down_proj_quant_key)
 | |
|             num_local_ffn_keys.append(up_gate_proj_scale_key)
 | |
|             num_local_ffn_keys.append(down_proj_scale_key)
 | |
|             num_local_ffn_keys.append(down_proj_in_scale_key)
 | |
| 
 | |
|         # for EP w4a8, we need all expert's activation_scale for up_gate_proj
 | |
|         num_experts = fd_config.model_config.moe_num_experts
 | |
|         if isinstance(num_experts, list):
 | |
|             num_experts = num_experts[0]
 | |
| 
 | |
|         for j in range(num_experts):
 | |
|             up_gate_proj_in_scale_key = f"ernie.layers.{i}.mlp.experts.{j}.up_gate_proj.activation_scale"
 | |
|             num_local_ffn_keys.append(up_gate_proj_in_scale_key)
 | |
| 
 | |
|     for k in num_local_ffn_keys:
 | |
|         if k in weight_list:
 | |
|             filtered_map[k] = weight_list[k]
 | |
| 
 | |
|     if fd_config.parallel_config.tensor_parallel_size > 1:
 | |
|         tp_actions = cls._get_tensor_parallel_mappings(fd_config.model_config.pretrained_config)
 | |
|         new_actions = {k: v for k, v in tp_actions.items() if k not in num_local_ffn_keys}
 | |
| 
 | |
|     state_dict = {}
 | |
|     # Get all safetensor file paths that need to be opened
 | |
|     safetensor_paths = set(filtered_map.values())
 | |
| 
 | |
|     # Open each safetensor file sequentially with progress bar
 | |
|     for safetensor_path in tqdm(safetensor_paths, desc="Loading safetensor files", unit="file"):
 | |
|         with safe_open(
 | |
|             os.path.join(model_path, safetensor_path),
 | |
|             framework="np",
 | |
|             device="cpu",
 | |
|         ) as f:
 | |
|             # Check if this file contains keys from filtered_map
 | |
|             for k in filtered_map:
 | |
|                 if filtered_map[k] == safetensor_path and k in f.keys():
 | |
|                     weight = f.get_tensor(k)
 | |
|                     if fd_config.parallel_config.tensor_parallel_size > 1:
 | |
|                         if k in new_actions:
 | |
|                             weight = new_actions[k](weight)
 | |
|                     if not return_numpy:
 | |
|                         weight = paddle.Tensor(weight, zero_copy=True)
 | |
|                         weight = weight._copy_to(paddle.framework._current_expected_place(), False)
 | |
|                     state_dict[k] = weight
 | |
|     return state_dict
 | |
| 
 | |
| 
 | |
| def safetensors_weights_iterator(safe_tensor_list: list[str]):
 | |
|     """
 | |
|     safetensors_weights_iterator
 | |
|     """
 | |
|     for st_file in tqdm(
 | |
|         safe_tensor_list,
 | |
|         desc="Loading safetensors checkpoint shards",
 | |
|     ):
 | |
|         with safe_open(st_file, framework="np") as f:
 | |
|             for name in f.keys():
 | |
|                 param = f.get_tensor(name)
 | |
|                 yield name, param
 | |
| 
 | |
| 
 | |
| def fast_weights_iterator(safe_tensor_list: list[str]):
 | |
|     """
 | |
|     paddleformers' iterator for safetensors
 | |
|     """
 | |
|     for st_file in tqdm(
 | |
|         safe_tensor_list,
 | |
|         desc="Loading safetensors checkpoint shards",
 | |
|     ):
 | |
|         with fast_safe_open(st_file, framework="np") as f:
 | |
|             for name in f.keys():
 | |
|                 param_slice = f.get_slice(name)
 | |
|                 yield name, param_slice
 | |
| 
 | |
| 
 | |
| def fastsafetensors_weights_iterator(
 | |
|     safetensor_list: list[str],
 | |
| ):
 | |
|     """
 | |
|     Return an iterator over tensors on GPU from a given safetensor_list.
 | |
|     """
 | |
|     world_size = dist.get_world_size()
 | |
|     if world_size > 1:
 | |
|         pg = dist.get_group()
 | |
|         device = f"gpu:{pg.rank}" if paddle.is_compiled_with_cuda() else "cpu"
 | |
|     else:
 | |
|         pg = SingleGroup()
 | |
|         device = f"gpu:{pg.rank()}" if paddle.is_compiled_with_cuda() else "cpu"
 | |
| 
 | |
|     safetensor_files_sub_lists = [
 | |
|         safetensor_list[i : i + world_size] for i in range(0, len(safetensor_list), world_size)
 | |
|     ]
 | |
| 
 | |
|     for st_file in tqdm(
 | |
|         safetensor_files_sub_lists,
 | |
|         desc="Loading fastsafetensors checkpoint shards",
 | |
|     ):
 | |
|         loader = SafeTensorsFileLoader(pg, device, nogds=True, debug_log=False, framework="paddle")
 | |
|         rank_file_map = {i: [f] for i, f in enumerate(st_file)}
 | |
|         loader.add_filenames(rank_file_map)
 | |
|         try:
 | |
|             fb = loader.copy_files_to_device()
 | |
|             try:
 | |
|                 keys = list(fb.key_to_rank_lidx.keys())
 | |
|                 for k in keys:
 | |
|                     t = fb.get_tensor(k)
 | |
|                     yield k, t
 | |
|             finally:
 | |
|                 fb.close()
 | |
|         finally:
 | |
|             loader.close()
 | |
| 
 | |
| 
 | |
| def load_pre_sharded_checkpoint(model_path: str, local_rank: int, use_fastsafetensor: bool = False):
 | |
|     """
 | |
|     load_pre_sharded_checkpoint
 | |
|     """
 | |
| 
 | |
|     state_dict = {}
 | |
|     _, safetensor_files = get_all_safetensors(os.path.join(model_path, f"rank{local_rank}"))
 | |
|     weights_iterator = safetensors_weights_iterator(safetensor_files)
 | |
|     for name, weight in weights_iterator:
 | |
|         state_dict[name] = weight
 | |
|     return state_dict
 | |
| 
 | |
| 
 | |
| def get_all_safetensors(model_path: str):
 | |
|     """
 | |
|     get_all_safetensors
 | |
|     """
 | |
|     safe_model_path = os.path.join(model_path, "model.safetensors")
 | |
|     if os.path.exists(safe_model_path):
 | |
|         safetensor_list = [safe_model_path]
 | |
|         with safe_open(safe_model_path, framework="np", device="cpu") as f:
 | |
|             key_name_list = f.keys()
 | |
|         return key_name_list, safetensor_list
 | |
|     else:
 | |
|         with open(os.path.join(model_path, "model.safetensors.index.json"), "r") as f:
 | |
|             weight_map = json.load(f)["weight_map"]
 | |
|         weight_files_in_index = set()
 | |
|         for weight_name in weight_map:
 | |
|             weight_files_in_index.add(os.path.join(model_path, weight_map[weight_name]))
 | |
|         key_name_list = list(set(weight_map.keys()))
 | |
|         safetensor_list = list(weight_files_in_index)
 | |
|         safetensor_list.sort()
 | |
|     return key_name_list, safetensor_list
 | |
| 
 | |
| 
 | |
| def load_tp_checkpoint_v1(
 | |
|     model_path: str,
 | |
|     cls: PretrainedModel,
 | |
|     fd_config: FDConfig,
 | |
|     use_fastsafetensor: bool = True,
 | |
| ):
 | |
|     """
 | |
|     load_tp_checkpoint_v1
 | |
|     """
 | |
| 
 | |
|     safetensor_keys, safetensor_files = get_all_safetensors(model_path)
 | |
| 
 | |
|     if use_fastsafetensor:
 | |
|         weights_iterator = fastsafetensors_weights_iterator(safetensor_files)
 | |
|     else:
 | |
|         weights_iterator = safetensors_weights_iterator(safetensor_files)
 | |
| 
 | |
|     tensor_parallel_filtered_map = {}
 | |
|     check_tensor_parallel_prerequisites(
 | |
|         fd_config,
 | |
|         cls,
 | |
|         tensor_parallel_filtered_map,
 | |
|         safetensor_keys,
 | |
|     )
 | |
|     need_tp = True if tensor_parallel_filtered_map else False
 | |
|     state_dict = {}
 | |
|     for key, weight in weights_iterator:
 | |
|         paddle.device.synchronize()
 | |
|         if need_tp and key in tensor_parallel_filtered_map:
 | |
|             action = tensor_parallel_filtered_map.pop(key)
 | |
|             tensor = action(weight).clone()
 | |
|         else:
 | |
|             tensor = weight.clone()
 | |
|         state_dict[key] = tensor
 | |
|         weight.value().get_tensor()._clear()
 | |
|     return state_dict
 | |
| 
 | |
| 
 | |
| def deal_state_dict(state_dict):
 | |
|     """deal_state_dict"""
 | |
|     device = paddle.CUDAPinnedPlace()
 | |
|     for name, src in state_dict.items():
 | |
|         if src._is_initialized() and not isinstance(src.place, paddle.CUDAPinnedPlace):
 | |
|             dst = src._copy_to(device, True)
 | |
|             dst_tensor = dst.value().get_tensor()
 | |
|             src_tensor = src.value().get_tensor()
 | |
|             src_tensor._clear()
 | |
|             src_tensor._share_data_with(dst_tensor)
 | |
| 
 | |
| 
 | |
| def load_composite_checkpoint(
 | |
|     model_path: str,
 | |
|     cls: PretrainedModel,
 | |
|     fd_config: FDConfig,
 | |
|     return_numpy=True,
 | |
| ):
 | |
|     """
 | |
|     # This method supports loading model weights under three parallelism strategies:
 | |
|     # 1. Expert Parallel (EP)
 | |
|     # 2. Tensor Parallel (TP)
 | |
|     # 3. Pre-sharded (pre-split)
 | |
|     """
 | |
|     # (TODO: remove in the future)
 | |
|     if fd_config.parallel_config.use_ep and fd_config.speculative_config.model_type != "mtp":
 | |
|         state_dict = load_ep_checkpoint(cls, model_path, fd_config, return_numpy=True)
 | |
|     else:
 | |
|         rank_dirs = [
 | |
|             f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f))
 | |
|         ]
 | |
|         if len(rank_dirs) > 1:
 | |
|             if fd_config.parallel_config.tensor_parallel_size != len(rank_dirs):
 | |
|                 raise ValueError(f"Your model only supports loading with tp{len(rank_dirs)}")
 | |
|             state_dict = load_pre_sharded_checkpoint(
 | |
|                 model_path,
 | |
|                 fd_config.parallel_config.tensor_parallel_rank,
 | |
|                 use_fastsafetensor=False,
 | |
|             )
 | |
|         else:
 | |
|             if fd_config.load_config.use_fastsafetensor and (
 | |
|                 current_platform.available() and current_platform.is_cuda()
 | |
|             ):
 | |
|                 state_dict = load_tp_checkpoint_v1(model_path, cls, fd_config, use_fastsafetensor=True)
 | |
|                 deal_state_dict(state_dict)
 | |
|             else:
 | |
|                 # NOTE: for very big model, cpu will be out of memory
 | |
|                 state_dict = load_tp_checkpoint(
 | |
|                     model_path,
 | |
|                     cls,
 | |
|                     fd_config.model_config.pretrained_config,
 | |
|                     return_numpy=return_numpy,
 | |
|                 )
 | |
|     if not state_dict:
 | |
|         raise ValueError("weight not found in state_dict !")
 | |
|     return state_dict
 | 
