mirror of
https://github.com/yalue/onnxruntime_go.git
synced 2025-10-30 02:11:45 +08:00
Add TensorRT Support
- This change introduces support for enabling the TensorRT execution backend. It is configured in basically the same manner as the CUDA backend, with analogous APIs. - Added a unit test and benchmark for the TensorRT backend. To try it, run "go test -v -bench=." on a system where TensorRT is installed, and you're using a build of onnxruntime with TensorRT enabled.
This commit is contained in:
@@ -389,8 +389,29 @@ type CUDAProviderOptions struct {
|
||||
o *C.OrtCUDAProviderOptionsV2
|
||||
}
|
||||
|
||||
// Used when setting key-value pair options with certain obnoxious C APIs.
|
||||
// The entries in each of the returned slices must be freed when they're
|
||||
// no longer needed.
|
||||
func mapToCStrings(options map[string]string) ([]*C.char, []*C.char) {
|
||||
keys := make([]*C.char, 0, len(options))
|
||||
values := make([]*C.char, 0, len(options))
|
||||
for k, v := range options {
|
||||
keys = append(keys, C.CString(k))
|
||||
values = append(values, C.CString(v))
|
||||
}
|
||||
return keys, values
|
||||
}
|
||||
|
||||
// Calls free on each entry in the array of C strings.
|
||||
func freeCStrings(s []*C.char) {
|
||||
for i := range s {
|
||||
C.free(unsafe.Pointer(s[i]))
|
||||
s[i] = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Wraps the call to the UpdateCUDAProviderOptions in the onnxruntime C API.
|
||||
// Requires a list of string keys and values for configuring the CUDA backend.
|
||||
// Requires a map of string keys to values for configuring the CUDA backend.
|
||||
// For example, set the key "device_id" to "1" to use GPU 1 rather than 0.
|
||||
//
|
||||
// The onnxruntime headers refer users to
|
||||
@@ -400,20 +421,9 @@ func (o *CUDAProviderOptions) Update(options map[string]string) error {
|
||||
if len(options) == 0 {
|
||||
return nil
|
||||
}
|
||||
keys := make([]*C.char, 0, len(options))
|
||||
values := make([]*C.char, 0, len(options))
|
||||
for k, v := range options {
|
||||
keys = append(keys, C.CString(k))
|
||||
values = append(values, C.CString(v))
|
||||
}
|
||||
// We don't need these C strings as soon as UpdateCUDAProviderOptions
|
||||
// returns.
|
||||
defer func() {
|
||||
for i := range keys {
|
||||
C.free(unsafe.Pointer(keys[i]))
|
||||
C.free(unsafe.Pointer(values[i]))
|
||||
}
|
||||
}()
|
||||
keys, values := mapToCStrings(options)
|
||||
defer freeCStrings(keys)
|
||||
defer freeCStrings(values)
|
||||
status := C.UpdateCUDAProviderOptions(o.o, &(keys[0]), &(values[0]),
|
||||
C.int(len(options)))
|
||||
if status != nil {
|
||||
@@ -428,7 +438,7 @@ func (o *CUDAProviderOptions) Update(options map[string]string) error {
|
||||
// called.
|
||||
func (o *CUDAProviderOptions) Destroy() error {
|
||||
if o.o == nil {
|
||||
return fmt.Errorf("The CUDAProviderOptions have not been initialized")
|
||||
return fmt.Errorf("The CUDAProviderOptions are not initialized")
|
||||
}
|
||||
C.ReleaseCUDAProviderOptions(o.o)
|
||||
o.o = nil
|
||||
@@ -454,6 +464,64 @@ func NewCUDAProviderOptions() (*CUDAProviderOptions, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Like the CUDAProviderOptions struct, but used for configuring TensorRT
|
||||
// options. Instances of this struct must be initialized using
|
||||
// NewTensorRTProviderOptions() and cleaned up by calling their Destroy()
|
||||
// function when they are no longer needed.
|
||||
type TensorRTProviderOptions struct {
|
||||
o *C.OrtTensorRTProviderOptionsV2
|
||||
}
|
||||
|
||||
// Wraps the call to the UpdateTensorRTProviderOptions in the C API. Requires
|
||||
// a map of string keys to values.
|
||||
//
|
||||
// The onnxruntime headers refer users to
|
||||
// https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#cc
|
||||
// for the list of available keys and values.
|
||||
func (o *TensorRTProviderOptions) Update(options map[string]string) error {
|
||||
if len(options) == 0 {
|
||||
return nil
|
||||
}
|
||||
keys, values := mapToCStrings(options)
|
||||
defer freeCStrings(keys)
|
||||
defer freeCStrings(values)
|
||||
status := C.UpdateTensorRTProviderOptions(o.o, &(keys[0]), &(values[0]),
|
||||
C.int(len(options)))
|
||||
if status != nil {
|
||||
return statusToError(status)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Must be called when the TensorRTProviderOptions are no longer needed, in
|
||||
// order to free internal state. The struct is not needed as soon as you have
|
||||
// passed it to the AppendExecutionProviderTensorRT function.
|
||||
func (o *TensorRTProviderOptions) Destroy() error {
|
||||
if o.o == nil {
|
||||
return fmt.Errorf("The TensorRTProviderOptions are not initialized")
|
||||
}
|
||||
C.ReleaseTensorRTProviderOptions(o.o)
|
||||
o.o = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initializes and returns a TensorRTProviderOptions struct, used when enabling
|
||||
// the TensorRT backend. The caller must call the Destroy() function on the
|
||||
// returned struct when it's no longer needed.
|
||||
func NewTensorRTProviderOptions() (*TensorRTProviderOptions, error) {
|
||||
if !IsInitialized() {
|
||||
return nil, NotInitializedError
|
||||
}
|
||||
var o *C.OrtTensorRTProviderOptionsV2
|
||||
status := C.CreateTensorRTProviderOptions(&o)
|
||||
if status != nil {
|
||||
return nil, statusToError(status)
|
||||
}
|
||||
return &TensorRTProviderOptions{
|
||||
o: o,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Used to set options when creating an ONNXRuntime session. There is currently
|
||||
// not a way to change options after the session is created, apart from
|
||||
// destroying the session and creating a new one. This struct opaquely wraps a
|
||||
@@ -516,6 +584,20 @@ func (o *SessionOptions) AppendExecutionProviderCUDA(
|
||||
return nil
|
||||
}
|
||||
|
||||
// Takes an initialized TensorRTProviderOptions instance, and applies them to
|
||||
// the session options. You'll need to call this if you want the session to use
|
||||
// TensorRT. Returns an error if your device (or onnxruntime library version)
|
||||
// does not support TensorRT. The TensorRTProviderOptions can be destroyed
|
||||
// after this.
|
||||
func (o *SessionOptions) AppendExecutionProviderTensorRT(
|
||||
tensorRTOptions *TensorRTProviderOptions) error {
|
||||
status := C.AppendExecutionProviderTensorRTV2(o.o, tensorRTOptions.o)
|
||||
if status != nil {
|
||||
return statusToError(status)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initializes and returns a SessionOptions struct, used when setting options
|
||||
// in new AdvancedSession instances. The caller must call the Destroy()
|
||||
// function on the returned struct when it's no longer needed.
|
||||
|
||||
@@ -812,3 +812,87 @@ func BenchmarkCUDASession(b *testing.B) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Creates a SessionOptions struct that's configured to enable TensorRT.
|
||||
// Basically the same as getCUDASessionOptions; see the comments there.
|
||||
func getTensorRTSessionOptions(t testing.TB) *SessionOptions {
|
||||
trtOptions, e := NewTensorRTProviderOptions()
|
||||
if e != nil {
|
||||
t.Skipf("Error creating TensorRT provider options; %s. "+
|
||||
"Your version of the onnxruntime library may not include "+
|
||||
"TensorRT support. Skipping the remainder of this test.\n", e)
|
||||
}
|
||||
defer trtOptions.Destroy()
|
||||
// Arbitrarily update an option to test trtOptions.Update()
|
||||
e = trtOptions.Update(
|
||||
map[string]string{"trt_max_partition_iterations": "60"})
|
||||
if e != nil {
|
||||
t.Skipf("Error updating TensorRT options: %s. Your system may not "+
|
||||
"support TensorRT, TensorRT may be misconfigured, or it may be "+
|
||||
"incompatible with this build of onnxruntime. Skipping the "+
|
||||
"remainder of this test.\n", e)
|
||||
}
|
||||
sessionOptions, e := NewSessionOptions()
|
||||
if e != nil {
|
||||
t.Logf("Error creating SessionOptions: %s\n", e)
|
||||
t.FailNow()
|
||||
}
|
||||
e = sessionOptions.AppendExecutionProviderTensorRT(trtOptions)
|
||||
if e != nil {
|
||||
t.Logf("Error setting TensorRT execution provider: %s\n", e)
|
||||
t.FailNow()
|
||||
}
|
||||
return sessionOptions
|
||||
}
|
||||
|
||||
func TestTensorRTSession(t *testing.T) {
|
||||
InitializeRuntime(t)
|
||||
defer CleanupRuntime(t)
|
||||
sessionOptions := getTensorRTSessionOptions(t)
|
||||
defer sessionOptions.Destroy()
|
||||
|
||||
input, output := prepareBenchmarkTensors(t, 1337)
|
||||
defer input.Destroy()
|
||||
defer output.Destroy()
|
||||
session, e := NewAdvancedSession("test_data/example_big_compute.onnx",
|
||||
[]string{"Input"}, []string{"Output"}, []ArbitraryTensor{input},
|
||||
[]ArbitraryTensor{output}, sessionOptions)
|
||||
if e != nil {
|
||||
t.Logf("Error creating session: %s\n", e)
|
||||
t.FailNow()
|
||||
}
|
||||
defer session.Destroy()
|
||||
e = session.Run()
|
||||
if e != nil {
|
||||
t.Logf("Error running session with TensorRT: %s\n", e)
|
||||
t.FailNow()
|
||||
}
|
||||
t.Logf("Ran session with TensorRT OK.\n")
|
||||
}
|
||||
|
||||
func BenchmarkTensorRTSession(b *testing.B) {
|
||||
b.StopTimer()
|
||||
InitializeRuntime(b)
|
||||
defer CleanupRuntime(b)
|
||||
sessionOptions := getTensorRTSessionOptions(b)
|
||||
defer sessionOptions.Destroy()
|
||||
input, output := prepareBenchmarkTensors(b, benchmarkRNGSeed)
|
||||
defer input.Destroy()
|
||||
defer output.Destroy()
|
||||
session, e := NewAdvancedSession("test_data/example_big_compute.onnx",
|
||||
[]string{"Input"}, []string{"Output"}, []ArbitraryTensor{input},
|
||||
[]ArbitraryTensor{output}, sessionOptions)
|
||||
if e != nil {
|
||||
b.Logf("Error creating session: %s\n", e)
|
||||
b.FailNow()
|
||||
}
|
||||
defer session.Destroy()
|
||||
b.StartTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
e = session.Run()
|
||||
if e != nil {
|
||||
b.Logf("Error running iteration %d/%d: %s\n", n+1, b.N, e)
|
||||
b.FailNow()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,6 +78,25 @@ OrtStatus *UpdateCUDAProviderOptions(OrtCUDAProviderOptionsV2 *o,
|
||||
return ort_api->UpdateCUDAProviderOptions(o, keys, values, num_keys);
|
||||
}
|
||||
|
||||
OrtStatus *CreateTensorRTProviderOptions(OrtTensorRTProviderOptionsV2 **o) {
|
||||
return ort_api->CreateTensorRTProviderOptions(o);
|
||||
}
|
||||
|
||||
void ReleaseTensorRTProviderOptions(OrtTensorRTProviderOptionsV2 *o) {
|
||||
ort_api->ReleaseTensorRTProviderOptions(o);
|
||||
}
|
||||
|
||||
OrtStatus *UpdateTensorRTProviderOptions(OrtTensorRTProviderOptionsV2 *o,
|
||||
const char **keys, const char **values, int num_keys) {
|
||||
return ort_api->UpdateTensorRTProviderOptions(o, keys, values, num_keys);
|
||||
}
|
||||
|
||||
OrtStatus *AppendExecutionProviderTensorRTV2(OrtSessionOptions *o,
|
||||
OrtTensorRTProviderOptionsV2 *tensor_rt_options) {
|
||||
return ort_api->SessionOptionsAppendExecutionProvider_TensorRT_V2(o,
|
||||
tensor_rt_options);
|
||||
}
|
||||
|
||||
OrtStatus *CreateSession(void *model_data, size_t model_data_length,
|
||||
OrtEnv *env, OrtSession **out, OrtSessionOptions *options) {
|
||||
OrtStatus *status = NULL;
|
||||
|
||||
@@ -76,6 +76,20 @@ void ReleaseCUDAProviderOptions(OrtCUDAProviderOptionsV2 *o);
|
||||
OrtStatus *UpdateCUDAProviderOptions(OrtCUDAProviderOptionsV2 *o,
|
||||
const char **keys, const char **values, int num_keys);
|
||||
|
||||
// Wraps ort_api->CreateTensorRTProviderOptions
|
||||
OrtStatus *CreateTensorRTProviderOptions(OrtTensorRTProviderOptionsV2 **o);
|
||||
|
||||
// Wraps ort_api->ReleaseTensorRTProviderOptions
|
||||
void ReleaseTensorRTProviderOptions(OrtTensorRTProviderOptionsV2 *o);
|
||||
|
||||
// Wraps ort_api->UpdateTensorRTProviderOptions
|
||||
OrtStatus *UpdateTensorRTProviderOptions(OrtTensorRTProviderOptionsV2 *o,
|
||||
const char **keys, const char **values, int num_keys);
|
||||
|
||||
// Wraps ort_api->SessionOptionsAppendExecutionProvider_TensorRT_V2
|
||||
OrtStatus *AppendExecutionProviderTensorRTV2(OrtSessionOptions *o,
|
||||
OrtTensorRTProviderOptionsV2 *tensor_rt_options);
|
||||
|
||||
// Creates an ORT session using the given model. The given options pointer may
|
||||
// be NULL; if it is, then we'll use default options.
|
||||
OrtStatus *CreateSession(void *model_data, size_t model_data_length,
|
||||
|
||||
Reference in New Issue
Block a user