diff --git a/examples/vision/ocr/PP-OCRv2/cpp/infer.cc b/examples/vision/ocr/PP-OCRv2/cpp/infer.cc index 9d628689b..2537c12bb 100644 --- a/examples/vision/ocr/PP-OCRv2/cpp/infer.cc +++ b/examples/vision/ocr/PP-OCRv2/cpp/infer.cc @@ -29,9 +29,25 @@ void InitAndInfer(const std::string& det_model_dir, const std::string& cls_model auto rec_model_file = rec_model_dir + sep + "inference.pdmodel"; auto rec_params_file = rec_model_dir + sep + "inference.pdiparams"; - auto det_model = fastdeploy::vision::ocr::DBDetector(det_model_file, det_params_file, option); - auto cls_model = fastdeploy::vision::ocr::Classifier(cls_model_file, cls_params_file, option); - auto rec_model = fastdeploy::vision::ocr::Recognizer(rec_model_file, rec_params_file, rec_label_file, option); + auto det_option = option; + auto cls_option = option; + auto rec_option = option; + + // If use TRT backend, the dynamic shape will be set as follow. + det_option.SetTrtInputShape("x", {1, 3, 50, 50}, {1, 3, 640, 640}, + {1, 3, 1536, 1536}); + cls_option.SetTrtInputShape("x", {1, 3, 48, 10}, {1, 3, 48, 320}, {1, 3, 48, 1024}); + rec_option.SetTrtInputShape("x", {1, 3, 32, 10}, {1, 3, 32, 320}, + {1, 3, 32, 2304}); + + // Users could save TRT cache file to disk as follow. + // det_option.SetTrtCacheFile(det_model_dir + sep + "det_trt_cache.trt"); + // cls_option.SetTrtCacheFile(cls_model_dir + sep + "cls_trt_cache.trt"); + // rec_option.SetTrtCacheFile(rec_model_dir + sep + "rec_trt_cache.trt"); + + auto det_model = fastdeploy::vision::ocr::DBDetector(det_model_file, det_params_file, det_option); + auto cls_model = fastdeploy::vision::ocr::Classifier(cls_model_file, cls_params_file, cls_option); + auto rec_model = fastdeploy::vision::ocr::Recognizer(rec_model_file, rec_params_file, rec_label_file, rec_option); assert(det_model.Initialized()); assert(cls_model.Initialized()); diff --git a/examples/vision/ocr/PP-OCRv2/python/infer.py b/examples/vision/ocr/PP-OCRv2/python/infer.py index 984ede8e7..1bb94eb7e 100644 --- a/examples/vision/ocr/PP-OCRv2/python/infer.py +++ b/examples/vision/ocr/PP-OCRv2/python/infer.py @@ -96,19 +96,29 @@ rec_model_file = os.path.join(args.rec_model, "inference.pdmodel") rec_params_file = os.path.join(args.rec_model, "inference.pdiparams") rec_label_file = args.rec_label_file -# 对于三个模型,均采用同样的部署配置 -# 用户也可根据自行需求分别配置 -runtime_option = build_option(args) +det_option = runtime_option +cls_option = runtime_option +rec_option = runtime_option + +# 当使用TRT时,分别给三个Runtime设置动态shape +det_option.set_trt_input_shape("x", [1, 3, 50, 50], [1, 3, 640, 640], + [1, 3, 1536, 1536]) +cls_option.set_trt_input_shape("x", [1, 3, 48, 10], [1, 3, 48, 320], + [1, 3, 48, 1024]) +rec_option.set_trt_input_shape("x", [1, 3, 32, 10], [1, 3, 32, 320], + [1, 3, 32, 2304]) + +# 用户可以把TRT引擎文件保存至本地 +# det_option.set_trt_cache_file(args.det_model + "/det_trt_cache.trt") +# cls_option.set_trt_cache_file(args.cls_model + "/cls_trt_cache.trt") +# rec_option.set_trt_cache_file(args.rec_model + "/rec_trt_cache.trt") det_model = fd.vision.ocr.DBDetector( - det_model_file, det_params_file, runtime_option=runtime_option) + det_model_file, det_params_file, runtime_option=det_option) cls_model = fd.vision.ocr.Classifier( - cls_model_file, cls_params_file, runtime_option=runtime_option) + cls_model_file, cls_params_file, runtime_option=cls_option) rec_model = fd.vision.ocr.Recognizer( - rec_model_file, - rec_params_file, - rec_label_file, - runtime_option=runtime_option) + rec_model_file, rec_params_file, rec_label_file, runtime_option=rec_option) # 创建PP-OCR,串联3个模型,其中cls_model可选,如无需求,可设置为None ppocr_v2 = fd.vision.ocr.PPOCRv2( diff --git a/examples/vision/ocr/PP-OCRv3/cpp/infer.cc b/examples/vision/ocr/PP-OCRv3/cpp/infer.cc index 333dbaa3f..da950872b 100644 --- a/examples/vision/ocr/PP-OCRv3/cpp/infer.cc +++ b/examples/vision/ocr/PP-OCRv3/cpp/infer.cc @@ -29,9 +29,25 @@ void InitAndInfer(const std::string& det_model_dir, const std::string& cls_model auto rec_model_file = rec_model_dir + sep + "inference.pdmodel"; auto rec_params_file = rec_model_dir + sep + "inference.pdiparams"; - auto det_model = fastdeploy::vision::ocr::DBDetector(det_model_file, det_params_file, option); - auto cls_model = fastdeploy::vision::ocr::Classifier(cls_model_file, cls_params_file, option); - auto rec_model = fastdeploy::vision::ocr::Recognizer(rec_model_file, rec_params_file, rec_label_file, option); + auto det_option = option; + auto cls_option = option; + auto rec_option = option; + + // If use TRT backend, the dynamic shape will be set as follow. + det_option.SetTrtInputShape("x", {1, 3, 50, 50}, {1, 3, 640, 640}, + {1, 3, 1536, 1536}); + cls_option.SetTrtInputShape("x", {1, 3, 48, 10}, {1, 3, 48, 320}, {1, 3, 48, 1024}); + rec_option.SetTrtInputShape("x", {1, 3, 48, 10}, {1, 3, 48, 320}, + {1, 3, 48, 2304}); + + // Users could save TRT cache file to disk as follow. + // det_option.SetTrtCacheFile(det_model_dir + sep + "det_trt_cache.trt"); + // cls_option.SetTrtCacheFile(cls_model_dir + sep + "cls_trt_cache.trt"); + // rec_option.SetTrtCacheFile(rec_model_dir + sep + "rec_trt_cache.trt"); + + auto det_model = fastdeploy::vision::ocr::DBDetector(det_model_file, det_params_file, det_option); + auto cls_model = fastdeploy::vision::ocr::Classifier(cls_model_file, cls_params_file, cls_option); + auto rec_model = fastdeploy::vision::ocr::Recognizer(rec_model_file, rec_params_file, rec_label_file, rec_option); assert(det_model.Initialized()); assert(cls_model.Initialized()); diff --git a/examples/vision/ocr/PP-OCRv3/python/infer.py b/examples/vision/ocr/PP-OCRv3/python/infer.py index 46df9c507..43b9b630c 100644 --- a/examples/vision/ocr/PP-OCRv3/python/infer.py +++ b/examples/vision/ocr/PP-OCRv3/python/infer.py @@ -100,15 +100,29 @@ rec_label_file = args.rec_label_file # 用户也可根据自行需求分别配置 runtime_option = build_option(args) +det_option = runtime_option +cls_option = runtime_option +rec_option = runtime_option + +# 当使用TRT时,分别给三个Runtime设置动态shape +det_option.set_trt_input_shape("x", [1, 3, 50, 50], [1, 3, 640, 640], + [1, 3, 1536, 1536]) +cls_option.set_trt_input_shape("x", [1, 3, 48, 10], [1, 3, 48, 320], + [1, 3, 48, 1024]) +rec_option.set_trt_input_shape("x", [1, 3, 48, 10], [1, 3, 48, 320], + [1, 3, 48, 2304]) + +# 用户可以把TRT引擎文件保存至本地 +# det_option.set_trt_cache_file(args.det_model + "/det_trt_cache.trt") +# cls_option.set_trt_cache_file(args.cls_model + "/cls_trt_cache.trt") +# rec_option.set_trt_cache_file(args.rec_model + "/rec_trt_cache.trt") + det_model = fd.vision.ocr.DBDetector( - det_model_file, det_params_file, runtime_option=runtime_option) + det_model_file, det_params_file, runtime_option=det_option) cls_model = fd.vision.ocr.Classifier( - cls_model_file, cls_params_file, runtime_option=runtime_option) + cls_model_file, cls_params_file, runtime_option=cls_option) rec_model = fd.vision.ocr.Recognizer( - rec_model_file, - rec_params_file, - rec_label_file, - runtime_option=runtime_option) + rec_model_file, rec_params_file, rec_label_file, runtime_option=rec_option) # 创建PP-OCR,串联3个模型,其中cls_model可选,如无需求,可设置为None ppocr_v3 = fd.vision.ocr.PPOCRv3(