[backend][Serving]Fix paddle backend get outout tensor error (#741)

fix paddle backend no_copy_infer
This commit is contained in:
heliqi
2022-11-29 18:34:56 +08:00
committed by GitHub
parent b96c8a4146
commit 5dca3a45c9
7 changed files with 133 additions and 20 deletions

View File

@@ -87,14 +87,29 @@ void PaddleTensorToFDTensor(std::unique_ptr<paddle_infer::Tensor>& tensor,
int size = 0;
// TODO(liqi): The tensor->data interface of paddle don't return device id
// and don't support return void*.
auto* out_data = tensor->data<uint8_t>(&place, &size);
void* out_data = nullptr;
if (fd_dtype == FDDataType::FP32) {
out_data = tensor->data<float>(&place, &size);
} else if (fd_dtype == FDDataType::INT32) {
out_data = tensor->data<int>(&place, &size);
} else if (fd_dtype == FDDataType::INT64) {
out_data = tensor->data<int64_t>(&place, &size);
} else if (fd_dtype == FDDataType::INT8) {
out_data = tensor->data<int8_t>(&place, &size);
} else if (fd_dtype == FDDataType::UINT8) {
out_data = tensor->data<uint8_t>(&place, &size);
} else {
FDASSERT(false, "Unexpected data type(%s) while infer shared with PaddleBackend.",
Str(fd_dtype).c_str());
}
Device device = Device::CPU;
if(place == paddle_infer::PlaceType::kGPU) {
device = Device::GPU;
}
fd_tensor->name = tensor->name();
fd_tensor->SetExternalData(
shape, fd_dtype,
reinterpret_cast<void*>(out_data), device);
out_data, device);
}
}

View File

@@ -181,7 +181,8 @@ void BindFDTensor(pybind11::module& m) {
.def("from_numpy", [](FDTensor& self, pybind11::array& pyarray, bool share_buffer = false) {
PyArrayToTensor(pyarray, &self, share_buffer);
})
.def("to_dlpack", &FDTensorToDLPack);
.def("to_dlpack", &FDTensorToDLPack)
.def("print_info", &FDTensor::PrintInfo);
}
} // namespace fastdeploy

View File

@@ -128,12 +128,6 @@ void BindRuntime(pybind11::module& m) {
}
return self.Compile(warm_tensors, _option);
})
.def("infer",
[](Runtime& self, std::vector<FDTensor>& inputs) {
std::vector<FDTensor> outputs(self.NumOutputs());
self.Infer(inputs, &outputs);
return outputs;
})
.def("infer",
[](Runtime& self, std::map<std::string, pybind11::array>& data) {
std::vector<FDTensor> inputs(data.size());
@@ -185,6 +179,17 @@ void BindRuntime(pybind11::module& m) {
std::vector<FDTensor> outputs;
return self.Infer(inputs, &outputs);
})
.def("bind_input_tensor", &Runtime::BindInputTensor)
.def("infer", [](Runtime& self) {
self.Infer();
})
.def("get_output_tensor", [](Runtime& self, const std::string& name) {
FDTensor* output = self.GetOutputTensor(name);
if(output == nullptr) {
return pybind11::cast(nullptr);
}
return pybind11::cast(*output);
})
.def("num_inputs", &Runtime::NumInputs)
.def("num_outputs", &Runtime::NumOutputs)
.def("get_input_info", &Runtime::GetInputInfo)

View File

@@ -610,6 +610,7 @@ FDTensor* Runtime::GetOutputTensor(const std::string& name) {
return &t;
}
}
FDWARNING << "The output name [" << name << "] don't exist." << std::endl;
return nullptr;
}

View File

@@ -57,11 +57,36 @@ class Runtime:
"""
assert isinstance(data, dict) or isinstance(
data, list), "The input data should be type of dict or list."
if isinstance(data, dict):
for k, v in data.items():
if not v.data.contiguous:
if isinstance(v, np.ndarray) and not v.data.contiguous:
data[k] = np.ascontiguousarray(data[k])
return self._runtime.infer(data)
def bind_input_tensor(self, name, fdtensor):
"""Bind FDTensor by name, no copy and share input memory
:param name: (str)The name of input data.
:param fdtensor: (fastdeploy.FDTensor)The input FDTensor.
"""
self._runtime.bind_input_tensor(name, fdtensor)
def zero_copy_infer(self):
"""No params inference the model.
the input and output data need to pass through the bind_input_tensor and get_output_tensor interfaces.
"""
self._runtime.infer()
def get_output_tensor(self, name):
"""Get output FDTensor by name, no copy and share backend output memory
:param name: (str)The name of output data.
:return fastdeploy.FDTensor
"""
return self._runtime.get_output_tensor(name)
def compile(self, warm_datas):
"""[Only for Poros backend] compile with prewarm data for poros
@@ -178,7 +203,8 @@ class RuntimeOption:
@long_to_int.setter
def long_to_int(self, value):
assert isinstance(
value, bool), "The value to set `long_to_int` must be type of bool."
value,
bool), "The value to set `long_to_int` must be type of bool."
self._option.long_to_int = value
@use_nvidia_tf32.setter
@@ -434,7 +460,8 @@ class RuntimeOption:
continue
if hasattr(getattr(self._option, attr), "__call__"):
continue
message += " {} : {}\t\n".format(attr, getattr(self._option, attr))
message += " {} : {}\t\n".format(attr,
getattr(self._option, attr))
message.strip("\n")
message += ")"
return message

View File

@@ -37,7 +37,7 @@ nvidia-docker run -i --rm --name build_fd \
nvcr.io/nvidia/tritonserver:21.10-py3-min \
bash -c \
'cd /workspace/fastdeploy/python;
rm -rf .setuptools-cmake-build dist;
rm -rf .setuptools-cmake-build dist build fastdeploy/libs/third_libs;
apt-get update;
apt-get install -y --no-install-recommends patchelf python3-dev python3-pip rapidjson-dev;
ln -s /usr/bin/python3 /usr/bin/python;
@@ -75,7 +75,7 @@ docker run -i --rm --name build_fd \
paddlepaddle/fastdeploy:21.10-cpu-only-buildbase \
bash -c \
'cd /workspace/fastdeploy/python;
rm -rf .setuptools-cmake-build dist;
rm -rf .setuptools-cmake-build dist build fastdeploy/libs/third_libs;
ln -s /usr/bin/python3 /usr/bin/python;
export WITH_GPU=OFF;
export ENABLE_ORT_BACKEND=OFF;

View File

@@ -66,6 +66,7 @@ def test_detection_faster_rcnn():
# with open("faster_rcnn_baseline.pkl", "wb") as f:
# pickle.dump([np.array(result.boxes), np.array(result.scores), np.array(result.label_ids)], f)
def test_detection_faster_rcnn1():
model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/faster_rcnn_r50_vd_fpn_2x_coco.tgz"
input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg"
@@ -84,14 +85,17 @@ def test_detection_faster_rcnn1():
option = rc.test_option
option.set_model_path(model_file, params_file)
option.use_paddle_infer_backend()
runtime = fd.Runtime(option);
runtime = fd.Runtime(option)
# compare diff
im1 = cv2.imread("./resources/000000014439.jpg")
for i in range(2):
im1 = cv2.imread("./resources/000000014439.jpg")
input_tensors = preprocessor.run([im1])
output_tensors = runtime.infer({"image": input_tensors[0], "scale_factor": input_tensors[1], "im_shape": input_tensors[2]})
output_tensors = runtime.infer({
"image": input_tensors[0],
"scale_factor": input_tensors[1],
"im_shape": input_tensors[2]
})
results = postprocessor.run(output_tensors)
result = results[0]
@@ -115,6 +119,66 @@ def test_detection_faster_rcnn1():
assert diff_label_ids[scores > score_threshold].max(
) < 1e-04, "There's diff in label_ids."
# test runtime.zero_copy_infer and bind_input_tensor get_output_tensor
def test_detection_faster_rcnn2():
model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/faster_rcnn_r50_vd_fpn_2x_coco.tgz"
input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg"
result_url = "https://bj.bcebos.com/fastdeploy/tests/data/faster_rcnn_baseline.pkl"
fd.download_and_decompress(model_url, "resources")
fd.download(input_url1, "resources")
fd.download(result_url, "resources")
model_path = "resources/faster_rcnn_r50_vd_fpn_2x_coco"
model_file = os.path.join(model_path, "model.pdmodel")
params_file = os.path.join(model_path, "model.pdiparams")
config_file = os.path.join(model_path, "infer_cfg.yml")
preprocessor = fd.vision.detection.PaddleDetPreprocessor(config_file)
postprocessor = fd.vision.detection.PaddleDetPostprocessor()
option = rc.test_option
option.set_model_path(model_file, params_file)
option.use_paddle_infer_backend()
runtime = fd.Runtime(option)
# compare diff
input_names = ["image", "scale_factor", "im_shape"]
output_names = ["concat_12.tmp_0", "concat_8.tmp_0"]
for i in range(2):
im1 = cv2.imread("./resources/000000014439.jpg")
input_tensors = preprocessor.run([im1.copy(), ])
for i, input_tensor in enumerate(input_tensors):
runtime.bind_input_tensor(input_names[i], input_tensor)
runtime.zero_copy_infer()
output_tensors = []
for name in output_names:
output_tensor = runtime.get_output_tensor(name)
output_tensors.append(output_tensor)
results = postprocessor.run(output_tensors)
result = results[0]
with open("resources/faster_rcnn_baseline.pkl", "rb") as f:
boxes, scores, label_ids = pickle.load(f)
pred_boxes = np.array(result.boxes)
pred_scores = np.array(result.scores)
pred_label_ids = np.array(result.label_ids)
diff_boxes = np.fabs(boxes - pred_boxes)
diff_scores = np.fabs(scores - pred_scores)
diff_label_ids = np.fabs(label_ids - pred_label_ids)
print(diff_boxes.max(), diff_scores.max(), diff_label_ids.max())
score_threshold = 0.0
assert diff_boxes[scores > score_threshold].max(
) < 1e-04, "There's diff in boxes."
assert diff_scores[scores > score_threshold].max(
) < 1e-04, "There's diff in scores."
assert diff_label_ids[scores > score_threshold].max(
) < 1e-04, "There's diff in label_ids."
if __name__ == "__main__":
test_detection_faster_rcnn()
test_detection_faster_rcnn1()
test_detection_faster_rcnn2()