Initial commit

- Sets up the onnxruntime environment, but doesn't load or run networks
   yet.

 - The things builds and run on Windows.

 - Still working on getting the Linux (arm64 for now) test to work.
This commit is contained in:
yalue
2023-01-28 13:47:09 -05:00
commit 75e3434f7e
30 changed files with 8993 additions and 0 deletions

3
.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
onnx_example_application/onnx_example_application
onnx_example_application/onnx_example_application.exe

19
LICENSE Normal file
View File

@@ -0,0 +1,19 @@
Copyright (c) 2023 Nathan Otterness
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

98
README.md Normal file
View File

@@ -0,0 +1,98 @@
Cross-Platform `onnxruntime` Wrapper for Go
===========================================
About
-----
This library seeks to provide an interface for loading and executing neural
networks from Go(lang) code, while remaining as simple to use as possible.
The [onnxruntime](https://github.com/microsoft/onnxruntime) library provides a
way to load and execute ONNX-format neural networks, though the library
primarily supports C and C++ APIs. Several efforts exist to have written
Go(lang) wrappers for the `onnxruntime` library, but as far as I can tell, none
of these existing Go wrappers support Windows. This is due to the fact that
Microsoft's `onnxruntime` library assumes the user will be using the MSVC
compiler on Windows systems, while CGo on Windows requires using Mingw.
This wrapper works around the issues by manually loading the `onnxruntime`
shared library, removing any dependency on the `onnxruntime` source code beyond
the header files. Naturally, this approach works equally well on non-Windows
systems.
Requirements
------------
To use this library, you'll need a version of Go with cgo support. If you are
not using an amd64 version of Windows or Linux (or if you want to provide your
own library for some other reason), you simply need to provide the correct path
to the shared library when initializing the wrapper. This is seen in the first
few lines of the following example.
Example Usage
-------------
The following example illustrates how this library can be used to load and run
an ONNX network taking a single input tensor and producing a single output
tensor, both of which contain 32-bit floating point values.
```
import (
"fmt"
"github.com/yalue/onnxruntime"
"os"
)
func main() {
// This line may be optional, by default the library will try to load
// "onnxruntime.dll" on Windows, and "onnxruntime.so" on any other system.
onnxruntime.SetSharedLibraryPath("path/to/onnxruntime.so")
err := onnxruntime.InitializeEnvironment()
if err != nil {
fmt.Printf("Failed initializing onnxruntime: %s\n", err)
os.Exit(1)
}
defer onnxruntime.CleanupEnvironment()
// We'll assume that network.onnx takes a single 2x3x4 input tensor and
// produces a 1x2x2 output tensor.
inputShape := []int64{1, 2, 3}
outputShape := []int64{1, 2, 2}
session, err := onnxruntime.CreateSimpleSession("path/to/network.onnx",
inputShape, outputShape)
if err != nil {
fmt.Printf("Error creating session: %s\n", err)
os.Exit(1)
}
defer session.Destroy()
// Network inputs must be provided as flattened slices of floats. Run() can
// be called as many times as necessary with a single session.
err := session.Run([]float32{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6})
if err != nil {
fmt.Printf("Error running the network: %s\n", err)
os.Exit(1)
}
// This will be a flattened slice containing the elements in the 1x2x2
// output tensor.
results := session.Results()
// ...
}
```
Full Documentation
------------------
The above example uses a single input and produces a single output, all with
`float32` data. The `CreateSimpleSession` function supports this, as it is
expected to be a common use case. However, the library supports far more
options, i.e. using the `CreateSession` function when setting up a session.
The full documentation can be found at [pkg.go.dev](https://pkg.go.dev/github.com/yalue/onnxruntime).

3
go.mod Normal file
View File

@@ -0,0 +1,3 @@
module github.com/yalue/onnxruntime
go 1.19

View File

@@ -0,0 +1,7 @@
An Example Application Using the `onnxruntime` Go Wrapper
=========================================================
To run this application, navigate to this directory and compile it using
`go build`. Afterwards, run it using `./onnx_example_application` (or
`onnx_example_application.exe` on Windows).

Binary file not shown.

View File

@@ -0,0 +1,38 @@
// This application loads a test ONNX network and executes it on some fixed
// data. It serves as an example of how to use the onnxruntime wrapper library.
package main
import (
"fmt"
"github.com/yalue/onnxruntime"
"os"
"runtime"
)
func run() int {
if runtime.GOOS == "windows" {
onnxruntime.SetSharedLibraryPath("../test_data/onnxruntime.dll")
} else {
onnxruntime.SetSharedLibraryPath("../test_data/onnxruntime.so")
}
e := onnxruntime.InitializeEnvironment()
if e != nil {
fmt.Printf("Error initializing the onnxruntime environment: %s\n", e)
return 1
}
fmt.Printf("The onnxruntime environment initialized OK.\n")
// Ordinarily, it is probably fine to call this using defer, but we do it
// here just so we can print a status message after the cleanup completes.
e = onnxruntime.CleanupEnvironment()
if e != nil {
fmt.Printf("Error cleaning up the environment: %s\n", e)
return 1
}
fmt.Printf("The onnxruntime environment was cleaned up OK.\n")
return 0
}
func main() {
os.Exit(run())
}

83
onnxruntime.go Normal file
View File

@@ -0,0 +1,83 @@
// This library wraps the C "onnxruntime" library maintained at
// https://github.com/microsoft/onnxruntime. It seeks to provide as simple an
// interface as possible to load and run ONNX-format neural networks from
// Go code.
package onnxruntime
import (
"fmt"
"unsafe"
)
// #cgo CFLAGS: -I${SRCDIR}/onnxruntime/include
//
// #include "onnxruntime_wrapper.h"
import "C"
// This string should be the path to onnxruntime.so, or onnxruntime.dll.
var onnxSharedLibraryPath string
// For simplicity, this library maintains a single ORT environment internally.
var ortEnv *C.OrtEnv
// Does two things: converts the given OrtStatus to a Go error, and releases
// the status. If the status is nil, this does nothing and returns nil.
func statusToError(status *C.OrtStatus) error {
if status == nil {
return nil
}
msg := C.GetErrorMessage(status)
toReturn := C.GoString(msg)
C.ReleaseOrtStatus(status)
return fmt.Errorf("%s", toReturn)
}
// Use this function to set the path to the "onnxruntime.so" or
// "onnxruntime.dll" function. By default, it will be set to "onnxruntime.so"
// on non-Windows systems, and "onnxruntime.dll" on Windows. Users wishing to
// specify a particular location of this library must call this function prior
// to calling onnxruntime.InitializeEnvironment().
func SetSharedLibraryPath(path string) {
onnxSharedLibraryPath = path
}
// Call this function to initialize the internal onnxruntime environment. If
// this doesn't return an error, the caller will be responsible for calling
// CleanupEnvironment to free the onnxruntime state when no longer needed.
func InitializeEnvironment() error {
if ortEnv != nil {
return fmt.Errorf("The onnxruntime has already been initialized")
}
// Do the windows- or linux- specific initialization first.
e := platformInitializeEnvironment()
if e != nil {
return fmt.Errorf("Platform-specific initialization failed: %w", e)
}
name := C.CString("Golang onnxruntime environment")
defer C.free(unsafe.Pointer(name))
status := C.CreateOrtEnv(name, &ortEnv)
if status != nil {
return fmt.Errorf("Error creating ORT environment: %w",
statusToError(status))
}
return nil
}
// Call this function to cleanup the internal onnxruntime environment when it
// is no longer needed.
func CleanupEnvironment() error {
var e error
// TODO: Implement CleanupEnvironment
// Prior to calling platformCleanup, we need to:
// - Destroy the environment
// - Destroy any remaining active sessions?
// platformCleanup primarily unloads the library, so we need to call it
// last, after unloading the library.
e = platformCleanup()
if e != nil {
return fmt.Errorf("Platform-specific cleanup failed: %w", e)
}
return nil
}

View File

@@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "onnxruntime_c_api.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \param use_arena zero: false. non-zero: true.
*/
ORT_EXPORT
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena)
ORT_ALL_ARGS_NONNULL;
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,96 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <atomic>
#include <memory>
#include "core/common/common.h"
#include "core/common/status.h"
#include "core/platform/threadpool.h"
#include "core/common/logging/logging.h"
#include "core/framework/allocator.h"
struct OrtThreadingOptions;
namespace onnxruntime {
/** TODO: remove this class
Provides the runtime environment for onnxruntime.
Create one instance for the duration of execution.
*/
class Environment {
public:
/**
Create and initialize the runtime environment.
@param logging manager instance that will enable per session logger output using
session_options.session_logid as the logger id in messages.
If nullptr, the default LoggingManager MUST have been created previously as it will be used
for logging. This will use the default logger id in messages.
See core/common/logging/logging.h for details, and how LoggingManager::DefaultLogger works.
@param tp_options optional set of parameters controlling the number of intra and inter op threads for the global
threadpools.
@param create_global_thread_pools determine if this function will create the global threadpools or not.
*/
static Status Create(std::unique_ptr<logging::LoggingManager> logging_manager,
std::unique_ptr<Environment>& environment,
const OrtThreadingOptions* tp_options = nullptr,
bool create_global_thread_pools = false);
logging::LoggingManager* GetLoggingManager() const {
return logging_manager_.get();
}
void SetLoggingManager(std::unique_ptr<onnxruntime::logging::LoggingManager> logging_manager) {
logging_manager_ = std::move(logging_manager);
}
onnxruntime::concurrency::ThreadPool* GetIntraOpThreadPool() const {
return intra_op_thread_pool_.get();
}
onnxruntime::concurrency::ThreadPool* GetInterOpThreadPool() const {
return inter_op_thread_pool_.get();
}
bool EnvCreatedWithGlobalThreadPools() const {
return create_global_thread_pools_;
}
/**
* Registers an allocator for sharing between multiple sessions.
* Return an error if an allocator with the same OrtMemoryInfo is already registered.
*/
Status RegisterAllocator(AllocatorPtr allocator);
/**
* Creates and registers an allocator for sharing between multiple sessions.
* Return an error if an allocator with the same OrtMemoryInfo is already registered.
*/
Status CreateAndRegisterAllocator(const OrtMemoryInfo& mem_info, const OrtArenaCfg* arena_cfg = nullptr);
/**
* Returns the list of registered allocators in this env.
*/
const std::vector<AllocatorPtr>& GetRegisteredSharedAllocators() const {
return shared_allocators_;
}
/**
* Removes registered allocator that was previously registered for sharing between multiple sessions.
*/
Status UnregisterAllocator(const OrtMemoryInfo& mem_info);
Environment() = default;
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Environment);
Status Initialize(std::unique_ptr<logging::LoggingManager> logging_manager,
const OrtThreadingOptions* tp_options = nullptr,
bool create_global_thread_pools = false);
std::unique_ptr<logging::LoggingManager> logging_manager_;
std::unique_ptr<onnxruntime::concurrency::ThreadPool> intra_op_thread_pool_;
std::unique_ptr<onnxruntime::concurrency::ThreadPool> inter_op_thread_pool_;
bool create_global_thread_pools_{false};
std::vector<AllocatorPtr> shared_allocators_;
};
} // namespace onnxruntime

View File

@@ -0,0 +1,69 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// Summary: The experimental Ort C++ API is a wrapper around the Ort C++ API.
//
// This C++ API further simplifies usage and provides support for modern C++ syntax/features
// at the cost of some overhead (i.e. using std::string over char *).
//
// Where applicable, default memory allocator options are used unless explicitly set.
//
// Experimental components are designed as drop-in replacements of the regular API, requiring
// minimal code modifications to use.
//
// Example: Ort::Session -> Ort::Experimental::Session
//
// NOTE: Experimental API components are subject to change based on feedback and provide no
// guarantee of backwards compatibility in future releases.
#pragma once
#include "onnxruntime_cxx_api.h"
namespace Ort {
namespace Experimental {
struct Session : Ort::Session {
Session(Env& env, std::basic_string<ORTCHAR_T>& model_path, SessionOptions& options)
: Ort::Session(env, model_path.data(), options){};
Session(Env& env, void* model_data, size_t model_data_length, SessionOptions& options)
: Ort::Session(env, model_data, model_data_length, options){};
// overloaded Run() with sensible defaults
std::vector<Ort::Value> Run(const std::vector<std::string>& input_names,
const std::vector<Ort::Value>& input_values,
const std::vector<std::string>& output_names,
const RunOptions& run_options = RunOptions());
void Run(const std::vector<std::string>& input_names,
const std::vector<Ort::Value>& input_values,
const std::vector<std::string>& output_names,
std::vector<Ort::Value>& output_values,
const RunOptions& run_options = RunOptions());
// convenience methods that simplify common lower-level API calls
std::vector<std::string> GetInputNames() const;
std::vector<std::string> GetOutputNames() const;
std::vector<std::string> GetOverridableInitializerNames() const;
// NOTE: shape dimensions may have a negative value to indicate a symbolic/unknown dimension.
std::vector<std::vector<int64_t> > GetInputShapes() const;
std::vector<std::vector<int64_t> > GetOutputShapes() const;
std::vector<std::vector<int64_t> > GetOverridableInitializerShapes() const;
};
struct Value : Ort::Value {
Value(OrtValue* p)
: Ort::Value(p){};
template <typename T>
static Ort::Value CreateTensor(T* p_data, size_t p_data_element_count, const std::vector<int64_t>& shape);
static Ort::Value CreateTensor(void* p_data, size_t p_data_byte_count, const std::vector<int64_t>& shape, ONNXTensorElementDataType type);
template <typename T>
static Ort::Value CreateTensor(const std::vector<int64_t>& shape);
static Ort::Value CreateTensor(const std::vector<int64_t>& shape, ONNXTensorElementDataType type);
};
}
}
#include "experimental_onnxruntime_cxx_inline.h"

View File

@@ -0,0 +1,108 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// Do not include this file directly. Please include "experimental_onnxruntime_cxx_api.h" instead.
//
// These are the inline implementations of the C++ header APIs. They are in this separate file as to not clutter
// the main C++ file with implementation details.
namespace Ort {
namespace Experimental {
inline std::vector<Ort::Value> Session::Run(const std::vector<std::string>& input_names, const std::vector<Ort::Value>& input_values,
const std::vector<std::string>& output_names, const RunOptions& run_options) {
size_t output_names_count = GetOutputNames().size();
std::vector<Ort::Value> output_values;
for (size_t i = 0; i < output_names_count; i++) output_values.emplace_back(nullptr);
Run(input_names, input_values, output_names, output_values, run_options);
return output_values;
}
inline void Session::Run(const std::vector<std::string>& input_names, const std::vector<Ort::Value>& input_values,
const std::vector<std::string>& output_names, std::vector<Ort::Value>& output_values, const RunOptions& run_options) {
size_t input_names_count = input_names.size();
size_t output_names_count = output_names.size();
std::vector<const char*> input_names_(input_names_count, nullptr);
size_t i = 0;
for (auto it = input_names.begin(); it != input_names.end(); it++) input_names_[i++] = (*it).c_str();
std::vector<const char*> output_names_(output_names_count, nullptr);
i = 0;
for (auto it = output_names.begin(); it != output_names.end(); it++) output_names_[i++] = (*it).c_str();
Ort::Session::Run(run_options, input_names_.data(), input_values.data(), input_names_count, output_names_.data(), output_values.data(), output_names_count);
}
inline std::vector<std::string> Session::GetInputNames() const {
Ort::AllocatorWithDefaultOptions allocator;
size_t node_count = GetInputCount();
std::vector<std::string> out(node_count);
for (size_t i = 0; i < node_count; i++) {
auto tmp = GetInputNameAllocated(i, allocator);
out[i] = tmp.get();
}
return out;
}
inline std::vector<std::string> Session::GetOutputNames() const {
Ort::AllocatorWithDefaultOptions allocator;
size_t node_count = GetOutputCount();
std::vector<std::string> out(node_count);
for (size_t i = 0; i < node_count; i++) {
auto tmp = GetOutputNameAllocated(i, allocator);
out[i] = tmp.get();
}
return out;
}
inline std::vector<std::string> Session::GetOverridableInitializerNames() const {
Ort::AllocatorWithDefaultOptions allocator;
size_t init_count = GetOverridableInitializerCount();
std::vector<std::string> out(init_count);
for (size_t i = 0; i < init_count; i++) {
auto tmp = GetOverridableInitializerNameAllocated(i, allocator);
out[i] = tmp.get();
}
return out;
}
inline std::vector<std::vector<int64_t>> Session::GetInputShapes() const {
size_t node_count = GetInputCount();
std::vector<std::vector<int64_t>> out(node_count);
for (size_t i = 0; i < node_count; i++) out[i] = GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
return out;
}
inline std::vector<std::vector<int64_t>> Session::GetOutputShapes() const {
size_t node_count = GetOutputCount();
std::vector<std::vector<int64_t>> out(node_count);
for (size_t i = 0; i < node_count; i++) out[i] = GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
return out;
}
inline std::vector<std::vector<int64_t>> Session::GetOverridableInitializerShapes() const {
size_t init_count = GetOverridableInitializerCount();
std::vector<std::vector<int64_t>> out(init_count);
for (size_t i = 0; i < init_count; i++) out[i] = GetOverridableInitializerTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
return out;
}
template <typename T>
inline Ort::Value Value::CreateTensor(T* p_data, size_t p_data_element_count, const std::vector<int64_t>& shape) {
return CreateTensor(p_data, p_data_element_count * sizeof(T), shape, TypeToTensorType<T>::type);
}
inline Ort::Value Value::CreateTensor(void* p_data, size_t p_data_byte_count, const std::vector<int64_t>& shape, ONNXTensorElementDataType type) {
Ort::MemoryInfo info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
return Ort::Value::CreateTensor(info, p_data, p_data_byte_count, shape.data(), shape.size(), type);
}
template <typename T>
inline Ort::Value Value::CreateTensor(const std::vector<int64_t>& shape) {
return CreateTensor(shape, TypeToTensorType<T>::type);
}
inline Ort::Value Value::CreateTensor(const std::vector<int64_t>& shape, ONNXTensorElementDataType type) {
Ort::AllocatorWithDefaultOptions allocator;
return Ort::Value::CreateTensor(allocator, shape.data(), shape.size(), type);
}
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
/*
* This file defines RunOptions Config Keys and format of the Config Values.
*
* The Naming Convention for a RunOptions Config Key,
* "[Area][.[SubArea1].[SubArea2]...].[Keyname]"
* Such as "ep.cuda.use_arena"
* The Config Key cannot be empty
* The maximum length of the Config Key is 128
*
* The string format of a RunOptions Config Value is defined individually for each Config.
* The maximum length of the Config Value is 1024
*/
// Key for enabling shrinkages of user listed device memory arenas.
// Expects a list of semi-colon separated key value pairs separated by colon in the following format:
// "device_0:device_id_0;device_1:device_id_1"
// No white-spaces allowed in the provided list string.
// Currently, the only supported devices are : "cpu", "gpu" (case sensitive).
// If "cpu" is included in the list, DisableCpuMemArena() API must not be called (i.e.) arena for cpu should be enabled.
// Example usage: "cpu:0;gpu:0" (or) "gpu:0"
// By default, the value for this key is empty (i.e.) no memory arenas are shrunk
static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memory.enable_memory_arena_shrinkage";

View File

@@ -0,0 +1,186 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
/*
* This file defines SessionOptions Config Keys and format of the Config Values.
*
* The Naming Convention for a SessionOptions Config Key,
* "[Area][.[SubArea1].[SubArea2]...].[Keyname]"
* Such as "ep.cuda.use_arena"
* The Config Key cannot be empty
* The maximum length of the Config Key is 128
*
* The string format of a SessionOptions Config Value is defined individually for each Config.
* The maximum length of the Config Value is 1024
*/
// Key for disable PrePacking,
// If the config value is set to "1" then the prepacking is disabled, otherwise prepacking is enabled (default value)
static const char* const kOrtSessionOptionsConfigDisablePrepacking = "session.disable_prepacking";
// A value of "1" means allocators registered in the env will be used. "0" means the allocators created in the session
// will be used. Use this to override the usage of env allocators on a per session level.
static const char* const kOrtSessionOptionsConfigUseEnvAllocators = "session.use_env_allocators";
// Set to 'ORT' (case sensitive) to load an ORT format model.
// If unset, model type will default to ONNX unless inferred from filename ('.ort' == ORT format) or bytes to be ORT
static const char* const kOrtSessionOptionsConfigLoadModelFormat = "session.load_model_format";
// Set to 'ORT' (case sensitive) to save optimized model in ORT format when SessionOptions.optimized_model_path is set.
// If unset, format will default to ONNX unless optimized_model_filepath ends in '.ort'.
static const char* const kOrtSessionOptionsConfigSaveModelFormat = "session.save_model_format";
// If a value is "1", flush-to-zero and denormal-as-zero are applied. The default is "0".
// When multiple sessions are created, a main thread doesn't override changes from succeeding session options,
// but threads in session thread pools follow option changes.
// When ORT runs with OpenMP, the same rule is applied, i.e. the first session option to flush-to-zero and
// denormal-as-zero is only applied to global OpenMP thread pool, which doesn't support per-session thread pool.
// Note that an alternative way not using this option at runtime is to train and export a model without denormals
// and that's recommended because turning this option on may hurt model accuracy.
static const char* const kOrtSessionOptionsConfigSetDenormalAsZero = "session.set_denormal_as_zero";
// It controls to run quantization model in QDQ (QuantizelinearDeQuantizelinear) format or not.
// "0": enable. ORT does fusion logic for QDQ format.
// "1": disable. ORT doesn't do fusion logic for QDQ format.
// Its default value is "0"
static const char* const kOrtSessionOptionsDisableQuantQDQ = "session.disable_quant_qdq";
// It controls whether to enable Double QDQ remover and Identical Children Consolidation
// "0": not to disable. ORT does remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs
// "1": disable. ORT doesn't remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs
// Its default value is "0"
static const char* const kOrtSessionOptionsDisableDoubleQDQRemover = "session.disable_double_qdq_remover";
// If set to "1", enables the removal of QuantizeLinear/DequantizeLinear node pairs once all QDQ handling has been
// completed. e.g. If after all QDQ handling has completed and we have -> FloatOp -> Q -> DQ -> FloatOp -> the
// Q -> DQ could potentially be removed. This will provide a performance benefit by avoiding going from float to
// 8-bit and back to float, but could impact accuracy. The impact on accuracy will be model specific and depend on
// other factors like whether the model was created using Quantization Aware Training or Post Training Quantization.
// As such, it's best to test to determine if enabling this works well for your scenario.
// The default value is "0"
// Available since version 1.11.
static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enable_quant_qdq_cleanup";
// Enable or disable gelu approximation in graph optimization. "0": disable; "1": enable. The default is "0".
// GeluApproximation has side effects which may change the inference results. It is disabled by default due to this.
static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation";
#ifdef ENABLE_TRAINING
// Specifies a list of op types for memory footprint reduction.
// The value should be a ","-delimited list of pair of
// <subgraph string : optimization strategy : number of subgraph to apply>.
// For example, "Gelu+Cast+:1:0,Dropout+:1:1".
// A valid "subgraph string" should be one subgraph representation output by ORT graph transformations.
// "optimization strategy" currently has valid values: 0 - disabled, 1 - recompute.
// "number of subgraph to apply" is used to control how many subgraphs to apply optimization, to avoid "oversaving"
// the memory.
static const char* const kOrtSessionOptionsMemoryOptimizerEnabler = "optimization.enable_memory_optimizer";
// Specifies the level for detecting subgraphs for memory footprint reduction.
// The value should be an integer. The default value is 0.
static const char* const kOrtSessionOptionsMemoryOptimizerProbeLevel = "optimization.enable_memory_probe_recompute_level";
#endif
// Enable or disable using device allocator for allocating initialized tensor memory. "1": enable; "0": disable. The default is "0".
// Using device allocators means the memory allocation is made using malloc/new.
static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = "session.use_device_allocator_for_initializers";
// Configure whether to allow the inter_op/intra_op threads spinning a number of times before blocking
// "0": thread will block if found no job to run
// "1": default, thread will spin a number of times before blocking
static const char* const kOrtSessionOptionsConfigAllowInterOpSpinning = "session.inter_op.allow_spinning";
static const char* const kOrtSessionOptionsConfigAllowIntraOpSpinning = "session.intra_op.allow_spinning";
// Key for using model bytes directly for ORT format
// If a session is created using an input byte array contains the ORT format model data,
// By default we will copy the model bytes at the time of session creation to ensure the model bytes
// buffer is valid.
// Setting this option to "1" will disable copy the model bytes, and use the model bytes directly. The caller
// has to guarantee that the model bytes are valid until the ORT session using the model bytes is destroyed.
static const char* const kOrtSessionOptionsConfigUseORTModelBytesDirectly = "session.use_ort_model_bytes_directly";
/// <summary>
/// Key for using the ORT format model flatbuffer bytes directly for initializers.
/// This avoids copying the bytes and reduces peak memory usage during model loading and initialization.
/// Requires `session.use_ort_model_bytes_directly` to be true.
/// If set, the flatbuffer bytes provided when creating the InferenceSession MUST remain valid for the entire
/// duration of the InferenceSession.
/// </summary>
static const char* const kOrtSessionOptionsConfigUseORTModelBytesForInitializers =
"session.use_ort_model_bytes_for_initializers";
// This should only be specified when exporting an ORT format model for use on a different platform.
// If the ORT format model will be used on ARM platforms set to "1". For other platforms set to "0"
// Available since version 1.11.
static const char* const kOrtSessionOptionsQDQIsInt8Allowed = "session.qdqisint8allowed";
// x64 SSE4.1/AVX2/AVX512(with no VNNI) has overflow problem with quantizied matrix multiplication with U8S8.
// To avoid this we need to use slower U8U8 matrix multiplication instead. This option, if
// turned on, use slower U8U8 matrix multiplications. Only effective with AVX2 or AVX512
// platforms.
static const char* const kOrtSessionOptionsAvx2PrecisionMode = "session.x64quantprecision";
// Specifies how minimal build graph optimizations are handled in a full build.
// These optimizations are at the extended level or higher.
// Possible values and their effects are:
// "save": Save runtime optimizations when saving an ORT format model.
// "apply": Only apply optimizations available in a minimal build.
// ""/<unspecified>: Apply optimizations available in a full build.
// Available since version 1.11.
static const char* const kOrtSessionOptionsConfigMinimalBuildOptimizations =
"optimization.minimal_build_optimizations";
// Note: The options specific to an EP should be specified prior to appending that EP to the session options object in
// order for them to take effect.
// Specifies a list of stop op types. Nodes of a type in the stop op types and nodes downstream from them will not be
// run by the NNAPI EP.
// The value should be a ","-delimited list of op types. For example, "Add,Sub".
// If not specified, the default set of stop ops is used. To specify an empty stop ops types list and disable stop op
// exclusion, set the value to "".
static const char* const kOrtSessionOptionsConfigNnapiEpPartitioningStopOps = "ep.nnapi.partitioning_stop_ops";
// Enabling dynamic block-sizing for multithreading.
// With a positive value, thread pool will split a task of N iterations to blocks of size starting from:
// N / (num_of_threads * dynamic_block_base)
// As execution progresses, the size will decrease according to the diminishing residual of N,
// meaning the task will be distributed in smaller granularity for better parallelism.
// For some models, it helps to reduce the variance of E2E inference latency and boost performance.
// The feature will not function by default, specify any positive integer, e.g. "4", to enable it.
// Available since version 1.11.
static const char* const kOrtSessionOptionsConfigDynamicBlockBase = "session.dynamic_block_base";
// This option allows to decrease CPU usage between infrequent
// requests and forces any TP threads spinning stop immediately when the last of
// concurrent Run() call returns.
// Spinning is restarted on the next Run() call.
// Applies only to internal thread-pools
static const char* const kOrtSessionOptionsConfigForceSpinningStop = "session.force_spinning_stop";
// "1": all inconsistencies encountered during shape and type inference
// will result in failures.
// "0": in some cases warnings will be logged but processing will continue. The default.
// May be useful to expose bugs in models.
static const char* const kOrtSessionOptionsConfigStrictShapeTypeInference = "session.strict_shape_type_inference";
// The file saves configuration for partitioning node among logic streams
static const char* const kNodePartitionConfigFile = "session.node_partition_config_file";
// This Option allows setting affinities for intra op threads.
// Affinity string follows format:
// logical_processor_id,logical_processor_id;logical_processor_id,logical_processor_id
// Semicolon isolates configurations among threads, while comma split processors where ith thread expected to attach to.
// e.g.1,2,3;4,5
// specifies affinities for two threads, with the 1st thread attach to the 1st, 2nd, and 3rd processor, and 2nd thread to the 4th and 5th.
// To ease the configuration, an "interval" is also allowed:
// e.g. 1-8;8-16;17-24
// orders that the 1st thread runs on first eight processors, 2nd thread runs on next eight processors, and so forth.
// Note:
// 1. Once set, the number of thread affinities must equal to intra_op_num_threads - 1, since ort does not set affinity on the main thread which
// is started and managed by the calling app;
// 2. For windows, ort will infer the group id from a logical processor id, for example, assuming there are two groups with each has 64 logical processors,
// an id of 64 will be inferred as the last processor of the 1st group, while 65 will be interpreted as the 1st processor of the second group.
// Hence 64-65 is an invalid configuration, because a windows thread cannot be attached to processors across group boundary.
static const char* const kOrtSessionOptionsConfigIntraOpThreadAffinities = "session.intra_op_thread_affinities";

View File

@@ -0,0 +1,18 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <string>
#include <unordered_map>
#include <vector>
namespace onnxruntime {
// data types for execution provider options
using ProviderOptions = std::unordered_map<std::string, std::string>;
using ProviderOptionsVector = std::vector<ProviderOptions>;
using ProviderOptionsMap = std::unordered_map<std::string, ProviderOptions>;
} // namespace onnxruntime

View File

@@ -0,0 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "onnxruntime_c_api.h"
#ifdef __cplusplus
extern "C" {
#endif
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id);
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,34 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
/// <summary>
/// Options for the TensorRT provider that are passed to SessionOptionsAppendExecutionProvider_TensorRT_V2.
/// Please note that this struct is *similar* to OrtTensorRTProviderOptions but only to be used internally.
/// Going forward, new trt provider options are to be supported via this struct and usage of the publicly defined
/// OrtTensorRTProviderOptions will be deprecated over time.
/// User can only get the instance of OrtTensorRTProviderOptionsV2 via CreateTensorRTProviderOptions.
/// </summary>
struct OrtTensorRTProviderOptionsV2 {
int device_id; // cuda device id.
int has_user_compute_stream; // indicator of user specified CUDA compute stream.
void* user_compute_stream; // user specified CUDA compute stream.
int trt_max_partition_iterations; // maximum iterations for TensorRT parser to get capability
int trt_min_subgraph_size; // minimum size of TensorRT subgraphs
size_t trt_max_workspace_size; // maximum workspace size for TensorRT.
int trt_fp16_enable; // enable TensorRT FP16 precision. Default 0 = false, nonzero = true
int trt_int8_enable; // enable TensorRT INT8 precision. Default 0 = false, nonzero = true
const char* trt_int8_calibration_table_name; // TensorRT INT8 calibration table name.
int trt_int8_use_native_calibration_table; // use native TensorRT generated calibration table. Default 0 = false, nonzero = true
int trt_dla_enable; // enable DLA. Default 0 = false, nonzero = true
int trt_dla_core; // DLA core number. Default 0
int trt_dump_subgraphs; // dump TRT subgraph. Default 0 = false, nonzero = true
int trt_engine_cache_enable; // enable engine caching. Default 0 = false, nonzero = true
const char* trt_engine_cache_path; // specify engine cache path
int trt_engine_decryption_enable; // enable engine decryption. Default 0 = false, nonzero = true
const char* trt_engine_decryption_lib_path; // specify engine decryption library path
int trt_force_sequential_engine_build; // force building TensorRT engine sequentially. Default 0 = false, nonzero = true
int trt_context_memory_sharing_enable; // enable context memory sharing between subgraphs. Default 0 = false, nonzero = true
int trt_layer_norm_fp32_fallback; // force Pow + Reduce ops in layer norm to FP32. Default 0 = false, nonzero = true
};

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) Microsoft Corporation
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

66
onnxruntime_test.go Normal file
View File

@@ -0,0 +1,66 @@
package onnxruntime
import (
"encoding/json"
"os"
"runtime"
"testing"
)
// This type is read from JSON and used to determine the inputs and expected
// outputs for an ONNX network.
type testInputsInfo struct {
InputShape []int `json:input_shape`
FlattenedInput []float32 `json:flattened_input`
OutputShape []int `json:output_shape`
FlattenedOutput []float32 `json:flattened_output`
}
// This must be called prior to running each test.
func InitializeRuntime(t *testing.T) {
if runtime.GOOS == "windows" {
SetSharedLibraryPath("test_data/onnxruntime.dll")
} else {
if runtime.GOARCH == "arm64" {
SetSharedLibraryPath("test_data/onnxruntime_arm64.so")
} else {
SetSharedLibraryPath("test_data/onnxruntime.so")
}
}
e := InitializeEnvironment()
if e != nil {
t.Logf("Failed setting up onnxruntime environment: %s\n", e)
t.FailNow()
}
}
// Used to obtain the shape
func parseInputsJSON(path string, t *testing.T) *testInputsInfo {
toReturn := testInputsInfo{}
f, e := os.Open(path)
if e != nil {
t.Logf("Failed opening %s: %s\n", path, e)
t.FailNow()
}
defer f.Close()
d := json.NewDecoder(f)
e = d.Decode(&toReturn)
if e != nil {
t.Logf("Failed decoding %s: %s\n", path, e)
t.FailNow()
}
return &toReturn
}
func TestExampleNetwork(t *testing.T) {
InitializeRuntime(t)
_ = parseInputsJSON("test_data/example_network_results.json", t)
// TODO: More tests here to run the network, once that's supported.
e := CleanupEnvironment()
if e != nil {
t.Logf("Failed cleaning up the environment: %s\n", e)
t.FailNow()
}
}

27
onnxruntime_wrapper.c Normal file
View File

@@ -0,0 +1,27 @@
#include "onnxruntime_wrapper.h"
static const OrtApi *ort_api = NULL;
int SetAPIFromBase(OrtApiBase *api_base) {
if (!api_base) return 1;
ort_api = api_base->GetApi(ORT_API_VERSION);
if (!ort_api) return 2;
return 0;
}
void ReleaseOrtStatus(OrtStatus *status) {
ort_api->ReleaseStatus(status);
}
void ReleaseOrtEnv(OrtEnv *env) {
ort_api->ReleaseEnv(env);
}
OrtStatus *CreateOrtEnv(char *name, OrtEnv **env) {
return ort_api->CreateEnv(ORT_LOGGING_LEVEL_ERROR, name, env);
}
const char *GetErrorMessage(OrtStatus *status) {
if (!status) return "No error (NULL status)";
return ort_api->GetErrorMessage(status);
}

44
onnxruntime_wrapper.h Normal file
View File

@@ -0,0 +1,44 @@
#ifndef ONNXRUNTIME_WRAPPER_H
#define ONNXRUNTIME_WRAPPER_H
// We want to always use the unix-like onnxruntime C APIs, even on Windows, so
// we need to undefine _WIN32 before including onnxruntime_c_api.h. However,
// this requires a careful song-and-dance.
// First, include these common headers, as they get transitively included by
// onnxruntime_c_api.h. We need to include them ourselves, first, so that the
// preprocessor will skip then while _WIN32 is undefined.
#include <stdio.h>
#include <stdlib.h>
// Next, we actually include the header.
#undef _WIN32
#include "onnxruntime_c_api.h"
// ... However, mingw will complain if _WIN32 is *not* defined! So redefine it.
#define _WIN32
#ifdef __cplusplus
extern "C" {
#endif
// Takes a pointer to the api_base struct in order to obtain the OrtApi
// pointer. Intended to be called from Go. Returns nonzero on error.
int SetAPIFromBase(OrtApiBase *api_base);
// Wraps calling ort_api->ReleaseStatus(status)
void ReleaseOrtStatus(OrtStatus *status);
// Wraps calling ort_api->CreateEnv. Returns a non-NULL status on error.
OrtStatus *CreateOrtEnv(char *name, OrtEnv **env);
// Releases the given OrtEnv.
void ReleaseOrtEnv(OrtEnv *env);
// Returns the message associated with the given ORT status.
const char *GetErrorMessage(OrtStatus *status);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // ONNXRUNTIME_WRAPPER_H

69
setup_env.go Normal file
View File

@@ -0,0 +1,69 @@
//go:build !windows
package onnxruntime
import (
"fmt"
"unsafe"
)
/*
#cgo LDFLAGS: -ldl
#include <dlfcn.h>
#include "onnxruntime_wrapper.h"
typedef OrtApiBase* (*GetOrtApiBaseFunction)(void);
// Since Go can't call C function pointers directly, we just use this helper
// when calling GetApiBase
OrtApiBase *CallGetAPIBaseFunction(void *fn) {
OrtApiBase *to_return = ((GetOrtApiBaseFunction) fn)();
return to_return;
}
*/
import "C"
// This file includes the code for loading the onnxruntime and setting up the
// environment on non-Windows systems. For now, it has only been tested on
// Linux.
// This will contain the handle to the onnxruntime shared library if it has
// been loaded successfully.
var libraryHandle unsafe.Pointer
func platformCleanup() error {
v, e := C.dlclose(libraryHandle)
if v != 0 {
return fmt.Errorf("Error closing the library: %w", e)
}
return nil
}
func platformInitializeEnvironment() error {
cName := C.CString(onnxSharedLibraryPath)
defer C.free(unsafe.Pointer(cName))
handle := C.dlopen(cName, C.RTLD_LAZY)
if handle == nil {
msg := C.GoString(C.dlerror())
return fmt.Errorf("Error loading ONNX shared library \"%s\": %s",
onnxSharedLibraryPath, msg)
}
cFunctionName := C.CString("OrtGetApiBase")
defer C.free(unsafe.Pointer(cFunctionName))
getAPIBaseProc := C.dlsym(handle, cFunctionName)
if getAPIBaseProc == nil {
C.dlclose(handle)
msg := C.GoString(C.dlerror())
return fmt.Errorf("Error looking up OrtGetApiBase in \"%s\": %s",
onnxSharedLibraryPath, msg)
}
ortAPIBase := C.CallGetAPIBaseFunction(getAPIBaseProc)
tmp := C.SetAPIFromBase((*C.OrtApiBase)(unsafe.Pointer(ortAPIBase)))
if tmp != 0 {
C.dlclose(handle)
return fmt.Errorf("Error setting ORT API base: %d", tmp)
}
libraryHandle = handle
return nil
}

55
setup_env_windows.go Normal file
View File

@@ -0,0 +1,55 @@
//go:build windows
package onnxruntime
// This file includes the Windows-specific code for loading the onnxruntime
// library and setting up the environment.
import (
"fmt"
"syscall"
"unsafe"
)
// #include "onnxruntime_wrapper.h"
import "C"
// This will contain the handle to the onnxruntime dll if it has been loaded
// successfully.
var libraryHandle syscall.Handle
func platformCleanup() error {
e := syscall.FreeLibrary(libraryHandle)
libraryHandle = 0
return e
}
func platformInitializeEnvironment() error {
handle, e := syscall.LoadLibrary(onnxSharedLibraryPath)
if e != nil {
return fmt.Errorf("Error loading ONNX shared library \"%s\": %w",
onnxSharedLibraryPath, e)
}
getApiBaseProc, e := syscall.GetProcAddress(handle, "OrtGetApiBase")
if e != nil {
syscall.FreeLibrary(handle)
return fmt.Errorf("Error finding OrtGetApiBase function in %s: %w",
onnxSharedLibraryPath, e)
}
ortApiBase, _, e := syscall.SyscallN(uintptr(getApiBaseProc), 0)
if ortApiBase == 0 {
syscall.FreeLibrary(handle)
if e != nil {
return fmt.Errorf("Error calling OrtGetApiBase: %w", e)
} else {
return fmt.Errorf("Error calling OrtGetApiBase")
}
}
tmp := C.SetAPIFromBase((*C.OrtApiBase)(unsafe.Pointer(ortApiBase)))
if tmp != 0 {
syscall.FreeLibrary(handle)
return fmt.Errorf("Error setting ORT API base: %d", tmp)
}
libraryHandle = handle
return nil
}

Binary file not shown.

View File

@@ -0,0 +1,22 @@
{
"input_shape": [
1,
1,
4
],
"flattened_input": [
0.19473445415496826,
0.9139836430549622,
0.7043011784553528,
0.7685686945915222
],
"output_shape": [
1,
1,
2
],
"flattened_output": [
2.581585645675659,
0.6283518075942993
]
}

View File

@@ -0,0 +1,131 @@
# This script sets up and "trains" a toy pytorch network, that trains a NN to
# map a 1x4 vector to a 1x2 vector containing [sum, max difference] of the
# input values. Finally, it exports the network to an ONNX file to use in
# testing.
import torch
from torch.nn.functional import relu
import json
def fake_dataset(size):
""" Returns a dataset filled with our fake training data. """
inputs = torch.rand((size, 1, 4))
outputs = torch.zeros((size, 1, 2))
for i in range(size):
outputs[i][0][0] = inputs[i][0].sum()
outputs[i][0][1] = inputs[i][0].max() - inputs[i][0].min()
return torch.utils.data.TensorDataset(inputs, outputs)
class SumAndDiffModel(torch.nn.Module):
""" Just a standard, fairly minimal, pytorch model for generating the NN.
"""
def __init__(self):
super().__init__()
# We'll do four 1x4 convolutions to make the network more interesting.
self.conv = torch.nn.Conv1d(1, 4, 4)
# We'll follow the conv with a FC layer to produce the outputs. The
# input to the FC layer are the 4 conv outputs concatenated with the
# original input.
self.fc = torch.nn.Linear(8, 2)
def forward(self, data):
batch_size = len(data)
conv_out = relu(self.conv(data))
conv_flattened = torch.flatten(conv_out, start_dim=1)
data_flattened = torch.flatten(data, start_dim=1)
combined = torch.cat((conv_flattened, data_flattened), dim=1)
output = relu(self.fc(combined))
output = output.view(batch_size, 1, 2)
return output
def get_test_loss(model, loader, loss_function):
""" Just runs a single epoch of data from the given loader. Returns the
average loss per batch. The provided model is expected to be in eval mode.
"""
i = 0
total_loss = 0.0
for in_data, desired_result in loader:
produced_result = model(in_data)
loss = loss_function(desired_result, produced_result)
total_loss += loss.item()
i += 1
return total_loss / i
def save_model(model, output_filename):
""" Saves the model to an onnx file with the given name. Assumes the model
is in eval mode. """
print("Saving network to " + output_filename)
dummy_input = torch.rand(1, 1, 4)
input_names = ["1x4 Input Vector"]
output_names = ["1x2 Output Vector"]
torch.onnx.export(model, dummy_input, output_filename,
input_names=input_names, output_names=output_names)
return None
def print_sample(model):
""" Prints a sample input and output computation using the model. Expects
the model to be in eval mode. """
example_input = torch.rand(1, 1, 4)
result = model(example_input)
print("Sample model execution:")
print(" Example input: " + str(example_input))
print(" Produced output: " + str(result))
return None
def save_sample_json(model, output_name):
""" Saves a JSON file containing an input and an output from the network,
for use when validating execution of the ONNX network. """
example_input = torch.rand(1, 1, 4)
result = model(example_input)
json_content = {}
json_content["input_shape"] = list(example_input.shape)
json_content["flattened_input"] = list(example_input.flatten().tolist())
json_content["output_shape"] = list(result.shape)
json_content["flattened_output"] = list(result.flatten().tolist())
with open(output_name, "w") as f:
json.dump(json_content, f, indent=" ")
return None
def main():
print("Generating train and test datasets...")
train_data = fake_dataset(100 * 1000)
train_loader = torch.utils.data.DataLoader(dataset=train_data,
batch_size=16, shuffle=True)
test_data = fake_dataset(10 * 1000)
test_loader = torch.utils.data.DataLoader(dataset=test_data,
batch_size=16)
model = SumAndDiffModel()
model.train()
loss_function = torch.nn.L1Loss(reduction="mean")
optimizer = torch.optim.Adam(model.parameters(), lr=0.00005)
for epoch in range(8):
i = 0
total_loss = 0.0
for in_data, desired_result in train_loader:
i += 1
produced_result = model(in_data)
loss = loss_function(desired_result, produced_result)
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
if (i % 1000) == 1:
print("Epoch %d, iteration %d. Current loss = %f" % (epoch, i,
loss.item()))
train_loss = total_loss / i
print(" => Average train-set loss: " + str(train_loss))
model.eval()
with torch.no_grad():
test_loss = get_test_loss(model, test_loader, loss_function)
model.train()
print(" => Average test-set loss: " + str(test_loss))
model.eval()
with torch.no_grad():
save_model(model, "example_network.onnx")
save_sample_json(model, "example_network_results.json")
print_sample(model)
print("Done!")
if __name__ == "__main__":
main()

BIN
test_data/onnxruntime.dll Normal file

Binary file not shown.