diff --git a/docs/model_compression/quant.md b/docs/model_compression/quant.md new file mode 100644 index 000000000..db61c2d1f --- /dev/null +++ b/docs/model_compression/quant.md @@ -0,0 +1,37 @@ +# FastDeploy 支持量化模型部署 +量化是一种流行的模型压缩方法,量化后的模型拥有更小的体积和更快的推理速度. FastDeploy支持部署量化后的模型,帮助用户实现推理加速. + +## 1. FastDeploy 多个引擎支持量化模型部署 + +当前,FastDeploy中多个推理后端可以在不同硬件上支持量化模型的部署. 支持情况如下: + +| 硬件/推理后端 | ONNXRuntime | Paddle Inference | TensorRT | +| :-----------| :-------- | :--------------- | :------- | +| CPU | 支持 | 支持 | | +| GPU | | | 支持 | + + +## 2. 用户如何量化模型 + +### 量化方式 +用户可以通过PaddleSlim来量化模型, 量化主要有量化训练和离线量化两种方式, 量化训练通过模型训练来获得量化模型, 离线量化不需要模型训练即可完成模型的量化. FastDeploy 对两种方式产出的量化模型均能部署. +两种方法的主要对比如下表所示: +| 量化方法 | 量化过程耗时 | 量化模型精度 | 模型体积 | 推理速度 | +| :-----------| :--------| :-------| :------- | :------- | +| 离线量化 | 无需训练,耗时短 | 比量化训练稍低 | 两者一致 | 两者一致 | +| 量化训练 | 需要训练,耗时高 | 较未量化模型有少量损失 | 两者一致 |两者一致 | + +### 用户使用fastdeploy_quant命令量化模型 +Fastdeploy 为用户提供了一键模型量化的功能,请参考如下文档进行模型量化. +- [FastDeploy 一键模型量化](../../tools/quantization/) +当用户获得产出的量化模型之后,即可以使用FastDeploy来部署量化模型. + +## 3. FastDeploy 部署量化模型 +用户只需要简单地传入量化后的模型路径及相应参数,即可以使用FastDeploy进行部署. +具体请用户参考示例文档: +- [YOLOv5s 量化模型Python部署](../../examples/slim/yolov5s/python/) +- [YOLOv5s 量化模型C++部署](../../examples/slim/yolov5s/cpp/) +- [YOLOv6s 量化模型Python部署](../../examples/slim/yolov6s/python/) +- [YOLOv6s 量化模型C++部署](../../examples/slim/yolov6s/cpp/) +- [YOLOv7 量化模型Python部署](../../examples/slim/yolov7/python/) +- [YOLOv7 量化模型C++部署](../../examples/slim/yolov7/cpp/) diff --git a/tools/quantization/configs/classification/mobilenetv1_ssld_quant.yaml b/tools/quantization/configs/classification/mobilenetv1_ssld_quant.yaml new file mode 100644 index 000000000..aa4ae5a71 --- /dev/null +++ b/tools/quantization/configs/classification/mobilenetv1_ssld_quant.yaml @@ -0,0 +1,49 @@ +Global: + model_dir: ./MobileNetV1_ssld_infer/ + format: 'paddle' + model_filename: inference.pdmodel + params_filename: inference.pdiparams + image_path: ./ImageNet_val_640 + arch: MobileNetV1 + input_list: ['input'] + preprocess: cls_image_preprocess + + +Distillation: + alpha: 1.0 + loss: l2 + node: + - softmax_0.tmp_0 + + +Quantization: + use_pact: true + activation_bits: 8 + is_full_quantize: false + onnx_format: True + activation_quantize_type: moving_average_abs_max + weight_quantize_type: channel_wise_abs_max + not_quant_pattern: + - skip_quant + quantize_op_types: + - conv2d + - depthwise_conv2d + weight_bits: 8 + + +TrainConfig: + train_iter: 5000 + learning_rate: + type: CosineAnnealingDecay + learning_rate: 0.015 + T_max: 8000 + optimizer_builder: + optimizer: + type: Momentum + weight_decay: 0.00002 + origin_metric: 0.70898 + + +PTQ: + calibration_method: 'avg' # option: avg, abs_max, hist, KL, mse + skip_tensor_list: None diff --git a/tools/quantization/configs/classification/resnet50_vd_quant.yaml b/tools/quantization/configs/classification/resnet50_vd_quant.yaml new file mode 100644 index 000000000..ab2264e50 --- /dev/null +++ b/tools/quantization/configs/classification/resnet50_vd_quant.yaml @@ -0,0 +1,47 @@ +Global: + model_dir: ./ResNet50_vd_infer/ + format: 'paddle' + model_filename: inference.pdmodel + params_filename: inference.pdiparams + image_path: ./ImageNet_val_640 + arch: ResNet50 + input_list: ['input'] + preprocess: cls_image_preprocess + + +Distillation: + alpha: 1.0 + loss: l2 + node: + - softmax_0.tmp_0 + +Quantization: + use_pact: true + activation_bits: 8 + is_full_quantize: false + onnx_format: True + activation_quantize_type: moving_average_abs_max + weight_quantize_type: channel_wise_abs_max + not_quant_pattern: + - skip_quant + quantize_op_types: + - conv2d + - depthwise_conv2d + weight_bits: 8 + +TrainConfig: + train_iter: 5000 + learning_rate: + type: CosineAnnealingDecay + learning_rate: 0.015 + T_max: 8000 + optimizer_builder: + optimizer: + type: Momentum + weight_decay: 0.00002 + origin_metric: 0.7912 + + +PTQ: + calibration_method: 'avg' # option: avg, abs_max, hist, KL, mse + skip_tensor_list: None diff --git a/tools/quantization/configs/detection/yolov5s_quant.yaml b/tools/quantization/configs/detection/yolov5s_quant.yaml new file mode 100644 index 000000000..58d4332e1 --- /dev/null +++ b/tools/quantization/configs/detection/yolov5s_quant.yaml @@ -0,0 +1,35 @@ +Global: + model_dir: ./yolov5s.onnx + format: 'onnx' + model_filename: model.pdmodel + params_filename: model.pdiparams + image_path: ./COCO_val_320 + arch: YOLOv5 + input_list: ['x2paddle_images'] + preprocess: yolo_image_preprocess + +Distillation: + alpha: 1.0 + loss: soft_label + +Quantization: + onnx_format: true + use_pact: true + activation_quantize_type: 'moving_average_abs_max' + quantize_op_types: + - conv2d + - depthwise_conv2d + + +PTQ: + calibration_method: 'avg' # option: avg, abs_max, hist, KL, mse + skip_tensor_list: None + +TrainConfig: + train_iter: 3000 + learning_rate: 0.00001 + optimizer_builder: + optimizer: + type: SGD + weight_decay: 4.0e-05 + target_metric: 0.365 diff --git a/tools/quantization/configs/detection/yolov6s_quant.yaml b/tools/quantization/configs/detection/yolov6s_quant.yaml new file mode 100644 index 000000000..1c35ff394 --- /dev/null +++ b/tools/quantization/configs/detection/yolov6s_quant.yaml @@ -0,0 +1,35 @@ +Global: + model_dir: ./yolov6s.onnx + format: 'onnx' + model_filename: model.pdmodel + params_filename: model.pdiparams + image_path: ./COCO_val_320 + arch: YOLOv6 + input_list: ['x2paddle_image_arrays'] + +Distillation: + alpha: 1.0 + loss: soft_label + +Quantization: + onnx_format: true + activation_quantize_type: 'moving_average_abs_max' + quantize_op_types: + - conv2d + - depthwise_conv2d + + +PTQ: + calibration_method: 'avg' # option: avg, abs_max, hist, KL, mse + skip_tensor_list: ['conv2d_2.w_0', 'conv2d_15.w_0', 'conv2d_46.w_0', 'conv2d_11.w_0', 'conv2d_49.w_0'] + +TrainConfig: + train_iter: 8000 + learning_rate: + type: CosineAnnealingDecay + learning_rate: 0.00003 + T_max: 8000 + optimizer_builder: + optimizer: + type: SGD + weight_decay: 0.00004 diff --git a/tools/quantization/configs/detection/yolov7_quant.yaml b/tools/quantization/configs/detection/yolov7_quant.yaml new file mode 100644 index 000000000..f04e8dc63 --- /dev/null +++ b/tools/quantization/configs/detection/yolov7_quant.yaml @@ -0,0 +1,34 @@ +Global: + model_dir: ./yolov7.onnx + format: 'onnx' + model_filename: model.pdmodel + params_filename: model.pdiparams + image_path: ./COCO_val_320 + arch: YOLOv7 + input_list: ['x2paddle_images'] + +Distillation: + alpha: 1.0 + loss: soft_label + +Quantization: + onnx_format: true + activation_quantize_type: 'moving_average_abs_max' + quantize_op_types: + - conv2d + - depthwise_conv2d + +PTQ: + calibration_method: 'avg' # option: avg, abs_max, hist, KL, mse + skip_tensor_list: None + +TrainConfig: + train_iter: 3000 + learning_rate: + type: CosineAnnealingDecay + learning_rate: 0.00003 + T_max: 8000 + optimizer_builder: + optimizer: + type: SGD + weight_decay: 0.00004 diff --git a/tools/quantization/configs/readme.md b/tools/quantization/configs/readme.md new file mode 100644 index 000000000..782584815 --- /dev/null +++ b/tools/quantization/configs/readme.md @@ -0,0 +1,48 @@ +# FastDeploy 量化配置文件说明 +FastDeploy 量化配置文件中,包含了全局配置,量化蒸馏训练配置,离线量化配置和训练配置. +用户除了直接使用FastDeploy提供在本目录的配置文件外,可以按需求自行修改相关配置文件 + +## 实例解读 + +``` +#全局信息 +Global: + model_dir: ./yolov7-tiny.onnx #输入模型路径 + format: 'onnx' #输入模型格式,选项为 onnx 或者 paddle + model_filename: model.pdmodel #paddle模型的模型文件名 + params_filename: model.pdiparams #paddle模型的参数文件名 + image_path: ./COCO_val_320 #PTQ所有的Calibration数据集或者量化训练所用的训练集 + arch: YOLOv7 #模型系列 + +#量化蒸馏训练中的蒸馏参数设置 +Distillation: + alpha: 1.0 + loss: soft_label + +#量化蒸馏训练中的量化参数设置 +Quantization: + onnx_format: true + activation_quantize_type: 'moving_average_abs_max' + quantize_op_types: + - conv2d + - depthwise_conv2d + +#离线量化参数配置 +PTQ: + calibration_method: 'avg' #Calibraion算法,可选为 avg, abs_max, hist, KL, mse + skip_tensor_list: None #不进行离线量化的tensor + + +#训练参数 +TrainConfig: + train_iter: 3000 + learning_rate: + type: CosineAnnealingDecay + learning_rate: 0.00003 + T_max: 8000 + optimizer_builder: + optimizer: + type: SGD + weight_decay: 0.00004 + +``` diff --git a/tools/quantization/fdquant/__init__.py b/tools/quantization/fdquant/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tools/quantization/fdquant/dataset.py b/tools/quantization/fdquant/dataset.py new file mode 100644 index 000000000..a373d973d --- /dev/null +++ b/tools/quantization/fdquant/dataset.py @@ -0,0 +1,150 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import os +import numpy as np +import paddle + + +def generate_scale(im, target_shape): + origin_shape = im.shape[:2] + im_size_min = np.min(origin_shape) + im_size_max = np.max(origin_shape) + target_size_min = np.min(target_shape) + target_size_max = np.max(target_shape) + im_scale = float(target_size_min) / float(im_size_min) + if np.round(im_scale * im_size_max) > target_size_max: + im_scale = float(target_size_max) / float(im_size_max) + im_scale_x = im_scale + im_scale_y = im_scale + + return im_scale_y, im_scale_x + + +def yolo_image_preprocess(img, target_shape=[640, 640]): + # Resize image + im_scale_y, im_scale_x = generate_scale(img, target_shape) + img = cv2.resize( + img, + None, + None, + fx=im_scale_x, + fy=im_scale_y, + interpolation=cv2.INTER_LINEAR) + # Pad + im_h, im_w = img.shape[:2] + h, w = target_shape[:] + if h != im_h or w != im_w: + canvas = np.ones((h, w, 3), dtype=np.float32) + canvas *= np.array([114.0, 114.0, 114.0], dtype=np.float32) + canvas[0:im_h, 0:im_w, :] = img.astype(np.float32) + img = canvas + img = np.transpose(img / 255, [2, 0, 1]) + + return img.astype(np.float32) + + +def cls_resize_short(img, target_size): + + img_h, img_w = img.shape[:2] + percent = float(target_size) / min(img_w, img_h) + w = int(round(img_w * percent)) + h = int(round(img_h * percent)) + + return cv2.resize(img, (w, h), interpolation=cv2.INTER_LINEAR) + + +def crop_image(img, target_size, center): + + height, width = img.shape[:2] + size = target_size + + if center == True: + w_start = (width - size) // 2 + h_start = (height - size) // 2 + else: + w_start = np.random.randint(0, width - size + 1) + h_start = np.random.randint(0, height - size + 1) + w_end = w_start + size + h_end = h_start + size + + return img[h_start:h_end, w_start:w_end, :] + + +def cls_image_preprocess(img): + + # resize + img = cls_resize_short(img, target_size=256) + # crop + img = crop_image(img, target_size=224, center=True) + + #ToCHWImage & Normalize + img = np.transpose(img / 255, [2, 0, 1]) + + img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) + img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) + img -= img_mean + img /= img_std + + return img.astype(np.float32) + + +def ppdet_resize_no_keepratio(img, target_shape=[640, 640]): + im_shape = img.shape + + resize_h, resize_w = target_shape + im_scale_y = resize_h / im_shape[0] + im_scale_x = resize_w / im_shape[1] + + scale_factor = np.asarray([im_scale_y, im_scale_x], dtype=np.float32) + return cv2.resize( + img, None, None, fx=im_scale_x, fy=im_scale_y, + interpolation=2), scale_factor + + +def ppdet_normliaze(img, is_scale=True): + + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + img = img.astype(np.float32, copy=False) + + if is_scale: + scale = 1.0 / 255.0 + img *= scale + + mean = np.array(mean)[np.newaxis, np.newaxis, :] + std = np.array(std)[np.newaxis, np.newaxis, :] + img -= mean + img /= std + return img + + +def hwc_to_chw(img): + img = img.transpose((2, 0, 1)) + return img + + +def ppdet_image_preprocess(img): + + img, scale_factor = ppdet_resize_no_keepratio(img, target_shape=[640, 640]) + + img = np.transpose(img / 255, [2, 0, 1]) + + img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) + img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) + img -= img_mean + img /= img_std + + return img.astype(np.float32), scale_factor diff --git a/tools/quantization/fdquant/fdquant.py b/tools/quantization/fdquant/fdquant.py new file mode 100644 index 000000000..4d2bb511e --- /dev/null +++ b/tools/quantization/fdquant/fdquant.py @@ -0,0 +1,155 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import numpy as np +import time +import argparse +from tqdm import tqdm +import paddle +from paddleslim.common import load_config, load_onnx_model +from paddleslim.auto_compression import AutoCompression +from paddleslim.quant import quant_post_static +from fdquant.dataset import * + + +def argsparser(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + '--config_path', + type=str, + default=None, + help="path of compression strategy config.", + required=True) + parser.add_argument( + '--method', + type=str, + default=None, + help="choose PTQ or QAT as quantization method", + required=True) + parser.add_argument( + '--save_dir', + type=str, + default='output', + help="directory to save compressed model.") + parser.add_argument( + '--devices', + type=str, + default='gpu', + help="which device used to compress.") + + return parser + + +def reader_wrapper(reader, input_list=None): + def gen(): + for data_list in reader: + in_dict = {} + for data in data_list: + for i, input_name in enumerate(input_list): + in_dict[input_name] = data[i] + yield in_dict + + return gen + + +def main(): + + time_s = time.time() + + paddle.enable_static() + parser = argsparser() + FLAGS = parser.parse_args() + + assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu'] + paddle.set_device(FLAGS.devices) + + global global_config + all_config = load_config(FLAGS.config_path) + assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}" + global_config = all_config["Global"] + input_list = global_config['input_list'] + + assert os.path.exists(global_config[ + 'image_path']), "image_path does not exist!" + paddle.vision.image.set_image_backend('cv2') + # transform could be customized. + train_dataset = paddle.vision.datasets.ImageFolder( + global_config['image_path'], + transform=eval(global_config['preprocess'])) + train_loader = paddle.io.DataLoader( + train_dataset, + batch_size=1, + shuffle=True, + drop_last=True, + num_workers=0) + train_loader = reader_wrapper(train_loader, input_list=input_list) + eval_func = None + + # ACT compression + if FLAGS.method == 'QAT': + ac = AutoCompression( + model_dir=global_config['model_dir'], + model_filename=global_config['model_filename'], + params_filename=global_config['params_filename'], + train_dataloader=train_loader, + save_dir=FLAGS.save_dir, + config=all_config, + eval_callback=eval_func) + ac.compress() + + # PTQ compression + if FLAGS.method == 'PTQ': + + # Read PTQ config + assert "PTQ" in all_config, f"Key 'PTQ' not found in config file. \n{all_config}" + ptq_config = all_config["PTQ"] + + # Inititalize the executor + place = paddle.CUDAPlace( + 0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() + exe = paddle.static.Executor(place) + + # Read ONNX or PADDLE format model + if global_config['format'] == 'onnx': + load_onnx_model(global_config["model_dir"]) + inference_model_path = global_config["model_dir"].rstrip().rstrip( + '.onnx') + '_infer' + else: + inference_model_path = global_config["model_dir"].rstrip('/') + + quant_post_static( + executor=exe, + model_dir=inference_model_path, + quantize_model_path=FLAGS.save_dir, + data_loader=train_loader, + model_filename=global_config["model_filename"], + params_filename=global_config["params_filename"], + batch_size=32, + batch_nums=10, + algo=ptq_config['calibration_method'], + hist_percent=0.999, + is_full_quantize=False, + bias_correction=False, + onnx_format=True, + skip_tensor_list=ptq_config['skip_tensor_list'] + if 'skip_tensor_list' in ptq_config else None) + + time_total = time.time() - time_s + print("Finish Compression, total time used is : ", time_total, "seconds.") + + +if __name__ == '__main__': + main() diff --git a/tools/quantization/readme.md b/tools/quantization/readme.md new file mode 100644 index 000000000..600f79441 --- /dev/null +++ b/tools/quantization/readme.md @@ -0,0 +1,120 @@ +# FastDeploy 一键模型量化 +FastDeploy 给用户提供了一键量化功能, 支持离线量化和量化蒸馏训练. 本文档已Yolov5s为例, 用户可参考如何安装并执行FastDeploy的一键量化功能. + +## 1.安装 + +### 环境依赖 + +1.用户参考PaddlePaddle官网, 安装develop版本 +``` +https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/develop/install/pip/linux-pip.html +``` + +2.安装paddleslim-develop版本 +```bash +git clone https://github.com/PaddlePaddle/PaddleSlim.git & cd PaddleSlim +python setup.py install +``` + +### FastDeploy-Quantization 安装方式 +用户在当前目录下,运行如下命令: +``` +python setup.py install +``` + +## 2.使用方式 + +### 一键离线量化示例 + +#### 离线量化 + +##### 1. 准备模型和Calibration数据集 +用户需要自行准备待量化模型与Calibration数据集. +本例中用户可执行以下命令, 下载待量化的yolov5s.onnx模型和我们为用户准备的Calibration数据集示例. + +```shell +# 下载yolov5.onnx +wget https://bj.bcebos.com/paddlehub/fastdeploy/yolov5s.onnx + +# 下载数据集, 此Calibration数据集为COCO val2017中的前320张图片 +wget https://bj.bcebos.com/paddlehub/fastdeploy/COCO_val_320.tar.gz +tar -xvf COCO_val_320.tar.gz +``` + +##### 2.使用fastdeploy_quant命令,执行一键模型量化: + +```shell +fastdeploy_quant --config_path=./configs/detection/yolov5s_quant.yaml --method='PTQ' --save_dir='./yolov5s_ptq_model/' +``` + +##### 3.参数说明 + +| 参数 | 作用 | +| -------------------- | ------------------------------------------------------------ | +| --config_path | 一键量化所需要的量化配置文件.[详解](./fdquant/configs/readme.md) | +| --method | 量化方式选择, 离线量化选PTQ,量化蒸馏训练选QAT | +| --save_dir | 产出的量化后模型路径, 该模型可直接在FastDeploy部署 | + +注意:目前fastdeploy_quant暂时只支持YOLOv5,YOLOv6和YOLOv7模型的量化 + + +#### 量化蒸馏训练 + +##### 1.准备待量化模型和训练数据集 +FastDeploy目前的量化蒸馏训练,只支持无标注图片训练,训练过程中不支持评估模型精度. +数据集为真实预测场景下的图片,图片数量依据数据集大小来定,尽量覆盖所有部署场景. 此例中,我们为用户准备了COCO2017验证集中的前320张图片. + +```shell +# 下载yolov5.onnx +wget https://bj.bcebos.com/paddlehub/fastdeploy/yolov5s.onnx + +# 下载数据集, 此Calibration数据集为COCO2017验证集中的前320张图片 +wget https://bj.bcebos.com/paddlehub/fastdeploy/COCO_val_320.tar.gz +tar -xvf COCO_val_320.tar.gz +``` + +##### 2.使用fastdeploy_quant命令,执行一键模型量化: + +```shell +export CUDA_VISIBLE_DEVICES=0 +fastdeploy_quant --config_path=./configs/detection/yolov5s_quant.yaml --method='QAT' --save_dir='./yolov5s_qat_model/' +``` + +##### 3.参数说明 + +| 参数 | 作用 | +| -------------------- | ------------------------------------------------------------ | +| --config_path | 一键量化所需要的量化配置文件.[详解](./fdquant/configs/readme.md) | +| --method | 量化方式选择, 离线量化选PTQ,量化蒸馏训练选QAT | +| --save_dir | 产出的量化后模型路径, 该模型可直接在FastDeploy部署 | + +注意:目前fastdeploy_quant暂时只支持YOLOv5,YOLOv6和YOLOv7模型的量化 + + +## 3. FastDeploy 部署量化模型 +用户在获得量化模型之后,只需要简单地传入量化后的模型路径及相应参数,即可以使用FastDeploy进行部署. +具体请用户参考示例文档: +- [YOLOv5s 量化模型Python部署](../examples/slim/yolov5s/python/) +- [YOLOv5s 量化模型C++部署](../examples/slim/yolov5s/cpp/) +- [YOLOv6s 量化模型Python部署](../examples/slim/yolov6s/python/) +- [YOLOv6s 量化模型C++部署](../examples/slim/yolov6s/cpp/) +- [YOLOv7 量化模型Python部署](../examples/slim/yolov7/python/) +- [YOLOv7 量化模型C++部署](../examples/slim/yolov7/cpp/) + +## 4.Benchmark +下表为模型量化前后,在FastDeploy部署的端到端推理性能. +- 测试图片为COCO val2017中的图片. +- 推理时延为端到端推理(包含前后处理)的平均时延, 单位是毫秒. +- CPU为Intel(R) Xeon(R) Gold 6271C, GPU为Tesla T4, TensorRT版本8.4.15, 所有测试中固定CPU线程数为1. + +| 模型 |推理后端 |部署硬件 | FP32推理时延 | INT8推理时延 | 加速比 | FP32 mAP | INT8 mAP | +| ------------------- | -----------------|-----------| -------- |-------- |-------- | --------- |-------- | +| YOLOv5s | TensorRT | GPU | 14.13 | 11.22 | 1.26 | 37.6 | 36.6 | +| YOLOv5s | ONNX Runtime | CPU | 183.68 | 100.39 | 1.83 | 37.6 | 33.1 | +| YOLOv5s | Paddle Inference | CPU | 226.36 | 152.27 | 1.48 |37.6 | 36.8 | +| YOLOv6s | TensorRT | GPU | 12.89 | 8.92 | 1.45 | 42.5 | 40.6| +| YOLOv6s | ONNX Runtime | CPU | 345.85 | 131.81 | 2.60 |42.5| 36.1| +| YOLOv6s | Paddle Inference | CPU | 366.41 | 131.70 | 2.78 |42.5| 41.2| +| YOLOv7 | TensorRT | GPU | 30.43 | 15.40 | 1.98 | 51.1| 50.8| +| YOLOv7 | ONNX Runtime | CPU | 971.27 | 471.88 | 2.06 | 51.1 | 42.5| +| YOLOv7 | Paddle Inference | CPU | 1015.70 | 562.41 | 1.82 |51.1 | 46.3| diff --git a/tools/quantization/requirements.txt b/tools/quantization/requirements.txt new file mode 100644 index 000000000..b9b109222 --- /dev/null +++ b/tools/quantization/requirements.txt @@ -0,0 +1 @@ +paddleslim diff --git a/tools/quantization/setup.py b/tools/quantization/setup.py new file mode 100644 index 000000000..a0c0c2fc0 --- /dev/null +++ b/tools/quantization/setup.py @@ -0,0 +1,25 @@ +import setuptools +import fdquant + +long_description = "FDQuant is a toolkit for model quantization of FastDeploy.\n\n" +long_description += "Usage: fastdeploy_quant --config_path=./yolov7_tiny_qat_dis.yaml --method='QAT' --save_dir='../v7_qat_outmodel/' \n" + +with open("requirements.txt") as fin: + REQUIRED_PACKAGES = fin.read() + +setuptools.setup( + name="fastdeploy-quantization", # name of package + description="A toolkit for model quantization of FastDeploy.", + long_description=long_description, + long_description_content_type="text/plain", + packages=setuptools.find_packages(), + install_requires=REQUIRED_PACKAGES, + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + ], + license='Apache 2.0', + entry_points={ + 'console_scripts': ['fastdeploy_quant=fdquant.fdquant:main', ] + })