ORT support UINT8 and INT8 input and output

This commit is contained in:
yunyaoXYY
2023-03-21 02:23:23 +00:00
parent 3cc72765dd
commit b3e16e9966
2 changed files with 14 additions and 2 deletions

View File

@@ -239,7 +239,8 @@ bool OrtBackend::InitFromOnnx(const std::string& model_file,
type_info.GetTensorTypeAndShapeInfo().GetShape();
ONNXTensorElementDataType data_type =
type_info.GetTensorTypeAndShapeInfo().GetElementType();
inputs_desc_.emplace_back(OrtValueInfo{input_name_ptr.get(), shape, data_type});
inputs_desc_.emplace_back(
OrtValueInfo{input_name_ptr.get(), shape, data_type});
}
size_t n_outputs = session_.GetOutputCount();
@@ -250,7 +251,8 @@ bool OrtBackend::InitFromOnnx(const std::string& model_file,
type_info.GetTensorTypeAndShapeInfo().GetShape();
ONNXTensorElementDataType data_type =
type_info.GetTensorTypeAndShapeInfo().GetElementType();
outputs_desc_.emplace_back(OrtValueInfo{output_name_ptr.get(), shape, data_type});
outputs_desc_.emplace_back(
OrtValueInfo{output_name_ptr.get(), shape, data_type});
Ort::MemoryInfo out_memory_info("Cpu", OrtDeviceAllocator, 0,
OrtMemTypeDefault);
@@ -283,6 +285,12 @@ void OrtBackend::OrtValueToFDTensor(const Ort::Value& value, FDTensor* tensor,
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
dtype = FDDataType::FP16;
numel *= sizeof(float16);
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) {
dtype = FDDataType::UINT8;
numel *= sizeof(uint8_t);
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
dtype = FDDataType::INT8;
numel *= sizeof(int8_t);
} else {
FDASSERT(
false,

View File

@@ -50,6 +50,10 @@ FDDataType GetFdDtype(const ONNXTensorElementDataType& ort_dtype) {
return FDDataType::INT64;
} else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
return FDDataType::FP16;
} else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) {
return FDDataType::UINT8;
} else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
return FDDataType::INT8;
}
FDERROR << "Unrecognized ort data type:" << ort_dtype << "." << std::endl;
return FDDataType::FP32;