mirror of
https://github.com/yalue/onnxruntime_go.git
synced 2025-10-29 01:42:27 +08:00
Add TensorRT Support
- This change introduces support for enabling the TensorRT execution backend. It is configured in basically the same manner as the CUDA backend, with analogous APIs. - Added a unit test and benchmark for the TensorRT backend. To try it, run "go test -v -bench=." on a system where TensorRT is installed, and you're using a build of onnxruntime with TensorRT enabled.
This commit is contained in:
@@ -812,3 +812,87 @@ func BenchmarkCUDASession(b *testing.B) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Creates a SessionOptions struct that's configured to enable TensorRT.
|
||||
// Basically the same as getCUDASessionOptions; see the comments there.
|
||||
func getTensorRTSessionOptions(t testing.TB) *SessionOptions {
|
||||
trtOptions, e := NewTensorRTProviderOptions()
|
||||
if e != nil {
|
||||
t.Skipf("Error creating TensorRT provider options; %s. "+
|
||||
"Your version of the onnxruntime library may not include "+
|
||||
"TensorRT support. Skipping the remainder of this test.\n", e)
|
||||
}
|
||||
defer trtOptions.Destroy()
|
||||
// Arbitrarily update an option to test trtOptions.Update()
|
||||
e = trtOptions.Update(
|
||||
map[string]string{"trt_max_partition_iterations": "60"})
|
||||
if e != nil {
|
||||
t.Skipf("Error updating TensorRT options: %s. Your system may not "+
|
||||
"support TensorRT, TensorRT may be misconfigured, or it may be "+
|
||||
"incompatible with this build of onnxruntime. Skipping the "+
|
||||
"remainder of this test.\n", e)
|
||||
}
|
||||
sessionOptions, e := NewSessionOptions()
|
||||
if e != nil {
|
||||
t.Logf("Error creating SessionOptions: %s\n", e)
|
||||
t.FailNow()
|
||||
}
|
||||
e = sessionOptions.AppendExecutionProviderTensorRT(trtOptions)
|
||||
if e != nil {
|
||||
t.Logf("Error setting TensorRT execution provider: %s\n", e)
|
||||
t.FailNow()
|
||||
}
|
||||
return sessionOptions
|
||||
}
|
||||
|
||||
func TestTensorRTSession(t *testing.T) {
|
||||
InitializeRuntime(t)
|
||||
defer CleanupRuntime(t)
|
||||
sessionOptions := getTensorRTSessionOptions(t)
|
||||
defer sessionOptions.Destroy()
|
||||
|
||||
input, output := prepareBenchmarkTensors(t, 1337)
|
||||
defer input.Destroy()
|
||||
defer output.Destroy()
|
||||
session, e := NewAdvancedSession("test_data/example_big_compute.onnx",
|
||||
[]string{"Input"}, []string{"Output"}, []ArbitraryTensor{input},
|
||||
[]ArbitraryTensor{output}, sessionOptions)
|
||||
if e != nil {
|
||||
t.Logf("Error creating session: %s\n", e)
|
||||
t.FailNow()
|
||||
}
|
||||
defer session.Destroy()
|
||||
e = session.Run()
|
||||
if e != nil {
|
||||
t.Logf("Error running session with TensorRT: %s\n", e)
|
||||
t.FailNow()
|
||||
}
|
||||
t.Logf("Ran session with TensorRT OK.\n")
|
||||
}
|
||||
|
||||
func BenchmarkTensorRTSession(b *testing.B) {
|
||||
b.StopTimer()
|
||||
InitializeRuntime(b)
|
||||
defer CleanupRuntime(b)
|
||||
sessionOptions := getTensorRTSessionOptions(b)
|
||||
defer sessionOptions.Destroy()
|
||||
input, output := prepareBenchmarkTensors(b, benchmarkRNGSeed)
|
||||
defer input.Destroy()
|
||||
defer output.Destroy()
|
||||
session, e := NewAdvancedSession("test_data/example_big_compute.onnx",
|
||||
[]string{"Input"}, []string{"Output"}, []ArbitraryTensor{input},
|
||||
[]ArbitraryTensor{output}, sessionOptions)
|
||||
if e != nil {
|
||||
b.Logf("Error creating session: %s\n", e)
|
||||
b.FailNow()
|
||||
}
|
||||
defer session.Destroy()
|
||||
b.StartTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
e = session.Run()
|
||||
if e != nil {
|
||||
b.Logf("Error running iteration %d/%d: %s\n", n+1, b.N, e)
|
||||
b.FailNow()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user