mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 11:56:44 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			125 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			125 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| """
 | |
| 
 | |
| import argparse
 | |
| import json
 | |
| import os
 | |
| 
 | |
| import paddle
 | |
| from paddleformers.transformers.model_utils import shard_checkpoint
 | |
| from paddleformers.utils.env import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
 | |
| from paddleformers.utils.log import logger
 | |
| from safetensors import safe_open
 | |
| from safetensors.numpy import save_file as safe_save_file
 | |
| 
 | |
| 
 | |
| def parse_args():
 | |
|     """"""
 | |
|     parser = argparse.ArgumentParser(description="Extract and save MTP weights from safetensors.")
 | |
|     parser.add_argument(
 | |
|         "-i",
 | |
|         "--input_dir",
 | |
|         type=str,
 | |
|         required=True,
 | |
|         help="Path to the input safetensors model directory.",
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "-o",
 | |
|         "--output_dir",
 | |
|         type=str,
 | |
|         required=True,
 | |
|         help="Path to the output directory for saving processed weights.",
 | |
|     )
 | |
|     return parser.parse_args()
 | |
| 
 | |
| 
 | |
| def extract_mtp_weights(input_dir: str) -> dict:
 | |
|     """
 | |
|     Load all MTP-related weights from safetensors files in input_dir.
 | |
|     """
 | |
|     index_path = os.path.join(input_dir, SAFE_WEIGHTS_INDEX_NAME)
 | |
|     if not os.path.isfile(index_path):
 | |
|         raise FileNotFoundError(f"Index file not found: {index_path}")
 | |
| 
 | |
|     with open(index_path, "r", encoding="utf-8") as f:
 | |
|         index = json.load(f)
 | |
| 
 | |
|     weight_map = index.get("weight_map", {})
 | |
|     required_files = {v for k, v in weight_map.items() if "mtp" in k}
 | |
|     logger.info(f"Found {len(required_files)} shards with MTP weights.")
 | |
| 
 | |
|     state_dict = {}
 | |
|     for file_name in required_files:
 | |
|         file_path = os.path.join(input_dir, file_name)
 | |
|         if not os.path.isfile(file_path):
 | |
|             logger.warning(f"Shard not found: {file_path}")
 | |
|             continue
 | |
|         logger.info(f"Loading shard: {file_path}")
 | |
|         with safe_open(file_path, framework="np", device="cpu") as f:
 | |
|             for k in f.keys():
 | |
|                 if "mtp" in k:
 | |
|                     state_dict[k] = f.get_tensor(k)
 | |
| 
 | |
|     logger.info(f"Loaded {len(state_dict)} MTP weights.")
 | |
|     return state_dict
 | |
| 
 | |
| 
 | |
| def save_safetensors(state_dict: dict, output_dir: str):
 | |
|     """
 | |
|     Save state_dict as safetensors shards into output_dir.
 | |
|     """
 | |
|     os.makedirs(output_dir, exist_ok=True)
 | |
| 
 | |
|     logger.info("Converting tensors to numpy arrays.")
 | |
|     for k in list(state_dict.keys()):
 | |
|         if isinstance(state_dict[k], paddle.Tensor):
 | |
|             tensor = state_dict.pop(k)
 | |
|             array = tensor.cpu().numpy()
 | |
|             state_dict[k] = array
 | |
| 
 | |
|     logger.info("Sharding and saving safetensors.")
 | |
|     shards, index = shard_checkpoint(
 | |
|         state_dict,
 | |
|         max_shard_size="5GB",
 | |
|         weights_name=SAFE_WEIGHTS_NAME,
 | |
|         shard_format="naive",
 | |
|     )
 | |
| 
 | |
|     for shard_file, shard in shards.items():
 | |
|         save_path = os.path.join(output_dir, shard_file)
 | |
|         logger.info(f"Saving shard: {save_path}")
 | |
|         safe_save_file(shard, save_path, metadata={"format": "np"})
 | |
| 
 | |
|     index_path = os.path.join(output_dir, SAFE_WEIGHTS_INDEX_NAME)
 | |
|     with open(index_path, "w", encoding="utf-8") as f:
 | |
|         json.dump(index, f, indent=2)
 | |
|     logger.info(f"Saved index file: {index_path}")
 | |
| 
 | |
| 
 | |
| def main():
 | |
|     """"""
 | |
|     args = parse_args()
 | |
|     logger.info(f"Input dir: {args.input_dir}")
 | |
|     logger.info(f"Output dir: {args.output_dir}")
 | |
| 
 | |
|     state_dict = extract_mtp_weights(args.input_dir)
 | |
|     save_safetensors(state_dict, args.output_dir)
 | |
|     logger.info("MTP weights extracted and saved successfully.")
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     main()
 | 
