mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 00:06:38 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
122
scripts/extract_mtp_weight_from_safetensor.py
Normal file
122
scripts/extract_mtp_weight_from_safetensor.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import paddle
|
||||
from paddleformers.transformers.model_utils import shard_checkpoint
|
||||
from paddleformers.utils.env import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
|
||||
from paddleformers.utils.log import logger
|
||||
from safetensors import safe_open
|
||||
from safetensors.numpy import save_file as safe_save_file
|
||||
|
||||
|
||||
def parse_args():
|
||||
""""""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Extract and save MTP weights from safetensors.")
|
||||
parser.add_argument("-i",
|
||||
"--input_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the input safetensors model directory.")
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the output directory for saving processed weights.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def extract_mtp_weights(input_dir: str) -> dict:
|
||||
"""
|
||||
Load all MTP-related weights from safetensors files in input_dir.
|
||||
"""
|
||||
index_path = os.path.join(input_dir, SAFE_WEIGHTS_INDEX_NAME)
|
||||
if not os.path.isfile(index_path):
|
||||
raise FileNotFoundError(f"Index file not found: {index_path}")
|
||||
|
||||
with open(index_path, "r", encoding="utf-8") as f:
|
||||
index = json.load(f)
|
||||
|
||||
weight_map = index.get("weight_map", {})
|
||||
required_files = {v for k, v in weight_map.items() if "mtp" in k}
|
||||
logger.info(f"Found {len(required_files)} shards with MTP weights.")
|
||||
|
||||
state_dict = {}
|
||||
for file_name in required_files:
|
||||
file_path = os.path.join(input_dir, file_name)
|
||||
if not os.path.isfile(file_path):
|
||||
logger.warning(f"Shard not found: {file_path}")
|
||||
continue
|
||||
logger.info(f"Loading shard: {file_path}")
|
||||
with safe_open(file_path, framework="np", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
if "mtp" in k:
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
|
||||
logger.info(f"Loaded {len(state_dict)} MTP weights.")
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_safetensors(state_dict: dict, output_dir: str):
|
||||
"""
|
||||
Save state_dict as safetensors shards into output_dir.
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
logger.info("Converting tensors to numpy arrays.")
|
||||
for k in list(state_dict.keys()):
|
||||
if isinstance(state_dict[k], paddle.Tensor):
|
||||
tensor = state_dict.pop(k)
|
||||
array = tensor.cpu().numpy()
|
||||
state_dict[k] = array
|
||||
|
||||
logger.info("Sharding and saving safetensors.")
|
||||
shards, index = shard_checkpoint(
|
||||
state_dict,
|
||||
max_shard_size="5GB",
|
||||
weights_name=SAFE_WEIGHTS_NAME,
|
||||
shard_format="naive",
|
||||
)
|
||||
|
||||
for shard_file, shard in shards.items():
|
||||
save_path = os.path.join(output_dir, shard_file)
|
||||
logger.info(f"Saving shard: {save_path}")
|
||||
safe_save_file(shard, save_path, metadata={"format": "np"})
|
||||
|
||||
index_path = os.path.join(output_dir, SAFE_WEIGHTS_INDEX_NAME)
|
||||
with open(index_path, "w", encoding="utf-8") as f:
|
||||
json.dump(index, f, indent=2)
|
||||
logger.info(f"Saved index file: {index_path}")
|
||||
|
||||
|
||||
def main():
|
||||
""""""
|
||||
args = parse_args()
|
||||
logger.info(f"Input dir: {args.input_dir}")
|
||||
logger.info(f"Output dir: {args.output_dir}")
|
||||
|
||||
state_dict = extract_mtp_weights(args.input_dir)
|
||||
save_safetensors(state_dict, args.output_dir)
|
||||
logger.info("MTP weights extracted and saved successfully.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user