[backend][Serving]Fix paddle backend get outout tensor error (#741)

fix paddle backend no_copy_infer
This commit is contained in:
heliqi
2022-11-29 18:34:56 +08:00
committed by GitHub
parent b96c8a4146
commit 5dca3a45c9
7 changed files with 133 additions and 20 deletions

View File

@@ -87,14 +87,29 @@ void PaddleTensorToFDTensor(std::unique_ptr<paddle_infer::Tensor>& tensor,
int size = 0; int size = 0;
// TODO(liqi): The tensor->data interface of paddle don't return device id // TODO(liqi): The tensor->data interface of paddle don't return device id
// and don't support return void*. // and don't support return void*.
auto* out_data = tensor->data<uint8_t>(&place, &size); void* out_data = nullptr;
if (fd_dtype == FDDataType::FP32) {
out_data = tensor->data<float>(&place, &size);
} else if (fd_dtype == FDDataType::INT32) {
out_data = tensor->data<int>(&place, &size);
} else if (fd_dtype == FDDataType::INT64) {
out_data = tensor->data<int64_t>(&place, &size);
} else if (fd_dtype == FDDataType::INT8) {
out_data = tensor->data<int8_t>(&place, &size);
} else if (fd_dtype == FDDataType::UINT8) {
out_data = tensor->data<uint8_t>(&place, &size);
} else {
FDASSERT(false, "Unexpected data type(%s) while infer shared with PaddleBackend.",
Str(fd_dtype).c_str());
}
Device device = Device::CPU; Device device = Device::CPU;
if(place == paddle_infer::PlaceType::kGPU) { if(place == paddle_infer::PlaceType::kGPU) {
device = Device::GPU; device = Device::GPU;
} }
fd_tensor->name = tensor->name();
fd_tensor->SetExternalData( fd_tensor->SetExternalData(
shape, fd_dtype, shape, fd_dtype,
reinterpret_cast<void*>(out_data), device); out_data, device);
} }
} }

View File

@@ -181,7 +181,8 @@ void BindFDTensor(pybind11::module& m) {
.def("from_numpy", [](FDTensor& self, pybind11::array& pyarray, bool share_buffer = false) { .def("from_numpy", [](FDTensor& self, pybind11::array& pyarray, bool share_buffer = false) {
PyArrayToTensor(pyarray, &self, share_buffer); PyArrayToTensor(pyarray, &self, share_buffer);
}) })
.def("to_dlpack", &FDTensorToDLPack); .def("to_dlpack", &FDTensorToDLPack)
.def("print_info", &FDTensor::PrintInfo);
} }
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -128,12 +128,6 @@ void BindRuntime(pybind11::module& m) {
} }
return self.Compile(warm_tensors, _option); return self.Compile(warm_tensors, _option);
}) })
.def("infer",
[](Runtime& self, std::vector<FDTensor>& inputs) {
std::vector<FDTensor> outputs(self.NumOutputs());
self.Infer(inputs, &outputs);
return outputs;
})
.def("infer", .def("infer",
[](Runtime& self, std::map<std::string, pybind11::array>& data) { [](Runtime& self, std::map<std::string, pybind11::array>& data) {
std::vector<FDTensor> inputs(data.size()); std::vector<FDTensor> inputs(data.size());
@@ -185,6 +179,17 @@ void BindRuntime(pybind11::module& m) {
std::vector<FDTensor> outputs; std::vector<FDTensor> outputs;
return self.Infer(inputs, &outputs); return self.Infer(inputs, &outputs);
}) })
.def("bind_input_tensor", &Runtime::BindInputTensor)
.def("infer", [](Runtime& self) {
self.Infer();
})
.def("get_output_tensor", [](Runtime& self, const std::string& name) {
FDTensor* output = self.GetOutputTensor(name);
if(output == nullptr) {
return pybind11::cast(nullptr);
}
return pybind11::cast(*output);
})
.def("num_inputs", &Runtime::NumInputs) .def("num_inputs", &Runtime::NumInputs)
.def("num_outputs", &Runtime::NumOutputs) .def("num_outputs", &Runtime::NumOutputs)
.def("get_input_info", &Runtime::GetInputInfo) .def("get_input_info", &Runtime::GetInputInfo)

View File

@@ -610,6 +610,7 @@ FDTensor* Runtime::GetOutputTensor(const std::string& name) {
return &t; return &t;
} }
} }
FDWARNING << "The output name [" << name << "] don't exist." << std::endl;
return nullptr; return nullptr;
} }

View File

@@ -57,11 +57,36 @@ class Runtime:
""" """
assert isinstance(data, dict) or isinstance( assert isinstance(data, dict) or isinstance(
data, list), "The input data should be type of dict or list." data, list), "The input data should be type of dict or list."
if isinstance(data, dict):
for k, v in data.items(): for k, v in data.items():
if not v.data.contiguous: if isinstance(v, np.ndarray) and not v.data.contiguous:
data[k] = np.ascontiguousarray(data[k]) data[k] = np.ascontiguousarray(data[k])
return self._runtime.infer(data) return self._runtime.infer(data)
def bind_input_tensor(self, name, fdtensor):
"""Bind FDTensor by name, no copy and share input memory
:param name: (str)The name of input data.
:param fdtensor: (fastdeploy.FDTensor)The input FDTensor.
"""
self._runtime.bind_input_tensor(name, fdtensor)
def zero_copy_infer(self):
"""No params inference the model.
the input and output data need to pass through the bind_input_tensor and get_output_tensor interfaces.
"""
self._runtime.infer()
def get_output_tensor(self, name):
"""Get output FDTensor by name, no copy and share backend output memory
:param name: (str)The name of output data.
:return fastdeploy.FDTensor
"""
return self._runtime.get_output_tensor(name)
def compile(self, warm_datas): def compile(self, warm_datas):
"""[Only for Poros backend] compile with prewarm data for poros """[Only for Poros backend] compile with prewarm data for poros
@@ -178,7 +203,8 @@ class RuntimeOption:
@long_to_int.setter @long_to_int.setter
def long_to_int(self, value): def long_to_int(self, value):
assert isinstance( assert isinstance(
value, bool), "The value to set `long_to_int` must be type of bool." value,
bool), "The value to set `long_to_int` must be type of bool."
self._option.long_to_int = value self._option.long_to_int = value
@use_nvidia_tf32.setter @use_nvidia_tf32.setter
@@ -434,7 +460,8 @@ class RuntimeOption:
continue continue
if hasattr(getattr(self._option, attr), "__call__"): if hasattr(getattr(self._option, attr), "__call__"):
continue continue
message += " {} : {}\t\n".format(attr, getattr(self._option, attr)) message += " {} : {}\t\n".format(attr,
getattr(self._option, attr))
message.strip("\n") message.strip("\n")
message += ")" message += ")"
return message return message

View File

@@ -37,7 +37,7 @@ nvidia-docker run -i --rm --name build_fd \
nvcr.io/nvidia/tritonserver:21.10-py3-min \ nvcr.io/nvidia/tritonserver:21.10-py3-min \
bash -c \ bash -c \
'cd /workspace/fastdeploy/python; 'cd /workspace/fastdeploy/python;
rm -rf .setuptools-cmake-build dist; rm -rf .setuptools-cmake-build dist build fastdeploy/libs/third_libs;
apt-get update; apt-get update;
apt-get install -y --no-install-recommends patchelf python3-dev python3-pip rapidjson-dev; apt-get install -y --no-install-recommends patchelf python3-dev python3-pip rapidjson-dev;
ln -s /usr/bin/python3 /usr/bin/python; ln -s /usr/bin/python3 /usr/bin/python;
@@ -75,7 +75,7 @@ docker run -i --rm --name build_fd \
paddlepaddle/fastdeploy:21.10-cpu-only-buildbase \ paddlepaddle/fastdeploy:21.10-cpu-only-buildbase \
bash -c \ bash -c \
'cd /workspace/fastdeploy/python; 'cd /workspace/fastdeploy/python;
rm -rf .setuptools-cmake-build dist; rm -rf .setuptools-cmake-build dist build fastdeploy/libs/third_libs;
ln -s /usr/bin/python3 /usr/bin/python; ln -s /usr/bin/python3 /usr/bin/python;
export WITH_GPU=OFF; export WITH_GPU=OFF;
export ENABLE_ORT_BACKEND=OFF; export ENABLE_ORT_BACKEND=OFF;

View File

@@ -66,6 +66,7 @@ def test_detection_faster_rcnn():
# with open("faster_rcnn_baseline.pkl", "wb") as f: # with open("faster_rcnn_baseline.pkl", "wb") as f:
# pickle.dump([np.array(result.boxes), np.array(result.scores), np.array(result.label_ids)], f) # pickle.dump([np.array(result.boxes), np.array(result.scores), np.array(result.label_ids)], f)
def test_detection_faster_rcnn1(): def test_detection_faster_rcnn1():
model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/faster_rcnn_r50_vd_fpn_2x_coco.tgz" model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/faster_rcnn_r50_vd_fpn_2x_coco.tgz"
input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg" input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg"
@@ -84,14 +85,17 @@ def test_detection_faster_rcnn1():
option = rc.test_option option = rc.test_option
option.set_model_path(model_file, params_file) option.set_model_path(model_file, params_file)
option.use_paddle_infer_backend() option.use_paddle_infer_backend()
runtime = fd.Runtime(option); runtime = fd.Runtime(option)
# compare diff # compare diff
im1 = cv2.imread("./resources/000000014439.jpg")
for i in range(2): for i in range(2):
im1 = cv2.imread("./resources/000000014439.jpg") im1 = cv2.imread("./resources/000000014439.jpg")
input_tensors = preprocessor.run([im1]) input_tensors = preprocessor.run([im1])
output_tensors = runtime.infer({"image": input_tensors[0], "scale_factor": input_tensors[1], "im_shape": input_tensors[2]}) output_tensors = runtime.infer({
"image": input_tensors[0],
"scale_factor": input_tensors[1],
"im_shape": input_tensors[2]
})
results = postprocessor.run(output_tensors) results = postprocessor.run(output_tensors)
result = results[0] result = results[0]
@@ -115,6 +119,66 @@ def test_detection_faster_rcnn1():
assert diff_label_ids[scores > score_threshold].max( assert diff_label_ids[scores > score_threshold].max(
) < 1e-04, "There's diff in label_ids." ) < 1e-04, "There's diff in label_ids."
# test runtime.zero_copy_infer and bind_input_tensor get_output_tensor
def test_detection_faster_rcnn2():
model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/faster_rcnn_r50_vd_fpn_2x_coco.tgz"
input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg"
result_url = "https://bj.bcebos.com/fastdeploy/tests/data/faster_rcnn_baseline.pkl"
fd.download_and_decompress(model_url, "resources")
fd.download(input_url1, "resources")
fd.download(result_url, "resources")
model_path = "resources/faster_rcnn_r50_vd_fpn_2x_coco"
model_file = os.path.join(model_path, "model.pdmodel")
params_file = os.path.join(model_path, "model.pdiparams")
config_file = os.path.join(model_path, "infer_cfg.yml")
preprocessor = fd.vision.detection.PaddleDetPreprocessor(config_file)
postprocessor = fd.vision.detection.PaddleDetPostprocessor()
option = rc.test_option
option.set_model_path(model_file, params_file)
option.use_paddle_infer_backend()
runtime = fd.Runtime(option)
# compare diff
input_names = ["image", "scale_factor", "im_shape"]
output_names = ["concat_12.tmp_0", "concat_8.tmp_0"]
for i in range(2):
im1 = cv2.imread("./resources/000000014439.jpg")
input_tensors = preprocessor.run([im1.copy(), ])
for i, input_tensor in enumerate(input_tensors):
runtime.bind_input_tensor(input_names[i], input_tensor)
runtime.zero_copy_infer()
output_tensors = []
for name in output_names:
output_tensor = runtime.get_output_tensor(name)
output_tensors.append(output_tensor)
results = postprocessor.run(output_tensors)
result = results[0]
with open("resources/faster_rcnn_baseline.pkl", "rb") as f:
boxes, scores, label_ids = pickle.load(f)
pred_boxes = np.array(result.boxes)
pred_scores = np.array(result.scores)
pred_label_ids = np.array(result.label_ids)
diff_boxes = np.fabs(boxes - pred_boxes)
diff_scores = np.fabs(scores - pred_scores)
diff_label_ids = np.fabs(label_ids - pred_label_ids)
print(diff_boxes.max(), diff_scores.max(), diff_label_ids.max())
score_threshold = 0.0
assert diff_boxes[scores > score_threshold].max(
) < 1e-04, "There's diff in boxes."
assert diff_scores[scores > score_threshold].max(
) < 1e-04, "There's diff in scores."
assert diff_label_ids[scores > score_threshold].max(
) < 1e-04, "There's diff in label_ids."
if __name__ == "__main__": if __name__ == "__main__":
test_detection_faster_rcnn() test_detection_faster_rcnn()
test_detection_faster_rcnn1() test_detection_faster_rcnn1()
test_detection_faster_rcnn2()