Add TensorRT Support

- This change introduces support for enabling the TensorRT execution
   backend. It is configured in basically the same manner as the CUDA
   backend, with analogous APIs.

 - Added a unit test and benchmark for the TensorRT backend. To try it,
   run "go test -v -bench=." on a system where TensorRT is installed,
   and you're using a build of onnxruntime with TensorRT enabled.
This commit is contained in:
yalue
2023-08-28 11:51:04 -04:00
parent 5ba50745d5
commit 4768a7ba69
4 changed files with 215 additions and 16 deletions

View File

@@ -812,3 +812,87 @@ func BenchmarkCUDASession(b *testing.B) {
}
}
}
// Creates a SessionOptions struct that's configured to enable TensorRT.
// Basically the same as getCUDASessionOptions; see the comments there.
func getTensorRTSessionOptions(t testing.TB) *SessionOptions {
trtOptions, e := NewTensorRTProviderOptions()
if e != nil {
t.Skipf("Error creating TensorRT provider options; %s. "+
"Your version of the onnxruntime library may not include "+
"TensorRT support. Skipping the remainder of this test.\n", e)
}
defer trtOptions.Destroy()
// Arbitrarily update an option to test trtOptions.Update()
e = trtOptions.Update(
map[string]string{"trt_max_partition_iterations": "60"})
if e != nil {
t.Skipf("Error updating TensorRT options: %s. Your system may not "+
"support TensorRT, TensorRT may be misconfigured, or it may be "+
"incompatible with this build of onnxruntime. Skipping the "+
"remainder of this test.\n", e)
}
sessionOptions, e := NewSessionOptions()
if e != nil {
t.Logf("Error creating SessionOptions: %s\n", e)
t.FailNow()
}
e = sessionOptions.AppendExecutionProviderTensorRT(trtOptions)
if e != nil {
t.Logf("Error setting TensorRT execution provider: %s\n", e)
t.FailNow()
}
return sessionOptions
}
func TestTensorRTSession(t *testing.T) {
InitializeRuntime(t)
defer CleanupRuntime(t)
sessionOptions := getTensorRTSessionOptions(t)
defer sessionOptions.Destroy()
input, output := prepareBenchmarkTensors(t, 1337)
defer input.Destroy()
defer output.Destroy()
session, e := NewAdvancedSession("test_data/example_big_compute.onnx",
[]string{"Input"}, []string{"Output"}, []ArbitraryTensor{input},
[]ArbitraryTensor{output}, sessionOptions)
if e != nil {
t.Logf("Error creating session: %s\n", e)
t.FailNow()
}
defer session.Destroy()
e = session.Run()
if e != nil {
t.Logf("Error running session with TensorRT: %s\n", e)
t.FailNow()
}
t.Logf("Ran session with TensorRT OK.\n")
}
func BenchmarkTensorRTSession(b *testing.B) {
b.StopTimer()
InitializeRuntime(b)
defer CleanupRuntime(b)
sessionOptions := getTensorRTSessionOptions(b)
defer sessionOptions.Destroy()
input, output := prepareBenchmarkTensors(b, benchmarkRNGSeed)
defer input.Destroy()
defer output.Destroy()
session, e := NewAdvancedSession("test_data/example_big_compute.onnx",
[]string{"Input"}, []string{"Output"}, []ArbitraryTensor{input},
[]ArbitraryTensor{output}, sessionOptions)
if e != nil {
b.Logf("Error creating session: %s\n", e)
b.FailNow()
}
defer session.Destroy()
b.StartTimer()
for n := 0; n < b.N; n++ {
e = session.Run()
if e != nil {
b.Logf("Error running iteration %d/%d: %s\n", n+1, b.N, e)
b.FailNow()
}
}
}