[backend][Serving]Fix paddle backend get outout tensor error (#741)

fix paddle backend no_copy_infer
This commit is contained in:
heliqi
2022-11-29 18:34:56 +08:00
committed by GitHub
parent b96c8a4146
commit 5dca3a45c9
7 changed files with 133 additions and 20 deletions

View File

@@ -57,11 +57,36 @@ class Runtime:
"""
assert isinstance(data, dict) or isinstance(
data, list), "The input data should be type of dict or list."
for k, v in data.items():
if not v.data.contiguous:
data[k] = np.ascontiguousarray(data[k])
if isinstance(data, dict):
for k, v in data.items():
if isinstance(v, np.ndarray) and not v.data.contiguous:
data[k] = np.ascontiguousarray(data[k])
return self._runtime.infer(data)
def bind_input_tensor(self, name, fdtensor):
"""Bind FDTensor by name, no copy and share input memory
:param name: (str)The name of input data.
:param fdtensor: (fastdeploy.FDTensor)The input FDTensor.
"""
self._runtime.bind_input_tensor(name, fdtensor)
def zero_copy_infer(self):
"""No params inference the model.
the input and output data need to pass through the bind_input_tensor and get_output_tensor interfaces.
"""
self._runtime.infer()
def get_output_tensor(self, name):
"""Get output FDTensor by name, no copy and share backend output memory
:param name: (str)The name of output data.
:return fastdeploy.FDTensor
"""
return self._runtime.get_output_tensor(name)
def compile(self, warm_datas):
"""[Only for Poros backend] compile with prewarm data for poros
@@ -178,7 +203,8 @@ class RuntimeOption:
@long_to_int.setter
def long_to_int(self, value):
assert isinstance(
value, bool), "The value to set `long_to_int` must be type of bool."
value,
bool), "The value to set `long_to_int` must be type of bool."
self._option.long_to_int = value
@use_nvidia_tf32.setter
@@ -434,7 +460,8 @@ class RuntimeOption:
continue
if hasattr(getattr(self._option, attr), "__call__"):
continue
message += " {} : {}\t\n".format(attr, getattr(self._option, attr))
message += " {} : {}\t\n".format(attr,
getattr(self._option, attr))
message.strip("\n")
message += ")"
return message