""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ import argparse import json import os import paddle from paddleformers.transformers.model_utils import shard_checkpoint from paddleformers.utils.env import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME from paddleformers.utils.log import logger from safetensors import safe_open from safetensors.numpy import save_file as safe_save_file def parse_args(): """""" parser = argparse.ArgumentParser(description="Extract and save MTP weights from safetensors.") parser.add_argument( "-i", "--input_dir", type=str, required=True, help="Path to the input safetensors model directory.", ) parser.add_argument( "-o", "--output_dir", type=str, required=True, help="Path to the output directory for saving processed weights.", ) return parser.parse_args() def extract_mtp_weights(input_dir: str) -> dict: """ Load all MTP-related weights from safetensors files in input_dir. """ index_path = os.path.join(input_dir, SAFE_WEIGHTS_INDEX_NAME) if not os.path.isfile(index_path): raise FileNotFoundError(f"Index file not found: {index_path}") with open(index_path, "r", encoding="utf-8") as f: index = json.load(f) weight_map = index.get("weight_map", {}) required_files = {v for k, v in weight_map.items() if "mtp" in k} logger.info(f"Found {len(required_files)} shards with MTP weights.") state_dict = {} for file_name in required_files: file_path = os.path.join(input_dir, file_name) if not os.path.isfile(file_path): logger.warning(f"Shard not found: {file_path}") continue logger.info(f"Loading shard: {file_path}") with safe_open(file_path, framework="np", device="cpu") as f: for k in f.keys(): if "mtp" in k: state_dict[k] = f.get_tensor(k) logger.info(f"Loaded {len(state_dict)} MTP weights.") return state_dict def save_safetensors(state_dict: dict, output_dir: str): """ Save state_dict as safetensors shards into output_dir. """ os.makedirs(output_dir, exist_ok=True) logger.info("Converting tensors to numpy arrays.") for k in list(state_dict.keys()): if isinstance(state_dict[k], paddle.Tensor): tensor = state_dict.pop(k) array = tensor.cpu().numpy() state_dict[k] = array logger.info("Sharding and saving safetensors.") shards, index = shard_checkpoint( state_dict, max_shard_size="5GB", weights_name=SAFE_WEIGHTS_NAME, shard_format="naive", ) for shard_file, shard in shards.items(): save_path = os.path.join(output_dir, shard_file) logger.info(f"Saving shard: {save_path}") safe_save_file(shard, save_path, metadata={"format": "np"}) index_path = os.path.join(output_dir, SAFE_WEIGHTS_INDEX_NAME) with open(index_path, "w", encoding="utf-8") as f: json.dump(index, f, indent=2) logger.info(f"Saved index file: {index_path}") def main(): """""" args = parse_args() logger.info(f"Input dir: {args.input_dir}") logger.info(f"Output dir: {args.output_dir}") state_dict = extract_mtp_weights(args.input_dir) save_safetensors(state_dict, args.output_dir) logger.info("MTP weights extracted and saved successfully.") if __name__ == "__main__": main()