mirror of
https://github.com/yalue/onnxruntime_go.git
synced 2025-09-26 19:31:13 +08:00
Initial commit
- Sets up the onnxruntime environment, but doesn't load or run networks yet. - The things builds and run on Windows. - Still working on getting the Linux (arm64 for now) test to work.
This commit is contained in:
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
onnx_example_application/onnx_example_application
|
||||
onnx_example_application/onnx_example_application.exe
|
||||
|
19
LICENSE
Normal file
19
LICENSE
Normal file
@@ -0,0 +1,19 @@
|
||||
Copyright (c) 2023 Nathan Otterness
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
98
README.md
Normal file
98
README.md
Normal file
@@ -0,0 +1,98 @@
|
||||
Cross-Platform `onnxruntime` Wrapper for Go
|
||||
===========================================
|
||||
|
||||
About
|
||||
-----
|
||||
|
||||
This library seeks to provide an interface for loading and executing neural
|
||||
networks from Go(lang) code, while remaining as simple to use as possible.
|
||||
|
||||
The [onnxruntime](https://github.com/microsoft/onnxruntime) library provides a
|
||||
way to load and execute ONNX-format neural networks, though the library
|
||||
primarily supports C and C++ APIs. Several efforts exist to have written
|
||||
Go(lang) wrappers for the `onnxruntime` library, but as far as I can tell, none
|
||||
of these existing Go wrappers support Windows. This is due to the fact that
|
||||
Microsoft's `onnxruntime` library assumes the user will be using the MSVC
|
||||
compiler on Windows systems, while CGo on Windows requires using Mingw.
|
||||
|
||||
This wrapper works around the issues by manually loading the `onnxruntime`
|
||||
shared library, removing any dependency on the `onnxruntime` source code beyond
|
||||
the header files. Naturally, this approach works equally well on non-Windows
|
||||
systems.
|
||||
|
||||
|
||||
Requirements
|
||||
------------
|
||||
|
||||
To use this library, you'll need a version of Go with cgo support. If you are
|
||||
not using an amd64 version of Windows or Linux (or if you want to provide your
|
||||
own library for some other reason), you simply need to provide the correct path
|
||||
to the shared library when initializing the wrapper. This is seen in the first
|
||||
few lines of the following example.
|
||||
|
||||
|
||||
Example Usage
|
||||
-------------
|
||||
|
||||
The following example illustrates how this library can be used to load and run
|
||||
an ONNX network taking a single input tensor and producing a single output
|
||||
tensor, both of which contain 32-bit floating point values.
|
||||
|
||||
```
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/yalue/onnxruntime"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// This line may be optional, by default the library will try to load
|
||||
// "onnxruntime.dll" on Windows, and "onnxruntime.so" on any other system.
|
||||
onnxruntime.SetSharedLibraryPath("path/to/onnxruntime.so")
|
||||
|
||||
err := onnxruntime.InitializeEnvironment()
|
||||
if err != nil {
|
||||
fmt.Printf("Failed initializing onnxruntime: %s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer onnxruntime.CleanupEnvironment()
|
||||
|
||||
// We'll assume that network.onnx takes a single 2x3x4 input tensor and
|
||||
// produces a 1x2x2 output tensor.
|
||||
inputShape := []int64{1, 2, 3}
|
||||
outputShape := []int64{1, 2, 2}
|
||||
session, err := onnxruntime.CreateSimpleSession("path/to/network.onnx",
|
||||
inputShape, outputShape)
|
||||
if err != nil {
|
||||
fmt.Printf("Error creating session: %s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer session.Destroy()
|
||||
|
||||
// Network inputs must be provided as flattened slices of floats. Run() can
|
||||
// be called as many times as necessary with a single session.
|
||||
err := session.Run([]float32{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6})
|
||||
if err != nil {
|
||||
fmt.Printf("Error running the network: %s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// This will be a flattened slice containing the elements in the 1x2x2
|
||||
// output tensor.
|
||||
results := session.Results()
|
||||
|
||||
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
Full Documentation
|
||||
------------------
|
||||
|
||||
The above example uses a single input and produces a single output, all with
|
||||
`float32` data. The `CreateSimpleSession` function supports this, as it is
|
||||
expected to be a common use case. However, the library supports far more
|
||||
options, i.e. using the `CreateSession` function when setting up a session.
|
||||
|
||||
The full documentation can be found at [pkg.go.dev](https://pkg.go.dev/github.com/yalue/onnxruntime).
|
||||
|
7
onnx_example_application/README.md
Normal file
7
onnx_example_application/README.md
Normal file
@@ -0,0 +1,7 @@
|
||||
An Example Application Using the `onnxruntime` Go Wrapper
|
||||
=========================================================
|
||||
|
||||
To run this application, navigate to this directory and compile it using
|
||||
`go build`. Afterwards, run it using `./onnx_example_application` (or
|
||||
`onnx_example_application.exe` on Windows).
|
||||
|
BIN
onnx_example_application/example_network.onnx
Normal file
BIN
onnx_example_application/example_network.onnx
Normal file
Binary file not shown.
38
onnx_example_application/onnx_example_application.go
Normal file
38
onnx_example_application/onnx_example_application.go
Normal file
@@ -0,0 +1,38 @@
|
||||
// This application loads a test ONNX network and executes it on some fixed
|
||||
// data. It serves as an example of how to use the onnxruntime wrapper library.
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/yalue/onnxruntime"
|
||||
"os"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
func run() int {
|
||||
if runtime.GOOS == "windows" {
|
||||
onnxruntime.SetSharedLibraryPath("../test_data/onnxruntime.dll")
|
||||
} else {
|
||||
onnxruntime.SetSharedLibraryPath("../test_data/onnxruntime.so")
|
||||
}
|
||||
e := onnxruntime.InitializeEnvironment()
|
||||
if e != nil {
|
||||
fmt.Printf("Error initializing the onnxruntime environment: %s\n", e)
|
||||
return 1
|
||||
}
|
||||
fmt.Printf("The onnxruntime environment initialized OK.\n")
|
||||
|
||||
// Ordinarily, it is probably fine to call this using defer, but we do it
|
||||
// here just so we can print a status message after the cleanup completes.
|
||||
e = onnxruntime.CleanupEnvironment()
|
||||
if e != nil {
|
||||
fmt.Printf("Error cleaning up the environment: %s\n", e)
|
||||
return 1
|
||||
}
|
||||
fmt.Printf("The onnxruntime environment was cleaned up OK.\n")
|
||||
return 0
|
||||
}
|
||||
|
||||
func main() {
|
||||
os.Exit(run())
|
||||
}
|
83
onnxruntime.go
Normal file
83
onnxruntime.go
Normal file
@@ -0,0 +1,83 @@
|
||||
// This library wraps the C "onnxruntime" library maintained at
|
||||
// https://github.com/microsoft/onnxruntime. It seeks to provide as simple an
|
||||
// interface as possible to load and run ONNX-format neural networks from
|
||||
// Go code.
|
||||
package onnxruntime
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// #cgo CFLAGS: -I${SRCDIR}/onnxruntime/include
|
||||
//
|
||||
// #include "onnxruntime_wrapper.h"
|
||||
import "C"
|
||||
|
||||
// This string should be the path to onnxruntime.so, or onnxruntime.dll.
|
||||
var onnxSharedLibraryPath string
|
||||
|
||||
// For simplicity, this library maintains a single ORT environment internally.
|
||||
var ortEnv *C.OrtEnv
|
||||
|
||||
// Does two things: converts the given OrtStatus to a Go error, and releases
|
||||
// the status. If the status is nil, this does nothing and returns nil.
|
||||
func statusToError(status *C.OrtStatus) error {
|
||||
if status == nil {
|
||||
return nil
|
||||
}
|
||||
msg := C.GetErrorMessage(status)
|
||||
toReturn := C.GoString(msg)
|
||||
C.ReleaseOrtStatus(status)
|
||||
return fmt.Errorf("%s", toReturn)
|
||||
}
|
||||
|
||||
// Use this function to set the path to the "onnxruntime.so" or
|
||||
// "onnxruntime.dll" function. By default, it will be set to "onnxruntime.so"
|
||||
// on non-Windows systems, and "onnxruntime.dll" on Windows. Users wishing to
|
||||
// specify a particular location of this library must call this function prior
|
||||
// to calling onnxruntime.InitializeEnvironment().
|
||||
func SetSharedLibraryPath(path string) {
|
||||
onnxSharedLibraryPath = path
|
||||
}
|
||||
|
||||
// Call this function to initialize the internal onnxruntime environment. If
|
||||
// this doesn't return an error, the caller will be responsible for calling
|
||||
// CleanupEnvironment to free the onnxruntime state when no longer needed.
|
||||
func InitializeEnvironment() error {
|
||||
if ortEnv != nil {
|
||||
return fmt.Errorf("The onnxruntime has already been initialized")
|
||||
}
|
||||
// Do the windows- or linux- specific initialization first.
|
||||
e := platformInitializeEnvironment()
|
||||
if e != nil {
|
||||
return fmt.Errorf("Platform-specific initialization failed: %w", e)
|
||||
}
|
||||
|
||||
name := C.CString("Golang onnxruntime environment")
|
||||
defer C.free(unsafe.Pointer(name))
|
||||
status := C.CreateOrtEnv(name, &ortEnv)
|
||||
if status != nil {
|
||||
return fmt.Errorf("Error creating ORT environment: %w",
|
||||
statusToError(status))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Call this function to cleanup the internal onnxruntime environment when it
|
||||
// is no longer needed.
|
||||
func CleanupEnvironment() error {
|
||||
var e error
|
||||
// TODO: Implement CleanupEnvironment
|
||||
// Prior to calling platformCleanup, we need to:
|
||||
// - Destroy the environment
|
||||
// - Destroy any remaining active sessions?
|
||||
|
||||
// platformCleanup primarily unloads the library, so we need to call it
|
||||
// last, after unloading the library.
|
||||
e = platformCleanup()
|
||||
if e != nil {
|
||||
return fmt.Errorf("Platform-specific cleanup failed: %w", e)
|
||||
}
|
||||
return nil
|
||||
}
|
19
onnxruntime/include/cpu_provider_factory.h
Normal file
19
onnxruntime/include/cpu_provider_factory.h
Normal file
@@ -0,0 +1,19 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "onnxruntime_c_api.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \param use_arena zero: false. non-zero: true.
|
||||
*/
|
||||
ORT_EXPORT
|
||||
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena)
|
||||
ORT_ALL_ARGS_NONNULL;
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
96
onnxruntime/include/environment.h
Normal file
96
onnxruntime/include/environment.h
Normal file
@@ -0,0 +1,96 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
#include "core/common/common.h"
|
||||
#include "core/common/status.h"
|
||||
#include "core/platform/threadpool.h"
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/framework/allocator.h"
|
||||
|
||||
struct OrtThreadingOptions;
|
||||
namespace onnxruntime {
|
||||
/** TODO: remove this class
|
||||
Provides the runtime environment for onnxruntime.
|
||||
Create one instance for the duration of execution.
|
||||
*/
|
||||
class Environment {
|
||||
public:
|
||||
/**
|
||||
Create and initialize the runtime environment.
|
||||
@param logging manager instance that will enable per session logger output using
|
||||
session_options.session_logid as the logger id in messages.
|
||||
If nullptr, the default LoggingManager MUST have been created previously as it will be used
|
||||
for logging. This will use the default logger id in messages.
|
||||
See core/common/logging/logging.h for details, and how LoggingManager::DefaultLogger works.
|
||||
@param tp_options optional set of parameters controlling the number of intra and inter op threads for the global
|
||||
threadpools.
|
||||
@param create_global_thread_pools determine if this function will create the global threadpools or not.
|
||||
*/
|
||||
static Status Create(std::unique_ptr<logging::LoggingManager> logging_manager,
|
||||
std::unique_ptr<Environment>& environment,
|
||||
const OrtThreadingOptions* tp_options = nullptr,
|
||||
bool create_global_thread_pools = false);
|
||||
|
||||
logging::LoggingManager* GetLoggingManager() const {
|
||||
return logging_manager_.get();
|
||||
}
|
||||
|
||||
void SetLoggingManager(std::unique_ptr<onnxruntime::logging::LoggingManager> logging_manager) {
|
||||
logging_manager_ = std::move(logging_manager);
|
||||
}
|
||||
|
||||
onnxruntime::concurrency::ThreadPool* GetIntraOpThreadPool() const {
|
||||
return intra_op_thread_pool_.get();
|
||||
}
|
||||
|
||||
onnxruntime::concurrency::ThreadPool* GetInterOpThreadPool() const {
|
||||
return inter_op_thread_pool_.get();
|
||||
}
|
||||
|
||||
bool EnvCreatedWithGlobalThreadPools() const {
|
||||
return create_global_thread_pools_;
|
||||
}
|
||||
|
||||
/**
|
||||
* Registers an allocator for sharing between multiple sessions.
|
||||
* Return an error if an allocator with the same OrtMemoryInfo is already registered.
|
||||
*/
|
||||
Status RegisterAllocator(AllocatorPtr allocator);
|
||||
|
||||
/**
|
||||
* Creates and registers an allocator for sharing between multiple sessions.
|
||||
* Return an error if an allocator with the same OrtMemoryInfo is already registered.
|
||||
*/
|
||||
Status CreateAndRegisterAllocator(const OrtMemoryInfo& mem_info, const OrtArenaCfg* arena_cfg = nullptr);
|
||||
|
||||
/**
|
||||
* Returns the list of registered allocators in this env.
|
||||
*/
|
||||
const std::vector<AllocatorPtr>& GetRegisteredSharedAllocators() const {
|
||||
return shared_allocators_;
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes registered allocator that was previously registered for sharing between multiple sessions.
|
||||
*/
|
||||
Status UnregisterAllocator(const OrtMemoryInfo& mem_info);
|
||||
|
||||
Environment() = default;
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Environment);
|
||||
Status Initialize(std::unique_ptr<logging::LoggingManager> logging_manager,
|
||||
const OrtThreadingOptions* tp_options = nullptr,
|
||||
bool create_global_thread_pools = false);
|
||||
|
||||
std::unique_ptr<logging::LoggingManager> logging_manager_;
|
||||
std::unique_ptr<onnxruntime::concurrency::ThreadPool> intra_op_thread_pool_;
|
||||
std::unique_ptr<onnxruntime::concurrency::ThreadPool> inter_op_thread_pool_;
|
||||
bool create_global_thread_pools_{false};
|
||||
std::vector<AllocatorPtr> shared_allocators_;
|
||||
};
|
||||
} // namespace onnxruntime
|
69
onnxruntime/include/experimental_onnxruntime_cxx_api.h
Normal file
69
onnxruntime/include/experimental_onnxruntime_cxx_api.h
Normal file
@@ -0,0 +1,69 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// Summary: The experimental Ort C++ API is a wrapper around the Ort C++ API.
|
||||
//
|
||||
// This C++ API further simplifies usage and provides support for modern C++ syntax/features
|
||||
// at the cost of some overhead (i.e. using std::string over char *).
|
||||
//
|
||||
// Where applicable, default memory allocator options are used unless explicitly set.
|
||||
//
|
||||
// Experimental components are designed as drop-in replacements of the regular API, requiring
|
||||
// minimal code modifications to use.
|
||||
//
|
||||
// Example: Ort::Session -> Ort::Experimental::Session
|
||||
//
|
||||
// NOTE: Experimental API components are subject to change based on feedback and provide no
|
||||
// guarantee of backwards compatibility in future releases.
|
||||
|
||||
#pragma once
|
||||
#include "onnxruntime_cxx_api.h"
|
||||
|
||||
namespace Ort {
|
||||
namespace Experimental {
|
||||
|
||||
struct Session : Ort::Session {
|
||||
Session(Env& env, std::basic_string<ORTCHAR_T>& model_path, SessionOptions& options)
|
||||
: Ort::Session(env, model_path.data(), options){};
|
||||
Session(Env& env, void* model_data, size_t model_data_length, SessionOptions& options)
|
||||
: Ort::Session(env, model_data, model_data_length, options){};
|
||||
|
||||
// overloaded Run() with sensible defaults
|
||||
std::vector<Ort::Value> Run(const std::vector<std::string>& input_names,
|
||||
const std::vector<Ort::Value>& input_values,
|
||||
const std::vector<std::string>& output_names,
|
||||
const RunOptions& run_options = RunOptions());
|
||||
void Run(const std::vector<std::string>& input_names,
|
||||
const std::vector<Ort::Value>& input_values,
|
||||
const std::vector<std::string>& output_names,
|
||||
std::vector<Ort::Value>& output_values,
|
||||
const RunOptions& run_options = RunOptions());
|
||||
|
||||
// convenience methods that simplify common lower-level API calls
|
||||
std::vector<std::string> GetInputNames() const;
|
||||
std::vector<std::string> GetOutputNames() const;
|
||||
std::vector<std::string> GetOverridableInitializerNames() const;
|
||||
|
||||
// NOTE: shape dimensions may have a negative value to indicate a symbolic/unknown dimension.
|
||||
std::vector<std::vector<int64_t> > GetInputShapes() const;
|
||||
std::vector<std::vector<int64_t> > GetOutputShapes() const;
|
||||
std::vector<std::vector<int64_t> > GetOverridableInitializerShapes() const;
|
||||
};
|
||||
|
||||
struct Value : Ort::Value {
|
||||
Value(OrtValue* p)
|
||||
: Ort::Value(p){};
|
||||
|
||||
template <typename T>
|
||||
static Ort::Value CreateTensor(T* p_data, size_t p_data_element_count, const std::vector<int64_t>& shape);
|
||||
static Ort::Value CreateTensor(void* p_data, size_t p_data_byte_count, const std::vector<int64_t>& shape, ONNXTensorElementDataType type);
|
||||
|
||||
template <typename T>
|
||||
static Ort::Value CreateTensor(const std::vector<int64_t>& shape);
|
||||
static Ort::Value CreateTensor(const std::vector<int64_t>& shape, ONNXTensorElementDataType type);
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#include "experimental_onnxruntime_cxx_inline.h"
|
108
onnxruntime/include/experimental_onnxruntime_cxx_inline.h
Normal file
108
onnxruntime/include/experimental_onnxruntime_cxx_inline.h
Normal file
@@ -0,0 +1,108 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// Do not include this file directly. Please include "experimental_onnxruntime_cxx_api.h" instead.
|
||||
//
|
||||
// These are the inline implementations of the C++ header APIs. They are in this separate file as to not clutter
|
||||
// the main C++ file with implementation details.
|
||||
namespace Ort {
|
||||
namespace Experimental {
|
||||
|
||||
inline std::vector<Ort::Value> Session::Run(const std::vector<std::string>& input_names, const std::vector<Ort::Value>& input_values,
|
||||
const std::vector<std::string>& output_names, const RunOptions& run_options) {
|
||||
size_t output_names_count = GetOutputNames().size();
|
||||
std::vector<Ort::Value> output_values;
|
||||
for (size_t i = 0; i < output_names_count; i++) output_values.emplace_back(nullptr);
|
||||
Run(input_names, input_values, output_names, output_values, run_options);
|
||||
return output_values;
|
||||
}
|
||||
|
||||
inline void Session::Run(const std::vector<std::string>& input_names, const std::vector<Ort::Value>& input_values,
|
||||
const std::vector<std::string>& output_names, std::vector<Ort::Value>& output_values, const RunOptions& run_options) {
|
||||
size_t input_names_count = input_names.size();
|
||||
size_t output_names_count = output_names.size();
|
||||
std::vector<const char*> input_names_(input_names_count, nullptr);
|
||||
size_t i = 0;
|
||||
for (auto it = input_names.begin(); it != input_names.end(); it++) input_names_[i++] = (*it).c_str();
|
||||
std::vector<const char*> output_names_(output_names_count, nullptr);
|
||||
i = 0;
|
||||
for (auto it = output_names.begin(); it != output_names.end(); it++) output_names_[i++] = (*it).c_str();
|
||||
Ort::Session::Run(run_options, input_names_.data(), input_values.data(), input_names_count, output_names_.data(), output_values.data(), output_names_count);
|
||||
}
|
||||
|
||||
inline std::vector<std::string> Session::GetInputNames() const {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
size_t node_count = GetInputCount();
|
||||
std::vector<std::string> out(node_count);
|
||||
for (size_t i = 0; i < node_count; i++) {
|
||||
auto tmp = GetInputNameAllocated(i, allocator);
|
||||
out[i] = tmp.get();
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
inline std::vector<std::string> Session::GetOutputNames() const {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
size_t node_count = GetOutputCount();
|
||||
std::vector<std::string> out(node_count);
|
||||
for (size_t i = 0; i < node_count; i++) {
|
||||
auto tmp = GetOutputNameAllocated(i, allocator);
|
||||
out[i] = tmp.get();
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
inline std::vector<std::string> Session::GetOverridableInitializerNames() const {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
size_t init_count = GetOverridableInitializerCount();
|
||||
std::vector<std::string> out(init_count);
|
||||
for (size_t i = 0; i < init_count; i++) {
|
||||
auto tmp = GetOverridableInitializerNameAllocated(i, allocator);
|
||||
out[i] = tmp.get();
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
inline std::vector<std::vector<int64_t>> Session::GetInputShapes() const {
|
||||
size_t node_count = GetInputCount();
|
||||
std::vector<std::vector<int64_t>> out(node_count);
|
||||
for (size_t i = 0; i < node_count; i++) out[i] = GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
|
||||
return out;
|
||||
}
|
||||
|
||||
inline std::vector<std::vector<int64_t>> Session::GetOutputShapes() const {
|
||||
size_t node_count = GetOutputCount();
|
||||
std::vector<std::vector<int64_t>> out(node_count);
|
||||
for (size_t i = 0; i < node_count; i++) out[i] = GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
|
||||
return out;
|
||||
}
|
||||
|
||||
inline std::vector<std::vector<int64_t>> Session::GetOverridableInitializerShapes() const {
|
||||
size_t init_count = GetOverridableInitializerCount();
|
||||
std::vector<std::vector<int64_t>> out(init_count);
|
||||
for (size_t i = 0; i < init_count; i++) out[i] = GetOverridableInitializerTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline Ort::Value Value::CreateTensor(T* p_data, size_t p_data_element_count, const std::vector<int64_t>& shape) {
|
||||
return CreateTensor(p_data, p_data_element_count * sizeof(T), shape, TypeToTensorType<T>::type);
|
||||
}
|
||||
|
||||
inline Ort::Value Value::CreateTensor(void* p_data, size_t p_data_byte_count, const std::vector<int64_t>& shape, ONNXTensorElementDataType type) {
|
||||
Ort::MemoryInfo info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
|
||||
return Ort::Value::CreateTensor(info, p_data, p_data_byte_count, shape.data(), shape.size(), type);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline Ort::Value Value::CreateTensor(const std::vector<int64_t>& shape) {
|
||||
return CreateTensor(shape, TypeToTensorType<T>::type);
|
||||
}
|
||||
|
||||
inline Ort::Value Value::CreateTensor(const std::vector<int64_t>& shape, ONNXTensorElementDataType type) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
return Ort::Value::CreateTensor(allocator, shape.data(), shape.size(), type);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
3986
onnxruntime/include/onnxruntime_c_api.h
Normal file
3986
onnxruntime/include/onnxruntime_c_api.h
Normal file
File diff suppressed because it is too large
Load Diff
1876
onnxruntime/include/onnxruntime_cxx_api.h
Normal file
1876
onnxruntime/include/onnxruntime_cxx_api.h
Normal file
File diff suppressed because it is too large
Load Diff
1874
onnxruntime/include/onnxruntime_cxx_inline.h
Normal file
1874
onnxruntime/include/onnxruntime_cxx_inline.h
Normal file
File diff suppressed because it is too large
Load Diff
27
onnxruntime/include/onnxruntime_run_options_config_keys.h
Normal file
27
onnxruntime/include/onnxruntime_run_options_config_keys.h
Normal file
@@ -0,0 +1,27 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
/*
|
||||
* This file defines RunOptions Config Keys and format of the Config Values.
|
||||
*
|
||||
* The Naming Convention for a RunOptions Config Key,
|
||||
* "[Area][.[SubArea1].[SubArea2]...].[Keyname]"
|
||||
* Such as "ep.cuda.use_arena"
|
||||
* The Config Key cannot be empty
|
||||
* The maximum length of the Config Key is 128
|
||||
*
|
||||
* The string format of a RunOptions Config Value is defined individually for each Config.
|
||||
* The maximum length of the Config Value is 1024
|
||||
*/
|
||||
|
||||
// Key for enabling shrinkages of user listed device memory arenas.
|
||||
// Expects a list of semi-colon separated key value pairs separated by colon in the following format:
|
||||
// "device_0:device_id_0;device_1:device_id_1"
|
||||
// No white-spaces allowed in the provided list string.
|
||||
// Currently, the only supported devices are : "cpu", "gpu" (case sensitive).
|
||||
// If "cpu" is included in the list, DisableCpuMemArena() API must not be called (i.e.) arena for cpu should be enabled.
|
||||
// Example usage: "cpu:0;gpu:0" (or) "gpu:0"
|
||||
// By default, the value for this key is empty (i.e.) no memory arenas are shrunk
|
||||
static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memory.enable_memory_arena_shrinkage";
|
186
onnxruntime/include/onnxruntime_session_options_config_keys.h
Normal file
186
onnxruntime/include/onnxruntime_session_options_config_keys.h
Normal file
@@ -0,0 +1,186 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
/*
|
||||
* This file defines SessionOptions Config Keys and format of the Config Values.
|
||||
*
|
||||
* The Naming Convention for a SessionOptions Config Key,
|
||||
* "[Area][.[SubArea1].[SubArea2]...].[Keyname]"
|
||||
* Such as "ep.cuda.use_arena"
|
||||
* The Config Key cannot be empty
|
||||
* The maximum length of the Config Key is 128
|
||||
*
|
||||
* The string format of a SessionOptions Config Value is defined individually for each Config.
|
||||
* The maximum length of the Config Value is 1024
|
||||
*/
|
||||
|
||||
// Key for disable PrePacking,
|
||||
// If the config value is set to "1" then the prepacking is disabled, otherwise prepacking is enabled (default value)
|
||||
static const char* const kOrtSessionOptionsConfigDisablePrepacking = "session.disable_prepacking";
|
||||
|
||||
// A value of "1" means allocators registered in the env will be used. "0" means the allocators created in the session
|
||||
// will be used. Use this to override the usage of env allocators on a per session level.
|
||||
static const char* const kOrtSessionOptionsConfigUseEnvAllocators = "session.use_env_allocators";
|
||||
|
||||
// Set to 'ORT' (case sensitive) to load an ORT format model.
|
||||
// If unset, model type will default to ONNX unless inferred from filename ('.ort' == ORT format) or bytes to be ORT
|
||||
static const char* const kOrtSessionOptionsConfigLoadModelFormat = "session.load_model_format";
|
||||
|
||||
// Set to 'ORT' (case sensitive) to save optimized model in ORT format when SessionOptions.optimized_model_path is set.
|
||||
// If unset, format will default to ONNX unless optimized_model_filepath ends in '.ort'.
|
||||
static const char* const kOrtSessionOptionsConfigSaveModelFormat = "session.save_model_format";
|
||||
|
||||
// If a value is "1", flush-to-zero and denormal-as-zero are applied. The default is "0".
|
||||
// When multiple sessions are created, a main thread doesn't override changes from succeeding session options,
|
||||
// but threads in session thread pools follow option changes.
|
||||
// When ORT runs with OpenMP, the same rule is applied, i.e. the first session option to flush-to-zero and
|
||||
// denormal-as-zero is only applied to global OpenMP thread pool, which doesn't support per-session thread pool.
|
||||
// Note that an alternative way not using this option at runtime is to train and export a model without denormals
|
||||
// and that's recommended because turning this option on may hurt model accuracy.
|
||||
static const char* const kOrtSessionOptionsConfigSetDenormalAsZero = "session.set_denormal_as_zero";
|
||||
|
||||
// It controls to run quantization model in QDQ (QuantizelinearDeQuantizelinear) format or not.
|
||||
// "0": enable. ORT does fusion logic for QDQ format.
|
||||
// "1": disable. ORT doesn't do fusion logic for QDQ format.
|
||||
// Its default value is "0"
|
||||
static const char* const kOrtSessionOptionsDisableQuantQDQ = "session.disable_quant_qdq";
|
||||
|
||||
// It controls whether to enable Double QDQ remover and Identical Children Consolidation
|
||||
// "0": not to disable. ORT does remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs
|
||||
// "1": disable. ORT doesn't remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs
|
||||
// Its default value is "0"
|
||||
static const char* const kOrtSessionOptionsDisableDoubleQDQRemover = "session.disable_double_qdq_remover";
|
||||
|
||||
// If set to "1", enables the removal of QuantizeLinear/DequantizeLinear node pairs once all QDQ handling has been
|
||||
// completed. e.g. If after all QDQ handling has completed and we have -> FloatOp -> Q -> DQ -> FloatOp -> the
|
||||
// Q -> DQ could potentially be removed. This will provide a performance benefit by avoiding going from float to
|
||||
// 8-bit and back to float, but could impact accuracy. The impact on accuracy will be model specific and depend on
|
||||
// other factors like whether the model was created using Quantization Aware Training or Post Training Quantization.
|
||||
// As such, it's best to test to determine if enabling this works well for your scenario.
|
||||
// The default value is "0"
|
||||
// Available since version 1.11.
|
||||
static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enable_quant_qdq_cleanup";
|
||||
|
||||
// Enable or disable gelu approximation in graph optimization. "0": disable; "1": enable. The default is "0".
|
||||
// GeluApproximation has side effects which may change the inference results. It is disabled by default due to this.
|
||||
static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation";
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
// Specifies a list of op types for memory footprint reduction.
|
||||
// The value should be a ","-delimited list of pair of
|
||||
// <subgraph string : optimization strategy : number of subgraph to apply>.
|
||||
// For example, "Gelu+Cast+:1:0,Dropout+:1:1".
|
||||
// A valid "subgraph string" should be one subgraph representation output by ORT graph transformations.
|
||||
// "optimization strategy" currently has valid values: 0 - disabled, 1 - recompute.
|
||||
// "number of subgraph to apply" is used to control how many subgraphs to apply optimization, to avoid "oversaving"
|
||||
// the memory.
|
||||
static const char* const kOrtSessionOptionsMemoryOptimizerEnabler = "optimization.enable_memory_optimizer";
|
||||
|
||||
// Specifies the level for detecting subgraphs for memory footprint reduction.
|
||||
// The value should be an integer. The default value is 0.
|
||||
static const char* const kOrtSessionOptionsMemoryOptimizerProbeLevel = "optimization.enable_memory_probe_recompute_level";
|
||||
#endif
|
||||
|
||||
// Enable or disable using device allocator for allocating initialized tensor memory. "1": enable; "0": disable. The default is "0".
|
||||
// Using device allocators means the memory allocation is made using malloc/new.
|
||||
static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = "session.use_device_allocator_for_initializers";
|
||||
|
||||
// Configure whether to allow the inter_op/intra_op threads spinning a number of times before blocking
|
||||
// "0": thread will block if found no job to run
|
||||
// "1": default, thread will spin a number of times before blocking
|
||||
static const char* const kOrtSessionOptionsConfigAllowInterOpSpinning = "session.inter_op.allow_spinning";
|
||||
static const char* const kOrtSessionOptionsConfigAllowIntraOpSpinning = "session.intra_op.allow_spinning";
|
||||
|
||||
// Key for using model bytes directly for ORT format
|
||||
// If a session is created using an input byte array contains the ORT format model data,
|
||||
// By default we will copy the model bytes at the time of session creation to ensure the model bytes
|
||||
// buffer is valid.
|
||||
// Setting this option to "1" will disable copy the model bytes, and use the model bytes directly. The caller
|
||||
// has to guarantee that the model bytes are valid until the ORT session using the model bytes is destroyed.
|
||||
static const char* const kOrtSessionOptionsConfigUseORTModelBytesDirectly = "session.use_ort_model_bytes_directly";
|
||||
|
||||
/// <summary>
|
||||
/// Key for using the ORT format model flatbuffer bytes directly for initializers.
|
||||
/// This avoids copying the bytes and reduces peak memory usage during model loading and initialization.
|
||||
/// Requires `session.use_ort_model_bytes_directly` to be true.
|
||||
/// If set, the flatbuffer bytes provided when creating the InferenceSession MUST remain valid for the entire
|
||||
/// duration of the InferenceSession.
|
||||
/// </summary>
|
||||
static const char* const kOrtSessionOptionsConfigUseORTModelBytesForInitializers =
|
||||
"session.use_ort_model_bytes_for_initializers";
|
||||
|
||||
// This should only be specified when exporting an ORT format model for use on a different platform.
|
||||
// If the ORT format model will be used on ARM platforms set to "1". For other platforms set to "0"
|
||||
// Available since version 1.11.
|
||||
static const char* const kOrtSessionOptionsQDQIsInt8Allowed = "session.qdqisint8allowed";
|
||||
|
||||
// x64 SSE4.1/AVX2/AVX512(with no VNNI) has overflow problem with quantizied matrix multiplication with U8S8.
|
||||
// To avoid this we need to use slower U8U8 matrix multiplication instead. This option, if
|
||||
// turned on, use slower U8U8 matrix multiplications. Only effective with AVX2 or AVX512
|
||||
// platforms.
|
||||
static const char* const kOrtSessionOptionsAvx2PrecisionMode = "session.x64quantprecision";
|
||||
|
||||
// Specifies how minimal build graph optimizations are handled in a full build.
|
||||
// These optimizations are at the extended level or higher.
|
||||
// Possible values and their effects are:
|
||||
// "save": Save runtime optimizations when saving an ORT format model.
|
||||
// "apply": Only apply optimizations available in a minimal build.
|
||||
// ""/<unspecified>: Apply optimizations available in a full build.
|
||||
// Available since version 1.11.
|
||||
static const char* const kOrtSessionOptionsConfigMinimalBuildOptimizations =
|
||||
"optimization.minimal_build_optimizations";
|
||||
|
||||
// Note: The options specific to an EP should be specified prior to appending that EP to the session options object in
|
||||
// order for them to take effect.
|
||||
|
||||
// Specifies a list of stop op types. Nodes of a type in the stop op types and nodes downstream from them will not be
|
||||
// run by the NNAPI EP.
|
||||
// The value should be a ","-delimited list of op types. For example, "Add,Sub".
|
||||
// If not specified, the default set of stop ops is used. To specify an empty stop ops types list and disable stop op
|
||||
// exclusion, set the value to "".
|
||||
static const char* const kOrtSessionOptionsConfigNnapiEpPartitioningStopOps = "ep.nnapi.partitioning_stop_ops";
|
||||
|
||||
// Enabling dynamic block-sizing for multithreading.
|
||||
// With a positive value, thread pool will split a task of N iterations to blocks of size starting from:
|
||||
// N / (num_of_threads * dynamic_block_base)
|
||||
// As execution progresses, the size will decrease according to the diminishing residual of N,
|
||||
// meaning the task will be distributed in smaller granularity for better parallelism.
|
||||
// For some models, it helps to reduce the variance of E2E inference latency and boost performance.
|
||||
// The feature will not function by default, specify any positive integer, e.g. "4", to enable it.
|
||||
// Available since version 1.11.
|
||||
static const char* const kOrtSessionOptionsConfigDynamicBlockBase = "session.dynamic_block_base";
|
||||
|
||||
// This option allows to decrease CPU usage between infrequent
|
||||
// requests and forces any TP threads spinning stop immediately when the last of
|
||||
// concurrent Run() call returns.
|
||||
// Spinning is restarted on the next Run() call.
|
||||
// Applies only to internal thread-pools
|
||||
static const char* const kOrtSessionOptionsConfigForceSpinningStop = "session.force_spinning_stop";
|
||||
|
||||
// "1": all inconsistencies encountered during shape and type inference
|
||||
// will result in failures.
|
||||
// "0": in some cases warnings will be logged but processing will continue. The default.
|
||||
// May be useful to expose bugs in models.
|
||||
static const char* const kOrtSessionOptionsConfigStrictShapeTypeInference = "session.strict_shape_type_inference";
|
||||
|
||||
// The file saves configuration for partitioning node among logic streams
|
||||
static const char* const kNodePartitionConfigFile = "session.node_partition_config_file";
|
||||
|
||||
// This Option allows setting affinities for intra op threads.
|
||||
// Affinity string follows format:
|
||||
// logical_processor_id,logical_processor_id;logical_processor_id,logical_processor_id
|
||||
// Semicolon isolates configurations among threads, while comma split processors where ith thread expected to attach to.
|
||||
// e.g.1,2,3;4,5
|
||||
// specifies affinities for two threads, with the 1st thread attach to the 1st, 2nd, and 3rd processor, and 2nd thread to the 4th and 5th.
|
||||
// To ease the configuration, an "interval" is also allowed:
|
||||
// e.g. 1-8;8-16;17-24
|
||||
// orders that the 1st thread runs on first eight processors, 2nd thread runs on next eight processors, and so forth.
|
||||
// Note:
|
||||
// 1. Once set, the number of thread affinities must equal to intra_op_num_threads - 1, since ort does not set affinity on the main thread which
|
||||
// is started and managed by the calling app;
|
||||
// 2. For windows, ort will infer the group id from a logical processor id, for example, assuming there are two groups with each has 64 logical processors,
|
||||
// an id of 64 will be inferred as the last processor of the 1st group, while 65 will be interpreted as the 1st processor of the second group.
|
||||
// Hence 64-65 is an invalid configuration, because a windows thread cannot be attached to processors across group boundary.
|
||||
static const char* const kOrtSessionOptionsConfigIntraOpThreadAffinities = "session.intra_op_thread_affinities";
|
18
onnxruntime/include/provider_options.h
Normal file
18
onnxruntime/include/provider_options.h
Normal file
@@ -0,0 +1,18 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
// data types for execution provider options
|
||||
|
||||
using ProviderOptions = std::unordered_map<std::string, std::string>;
|
||||
using ProviderOptionsVector = std::vector<ProviderOptions>;
|
||||
using ProviderOptionsMap = std::unordered_map<std::string, ProviderOptions>;
|
||||
|
||||
} // namespace onnxruntime
|
14
onnxruntime/include/tensorrt_provider_factory.h
Normal file
14
onnxruntime/include/tensorrt_provider_factory.h
Normal file
@@ -0,0 +1,14 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "onnxruntime_c_api.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
34
onnxruntime/include/tensorrt_provider_options.h
Normal file
34
onnxruntime/include/tensorrt_provider_options.h
Normal file
@@ -0,0 +1,34 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
/// <summary>
|
||||
/// Options for the TensorRT provider that are passed to SessionOptionsAppendExecutionProvider_TensorRT_V2.
|
||||
/// Please note that this struct is *similar* to OrtTensorRTProviderOptions but only to be used internally.
|
||||
/// Going forward, new trt provider options are to be supported via this struct and usage of the publicly defined
|
||||
/// OrtTensorRTProviderOptions will be deprecated over time.
|
||||
/// User can only get the instance of OrtTensorRTProviderOptionsV2 via CreateTensorRTProviderOptions.
|
||||
/// </summary>
|
||||
struct OrtTensorRTProviderOptionsV2 {
|
||||
int device_id; // cuda device id.
|
||||
int has_user_compute_stream; // indicator of user specified CUDA compute stream.
|
||||
void* user_compute_stream; // user specified CUDA compute stream.
|
||||
int trt_max_partition_iterations; // maximum iterations for TensorRT parser to get capability
|
||||
int trt_min_subgraph_size; // minimum size of TensorRT subgraphs
|
||||
size_t trt_max_workspace_size; // maximum workspace size for TensorRT.
|
||||
int trt_fp16_enable; // enable TensorRT FP16 precision. Default 0 = false, nonzero = true
|
||||
int trt_int8_enable; // enable TensorRT INT8 precision. Default 0 = false, nonzero = true
|
||||
const char* trt_int8_calibration_table_name; // TensorRT INT8 calibration table name.
|
||||
int trt_int8_use_native_calibration_table; // use native TensorRT generated calibration table. Default 0 = false, nonzero = true
|
||||
int trt_dla_enable; // enable DLA. Default 0 = false, nonzero = true
|
||||
int trt_dla_core; // DLA core number. Default 0
|
||||
int trt_dump_subgraphs; // dump TRT subgraph. Default 0 = false, nonzero = true
|
||||
int trt_engine_cache_enable; // enable engine caching. Default 0 = false, nonzero = true
|
||||
const char* trt_engine_cache_path; // specify engine cache path
|
||||
int trt_engine_decryption_enable; // enable engine decryption. Default 0 = false, nonzero = true
|
||||
const char* trt_engine_decryption_lib_path; // specify engine decryption library path
|
||||
int trt_force_sequential_engine_build; // force building TensorRT engine sequentially. Default 0 = false, nonzero = true
|
||||
int trt_context_memory_sharing_enable; // enable context memory sharing between subgraphs. Default 0 = false, nonzero = true
|
||||
int trt_layer_norm_fp32_fallback; // force Pow + Reduce ops in layer norm to FP32. Default 0 = false, nonzero = true
|
||||
};
|
21
onnxruntime/onnxruntime_LICENSE
Normal file
21
onnxruntime/onnxruntime_LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) Microsoft Corporation
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
66
onnxruntime_test.go
Normal file
66
onnxruntime_test.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package onnxruntime
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// This type is read from JSON and used to determine the inputs and expected
|
||||
// outputs for an ONNX network.
|
||||
type testInputsInfo struct {
|
||||
InputShape []int `json:input_shape`
|
||||
FlattenedInput []float32 `json:flattened_input`
|
||||
OutputShape []int `json:output_shape`
|
||||
FlattenedOutput []float32 `json:flattened_output`
|
||||
}
|
||||
|
||||
// This must be called prior to running each test.
|
||||
func InitializeRuntime(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
SetSharedLibraryPath("test_data/onnxruntime.dll")
|
||||
} else {
|
||||
if runtime.GOARCH == "arm64" {
|
||||
SetSharedLibraryPath("test_data/onnxruntime_arm64.so")
|
||||
} else {
|
||||
SetSharedLibraryPath("test_data/onnxruntime.so")
|
||||
}
|
||||
}
|
||||
e := InitializeEnvironment()
|
||||
if e != nil {
|
||||
t.Logf("Failed setting up onnxruntime environment: %s\n", e)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
// Used to obtain the shape
|
||||
func parseInputsJSON(path string, t *testing.T) *testInputsInfo {
|
||||
toReturn := testInputsInfo{}
|
||||
f, e := os.Open(path)
|
||||
if e != nil {
|
||||
t.Logf("Failed opening %s: %s\n", path, e)
|
||||
t.FailNow()
|
||||
}
|
||||
defer f.Close()
|
||||
d := json.NewDecoder(f)
|
||||
e = d.Decode(&toReturn)
|
||||
if e != nil {
|
||||
t.Logf("Failed decoding %s: %s\n", path, e)
|
||||
t.FailNow()
|
||||
}
|
||||
return &toReturn
|
||||
}
|
||||
|
||||
func TestExampleNetwork(t *testing.T) {
|
||||
InitializeRuntime(t)
|
||||
_ = parseInputsJSON("test_data/example_network_results.json", t)
|
||||
|
||||
// TODO: More tests here to run the network, once that's supported.
|
||||
|
||||
e := CleanupEnvironment()
|
||||
if e != nil {
|
||||
t.Logf("Failed cleaning up the environment: %s\n", e)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
27
onnxruntime_wrapper.c
Normal file
27
onnxruntime_wrapper.c
Normal file
@@ -0,0 +1,27 @@
|
||||
#include "onnxruntime_wrapper.h"
|
||||
|
||||
static const OrtApi *ort_api = NULL;
|
||||
|
||||
int SetAPIFromBase(OrtApiBase *api_base) {
|
||||
if (!api_base) return 1;
|
||||
ort_api = api_base->GetApi(ORT_API_VERSION);
|
||||
if (!ort_api) return 2;
|
||||
return 0;
|
||||
}
|
||||
|
||||
void ReleaseOrtStatus(OrtStatus *status) {
|
||||
ort_api->ReleaseStatus(status);
|
||||
}
|
||||
|
||||
void ReleaseOrtEnv(OrtEnv *env) {
|
||||
ort_api->ReleaseEnv(env);
|
||||
}
|
||||
|
||||
OrtStatus *CreateOrtEnv(char *name, OrtEnv **env) {
|
||||
return ort_api->CreateEnv(ORT_LOGGING_LEVEL_ERROR, name, env);
|
||||
}
|
||||
|
||||
const char *GetErrorMessage(OrtStatus *status) {
|
||||
if (!status) return "No error (NULL status)";
|
||||
return ort_api->GetErrorMessage(status);
|
||||
}
|
44
onnxruntime_wrapper.h
Normal file
44
onnxruntime_wrapper.h
Normal file
@@ -0,0 +1,44 @@
|
||||
#ifndef ONNXRUNTIME_WRAPPER_H
|
||||
#define ONNXRUNTIME_WRAPPER_H
|
||||
|
||||
// We want to always use the unix-like onnxruntime C APIs, even on Windows, so
|
||||
// we need to undefine _WIN32 before including onnxruntime_c_api.h. However,
|
||||
// this requires a careful song-and-dance.
|
||||
|
||||
// First, include these common headers, as they get transitively included by
|
||||
// onnxruntime_c_api.h. We need to include them ourselves, first, so that the
|
||||
// preprocessor will skip then while _WIN32 is undefined.
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
// Next, we actually include the header.
|
||||
#undef _WIN32
|
||||
#include "onnxruntime_c_api.h"
|
||||
|
||||
// ... However, mingw will complain if _WIN32 is *not* defined! So redefine it.
|
||||
#define _WIN32
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Takes a pointer to the api_base struct in order to obtain the OrtApi
|
||||
// pointer. Intended to be called from Go. Returns nonzero on error.
|
||||
int SetAPIFromBase(OrtApiBase *api_base);
|
||||
|
||||
// Wraps calling ort_api->ReleaseStatus(status)
|
||||
void ReleaseOrtStatus(OrtStatus *status);
|
||||
|
||||
// Wraps calling ort_api->CreateEnv. Returns a non-NULL status on error.
|
||||
OrtStatus *CreateOrtEnv(char *name, OrtEnv **env);
|
||||
|
||||
// Releases the given OrtEnv.
|
||||
void ReleaseOrtEnv(OrtEnv *env);
|
||||
|
||||
// Returns the message associated with the given ORT status.
|
||||
const char *GetErrorMessage(OrtStatus *status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
#endif // ONNXRUNTIME_WRAPPER_H
|
69
setup_env.go
Normal file
69
setup_env.go
Normal file
@@ -0,0 +1,69 @@
|
||||
//go:build !windows
|
||||
|
||||
package onnxruntime
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
/*
|
||||
#cgo LDFLAGS: -ldl
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include "onnxruntime_wrapper.h"
|
||||
|
||||
typedef OrtApiBase* (*GetOrtApiBaseFunction)(void);
|
||||
|
||||
// Since Go can't call C function pointers directly, we just use this helper
|
||||
// when calling GetApiBase
|
||||
OrtApiBase *CallGetAPIBaseFunction(void *fn) {
|
||||
OrtApiBase *to_return = ((GetOrtApiBaseFunction) fn)();
|
||||
return to_return;
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
// This file includes the code for loading the onnxruntime and setting up the
|
||||
// environment on non-Windows systems. For now, it has only been tested on
|
||||
// Linux.
|
||||
|
||||
// This will contain the handle to the onnxruntime shared library if it has
|
||||
// been loaded successfully.
|
||||
var libraryHandle unsafe.Pointer
|
||||
|
||||
func platformCleanup() error {
|
||||
v, e := C.dlclose(libraryHandle)
|
||||
if v != 0 {
|
||||
return fmt.Errorf("Error closing the library: %w", e)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func platformInitializeEnvironment() error {
|
||||
cName := C.CString(onnxSharedLibraryPath)
|
||||
defer C.free(unsafe.Pointer(cName))
|
||||
handle := C.dlopen(cName, C.RTLD_LAZY)
|
||||
if handle == nil {
|
||||
msg := C.GoString(C.dlerror())
|
||||
return fmt.Errorf("Error loading ONNX shared library \"%s\": %s",
|
||||
onnxSharedLibraryPath, msg)
|
||||
}
|
||||
cFunctionName := C.CString("OrtGetApiBase")
|
||||
defer C.free(unsafe.Pointer(cFunctionName))
|
||||
getAPIBaseProc := C.dlsym(handle, cFunctionName)
|
||||
if getAPIBaseProc == nil {
|
||||
C.dlclose(handle)
|
||||
msg := C.GoString(C.dlerror())
|
||||
return fmt.Errorf("Error looking up OrtGetApiBase in \"%s\": %s",
|
||||
onnxSharedLibraryPath, msg)
|
||||
}
|
||||
ortAPIBase := C.CallGetAPIBaseFunction(getAPIBaseProc)
|
||||
tmp := C.SetAPIFromBase((*C.OrtApiBase)(unsafe.Pointer(ortAPIBase)))
|
||||
if tmp != 0 {
|
||||
C.dlclose(handle)
|
||||
return fmt.Errorf("Error setting ORT API base: %d", tmp)
|
||||
}
|
||||
libraryHandle = handle
|
||||
return nil
|
||||
}
|
55
setup_env_windows.go
Normal file
55
setup_env_windows.go
Normal file
@@ -0,0 +1,55 @@
|
||||
//go:build windows
|
||||
|
||||
package onnxruntime
|
||||
|
||||
// This file includes the Windows-specific code for loading the onnxruntime
|
||||
// library and setting up the environment.
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// #include "onnxruntime_wrapper.h"
|
||||
import "C"
|
||||
|
||||
// This will contain the handle to the onnxruntime dll if it has been loaded
|
||||
// successfully.
|
||||
var libraryHandle syscall.Handle
|
||||
|
||||
func platformCleanup() error {
|
||||
e := syscall.FreeLibrary(libraryHandle)
|
||||
libraryHandle = 0
|
||||
return e
|
||||
}
|
||||
|
||||
func platformInitializeEnvironment() error {
|
||||
handle, e := syscall.LoadLibrary(onnxSharedLibraryPath)
|
||||
if e != nil {
|
||||
return fmt.Errorf("Error loading ONNX shared library \"%s\": %w",
|
||||
onnxSharedLibraryPath, e)
|
||||
}
|
||||
getApiBaseProc, e := syscall.GetProcAddress(handle, "OrtGetApiBase")
|
||||
if e != nil {
|
||||
syscall.FreeLibrary(handle)
|
||||
return fmt.Errorf("Error finding OrtGetApiBase function in %s: %w",
|
||||
onnxSharedLibraryPath, e)
|
||||
}
|
||||
ortApiBase, _, e := syscall.SyscallN(uintptr(getApiBaseProc), 0)
|
||||
if ortApiBase == 0 {
|
||||
syscall.FreeLibrary(handle)
|
||||
if e != nil {
|
||||
return fmt.Errorf("Error calling OrtGetApiBase: %w", e)
|
||||
} else {
|
||||
return fmt.Errorf("Error calling OrtGetApiBase")
|
||||
}
|
||||
}
|
||||
tmp := C.SetAPIFromBase((*C.OrtApiBase)(unsafe.Pointer(ortApiBase)))
|
||||
if tmp != 0 {
|
||||
syscall.FreeLibrary(handle)
|
||||
return fmt.Errorf("Error setting ORT API base: %d", tmp)
|
||||
}
|
||||
libraryHandle = handle
|
||||
return nil
|
||||
}
|
BIN
test_data/example_network.onnx
Normal file
BIN
test_data/example_network.onnx
Normal file
Binary file not shown.
22
test_data/example_network_results.json
Normal file
22
test_data/example_network_results.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"input_shape": [
|
||||
1,
|
||||
1,
|
||||
4
|
||||
],
|
||||
"flattened_input": [
|
||||
0.19473445415496826,
|
||||
0.9139836430549622,
|
||||
0.7043011784553528,
|
||||
0.7685686945915222
|
||||
],
|
||||
"output_shape": [
|
||||
1,
|
||||
1,
|
||||
2
|
||||
],
|
||||
"flattened_output": [
|
||||
2.581585645675659,
|
||||
0.6283518075942993
|
||||
]
|
||||
}
|
131
test_data/generate_network.py
Normal file
131
test_data/generate_network.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# This script sets up and "trains" a toy pytorch network, that trains a NN to
|
||||
# map a 1x4 vector to a 1x2 vector containing [sum, max difference] of the
|
||||
# input values. Finally, it exports the network to an ONNX file to use in
|
||||
# testing.
|
||||
import torch
|
||||
from torch.nn.functional import relu
|
||||
import json
|
||||
|
||||
def fake_dataset(size):
|
||||
""" Returns a dataset filled with our fake training data. """
|
||||
inputs = torch.rand((size, 1, 4))
|
||||
outputs = torch.zeros((size, 1, 2))
|
||||
for i in range(size):
|
||||
outputs[i][0][0] = inputs[i][0].sum()
|
||||
outputs[i][0][1] = inputs[i][0].max() - inputs[i][0].min()
|
||||
return torch.utils.data.TensorDataset(inputs, outputs)
|
||||
|
||||
class SumAndDiffModel(torch.nn.Module):
|
||||
""" Just a standard, fairly minimal, pytorch model for generating the NN.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# We'll do four 1x4 convolutions to make the network more interesting.
|
||||
self.conv = torch.nn.Conv1d(1, 4, 4)
|
||||
# We'll follow the conv with a FC layer to produce the outputs. The
|
||||
# input to the FC layer are the 4 conv outputs concatenated with the
|
||||
# original input.
|
||||
self.fc = torch.nn.Linear(8, 2)
|
||||
|
||||
def forward(self, data):
|
||||
batch_size = len(data)
|
||||
conv_out = relu(self.conv(data))
|
||||
conv_flattened = torch.flatten(conv_out, start_dim=1)
|
||||
data_flattened = torch.flatten(data, start_dim=1)
|
||||
combined = torch.cat((conv_flattened, data_flattened), dim=1)
|
||||
output = relu(self.fc(combined))
|
||||
output = output.view(batch_size, 1, 2)
|
||||
return output
|
||||
|
||||
def get_test_loss(model, loader, loss_function):
|
||||
""" Just runs a single epoch of data from the given loader. Returns the
|
||||
average loss per batch. The provided model is expected to be in eval mode.
|
||||
"""
|
||||
i = 0
|
||||
total_loss = 0.0
|
||||
for in_data, desired_result in loader:
|
||||
produced_result = model(in_data)
|
||||
loss = loss_function(desired_result, produced_result)
|
||||
total_loss += loss.item()
|
||||
i += 1
|
||||
return total_loss / i
|
||||
|
||||
def save_model(model, output_filename):
|
||||
""" Saves the model to an onnx file with the given name. Assumes the model
|
||||
is in eval mode. """
|
||||
print("Saving network to " + output_filename)
|
||||
dummy_input = torch.rand(1, 1, 4)
|
||||
input_names = ["1x4 Input Vector"]
|
||||
output_names = ["1x2 Output Vector"]
|
||||
torch.onnx.export(model, dummy_input, output_filename,
|
||||
input_names=input_names, output_names=output_names)
|
||||
return None
|
||||
|
||||
def print_sample(model):
|
||||
""" Prints a sample input and output computation using the model. Expects
|
||||
the model to be in eval mode. """
|
||||
example_input = torch.rand(1, 1, 4)
|
||||
result = model(example_input)
|
||||
print("Sample model execution:")
|
||||
print(" Example input: " + str(example_input))
|
||||
print(" Produced output: " + str(result))
|
||||
return None
|
||||
|
||||
def save_sample_json(model, output_name):
|
||||
""" Saves a JSON file containing an input and an output from the network,
|
||||
for use when validating execution of the ONNX network. """
|
||||
example_input = torch.rand(1, 1, 4)
|
||||
result = model(example_input)
|
||||
json_content = {}
|
||||
json_content["input_shape"] = list(example_input.shape)
|
||||
json_content["flattened_input"] = list(example_input.flatten().tolist())
|
||||
json_content["output_shape"] = list(result.shape)
|
||||
json_content["flattened_output"] = list(result.flatten().tolist())
|
||||
with open(output_name, "w") as f:
|
||||
json.dump(json_content, f, indent=" ")
|
||||
return None
|
||||
|
||||
def main():
|
||||
print("Generating train and test datasets...")
|
||||
train_data = fake_dataset(100 * 1000)
|
||||
train_loader = torch.utils.data.DataLoader(dataset=train_data,
|
||||
batch_size=16, shuffle=True)
|
||||
test_data = fake_dataset(10 * 1000)
|
||||
test_loader = torch.utils.data.DataLoader(dataset=test_data,
|
||||
batch_size=16)
|
||||
model = SumAndDiffModel()
|
||||
model.train()
|
||||
loss_function = torch.nn.L1Loss(reduction="mean")
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.00005)
|
||||
for epoch in range(8):
|
||||
i = 0
|
||||
total_loss = 0.0
|
||||
for in_data, desired_result in train_loader:
|
||||
i += 1
|
||||
produced_result = model(in_data)
|
||||
loss = loss_function(desired_result, produced_result)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
total_loss += loss.item()
|
||||
if (i % 1000) == 1:
|
||||
print("Epoch %d, iteration %d. Current loss = %f" % (epoch, i,
|
||||
loss.item()))
|
||||
train_loss = total_loss / i
|
||||
print(" => Average train-set loss: " + str(train_loss))
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
test_loss = get_test_loss(model, test_loader, loss_function)
|
||||
model.train()
|
||||
print(" => Average test-set loss: " + str(test_loss))
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
save_model(model, "example_network.onnx")
|
||||
save_sample_json(model, "example_network_results.json")
|
||||
print_sample(model)
|
||||
print("Done!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
BIN
test_data/onnxruntime.dll
Normal file
BIN
test_data/onnxruntime.dll
Normal file
Binary file not shown.
Reference in New Issue
Block a user