[Other]Refactor PaddleSeg with preprocessor && postprocessor && support batch (#639)

* Refactor PaddleSeg with preprocessor && postprocessor

* Fix bugs

* Delete redundancy code

* Modify by comments

* Refactor according to comments

* Add batch evaluation

* Add single test script

* Add ppliteseg single test script && fix eval(raise) error

* fix bug

* Fix evaluation segmentation.py batch predict

* Fix segmentation evaluation bug

* Fix evaluation segmentation bugs

Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
huangjianhui
2022-11-28 15:50:12 +08:00
committed by GitHub
parent d0307192f9
commit 312e1b097d
26 changed files with 1173 additions and 449 deletions

View File

@@ -20,7 +20,7 @@ import math
import time
def eval_segmentation(model, data_dir):
def eval_segmentation(model, data_dir, batch_size=1):
import cv2
from .utils import Cityscapes
from .utils import f1_score, calculate_area, mean_iou, accuracy, kappa
@@ -39,6 +39,8 @@ def eval_segmentation(model, data_dir):
start_time = 0
end_time = 0
average_inference_time = 0
im_list = []
label_list = []
for image_label_path, i in zip(file_list,
trange(
image_num, desc="Inference Progress")):
@@ -46,19 +48,31 @@ def eval_segmentation(model, data_dir):
start_time = time.time()
im = cv2.imread(image_label_path[0])
label = cv2.imread(image_label_path[1], cv2.IMREAD_GRAYSCALE)
result = model.predict(im)
label_list.append(label)
if batch_size == 1:
result = model.predict(im)
results = [result]
else:
im_list.append(im)
# If the batch_size is not satisfied, the remaining pictures are formed into a batch
if (i + 1) % batch_size != 0 and i != image_num - 1:
continue
results = model.batch_predict(im_list)
if i == image_num - 1:
end_time = time.time()
average_inference_time = round(
(end_time - start_time) / (image_num - twenty_percent_image_num),
4)
pred = np.array(result.label_map).reshape(result.shape[0],
result.shape[1])
intersect_area, pred_area, label_area = calculate_area(pred, label,
num_classes)
intersect_area_all = intersect_area_all + intersect_area
pred_area_all = pred_area_all + pred_area
label_area_all = label_area_all + label_area
average_inference_time = round(
(end_time - start_time) /
(image_num - twenty_percent_image_num), 4)
for result, label in zip(results, label_list):
pred = np.array(result.label_map).reshape(result.shape[0],
result.shape[1])
intersect_area, pred_area, label_area = calculate_area(pred, label,
num_classes)
intersect_area_all = intersect_area_all + intersect_area
pred_area_all = pred_area_all + pred_area
label_area_all = label_area_all + label_area
im_list.clear()
label_list.clear()
class_iou, miou = mean_iou(intersect_area_all, pred_area_all,
label_area_all)

View File

@@ -13,4 +13,4 @@
# limitations under the License.
from __future__ import absolute_import
from .ppseg import PaddleSegModel
from .ppseg import *

View File

@@ -41,35 +41,55 @@ class PaddleSegModel(FastDeployModel):
model_format)
assert self.initialized, "PaddleSeg model initialize failed."
def predict(self, input_image):
def predict(self, image):
"""Predict the segmentation result for an input image
:param im: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
:return: SegmentationResult
"""
return self._model.predict(input_image)
return self._model.predict(image)
def disable_normalize_and_permute(self):
return self._model.disable_normalize_and_permute()
def batch_predict(self, image_list):
"""Predict the segmentation results for a batch of input image
:param image_list: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format
:return list of SegmentationResult
"""
return self._model.batch_predict(image_list)
@property
def apply_softmax(self):
"""Atrribute of PaddleSeg model. Stating Whether applying softmax operator in the postprocess, default value is False
:return: value of apply_softmax(bool)
def preprocessor(self):
"""Get PaddleSegPreprocessor object of the loaded model
:return PaddleSegPreprocessor
"""
return self._model.apply_softmax
return self._model.preprocessor
@apply_softmax.setter
def apply_softmax(self, value):
"""Set attribute apply_softmax of PaddleSeg model.
:param value: (bool)The value to set apply_softmax
@property
def postprocessor(self):
"""Get PaddleSegPostprocessor object of the loaded model
:return PaddleSegPostprocessor
"""
assert isinstance(
value,
bool), "The value to set `apply_softmax` must be type of bool."
self._model.apply_softmax = value
return self._model.postprocessor
class PaddleSegPreprocessor:
def __init__(self, config_file):
"""Create a preprocessor for PaddleSegModel from configuration file
:param config_file: (str)Path of configuration file, e.g ppliteseg/deploy.yaml
"""
self._preprocessor = C.vision.segmentation.PaddleSegPreprocessor(
config_file)
def run(self, input_ims):
"""Preprocess input images for PaddleSegModel
:param: input_ims: (list of numpy.ndarray)The input image
:return: list of FDTensor
"""
return self._preprocessor.run(input_ims)
def disable_normalize_and_permute(self):
"""To disable normalize and hwc2chw in preprocessing step.
"""
return self._preprocessor.disable_normalize_and_permute()
@property
def is_vertical_screen(self):
@@ -77,7 +97,7 @@ class PaddleSegModel(FastDeployModel):
:return: value of is_vertical_screen(bool)
"""
return self._model.is_vertical_screen
return self._preprocessor.is_vertical_screen
@is_vertical_screen.setter
def is_vertical_screen(self, value):
@@ -88,4 +108,59 @@ class PaddleSegModel(FastDeployModel):
assert isinstance(
value,
bool), "The value to set `is_vertical_screen` must be type of bool."
self._model.is_vertical_screen = value
self._preprocessor.is_vertical_screen = value
class PaddleSegPostprocessor:
def __init__(self, config_file):
"""Create a postprocessor for PaddleSegModel from configuration file
:param config_file: (str)Path of configuration file, e.g ppliteseg/deploy.yaml
"""
self._postprocessor = C.vision.segmentation.PaddleSegPostprocessor(
config_file)
def run(self, runtime_results, imgs_info):
"""Postprocess the runtime results for PaddleSegModel
:param: runtime_results: (list of FDTensor)The output FDTensor results from runtime
:param: imgs_info: The original input images shape info map, key is "shape_info", value is [[image_height, image_width]]
:return: list of SegmentationResult(If the runtime_results is predict by batched samples, the length of this list equals to the batch size)
"""
return self._postprocessor.run(runtime_results, imgs_info)
@property
def apply_softmax(self):
"""Atrribute of PaddleSeg model. Stating Whether applying softmax operator in the postprocess, default value is False
:return: value of apply_softmax(bool)
"""
return self._postprocessor.apply_softmax
@apply_softmax.setter
def apply_softmax(self, value):
"""Set attribute apply_softmax of PaddleSeg model.
:param value: (bool)The value to set apply_softmax
"""
assert isinstance(
value,
bool), "The value to set `apply_softmax` must be type of bool."
self._postprocessor.apply_softmax = value
@property
def store_score_map(self):
"""Atrribute of PaddleSeg model. Stating Whether storing score map in the SegmentationResult, default value is False
:return: value of store_score_map(bool)
"""
return self._postprocessor.store_score_map
@store_score_map.setter
def store_score_map(self, value):
"""Set attribute store_score_map of PaddleSeg model.
:param value: (bool)The value to set store_score_map
"""
assert isinstance(
value,
bool), "The value to set `store_score_map` must be type of bool."
self._postprocessor.store_score_map = value