diff --git a/fastdeploy/backends/ort/ort_backend.cc b/fastdeploy/backends/ort/ort_backend.cc index 16814d910..3b89be6e5 100644 --- a/fastdeploy/backends/ort/ort_backend.cc +++ b/fastdeploy/backends/ort/ort_backend.cc @@ -63,6 +63,10 @@ void OrtBackend::BuildOption(const OrtBackendOption& option) { } else { OrtCUDAProviderOptions cuda_options; cuda_options.device_id = option.gpu_id; + if(option.external_stream_) { + cuda_options.has_user_compute_stream = 1; + cuda_options.user_compute_stream = option.external_stream_; + } session_options_.AppendExecutionProvider_CUDA(cuda_options); } } diff --git a/fastdeploy/backends/ort/ort_backend.h b/fastdeploy/backends/ort/ort_backend.h index 445d55540..31c769824 100644 --- a/fastdeploy/backends/ort/ort_backend.h +++ b/fastdeploy/backends/ort/ort_backend.h @@ -44,6 +44,7 @@ struct OrtBackendOption { int execution_mode = -1; bool use_gpu = false; int gpu_id = 0; + void* external_stream_ = nullptr; // inside parameter, maybe remove next version bool remove_multiclass_nms_ = false; @@ -66,7 +67,8 @@ class OrtBackend : public BaseBackend { const OrtBackendOption& option = OrtBackendOption(), bool from_memory_buffer = false); - bool Infer(std::vector& inputs, std::vector* outputs) override; + bool Infer(std::vector& inputs, + std::vector* outputs) override; int NumInputs() const override { return inputs_desc_.size(); } diff --git a/fastdeploy/backends/paddle/paddle_backend.cc b/fastdeploy/backends/paddle/paddle_backend.cc index 486c3347b..4e33bd441 100644 --- a/fastdeploy/backends/paddle/paddle_backend.cc +++ b/fastdeploy/backends/paddle/paddle_backend.cc @@ -22,6 +22,9 @@ void PaddleBackend::BuildOption(const PaddleBackendOption& option) { option_ = option; if (option.use_gpu) { config_.EnableUseGpu(option.gpu_mem_init_size, option.gpu_id); + if(option_.external_stream_) { + config_.SetExecStream(option_.external_stream_); + } if (option.enable_trt) { #ifdef ENABLE_TRT_BACKEND auto precision = paddle_infer::PrecisionType::kFloat32; diff --git a/fastdeploy/backends/paddle/paddle_backend.h b/fastdeploy/backends/paddle/paddle_backend.h index 1d4f53db3..e29a5a724 100755 --- a/fastdeploy/backends/paddle/paddle_backend.h +++ b/fastdeploy/backends/paddle/paddle_backend.h @@ -54,6 +54,7 @@ struct PaddleBackendOption { // gpu device id int gpu_id = 0; bool enable_pinned_memory = false; + void* external_stream_ = nullptr; std::vector delete_pass_names = {}; }; diff --git a/fastdeploy/backends/tensorrt/trt_backend.cc b/fastdeploy/backends/tensorrt/trt_backend.cc index 100ce6f7d..563901254 100755 --- a/fastdeploy/backends/tensorrt/trt_backend.cc +++ b/fastdeploy/backends/tensorrt/trt_backend.cc @@ -258,8 +258,12 @@ bool TrtBackend::InitFromOnnx(const std::string& model_file, ReaderDtypeToTrtDtype(onnx_reader.outputs[i].dtype); } - FDASSERT(cudaStreamCreate(&stream_) == 0, + if (option_.external_stream_) { + stream_ = reinterpret_cast(option_.external_stream_); + } else { + FDASSERT(cudaStreamCreate(&stream_) == 0, "[ERROR] Error occurs while calling cudaStreamCreate()."); + } if (!CreateTrtEngineFromOnnx(onnx_content)) { FDERROR << "Failed to create tensorrt engine." << std::endl; diff --git a/fastdeploy/backends/tensorrt/trt_backend.h b/fastdeploy/backends/tensorrt/trt_backend.h index 09f18b2df..0aebba717 100755 --- a/fastdeploy/backends/tensorrt/trt_backend.h +++ b/fastdeploy/backends/tensorrt/trt_backend.h @@ -71,6 +71,7 @@ struct TrtBackendOption { std::map> opt_shape; std::string serialize_file = ""; bool enable_pinned_memory = false; + void* external_stream_ = nullptr; // inside parameter, maybe remove next version bool remove_multiclass_nms_ = false; diff --git a/fastdeploy/pybind/runtime.cc b/fastdeploy/pybind/runtime.cc index 70f9a5917..cde6c5b2d 100755 --- a/fastdeploy/pybind/runtime.cc +++ b/fastdeploy/pybind/runtime.cc @@ -22,6 +22,7 @@ void BindRuntime(pybind11::module& m) { .def("set_model_path", &RuntimeOption::SetModelPath) .def("use_gpu", &RuntimeOption::UseGpu) .def("use_cpu", &RuntimeOption::UseCpu) + .def("set_external_stream", &RuntimeOption::SetExternalStream) .def("set_cpu_thread_num", &RuntimeOption::SetCpuThreadNum) .def("use_paddle_backend", &RuntimeOption::UsePaddleBackend) .def("use_poros_backend", &RuntimeOption::UsePorosBackend) @@ -52,6 +53,7 @@ void BindRuntime(pybind11::module& m) { .def_readwrite("params_file", &RuntimeOption::params_file) .def_readwrite("model_format", &RuntimeOption::model_format) .def_readwrite("backend", &RuntimeOption::backend) + .def_readwrite("backend", &RuntimeOption::external_stream_) .def_readwrite("cpu_thread_num", &RuntimeOption::cpu_thread_num) .def_readwrite("device_id", &RuntimeOption::device_id) .def_readwrite("device", &RuntimeOption::device) diff --git a/fastdeploy/runtime.cc b/fastdeploy/runtime.cc index 5037dc120..087b57755 100755 --- a/fastdeploy/runtime.cc +++ b/fastdeploy/runtime.cc @@ -223,6 +223,10 @@ void RuntimeOption::UseGpu(int gpu_id) { void RuntimeOption::UseCpu() { device = Device::CPU; } +void RuntimeOption::SetExternalStream(void* external_stream) { + external_stream_ = external_stream; +} + void RuntimeOption::SetCpuThreadNum(int thread_num) { FDASSERT(thread_num > 0, "The thread_num must be greater than 0."); cpu_thread_num = thread_num; @@ -508,6 +512,7 @@ void Runtime::CreatePaddleBackend() { pd_option.delete_pass_names = option.pd_delete_pass_names; pd_option.cpu_thread_num = option.cpu_thread_num; pd_option.enable_pinned_memory = option.enable_pinned_memory; + pd_option.external_stream_ = option.external_stream_; #ifdef ENABLE_TRT_BACKEND if (pd_option.use_gpu && option.pd_enable_trt) { pd_option.enable_trt = true; @@ -574,6 +579,7 @@ void Runtime::CreateOrtBackend() { ort_option.execution_mode = option.ort_execution_mode; ort_option.use_gpu = (option.device == Device::GPU) ? true : false; ort_option.gpu_id = option.device_id; + ort_option.external_stream_ = option.external_stream_; // TODO(jiangjiajun): inside usage, maybe remove this later ort_option.remove_multiclass_nms_ = option.remove_multiclass_nms_; @@ -613,6 +619,7 @@ void Runtime::CreateTrtBackend() { trt_option.opt_shape = option.trt_opt_shape; trt_option.serialize_file = option.trt_serialize_file; trt_option.enable_pinned_memory = option.enable_pinned_memory; + trt_option.external_stream_ = option.external_stream_; // TODO(jiangjiajun): inside usage, maybe remove this later trt_option.remove_multiclass_nms_ = option.remove_multiclass_nms_; diff --git a/fastdeploy/runtime.h b/fastdeploy/runtime.h index 021103cb2..26628217b 100755 --- a/fastdeploy/runtime.h +++ b/fastdeploy/runtime.h @@ -102,6 +102,8 @@ struct FASTDEPLOY_DECL RuntimeOption { /// Use Nvidia GPU to inference void UseGpu(int gpu_id = 0); + void SetExternalStream(void* external_stream); + /* * @brief Set number of cpu threads while inference on CPU, by default it will decided by the different backends */ @@ -232,6 +234,8 @@ struct FASTDEPLOY_DECL RuntimeOption { Device device = Device::CPU; + void* external_stream_ = nullptr; + bool enable_pinned_memory = false; // ======Only for ORT Backend======== diff --git a/serving/src/fastdeploy_runtime.cc b/serving/src/fastdeploy_runtime.cc index 4d9a6d17b..748633b58 100644 --- a/serving/src/fastdeploy_runtime.cc +++ b/serving/src/fastdeploy_runtime.cc @@ -379,6 +379,7 @@ TRITONSERVER_Error* ModelState::LoadModel( if ((instance_group_kind == TRITONSERVER_INSTANCEGROUPKIND_GPU) || (instance_group_kind == TRITONSERVER_INSTANCEGROUPKIND_AUTO)) { runtime_options_->UseGpu(instance_group_device_id); + runtime_options_->SetExternalStream((void*)stream); } else { runtime_options_->UseCpu(); } @@ -1001,9 +1002,7 @@ TRITONSERVER_Error* ModelInstanceState::Run( runtime_->Infer(input_tensors_, &output_tensors_); #ifdef TRITON_ENABLE_GPU if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) { - // TODO: stream controll - cudaDeviceSynchronize(); - // cudaStreamSynchronize(CudaStream()); + cudaStreamSynchronize(CudaStream()); } #endif return nullptr;