From 4768a7ba691d15b4d33d0ab6793dee78f347d95d Mon Sep 17 00:00:00 2001 From: yalue Date: Mon, 28 Aug 2023 11:51:04 -0400 Subject: [PATCH] Add TensorRT Support - This change introduces support for enabling the TensorRT execution backend. It is configured in basically the same manner as the CUDA backend, with analogous APIs. - Added a unit test and benchmark for the TensorRT backend. To try it, run "go test -v -bench=." on a system where TensorRT is installed, and you're using a build of onnxruntime with TensorRT enabled. --- onnxruntime_go.go | 114 ++++++++++++++++++++++++++++++++++++------ onnxruntime_test.go | 84 +++++++++++++++++++++++++++++++ onnxruntime_wrapper.c | 19 +++++++ onnxruntime_wrapper.h | 14 ++++++ 4 files changed, 215 insertions(+), 16 deletions(-) diff --git a/onnxruntime_go.go b/onnxruntime_go.go index d14ac89..3d17bb1 100644 --- a/onnxruntime_go.go +++ b/onnxruntime_go.go @@ -389,8 +389,29 @@ type CUDAProviderOptions struct { o *C.OrtCUDAProviderOptionsV2 } +// Used when setting key-value pair options with certain obnoxious C APIs. +// The entries in each of the returned slices must be freed when they're +// no longer needed. +func mapToCStrings(options map[string]string) ([]*C.char, []*C.char) { + keys := make([]*C.char, 0, len(options)) + values := make([]*C.char, 0, len(options)) + for k, v := range options { + keys = append(keys, C.CString(k)) + values = append(values, C.CString(v)) + } + return keys, values +} + +// Calls free on each entry in the array of C strings. +func freeCStrings(s []*C.char) { + for i := range s { + C.free(unsafe.Pointer(s[i])) + s[i] = nil + } +} + // Wraps the call to the UpdateCUDAProviderOptions in the onnxruntime C API. -// Requires a list of string keys and values for configuring the CUDA backend. +// Requires a map of string keys to values for configuring the CUDA backend. // For example, set the key "device_id" to "1" to use GPU 1 rather than 0. // // The onnxruntime headers refer users to @@ -400,20 +421,9 @@ func (o *CUDAProviderOptions) Update(options map[string]string) error { if len(options) == 0 { return nil } - keys := make([]*C.char, 0, len(options)) - values := make([]*C.char, 0, len(options)) - for k, v := range options { - keys = append(keys, C.CString(k)) - values = append(values, C.CString(v)) - } - // We don't need these C strings as soon as UpdateCUDAProviderOptions - // returns. - defer func() { - for i := range keys { - C.free(unsafe.Pointer(keys[i])) - C.free(unsafe.Pointer(values[i])) - } - }() + keys, values := mapToCStrings(options) + defer freeCStrings(keys) + defer freeCStrings(values) status := C.UpdateCUDAProviderOptions(o.o, &(keys[0]), &(values[0]), C.int(len(options))) if status != nil { @@ -428,7 +438,7 @@ func (o *CUDAProviderOptions) Update(options map[string]string) error { // called. func (o *CUDAProviderOptions) Destroy() error { if o.o == nil { - return fmt.Errorf("The CUDAProviderOptions have not been initialized") + return fmt.Errorf("The CUDAProviderOptions are not initialized") } C.ReleaseCUDAProviderOptions(o.o) o.o = nil @@ -454,6 +464,64 @@ func NewCUDAProviderOptions() (*CUDAProviderOptions, error) { }, nil } +// Like the CUDAProviderOptions struct, but used for configuring TensorRT +// options. Instances of this struct must be initialized using +// NewTensorRTProviderOptions() and cleaned up by calling their Destroy() +// function when they are no longer needed. +type TensorRTProviderOptions struct { + o *C.OrtTensorRTProviderOptionsV2 +} + +// Wraps the call to the UpdateTensorRTProviderOptions in the C API. Requires +// a map of string keys to values. +// +// The onnxruntime headers refer users to +// https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#cc +// for the list of available keys and values. +func (o *TensorRTProviderOptions) Update(options map[string]string) error { + if len(options) == 0 { + return nil + } + keys, values := mapToCStrings(options) + defer freeCStrings(keys) + defer freeCStrings(values) + status := C.UpdateTensorRTProviderOptions(o.o, &(keys[0]), &(values[0]), + C.int(len(options))) + if status != nil { + return statusToError(status) + } + return nil +} + +// Must be called when the TensorRTProviderOptions are no longer needed, in +// order to free internal state. The struct is not needed as soon as you have +// passed it to the AppendExecutionProviderTensorRT function. +func (o *TensorRTProviderOptions) Destroy() error { + if o.o == nil { + return fmt.Errorf("The TensorRTProviderOptions are not initialized") + } + C.ReleaseTensorRTProviderOptions(o.o) + o.o = nil + return nil +} + +// Initializes and returns a TensorRTProviderOptions struct, used when enabling +// the TensorRT backend. The caller must call the Destroy() function on the +// returned struct when it's no longer needed. +func NewTensorRTProviderOptions() (*TensorRTProviderOptions, error) { + if !IsInitialized() { + return nil, NotInitializedError + } + var o *C.OrtTensorRTProviderOptionsV2 + status := C.CreateTensorRTProviderOptions(&o) + if status != nil { + return nil, statusToError(status) + } + return &TensorRTProviderOptions{ + o: o, + }, nil +} + // Used to set options when creating an ONNXRuntime session. There is currently // not a way to change options after the session is created, apart from // destroying the session and creating a new one. This struct opaquely wraps a @@ -516,6 +584,20 @@ func (o *SessionOptions) AppendExecutionProviderCUDA( return nil } +// Takes an initialized TensorRTProviderOptions instance, and applies them to +// the session options. You'll need to call this if you want the session to use +// TensorRT. Returns an error if your device (or onnxruntime library version) +// does not support TensorRT. The TensorRTProviderOptions can be destroyed +// after this. +func (o *SessionOptions) AppendExecutionProviderTensorRT( + tensorRTOptions *TensorRTProviderOptions) error { + status := C.AppendExecutionProviderTensorRTV2(o.o, tensorRTOptions.o) + if status != nil { + return statusToError(status) + } + return nil +} + // Initializes and returns a SessionOptions struct, used when setting options // in new AdvancedSession instances. The caller must call the Destroy() // function on the returned struct when it's no longer needed. diff --git a/onnxruntime_test.go b/onnxruntime_test.go index bf34267..4ed207a 100644 --- a/onnxruntime_test.go +++ b/onnxruntime_test.go @@ -812,3 +812,87 @@ func BenchmarkCUDASession(b *testing.B) { } } } + +// Creates a SessionOptions struct that's configured to enable TensorRT. +// Basically the same as getCUDASessionOptions; see the comments there. +func getTensorRTSessionOptions(t testing.TB) *SessionOptions { + trtOptions, e := NewTensorRTProviderOptions() + if e != nil { + t.Skipf("Error creating TensorRT provider options; %s. "+ + "Your version of the onnxruntime library may not include "+ + "TensorRT support. Skipping the remainder of this test.\n", e) + } + defer trtOptions.Destroy() + // Arbitrarily update an option to test trtOptions.Update() + e = trtOptions.Update( + map[string]string{"trt_max_partition_iterations": "60"}) + if e != nil { + t.Skipf("Error updating TensorRT options: %s. Your system may not "+ + "support TensorRT, TensorRT may be misconfigured, or it may be "+ + "incompatible with this build of onnxruntime. Skipping the "+ + "remainder of this test.\n", e) + } + sessionOptions, e := NewSessionOptions() + if e != nil { + t.Logf("Error creating SessionOptions: %s\n", e) + t.FailNow() + } + e = sessionOptions.AppendExecutionProviderTensorRT(trtOptions) + if e != nil { + t.Logf("Error setting TensorRT execution provider: %s\n", e) + t.FailNow() + } + return sessionOptions +} + +func TestTensorRTSession(t *testing.T) { + InitializeRuntime(t) + defer CleanupRuntime(t) + sessionOptions := getTensorRTSessionOptions(t) + defer sessionOptions.Destroy() + + input, output := prepareBenchmarkTensors(t, 1337) + defer input.Destroy() + defer output.Destroy() + session, e := NewAdvancedSession("test_data/example_big_compute.onnx", + []string{"Input"}, []string{"Output"}, []ArbitraryTensor{input}, + []ArbitraryTensor{output}, sessionOptions) + if e != nil { + t.Logf("Error creating session: %s\n", e) + t.FailNow() + } + defer session.Destroy() + e = session.Run() + if e != nil { + t.Logf("Error running session with TensorRT: %s\n", e) + t.FailNow() + } + t.Logf("Ran session with TensorRT OK.\n") +} + +func BenchmarkTensorRTSession(b *testing.B) { + b.StopTimer() + InitializeRuntime(b) + defer CleanupRuntime(b) + sessionOptions := getTensorRTSessionOptions(b) + defer sessionOptions.Destroy() + input, output := prepareBenchmarkTensors(b, benchmarkRNGSeed) + defer input.Destroy() + defer output.Destroy() + session, e := NewAdvancedSession("test_data/example_big_compute.onnx", + []string{"Input"}, []string{"Output"}, []ArbitraryTensor{input}, + []ArbitraryTensor{output}, sessionOptions) + if e != nil { + b.Logf("Error creating session: %s\n", e) + b.FailNow() + } + defer session.Destroy() + b.StartTimer() + for n := 0; n < b.N; n++ { + e = session.Run() + if e != nil { + b.Logf("Error running iteration %d/%d: %s\n", n+1, b.N, e) + b.FailNow() + } + } +} diff --git a/onnxruntime_wrapper.c b/onnxruntime_wrapper.c index 2d37979..ceda1dd 100644 --- a/onnxruntime_wrapper.c +++ b/onnxruntime_wrapper.c @@ -78,6 +78,25 @@ OrtStatus *UpdateCUDAProviderOptions(OrtCUDAProviderOptionsV2 *o, return ort_api->UpdateCUDAProviderOptions(o, keys, values, num_keys); } +OrtStatus *CreateTensorRTProviderOptions(OrtTensorRTProviderOptionsV2 **o) { + return ort_api->CreateTensorRTProviderOptions(o); +} + +void ReleaseTensorRTProviderOptions(OrtTensorRTProviderOptionsV2 *o) { + ort_api->ReleaseTensorRTProviderOptions(o); +} + +OrtStatus *UpdateTensorRTProviderOptions(OrtTensorRTProviderOptionsV2 *o, + const char **keys, const char **values, int num_keys) { + return ort_api->UpdateTensorRTProviderOptions(o, keys, values, num_keys); +} + +OrtStatus *AppendExecutionProviderTensorRTV2(OrtSessionOptions *o, + OrtTensorRTProviderOptionsV2 *tensor_rt_options) { + return ort_api->SessionOptionsAppendExecutionProvider_TensorRT_V2(o, + tensor_rt_options); +} + OrtStatus *CreateSession(void *model_data, size_t model_data_length, OrtEnv *env, OrtSession **out, OrtSessionOptions *options) { OrtStatus *status = NULL; diff --git a/onnxruntime_wrapper.h b/onnxruntime_wrapper.h index cf8170d..6399e1f 100644 --- a/onnxruntime_wrapper.h +++ b/onnxruntime_wrapper.h @@ -76,6 +76,20 @@ void ReleaseCUDAProviderOptions(OrtCUDAProviderOptionsV2 *o); OrtStatus *UpdateCUDAProviderOptions(OrtCUDAProviderOptionsV2 *o, const char **keys, const char **values, int num_keys); +// Wraps ort_api->CreateTensorRTProviderOptions +OrtStatus *CreateTensorRTProviderOptions(OrtTensorRTProviderOptionsV2 **o); + +// Wraps ort_api->ReleaseTensorRTProviderOptions +void ReleaseTensorRTProviderOptions(OrtTensorRTProviderOptionsV2 *o); + +// Wraps ort_api->UpdateTensorRTProviderOptions +OrtStatus *UpdateTensorRTProviderOptions(OrtTensorRTProviderOptionsV2 *o, + const char **keys, const char **values, int num_keys); + +// Wraps ort_api->SessionOptionsAppendExecutionProvider_TensorRT_V2 +OrtStatus *AppendExecutionProviderTensorRTV2(OrtSessionOptions *o, + OrtTensorRTProviderOptionsV2 *tensor_rt_options); + // Creates an ORT session using the given model. The given options pointer may // be NULL; if it is, then we'll use default options. OrtStatus *CreateSession(void *model_data, size_t model_data_length,