mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 11:56:44 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			179 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			179 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import argparse
 | |
| import json
 | |
| import os
 | |
| import re
 | |
| import time
 | |
| 
 | |
| import paddle
 | |
| from paddleformers.trainer import strtobool
 | |
| from paddleformers.transformers.configuration_utils import PretrainedConfig
 | |
| from paddleformers.transformers.model_utils import shard_checkpoint
 | |
| from paddleformers.utils.env import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
 | |
| from paddleformers.utils.log import logger
 | |
| from safetensors.numpy import save_file as safe_save_file
 | |
| 
 | |
| from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer
 | |
| from fastdeploy.model_executor.layers.utils import get_tensor
 | |
| from fastdeploy.model_executor.load_weight_utils import (
 | |
|     get_all_safetensors,
 | |
|     safetensors_weights_iterator,
 | |
| )
 | |
| 
 | |
| 
 | |
| def parse_arguments():
 | |
|     """
 | |
|     parse_arguments
 | |
|     """
 | |
|     parser = argparse.ArgumentParser()
 | |
| 
 | |
|     parser.add_argument(
 | |
|         "--model_name_or_path",
 | |
|         default=None,
 | |
|         required=True,
 | |
|         help="The directory of model.",
 | |
|     )
 | |
| 
 | |
|     parser.add_argument(
 | |
|         "--output_dir",
 | |
|         default="merged_output",
 | |
|         required=True,
 | |
|         help="The directory of merged model output.",
 | |
|     )
 | |
| 
 | |
|     parser.add_argument(
 | |
|         "--safe_serialization",
 | |
|         type=strtobool,
 | |
|         default="True",
 | |
|         help="Whether merge the model into safetensors format.",
 | |
|     )
 | |
| 
 | |
|     return parser.parse_args()
 | |
| 
 | |
| 
 | |
| def reorder():
 | |
|     def fn(weight):
 | |
|         from paddle.nn.quant import weight_quantize
 | |
| 
 | |
|         quant_weight, _ = weight_quantize(weight.cuda(), algo="w4a8", arch=80)
 | |
|         return quant_weight.cpu()
 | |
| 
 | |
|     return fn
 | |
| 
 | |
| 
 | |
| def deal_in_scale():
 | |
|     def fn(in_scale):
 | |
|         processed_in_scale = 1 / in_scale
 | |
|         return processed_in_scale
 | |
| 
 | |
|     return fn
 | |
| 
 | |
| 
 | |
| def deal_weight_scale():
 | |
|     def fn(weight_scale, processed_in_scale):
 | |
|         processed_weight_scale = weight_scale / (127 * 112) / processed_in_scale
 | |
|         return processed_weight_scale
 | |
| 
 | |
|     return fn
 | |
| 
 | |
| 
 | |
| # tmp support w4a8
 | |
| def deal_quant(state_dict, save_state_dict):
 | |
|     w4a8_mapping = [
 | |
|         # pattern,fn
 | |
|         (r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.activation_scale", deal_in_scale()),
 | |
|         (r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.weight_scale", deal_weight_scale()),
 | |
|         (r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.quant_weight", reorder()),
 | |
|     ]
 | |
|     for pattern, fn in w4a8_mapping:
 | |
|         for key in list(state_dict.keys()):
 | |
|             # print(f"deal {key}")
 | |
|             match = re.search(pattern, key)
 | |
|             if match:
 | |
|                 # print(f"{key} is match")
 | |
|                 weight_or_scale = state_dict.pop(key)
 | |
|                 if "weight_scale" in key:
 | |
|                     in_scale_key = key.replace("weight_scale", "activation_scale")
 | |
|                     in_scale = save_state_dict[in_scale_key]
 | |
|                     save_state_dict[key] = fn(weight_or_scale, in_scale)
 | |
|                 else:
 | |
|                     save_state_dict[key] = fn(weight_or_scale)
 | |
| 
 | |
| 
 | |
| def save_safetensors(state_dict, args):
 | |
|     """
 | |
|     save_safetensors
 | |
|     """
 | |
|     logger.info("Move to numpy.")
 | |
|     for k in list(state_dict.keys()):
 | |
|         if isinstance(state_dict[k], paddle.Tensor):
 | |
|             state_dict[k] = state_dict.pop(k).cpu().numpy()
 | |
| 
 | |
|     logger.info("Save safetensors files.")
 | |
|     shards, index = shard_checkpoint(
 | |
|         state_dict,
 | |
|         max_shard_size="5GB",
 | |
|         weights_name=SAFE_WEIGHTS_NAME,
 | |
|         shard_format="naive",
 | |
|     )
 | |
|     for shard_file, shard in shards.items():
 | |
|         save_file = os.path.join(args.output_dir, shard_file)
 | |
|         logger.info(f"Saving {save_file}")
 | |
|         safe_save_file(shard, save_file, metadata={"format": "np"})
 | |
| 
 | |
|     save_index_file = os.path.join(args.output_dir, SAFE_WEIGHTS_INDEX_NAME)
 | |
|     with open(save_index_file, "w", encoding="utf-8") as f:
 | |
|         content = json.dumps(index, indent=2) + "\n"
 | |
|         f.write(content)
 | |
| 
 | |
| 
 | |
| def main():
 | |
|     """
 | |
|     main
 | |
|     """
 | |
|     args = parse_arguments()
 | |
|     pretrained_config, _ = PretrainedConfig.get_config_dict(args.model_name_or_path)
 | |
|     pretrained_config = PretrainedConfig.from_dict(pretrained_config)
 | |
|     vocab_file_names = [
 | |
|         "tokenizer.model",
 | |
|         "spm.model",
 | |
|         "ernie_token_100k.model",
 | |
|     ]
 | |
|     for i in range(len(vocab_file_names)):
 | |
|         if os.path.exists(os.path.join(args.model_name_or_path, vocab_file_names[i])):
 | |
|             ErnieBotTokenizer.resource_files_names["vocab_file"] = vocab_file_names[i]
 | |
|             break
 | |
|     tokenizer = ErnieBotTokenizer.from_pretrained(args.model_name_or_path)
 | |
|     _, safetensor_files = get_all_safetensors(args.model_name_or_path)
 | |
|     weights_iterator = safetensors_weights_iterator(safetensor_files)
 | |
|     state_dict = {}
 | |
|     save_state_dict = {}
 | |
|     start = time.perf_counter()
 | |
|     for k, v in weights_iterator:
 | |
|         state_dict[k] = get_tensor(v).cpu()
 | |
|     end = time.perf_counter()
 | |
|     logger.info("Finish Quantize.")
 | |
|     logger.info(f"load and quantize took : {end - start:.6f} seconds")
 | |
|     deal_quant(state_dict, save_state_dict)
 | |
|     for key in list(state_dict.keys()):
 | |
|         save_state_dict[key] = state_dict.pop(key)
 | |
|     logger.info("Begin to save model")
 | |
|     os.makedirs(args.output_dir, exist_ok=True)
 | |
|     start = time.perf_counter()
 | |
|     if not args.safe_serialization:
 | |
|         paddle.save(
 | |
|             save_state_dict,
 | |
|             os.path.join(args.output_dir, "model_state.pdparams"),
 | |
|         )
 | |
|     else:
 | |
|         save_safetensors(save_state_dict, args)
 | |
|     pretrained_config.is_permuted = True
 | |
|     pretrained_config.save_pretrained(args.output_dir)
 | |
|     tokenizer.save_pretrained(args.output_dir)
 | |
|     end = time.perf_counter()
 | |
|     logger.info(f"save model took: {end - start:.6f} seconds")
 | |
|     logger.info("Finish.")
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     main()
 | 
