Remove training API

- The training API has been deprecated by onnxruntime itself, and it
   will be much easier to remove it rather than deprecate it.

 - The training API wrapper functions have been replaced by stubs that
   return errors in legacy_types.go.

 - The README mentions the old version required for the training API.

 - The Scalar type has been promoted to onnxruntime_go.go, and a test
   has been added for it.
This commit is contained in:
yalue
2024-11-12 21:50:00 -05:00
committed by Nathan O
parent f9cafdda92
commit 11b449bb38
10 changed files with 202 additions and 1949 deletions

View File

@@ -206,10 +206,13 @@ run and pass.
Training API Support
--------------------
This wrapper supports the onnxruntime training API on limited platforms. See
the `NewTrainingSession` and associated data types or functions to use it. So
far, the training API has only been tested on Linux, on `x86_64` architectures.
The training API has been deprecated as of onnxruntime version 1.20. Rather
than continuing to maintain wrappers for a deprecated API, `onnxruntime_go` has
replaced the wrapper functions for the training API with stubs that return an
error. Users who need to continue to use the training API will need to use an
older version. For example the following versions should be compatible with
training:
- Version `v1.12.1` of `onnxruntime_go`, and
- Version 1.19.2 of `onnxruntime`.
If you are not sure whether your platform or build of onnxruntime supports
training, you can call `onnxruntime_go.IsTrainingSupported()`, which will
return `true` if training is supported on your system.

View File

@@ -171,3 +171,77 @@ type ArbitraryTensor = Value
// As with the ArbitraryTensor type, this type alias only exists to facilitate
// renaming an old type without breaking existing code.
type TensorInternalData = ValueInternalData
var TrainingAPIRemovedError error = fmt.Errorf("Support for the training " +
"API has been removed from onnxruntime_go following its deprecation in " +
"onnxruntime versions 1.19.2 and later. The last revision of " +
"onnxruntime_go supporting the training API is version v1.12.1")
// Support for TrainingSessions has been removed from onnxruntime_go following
// the deprecation of the training API in onnxruntime 1.20.0.
type TrainingSession struct{}
// Always returns TrainingAPIRemovedError.
func (s *TrainingSession) ExportModel(path string, outputNames []string) error {
return TrainingAPIRemovedError
}
// Always returns TrainingAPIRemovedError.
func (s *TrainingSession) SaveCheckpoint(path string,
saveOptimizerState bool) error {
return TrainingAPIRemovedError
}
// Always returns TrainingAPIRemovedError.
func (s *TrainingSession) Destroy() error {
return TrainingAPIRemovedError
}
// Always returns TrainingAPIRemovedError.
func (s *TrainingSession) TrainStep() error {
return TrainingAPIRemovedError
}
// Always returns TrainingAPIRemovedError.
func (s *TrainingSession) OptimizerStep() error {
return TrainingAPIRemovedError
}
// Always returns TrainingAPIRemovedError.
func (s *TrainingSession) LazyResetGrad() error {
return TrainingAPIRemovedError
}
// Support for TrainingInputOutputNames has been removed from onnxruntime_go
// following the deprecation of the training API in onnxruntime 1.20.0.
type TrainingInputOutputNames struct {
TrainingInputNames []string
EvalInputNames []string
TrainingOutputNames []string
EvalOutputNames []string
}
// Always returns (nil, TrainingAPIRemovedError).
func GetInputOutputNames(checkpointStatePath string, trainingModelPath string,
evalModelPath string) (*TrainingInputOutputNames, error) {
return nil, TrainingAPIRemovedError
}
// Always returns false.
func IsTrainingSupported() bool {
return false
}
// Always returns (nil, TrainingAPIRemovedError).
func NewTrainingSessionWithOnnxData(checkpointData, trainingData, evalData,
optimizerData []byte, inputs, outputs []Value,
options *SessionOptions) (*TrainingSession, error) {
return nil, TrainingAPIRemovedError
}
// Always returns (nil, TrainingAPIRemovedError).
func NewTrainingSession(checkpointStatePath, trainingModelPath, evalModelPath,
optimizerModelPath string, inputs, outputs []Value,
options *SessionOptions) (*TrainingSession, error) {
return nil, TrainingAPIRemovedError
}

View File

@@ -90,9 +90,6 @@ func InitializeEnvironment() error {
return fmt.Errorf("Platform-specific initialization failed: %w", e)
}
// Get the training API pointer if it is supported.
C.SetTrainingApi()
name := C.CString("Golang onnxruntime environment")
defer C.free(unsafe.Pointer(name))
status := C.CreateOrtEnv(name, &ortEnv)
@@ -883,6 +880,90 @@ func (t *CustomDataTensor) GetData() []byte {
return t.data
}
// Scalar is like a tensor but the underlying go slice is of length 1 and it
// has no dimension. It was introduced for use with the training API, but
// remains supported since it conceivable will have use outside of the training
// API.
type Scalar[T TensorData] struct {
data []T
dataSize uintptr
ortValue *C.OrtValue
}
// Always returns nil for Scalars.
func (s *Scalar[T]) GetShape() Shape {
return nil
}
func (s *Scalar[T]) ZeroContents() {
C.memset(unsafe.Pointer(&s.data[0]), 0, C.size_t(s.dataSize))
}
func (s *Scalar[T]) Destroy() error {
C.ReleaseOrtValue(s.ortValue)
s.ortValue = nil
s.data = nil
s.dataSize = 0
return nil
}
// GetData returns the undelying data for the scalar. If you want to set the
// scalar's data, use Set.
func (t *Scalar[T]) GetData() T {
return t.data[0]
}
// Changes the underlying value of the scalar to the new value.
func (t *Scalar[T]) Set(value T) {
t.data = []T{value}
}
func (t *Scalar[T]) DataType() C.ONNXTensorElementDataType {
return GetTensorElementDataType[T]()
}
func (t *Scalar[_]) GetInternals() *ValueInternalData {
return &ValueInternalData{
ortValue: t.ortValue,
}
}
func (t *Scalar[_]) GetONNXType() ONNXType {
return ONNXTypeTensor
}
// NewEmptyScalar creates a new scalar of type T.
func NewEmptyScalar[T TensorData]() (*Scalar[T], error) {
var data T
return NewScalar(data)
}
// NewScalar creates a new scalar of type T backed by a value of type T.
// Note that, differently from tensors, this is not a []T but just a value T.
func NewScalar[T TensorData](data T) (*Scalar[T], error) {
if !IsInitialized() {
return nil, NotInitializedError
}
dataSlice := []T{data}
var ortValue *C.OrtValue
dataType := GetTensorElementDataType[T]()
dataSize := unsafe.Sizeof(dataSlice[0]) * uintptr(1)
status := C.CreateOrtTensorWithShape(unsafe.Pointer(&dataSlice[0]),
C.size_t(dataSize), nil, C.int64_t(0), ortMemoryInfo, dataType, &ortValue)
if status != nil {
return nil, statusToError(status)
}
toReturn := Scalar[T]{
data: dataSlice,
dataSize: dataSize,
ortValue: ortValue,
}
return &toReturn, nil
}
// Holds options required when enabling the CUDA backend for a session. This
// struct wraps C onnxruntime types; users must create instances of this using
// the NewCUDAProviderOptions() function. So, to enable CUDA for a session,

View File

@@ -1461,6 +1461,41 @@ func TestSessionFromDataBuffer(t *testing.T) {
}
}
func TestScalar(t *testing.T) {
InitializeRuntime(t)
defer CleanupRuntime(t)
s, e := NewEmptyScalar[float32]()
if e != nil {
t.Fatalf("Error creating empty scalar: %s\n", e)
}
if s.GetData() != 0.0 {
t.Fatalf("Empty scalar not initialized to 0: %s\n", e)
}
e = s.Destroy()
if e != nil {
t.Fatalf("Failed destroying scalar: %s\n", e)
}
s2, e := NewScalar(int64(1337))
if e != nil {
t.Fatalf("Failed creating int64 scalar: %s\n", e)
}
defer s2.Destroy()
contents := s2.GetData()
if contents != 1337 {
t.Fatalf("Incorrect initial contents of s2: %d\n", contents)
}
s2.ZeroContents()
contents = s2.GetData()
if contents != 0 {
t.Fatalf("Incorrect value of s2 after zeroing: %d\n", contents)
}
s2.Set(1234)
contents = s2.GetData()
if contents != 1234 {
t.Fatalf("Incorrect value of s2: %d (expected 1234)\n", contents)
}
}
// See the comment in generate_network_big_compute.py for information about
// the inputs and outputs used for testing or benchmarking session options.
func prepareBenchmarkTensors(t testing.TB, seed int64) (*Tensor[float32],

View File

@@ -1,731 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// This file contains the training c apis.
#pragma once
#include <stdbool.h>
#include "onnxruntime_c_api.h"
/** \page training_c_cpp_api Training C & C++ APIs
*
* Training C and C++ APIs are an extension of the \ref c_cpp_api "onnxruntime core C and C++ APIs" and should be used in conjunction with them.
*
* In order to train a model with onnxruntime, the following training artifacts must be generated:
* - The training onnx model
* - The checkpoint file
* - The optimizer onnx model
* - The eval onnx model model (optional)
*
* These training artifacts can be generated as part of an offline step using the python [utilities](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md) made available in the `onnxruntime-training` python package.
*
* After these artifacts have been generated, the C and C++ utilities listed in this documentation can be leveraged to perform training.
*
* If any problem is encountered, please create an [issue](https://github.com/microsoft/onnxruntime/issues/new) with your scenario and requirements, and we will be sure to respond and follow up on the request.
*
* <h1>Training C API</h1>
*
* ::OrtTrainingApi - Training C API functions.
*
* This C structure contains functions that enable users to perform training with onnxruntime.
*
* _Sample Code_:
*
* ```c
* #include <onnxruntime_training_api.h>
*
* OrtApi* g_ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
* OrtTrainingApi* g_ort_training_api = g_ort_api->GetTrainingApi(ORT_API_VERSION);
*
* OrtEnv* env = NULL;
* g_ort_api->CreateEnv(logging_level, logid, &env);
* OrtSessionOptions* session_options = NULL;
* g_ort_api->CreateSessionOptions(&session_options);
*
* OrtCheckpointState* state = NULL;
* g_ort_training_api->LoadCheckpoint(path_to_checkpoint, &state);
*
* OrtTrainingSession* training_session = NULL;
* g_ort_training_api->CreateTrainingSession(env, session_options, training_model_path,
* state, eval_model_path, optimizer_model_path,
* &training_session);
* // Training loop
* {
* g_ort_training_api->TrainStep(...);
* g_ort_training_api->OptimizerStep(...);
* g_ort_training_api->LazyResetGrad(...);
* }
*
* g_ort_training_api->ExportModelForInferencing(training_session, inference_model_path, ...);
* g_ort_training_api->SaveCheckpoint(state, path_to_checkpoint, false);
*
* g_ort_training_api->ReleaseTrainingSession(training_session);
* g_ort_training_api->ReleaseCheckpointState(state);
* ```
*
* > **Note**
* > The ::OrtCheckpointState contains the entire training state that the ::OrtTrainingSession uses. As a result, the training session must always have access to the state. That is to say, the ::OrtCheckpointState instance must outlive the lifetime of the ::OrtTrainingSession instance.
*
* <h1>Training C++ API</h1>
*
* @ref TrainingCpp - Training C++ API classes and functions.
*
* These C++ classes and functions enable users to perform training with onnxruntime.
*
* _Sample Code_:
*
* ```cc
* #include <onnxruntime_training_cxx_api.h>
*
* Ort::Env env;
* Ort::SessionOptions session_options;
*
* auto state = Ort::CheckpointState::LoadCheckpoint(path_to_checkpoint);
* auto training_session = Ort::TrainingSession(env, session_options, state, training_model_path,
* eval_model_path, optimizer_model_path);
*
* // Training Loop
* {
* training_session.TrainStep(...);
* training_session.OptimizerStep(...);
* training_session.LazyResetGrad(...);
* }
*
* training_session->ExportModelForInferencing(inference_model_path, ...);
* Ort::CheckpointState::SaveCheckpoint(state, path_to_checkpoint, false);
* ```
* > **Note**
* > The ::Ort::CheckpointState contains the entire training state that the ::Ort::TrainingSession uses. As a result, the training session must always have access to the state. That is to say, the ::Ort::CheckpointState instance must outlive the lifetime of the ::Ort::TrainingSession instance.
*/
/** @defgroup TrainingC Ort Training C API
* @{
*/
ORT_RUNTIME_CLASS(TrainingSession); // Type that enables performing training for the given user models.
ORT_RUNTIME_CLASS(CheckpointState); // Type that holds the training states for the training session.
/** \brief Type of property to be added to or returned from the ::OrtCheckpointState.
*/
typedef enum OrtPropertyType {
OrtIntProperty = 0,
OrtFloatProperty = 1,
OrtStringProperty = 2,
} OrtPropertyType;
/** \brief The Training C API that holds onnxruntime training function pointers
*
* All the Training C API functions are defined inside this structure as pointers to functions.
* Call OrtApi::GetTrainingApi to get a pointer to this struct.
*
* \nosubgrouping
*/
struct OrtTrainingApi {
/// \name Accessing The Training Session State
/// @{
/** \brief Load a checkpoint state from a file on disk into checkpoint_state.
*
* This function will parse a checkpoint file, pull relevant data and load the training
* state into the checkpoint_state. This checkpoint state can then be used to create the
* training session by invoking OrtTrainingApi::CreateTrainingSession. By doing so, the training
* session will resume training from the given checkpoint state.
* \note Note that the training session created with a checkpoint state uses this state to store the entire
* training state (including model parameters, its gradients, the optimizer states and the properties).
* As a result, it is required that the checkpoint state outlive the lifetime of the training session.
* \note Note that the checkpoint file can be either the complete checkpoint or the nominal checkpoint.
*
* \param[in] checkpoint_path Path to the checkpoint file
* \param[out] checkpoint_state Checkpoint state that contains the states of the training session.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path,
_Outptr_ OrtCheckpointState** checkpoint_state);
/** \brief Save the given state to a checkpoint file on disk.
*
* This function serializes the provided checkpoint state to a file on disk.
* This checkpoint can later be loaded by invoking OrtTrainingApi::LoadCheckpoint to resume
* training from this snapshot of the state.
*
* \param[in] checkpoint_state The checkpoint state to save.
* \param[in] checkpoint_path Path to the checkpoint file.
* \param[in] include_optimizer_state Flag to indicate whether to save the optimizer state or not.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(SaveCheckpoint, _In_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* checkpoint_path,
const bool include_optimizer_state);
/// @}
/// \name Implementing The Training Loop
/// @{
/** \brief Create a training session that can be used to begin or resume training.
*
* This function creates a training session based on the env and session options provided that can
* begin or resume training from a given checkpoint state for the given onnx models.
* The checkpoint state represents the parameters of the training session which will be moved
* to the device specified by the user through the session options (if necessary).
* The training session requires four training artifacts
* - The training onnx model
* - The evaluation onnx model (optional)
* - The optimizer onnx model
* - The checkpoint file
*
* These artifacts can be generated using the `onnxruntime-training` python [utility](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md).
*
* \param[in] env Environment to be used for the training session.
* \param[in] options Session options that the user can customize for this training session.
* \param[in] checkpoint_state Training states that the training session uses as a starting point for training.
* \param[in] train_model_path Model to be used to perform training.
* \param[in] eval_model_path Model to be used to perform evaluation.
* \param[in] optimizer_model_path Model to be used to perform gradient descent.
* \param[out] out Created training session.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options,
_Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path,
_In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path,
_Outptr_result_maybenull_ OrtTrainingSession** out);
/** \brief Create a training session that can be used to begin or resume training.
* This api provides a way to load all the training artifacts from buffers instead of files.
*
* \param[in] env Environment to be used for the training session.
* \param[in] options Session options that the user can customize for this training session.
* \param[in] checkpoint_state Training states that the training session uses as a starting point for training.
* \param[in] train_model_data Buffer containing the model data to be used to perform training
* \param[in] train_data_length Length of the buffer containing train_model_data
* \param[in] eval_model_data Buffer containing the model data to be used to perform evaluation
* \param[in] eval_data_length Length of the buffer containing eval_model_data
* \param[in] optim_model_data Buffer containing the model data to be used to perform weight update
* \param[in] optim_data_length Length of the buffer containing optim_model_data
* \param[out] out Created training session.
*
*/
ORT_API2_STATUS(CreateTrainingSessionFromBuffer, _In_ const OrtEnv* env,
_In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state,
_In_ const void* train_model_data, size_t train_data_length,
_In_ const void* eval_model_data, size_t eval_data_length,
_In_ const void* optim_model_data, size_t optim_data_length,
_Outptr_result_maybenull_ OrtTrainingSession** out);
/// @}
/// \name Model IO Information
/// @{
/** \brief Retrieves the number of user outputs in the training model.
*
* This function returns the number of outputs of the training model so that the user can
* allocate space for the number of outputs when OrtTrainingApi::TrainStep is invoked.
*
* \param[in] sess The `this` pointer to the training session.
* \param[out] out Number of user outputs in the training model.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(TrainingSessionGetTrainingModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
/** \brief Retrieves the number of user outputs in the eval model.
*
* This function returns the number of outputs of the eval model so that the user can
* allocate space for the number of outputs when OrtTrainingApi::EvalStep is invoked.
*
* \param[in] sess The `this` pointer to the training session.
* \param[out] out Number of user outputs in the eval model.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(TrainingSessionGetEvalModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
/** \brief Retrieves the names of user outputs in the training model.
*
* This function returns the names of outputs of the training model that can be associated with the OrtValue(s)
* returned by the OrtTrainingApi::TrainStep function.
*
* \param[in] sess The `this` pointer to the training session.
* \param[in] index Index of the output name requested.
* \param[in] allocator Allocator to use to allocate the memory for the name.
* \param[out] output Name of the training model output at the given index.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(TrainingSessionGetTrainingModelOutputName, _In_ const OrtTrainingSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** output);
/** \brief Retrieves the names of user outputs in the eval model.
*
* This function returns the names of outputs of the eval model that can be associated with the OrtValue(s) returned
* by the OrtTrainingApi::EvalStep function.
*
* \param[in] sess The `this` pointer to the training session.
* \param[in] index Index of the output name requested.
* \param[in] allocator Allocator to use to allocate the memory for the name.
* \param[out] output Name of the eval model output at the given index.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(TrainingSessionGetEvalModelOutputName, _In_ const OrtTrainingSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** output);
/// @}
/// \name Implementing The Training Loop
/// @{
/** \brief Reset the gradients of all trainable parameters to zero lazily.
*
* This function sets the internal state of the training session such that the gradients of the trainable
* parameters in the OrtCheckpointState will be scheduled to be reset just before the new gradients are
* computed on the next invocation of the next OrtTrainingApi::TrainStep.
*
* \param[in] session The `this` pointer to the training session.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(LazyResetGrad, _Inout_ OrtTrainingSession* session);
/** \brief Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs
*
* This function performs a training step that computes the outputs of the training model and the gradients
* of the trainable parameters for the given inputs. The train step is performed based on the training model
* that was provided to the training session.
* The OrtTrainingApi::TrainStep is equivalent of running forward propagation and backward propagation in a single
* step.
* The gradients computed are stored inside the training session state so they can be later consumed
* by the OrtTrainingApi::OptimizerStep function.
* The gradients can be lazily reset by invoking the OrtTrainingApi::LazyResetGrad function.
*
* \param[in] sess The `this` pointer to the training session.
* \param[in] run_options Run options for this training step.
* \param[in] inputs_len Number of user inputs to the training model.
* \param[in] inputs The user inputs to the training model.
* \param[in] outputs_len Number of user outputs expected from this training step.
* \param[out] outputs User outputs computed by train step.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(TrainStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
_In_ size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
_In_ size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
/** \brief Computes the outputs for the eval model for the given inputs
*
* This function performs an eval step that computes the outputs of the eval model for the given inputs.
* The eval step is performed based on the eval model that was provided to the training session.
*
* \param[in] sess The `this` pointer to the training session.
* \param[in] run_options Run options for this eval step.
* \param[in] inputs_len Number of user inputs to the eval model.
* \param[in] inputs The user inputs to the eval model.
* \param[in] outputs_len Number of user outputs expected from this eval step.
* \param[out] outputs User outputs computed by eval step.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(EvalStep, _In_ const OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
_In_ size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
_In_ size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
/** \brief Sets the learning rate for this training session.
*
* This function allows users to set the learning rate for the training session. The current
* learning rate is maintained by the training session and can be overwritten by invoking
* this function with the desired learning rate. This function should not be used when a valid
* learning rate scheduler is registered. It should be used either to set the learning rate
* derived from a custom learning rate scheduler or to set a constant learning rate to be used
* throughout the training session.
* \note Please note that this function does not set the initial learning rate that may be needed
* by the predefined learning rate schedulers. To set the initial learning rate for learning
* rate schedulers, please look at the function OrtTrainingApi::RegisterLinearLRScheduler.
*
* \param[in] sess The `this` pointer to the training session.
* \param[in] learning_rate Desired learning rate to be set.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(SetLearningRate, _Inout_ OrtTrainingSession* sess, _In_ float learning_rate);
/** \brief Gets the current learning rate for this training session.
*
* This function allows users to get the learning rate for the training session. The current
* learning rate is maintained by the training session, and users can query it for the purpose
* of implementing their own learning rate schedulers.
*
* \param[in] sess The `this` pointer to the training session.
* \param[out] learning_rate Learning rate currently in use by the training session.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(GetLearningRate, _Inout_ OrtTrainingSession* sess, _Out_ float* learning_rate);
/** \brief Performs the weight updates for the trainable parameters using the optimizer model.
*
* This function performs the weight update step that updates the trainable parameters such that they
* take a step in the direction of their gradients (gradient descent). The optimizer step is performed
* based on the optimizer model that was provided to the training session.
* The updated parameters are stored inside the training state so that they can be used by the next
* OrtTrainingApi::TrainStep function call.
*
* \param[in] sess The `this` pointer to the training session.
* \param[in] run_options Run options for this optimizer step.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(OptimizerStep, _Inout_ OrtTrainingSession* sess,
_In_opt_ const OrtRunOptions* run_options);
/** \brief Registers a linear learning rate scheduler for the training session.
*
* Register a linear learning rate scheduler that decays the learning rate by linearly updated
* multiplicative factor from the initial learning rate set on the training session to 0. The decay
* is performed after the initial warm up phase where the learning rate is linearly incremented
* from 0 to the initial learning rate provided.
*
* \param[in] sess The `this` pointer to the training session.
* \param[in] warmup_step_count Warmup steps for LR warmup.
* \param[in] total_step_count Total step count.
* \param[in] initial_lr The initial learning rate to be used by the training session.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(RegisterLinearLRScheduler, _Inout_ OrtTrainingSession* sess, _In_ const int64_t warmup_step_count,
_In_ const int64_t total_step_count, _In_ const float initial_lr);
/** \brief Update the learning rate based on the registered learing rate scheduler.
*
* Takes a scheduler step that updates the learning rate that is being used by the training session.
* This function should typically be called before invoking the optimizer step for each round,
* or as determined necessary to update the learning rate being used by the training session.
* \note Please note that a valid predefined learning rate scheduler must be first registered to invoke this
* function.
*
* \param[in] sess The `this` pointer to the training session.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(SchedulerStep, _Inout_ OrtTrainingSession* sess);
/// @}
/// \name Accessing The Training Session State
/// @{
/** \brief Retrieves the size of all the parameters.
*
* Calculates the total number of primitive (datatype of the parameters) elements of all the parameters in the
* training state.
* When trainable_only argument is true, the size is calculated for trainable params only.
*
* \param[in] sess The `this` pointer to the training session.
* \param[out] out Size of all parameter elements.
* \param[in] trainable_only Whether to skip non-trainable parameters
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(GetParametersSize, _Inout_ OrtTrainingSession* sess, _Out_ size_t* out, bool trainable_only);
/** \brief Copy all parameters to a contiguous buffer held by the argument parameters_buffer
*
* The parameters_buffer has to be of the size given by GetParametersSize api call,
* with matching setting for the argument trainable_only. All the target parameters must be of the same
* datatype. The OrtValue must be pre-allocated onto
* the desired device. This is a complementary function to OrtTrainingApi::CopyBufferToParameters.
* Parameter ordering is preserved.
* User is responsible for allocating and freeing the resources used by the parameters_buffer.
*
* \param[in] sess The `this` pointer to the training session.
* \param[in] trainable_only Whether to skip non-trainable parameters
* \param[out] parameters_buffer The pre-allocated OrtValue buffer to copy onto.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(CopyParametersToBuffer, _Inout_ OrtTrainingSession* sess,
_Inout_ OrtValue* parameters_buffer, bool trainable_only);
/** \brief Copy parameter values from the given contiguous buffer held by parameters_buffer to the training state
*
* The parameters_buffer argument has to be of the size given by OrtTrainingApi::GetParametersSize api call,
* with matching setting for trainable_only argument. All the target parameters must be of the same
* datatype. This is a complementary function to OrtTrainingApi::CopyParametersToBuffer
* and can be used to load updated buffer values onto the training state.
* Parameter ordering is preserved.
* User is responsible for allocating and freeing the resources used by the parameters_buffer.
* In case the training session was created with a nominal checkpoint, invoking this function is required
* to load the updated parameters onto the checkpoint to complete it.
*
* \param[in] sess The `this` pointer to the training session.
* \param[in] trainable_only Whether to skip non-trainable parameters
* \param[out] parameters_buffer The pre-allocated OrtValue buffer to copy from.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(CopyBufferToParameters, _Inout_ OrtTrainingSession* sess,
_Inout_ OrtValue* parameters_buffer, bool trainable_only);
/// @}
/// \name Release Training Resources
/// @{
/** \brief Frees up the memory used up by the training session.
*
* This function frees up any memory that was allocated in the training session. The training
* session can no longer be used after this call.
*
*/
ORT_CLASS_RELEASE(TrainingSession);
/** \brief Frees up the memory used up by the checkpoint state.
*
* This function frees up any memory that was allocated in the checkpoint state. The checkpoint
* state can no longer be used after this call.
* \note Note that the checkpoint state must be released only after the training session has been released.
*
*/
ORT_CLASS_RELEASE(CheckpointState);
/// @}
/// \name Prepare For Inferencing
/// @{
/** \brief Export a model that can be used for inferencing.
*
* If the training session was provided with an eval model, the training session can generate
* an inference model if it knows the inference graph outputs. The input inference graph outputs
* are used to prune the eval model so that the inference model's outputs align with the provided outputs.
* The exported model is saved at the path provided and can be used for inferencing with InferenceSession.
* \note Note that the function re-loads the eval model from the path provided to OrtTrainingApi::CreateTrainingSession
* and expects that this path still be valid.
*
* \param[in] sess The `this` pointer to the training session.
* \param[in] inference_model_path Path where the inference model should be serialized to.
* \param[in] graph_outputs_len Size of the graph output names array.
* \param[in] graph_output_names Names of the outputs that are needed in the inference model.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(ExportModelForInferencing, _Inout_ OrtTrainingSession* sess,
_In_ const ORTCHAR_T* inference_model_path, size_t graph_outputs_len,
_In_reads_(graph_outputs_len) const char* const* graph_output_names);
/// @}
/// \name Training Utilities
/// @{
/** \brief Sets the seed used for random number generation in Onnxruntime.
*
* Use this function to generate reproducible results. It should be noted that completely reproducible
* results are not guaranteed.
*
* \param[in] seed The seed to be set.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(SetSeed, _In_ const int64_t seed);
/// @}
/// \name Model IO Information
/// @{
/** \brief Retrieves the number of user inputs in the training model.
*
* This function returns the number of inputs of the training model so that the user can accordingly
* allocate the OrtValue(s) provided to the OrtTrainingApi::TrainStep function.
*
* \param[in] sess The `this` pointer to the training session.
* \param[out] out Number of user inputs in the training model.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(TrainingSessionGetTrainingModelInputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
/** \brief Retrieves the number of user inputs in the eval model.
*
* This function returns the number of inputs of the eval model so that the user can accordingly
* allocate the OrtValue(s) provided to the OrtTrainingApi::EvalStep function.
*
* \param[in] sess The `this` pointer to the training session.
* \param[out] out Number of user inputs in the eval model.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(TrainingSessionGetEvalModelInputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
/** \brief Retrieves the name of the user input at given index in the training model.
*
* This function returns the names of inputs of the training model that can be associated with the
* OrtValue(s) provided to the OrtTrainingApi::TrainStep function.
*
* \param[in] sess The `this` pointer to the training session.
* \param[in] index The index of the training model input name requested.
* \param[in] allocator The allocator to use to allocate the memory for the requested name.
* \param[out] output Name of the user input for the training model at the given index.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(TrainingSessionGetTrainingModelInputName, _In_ const OrtTrainingSession* sess, size_t index,
_In_ OrtAllocator* allocator, _Outptr_ char** output);
/** \brief Retrieves the name of the user input at given index in the eval model.
*
* This function returns the names of inputs of the eval model that can be associated with the OrtValue(s) provided
* to the OrtTrainingApi::EvalStep function.
*
* \param[in] sess The `this` pointer to the training session.
* \param[in] index The index of the eval model input name requested.
* \param[in] allocator The allocator to use to allocate the memory for the requested name.
* \param[out] output Name of the user input for the eval model at the given index.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(TrainingSessionGetEvalModelInputName, _In_ const OrtTrainingSession* sess, size_t index,
_In_ OrtAllocator* allocator, _Outptr_ char** output);
/// @}
/// \name Accessing The Training Session State
/// @{
/** \brief Adds or updates the given property to/in the checkpoint state.
*
* Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
* state by the user by calling this function with the corresponding property name and value.
* The given property name must be unique to be able to successfully add the property.
*
* \param[in] checkpoint_state The checkpoint state which should hold the property.
* \param[in] property_name Name of the property being added or updated.
* \param[in] property_type Type of the property associated with the given name.
* \param[in] property_value Property value associated with the given name.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(AddProperty, _Inout_ OrtCheckpointState* checkpoint_state,
_In_ const char* property_name, _In_ enum OrtPropertyType property_type,
_In_ void* property_value);
/** \brief Gets the property value associated with the given name from the checkpoint state.
*
* Gets the property value from an existing entry in the checkpoint state. The property must
* exist in the checkpoint state to be able to retrieve it successfully.
*
* \param[in] checkpoint_state The checkpoint state that is currently holding the property.
* \param[in] property_name Name of the property being retrieved.
* \param[in] allocator Allocator used to allocate the memory for the property_value.
* \param[out] property_type Type of the property associated with the given name.
* \param[out] property_value Property value associated with the given name.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(GetProperty, _In_ const OrtCheckpointState* checkpoint_state,
_In_ const char* property_name, _Inout_ OrtAllocator* allocator,
_Out_ enum OrtPropertyType* property_type, _Outptr_ void** property_value);
/// @}
/// \name Accessing The Training Session State
/// @{
/** \brief Load a checkpoint state from a buffer into checkpoint_state.
*
* This function will parse a checkpoint bytes buffer, pull relevant data and load the training
* state into the checkpoint_state. This checkpoint state can then be used to create the
* training session by invoking OrtTrainingApi::CreateTrainingSession. By doing so, the training
* session will resume training from the given checkpoint state.
* \note Note that the training session created with a checkpoint state uses this state to store the entire
* training state (including model parameters, its gradients, the optimizer states and the properties).
* As a result, it is required that the checkpoint state outlive the lifetime of the training session.
*
* \param[in] checkpoint_buffer Path to the checkpoint bytes buffer.
* \param[in] num_bytes Number of bytes in the checkpoint buffer.
* \param[out] checkpoint_state Checkpoint state that contains the states of the training session.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer,
_In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state);
/** \brief Retrieves the type and shape information of the parameter associated with the given parameter name.
*
* This function retrieves the type and shape of the parameter associated with the given parameter name.
* The parameter must exist in the checkpoint state to be able to retrieve its type and shape information successfully.
*
* \param[in] checkpoint_state The checkpoint state.
* \param[in] parameter_name Name of the parameter being retrieved.
* \param[out] parameter_type_and_shape The type and shape of the parameter being retrieved.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state,
_In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape);
/** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
*
* This function updates a model parameter in the checkpoint state with the given parameter data.
* The training session must be already created with the checkpoint state that contains the parameter
* being updated. The given parameter is copied over to the registered device for the training session.
* The parameter must exist in the checkpoint state to be able to update it successfully.
*
* \param[in] checkpoint_state The checkpoint state.
* \param[in] parameter_name Name of the parameter being updated.
* \param[in] parameter The parameter data that should replace the existing parameter data.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state,
_In_ const char* parameter_name, _In_ OrtValue* parameter);
/** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
*
* This function retrieves the model parameter data from the checkpoint state for the given parameter name.
* The parameter is copied over and returned as an OrtValue. The training session must be already created
* with the checkpoint state that contains the parameter being retrieved.
* The parameter must exist in the checkpoint state to be able to retrieve it successfully.
*
* \param[in] checkpoint_state The checkpoint state.
* \param[in] parameter_name Name of the parameter being retrieved.
* \param[in] allocator Allocator used to allocate the memory for the parameter.
* \param[out] parameter The parameter data that is retrieved from the checkpoint state.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
*/
ORT_API2_STATUS(GetParameter, _In_ const OrtCheckpointState* checkpoint_state,
_In_ const char* parameter_name, _Inout_ OrtAllocator* allocator,
_Outptr_ OrtValue** parameter);
/// @}
};
typedef struct OrtTrainingApi OrtTrainingApi;
/// @}

View File

@@ -1,639 +0,0 @@
package onnxruntime_go
// #cgo CFLAGS: -O2 -g
//
// #include "onnxruntime_wrapper.h"
import "C"
import (
"fmt"
"os"
"path/filepath"
"unsafe"
)
var trainingNotSupportedError error = fmt.Errorf("training not supported by onnx library")
// Scalar is like a tensor but the underlying go slice is of length 1 and it has no dimension.
// It can be used to store e.g. the loss from a training cycle.
type Scalar[T TensorData] struct {
data []T
dataSize uintptr
ortValue *C.OrtValue
}
func (s *Scalar[T]) GetShape() Shape {
return nil
}
func (s *Scalar[T]) ZeroContents() {
C.memset(unsafe.Pointer(&s.data[0]), 0, C.size_t(s.dataSize))
}
func (s *Scalar[T]) Destroy() error {
C.ReleaseOrtValue(s.ortValue)
s.ortValue = nil
s.data = nil
s.dataSize = 0
return nil
}
// GetData returns the undelying data for the scalar.
// If you want to explicitly set the scalar's data, use Set.
func (t *Scalar[T]) GetData() T {
return t.data[0]
}
// Set allows to explicitly set the underlying value for the scalar.
func (t *Scalar[T]) Set(value T) {
t.data = []T{value}
}
func (t *Scalar[T]) DataType() C.ONNXTensorElementDataType {
return GetTensorElementDataType[T]()
}
func (t *Scalar[_]) GetInternals() *ValueInternalData {
return &ValueInternalData{
ortValue: t.ortValue,
}
}
func (t *Scalar[_]) GetONNXType() ONNXType {
return ONNXTypeTensor
}
// NewEmptyScalar creates a new scalar of type T.
func NewEmptyScalar[T TensorData]() (*Scalar[T], error) {
var data T
return NewScalar(data)
}
// NewScalar creates a new scalar of type T backed by a value of type T.
// Note that, differently from tensors, this is not a []T but just a value T.
func NewScalar[T TensorData](data T) (*Scalar[T], error) {
if !IsInitialized() {
return nil, NotInitializedError
}
dataSlice := []T{data}
var ortValue *C.OrtValue
dataType := GetTensorElementDataType[T]()
dataSize := unsafe.Sizeof(dataSlice[0]) * uintptr(1)
status := C.CreateOrtTensorWithShape(unsafe.Pointer(&dataSlice[0]),
C.size_t(dataSize), nil, C.int64_t(0), ortMemoryInfo, dataType, &ortValue)
if status != nil {
return nil, statusToError(status)
}
toReturn := Scalar[T]{
data: dataSlice,
dataSize: dataSize,
ortValue: ortValue,
}
return &toReturn, nil
}
// TraininSession is the type that wraps the C training session object.
type TrainingSession struct {
ortTrainingSession *C.OrtTrainingSession
ortCheckpointState *C.OrtCheckpointState
inputs []*C.OrtValue
outputs []*C.OrtValue
trainingModelPath *C.char
optimizerModelPath *C.char
evalModelPath *C.char
}
// ExportModel is used to export the final trained model to disk. It requires the path for
// the exported model as well as the names of the graph nodes to export.
// Note that currently the final model can only be exported if the session has been
// initialized with NewTrainingSession and the path to the eval model has been provided.
func (s *TrainingSession) ExportModel(path string, outputNames []string) error {
if s.evalModelPath == nil {
return fmt.Errorf("final model can only be exported if the eval model path is " +
"provided at session creation time (see NewTrainingSession)")
}
if path == "" {
return fmt.Errorf("path cannot be empty")
}
dir, _ := filepath.Split(path)
if _, err := os.Stat(dir); dir != "" && os.IsNotExist(err) {
return fmt.Errorf("directory %s does not exist", dir)
}
cOutputNames := make([]*C.char, len(outputNames))
for i, name := range outputNames {
cOutputNames[i] = C.CString(name)
}
cPath, err := createOrtCharString(path)
if err != nil {
return fmt.Errorf("Error converting export path to C string: %w", err)
}
outputLength := C.size_t(len(outputNames))
defer func() {
for i := range cOutputNames {
C.free(unsafe.Pointer(cOutputNames[i]))
}
C.free(unsafe.Pointer(cPath))
}()
status := C.ExportModel(s.ortTrainingSession, cPath, outputLength, &cOutputNames[0])
if status != nil {
return statusToError(status)
}
return nil
}
// SaveCheckpoint can be used to save the current checkpoint state at the specified path.
// This is useful to snapshot the training parameters to continue training later or on
// a different machine.
func (s *TrainingSession) SaveCheckpoint(path string, saveOptimizerState bool) error {
if path == "" {
return fmt.Errorf("path cannot be empty")
}
dir, _ := filepath.Split(path)
if _, err := os.Stat(dir); dir != "" && os.IsNotExist(err) {
return fmt.Errorf("directory %s does not exist", dir)
}
cPath, err := createOrtCharString(path)
if err != nil {
return fmt.Errorf("Error converting path to C string: %w", err)
}
var saveOptimizer int
if saveOptimizerState {
saveOptimizer = 1
}
defer func() {
C.free(unsafe.Pointer(cPath))
}()
status := C.SaveCheckpoint(s.ortCheckpointState, cPath, C.size_t(saveOptimizer))
if status != nil {
return statusToError(status)
}
return nil
}
// Destroy frees all the C memory associated to a training session.
func (s *TrainingSession) Destroy() error {
if s.ortTrainingSession != nil {
C.ReleaseOrtTrainingSession(s.ortTrainingSession)
s.ortTrainingSession = nil
}
// note: checkpoint MUST be released after session
if s.ortCheckpointState != nil {
C.ReleaseCheckpointState(s.ortCheckpointState)
s.ortCheckpointState = nil
}
C.free(unsafe.Pointer(s.trainingModelPath))
s.trainingModelPath = nil
C.free(unsafe.Pointer(s.evalModelPath))
s.evalModelPath = nil
C.free(unsafe.Pointer(s.optimizerModelPath))
s.optimizerModelPath = nil
s.inputs = nil
s.outputs = nil
return nil
}
// TrainStep performs the training step.
func (s *TrainingSession) TrainStep() error {
inputLength := C.size_t(len(s.inputs))
outputLength := C.size_t(len(s.outputs))
status := C.TrainStep(s.ortTrainingSession, inputLength, &s.inputs[0], outputLength, &s.outputs[0])
if status != nil {
return fmt.Errorf("error performing training step: %w", statusToError(status))
}
return nil
}
// TrainStep performs the optimizer step.
func (s *TrainingSession) OptimizerStep() error {
status := C.OptimizerStep(s.ortTrainingSession)
if status != nil {
return fmt.Errorf("error performing optimizer step: %w", statusToError(status))
}
return nil
}
// TrainStep performs the LazyResetGrad step.
func (s *TrainingSession) LazyResetGrad() error {
status := C.LazyResetGrad(s.ortTrainingSession)
if status != nil {
return fmt.Errorf("error performing lazyResetGrad step: %w", statusToError(status))
}
return nil
}
func getInputName(s *C.OrtTrainingSession, i int, model string) (string, error) {
var cName *C.char
var status *C.OrtStatus
switch model {
case "train":
status = C.TrainingSessionGetTrainingInputName(s, C.size_t(i), &cName)
case "eval":
status = C.TrainingSessionGetEvalInputName(s, C.size_t(i), &cName)
default:
return "", fmt.Errorf("%s model not recognized", model)
}
if status != nil {
return "", fmt.Errorf("error getting name: %w", statusToError(status))
}
name, e := convertORTString(cName)
if e != nil {
return "", fmt.Errorf("error converting C name to Go string: %w", e)
}
return name, nil
}
func getOutputName(s *C.OrtTrainingSession, i int, model string) (string, error) {
var cName *C.char
var status *C.OrtStatus
switch model {
case "train":
status = C.TrainingSessionGetTrainingOutputName(s, C.size_t(i), &cName)
case "eval":
status = C.TrainingSessionGetEvalOutputName(s, C.size_t(i), &cName)
default:
return "", fmt.Errorf("%s model not recognized", model)
}
if status != nil {
return "", fmt.Errorf("error getting name: %w", statusToError(status))
}
name, e := convertORTString(cName)
if e != nil {
return "", fmt.Errorf("error converting C name to Go string: %w", e)
}
return name, nil
}
type TrainingInputOutputNames struct {
TrainingInputNames []string
EvalInputNames []string
TrainingOutputNames []string
EvalOutputNames []string
}
// GetInputOutputNames returns the names of the training inputs and outputs
// for the training and validation models. Eval model is optional and can be empty
// string.
func GetInputOutputNames(checkpointStatePath string,
trainingModelPath string,
evalModelPath string) (*TrainingInputOutputNames, error) {
options, e := NewSessionOptions()
if e != nil {
return nil, fmt.Errorf("failed creating options with error: %v\n", e)
}
defer options.Destroy()
checkpointData, e := os.ReadFile(checkpointStatePath)
if e != nil {
return nil, fmt.Errorf("error reading %s: %w", checkpointStatePath, e)
}
trainingData, e := os.ReadFile(trainingModelPath)
if e != nil {
return nil, fmt.Errorf("error reading %s: %w", checkpointStatePath, e)
}
var evalData []byte
if evalModelPath != "" {
evalData, e = os.ReadFile(evalModelPath)
if e != nil {
return nil, fmt.Errorf("error reading %s: %w", evalModelPath, e)
}
}
// create checkpoint C object
ortCheckpointState, e := createCCheckpoint(checkpointData)
if e != nil {
return nil, fmt.Errorf("error creating C checkpointState: %w", e)
}
// create session C object
ortTrainingSession, e := createCTrainingSessionWithOnnxData(ortCheckpointState,
trainingData, evalData, nil, options)
if e != nil {
C.ReleaseCheckpointState(ortCheckpointState)
return nil, fmt.Errorf("error creating C training session: %w", e)
}
defer func() {
C.ReleaseOrtTrainingSession(ortTrainingSession)
C.ReleaseCheckpointState(ortCheckpointState)
}()
var inputCountTraining, inputCountEval C.size_t
status := C.TrainingSessionGetInputCount(ortTrainingSession, &inputCountTraining, &inputCountEval)
if status != nil {
return nil, statusToError(status)
}
var outputCountTraining, outputCountEval C.size_t
status = C.TrainingSessionGetOutputCount(ortTrainingSession, &outputCountTraining, &outputCountEval)
if status != nil {
return nil, statusToError(status)
}
trainInputNames := make([]string, inputCountTraining)
trainOutputNames := make([]string, outputCountTraining)
for i := 0; i < int(inputCountTraining); i++ {
name, err := getInputName(ortTrainingSession, i, "train")
if err != nil {
return nil, fmt.Errorf("error retrieving train input name: %w", err)
}
trainInputNames[i] = name
}
for i := 0; i < int(outputCountTraining); i++ {
name, err := getOutputName(ortTrainingSession, i, "train")
if err != nil {
return nil, fmt.Errorf("error retrieving train output name: %w", err)
}
trainOutputNames[i] = name
}
var evalInputNames []string
var evalOutputNames []string
if len(evalData) > 0 {
evalInputNames = make([]string, inputCountEval)
evalOutputNames = make([]string, outputCountEval)
for i := 0; i < int(inputCountEval); i++ {
name, err := getInputName(ortTrainingSession, i, "eval")
if err != nil {
return nil, fmt.Errorf("error retrieving eval input name: %w", err)
}
evalInputNames[i] = name
}
for i := 0; i < int(outputCountTraining); i++ {
name, err := getOutputName(ortTrainingSession, i, "eval")
if err != nil {
return nil, fmt.Errorf("error retrieving eval output name: %w", err)
}
evalOutputNames[i] = name
}
}
return &TrainingInputOutputNames{
TrainingInputNames: trainInputNames,
EvalInputNames: evalInputNames,
TrainingOutputNames: trainOutputNames,
EvalOutputNames: evalOutputNames,
}, nil
}
// IsTrainingSupported returns true if the training api is supported
// by the onnxruntime library.
func IsTrainingSupported() bool {
return C.IsTrainingApiSupported() != 0
}
func checkTraining() error {
if !IsInitialized() {
return NotInitializedError
}
if !IsTrainingSupported() {
return trainingNotSupportedError
}
return nil
}
func createCCheckpoint(onnxData []byte) (*C.OrtCheckpointState, error) {
if e := checkTraining(); e != nil {
return nil, e
}
if len(onnxData) == 0 {
return nil, fmt.Errorf("Missing checkpoint data")
}
var ortCheckpointState *C.OrtCheckpointState
status := C.CreateCheckpoint(unsafe.Pointer(&(onnxData[0])), C.size_t(len(onnxData)), &ortCheckpointState)
if status != nil {
return nil, statusToError(status)
}
return ortCheckpointState, nil
}
// createCTrainingSessionWithOnnxData creates a C session from byte data using buffers
func createCTrainingSessionWithOnnxData(checkpointState *C.OrtCheckpointState,
trainingData, evalData, optimizerData []byte,
options *SessionOptions) (*C.OrtTrainingSession, error) {
if e := checkTraining(); e != nil {
return nil, e
}
var ortTrainingSession *C.OrtTrainingSession
var ortSessionOptions *C.OrtSessionOptions
if options != nil {
ortSessionOptions = options.o
}
// eval model is optional
var evalDataPtr unsafe.Pointer
var evalDataSize C.size_t
if len(evalData) > 0 {
evalDataPtr = unsafe.Pointer(&(evalData[0]))
evalDataSize = C.size_t(len(evalData))
}
// optimizer model is also optional when e.g. getting input and output names
var optimizerDataPtr unsafe.Pointer
var optimizerDataSize C.size_t
if len(optimizerData) > 0 {
optimizerDataPtr = unsafe.Pointer(&(optimizerData[0]))
optimizerDataSize = C.size_t(len(optimizerData))
}
status := C.CreateTrainingSessionFromBuffer(
checkpointState,
unsafe.Pointer(&(trainingData[0])), C.size_t(len(trainingData)),
evalDataPtr, evalDataSize,
optimizerDataPtr, optimizerDataSize,
ortEnv, &ortTrainingSession, ortSessionOptions)
if status != nil {
return nil, statusToError(status)
}
return ortTrainingSession, nil
}
// createCTrainingSessionWithPaths creates a C session from paths
func createCtrainingSessionWithPaths(checkpointState *C.OrtCheckpointState,
trainingPath, evalPath, optimizerPath *C.char,
options *SessionOptions) (*C.OrtTrainingSession, error) {
if e := checkTraining(); e != nil {
return nil, e
}
var ortTrainingSession *C.OrtTrainingSession
var ortSessionOptions *C.OrtSessionOptions
if options != nil {
ortSessionOptions = options.o
}
status := C.CreateTrainingSessionFromPaths(checkpointState,
trainingPath, evalPath, optimizerPath, ortEnv, &ortTrainingSession, ortSessionOptions)
if status != nil {
return nil, statusToError(status)
}
return ortTrainingSession, nil
}
// NewTrainingSessionWithOnnxData is like NewTrainingSession, but it accepts
// bytes rather than paths to the training assets. Note that there does not
// seem to currently be a way to export the trained model from a session
// instantiated from bytes. If you wish to export the trained model, you should
// use NewTrainingSession instead.
func NewTrainingSessionWithOnnxData(checkpointData []byte,
trainingData []byte,
evalData []byte,
optimizerData []byte,
inputs, outputs []Value,
options *SessionOptions) (*TrainingSession, error) {
if err := checkTraining(); err != nil {
return nil, err
}
if err := validateInputOutputs(inputs, outputs); err != nil {
return nil, err
}
if len(trainingData) == 0 {
return nil, fmt.Errorf("training data has length zero.")
}
if len(optimizerData) == 0 {
return nil, fmt.Errorf("optimizer data has length zero.")
}
// create checkpoint C object
ortCheckpointState, e := createCCheckpoint(checkpointData)
if e != nil {
return nil, fmt.Errorf("error creating C checkpointState: %w", e)
}
// create session C object
ortTrainingSession, e := createCTrainingSessionWithOnnxData(ortCheckpointState,
trainingData, evalData, optimizerData, options)
if e != nil {
return nil, fmt.Errorf("error creating C training session: %w", e)
}
inputOrtTensors := make([]*C.OrtValue, len(inputs))
outputOrtTensors := make([]*C.OrtValue, len(outputs))
for i, v := range inputs {
inputOrtTensors[i] = v.GetInternals().ortValue
}
for i, v := range outputs {
outputOrtTensors[i] = v.GetInternals().ortValue
}
return &TrainingSession{
ortCheckpointState: ortCheckpointState,
ortTrainingSession: ortTrainingSession,
inputs: inputOrtTensors,
outputs: outputOrtTensors,
}, nil
}
func validateInputOutputs(inputs, outputs []Value) error {
if len(inputs) == 0 {
return fmt.Errorf("inputs must have length greater than zero")
}
if len(outputs) == 0 {
return fmt.Errorf("outputs must have length greater than zero")
}
return nil
}
// NewTrainingSession creates a new training session from paths stored on disk.
// evalModelPath is optional and can be the empty string. In case it is not
// provided, only the checkpoint state can be exported once training is complete
// (and not the final inference model).
func NewTrainingSession(checkpointStatePath string,
trainingModelPath string,
evalModelPath string,
optimizerModelPath string,
inputs, outputs []Value,
options *SessionOptions) (*TrainingSession, error) {
if err := checkTraining(); err != nil {
return nil, err
}
if err := validateInputOutputs(inputs, outputs); err != nil {
return nil, err
}
checkPointContent, e := os.ReadFile(checkpointStatePath)
if e != nil {
return nil, fmt.Errorf("reading checkpoint data failed: %s", e.Error())
}
// create checkpoint C object
ortCheckpointState, e := createCCheckpoint(checkPointContent)
if e != nil {
return nil, fmt.Errorf("error creating C checkpointState: %w", e)
}
// create session C object
if _, err := os.Stat(trainingModelPath); os.IsNotExist(err) {
return nil, fmt.Errorf("training model does not exist at path %s", trainingModelPath)
}
cTrainingPath, err := createOrtCharString(trainingModelPath)
if err != nil {
return nil, fmt.Errorf("Error converting training model path to C string: %w", err)
}
if _, err := os.Stat(optimizerModelPath); os.IsNotExist(err) {
return nil, fmt.Errorf("optimizer s does not exist at path %s", optimizerModelPath)
}
cOptimizerPath, err := createOrtCharString(optimizerModelPath)
if err != nil {
return nil, fmt.Errorf("Error converting optimizer path to C string: %w", err)
}
// eval is optional
var cEvalPath *C.char
if evalModelPath != "" {
if _, err := os.Stat(evalModelPath); os.IsNotExist(err) {
return nil, fmt.Errorf("eval model does not exist at path %s", evalModelPath)
}
cEvalPath, err = createOrtCharString(evalModelPath)
if err != nil {
return nil, fmt.Errorf("Error converting eval path to C string: %w", err)
}
} else {
cEvalPath = nil
}
ortTrainingSession, e := createCtrainingSessionWithPaths(ortCheckpointState,
cTrainingPath, cEvalPath, cOptimizerPath, options)
if e != nil {
return nil, fmt.Errorf("error creating C training session: %w", e)
}
inputOrtTensors := make([]*C.OrtValue, len(inputs))
outputOrtTensors := make([]*C.OrtValue, len(outputs))
for i, v := range inputs {
inputOrtTensors[i] = v.GetInternals().ortValue
}
for i, v := range outputs {
outputOrtTensors[i] = v.GetInternals().ortValue
}
return &TrainingSession{
ortCheckpointState: ortCheckpointState,
ortTrainingSession: ortTrainingSession,
inputs: inputOrtTensors,
outputs: outputOrtTensors,
evalModelPath: cEvalPath,
trainingModelPath: cTrainingPath,
optimizerModelPath: cOptimizerPath,
}, nil
}

View File

@@ -1,353 +0,0 @@
package onnxruntime_go
import (
"errors"
"math"
"math/rand"
"os"
"path"
"testing"
)
func TestTrainingNotSupported(t *testing.T) {
InitializeRuntime(t)
defer CleanupRuntime(t)
if IsTrainingSupported() {
t.Skipf("onnxruntime library supports training")
}
options, e := NewSessionOptions()
if e != nil {
t.Logf("Failed creating options: %s\n", e)
t.FailNow()
}
trainingSession, e := NewTrainingSession("test_data/onnxruntime_training_test/training_artifacts/checkpoint",
"test_data/onnxruntime_training_test/training_artifacts/training_model.onnx",
"test_data/onnxruntime_training_test/training_artifacts/eval_model.onnx",
"test_data/onnxruntime_training_test/training_artifacts/optimizer_model.onnx",
nil, nil,
options)
if !errors.Is(e, trainingNotSupportedError) {
t.Logf("Creating training session when onnxruntime lib does not support it should return training not supported error.")
if e != nil {
t.Logf("Received instead error: %s", e.Error())
} else {
t.Logf("Received no error instead")
}
t.FailNow()
}
if trainingSession != nil {
if err := trainingSession.Destroy(); err != nil {
t.Fatalf("cleanup of training session failed with error: %v", e)
}
}
}
func TestGetInputOutputNames(t *testing.T) {
InitializeRuntime(t)
defer CleanupRuntime(t)
if !IsTrainingSupported() {
t.Skipf("Training is not supported on this platform/onnxruntime build.")
}
artifactsPath := path.Join("test_data", "training_test")
names, err := GetInputOutputNames(
path.Join(artifactsPath, "checkpoint"),
path.Join(artifactsPath, "training_model.onnx"),
path.Join(artifactsPath, "eval_model.onnx"),
)
if err != nil {
t.Fatalf("Failed getting input and output names with error: %v\n", err)
}
expectedTrainInputNames := []string{"input", "target"}
expectedEvalInputNames := expectedTrainInputNames
expectedTrainOutputNames := []string{"onnx::reducemean_output::5"}
expectedEvalOutputNames := expectedTrainOutputNames
for i, v := range names.TrainingInputNames {
if v != expectedTrainInputNames[i] {
t.Fatalf("training input names don't match")
}
}
for i, v := range names.TrainingOutputNames {
if v != expectedTrainOutputNames[i] {
t.Fatalf("training output names don't match")
}
}
for i, v := range names.EvalInputNames {
if v != expectedEvalInputNames[i] {
t.Fatalf("eval input names don't match")
}
}
for i, v := range names.EvalOutputNames {
if v != expectedEvalOutputNames[i] {
t.Fatalf("eval output names don't match")
}
}
// without eval model
names, err = GetInputOutputNames(
path.Join(artifactsPath, "checkpoint"),
path.Join(artifactsPath, "training_model.onnx"),
"",
)
if err != nil {
t.Fatalf("Failed getting input and output names with error: %v\n", err)
}
for i, v := range names.TrainingInputNames {
if v != expectedTrainInputNames[i] {
t.Fatalf("training input names don't match")
}
}
for i, v := range names.TrainingOutputNames {
if v != expectedTrainOutputNames[i] {
t.Fatalf("training output names don't match")
}
}
}
func generateBatchData(nBatches int, batchSize int) map[int]map[string][]float32 {
batchData := map[int]map[string][]float32{}
source := rand.NewSource(1234)
g := rand.New(source)
for i := 0; i < nBatches; i++ {
inputCounter := 0
outputCounter := 0
inputSlice := make([]float32, batchSize*4)
outputSlice := make([]float32, batchSize*2)
batchData[i] = map[string][]float32{}
// generate random data for batch
for n := 0; n < batchSize; n++ {
var sum float32
min := float32(1)
max := float32(-1)
for i := 0; i < 4; i++ {
r := g.Float32()
inputSlice[inputCounter] = r
inputCounter++
if r > max {
max = r
}
if r < min {
min = r
}
sum = sum + r
}
outputSlice[outputCounter] = sum
outputSlice[outputCounter+1] = max - min
outputCounter = outputCounter + 2
}
batchData[i]["input"] = inputSlice
batchData[i]["output"] = outputSlice
}
return batchData
}
// TestTraining tests a basic training flow using the bindings to the C api for on-device onnxruntime training
func TestTraining(t *testing.T) {
InitializeRuntime(t)
defer CleanupRuntime(t)
if !IsTrainingSupported() {
t.Skipf("Training is not supported on this platform/onnxruntime build.")
}
trainingArtifactsFolder := path.Join("test_data", "training_test")
// generate training data
batchSize := 10
nBatches := 10
// holds inputs/outputs and loss for each training batch
batchInputShape := NewShape(int64(batchSize), 1, 4)
batchTargetShape := NewShape(int64(batchSize), 1, 2)
batchInputTensor, err := NewEmptyTensor[float32](batchInputShape)
if err != nil {
t.Fatalf("training test failed with error: %v", err)
}
batchTargetTensor, err := NewEmptyTensor[float32](batchTargetShape)
if err != nil {
t.Fatalf("training test failed with error: %v", err)
}
lossScalar, err := NewEmptyScalar[float32]()
if err != nil {
t.Fatalf("training test failed with error: %v", err)
}
trainingSession, errorSessionCreation := NewTrainingSession(
path.Join(trainingArtifactsFolder, "checkpoint"),
path.Join(trainingArtifactsFolder, "training_model.onnx"),
path.Join(trainingArtifactsFolder, "eval_model.onnx"),
path.Join(trainingArtifactsFolder, "optimizer_model.onnx"),
[]Value{batchInputTensor, batchTargetTensor}, []Value{lossScalar},
nil)
if errorSessionCreation != nil {
t.Fatalf("session creation failed with error: %v", errorSessionCreation)
}
// cleanup after test run
defer func(session *TrainingSession, tensors []Value) {
var errs []error
errs = append(errs, session.Destroy())
for _, t := range tensors {
errs = append(errs, t.Destroy())
}
if e := errors.Join(errs...); e != nil {
t.Fatalf("cleanup of test failed with error: %v", e)
}
}(trainingSession, []Value{batchInputTensor, batchTargetTensor, lossScalar})
losses := []float32{}
epochs := 100
batchData := generateBatchData(nBatches, batchSize)
for epoch := 0; epoch < epochs; epoch++ {
var epochLoss float32 // total epoch loss
for i := 0; i < nBatches; i++ {
inputSlice := batchInputTensor.GetData()
outputSlice := batchTargetTensor.GetData()
copy(inputSlice, batchData[i]["input"])
copy(outputSlice, batchData[i]["output"])
// train on batch
err = trainingSession.TrainStep()
if err != nil {
t.Fatalf("train step failed with error: %v", err)
}
epochLoss = epochLoss + lossScalar.GetData()
err = trainingSession.OptimizerStep()
if err != nil {
t.Fatalf("optimizer step failed with error: %v", err)
}
// ort training api - reset the gradients to zero so that new gradients can be computed for next batch
err = trainingSession.LazyResetGrad()
if err != nil {
t.Fatalf("lazy reset grad step failed with error: %v", err)
}
}
if epoch%10 == 0 {
t.Logf("Epoch {%d} Loss {%f}\n", epoch+1, epochLoss/float32(batchSize*nBatches))
losses = append(losses, epochLoss/float32(batchSize*nBatches))
}
}
expectedLosses := []float32{
0.125085,
0.097187,
0.062333,
0.024307,
0.019963,
0.018476,
0.017160,
0.015982,
0.014845,
0.013867,
}
for i, l := range losses {
diff := math.Abs(float64(l - expectedLosses[i]))
deviation := diff / float64(expectedLosses[i])
if deviation > 0.6 {
t.Fatalf("loss deviation too large: expected %f, actual %f, deviation %f", float64(expectedLosses[i]), float64(l), float64(deviation))
}
}
// test the saving of the checkpoint state
finalCheckpointPath := path.Join("test_data", "training_test", "finalCheckpoint")
errSaveCheckpoint := trainingSession.SaveCheckpoint(finalCheckpointPath, false)
if errSaveCheckpoint != nil {
t.Fatalf("Saving of checkpoint failed with error: %v", errSaveCheckpoint)
}
// test the saving of the model
finalModelPath := path.Join("test_data", "training_test", "final_inference.onnx")
errExport := trainingSession.ExportModel(finalModelPath, []string{"output"})
if errExport != nil {
t.Fatalf("Exporting model failed with error: %v", errExport)
}
defer func() {
e := os.Remove(finalCheckpointPath)
if e != nil {
t.Errorf("Error removing final checkpoint file %s: %s", finalCheckpointPath, e)
}
e = os.Remove(finalModelPath)
if e != nil {
t.Errorf("Error removing final model file %s: %s", finalModelPath, e)
}
}()
// load the model back in and test in-sample predictions for the first batch
// (we care about correctness more than generalization here)
session, err := NewAdvancedSession(path.Join("test_data", "training_test", "final_inference.onnx"),
[]string{"input"}, []string{"output"},
[]Value{batchInputTensor}, []Value{batchTargetTensor}, nil)
if err != nil {
t.Fatalf("creation of inference session failed with error: %v", err)
}
defer func(s *AdvancedSession) {
err := s.Destroy()
if err != nil {
t.Fatalf("cleanup of inference session failed with error: %v", err)
}
}(session)
// Calling Run() will run the network, reading the current contents of the
// input tensors and modifying the contents of the output tensors.
copy(batchInputTensor.GetData(), batchData[0]["input"])
err = session.Run()
if err != nil {
t.Fatalf("run of inference session failed with error: %v", err)
}
expectedOutput := []float32{
2.4524384,
0.65120333,
2.5457804,
0.6102175,
1.6276635,
0.573755,
1.7900972,
0.59951085,
3.1650176,
0.66626525,
1.9361509,
0.571084,
2.0798547,
0.6060241,
0.9611889,
0.52100605,
1.4070896,
0.5412475,
2.1449144,
0.5985652,
}
for i, l := range batchTargetTensor.GetData() {
diff := math.Abs(float64(l - expectedOutput[i]))
deviation := diff / float64(expectedOutput[i])
if deviation > 0.6 {
t.Fatalf("deviation too large")
}
}
}

View File

@@ -402,146 +402,3 @@ OrtStatus *CreateOrtValue(OrtValue **in, size_t num_values,
value_type, out);
}
// TRAINING API WRAPPER
static const OrtTrainingApi *ort_training_api = NULL;
void SetTrainingApi() {
ort_training_api = ort_api->GetTrainingApi(ORT_API_VERSION);
}
int IsTrainingApiSupported() {
return ort_training_api != NULL;
}
OrtStatus *CreateCheckpoint(void *checkpoint_data, size_t checkpoint_data_length, OrtCheckpointState **out) {
OrtStatus *status = NULL;
status = ort_training_api->LoadCheckpointFromBuffer(checkpoint_data, checkpoint_data_length, out);
return status;
}
OrtStatus *CreateTrainingSessionFromBuffer(OrtCheckpointState *checkpoint_state,
void *training_model_data, size_t training_model_data_length,
void *eval_model_data, size_t eval_model_data_length,
void *optim_model_data, size_t optim_model_data_length,
OrtEnv *env, OrtTrainingSession **out, OrtSessionOptions *options) {
OrtStatus *status = NULL;
int default_options = 0;
if (!options) {
default_options = 1;
status = ort_api->CreateSessionOptions(&options);
if (status) return status;
}
status = ort_training_api->CreateTrainingSessionFromBuffer(env, options, checkpoint_state,
training_model_data, training_model_data_length, eval_model_data, eval_model_data_length,
optim_model_data, optim_model_data_length, out);
if (default_options) {
ort_api->ReleaseSessionOptions(options);
}
return status;
}
OrtStatus *CreateTrainingSessionFromPaths(OrtCheckpointState *checkpoint_state,
char *training_model_path, char *eval_model_path, char *optim_model_path,
OrtEnv *env, OrtTrainingSession **out, OrtSessionOptions *options) {
OrtStatus *status = NULL;
int default_options = 0;
if (!options) {
default_options = 1;
status = ort_api->CreateSessionOptions(&options);
if (status) return status;
}
status = ort_training_api->CreateTrainingSession(env, options, checkpoint_state,
training_model_path, eval_model_path, optim_model_path, out);
if (default_options) {
ort_api->ReleaseSessionOptions(options);
}
return status;
}
OrtStatus *TrainingSessionGetInputCount(OrtTrainingSession *training_session, size_t *result_training, size_t *result_eval) {
OrtStatus *status = NULL;
status = ort_training_api->TrainingSessionGetTrainingModelInputCount(training_session, result_training);
if (status) return status;
status = ort_training_api->TrainingSessionGetEvalModelInputCount(training_session, result_eval);
return status;
}
OrtStatus *TrainingSessionGetOutputCount(OrtTrainingSession *training_session, size_t *result_training, size_t *result_eval) {
OrtStatus *status = NULL;
status = ort_training_api->TrainingSessionGetTrainingModelOutputCount(training_session, result_training);
if (status) return status;
status = ort_training_api->TrainingSessionGetEvalModelOutputCount(training_session, result_eval);
return status;
}
OrtStatus *TrainingSessionGetTrainingInputName(OrtTrainingSession *training_session, size_t i, char **name) {
OrtAllocator *allocator = NULL;
OrtStatus *status = NULL;
status = ort_api->GetAllocatorWithDefaultOptions(&allocator);
if (status) return status;
return ort_training_api->TrainingSessionGetTrainingModelInputName(training_session, i, allocator, name);
}
OrtStatus *TrainingSessionGetTrainingOutputName(OrtTrainingSession *training_session, size_t i, char **name) {
OrtAllocator *allocator = NULL;
OrtStatus *status = NULL;
status = ort_api->GetAllocatorWithDefaultOptions(&allocator);
if (status) return status;
return ort_training_api->TrainingSessionGetTrainingModelOutputName(training_session, i, allocator, name);
}
OrtStatus *TrainingSessionGetEvalInputName(OrtTrainingSession *training_session, size_t i, char **name) {
OrtAllocator *allocator = NULL;
OrtStatus *status = NULL;
status = ort_api->GetAllocatorWithDefaultOptions(&allocator);
if (status) return status;
return ort_training_api->TrainingSessionGetEvalModelInputName(training_session, i, allocator, name);
}
OrtStatus *TrainingSessionGetEvalOutputName(OrtTrainingSession *training_session, size_t i, char **name) {
OrtAllocator *allocator = NULL;
OrtStatus *status = NULL;
status = ort_api->GetAllocatorWithDefaultOptions(&allocator);
if (status) return status;
return ort_training_api->TrainingSessionGetEvalModelOutputName(training_session, i, allocator, name);
}
OrtStatus *TrainStep(OrtTrainingSession *training_session, size_t inputs_len, OrtValue **inputs, size_t output_len, OrtValue **outputs) {
OrtStatus *status = NULL;
status = ort_training_api->TrainStep(training_session, NULL, inputs_len, (const OrtValue* const*) inputs, output_len, outputs);
return status;
}
OrtStatus *OptimizerStep(OrtTrainingSession *training_session) {
OrtStatus *status = NULL;
status = ort_training_api->OptimizerStep(training_session, NULL);
return status;
}
OrtStatus *LazyResetGrad(OrtTrainingSession *training_session) {
OrtStatus *status = NULL;
status = ort_training_api->LazyResetGrad(training_session);
return status;
}
OrtStatus *SaveCheckpoint(OrtCheckpointState *checkpoint, char *path, size_t include_optimizer) {
OrtStatus *status = NULL;
status = ort_training_api->SaveCheckpoint(checkpoint, path, include_optimizer);
return status;
}
OrtStatus *ExportModel(OrtTrainingSession *training_session, char *path, size_t outputs_len, char **output_names) {
OrtStatus *status = NULL;
status = ort_training_api->ExportModelForInferencing(training_session, path, outputs_len, (const char* const*) output_names);
return status;
}
void ReleaseOrtTrainingSession(OrtTrainingSession *session) {
ort_training_api->ReleaseTrainingSession(session);
}
void ReleaseCheckpointState(OrtCheckpointState *checkpoint) {
ort_training_api->ReleaseCheckpointState(checkpoint);
}

View File

@@ -15,7 +15,6 @@
// Next, we actually include the header.
#undef _WIN32
#include "onnxruntime_c_api.h"
#include "onnxruntime_training_c_api.h"
// ... However, mingw will complain if _WIN32 is *not* defined! So redefine it.
#define _WIN32
@@ -254,75 +253,6 @@ OrtStatus *GetValueCount(OrtValue *v, size_t *out);
OrtStatus *CreateOrtValue(OrtValue **in, size_t num_values,
enum ONNXType value_type, OrtValue **out);
// TRAINING API WRAPPER
void SetTrainingApi();
// Checks if training api is supported.
int IsTrainingApiSupported();
// Wraps ort_training_api->CreateSessionFromBuffer.
// Creates and ORT checkpoint from the checkpoint data.
OrtStatus *CreateCheckpoint(void *checkpoint_data,
size_t checkpoint_data_length, OrtCheckpointState **out);
// Wraps ort_training_api->CreateTrainingSessionFromBuffer. Creates an ORT
// training session using the given models and checkpoint. The given options
// pointer may be NULL; if it is, then we'll use default options.
OrtStatus *CreateTrainingSessionFromBuffer(OrtCheckpointState *checkpoint_state,
void *training_model_data, size_t training_model_data_length,
void *eval_model_data, size_t eval_model_data_length,
void *optim_model_data, size_t optim_model_data_length,
OrtEnv *env, OrtTrainingSession **out, OrtSessionOptions *options);
// Wraps ort_training_api->CreateTrainingSession.
// Currently this is the only way to create a training session that is able to
// export the final trained model to disk.
OrtStatus *CreateTrainingSessionFromPaths(OrtCheckpointState *checkpoint_state,
char *training_model_path, char *eval_model_path, char *optim_model_path,
OrtEnv *env, OrtTrainingSession **out, OrtSessionOptions *options);
// Wraps ort_training_api->TrainingSessionGetTrainingModelInputCount
// and ort_training_api->TrainingSessionGetEvalgModelInputCount.
OrtStatus *TrainingSessionGetInputCount(OrtTrainingSession *training_session, size_t *result_training, size_t *result_eval);
// Wraps ort_training_api->TrainingSessionGetTrainingModelOutputCounet
// and ort_training_api->TrainingSessionGetEvalgModelOutputCount.
OrtStatus *TrainingSessionGetOutputCount(OrtTrainingSession *training_session, size_t *result_training, size_t *result_eval);
// Wraps ort_training_api->TrainingSessionGetTrainingModelInputName.
OrtStatus *TrainingSessionGetTrainingInputName(OrtTrainingSession *training_session, size_t i, char **name);
// Wraps ort_training_api->TrainingSessionGetEvalModelInputName.
OrtStatus *TrainingSessionGetEvalInputName(OrtTrainingSession *training_session, size_t i, char **name);
// Wraps ort_training_api->TrainingSessionGetTrainingModelOutputName.
OrtStatus *TrainingSessionGetTrainingOutputName(OrtTrainingSession *training_session, size_t i, char **name);
// Wraps ort_training_api->TrainingSessionGetEvalModelOutputName.
OrtStatus *TrainingSessionGetEvalOutputName(OrtTrainingSession *training_session, size_t i, char **name);
// Wraps ort_training_api->TrainStep.
OrtStatus *TrainStep(OrtTrainingSession *training_session, size_t inputs_len, OrtValue **inputs, size_t output_len, OrtValue **outputs);
// Wraps ort_training_api->OptimizerStep.
OrtStatus *OptimizerStep(OrtTrainingSession *training_session);
// Wraps ort_training_api->LazyResetGrad.
OrtStatus *LazyResetGrad(OrtTrainingSession *training_session);
// Wraps ort_training_api->SaveCheckpoint.
OrtStatus *SaveCheckpoint(OrtCheckpointState *checkpoint, char *path, size_t include_optimizer);
// Wraps ort_training_api->ExportModel.
OrtStatus *ExportModel(OrtTrainingSession *training_session, char *path, size_t outputs_len, char **output_names);
// Wraps ort_training_api->ReleaseTrainingSession.
void ReleaseOrtTrainingSession(OrtTrainingSession *session);
// Wraps ort_training_api->ReleaseCheckpointState.
void ReleaseCheckpointState(OrtCheckpointState *checkpoint);
#ifdef __cplusplus
} // extern "C"
#endif

View File

@@ -56,10 +56,6 @@ func platformInitializeEnvironment() error {
return fmt.Errorf("Error setting ORT API base: %d", tmp)
}
// we do not initialize the training API on windows (see setup_env.go)
// because currently we cannot support the conversion from UTF-8 to wide
// character. See https://github.com/yalue/onnxruntime_go/pull/56.
libraryHandle = handle
return nil
}