mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-09 10:30:37 +08:00
ORT support UINT8 and INT8 input and output
This commit is contained in:
@@ -239,7 +239,8 @@ bool OrtBackend::InitFromOnnx(const std::string& model_file,
|
||||
type_info.GetTensorTypeAndShapeInfo().GetShape();
|
||||
ONNXTensorElementDataType data_type =
|
||||
type_info.GetTensorTypeAndShapeInfo().GetElementType();
|
||||
inputs_desc_.emplace_back(OrtValueInfo{input_name_ptr.get(), shape, data_type});
|
||||
inputs_desc_.emplace_back(
|
||||
OrtValueInfo{input_name_ptr.get(), shape, data_type});
|
||||
}
|
||||
|
||||
size_t n_outputs = session_.GetOutputCount();
|
||||
@@ -250,7 +251,8 @@ bool OrtBackend::InitFromOnnx(const std::string& model_file,
|
||||
type_info.GetTensorTypeAndShapeInfo().GetShape();
|
||||
ONNXTensorElementDataType data_type =
|
||||
type_info.GetTensorTypeAndShapeInfo().GetElementType();
|
||||
outputs_desc_.emplace_back(OrtValueInfo{output_name_ptr.get(), shape, data_type});
|
||||
outputs_desc_.emplace_back(
|
||||
OrtValueInfo{output_name_ptr.get(), shape, data_type});
|
||||
|
||||
Ort::MemoryInfo out_memory_info("Cpu", OrtDeviceAllocator, 0,
|
||||
OrtMemTypeDefault);
|
||||
@@ -283,6 +285,12 @@ void OrtBackend::OrtValueToFDTensor(const Ort::Value& value, FDTensor* tensor,
|
||||
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
|
||||
dtype = FDDataType::FP16;
|
||||
numel *= sizeof(float16);
|
||||
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) {
|
||||
dtype = FDDataType::UINT8;
|
||||
numel *= sizeof(uint8_t);
|
||||
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
|
||||
dtype = FDDataType::INT8;
|
||||
numel *= sizeof(int8_t);
|
||||
} else {
|
||||
FDASSERT(
|
||||
false,
|
||||
|
@@ -50,6 +50,10 @@ FDDataType GetFdDtype(const ONNXTensorElementDataType& ort_dtype) {
|
||||
return FDDataType::INT64;
|
||||
} else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
|
||||
return FDDataType::FP16;
|
||||
} else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) {
|
||||
return FDDataType::UINT8;
|
||||
} else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
|
||||
return FDDataType::INT8;
|
||||
}
|
||||
FDERROR << "Unrecognized ort data type:" << ort_dtype << "." << std::endl;
|
||||
return FDDataType::FP32;
|
||||
|
Reference in New Issue
Block a user