mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[backend][Serving]Fix paddle backend get outout tensor error (#741)
fix paddle backend no_copy_infer
This commit is contained in:
@@ -57,11 +57,36 @@ class Runtime:
|
||||
"""
|
||||
assert isinstance(data, dict) or isinstance(
|
||||
data, list), "The input data should be type of dict or list."
|
||||
for k, v in data.items():
|
||||
if not v.data.contiguous:
|
||||
data[k] = np.ascontiguousarray(data[k])
|
||||
if isinstance(data, dict):
|
||||
for k, v in data.items():
|
||||
if isinstance(v, np.ndarray) and not v.data.contiguous:
|
||||
data[k] = np.ascontiguousarray(data[k])
|
||||
|
||||
return self._runtime.infer(data)
|
||||
|
||||
def bind_input_tensor(self, name, fdtensor):
|
||||
"""Bind FDTensor by name, no copy and share input memory
|
||||
|
||||
:param name: (str)The name of input data.
|
||||
:param fdtensor: (fastdeploy.FDTensor)The input FDTensor.
|
||||
"""
|
||||
self._runtime.bind_input_tensor(name, fdtensor)
|
||||
|
||||
def zero_copy_infer(self):
|
||||
"""No params inference the model.
|
||||
|
||||
the input and output data need to pass through the bind_input_tensor and get_output_tensor interfaces.
|
||||
"""
|
||||
self._runtime.infer()
|
||||
|
||||
def get_output_tensor(self, name):
|
||||
"""Get output FDTensor by name, no copy and share backend output memory
|
||||
|
||||
:param name: (str)The name of output data.
|
||||
:return fastdeploy.FDTensor
|
||||
"""
|
||||
return self._runtime.get_output_tensor(name)
|
||||
|
||||
def compile(self, warm_datas):
|
||||
"""[Only for Poros backend] compile with prewarm data for poros
|
||||
|
||||
@@ -178,7 +203,8 @@ class RuntimeOption:
|
||||
@long_to_int.setter
|
||||
def long_to_int(self, value):
|
||||
assert isinstance(
|
||||
value, bool), "The value to set `long_to_int` must be type of bool."
|
||||
value,
|
||||
bool), "The value to set `long_to_int` must be type of bool."
|
||||
self._option.long_to_int = value
|
||||
|
||||
@use_nvidia_tf32.setter
|
||||
@@ -434,7 +460,8 @@ class RuntimeOption:
|
||||
continue
|
||||
if hasattr(getattr(self._option, attr), "__call__"):
|
||||
continue
|
||||
message += " {} : {}\t\n".format(attr, getattr(self._option, attr))
|
||||
message += " {} : {}\t\n".format(attr,
|
||||
getattr(self._option, attr))
|
||||
message.strip("\n")
|
||||
message += ")"
|
||||
return message
|
||||
|
Reference in New Issue
Block a user