feat: select adapter id for DirectML

This commit is contained in:
MistEO
2024-11-20 14:37:26 +08:00
parent 3bb05ac574
commit 2507a172f8
3 changed files with 6 additions and 3 deletions

View File

@@ -98,7 +98,7 @@ bool OrtBackend::BuildOption(const OrtBackendOption& option) {
"DML", ORT_API_VERSION, reinterpret_cast<const void**>(&ortDmlApi)); "DML", ORT_API_VERSION, reinterpret_cast<const void**>(&ortDmlApi));
OrtStatus* onnx_dml_status = OrtStatus* onnx_dml_status =
ortDmlApi->SessionOptionsAppendExecutionProvider_DML(session_options_, ortDmlApi->SessionOptionsAppendExecutionProvider_DML(session_options_,
0); option_.device_id);
if (onnx_dml_status != nullptr) { if (onnx_dml_status != nullptr) {
FDERROR FDERROR
<< "DirectML is not support in your machine, the program will exit." << "DirectML is not support in your machine, the program will exit."

View File

@@ -141,7 +141,10 @@ void RuntimeOption::UseAscend() {
paddle_lite_option.device = device; paddle_lite_option.device = device;
} }
void RuntimeOption::UseDirectML() { device = Device::DIRECTML; } void RuntimeOption::UseDirectML(int adapter_id) {
device = Device::DIRECTML;
device_id = adapter_id;
}
void RuntimeOption::UseSophgo() { void RuntimeOption::UseSophgo() {
device = Device::SOPHGOTPUD; device = Device::SOPHGOTPUD;

View File

@@ -82,7 +82,7 @@ struct FASTDEPLOY_DECL RuntimeOption {
void UseAscend(); void UseAscend();
/// Use onnxruntime DirectML to inference /// Use onnxruntime DirectML to inference
void UseDirectML(); void UseDirectML(int adapter_id = 0);
/// Use Sophgo to inference /// Use Sophgo to inference
void UseSophgo(); void UseSophgo();