Add TensorRT Support

- This change introduces support for enabling the TensorRT execution
   backend. It is configured in basically the same manner as the CUDA
   backend, with analogous APIs.

 - Added a unit test and benchmark for the TensorRT backend. To try it,
   run "go test -v -bench=." on a system where TensorRT is installed,
   and you're using a build of onnxruntime with TensorRT enabled.
This commit is contained in:
yalue
2023-08-28 11:51:04 -04:00
parent 5ba50745d5
commit 4768a7ba69
4 changed files with 215 additions and 16 deletions

View File

@@ -389,8 +389,29 @@ type CUDAProviderOptions struct {
o *C.OrtCUDAProviderOptionsV2
}
// Used when setting key-value pair options with certain obnoxious C APIs.
// The entries in each of the returned slices must be freed when they're
// no longer needed.
func mapToCStrings(options map[string]string) ([]*C.char, []*C.char) {
keys := make([]*C.char, 0, len(options))
values := make([]*C.char, 0, len(options))
for k, v := range options {
keys = append(keys, C.CString(k))
values = append(values, C.CString(v))
}
return keys, values
}
// Calls free on each entry in the array of C strings.
func freeCStrings(s []*C.char) {
for i := range s {
C.free(unsafe.Pointer(s[i]))
s[i] = nil
}
}
// Wraps the call to the UpdateCUDAProviderOptions in the onnxruntime C API.
// Requires a list of string keys and values for configuring the CUDA backend.
// Requires a map of string keys to values for configuring the CUDA backend.
// For example, set the key "device_id" to "1" to use GPU 1 rather than 0.
//
// The onnxruntime headers refer users to
@@ -400,20 +421,9 @@ func (o *CUDAProviderOptions) Update(options map[string]string) error {
if len(options) == 0 {
return nil
}
keys := make([]*C.char, 0, len(options))
values := make([]*C.char, 0, len(options))
for k, v := range options {
keys = append(keys, C.CString(k))
values = append(values, C.CString(v))
}
// We don't need these C strings as soon as UpdateCUDAProviderOptions
// returns.
defer func() {
for i := range keys {
C.free(unsafe.Pointer(keys[i]))
C.free(unsafe.Pointer(values[i]))
}
}()
keys, values := mapToCStrings(options)
defer freeCStrings(keys)
defer freeCStrings(values)
status := C.UpdateCUDAProviderOptions(o.o, &(keys[0]), &(values[0]),
C.int(len(options)))
if status != nil {
@@ -428,7 +438,7 @@ func (o *CUDAProviderOptions) Update(options map[string]string) error {
// called.
func (o *CUDAProviderOptions) Destroy() error {
if o.o == nil {
return fmt.Errorf("The CUDAProviderOptions have not been initialized")
return fmt.Errorf("The CUDAProviderOptions are not initialized")
}
C.ReleaseCUDAProviderOptions(o.o)
o.o = nil
@@ -454,6 +464,64 @@ func NewCUDAProviderOptions() (*CUDAProviderOptions, error) {
}, nil
}
// Like the CUDAProviderOptions struct, but used for configuring TensorRT
// options. Instances of this struct must be initialized using
// NewTensorRTProviderOptions() and cleaned up by calling their Destroy()
// function when they are no longer needed.
type TensorRTProviderOptions struct {
o *C.OrtTensorRTProviderOptionsV2
}
// Wraps the call to the UpdateTensorRTProviderOptions in the C API. Requires
// a map of string keys to values.
//
// The onnxruntime headers refer users to
// https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#cc
// for the list of available keys and values.
func (o *TensorRTProviderOptions) Update(options map[string]string) error {
if len(options) == 0 {
return nil
}
keys, values := mapToCStrings(options)
defer freeCStrings(keys)
defer freeCStrings(values)
status := C.UpdateTensorRTProviderOptions(o.o, &(keys[0]), &(values[0]),
C.int(len(options)))
if status != nil {
return statusToError(status)
}
return nil
}
// Must be called when the TensorRTProviderOptions are no longer needed, in
// order to free internal state. The struct is not needed as soon as you have
// passed it to the AppendExecutionProviderTensorRT function.
func (o *TensorRTProviderOptions) Destroy() error {
if o.o == nil {
return fmt.Errorf("The TensorRTProviderOptions are not initialized")
}
C.ReleaseTensorRTProviderOptions(o.o)
o.o = nil
return nil
}
// Initializes and returns a TensorRTProviderOptions struct, used when enabling
// the TensorRT backend. The caller must call the Destroy() function on the
// returned struct when it's no longer needed.
func NewTensorRTProviderOptions() (*TensorRTProviderOptions, error) {
if !IsInitialized() {
return nil, NotInitializedError
}
var o *C.OrtTensorRTProviderOptionsV2
status := C.CreateTensorRTProviderOptions(&o)
if status != nil {
return nil, statusToError(status)
}
return &TensorRTProviderOptions{
o: o,
}, nil
}
// Used to set options when creating an ONNXRuntime session. There is currently
// not a way to change options after the session is created, apart from
// destroying the session and creating a new one. This struct opaquely wraps a
@@ -516,6 +584,20 @@ func (o *SessionOptions) AppendExecutionProviderCUDA(
return nil
}
// Takes an initialized TensorRTProviderOptions instance, and applies them to
// the session options. You'll need to call this if you want the session to use
// TensorRT. Returns an error if your device (or onnxruntime library version)
// does not support TensorRT. The TensorRTProviderOptions can be destroyed
// after this.
func (o *SessionOptions) AppendExecutionProviderTensorRT(
tensorRTOptions *TensorRTProviderOptions) error {
status := C.AppendExecutionProviderTensorRTV2(o.o, tensorRTOptions.o)
if status != nil {
return statusToError(status)
}
return nil
}
// Initializes and returns a SessionOptions struct, used when setting options
// in new AdvancedSession instances. The caller must call the Destroy()
// function on the returned struct when it's no longer needed.

View File

@@ -812,3 +812,87 @@ func BenchmarkCUDASession(b *testing.B) {
}
}
}
// Creates a SessionOptions struct that's configured to enable TensorRT.
// Basically the same as getCUDASessionOptions; see the comments there.
func getTensorRTSessionOptions(t testing.TB) *SessionOptions {
trtOptions, e := NewTensorRTProviderOptions()
if e != nil {
t.Skipf("Error creating TensorRT provider options; %s. "+
"Your version of the onnxruntime library may not include "+
"TensorRT support. Skipping the remainder of this test.\n", e)
}
defer trtOptions.Destroy()
// Arbitrarily update an option to test trtOptions.Update()
e = trtOptions.Update(
map[string]string{"trt_max_partition_iterations": "60"})
if e != nil {
t.Skipf("Error updating TensorRT options: %s. Your system may not "+
"support TensorRT, TensorRT may be misconfigured, or it may be "+
"incompatible with this build of onnxruntime. Skipping the "+
"remainder of this test.\n", e)
}
sessionOptions, e := NewSessionOptions()
if e != nil {
t.Logf("Error creating SessionOptions: %s\n", e)
t.FailNow()
}
e = sessionOptions.AppendExecutionProviderTensorRT(trtOptions)
if e != nil {
t.Logf("Error setting TensorRT execution provider: %s\n", e)
t.FailNow()
}
return sessionOptions
}
func TestTensorRTSession(t *testing.T) {
InitializeRuntime(t)
defer CleanupRuntime(t)
sessionOptions := getTensorRTSessionOptions(t)
defer sessionOptions.Destroy()
input, output := prepareBenchmarkTensors(t, 1337)
defer input.Destroy()
defer output.Destroy()
session, e := NewAdvancedSession("test_data/example_big_compute.onnx",
[]string{"Input"}, []string{"Output"}, []ArbitraryTensor{input},
[]ArbitraryTensor{output}, sessionOptions)
if e != nil {
t.Logf("Error creating session: %s\n", e)
t.FailNow()
}
defer session.Destroy()
e = session.Run()
if e != nil {
t.Logf("Error running session with TensorRT: %s\n", e)
t.FailNow()
}
t.Logf("Ran session with TensorRT OK.\n")
}
func BenchmarkTensorRTSession(b *testing.B) {
b.StopTimer()
InitializeRuntime(b)
defer CleanupRuntime(b)
sessionOptions := getTensorRTSessionOptions(b)
defer sessionOptions.Destroy()
input, output := prepareBenchmarkTensors(b, benchmarkRNGSeed)
defer input.Destroy()
defer output.Destroy()
session, e := NewAdvancedSession("test_data/example_big_compute.onnx",
[]string{"Input"}, []string{"Output"}, []ArbitraryTensor{input},
[]ArbitraryTensor{output}, sessionOptions)
if e != nil {
b.Logf("Error creating session: %s\n", e)
b.FailNow()
}
defer session.Destroy()
b.StartTimer()
for n := 0; n < b.N; n++ {
e = session.Run()
if e != nil {
b.Logf("Error running iteration %d/%d: %s\n", n+1, b.N, e)
b.FailNow()
}
}
}

View File

@@ -78,6 +78,25 @@ OrtStatus *UpdateCUDAProviderOptions(OrtCUDAProviderOptionsV2 *o,
return ort_api->UpdateCUDAProviderOptions(o, keys, values, num_keys);
}
OrtStatus *CreateTensorRTProviderOptions(OrtTensorRTProviderOptionsV2 **o) {
return ort_api->CreateTensorRTProviderOptions(o);
}
void ReleaseTensorRTProviderOptions(OrtTensorRTProviderOptionsV2 *o) {
ort_api->ReleaseTensorRTProviderOptions(o);
}
OrtStatus *UpdateTensorRTProviderOptions(OrtTensorRTProviderOptionsV2 *o,
const char **keys, const char **values, int num_keys) {
return ort_api->UpdateTensorRTProviderOptions(o, keys, values, num_keys);
}
OrtStatus *AppendExecutionProviderTensorRTV2(OrtSessionOptions *o,
OrtTensorRTProviderOptionsV2 *tensor_rt_options) {
return ort_api->SessionOptionsAppendExecutionProvider_TensorRT_V2(o,
tensor_rt_options);
}
OrtStatus *CreateSession(void *model_data, size_t model_data_length,
OrtEnv *env, OrtSession **out, OrtSessionOptions *options) {
OrtStatus *status = NULL;

View File

@@ -76,6 +76,20 @@ void ReleaseCUDAProviderOptions(OrtCUDAProviderOptionsV2 *o);
OrtStatus *UpdateCUDAProviderOptions(OrtCUDAProviderOptionsV2 *o,
const char **keys, const char **values, int num_keys);
// Wraps ort_api->CreateTensorRTProviderOptions
OrtStatus *CreateTensorRTProviderOptions(OrtTensorRTProviderOptionsV2 **o);
// Wraps ort_api->ReleaseTensorRTProviderOptions
void ReleaseTensorRTProviderOptions(OrtTensorRTProviderOptionsV2 *o);
// Wraps ort_api->UpdateTensorRTProviderOptions
OrtStatus *UpdateTensorRTProviderOptions(OrtTensorRTProviderOptionsV2 *o,
const char **keys, const char **values, int num_keys);
// Wraps ort_api->SessionOptionsAppendExecutionProvider_TensorRT_V2
OrtStatus *AppendExecutionProviderTensorRTV2(OrtSessionOptions *o,
OrtTensorRTProviderOptionsV2 *tensor_rt_options);
// Creates an ORT session using the given model. The given options pointer may
// be NULL; if it is, then we'll use default options.
OrtStatus *CreateSession(void *model_data, size_t model_data_length,