mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-23 08:39:33 +08:00
support w4afp8 offline quant (#3438)
This commit is contained in:
@@ -266,8 +266,8 @@ __global__ void permute_scale_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void W4AFp8GemmScalePermute(const paddle::Tensor& scale) {
|
void W4AFp8GemmScalePermute(const paddle::Tensor& scale) {
|
||||||
const int row = scale.dims()[0];
|
const int row = scale.dims().size() == 2 ? scale.dims()[0] : 1;
|
||||||
const int col = scale.dims()[1];
|
const int col = scale.dims().size() == 2 ? scale.dims()[1] : scale.dims()[0];
|
||||||
if (col % 16 != 0) {
|
if (col % 16 != 0) {
|
||||||
PD_THROW("Only supported when col is divisible by 16.");
|
PD_THROW("Only supported when col is divisible by 16.");
|
||||||
}
|
}
|
||||||
|
@@ -566,7 +566,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
|||||||
self.moe_quant_type = "w4afp8"
|
self.moe_quant_type = "w4afp8"
|
||||||
self.pack_num = 2
|
self.pack_num = 2
|
||||||
|
|
||||||
def process_prequanted_weights(self, layer: nn.Layer, state_dict):
|
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
|
||||||
"""
|
"""
|
||||||
Paddle cutlass process prequanted weights.
|
Paddle cutlass process prequanted weights.
|
||||||
"""
|
"""
|
||||||
@@ -579,9 +579,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
|||||||
|
|
||||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
|
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
|
||||||
layer.load_experts_weight(
|
layer.load_experts_weight(
|
||||||
state_dict,
|
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, is_rearrange
|
||||||
up_gate_proj_expert_weight_key,
|
|
||||||
down_proj_expert_weight_key,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -594,13 +592,17 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
|||||||
if isinstance(state_dict, list):
|
if isinstance(state_dict, list):
|
||||||
state_dict = dict(state_dict)
|
state_dict = dict(state_dict)
|
||||||
|
|
||||||
logger.info(f"ep_size:{layer.ep_size}")
|
|
||||||
|
|
||||||
if layer.ep_size > 1:
|
if layer.ep_size > 1:
|
||||||
for expert_idx in ep_rank_to_expert_id_list:
|
for expert_idx in ep_rank_to_expert_id_list:
|
||||||
scale_tensor = get_tensor(state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)])
|
scale_tensor = get_tensor(
|
||||||
|
(
|
||||||
|
state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)]
|
||||||
|
if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict
|
||||||
|
else up_gate_proj_expert_in_scale_key.format(expert_idx)
|
||||||
|
),
|
||||||
|
layer.fd_config.model_config.model,
|
||||||
|
)
|
||||||
up_gate_proj_in_scale_all_experts.append(scale_tensor)
|
up_gate_proj_in_scale_all_experts.append(scale_tensor)
|
||||||
logger.info(f"up_gate_proj_in_scale_all_experts:{up_gate_proj_in_scale_all_experts}")
|
|
||||||
|
|
||||||
for expert_idx in logical_expert_ids:
|
for expert_idx in logical_expert_ids:
|
||||||
up_gate_proj_weight_scale.append(
|
up_gate_proj_weight_scale.append(
|
||||||
@@ -662,7 +664,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
|
|||||||
"down_proj_in_scale": down_proj_in_scale,
|
"down_proj_in_scale": down_proj_in_scale,
|
||||||
}
|
}
|
||||||
for name, tensor in name_tensor_map.items():
|
for name, tensor in name_tensor_map.items():
|
||||||
getattr(layer, name).set_value(tensor)
|
create_and_set_parameter(layer, name, tensor)
|
||||||
|
|
||||||
def create_weights(self, layer: nn.Layer, state_dict):
|
def create_weights(self, layer: nn.Layer, state_dict):
|
||||||
"""
|
"""
|
||||||
|
193
scripts/offline_w4a8.py
Normal file
193
scripts/offline_w4a8.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddleformers.trainer import strtobool
|
||||||
|
from paddleformers.transformers.configuration_utils import PretrainedConfig
|
||||||
|
from paddleformers.transformers.model_utils import shard_checkpoint
|
||||||
|
from paddleformers.utils.env import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
|
||||||
|
from paddleformers.utils.log import logger
|
||||||
|
from safetensors.numpy import save_file as safe_save_file
|
||||||
|
|
||||||
|
from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer
|
||||||
|
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||||
|
from fastdeploy.model_executor.load_weight_utils import (
|
||||||
|
get_all_safetensors,
|
||||||
|
safetensors_weights_iterator,
|
||||||
|
)
|
||||||
|
from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute
|
||||||
|
|
||||||
|
|
||||||
|
def parse_arguments():
|
||||||
|
"""
|
||||||
|
parse_arguments
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name_or_path",
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
help="The directory of model.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
default="merged_output",
|
||||||
|
required=True,
|
||||||
|
help="The directory of merged model output.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--safe_serialization",
|
||||||
|
type=strtobool,
|
||||||
|
default="True",
|
||||||
|
help="Whether merge the model into safetensors format.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--moe_quant_type",
|
||||||
|
default="w4a8",
|
||||||
|
choices=["w4a8", "w4afp8"],
|
||||||
|
help="The moe quant type of the model.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def reorder():
|
||||||
|
def fn(weight, moe_quant_type):
|
||||||
|
from paddle.nn.quant import weight_quantize
|
||||||
|
|
||||||
|
quant_weight, _ = weight_quantize(weight.cuda(), algo=moe_quant_type, arch=80)
|
||||||
|
return quant_weight.cpu()
|
||||||
|
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def deal_in_scale():
|
||||||
|
def fn(in_scale):
|
||||||
|
processed_in_scale = 1 / in_scale
|
||||||
|
return processed_in_scale
|
||||||
|
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def deal_weight_scale():
|
||||||
|
def fn(weight_scale, processed_in_scale, moe_quant_type):
|
||||||
|
if moe_quant_type == "w4a8":
|
||||||
|
processed_weight_scale = weight_scale / (127 * 112) / processed_in_scale
|
||||||
|
return processed_weight_scale
|
||||||
|
elif moe_quant_type == "w4afp8":
|
||||||
|
processed_weight_scale = weight_scale / (448 * 7 * 2 ** (-9)) / processed_in_scale
|
||||||
|
processed_weight_scale = w4afp8_gemm_scale_permute(processed_weight_scale.cuda())
|
||||||
|
return processed_weight_scale
|
||||||
|
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
# tmp support w4a8
|
||||||
|
def deal_quant(state_dict, save_state_dict, moe_quant_type):
|
||||||
|
param_mapping = [
|
||||||
|
# pattern,fn
|
||||||
|
(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.activation_scale", deal_in_scale()),
|
||||||
|
(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.weight_scale", deal_weight_scale()),
|
||||||
|
(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.quant_weight", reorder()),
|
||||||
|
]
|
||||||
|
for pattern, fn in param_mapping:
|
||||||
|
for key in list(state_dict.keys()):
|
||||||
|
# print(f"deal {key}")
|
||||||
|
match = re.search(pattern, key)
|
||||||
|
if match:
|
||||||
|
# print(f"{key} is match")
|
||||||
|
weight_or_scale = state_dict.pop(key)
|
||||||
|
if "weight_scale" in key:
|
||||||
|
in_scale_key = key.replace("weight_scale", "activation_scale")
|
||||||
|
in_scale = save_state_dict[in_scale_key]
|
||||||
|
save_state_dict[key] = fn(weight_or_scale, in_scale, moe_quant_type)
|
||||||
|
elif "activation_scale" in key:
|
||||||
|
save_state_dict[key] = fn(weight_or_scale)
|
||||||
|
else:
|
||||||
|
save_state_dict[key] = fn(weight_or_scale, moe_quant_type)
|
||||||
|
|
||||||
|
|
||||||
|
def save_safetensors(state_dict, args):
|
||||||
|
"""
|
||||||
|
save_safetensors
|
||||||
|
"""
|
||||||
|
logger.info("Move to numpy.")
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
if isinstance(state_dict[k], paddle.Tensor):
|
||||||
|
state_dict[k] = state_dict.pop(k).cpu().numpy()
|
||||||
|
|
||||||
|
logger.info("Save safetensors files.")
|
||||||
|
shards, index = shard_checkpoint(
|
||||||
|
state_dict,
|
||||||
|
max_shard_size="5GB",
|
||||||
|
weights_name=SAFE_WEIGHTS_NAME,
|
||||||
|
shard_format="naive",
|
||||||
|
)
|
||||||
|
for shard_file, shard in shards.items():
|
||||||
|
save_file = os.path.join(args.output_dir, shard_file)
|
||||||
|
logger.info(f"Saving {save_file}")
|
||||||
|
safe_save_file(shard, save_file, metadata={"format": "np"})
|
||||||
|
|
||||||
|
save_index_file = os.path.join(args.output_dir, SAFE_WEIGHTS_INDEX_NAME)
|
||||||
|
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||||
|
content = json.dumps(index, indent=2) + "\n"
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""
|
||||||
|
main
|
||||||
|
"""
|
||||||
|
args = parse_arguments()
|
||||||
|
pretrained_config, _ = PretrainedConfig.get_config_dict(args.model_name_or_path)
|
||||||
|
pretrained_config = PretrainedConfig.from_dict(pretrained_config)
|
||||||
|
vocab_file_names = [
|
||||||
|
"tokenizer.model",
|
||||||
|
"spm.model",
|
||||||
|
"ernie_token_100k.model",
|
||||||
|
]
|
||||||
|
for i in range(len(vocab_file_names)):
|
||||||
|
if os.path.exists(os.path.join(args.model_name_or_path, vocab_file_names[i])):
|
||||||
|
ErnieBotTokenizer.resource_files_names["vocab_file"] = vocab_file_names[i]
|
||||||
|
break
|
||||||
|
tokenizer = ErnieBotTokenizer.from_pretrained(args.model_name_or_path)
|
||||||
|
_, safetensor_files = get_all_safetensors(args.model_name_or_path)
|
||||||
|
weights_iterator = safetensors_weights_iterator(safetensor_files)
|
||||||
|
state_dict = {}
|
||||||
|
save_state_dict = {}
|
||||||
|
start = time.perf_counter()
|
||||||
|
for k, v in weights_iterator:
|
||||||
|
state_dict[k] = get_tensor(v).cpu()
|
||||||
|
end = time.perf_counter()
|
||||||
|
logger.info("Finish Quantize.")
|
||||||
|
logger.info(f"load and quantize took : {end - start:.6f} seconds")
|
||||||
|
deal_quant(state_dict, save_state_dict, args.moe_quant_type)
|
||||||
|
for key in list(state_dict.keys()):
|
||||||
|
save_state_dict[key] = state_dict.pop(key)
|
||||||
|
logger.info("Begin to save model")
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
start = time.perf_counter()
|
||||||
|
if not args.safe_serialization:
|
||||||
|
paddle.save(
|
||||||
|
save_state_dict,
|
||||||
|
os.path.join(args.output_dir, "model_state.pdparams"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
save_safetensors(save_state_dict, args)
|
||||||
|
pretrained_config.is_permuted = True
|
||||||
|
pretrained_config.save_pretrained(args.output_dir)
|
||||||
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
|
end = time.perf_counter()
|
||||||
|
logger.info(f"save model took: {end - start:.6f} seconds")
|
||||||
|
logger.info("Finish.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
36
scripts/run_offline_w4a8.sh
Normal file
36
scripts/run_offline_w4a8.sh
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
set -ex
|
||||||
|
rm -rf log
|
||||||
|
rm -f core*
|
||||||
|
|
||||||
|
export devices=0
|
||||||
|
export CUDA_VISIBLE_DEVICES=${devices}
|
||||||
|
model_path=${1:-"/PATH/MODEL_PATH"}
|
||||||
|
output_path=${2:-"/PATH/OUTPUT_MODEL"}
|
||||||
|
moe_quant_type=${3:-"w4a8"}
|
||||||
|
for name in `env | grep -E 'PADDLE|ENDPOINT' | awk -F'=' '{print $1}'`; do
|
||||||
|
unset ${name}
|
||||||
|
done
|
||||||
|
export PADDLE_TRAINER_ID=0
|
||||||
|
export PADDLE_TRAINERS_NUM=1
|
||||||
|
export TRAINER_INSTANCES_NUM=1
|
||||||
|
export TRAINER_INSTANCES=`hostname -i`
|
||||||
|
self_ip=`hostname -i`
|
||||||
|
|
||||||
|
python offline_w4a8.py \
|
||||||
|
--model_name_or_path ${model_path} \
|
||||||
|
--output_dir ${output_path} \
|
||||||
|
--safe_serialization "True" \
|
||||||
|
--moe_quant_type ${moe_quant_type}
|
Reference in New Issue
Block a user