diff --git a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu index 4c02ae797..0a20ad6f5 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu +++ b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu @@ -266,8 +266,8 @@ __global__ void permute_scale_kernel( } void W4AFp8GemmScalePermute(const paddle::Tensor& scale) { - const int row = scale.dims()[0]; - const int col = scale.dims()[1]; + const int row = scale.dims().size() == 2 ? scale.dims()[0] : 1; + const int col = scale.dims().size() == 2 ? scale.dims()[1] : scale.dims()[0]; if (col % 16 != 0) { PD_THROW("Only supported when col is divisible by 16."); } diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 2f5ee6d50..dda5a1b0c 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -566,7 +566,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod): self.moe_quant_type = "w4afp8" self.pack_num = 2 - def process_prequanted_weights(self, layer: nn.Layer, state_dict): + def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False): """ Paddle cutlass process prequanted weights. """ @@ -579,9 +579,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod): up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = ( layer.load_experts_weight( - state_dict, - up_gate_proj_expert_weight_key, - down_proj_expert_weight_key, + state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, is_rearrange ) ) @@ -594,13 +592,17 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod): if isinstance(state_dict, list): state_dict = dict(state_dict) - logger.info(f"ep_size:{layer.ep_size}") - if layer.ep_size > 1: for expert_idx in ep_rank_to_expert_id_list: - scale_tensor = get_tensor(state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)]) + scale_tensor = get_tensor( + ( + state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)] + if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict + else up_gate_proj_expert_in_scale_key.format(expert_idx) + ), + layer.fd_config.model_config.model, + ) up_gate_proj_in_scale_all_experts.append(scale_tensor) - logger.info(f"up_gate_proj_in_scale_all_experts:{up_gate_proj_in_scale_all_experts}") for expert_idx in logical_expert_ids: up_gate_proj_weight_scale.append( @@ -662,7 +664,7 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod): "down_proj_in_scale": down_proj_in_scale, } for name, tensor in name_tensor_map.items(): - getattr(layer, name).set_value(tensor) + create_and_set_parameter(layer, name, tensor) def create_weights(self, layer: nn.Layer, state_dict): """ diff --git a/scripts/offline_w4a8.py b/scripts/offline_w4a8.py new file mode 100644 index 000000000..fd10b9148 --- /dev/null +++ b/scripts/offline_w4a8.py @@ -0,0 +1,193 @@ +import argparse +import json +import os +import re +import time + +import paddle +from paddleformers.trainer import strtobool +from paddleformers.transformers.configuration_utils import PretrainedConfig +from paddleformers.transformers.model_utils import shard_checkpoint +from paddleformers.utils.env import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME +from paddleformers.utils.log import logger +from safetensors.numpy import save_file as safe_save_file + +from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer +from fastdeploy.model_executor.layers.utils import get_tensor +from fastdeploy.model_executor.load_weight_utils import ( + get_all_safetensors, + safetensors_weights_iterator, +) +from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute + + +def parse_arguments(): + """ + parse_arguments + """ + parser = argparse.ArgumentParser() + + parser.add_argument( + "--model_name_or_path", + default=None, + required=True, + help="The directory of model.", + ) + + parser.add_argument( + "--output_dir", + default="merged_output", + required=True, + help="The directory of merged model output.", + ) + + parser.add_argument( + "--safe_serialization", + type=strtobool, + default="True", + help="Whether merge the model into safetensors format.", + ) + + parser.add_argument( + "--moe_quant_type", + default="w4a8", + choices=["w4a8", "w4afp8"], + help="The moe quant type of the model.", + ) + + return parser.parse_args() + + +def reorder(): + def fn(weight, moe_quant_type): + from paddle.nn.quant import weight_quantize + + quant_weight, _ = weight_quantize(weight.cuda(), algo=moe_quant_type, arch=80) + return quant_weight.cpu() + + return fn + + +def deal_in_scale(): + def fn(in_scale): + processed_in_scale = 1 / in_scale + return processed_in_scale + + return fn + + +def deal_weight_scale(): + def fn(weight_scale, processed_in_scale, moe_quant_type): + if moe_quant_type == "w4a8": + processed_weight_scale = weight_scale / (127 * 112) / processed_in_scale + return processed_weight_scale + elif moe_quant_type == "w4afp8": + processed_weight_scale = weight_scale / (448 * 7 * 2 ** (-9)) / processed_in_scale + processed_weight_scale = w4afp8_gemm_scale_permute(processed_weight_scale.cuda()) + return processed_weight_scale + + return fn + + +# tmp support w4a8 +def deal_quant(state_dict, save_state_dict, moe_quant_type): + param_mapping = [ + # pattern,fn + (r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.activation_scale", deal_in_scale()), + (r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.weight_scale", deal_weight_scale()), + (r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.quant_weight", reorder()), + ] + for pattern, fn in param_mapping: + for key in list(state_dict.keys()): + # print(f"deal {key}") + match = re.search(pattern, key) + if match: + # print(f"{key} is match") + weight_or_scale = state_dict.pop(key) + if "weight_scale" in key: + in_scale_key = key.replace("weight_scale", "activation_scale") + in_scale = save_state_dict[in_scale_key] + save_state_dict[key] = fn(weight_or_scale, in_scale, moe_quant_type) + elif "activation_scale" in key: + save_state_dict[key] = fn(weight_or_scale) + else: + save_state_dict[key] = fn(weight_or_scale, moe_quant_type) + + +def save_safetensors(state_dict, args): + """ + save_safetensors + """ + logger.info("Move to numpy.") + for k in list(state_dict.keys()): + if isinstance(state_dict[k], paddle.Tensor): + state_dict[k] = state_dict.pop(k).cpu().numpy() + + logger.info("Save safetensors files.") + shards, index = shard_checkpoint( + state_dict, + max_shard_size="5GB", + weights_name=SAFE_WEIGHTS_NAME, + shard_format="naive", + ) + for shard_file, shard in shards.items(): + save_file = os.path.join(args.output_dir, shard_file) + logger.info(f"Saving {save_file}") + safe_save_file(shard, save_file, metadata={"format": "np"}) + + save_index_file = os.path.join(args.output_dir, SAFE_WEIGHTS_INDEX_NAME) + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2) + "\n" + f.write(content) + + +def main(): + """ + main + """ + args = parse_arguments() + pretrained_config, _ = PretrainedConfig.get_config_dict(args.model_name_or_path) + pretrained_config = PretrainedConfig.from_dict(pretrained_config) + vocab_file_names = [ + "tokenizer.model", + "spm.model", + "ernie_token_100k.model", + ] + for i in range(len(vocab_file_names)): + if os.path.exists(os.path.join(args.model_name_or_path, vocab_file_names[i])): + ErnieBotTokenizer.resource_files_names["vocab_file"] = vocab_file_names[i] + break + tokenizer = ErnieBotTokenizer.from_pretrained(args.model_name_or_path) + _, safetensor_files = get_all_safetensors(args.model_name_or_path) + weights_iterator = safetensors_weights_iterator(safetensor_files) + state_dict = {} + save_state_dict = {} + start = time.perf_counter() + for k, v in weights_iterator: + state_dict[k] = get_tensor(v).cpu() + end = time.perf_counter() + logger.info("Finish Quantize.") + logger.info(f"load and quantize took : {end - start:.6f} seconds") + deal_quant(state_dict, save_state_dict, args.moe_quant_type) + for key in list(state_dict.keys()): + save_state_dict[key] = state_dict.pop(key) + logger.info("Begin to save model") + os.makedirs(args.output_dir, exist_ok=True) + start = time.perf_counter() + if not args.safe_serialization: + paddle.save( + save_state_dict, + os.path.join(args.output_dir, "model_state.pdparams"), + ) + else: + save_safetensors(save_state_dict, args) + pretrained_config.is_permuted = True + pretrained_config.save_pretrained(args.output_dir) + tokenizer.save_pretrained(args.output_dir) + end = time.perf_counter() + logger.info(f"save model took: {end - start:.6f} seconds") + logger.info("Finish.") + + +if __name__ == "__main__": + main() diff --git a/scripts/run_offline_w4a8.sh b/scripts/run_offline_w4a8.sh new file mode 100644 index 000000000..d9f8b9126 --- /dev/null +++ b/scripts/run_offline_w4a8.sh @@ -0,0 +1,36 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -ex +rm -rf log +rm -f core* + +export devices=0 +export CUDA_VISIBLE_DEVICES=${devices} +model_path=${1:-"/PATH/MODEL_PATH"} +output_path=${2:-"/PATH/OUTPUT_MODEL"} +moe_quant_type=${3:-"w4a8"} +for name in `env | grep -E 'PADDLE|ENDPOINT' | awk -F'=' '{print $1}'`; do +unset ${name} +done +export PADDLE_TRAINER_ID=0 +export PADDLE_TRAINERS_NUM=1 +export TRAINER_INSTANCES_NUM=1 +export TRAINER_INSTANCES=`hostname -i` +self_ip=`hostname -i` + +python offline_w4a8.py \ + --model_name_or_path ${model_path} \ + --output_dir ${output_path} \ + --safe_serialization "True" \ + --moe_quant_type ${moe_quant_type}