[Backend] add sophgo backend (#1015)

* Add Sophgo Device

add sophgo backend in fastdeploy

add resnet50, yolov5s, liteseg examples.

* replace sophgo lib with download links; fix model.cc bug

* modify CodeStyle

* remove unuseful files;change the names of sophgo device and sophgo
backend

* sophgo support python and add python examples

* remove unuseful rows in cmake according pr

Co-authored-by: Zilong Xing <zilong.xing@sophgo.com>
This commit is contained in:
Dantès
2023-01-04 15:49:17 +08:00
committed by GitHub
parent 0c292c0766
commit 34bea7649d
41 changed files with 1583 additions and 9 deletions

View File

@@ -0,0 +1,290 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/backends/sophgo/sophgo_backend.h"
#include <assert.h>
namespace fastdeploy {
SophgoBackend::~SophgoBackend() {
bm_dev_free(handle_);
}
/***************************************************************
* @name GetSDKAndDeviceVersion
* @brief get Sophgo sdk and device version
* @param None
* @return bool
* @note None
***************************************************************/
bool SophgoBackend::GetSDKAndDeviceVersion() {
return true;
}
/***************************************************************
* @name BuildOption
* @brief save option
* @param SOPHGOTPU2BackendOption
* @note None
***************************************************************/
void SophgoBackend::BuildOption(const SophgoBackendOption& option) {
// this->option_ = option;
// save cpu_name
// this->option_.cpu_name = option.cpu_name;
}
/***************************************************************
* @name InitFromSophgo
* @brief Initialize Sophgo model
* @param model_file: Binary data for the Sophgo model.
* params_file: None
* option: config
* @return bool
* @note None
***************************************************************/
bool SophgoBackend::InitFromSophgo(const std::string& model_file,
const SophgoBackendOption& option) {
// LoadModel
if (!this->LoadModel((char*)model_file.data())) {
FDERROR << "load model failed" << std::endl;
return false;
}
// GetSDKAndDeviceVersion
if (!this->GetSDKAndDeviceVersion()) {
FDERROR << "get SDK and device version failed" << std::endl;
return false;
}
// BuildOption
this->BuildOption(option);
// GetModelInputOutputInfos
if (!this->GetModelInputOutputInfos()) {
FDERROR << "get model input output infos failed" << std::endl;
return false;
}
return true;
}
/***************************************************************
* @name LoadModel
* @brief read Sophgo bmodel
* @param model: Binary data for the Sophgo model.
* @return bool
* @note None
***************************************************************/
bool SophgoBackend::LoadModel(void* model) {
unsigned int card_num = 0;
bm_status_t status = bm_get_card_num(&card_num);
status = bm_dev_request(&handle_, 0);
p_bmrt_ = bmrt_create(handle_);
assert(NULL != p_bmrt_);
bool load_status = bmrt_load_bmodel(p_bmrt_, (char*)model);
assert(load_status);
int network_num = bmrt_get_network_number(p_bmrt_);
const char **net_names = NULL;
bmrt_get_network_names(p_bmrt_, &net_names);
net_name_ = net_names[0];
free(net_names);
net_info_ = bmrt_get_network_info(p_bmrt_, net_name_.c_str());
assert(NULL != net_info_);
return true;
}
/***************************************************************
* @name GetModelInputOutputInfos
* @brief Get the detailed input and output infos of Model
* @param None
* @return bool
* @note None
***************************************************************/
bool SophgoBackend::GetModelInputOutputInfos() {
inputs_desc_.resize(net_info_->input_num);
bm_shape_t* input_shapes = net_info_->stages->input_shapes;
for(int idx=0; idx<net_info_->input_num; idx++){
std::string temp_name = (net_info_->input_names)[idx];
std::vector<int> temp_shape{};
temp_shape.resize(input_shapes[idx].num_dims);
for(int i=0; i<input_shapes[idx].num_dims; i++){
temp_shape[i] = input_shapes[idx].dims[i];
}
bm_data_type_t* input_dtypes = net_info_->input_dtypes;
//SophgoType to FDDataType
FDDataType temp_dtype = SophgoTensorTypeToFDDataType(*input_dtypes);
TensorInfo temp_input_info = {temp_name, temp_shape, temp_dtype};
inputs_desc_[idx] = temp_input_info;
}
outputs_desc_.resize(net_info_->output_num);
bm_shape_t* output_shapes = net_info_->stages->output_shapes;
for(int idx=0; idx<net_info_->output_num; idx++){
std::string temp_name1 = (net_info_->output_names)[idx];
std::vector<int> temp_shape1{};
temp_shape1.resize(output_shapes[idx].num_dims);
for(int i=0; i<output_shapes[idx].num_dims; i++){
temp_shape1[i] = output_shapes[idx].dims[i];
}
bm_data_type_t* output_dtypes = net_info_->output_dtypes;
//SophgoType to FDDataType
FDDataType temp_dtype1 = SophgoTensorTypeToFDDataType(*output_dtypes);
TensorInfo temp_output_info = {temp_name1, temp_shape1, temp_dtype1};
outputs_desc_[idx] = temp_output_info;
}
return true;
}
TensorInfo SophgoBackend::GetInputInfo(int index) {
FDASSERT(index < NumInputs(),
"The index: %d should less than the number of inputs: %d.", index,
NumInputs())
return inputs_desc_[index];
}
std::vector<TensorInfo> SophgoBackend::GetInputInfos() { return inputs_desc_; }
TensorInfo SophgoBackend::GetOutputInfo(int index) {
FDASSERT(index < NumOutputs(),
"The index: %d should less than the number of outputs %d.", index,
NumOutputs())
return outputs_desc_[index];
}
std::vector<TensorInfo> SophgoBackend::GetOutputInfos() { return outputs_desc_; }
bool SophgoBackend::Infer(std::vector<FDTensor>& inputs,
std::vector<FDTensor>* outputs,
bool copy_to_fd) {
int input_size = inputs.size();
assert(input_size != 0);
assert(input_size == NumInputs());
bm_tensor_t input_tensors[input_size];
bm_status_t status = BM_SUCCESS;
bm_data_type_t* input_dtypes = net_info_->input_dtypes;
for(int i=0;i<input_size;i++){
status = bm_malloc_device_byte(handle_,
&input_tensors[i].device_mem,net_info_->max_input_bytes[i]);
assert(BM_SUCCESS == status);
input_tensors[i].dtype = input_dtypes[i];
input_tensors[i].st_mode = BM_STORE_1N;
input_tensors[i].shape = *(net_info_->stages[i].input_shapes);
unsigned int input_byte = bmrt_tensor_bytesize(&input_tensors[i]);
bm_memcpy_s2d_partial(handle_, input_tensors[i].device_mem, (void *)inputs[i].Data(),
bmrt_tensor_bytesize(&input_tensors[i]));
}
int output_size = NumOutputs();
bm_tensor_t output_tensors[output_size];
for(int i=0;i<output_size;i++){
status = bm_malloc_device_byte(handle_, &output_tensors[i].device_mem,
net_info_->max_output_bytes[i]);
assert(BM_SUCCESS == status);
}
bool launch_status = bmrt_launch_tensor_ex(p_bmrt_, net_name_.c_str(), input_tensors, net_info_->input_num,
output_tensors, net_info_->output_num, true, false);
assert(launch_status);
status = bm_thread_sync(handle_);
assert(status == BM_SUCCESS);
outputs->resize(outputs_desc_.size());
bm_data_type_t* output_dtypes = net_info_->output_dtypes;
for(int i=0;i<output_size;i++){
int temp_bytesize = bmrt_tensor_bytesize(&output_tensors[i]); //Byte
float *temp_out = (float *)malloc(temp_bytesize);
bm_memcpy_d2s_partial(handle_, temp_out, output_tensors[i].device_mem, temp_bytesize);
std::vector<int64_t> temp_shape;
temp_shape.resize(outputs_desc_[i].shape.size());
for (int j = 0; j < outputs_desc_[i].shape.size(); ++j) {
temp_shape[j] = outputs_desc_[i].shape[j];
}
(*outputs)[i].Resize(temp_shape, outputs_desc_[i].dtype, outputs_desc_[i].name);
memcpy((*outputs)[i].MutableData(), temp_out, (*outputs)[i].Nbytes());
free(temp_out);
}
return true;
}
/***************************************************************
* @name SophgoTensorTypeToFDDataType
* @brief Change SophgoTensorType To FDDataType
* @param bm_data_type_t
* @return None
* @note None
***************************************************************/
FDDataType SophgoBackend::SophgoTensorTypeToFDDataType(bm_data_type_t type) {
if (type == BM_FLOAT16) {
return FDDataType::FP32;
}
if (type == BM_FLOAT32) {
return FDDataType::FP32;
}
if (type == BM_INT8) {
return FDDataType::INT8;
}
if (type == BM_INT16) {
return FDDataType::INT16;
}
if (type == BM_INT32) {
return FDDataType::INT32;
}
if (type == BM_UINT8) {
return FDDataType::UINT8;
}
FDERROR << "FDDataType don't support this type" << std::endl;
return FDDataType::UNKNOWN1;
}
/***************************************************************
* @name FDDataTypeToSophgoTensorType
* @brief Change FDDataType To SophgoTensorType
* @param FDDataType
* @return None
* @note None
***************************************************************/
// Sophgo_tensor_type
bm_data_type_t SophgoBackend::FDDataTypeToSophgoTensorType(fastdeploy::FDDataType type) {
if (type == FDDataType::FP16) {
return BM_FLOAT16;
}
if (type == FDDataType::FP32) {
return BM_FLOAT32;
}
if (type == FDDataType::INT8) {
return BM_INT8;
}
if (type == FDDataType::INT16) {
return BM_INT16;
}
if (type == FDDataType::INT32) {
return BM_INT32;
}
if (type == FDDataType::UINT8) {
return BM_UINT8;
}
FDERROR << "Sophgo_tensor_type don't support this type" << std::endl;
return BM_FLOAT32;
}
}

View File

@@ -0,0 +1,75 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/backends/backend.h"
#include "fastdeploy/core/fd_tensor.h"
#include "bmruntime_interface.h" // NOLINT
#include "bmlib_runtime.h" // NOLINT
#include "fastdeploy/backends/sophgo/sophgo_config.h"
#include <cstring>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
namespace fastdeploy {
struct SophgoBackendOption{
};
class SophgoBackend : public BaseBackend {
public:
SophgoBackend() = default;
virtual ~SophgoBackend();
bool LoadModel(void* model);
bool GetSDKAndDeviceVersion();
bool GetModelInputOutputInfos();
void BuildOption(const SophgoBackendOption& option);
bool InitFromSophgo(const std::string& model_file,
const SophgoBackendOption& option = SophgoBackendOption());
int NumInputs() const override {
return static_cast<int>(inputs_desc_.size());
}
int NumOutputs() const override {
return static_cast<int>(outputs_desc_.size());
}
TensorInfo GetInputInfo(int index) override;
TensorInfo GetOutputInfo(int index) override;
std::vector<TensorInfo> GetInputInfos() override;
std::vector<TensorInfo> GetOutputInfos() override;
bool Infer(std::vector<FDTensor>& inputs,
std::vector<FDTensor>* outputs,
bool copy_to_fd = true) override;
private:
std::vector<TensorInfo> inputs_desc_;
std::vector<TensorInfo> outputs_desc_;
std::string net_name_;
bm_handle_t handle_;
void * p_bmrt_ = nullptr;
bool infer_init = false;
const bm_net_info_t* net_info_ = nullptr;
// SophgoTPU2BackendOption option_;
static FDDataType SophgoTensorTypeToFDDataType(bm_data_type_t type);
static bm_data_type_t FDDataTypeToSophgoTensorType(FDDataType type);
};
} // namespace fastdeploy

View File

@@ -56,6 +56,9 @@ std::string Str(const Device& d) {
case Device::RKNPU:
out = "Device::RKNPU";
break;
case Device::SOPHGOTPUD:
out = "Device::SOPHGOTPUD";
break;
case Device::IPU:
out = "Device::IPU";
break;
@@ -85,6 +88,9 @@ std::ostream& operator<<(std::ostream& out,const Device& d){
case Device::RKNPU:
out << "Device::RKNPU";
break;
case Device::SOPHGOTPUD:
out << "Device::SOPHGOTPUD";
break;
case Device::TIMVX:
out << "Device::TIMVX";
break;
@@ -205,8 +211,10 @@ std::string Str(const ModelFormat& f) {
return "ModelFormat::PADDLE";
} else if (f == ModelFormat::ONNX) {
return "ModelFormat::ONNX";
}else if (f == ModelFormat::RKNN) {
} else if (f == ModelFormat::RKNN) {
return "ModelFormat::RKNN";
} else if (f == ModelFormat::SOPHGO) {
return "ModelFormat::SOPHGO";
} else if (f == ModelFormat::TORCHSCRIPT) {
return "ModelFormat::TORCHSCRIPT";
}
@@ -220,6 +228,8 @@ std::ostream& operator<<(std::ostream& out, const ModelFormat& format) {
out << "ModelFormat::ONNX";
} else if (format == ModelFormat::RKNN) {
out << "ModelFormat::RKNN";
} else if (format == ModelFormat::SOPHGO) {
out << "ModelFormat::SOPHGO";
} else if (format == ModelFormat::TORCHSCRIPT) {
out << "ModelFormat::TORCHSCRIPT";
}

View File

@@ -22,7 +22,8 @@
namespace fastdeploy {
enum FASTDEPLOY_DECL Device { CPU, GPU, RKNPU, IPU, TIMVX, KUNLUNXIN, ASCEND};
enum FASTDEPLOY_DECL Device {CPU, GPU, RKNPU, IPU, TIMVX, KUNLUNXIN, ASCEND,
SOPHGOTPUD};
FASTDEPLOY_DECL std::string Str(const Device& d);
@@ -72,6 +73,7 @@ enum ModelFormat {
ONNX, ///< Model with ONNX format
RKNN, ///< Model with RKNN format
TORCHSCRIPT, ///< Model with TorchScript format
SOPHGO, ///< Model with SOPHGO format
};
FASTDEPLOY_DECL std::ostream& operator<<(std::ostream& out,

View File

@@ -50,6 +50,7 @@ bool FastDeployModel::InitRuntimeWithSpecifiedBackend() {
bool use_gpu = (runtime_option.device == Device::GPU);
bool use_ipu = (runtime_option.device == Device::IPU);
bool use_rknpu = (runtime_option.device == Device::RKNPU);
bool use_sophgotpu = (runtime_option.device == Device::SOPHGOTPUD);
bool use_timvx = (runtime_option.device == Device::TIMVX);
bool use_ascend = (runtime_option.device == Device::ASCEND);
bool use_kunlunxin = (runtime_option.device == Device::KUNLUNXIN);
@@ -64,6 +65,11 @@ bool FastDeployModel::InitRuntimeWithSpecifiedBackend() {
FDERROR << "The valid rknpu backends of model " << ModelName() << " are " << Str(valid_rknpu_backends) << ", " << runtime_option.backend << " is not supported." << std::endl;
return false;
}
} else if (use_sophgotpu) {
if (!IsSupported(valid_sophgonpu_backends, runtime_option.backend)) {
FDERROR << "The valid rknpu backends of model " << ModelName() << " are " << Str(valid_rknpu_backends) << ", " << runtime_option.backend << " is not supported." << std::endl;
return false;
}
} else if (use_timvx) {
if (!IsSupported(valid_timvx_backends, runtime_option.backend)) {
FDERROR << "The valid timvx backends of model " << ModelName() << " are " << Str(valid_timvx_backends) << ", " << runtime_option.backend << " is not supported." << std::endl;
@@ -118,6 +124,8 @@ bool FastDeployModel::InitRuntimeWithSpecifiedDevice() {
return CreateASCENDBackend();
} else if (runtime_option.device == Device::KUNLUNXIN) {
return CreateKunlunXinBackend();
} else if (runtime_option.device == Device::SOPHGOTPUD) {
return CreateSophgoNPUBackend();
} else if (runtime_option.device == Device::IPU) {
#ifdef WITH_IPU
return CreateIpuBackend();
@@ -218,6 +226,30 @@ bool FastDeployModel::CreateRKNPUBackend() {
return false;
}
bool FastDeployModel::CreateSophgoNPUBackend() {
if (valid_sophgonpu_backends.empty()) {
FDERROR << "There's no valid npu backends for model: " << ModelName()
<< std::endl;
return false;
}
for (size_t i = 0; i < valid_sophgonpu_backends.size(); ++i) {
if (!IsBackendAvailable(valid_sophgonpu_backends[i])) {
continue;
}
runtime_option.backend = valid_sophgonpu_backends[i];
runtime_ = std::unique_ptr<Runtime>(new Runtime());
if (!runtime_->Init(runtime_option)) {
return false;
}
runtime_initialized_ = true;
return true;
}
FDERROR << "Cannot find an available npu backend to load this model."
<< std::endl;
return false;
}
bool FastDeployModel::CreateTimVXBackend() {
if (valid_timvx_backends.size() == 0) {
FDERROR << "There's no valid timvx backends for model: " << ModelName()

View File

@@ -54,6 +54,9 @@ class FASTDEPLOY_DECL FastDeployModel {
/** Model's valid hardware backends. This member defined all the gpu backends have successfully tested for the model
*/
std::vector<Backend> valid_rknpu_backends = {};
/** Model's valid hardware backends. This member defined all the sophgo npu backends have successfully tested for the model
*/
std::vector<Backend> valid_sophgonpu_backends = {};
/// Get number of inputs for this model
virtual int NumInputsOfRuntime() { return runtime_->NumInputs(); }
@@ -148,6 +151,7 @@ class FASTDEPLOY_DECL FastDeployModel {
bool CreateGpuBackend();
bool CreateIpuBackend();
bool CreateRKNPUBackend();
bool CreateSophgoNPUBackend();
bool CreateTimVXBackend();
bool CreateKunlunXinBackend();
bool CreateASCENDBackend();

View File

@@ -24,6 +24,7 @@ void BindRuntime(pybind11::module& m) {
.def("use_gpu", &RuntimeOption::UseGpu)
.def("use_cpu", &RuntimeOption::UseCpu)
.def("use_rknpu2", &RuntimeOption::UseRKNPU2)
.def("use_sophgo", &RuntimeOption::UseSophgo)
.def("use_ascend", &RuntimeOption::UseAscend)
.def("use_kunlunxin", &RuntimeOption::UseKunlunXin)
.def("set_external_stream", &RuntimeOption::SetExternalStream)
@@ -241,19 +242,22 @@ void BindRuntime(pybind11::module& m) {
.value("POROS", Backend::POROS)
.value("PDINFER", Backend::PDINFER)
.value("RKNPU2", Backend::RKNPU2)
.value("SOPHGOTPU", Backend::SOPHGOTPU)
.value("LITE", Backend::LITE);
pybind11::enum_<ModelFormat>(m, "ModelFormat", pybind11::arithmetic(),
"ModelFormat for inference.")
.value("PADDLE", ModelFormat::PADDLE)
.value("TORCHSCRIPT", ModelFormat::TORCHSCRIPT)
.value("RKNN", ModelFormat::RKNN)
.value("SOPHGO", ModelFormat::SOPHGO)
.value("ONNX", ModelFormat::ONNX);
pybind11::enum_<Device>(m, "Device", pybind11::arithmetic(),
"Device for inference.")
.value("CPU", Device::CPU)
.value("GPU", Device::GPU)
.value("IPU", Device::IPU)
.value("RKNPU", Device::RKNPU);
.value("RKNPU", Device::RKNPU)
.value("SOPHGOTPU", Device::SOPHGOTPUD);
pybind11::enum_<FDDataType>(m, "FDDataType", pybind11::arithmetic(),
"Data type of FastDeploy.")

View File

@@ -45,6 +45,10 @@
#include "fastdeploy/backends/rknpu/rknpu2/rknpu2_backend.h"
#endif
#ifdef ENABLE_SOPHGO_BACKEND
#include "fastdeploy/backends/sophgo/sophgo_backend.h"
#endif
namespace fastdeploy {
std::vector<Backend> GetAvailableBackends() {
@@ -69,6 +73,9 @@ std::vector<Backend> GetAvailableBackends() {
#endif
#ifdef ENABLE_RKNPU2_BACKEND
backends.push_back(Backend::RKNPU2);
#endif
#ifdef ENABLE_SOPHGO_BACKEND
backends.push_back(Backend::SOPHGOTPU);
#endif
return backends;
}
@@ -94,6 +101,8 @@ std::string Str(const Backend& b) {
return "Backend::POROS";
} else if (b == Backend::RKNPU2) {
return "Backend::RKNPU2";
} else if (b == Backend::SOPHGOTPU) {
return "Backend::SOPHGOTPU";
} else if (b == Backend::OPENVINO) {
return "Backend::OPENVINO";
} else if (b == Backend::LITE) {
@@ -113,6 +122,8 @@ std::ostream& operator<<(std::ostream& out, const Backend& backend) {
out << "Backend::OPENVINO";
} else if (backend == Backend::RKNPU2) {
out << "Backend::RKNPU2";
} else if (backend == Backend::SOPHGOTPU) {
out << "Backend::SOPHGOTPU";
} else if (backend == Backend::POROS) {
out << "Backend::POROS";
} else if (backend == Backend::LITE) {
@@ -158,6 +169,15 @@ bool CheckModelFormat(const std::string& model_file,
<< model_file << std::endl;
return false;
}
} else if (model_format == ModelFormat::SOPHGO) {
if (model_file.size() < 7 ||
model_file.substr(model_file.size() -7, 7) != ".bmodel") {
FDERROR
<< "With model format of ModelFormat::SOPHGO, the model file "
"should ends with `.bmodel`, but now it's "
<< model_file << std::endl;
return false;
}
} else {
FDERROR
<< "Only support model format with frontend ModelFormat::PADDLE / "
@@ -185,6 +205,10 @@ ModelFormat GuessModelFormat(const std::string& model_file) {
model_file.substr(model_file.size() - 5, 5) == ".rknn") {
FDINFO << "Model Format: RKNN." << std::endl;
return ModelFormat::RKNN;
} else if (model_file.size() > 7 &&
model_file.substr(model_file.size() - 7, 7) == ".bmodel") {
FDINFO << "Model Format: SOPHGO." << std::endl;
return ModelFormat::SOPHGO;
}
FDERROR << "Cannot guess which model format you are using, please set "
@@ -288,6 +312,11 @@ void RuntimeOption::UseAscend(){
device = Device::ASCEND;
}
void RuntimeOption::UseSophgo() {
device = Device::SOPHGOTPUD;
UseSophgoBackend();
}
void RuntimeOption::SetExternalStream(void* external_stream) {
external_stream_ = external_stream;
}
@@ -323,6 +352,15 @@ void RuntimeOption::UseOrtBackend() {
#endif
}
// use sophgoruntime backend
void RuntimeOption::UseSophgoBackend() {
#ifdef ENABLE_SOPHGO_BACKEND
backend = Backend::SOPHGOTPU;
#else
FDASSERT(false, "The FastDeploy didn't compile with SophgoBackend.");
#endif
}
// use poros backend
void RuntimeOption::UsePorosBackend() {
#ifdef ENABLE_POROS_BACKEND
@@ -564,6 +602,8 @@ bool Runtime::Init(const RuntimeOption& _option) {
option.backend = Backend::OPENVINO;
} else if (IsBackendAvailable(Backend::RKNPU2)) {
option.backend = Backend::RKNPU2;
} else if (IsBackendAvailable(Backend::SOPHGOTPU)) {
option.backend = Backend::SOPHGOTPU;
} else {
FDERROR << "Please define backend in RuntimeOption, current it's "
"Backend::UNKNOWN."
@@ -623,7 +663,15 @@ bool Runtime::Init(const RuntimeOption& _option) {
FDINFO << "Runtime initialized with Backend::RKNPU2 in "
<< Str(option.device) << "." << std::endl;
} else {
} else if (option.backend == Backend::SOPHGOTPU) {
FDASSERT(option.device == Device::SOPHGOTPUD,
"Backend::SOPHGO only supports Device::SOPHGO");
CreateSophgoNPUBackend();
FDINFO << "Runtime initialized with Backend::SOPHGO in "
<< Str(option.device) << "." << std::endl;
}
else {
FDERROR << "Runtime only support "
"Backend::ORT/Backend::TRT/Backend::PDINFER/Backend::POROS as "
"backend now."
@@ -926,6 +974,21 @@ void Runtime::CreateRKNPU2Backend() {
#endif
}
void Runtime::CreateSophgoNPUBackend() {
#ifdef ENABLE_SOPHGO_BACKEND
auto sophgo_option = SophgoBackendOption();
FDASSERT(option.model_format == ModelFormat::SOPHGO,
"SophgoBackend only support model format of ModelFormat::SOPHGO");
backend_ = utils::make_unique<SophgoBackend>();
auto casted_backend = dynamic_cast<SophgoBackend*>(backend_.get());
FDASSERT(casted_backend->InitFromSophgo(option.model_file, sophgo_option),
"Load model from nb file failed while initializing LiteBackend.");
#else
FDASSERT(false, "SophgoBackend is not available, please compiled with "
"ENABLE_SOPHGO_BACKEND=ON.");
#endif
}
Runtime* Runtime::Clone(void* stream, int device_id) {
Runtime* runtime = new Runtime();
if (option.backend != Backend::OPENVINO &&

View File

@@ -43,6 +43,7 @@ enum Backend {
OPENVINO, ///< Intel OpenVINO, support Paddle/ONNX format, CPU only
LITE, ///< Paddle Lite, support Paddle format model, ARM CPU only
RKNPU2, ///< RKNPU2, support RKNN format model, Rockchip NPU only
SOPHGOTPU, ///< SOPHGOTPU, support SOPHGO format model, Sophgo TPU only
};
FASTDEPLOY_DECL std::ostream& operator<<(std::ostream& out,
@@ -151,6 +152,9 @@ struct FASTDEPLOY_DECL RuntimeOption {
bool adaptive_seqlen = false,
bool enable_multi_stream = false);
/// Use Sophgo to inference
void UseSophgo();
void SetExternalStream(void* external_stream);
/*
@@ -170,6 +174,9 @@ struct FASTDEPLOY_DECL RuntimeOption {
/// Set ONNX Runtime as inference backend, support CPU/GPU
void UseOrtBackend();
/// Set SOPHGO Runtime as inference backend, support CPU/GPU
void UseSophgoBackend();
/// Set TensorRT as inference backend, only support GPU
void UseTrtBackend();
@@ -576,6 +583,7 @@ struct FASTDEPLOY_DECL Runtime {
void CreateOpenVINOBackend();
void CreateLiteBackend();
void CreateRKNPU2Backend();
void CreateSophgoNPUBackend();
std::unique_ptr<BaseBackend> backend_;
std::vector<FDTensor> input_tensors_;
std::vector<FDTensor> output_tensors_;

View File

@@ -32,7 +32,10 @@ PaddleClasModel::PaddleClasModel(const std::string& model_file,
valid_ascend_backends = {Backend::LITE};
valid_kunlunxin_backends = {Backend::LITE};
valid_ipu_backends = {Backend::PDINFER};
} else {
}else if (model_format == ModelFormat::SOPHGO) {
valid_sophgonpu_backends = {Backend::SOPHGOTPU};
}
else {
valid_cpu_backends = {Backend::ORT, Backend::OPENVINO};
valid_gpu_backends = {Backend::ORT, Backend::TRT};
valid_rknpu_backends = {Backend::RKNPU2};

View File

@@ -24,6 +24,8 @@ YOLOv5::YOLOv5(const std::string& model_file, const std::string& params_file,
if (model_format == ModelFormat::ONNX) {
valid_cpu_backends = {Backend::OPENVINO, Backend::ORT};
valid_gpu_backends = {Backend::ORT, Backend::TRT};
} else if (model_format == ModelFormat::SOPHGO) {
valid_sophgonpu_backends = {Backend::SOPHGOTPU};
} else {
valid_cpu_backends = {Backend::PDINFER, Backend::ORT, Backend::LITE};
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};

View File

@@ -41,6 +41,7 @@ class FASTDEPLOY_DECL PicoDet : public PPDetBase {
valid_rknpu_backends = {Backend::RKNPU2};
valid_kunlunxin_backends = {Backend::LITE};
valid_ascend_backends = {Backend::LITE};
valid_sophgonpu_backends = {Backend::SOPHGOTPU};
initialized = Initialize();
}

View File

@@ -25,12 +25,18 @@ PaddleSegModel::PaddleSegModel(const std::string& model_file,
const RuntimeOption& custom_option,
const ModelFormat& model_format) : preprocessor_(config_file),
postprocessor_(config_file) {
valid_cpu_backends = {Backend::OPENVINO, Backend::PDINFER, Backend::ORT, Backend::LITE};
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
if(model_format == ModelFormat::SOPHGO) {
valid_sophgonpu_backends = {Backend::SOPHGOTPU};
}
else{
valid_cpu_backends = {Backend::OPENVINO, Backend::PDINFER, Backend::ORT, Backend::LITE};
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
}
valid_rknpu_backends = {Backend::RKNPU2};
valid_timvx_backends = {Backend::LITE};
valid_kunlunxin_backends = {Backend::LITE};
valid_ascend_backends = {Backend::LITE};
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;