diff --git a/fastdeploy/runtime/backends/ort/ort_backend.cc b/fastdeploy/runtime/backends/ort/ort_backend.cc index 3938187cc..a8aa1e7cd 100644 --- a/fastdeploy/runtime/backends/ort/ort_backend.cc +++ b/fastdeploy/runtime/backends/ort/ort_backend.cc @@ -239,7 +239,8 @@ bool OrtBackend::InitFromOnnx(const std::string& model_file, type_info.GetTensorTypeAndShapeInfo().GetShape(); ONNXTensorElementDataType data_type = type_info.GetTensorTypeAndShapeInfo().GetElementType(); - inputs_desc_.emplace_back(OrtValueInfo{input_name_ptr.get(), shape, data_type}); + inputs_desc_.emplace_back( + OrtValueInfo{input_name_ptr.get(), shape, data_type}); } size_t n_outputs = session_.GetOutputCount(); @@ -250,7 +251,8 @@ bool OrtBackend::InitFromOnnx(const std::string& model_file, type_info.GetTensorTypeAndShapeInfo().GetShape(); ONNXTensorElementDataType data_type = type_info.GetTensorTypeAndShapeInfo().GetElementType(); - outputs_desc_.emplace_back(OrtValueInfo{output_name_ptr.get(), shape, data_type}); + outputs_desc_.emplace_back( + OrtValueInfo{output_name_ptr.get(), shape, data_type}); Ort::MemoryInfo out_memory_info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); @@ -283,6 +285,12 @@ void OrtBackend::OrtValueToFDTensor(const Ort::Value& value, FDTensor* tensor, } else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) { dtype = FDDataType::FP16; numel *= sizeof(float16); + } else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) { + dtype = FDDataType::UINT8; + numel *= sizeof(uint8_t); + } else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) { + dtype = FDDataType::INT8; + numel *= sizeof(int8_t); } else { FDASSERT( false, diff --git a/fastdeploy/runtime/backends/ort/utils.cc b/fastdeploy/runtime/backends/ort/utils.cc index 52f6247b6..1ec11b600 100644 --- a/fastdeploy/runtime/backends/ort/utils.cc +++ b/fastdeploy/runtime/backends/ort/utils.cc @@ -50,6 +50,10 @@ FDDataType GetFdDtype(const ONNXTensorElementDataType& ort_dtype) { return FDDataType::INT64; } else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) { return FDDataType::FP16; + } else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) { + return FDDataType::UINT8; + } else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) { + return FDDataType::INT8; } FDERROR << "Unrecognized ort data type:" << ort_dtype << "." << std::endl; return FDDataType::FP32;