diff --git a/fastdeploy/fastdeploy_model.h b/fastdeploy/fastdeploy_model.h index 1d7dd58ef..8bb12b91f 100755 --- a/fastdeploy/fastdeploy_model.h +++ b/fastdeploy/fastdeploy_model.h @@ -47,7 +47,7 @@ class FASTDEPLOY_DECL FastDeployModel { std::vector valid_timvx_backends = {}; /** Model's valid directml backends. This member defined all the onnxruntime directml backends have successfully tested for the model */ - std::vector valid_directml_backends = {}; + std::vector valid_directml_backends = {Backend::ORT}; /** Model's valid ascend backends. This member defined all the cann backends have successfully tested for the model */ std::vector valid_ascend_backends = {}; diff --git a/fastdeploy/runtime/backends/ort/ort_backend.cc b/fastdeploy/runtime/backends/ort/ort_backend.cc index db1f03e40..46bf01e9c 100644 --- a/fastdeploy/runtime/backends/ort/ort_backend.cc +++ b/fastdeploy/runtime/backends/ort/ort_backend.cc @@ -98,7 +98,7 @@ bool OrtBackend::BuildOption(const OrtBackendOption& option) { "DML", ORT_API_VERSION, reinterpret_cast(&ortDmlApi)); OrtStatus* onnx_dml_status = ortDmlApi->SessionOptionsAppendExecutionProvider_DML(session_options_, - 0); + option_.device_id); if (onnx_dml_status != nullptr) { FDERROR << "DirectML is not support in your machine, the program will exit." diff --git a/fastdeploy/runtime/runtime_option.cc b/fastdeploy/runtime/runtime_option.cc index a2a232ced..3552ef625 100644 --- a/fastdeploy/runtime/runtime_option.cc +++ b/fastdeploy/runtime/runtime_option.cc @@ -141,7 +141,10 @@ void RuntimeOption::UseAscend() { paddle_lite_option.device = device; } -void RuntimeOption::UseDirectML() { device = Device::DIRECTML; } +void RuntimeOption::UseDirectML(int adapter_id) { + device = Device::DIRECTML; + device_id = adapter_id; +} void RuntimeOption::UseSophgo() { device = Device::SOPHGOTPUD; diff --git a/fastdeploy/runtime/runtime_option.h b/fastdeploy/runtime/runtime_option.h index 205a2184c..0e43d9fb5 100755 --- a/fastdeploy/runtime/runtime_option.h +++ b/fastdeploy/runtime/runtime_option.h @@ -82,7 +82,7 @@ struct FASTDEPLOY_DECL RuntimeOption { void UseAscend(); /// Use onnxruntime DirectML to inference - void UseDirectML(); + void UseDirectML(int adapter_id = 0); /// Use Sophgo to inference void UseSophgo();