mirror of
https://github.com/yalue/onnxruntime_go.git
synced 2025-09-26 19:31:13 +08:00
Remove training API
- The training API has been deprecated by onnxruntime itself, and it will be much easier to remove it rather than deprecate it. - The training API wrapper functions have been replaced by stubs that return errors in legacy_types.go. - The README mentions the old version required for the training API. - The Scalar type has been promoted to onnxruntime_go.go, and a test has been added for it.
This commit is contained in:
15
README.md
15
README.md
@@ -206,10 +206,13 @@ run and pass.
|
||||
Training API Support
|
||||
--------------------
|
||||
|
||||
This wrapper supports the onnxruntime training API on limited platforms. See
|
||||
the `NewTrainingSession` and associated data types or functions to use it. So
|
||||
far, the training API has only been tested on Linux, on `x86_64` architectures.
|
||||
The training API has been deprecated as of onnxruntime version 1.20. Rather
|
||||
than continuing to maintain wrappers for a deprecated API, `onnxruntime_go` has
|
||||
replaced the wrapper functions for the training API with stubs that return an
|
||||
error. Users who need to continue to use the training API will need to use an
|
||||
older version. For example the following versions should be compatible with
|
||||
training:
|
||||
|
||||
- Version `v1.12.1` of `onnxruntime_go`, and
|
||||
- Version 1.19.2 of `onnxruntime`.
|
||||
|
||||
If you are not sure whether your platform or build of onnxruntime supports
|
||||
training, you can call `onnxruntime_go.IsTrainingSupported()`, which will
|
||||
return `true` if training is supported on your system.
|
||||
|
@@ -171,3 +171,77 @@ type ArbitraryTensor = Value
|
||||
// As with the ArbitraryTensor type, this type alias only exists to facilitate
|
||||
// renaming an old type without breaking existing code.
|
||||
type TensorInternalData = ValueInternalData
|
||||
|
||||
var TrainingAPIRemovedError error = fmt.Errorf("Support for the training " +
|
||||
"API has been removed from onnxruntime_go following its deprecation in " +
|
||||
"onnxruntime versions 1.19.2 and later. The last revision of " +
|
||||
"onnxruntime_go supporting the training API is version v1.12.1")
|
||||
|
||||
// Support for TrainingSessions has been removed from onnxruntime_go following
|
||||
// the deprecation of the training API in onnxruntime 1.20.0.
|
||||
type TrainingSession struct{}
|
||||
|
||||
// Always returns TrainingAPIRemovedError.
|
||||
func (s *TrainingSession) ExportModel(path string, outputNames []string) error {
|
||||
return TrainingAPIRemovedError
|
||||
}
|
||||
|
||||
// Always returns TrainingAPIRemovedError.
|
||||
func (s *TrainingSession) SaveCheckpoint(path string,
|
||||
saveOptimizerState bool) error {
|
||||
return TrainingAPIRemovedError
|
||||
}
|
||||
|
||||
// Always returns TrainingAPIRemovedError.
|
||||
func (s *TrainingSession) Destroy() error {
|
||||
return TrainingAPIRemovedError
|
||||
}
|
||||
|
||||
// Always returns TrainingAPIRemovedError.
|
||||
func (s *TrainingSession) TrainStep() error {
|
||||
return TrainingAPIRemovedError
|
||||
}
|
||||
|
||||
// Always returns TrainingAPIRemovedError.
|
||||
func (s *TrainingSession) OptimizerStep() error {
|
||||
return TrainingAPIRemovedError
|
||||
}
|
||||
|
||||
// Always returns TrainingAPIRemovedError.
|
||||
func (s *TrainingSession) LazyResetGrad() error {
|
||||
return TrainingAPIRemovedError
|
||||
}
|
||||
|
||||
// Support for TrainingInputOutputNames has been removed from onnxruntime_go
|
||||
// following the deprecation of the training API in onnxruntime 1.20.0.
|
||||
type TrainingInputOutputNames struct {
|
||||
TrainingInputNames []string
|
||||
EvalInputNames []string
|
||||
TrainingOutputNames []string
|
||||
EvalOutputNames []string
|
||||
}
|
||||
|
||||
// Always returns (nil, TrainingAPIRemovedError).
|
||||
func GetInputOutputNames(checkpointStatePath string, trainingModelPath string,
|
||||
evalModelPath string) (*TrainingInputOutputNames, error) {
|
||||
return nil, TrainingAPIRemovedError
|
||||
}
|
||||
|
||||
// Always returns false.
|
||||
func IsTrainingSupported() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Always returns (nil, TrainingAPIRemovedError).
|
||||
func NewTrainingSessionWithOnnxData(checkpointData, trainingData, evalData,
|
||||
optimizerData []byte, inputs, outputs []Value,
|
||||
options *SessionOptions) (*TrainingSession, error) {
|
||||
return nil, TrainingAPIRemovedError
|
||||
}
|
||||
|
||||
// Always returns (nil, TrainingAPIRemovedError).
|
||||
func NewTrainingSession(checkpointStatePath, trainingModelPath, evalModelPath,
|
||||
optimizerModelPath string, inputs, outputs []Value,
|
||||
options *SessionOptions) (*TrainingSession, error) {
|
||||
return nil, TrainingAPIRemovedError
|
||||
}
|
||||
|
@@ -90,9 +90,6 @@ func InitializeEnvironment() error {
|
||||
return fmt.Errorf("Platform-specific initialization failed: %w", e)
|
||||
}
|
||||
|
||||
// Get the training API pointer if it is supported.
|
||||
C.SetTrainingApi()
|
||||
|
||||
name := C.CString("Golang onnxruntime environment")
|
||||
defer C.free(unsafe.Pointer(name))
|
||||
status := C.CreateOrtEnv(name, &ortEnv)
|
||||
@@ -883,6 +880,90 @@ func (t *CustomDataTensor) GetData() []byte {
|
||||
return t.data
|
||||
}
|
||||
|
||||
// Scalar is like a tensor but the underlying go slice is of length 1 and it
|
||||
// has no dimension. It was introduced for use with the training API, but
|
||||
// remains supported since it conceivable will have use outside of the training
|
||||
// API.
|
||||
type Scalar[T TensorData] struct {
|
||||
data []T
|
||||
dataSize uintptr
|
||||
ortValue *C.OrtValue
|
||||
}
|
||||
|
||||
// Always returns nil for Scalars.
|
||||
func (s *Scalar[T]) GetShape() Shape {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Scalar[T]) ZeroContents() {
|
||||
C.memset(unsafe.Pointer(&s.data[0]), 0, C.size_t(s.dataSize))
|
||||
}
|
||||
|
||||
func (s *Scalar[T]) Destroy() error {
|
||||
C.ReleaseOrtValue(s.ortValue)
|
||||
s.ortValue = nil
|
||||
s.data = nil
|
||||
s.dataSize = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetData returns the undelying data for the scalar. If you want to set the
|
||||
// scalar's data, use Set.
|
||||
func (t *Scalar[T]) GetData() T {
|
||||
return t.data[0]
|
||||
}
|
||||
|
||||
// Changes the underlying value of the scalar to the new value.
|
||||
func (t *Scalar[T]) Set(value T) {
|
||||
t.data = []T{value}
|
||||
}
|
||||
|
||||
func (t *Scalar[T]) DataType() C.ONNXTensorElementDataType {
|
||||
return GetTensorElementDataType[T]()
|
||||
}
|
||||
|
||||
func (t *Scalar[_]) GetInternals() *ValueInternalData {
|
||||
return &ValueInternalData{
|
||||
ortValue: t.ortValue,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Scalar[_]) GetONNXType() ONNXType {
|
||||
return ONNXTypeTensor
|
||||
}
|
||||
|
||||
// NewEmptyScalar creates a new scalar of type T.
|
||||
func NewEmptyScalar[T TensorData]() (*Scalar[T], error) {
|
||||
var data T
|
||||
return NewScalar(data)
|
||||
}
|
||||
|
||||
// NewScalar creates a new scalar of type T backed by a value of type T.
|
||||
// Note that, differently from tensors, this is not a []T but just a value T.
|
||||
func NewScalar[T TensorData](data T) (*Scalar[T], error) {
|
||||
if !IsInitialized() {
|
||||
return nil, NotInitializedError
|
||||
}
|
||||
|
||||
dataSlice := []T{data}
|
||||
var ortValue *C.OrtValue
|
||||
dataType := GetTensorElementDataType[T]()
|
||||
dataSize := unsafe.Sizeof(dataSlice[0]) * uintptr(1)
|
||||
|
||||
status := C.CreateOrtTensorWithShape(unsafe.Pointer(&dataSlice[0]),
|
||||
C.size_t(dataSize), nil, C.int64_t(0), ortMemoryInfo, dataType, &ortValue)
|
||||
if status != nil {
|
||||
return nil, statusToError(status)
|
||||
}
|
||||
toReturn := Scalar[T]{
|
||||
data: dataSlice,
|
||||
dataSize: dataSize,
|
||||
ortValue: ortValue,
|
||||
}
|
||||
return &toReturn, nil
|
||||
}
|
||||
|
||||
|
||||
// Holds options required when enabling the CUDA backend for a session. This
|
||||
// struct wraps C onnxruntime types; users must create instances of this using
|
||||
// the NewCUDAProviderOptions() function. So, to enable CUDA for a session,
|
||||
|
@@ -1461,6 +1461,41 @@ func TestSessionFromDataBuffer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestScalar(t *testing.T) {
|
||||
InitializeRuntime(t)
|
||||
defer CleanupRuntime(t)
|
||||
s, e := NewEmptyScalar[float32]()
|
||||
if e != nil {
|
||||
t.Fatalf("Error creating empty scalar: %s\n", e)
|
||||
}
|
||||
if s.GetData() != 0.0 {
|
||||
t.Fatalf("Empty scalar not initialized to 0: %s\n", e)
|
||||
}
|
||||
e = s.Destroy()
|
||||
if e != nil {
|
||||
t.Fatalf("Failed destroying scalar: %s\n", e)
|
||||
}
|
||||
s2, e := NewScalar(int64(1337))
|
||||
if e != nil {
|
||||
t.Fatalf("Failed creating int64 scalar: %s\n", e)
|
||||
}
|
||||
defer s2.Destroy()
|
||||
contents := s2.GetData()
|
||||
if contents != 1337 {
|
||||
t.Fatalf("Incorrect initial contents of s2: %d\n", contents)
|
||||
}
|
||||
s2.ZeroContents()
|
||||
contents = s2.GetData()
|
||||
if contents != 0 {
|
||||
t.Fatalf("Incorrect value of s2 after zeroing: %d\n", contents)
|
||||
}
|
||||
s2.Set(1234)
|
||||
contents = s2.GetData()
|
||||
if contents != 1234 {
|
||||
t.Fatalf("Incorrect value of s2: %d (expected 1234)\n", contents)
|
||||
}
|
||||
}
|
||||
|
||||
// See the comment in generate_network_big_compute.py for information about
|
||||
// the inputs and outputs used for testing or benchmarking session options.
|
||||
func prepareBenchmarkTensors(t testing.TB, seed int64) (*Tensor[float32],
|
||||
|
731
onnxruntime_training_c_api.h
vendored
731
onnxruntime_training_c_api.h
vendored
@@ -1,731 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// This file contains the training c apis.
|
||||
|
||||
#pragma once
|
||||
#include <stdbool.h>
|
||||
#include "onnxruntime_c_api.h"
|
||||
|
||||
/** \page training_c_cpp_api Training C & C++ APIs
|
||||
*
|
||||
* Training C and C++ APIs are an extension of the \ref c_cpp_api "onnxruntime core C and C++ APIs" and should be used in conjunction with them.
|
||||
*
|
||||
* In order to train a model with onnxruntime, the following training artifacts must be generated:
|
||||
* - The training onnx model
|
||||
* - The checkpoint file
|
||||
* - The optimizer onnx model
|
||||
* - The eval onnx model model (optional)
|
||||
*
|
||||
* These training artifacts can be generated as part of an offline step using the python [utilities](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md) made available in the `onnxruntime-training` python package.
|
||||
*
|
||||
* After these artifacts have been generated, the C and C++ utilities listed in this documentation can be leveraged to perform training.
|
||||
*
|
||||
* If any problem is encountered, please create an [issue](https://github.com/microsoft/onnxruntime/issues/new) with your scenario and requirements, and we will be sure to respond and follow up on the request.
|
||||
*
|
||||
* <h1>Training C API</h1>
|
||||
*
|
||||
* ::OrtTrainingApi - Training C API functions.
|
||||
*
|
||||
* This C structure contains functions that enable users to perform training with onnxruntime.
|
||||
*
|
||||
* _Sample Code_:
|
||||
*
|
||||
* ```c
|
||||
* #include <onnxruntime_training_api.h>
|
||||
*
|
||||
* OrtApi* g_ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
|
||||
* OrtTrainingApi* g_ort_training_api = g_ort_api->GetTrainingApi(ORT_API_VERSION);
|
||||
*
|
||||
* OrtEnv* env = NULL;
|
||||
* g_ort_api->CreateEnv(logging_level, logid, &env);
|
||||
* OrtSessionOptions* session_options = NULL;
|
||||
* g_ort_api->CreateSessionOptions(&session_options);
|
||||
*
|
||||
* OrtCheckpointState* state = NULL;
|
||||
* g_ort_training_api->LoadCheckpoint(path_to_checkpoint, &state);
|
||||
*
|
||||
* OrtTrainingSession* training_session = NULL;
|
||||
* g_ort_training_api->CreateTrainingSession(env, session_options, training_model_path,
|
||||
* state, eval_model_path, optimizer_model_path,
|
||||
* &training_session);
|
||||
* // Training loop
|
||||
* {
|
||||
* g_ort_training_api->TrainStep(...);
|
||||
* g_ort_training_api->OptimizerStep(...);
|
||||
* g_ort_training_api->LazyResetGrad(...);
|
||||
* }
|
||||
*
|
||||
* g_ort_training_api->ExportModelForInferencing(training_session, inference_model_path, ...);
|
||||
* g_ort_training_api->SaveCheckpoint(state, path_to_checkpoint, false);
|
||||
*
|
||||
* g_ort_training_api->ReleaseTrainingSession(training_session);
|
||||
* g_ort_training_api->ReleaseCheckpointState(state);
|
||||
* ```
|
||||
*
|
||||
* > **Note**
|
||||
* > The ::OrtCheckpointState contains the entire training state that the ::OrtTrainingSession uses. As a result, the training session must always have access to the state. That is to say, the ::OrtCheckpointState instance must outlive the lifetime of the ::OrtTrainingSession instance.
|
||||
*
|
||||
* <h1>Training C++ API</h1>
|
||||
*
|
||||
* @ref TrainingCpp - Training C++ API classes and functions.
|
||||
*
|
||||
* These C++ classes and functions enable users to perform training with onnxruntime.
|
||||
*
|
||||
* _Sample Code_:
|
||||
*
|
||||
* ```cc
|
||||
* #include <onnxruntime_training_cxx_api.h>
|
||||
*
|
||||
* Ort::Env env;
|
||||
* Ort::SessionOptions session_options;
|
||||
*
|
||||
* auto state = Ort::CheckpointState::LoadCheckpoint(path_to_checkpoint);
|
||||
* auto training_session = Ort::TrainingSession(env, session_options, state, training_model_path,
|
||||
* eval_model_path, optimizer_model_path);
|
||||
*
|
||||
* // Training Loop
|
||||
* {
|
||||
* training_session.TrainStep(...);
|
||||
* training_session.OptimizerStep(...);
|
||||
* training_session.LazyResetGrad(...);
|
||||
* }
|
||||
*
|
||||
* training_session->ExportModelForInferencing(inference_model_path, ...);
|
||||
* Ort::CheckpointState::SaveCheckpoint(state, path_to_checkpoint, false);
|
||||
* ```
|
||||
* > **Note**
|
||||
* > The ::Ort::CheckpointState contains the entire training state that the ::Ort::TrainingSession uses. As a result, the training session must always have access to the state. That is to say, the ::Ort::CheckpointState instance must outlive the lifetime of the ::Ort::TrainingSession instance.
|
||||
*/
|
||||
|
||||
/** @defgroup TrainingC Ort Training C API
|
||||
* @{
|
||||
*/
|
||||
ORT_RUNTIME_CLASS(TrainingSession); // Type that enables performing training for the given user models.
|
||||
ORT_RUNTIME_CLASS(CheckpointState); // Type that holds the training states for the training session.
|
||||
|
||||
/** \brief Type of property to be added to or returned from the ::OrtCheckpointState.
|
||||
*/
|
||||
typedef enum OrtPropertyType {
|
||||
OrtIntProperty = 0,
|
||||
OrtFloatProperty = 1,
|
||||
OrtStringProperty = 2,
|
||||
} OrtPropertyType;
|
||||
|
||||
/** \brief The Training C API that holds onnxruntime training function pointers
|
||||
*
|
||||
* All the Training C API functions are defined inside this structure as pointers to functions.
|
||||
* Call OrtApi::GetTrainingApi to get a pointer to this struct.
|
||||
*
|
||||
* \nosubgrouping
|
||||
*/
|
||||
struct OrtTrainingApi {
|
||||
/// \name Accessing The Training Session State
|
||||
/// @{
|
||||
|
||||
/** \brief Load a checkpoint state from a file on disk into checkpoint_state.
|
||||
*
|
||||
* This function will parse a checkpoint file, pull relevant data and load the training
|
||||
* state into the checkpoint_state. This checkpoint state can then be used to create the
|
||||
* training session by invoking OrtTrainingApi::CreateTrainingSession. By doing so, the training
|
||||
* session will resume training from the given checkpoint state.
|
||||
* \note Note that the training session created with a checkpoint state uses this state to store the entire
|
||||
* training state (including model parameters, its gradients, the optimizer states and the properties).
|
||||
* As a result, it is required that the checkpoint state outlive the lifetime of the training session.
|
||||
* \note Note that the checkpoint file can be either the complete checkpoint or the nominal checkpoint.
|
||||
*
|
||||
* \param[in] checkpoint_path Path to the checkpoint file
|
||||
* \param[out] checkpoint_state Checkpoint state that contains the states of the training session.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path,
|
||||
_Outptr_ OrtCheckpointState** checkpoint_state);
|
||||
|
||||
/** \brief Save the given state to a checkpoint file on disk.
|
||||
*
|
||||
* This function serializes the provided checkpoint state to a file on disk.
|
||||
* This checkpoint can later be loaded by invoking OrtTrainingApi::LoadCheckpoint to resume
|
||||
* training from this snapshot of the state.
|
||||
*
|
||||
* \param[in] checkpoint_state The checkpoint state to save.
|
||||
* \param[in] checkpoint_path Path to the checkpoint file.
|
||||
* \param[in] include_optimizer_state Flag to indicate whether to save the optimizer state or not.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(SaveCheckpoint, _In_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* checkpoint_path,
|
||||
const bool include_optimizer_state);
|
||||
|
||||
/// @}
|
||||
|
||||
/// \name Implementing The Training Loop
|
||||
/// @{
|
||||
/** \brief Create a training session that can be used to begin or resume training.
|
||||
*
|
||||
* This function creates a training session based on the env and session options provided that can
|
||||
* begin or resume training from a given checkpoint state for the given onnx models.
|
||||
* The checkpoint state represents the parameters of the training session which will be moved
|
||||
* to the device specified by the user through the session options (if necessary).
|
||||
* The training session requires four training artifacts
|
||||
* - The training onnx model
|
||||
* - The evaluation onnx model (optional)
|
||||
* - The optimizer onnx model
|
||||
* - The checkpoint file
|
||||
*
|
||||
* These artifacts can be generated using the `onnxruntime-training` python [utility](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md).
|
||||
*
|
||||
* \param[in] env Environment to be used for the training session.
|
||||
* \param[in] options Session options that the user can customize for this training session.
|
||||
* \param[in] checkpoint_state Training states that the training session uses as a starting point for training.
|
||||
* \param[in] train_model_path Model to be used to perform training.
|
||||
* \param[in] eval_model_path Model to be used to perform evaluation.
|
||||
* \param[in] optimizer_model_path Model to be used to perform gradient descent.
|
||||
* \param[out] out Created training session.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options,
|
||||
_Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path,
|
||||
_In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path,
|
||||
_Outptr_result_maybenull_ OrtTrainingSession** out);
|
||||
|
||||
/** \brief Create a training session that can be used to begin or resume training.
|
||||
* This api provides a way to load all the training artifacts from buffers instead of files.
|
||||
*
|
||||
* \param[in] env Environment to be used for the training session.
|
||||
* \param[in] options Session options that the user can customize for this training session.
|
||||
* \param[in] checkpoint_state Training states that the training session uses as a starting point for training.
|
||||
* \param[in] train_model_data Buffer containing the model data to be used to perform training
|
||||
* \param[in] train_data_length Length of the buffer containing train_model_data
|
||||
* \param[in] eval_model_data Buffer containing the model data to be used to perform evaluation
|
||||
* \param[in] eval_data_length Length of the buffer containing eval_model_data
|
||||
* \param[in] optim_model_data Buffer containing the model data to be used to perform weight update
|
||||
* \param[in] optim_data_length Length of the buffer containing optim_model_data
|
||||
* \param[out] out Created training session.
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(CreateTrainingSessionFromBuffer, _In_ const OrtEnv* env,
|
||||
_In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state,
|
||||
_In_ const void* train_model_data, size_t train_data_length,
|
||||
_In_ const void* eval_model_data, size_t eval_data_length,
|
||||
_In_ const void* optim_model_data, size_t optim_data_length,
|
||||
_Outptr_result_maybenull_ OrtTrainingSession** out);
|
||||
|
||||
/// @}
|
||||
|
||||
/// \name Model IO Information
|
||||
/// @{
|
||||
|
||||
/** \brief Retrieves the number of user outputs in the training model.
|
||||
*
|
||||
* This function returns the number of outputs of the training model so that the user can
|
||||
* allocate space for the number of outputs when OrtTrainingApi::TrainStep is invoked.
|
||||
*
|
||||
* \param[in] sess The `this` pointer to the training session.
|
||||
* \param[out] out Number of user outputs in the training model.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(TrainingSessionGetTrainingModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
|
||||
|
||||
/** \brief Retrieves the number of user outputs in the eval model.
|
||||
*
|
||||
* This function returns the number of outputs of the eval model so that the user can
|
||||
* allocate space for the number of outputs when OrtTrainingApi::EvalStep is invoked.
|
||||
*
|
||||
* \param[in] sess The `this` pointer to the training session.
|
||||
* \param[out] out Number of user outputs in the eval model.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(TrainingSessionGetEvalModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
|
||||
|
||||
/** \brief Retrieves the names of user outputs in the training model.
|
||||
*
|
||||
* This function returns the names of outputs of the training model that can be associated with the OrtValue(s)
|
||||
* returned by the OrtTrainingApi::TrainStep function.
|
||||
*
|
||||
* \param[in] sess The `this` pointer to the training session.
|
||||
* \param[in] index Index of the output name requested.
|
||||
* \param[in] allocator Allocator to use to allocate the memory for the name.
|
||||
* \param[out] output Name of the training model output at the given index.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(TrainingSessionGetTrainingModelOutputName, _In_ const OrtTrainingSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** output);
|
||||
|
||||
/** \brief Retrieves the names of user outputs in the eval model.
|
||||
*
|
||||
* This function returns the names of outputs of the eval model that can be associated with the OrtValue(s) returned
|
||||
* by the OrtTrainingApi::EvalStep function.
|
||||
*
|
||||
* \param[in] sess The `this` pointer to the training session.
|
||||
* \param[in] index Index of the output name requested.
|
||||
* \param[in] allocator Allocator to use to allocate the memory for the name.
|
||||
* \param[out] output Name of the eval model output at the given index.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(TrainingSessionGetEvalModelOutputName, _In_ const OrtTrainingSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** output);
|
||||
|
||||
/// @}
|
||||
|
||||
/// \name Implementing The Training Loop
|
||||
/// @{
|
||||
|
||||
/** \brief Reset the gradients of all trainable parameters to zero lazily.
|
||||
*
|
||||
* This function sets the internal state of the training session such that the gradients of the trainable
|
||||
* parameters in the OrtCheckpointState will be scheduled to be reset just before the new gradients are
|
||||
* computed on the next invocation of the next OrtTrainingApi::TrainStep.
|
||||
*
|
||||
* \param[in] session The `this` pointer to the training session.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(LazyResetGrad, _Inout_ OrtTrainingSession* session);
|
||||
|
||||
/** \brief Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs
|
||||
*
|
||||
* This function performs a training step that computes the outputs of the training model and the gradients
|
||||
* of the trainable parameters for the given inputs. The train step is performed based on the training model
|
||||
* that was provided to the training session.
|
||||
* The OrtTrainingApi::TrainStep is equivalent of running forward propagation and backward propagation in a single
|
||||
* step.
|
||||
* The gradients computed are stored inside the training session state so they can be later consumed
|
||||
* by the OrtTrainingApi::OptimizerStep function.
|
||||
* The gradients can be lazily reset by invoking the OrtTrainingApi::LazyResetGrad function.
|
||||
*
|
||||
* \param[in] sess The `this` pointer to the training session.
|
||||
* \param[in] run_options Run options for this training step.
|
||||
* \param[in] inputs_len Number of user inputs to the training model.
|
||||
* \param[in] inputs The user inputs to the training model.
|
||||
* \param[in] outputs_len Number of user outputs expected from this training step.
|
||||
* \param[out] outputs User outputs computed by train step.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(TrainStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
|
||||
_In_ size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
|
||||
_In_ size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
|
||||
|
||||
/** \brief Computes the outputs for the eval model for the given inputs
|
||||
*
|
||||
* This function performs an eval step that computes the outputs of the eval model for the given inputs.
|
||||
* The eval step is performed based on the eval model that was provided to the training session.
|
||||
*
|
||||
* \param[in] sess The `this` pointer to the training session.
|
||||
* \param[in] run_options Run options for this eval step.
|
||||
* \param[in] inputs_len Number of user inputs to the eval model.
|
||||
* \param[in] inputs The user inputs to the eval model.
|
||||
* \param[in] outputs_len Number of user outputs expected from this eval step.
|
||||
* \param[out] outputs User outputs computed by eval step.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(EvalStep, _In_ const OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options,
|
||||
_In_ size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs,
|
||||
_In_ size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs);
|
||||
|
||||
/** \brief Sets the learning rate for this training session.
|
||||
*
|
||||
* This function allows users to set the learning rate for the training session. The current
|
||||
* learning rate is maintained by the training session and can be overwritten by invoking
|
||||
* this function with the desired learning rate. This function should not be used when a valid
|
||||
* learning rate scheduler is registered. It should be used either to set the learning rate
|
||||
* derived from a custom learning rate scheduler or to set a constant learning rate to be used
|
||||
* throughout the training session.
|
||||
* \note Please note that this function does not set the initial learning rate that may be needed
|
||||
* by the predefined learning rate schedulers. To set the initial learning rate for learning
|
||||
* rate schedulers, please look at the function OrtTrainingApi::RegisterLinearLRScheduler.
|
||||
*
|
||||
* \param[in] sess The `this` pointer to the training session.
|
||||
* \param[in] learning_rate Desired learning rate to be set.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(SetLearningRate, _Inout_ OrtTrainingSession* sess, _In_ float learning_rate);
|
||||
|
||||
/** \brief Gets the current learning rate for this training session.
|
||||
*
|
||||
* This function allows users to get the learning rate for the training session. The current
|
||||
* learning rate is maintained by the training session, and users can query it for the purpose
|
||||
* of implementing their own learning rate schedulers.
|
||||
*
|
||||
* \param[in] sess The `this` pointer to the training session.
|
||||
* \param[out] learning_rate Learning rate currently in use by the training session.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(GetLearningRate, _Inout_ OrtTrainingSession* sess, _Out_ float* learning_rate);
|
||||
|
||||
/** \brief Performs the weight updates for the trainable parameters using the optimizer model.
|
||||
*
|
||||
* This function performs the weight update step that updates the trainable parameters such that they
|
||||
* take a step in the direction of their gradients (gradient descent). The optimizer step is performed
|
||||
* based on the optimizer model that was provided to the training session.
|
||||
* The updated parameters are stored inside the training state so that they can be used by the next
|
||||
* OrtTrainingApi::TrainStep function call.
|
||||
*
|
||||
* \param[in] sess The `this` pointer to the training session.
|
||||
* \param[in] run_options Run options for this optimizer step.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(OptimizerStep, _Inout_ OrtTrainingSession* sess,
|
||||
_In_opt_ const OrtRunOptions* run_options);
|
||||
|
||||
/** \brief Registers a linear learning rate scheduler for the training session.
|
||||
*
|
||||
* Register a linear learning rate scheduler that decays the learning rate by linearly updated
|
||||
* multiplicative factor from the initial learning rate set on the training session to 0. The decay
|
||||
* is performed after the initial warm up phase where the learning rate is linearly incremented
|
||||
* from 0 to the initial learning rate provided.
|
||||
*
|
||||
* \param[in] sess The `this` pointer to the training session.
|
||||
* \param[in] warmup_step_count Warmup steps for LR warmup.
|
||||
* \param[in] total_step_count Total step count.
|
||||
* \param[in] initial_lr The initial learning rate to be used by the training session.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(RegisterLinearLRScheduler, _Inout_ OrtTrainingSession* sess, _In_ const int64_t warmup_step_count,
|
||||
_In_ const int64_t total_step_count, _In_ const float initial_lr);
|
||||
|
||||
/** \brief Update the learning rate based on the registered learing rate scheduler.
|
||||
*
|
||||
* Takes a scheduler step that updates the learning rate that is being used by the training session.
|
||||
* This function should typically be called before invoking the optimizer step for each round,
|
||||
* or as determined necessary to update the learning rate being used by the training session.
|
||||
* \note Please note that a valid predefined learning rate scheduler must be first registered to invoke this
|
||||
* function.
|
||||
*
|
||||
* \param[in] sess The `this` pointer to the training session.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(SchedulerStep, _Inout_ OrtTrainingSession* sess);
|
||||
|
||||
/// @}
|
||||
|
||||
/// \name Accessing The Training Session State
|
||||
/// @{
|
||||
/** \brief Retrieves the size of all the parameters.
|
||||
*
|
||||
* Calculates the total number of primitive (datatype of the parameters) elements of all the parameters in the
|
||||
* training state.
|
||||
* When trainable_only argument is true, the size is calculated for trainable params only.
|
||||
*
|
||||
* \param[in] sess The `this` pointer to the training session.
|
||||
* \param[out] out Size of all parameter elements.
|
||||
* \param[in] trainable_only Whether to skip non-trainable parameters
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(GetParametersSize, _Inout_ OrtTrainingSession* sess, _Out_ size_t* out, bool trainable_only);
|
||||
|
||||
/** \brief Copy all parameters to a contiguous buffer held by the argument parameters_buffer
|
||||
*
|
||||
* The parameters_buffer has to be of the size given by GetParametersSize api call,
|
||||
* with matching setting for the argument trainable_only. All the target parameters must be of the same
|
||||
* datatype. The OrtValue must be pre-allocated onto
|
||||
* the desired device. This is a complementary function to OrtTrainingApi::CopyBufferToParameters.
|
||||
* Parameter ordering is preserved.
|
||||
* User is responsible for allocating and freeing the resources used by the parameters_buffer.
|
||||
*
|
||||
* \param[in] sess The `this` pointer to the training session.
|
||||
* \param[in] trainable_only Whether to skip non-trainable parameters
|
||||
* \param[out] parameters_buffer The pre-allocated OrtValue buffer to copy onto.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(CopyParametersToBuffer, _Inout_ OrtTrainingSession* sess,
|
||||
_Inout_ OrtValue* parameters_buffer, bool trainable_only);
|
||||
|
||||
/** \brief Copy parameter values from the given contiguous buffer held by parameters_buffer to the training state
|
||||
*
|
||||
* The parameters_buffer argument has to be of the size given by OrtTrainingApi::GetParametersSize api call,
|
||||
* with matching setting for trainable_only argument. All the target parameters must be of the same
|
||||
* datatype. This is a complementary function to OrtTrainingApi::CopyParametersToBuffer
|
||||
* and can be used to load updated buffer values onto the training state.
|
||||
* Parameter ordering is preserved.
|
||||
* User is responsible for allocating and freeing the resources used by the parameters_buffer.
|
||||
* In case the training session was created with a nominal checkpoint, invoking this function is required
|
||||
* to load the updated parameters onto the checkpoint to complete it.
|
||||
*
|
||||
* \param[in] sess The `this` pointer to the training session.
|
||||
* \param[in] trainable_only Whether to skip non-trainable parameters
|
||||
* \param[out] parameters_buffer The pre-allocated OrtValue buffer to copy from.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(CopyBufferToParameters, _Inout_ OrtTrainingSession* sess,
|
||||
_Inout_ OrtValue* parameters_buffer, bool trainable_only);
|
||||
|
||||
/// @}
|
||||
|
||||
/// \name Release Training Resources
|
||||
/// @{
|
||||
|
||||
/** \brief Frees up the memory used up by the training session.
|
||||
*
|
||||
* This function frees up any memory that was allocated in the training session. The training
|
||||
* session can no longer be used after this call.
|
||||
*
|
||||
*/
|
||||
ORT_CLASS_RELEASE(TrainingSession);
|
||||
|
||||
/** \brief Frees up the memory used up by the checkpoint state.
|
||||
*
|
||||
* This function frees up any memory that was allocated in the checkpoint state. The checkpoint
|
||||
* state can no longer be used after this call.
|
||||
* \note Note that the checkpoint state must be released only after the training session has been released.
|
||||
*
|
||||
*/
|
||||
ORT_CLASS_RELEASE(CheckpointState);
|
||||
|
||||
/// @}
|
||||
|
||||
/// \name Prepare For Inferencing
|
||||
/// @{
|
||||
/** \brief Export a model that can be used for inferencing.
|
||||
*
|
||||
* If the training session was provided with an eval model, the training session can generate
|
||||
* an inference model if it knows the inference graph outputs. The input inference graph outputs
|
||||
* are used to prune the eval model so that the inference model's outputs align with the provided outputs.
|
||||
* The exported model is saved at the path provided and can be used for inferencing with InferenceSession.
|
||||
* \note Note that the function re-loads the eval model from the path provided to OrtTrainingApi::CreateTrainingSession
|
||||
* and expects that this path still be valid.
|
||||
*
|
||||
* \param[in] sess The `this` pointer to the training session.
|
||||
* \param[in] inference_model_path Path where the inference model should be serialized to.
|
||||
* \param[in] graph_outputs_len Size of the graph output names array.
|
||||
* \param[in] graph_output_names Names of the outputs that are needed in the inference model.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(ExportModelForInferencing, _Inout_ OrtTrainingSession* sess,
|
||||
_In_ const ORTCHAR_T* inference_model_path, size_t graph_outputs_len,
|
||||
_In_reads_(graph_outputs_len) const char* const* graph_output_names);
|
||||
|
||||
/// @}
|
||||
|
||||
/// \name Training Utilities
|
||||
/// @{
|
||||
/** \brief Sets the seed used for random number generation in Onnxruntime.
|
||||
*
|
||||
* Use this function to generate reproducible results. It should be noted that completely reproducible
|
||||
* results are not guaranteed.
|
||||
*
|
||||
* \param[in] seed The seed to be set.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(SetSeed, _In_ const int64_t seed);
|
||||
|
||||
/// @}
|
||||
|
||||
/// \name Model IO Information
|
||||
/// @{
|
||||
/** \brief Retrieves the number of user inputs in the training model.
|
||||
*
|
||||
* This function returns the number of inputs of the training model so that the user can accordingly
|
||||
* allocate the OrtValue(s) provided to the OrtTrainingApi::TrainStep function.
|
||||
*
|
||||
* \param[in] sess The `this` pointer to the training session.
|
||||
* \param[out] out Number of user inputs in the training model.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(TrainingSessionGetTrainingModelInputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
|
||||
|
||||
/** \brief Retrieves the number of user inputs in the eval model.
|
||||
*
|
||||
* This function returns the number of inputs of the eval model so that the user can accordingly
|
||||
* allocate the OrtValue(s) provided to the OrtTrainingApi::EvalStep function.
|
||||
*
|
||||
* \param[in] sess The `this` pointer to the training session.
|
||||
* \param[out] out Number of user inputs in the eval model.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(TrainingSessionGetEvalModelInputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out);
|
||||
|
||||
/** \brief Retrieves the name of the user input at given index in the training model.
|
||||
*
|
||||
* This function returns the names of inputs of the training model that can be associated with the
|
||||
* OrtValue(s) provided to the OrtTrainingApi::TrainStep function.
|
||||
*
|
||||
* \param[in] sess The `this` pointer to the training session.
|
||||
* \param[in] index The index of the training model input name requested.
|
||||
* \param[in] allocator The allocator to use to allocate the memory for the requested name.
|
||||
* \param[out] output Name of the user input for the training model at the given index.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(TrainingSessionGetTrainingModelInputName, _In_ const OrtTrainingSession* sess, size_t index,
|
||||
_In_ OrtAllocator* allocator, _Outptr_ char** output);
|
||||
|
||||
/** \brief Retrieves the name of the user input at given index in the eval model.
|
||||
*
|
||||
* This function returns the names of inputs of the eval model that can be associated with the OrtValue(s) provided
|
||||
* to the OrtTrainingApi::EvalStep function.
|
||||
*
|
||||
* \param[in] sess The `this` pointer to the training session.
|
||||
* \param[in] index The index of the eval model input name requested.
|
||||
* \param[in] allocator The allocator to use to allocate the memory for the requested name.
|
||||
* \param[out] output Name of the user input for the eval model at the given index.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(TrainingSessionGetEvalModelInputName, _In_ const OrtTrainingSession* sess, size_t index,
|
||||
_In_ OrtAllocator* allocator, _Outptr_ char** output);
|
||||
|
||||
/// @}
|
||||
|
||||
/// \name Accessing The Training Session State
|
||||
/// @{
|
||||
|
||||
/** \brief Adds or updates the given property to/in the checkpoint state.
|
||||
*
|
||||
* Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint
|
||||
* state by the user by calling this function with the corresponding property name and value.
|
||||
* The given property name must be unique to be able to successfully add the property.
|
||||
*
|
||||
* \param[in] checkpoint_state The checkpoint state which should hold the property.
|
||||
* \param[in] property_name Name of the property being added or updated.
|
||||
* \param[in] property_type Type of the property associated with the given name.
|
||||
* \param[in] property_value Property value associated with the given name.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(AddProperty, _Inout_ OrtCheckpointState* checkpoint_state,
|
||||
_In_ const char* property_name, _In_ enum OrtPropertyType property_type,
|
||||
_In_ void* property_value);
|
||||
|
||||
/** \brief Gets the property value associated with the given name from the checkpoint state.
|
||||
*
|
||||
* Gets the property value from an existing entry in the checkpoint state. The property must
|
||||
* exist in the checkpoint state to be able to retrieve it successfully.
|
||||
*
|
||||
* \param[in] checkpoint_state The checkpoint state that is currently holding the property.
|
||||
* \param[in] property_name Name of the property being retrieved.
|
||||
* \param[in] allocator Allocator used to allocate the memory for the property_value.
|
||||
* \param[out] property_type Type of the property associated with the given name.
|
||||
* \param[out] property_value Property value associated with the given name.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(GetProperty, _In_ const OrtCheckpointState* checkpoint_state,
|
||||
_In_ const char* property_name, _Inout_ OrtAllocator* allocator,
|
||||
_Out_ enum OrtPropertyType* property_type, _Outptr_ void** property_value);
|
||||
|
||||
/// @}
|
||||
|
||||
/// \name Accessing The Training Session State
|
||||
/// @{
|
||||
|
||||
/** \brief Load a checkpoint state from a buffer into checkpoint_state.
|
||||
*
|
||||
* This function will parse a checkpoint bytes buffer, pull relevant data and load the training
|
||||
* state into the checkpoint_state. This checkpoint state can then be used to create the
|
||||
* training session by invoking OrtTrainingApi::CreateTrainingSession. By doing so, the training
|
||||
* session will resume training from the given checkpoint state.
|
||||
* \note Note that the training session created with a checkpoint state uses this state to store the entire
|
||||
* training state (including model parameters, its gradients, the optimizer states and the properties).
|
||||
* As a result, it is required that the checkpoint state outlive the lifetime of the training session.
|
||||
*
|
||||
* \param[in] checkpoint_buffer Path to the checkpoint bytes buffer.
|
||||
* \param[in] num_bytes Number of bytes in the checkpoint buffer.
|
||||
* \param[out] checkpoint_state Checkpoint state that contains the states of the training session.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer,
|
||||
_In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state);
|
||||
|
||||
/** \brief Retrieves the type and shape information of the parameter associated with the given parameter name.
|
||||
*
|
||||
* This function retrieves the type and shape of the parameter associated with the given parameter name.
|
||||
* The parameter must exist in the checkpoint state to be able to retrieve its type and shape information successfully.
|
||||
*
|
||||
* \param[in] checkpoint_state The checkpoint state.
|
||||
* \param[in] parameter_name Name of the parameter being retrieved.
|
||||
* \param[out] parameter_type_and_shape The type and shape of the parameter being retrieved.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state,
|
||||
_In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape);
|
||||
|
||||
/** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name.
|
||||
*
|
||||
* This function updates a model parameter in the checkpoint state with the given parameter data.
|
||||
* The training session must be already created with the checkpoint state that contains the parameter
|
||||
* being updated. The given parameter is copied over to the registered device for the training session.
|
||||
* The parameter must exist in the checkpoint state to be able to update it successfully.
|
||||
*
|
||||
* \param[in] checkpoint_state The checkpoint state.
|
||||
* \param[in] parameter_name Name of the parameter being updated.
|
||||
* \param[in] parameter The parameter data that should replace the existing parameter data.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state,
|
||||
_In_ const char* parameter_name, _In_ OrtValue* parameter);
|
||||
|
||||
/** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name.
|
||||
*
|
||||
* This function retrieves the model parameter data from the checkpoint state for the given parameter name.
|
||||
* The parameter is copied over and returned as an OrtValue. The training session must be already created
|
||||
* with the checkpoint state that contains the parameter being retrieved.
|
||||
* The parameter must exist in the checkpoint state to be able to retrieve it successfully.
|
||||
*
|
||||
* \param[in] checkpoint_state The checkpoint state.
|
||||
* \param[in] parameter_name Name of the parameter being retrieved.
|
||||
* \param[in] allocator Allocator used to allocate the memory for the parameter.
|
||||
* \param[out] parameter The parameter data that is retrieved from the checkpoint state.
|
||||
*
|
||||
* \snippet{doc} snippets.dox OrtStatus Return Value
|
||||
*
|
||||
*/
|
||||
ORT_API2_STATUS(GetParameter, _In_ const OrtCheckpointState* checkpoint_state,
|
||||
_In_ const char* parameter_name, _Inout_ OrtAllocator* allocator,
|
||||
_Outptr_ OrtValue** parameter);
|
||||
|
||||
/// @}
|
||||
};
|
||||
|
||||
typedef struct OrtTrainingApi OrtTrainingApi;
|
||||
|
||||
/// @}
|
@@ -1,639 +0,0 @@
|
||||
package onnxruntime_go
|
||||
|
||||
// #cgo CFLAGS: -O2 -g
|
||||
//
|
||||
// #include "onnxruntime_wrapper.h"
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var trainingNotSupportedError error = fmt.Errorf("training not supported by onnx library")
|
||||
|
||||
// Scalar is like a tensor but the underlying go slice is of length 1 and it has no dimension.
|
||||
// It can be used to store e.g. the loss from a training cycle.
|
||||
type Scalar[T TensorData] struct {
|
||||
data []T
|
||||
dataSize uintptr
|
||||
ortValue *C.OrtValue
|
||||
}
|
||||
|
||||
func (s *Scalar[T]) GetShape() Shape {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Scalar[T]) ZeroContents() {
|
||||
C.memset(unsafe.Pointer(&s.data[0]), 0, C.size_t(s.dataSize))
|
||||
}
|
||||
|
||||
func (s *Scalar[T]) Destroy() error {
|
||||
C.ReleaseOrtValue(s.ortValue)
|
||||
s.ortValue = nil
|
||||
s.data = nil
|
||||
s.dataSize = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetData returns the undelying data for the scalar.
|
||||
// If you want to explicitly set the scalar's data, use Set.
|
||||
func (t *Scalar[T]) GetData() T {
|
||||
return t.data[0]
|
||||
}
|
||||
|
||||
// Set allows to explicitly set the underlying value for the scalar.
|
||||
func (t *Scalar[T]) Set(value T) {
|
||||
t.data = []T{value}
|
||||
}
|
||||
|
||||
func (t *Scalar[T]) DataType() C.ONNXTensorElementDataType {
|
||||
return GetTensorElementDataType[T]()
|
||||
}
|
||||
|
||||
func (t *Scalar[_]) GetInternals() *ValueInternalData {
|
||||
return &ValueInternalData{
|
||||
ortValue: t.ortValue,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Scalar[_]) GetONNXType() ONNXType {
|
||||
return ONNXTypeTensor
|
||||
}
|
||||
|
||||
// NewEmptyScalar creates a new scalar of type T.
|
||||
func NewEmptyScalar[T TensorData]() (*Scalar[T], error) {
|
||||
var data T
|
||||
return NewScalar(data)
|
||||
}
|
||||
|
||||
// NewScalar creates a new scalar of type T backed by a value of type T.
|
||||
// Note that, differently from tensors, this is not a []T but just a value T.
|
||||
func NewScalar[T TensorData](data T) (*Scalar[T], error) {
|
||||
if !IsInitialized() {
|
||||
return nil, NotInitializedError
|
||||
}
|
||||
|
||||
dataSlice := []T{data}
|
||||
var ortValue *C.OrtValue
|
||||
dataType := GetTensorElementDataType[T]()
|
||||
dataSize := unsafe.Sizeof(dataSlice[0]) * uintptr(1)
|
||||
|
||||
status := C.CreateOrtTensorWithShape(unsafe.Pointer(&dataSlice[0]),
|
||||
C.size_t(dataSize), nil, C.int64_t(0), ortMemoryInfo, dataType, &ortValue)
|
||||
if status != nil {
|
||||
return nil, statusToError(status)
|
||||
}
|
||||
toReturn := Scalar[T]{
|
||||
data: dataSlice,
|
||||
dataSize: dataSize,
|
||||
ortValue: ortValue,
|
||||
}
|
||||
return &toReturn, nil
|
||||
}
|
||||
|
||||
// TraininSession is the type that wraps the C training session object.
|
||||
type TrainingSession struct {
|
||||
ortTrainingSession *C.OrtTrainingSession
|
||||
ortCheckpointState *C.OrtCheckpointState
|
||||
inputs []*C.OrtValue
|
||||
outputs []*C.OrtValue
|
||||
trainingModelPath *C.char
|
||||
optimizerModelPath *C.char
|
||||
evalModelPath *C.char
|
||||
}
|
||||
|
||||
// ExportModel is used to export the final trained model to disk. It requires the path for
|
||||
// the exported model as well as the names of the graph nodes to export.
|
||||
// Note that currently the final model can only be exported if the session has been
|
||||
// initialized with NewTrainingSession and the path to the eval model has been provided.
|
||||
func (s *TrainingSession) ExportModel(path string, outputNames []string) error {
|
||||
if s.evalModelPath == nil {
|
||||
return fmt.Errorf("final model can only be exported if the eval model path is " +
|
||||
"provided at session creation time (see NewTrainingSession)")
|
||||
}
|
||||
if path == "" {
|
||||
return fmt.Errorf("path cannot be empty")
|
||||
}
|
||||
dir, _ := filepath.Split(path)
|
||||
if _, err := os.Stat(dir); dir != "" && os.IsNotExist(err) {
|
||||
return fmt.Errorf("directory %s does not exist", dir)
|
||||
}
|
||||
|
||||
cOutputNames := make([]*C.char, len(outputNames))
|
||||
for i, name := range outputNames {
|
||||
cOutputNames[i] = C.CString(name)
|
||||
}
|
||||
cPath, err := createOrtCharString(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error converting export path to C string: %w", err)
|
||||
}
|
||||
outputLength := C.size_t(len(outputNames))
|
||||
defer func() {
|
||||
for i := range cOutputNames {
|
||||
C.free(unsafe.Pointer(cOutputNames[i]))
|
||||
}
|
||||
C.free(unsafe.Pointer(cPath))
|
||||
}()
|
||||
status := C.ExportModel(s.ortTrainingSession, cPath, outputLength, &cOutputNames[0])
|
||||
if status != nil {
|
||||
return statusToError(status)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveCheckpoint can be used to save the current checkpoint state at the specified path.
|
||||
// This is useful to snapshot the training parameters to continue training later or on
|
||||
// a different machine.
|
||||
func (s *TrainingSession) SaveCheckpoint(path string, saveOptimizerState bool) error {
|
||||
if path == "" {
|
||||
return fmt.Errorf("path cannot be empty")
|
||||
}
|
||||
dir, _ := filepath.Split(path)
|
||||
if _, err := os.Stat(dir); dir != "" && os.IsNotExist(err) {
|
||||
return fmt.Errorf("directory %s does not exist", dir)
|
||||
}
|
||||
|
||||
cPath, err := createOrtCharString(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error converting path to C string: %w", err)
|
||||
}
|
||||
var saveOptimizer int
|
||||
if saveOptimizerState {
|
||||
saveOptimizer = 1
|
||||
}
|
||||
|
||||
defer func() {
|
||||
C.free(unsafe.Pointer(cPath))
|
||||
}()
|
||||
|
||||
status := C.SaveCheckpoint(s.ortCheckpointState, cPath, C.size_t(saveOptimizer))
|
||||
if status != nil {
|
||||
return statusToError(status)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Destroy frees all the C memory associated to a training session.
|
||||
func (s *TrainingSession) Destroy() error {
|
||||
if s.ortTrainingSession != nil {
|
||||
C.ReleaseOrtTrainingSession(s.ortTrainingSession)
|
||||
s.ortTrainingSession = nil
|
||||
}
|
||||
// note: checkpoint MUST be released after session
|
||||
if s.ortCheckpointState != nil {
|
||||
C.ReleaseCheckpointState(s.ortCheckpointState)
|
||||
s.ortCheckpointState = nil
|
||||
}
|
||||
C.free(unsafe.Pointer(s.trainingModelPath))
|
||||
s.trainingModelPath = nil
|
||||
C.free(unsafe.Pointer(s.evalModelPath))
|
||||
s.evalModelPath = nil
|
||||
C.free(unsafe.Pointer(s.optimizerModelPath))
|
||||
s.optimizerModelPath = nil
|
||||
s.inputs = nil
|
||||
s.outputs = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// TrainStep performs the training step.
|
||||
func (s *TrainingSession) TrainStep() error {
|
||||
inputLength := C.size_t(len(s.inputs))
|
||||
outputLength := C.size_t(len(s.outputs))
|
||||
status := C.TrainStep(s.ortTrainingSession, inputLength, &s.inputs[0], outputLength, &s.outputs[0])
|
||||
if status != nil {
|
||||
return fmt.Errorf("error performing training step: %w", statusToError(status))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TrainStep performs the optimizer step.
|
||||
func (s *TrainingSession) OptimizerStep() error {
|
||||
status := C.OptimizerStep(s.ortTrainingSession)
|
||||
if status != nil {
|
||||
return fmt.Errorf("error performing optimizer step: %w", statusToError(status))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TrainStep performs the LazyResetGrad step.
|
||||
func (s *TrainingSession) LazyResetGrad() error {
|
||||
status := C.LazyResetGrad(s.ortTrainingSession)
|
||||
if status != nil {
|
||||
return fmt.Errorf("error performing lazyResetGrad step: %w", statusToError(status))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getInputName(s *C.OrtTrainingSession, i int, model string) (string, error) {
|
||||
var cName *C.char
|
||||
var status *C.OrtStatus
|
||||
switch model {
|
||||
case "train":
|
||||
status = C.TrainingSessionGetTrainingInputName(s, C.size_t(i), &cName)
|
||||
case "eval":
|
||||
status = C.TrainingSessionGetEvalInputName(s, C.size_t(i), &cName)
|
||||
default:
|
||||
return "", fmt.Errorf("%s model not recognized", model)
|
||||
}
|
||||
if status != nil {
|
||||
return "", fmt.Errorf("error getting name: %w", statusToError(status))
|
||||
}
|
||||
|
||||
name, e := convertORTString(cName)
|
||||
if e != nil {
|
||||
return "", fmt.Errorf("error converting C name to Go string: %w", e)
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
|
||||
func getOutputName(s *C.OrtTrainingSession, i int, model string) (string, error) {
|
||||
var cName *C.char
|
||||
var status *C.OrtStatus
|
||||
switch model {
|
||||
case "train":
|
||||
status = C.TrainingSessionGetTrainingOutputName(s, C.size_t(i), &cName)
|
||||
case "eval":
|
||||
status = C.TrainingSessionGetEvalOutputName(s, C.size_t(i), &cName)
|
||||
default:
|
||||
return "", fmt.Errorf("%s model not recognized", model)
|
||||
}
|
||||
if status != nil {
|
||||
return "", fmt.Errorf("error getting name: %w", statusToError(status))
|
||||
}
|
||||
name, e := convertORTString(cName)
|
||||
if e != nil {
|
||||
return "", fmt.Errorf("error converting C name to Go string: %w", e)
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
|
||||
type TrainingInputOutputNames struct {
|
||||
TrainingInputNames []string
|
||||
EvalInputNames []string
|
||||
TrainingOutputNames []string
|
||||
EvalOutputNames []string
|
||||
}
|
||||
|
||||
// GetInputOutputNames returns the names of the training inputs and outputs
|
||||
// for the training and validation models. Eval model is optional and can be empty
|
||||
// string.
|
||||
func GetInputOutputNames(checkpointStatePath string,
|
||||
trainingModelPath string,
|
||||
evalModelPath string) (*TrainingInputOutputNames, error) {
|
||||
|
||||
options, e := NewSessionOptions()
|
||||
if e != nil {
|
||||
return nil, fmt.Errorf("failed creating options with error: %v\n", e)
|
||||
}
|
||||
defer options.Destroy()
|
||||
|
||||
checkpointData, e := os.ReadFile(checkpointStatePath)
|
||||
if e != nil {
|
||||
return nil, fmt.Errorf("error reading %s: %w", checkpointStatePath, e)
|
||||
}
|
||||
|
||||
trainingData, e := os.ReadFile(trainingModelPath)
|
||||
if e != nil {
|
||||
return nil, fmt.Errorf("error reading %s: %w", checkpointStatePath, e)
|
||||
}
|
||||
|
||||
var evalData []byte
|
||||
if evalModelPath != "" {
|
||||
evalData, e = os.ReadFile(evalModelPath)
|
||||
if e != nil {
|
||||
return nil, fmt.Errorf("error reading %s: %w", evalModelPath, e)
|
||||
}
|
||||
}
|
||||
|
||||
// create checkpoint C object
|
||||
ortCheckpointState, e := createCCheckpoint(checkpointData)
|
||||
if e != nil {
|
||||
return nil, fmt.Errorf("error creating C checkpointState: %w", e)
|
||||
}
|
||||
|
||||
// create session C object
|
||||
ortTrainingSession, e := createCTrainingSessionWithOnnxData(ortCheckpointState,
|
||||
trainingData, evalData, nil, options)
|
||||
if e != nil {
|
||||
C.ReleaseCheckpointState(ortCheckpointState)
|
||||
return nil, fmt.Errorf("error creating C training session: %w", e)
|
||||
}
|
||||
defer func() {
|
||||
C.ReleaseOrtTrainingSession(ortTrainingSession)
|
||||
C.ReleaseCheckpointState(ortCheckpointState)
|
||||
}()
|
||||
|
||||
var inputCountTraining, inputCountEval C.size_t
|
||||
status := C.TrainingSessionGetInputCount(ortTrainingSession, &inputCountTraining, &inputCountEval)
|
||||
if status != nil {
|
||||
return nil, statusToError(status)
|
||||
}
|
||||
|
||||
var outputCountTraining, outputCountEval C.size_t
|
||||
status = C.TrainingSessionGetOutputCount(ortTrainingSession, &outputCountTraining, &outputCountEval)
|
||||
if status != nil {
|
||||
return nil, statusToError(status)
|
||||
}
|
||||
|
||||
trainInputNames := make([]string, inputCountTraining)
|
||||
trainOutputNames := make([]string, outputCountTraining)
|
||||
|
||||
for i := 0; i < int(inputCountTraining); i++ {
|
||||
name, err := getInputName(ortTrainingSession, i, "train")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error retrieving train input name: %w", err)
|
||||
}
|
||||
trainInputNames[i] = name
|
||||
}
|
||||
|
||||
for i := 0; i < int(outputCountTraining); i++ {
|
||||
name, err := getOutputName(ortTrainingSession, i, "train")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error retrieving train output name: %w", err)
|
||||
}
|
||||
trainOutputNames[i] = name
|
||||
}
|
||||
|
||||
var evalInputNames []string
|
||||
var evalOutputNames []string
|
||||
|
||||
if len(evalData) > 0 {
|
||||
evalInputNames = make([]string, inputCountEval)
|
||||
evalOutputNames = make([]string, outputCountEval)
|
||||
|
||||
for i := 0; i < int(inputCountEval); i++ {
|
||||
name, err := getInputName(ortTrainingSession, i, "eval")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error retrieving eval input name: %w", err)
|
||||
}
|
||||
evalInputNames[i] = name
|
||||
}
|
||||
|
||||
for i := 0; i < int(outputCountTraining); i++ {
|
||||
name, err := getOutputName(ortTrainingSession, i, "eval")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error retrieving eval output name: %w", err)
|
||||
}
|
||||
evalOutputNames[i] = name
|
||||
}
|
||||
}
|
||||
|
||||
return &TrainingInputOutputNames{
|
||||
TrainingInputNames: trainInputNames,
|
||||
EvalInputNames: evalInputNames,
|
||||
TrainingOutputNames: trainOutputNames,
|
||||
EvalOutputNames: evalOutputNames,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// IsTrainingSupported returns true if the training api is supported
|
||||
// by the onnxruntime library.
|
||||
func IsTrainingSupported() bool {
|
||||
return C.IsTrainingApiSupported() != 0
|
||||
}
|
||||
|
||||
func checkTraining() error {
|
||||
if !IsInitialized() {
|
||||
return NotInitializedError
|
||||
}
|
||||
if !IsTrainingSupported() {
|
||||
return trainingNotSupportedError
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func createCCheckpoint(onnxData []byte) (*C.OrtCheckpointState, error) {
|
||||
if e := checkTraining(); e != nil {
|
||||
return nil, e
|
||||
}
|
||||
if len(onnxData) == 0 {
|
||||
return nil, fmt.Errorf("Missing checkpoint data")
|
||||
}
|
||||
var ortCheckpointState *C.OrtCheckpointState
|
||||
status := C.CreateCheckpoint(unsafe.Pointer(&(onnxData[0])), C.size_t(len(onnxData)), &ortCheckpointState)
|
||||
if status != nil {
|
||||
return nil, statusToError(status)
|
||||
}
|
||||
return ortCheckpointState, nil
|
||||
}
|
||||
|
||||
// createCTrainingSessionWithOnnxData creates a C session from byte data using buffers
|
||||
func createCTrainingSessionWithOnnxData(checkpointState *C.OrtCheckpointState,
|
||||
trainingData, evalData, optimizerData []byte,
|
||||
options *SessionOptions) (*C.OrtTrainingSession, error) {
|
||||
if e := checkTraining(); e != nil {
|
||||
return nil, e
|
||||
}
|
||||
|
||||
var ortTrainingSession *C.OrtTrainingSession
|
||||
var ortSessionOptions *C.OrtSessionOptions
|
||||
if options != nil {
|
||||
ortSessionOptions = options.o
|
||||
}
|
||||
|
||||
// eval model is optional
|
||||
var evalDataPtr unsafe.Pointer
|
||||
var evalDataSize C.size_t
|
||||
if len(evalData) > 0 {
|
||||
evalDataPtr = unsafe.Pointer(&(evalData[0]))
|
||||
evalDataSize = C.size_t(len(evalData))
|
||||
}
|
||||
|
||||
// optimizer model is also optional when e.g. getting input and output names
|
||||
var optimizerDataPtr unsafe.Pointer
|
||||
var optimizerDataSize C.size_t
|
||||
if len(optimizerData) > 0 {
|
||||
optimizerDataPtr = unsafe.Pointer(&(optimizerData[0]))
|
||||
optimizerDataSize = C.size_t(len(optimizerData))
|
||||
}
|
||||
|
||||
status := C.CreateTrainingSessionFromBuffer(
|
||||
checkpointState,
|
||||
unsafe.Pointer(&(trainingData[0])), C.size_t(len(trainingData)),
|
||||
evalDataPtr, evalDataSize,
|
||||
optimizerDataPtr, optimizerDataSize,
|
||||
ortEnv, &ortTrainingSession, ortSessionOptions)
|
||||
if status != nil {
|
||||
return nil, statusToError(status)
|
||||
}
|
||||
return ortTrainingSession, nil
|
||||
}
|
||||
|
||||
// createCTrainingSessionWithPaths creates a C session from paths
|
||||
func createCtrainingSessionWithPaths(checkpointState *C.OrtCheckpointState,
|
||||
trainingPath, evalPath, optimizerPath *C.char,
|
||||
options *SessionOptions) (*C.OrtTrainingSession, error) {
|
||||
if e := checkTraining(); e != nil {
|
||||
return nil, e
|
||||
}
|
||||
|
||||
var ortTrainingSession *C.OrtTrainingSession
|
||||
var ortSessionOptions *C.OrtSessionOptions
|
||||
if options != nil {
|
||||
ortSessionOptions = options.o
|
||||
}
|
||||
|
||||
status := C.CreateTrainingSessionFromPaths(checkpointState,
|
||||
trainingPath, evalPath, optimizerPath, ortEnv, &ortTrainingSession, ortSessionOptions)
|
||||
|
||||
if status != nil {
|
||||
return nil, statusToError(status)
|
||||
}
|
||||
return ortTrainingSession, nil
|
||||
}
|
||||
|
||||
// NewTrainingSessionWithOnnxData is like NewTrainingSession, but it accepts
|
||||
// bytes rather than paths to the training assets. Note that there does not
|
||||
// seem to currently be a way to export the trained model from a session
|
||||
// instantiated from bytes. If you wish to export the trained model, you should
|
||||
// use NewTrainingSession instead.
|
||||
func NewTrainingSessionWithOnnxData(checkpointData []byte,
|
||||
trainingData []byte,
|
||||
evalData []byte,
|
||||
optimizerData []byte,
|
||||
inputs, outputs []Value,
|
||||
options *SessionOptions) (*TrainingSession, error) {
|
||||
|
||||
if err := checkTraining(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := validateInputOutputs(inputs, outputs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(trainingData) == 0 {
|
||||
return nil, fmt.Errorf("training data has length zero.")
|
||||
}
|
||||
if len(optimizerData) == 0 {
|
||||
return nil, fmt.Errorf("optimizer data has length zero.")
|
||||
}
|
||||
|
||||
// create checkpoint C object
|
||||
ortCheckpointState, e := createCCheckpoint(checkpointData)
|
||||
if e != nil {
|
||||
return nil, fmt.Errorf("error creating C checkpointState: %w", e)
|
||||
}
|
||||
|
||||
// create session C object
|
||||
ortTrainingSession, e := createCTrainingSessionWithOnnxData(ortCheckpointState,
|
||||
trainingData, evalData, optimizerData, options)
|
||||
if e != nil {
|
||||
return nil, fmt.Errorf("error creating C training session: %w", e)
|
||||
}
|
||||
|
||||
inputOrtTensors := make([]*C.OrtValue, len(inputs))
|
||||
outputOrtTensors := make([]*C.OrtValue, len(outputs))
|
||||
for i, v := range inputs {
|
||||
inputOrtTensors[i] = v.GetInternals().ortValue
|
||||
}
|
||||
for i, v := range outputs {
|
||||
outputOrtTensors[i] = v.GetInternals().ortValue
|
||||
}
|
||||
|
||||
return &TrainingSession{
|
||||
ortCheckpointState: ortCheckpointState,
|
||||
ortTrainingSession: ortTrainingSession,
|
||||
inputs: inputOrtTensors,
|
||||
outputs: outputOrtTensors,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func validateInputOutputs(inputs, outputs []Value) error {
|
||||
if len(inputs) == 0 {
|
||||
return fmt.Errorf("inputs must have length greater than zero")
|
||||
}
|
||||
if len(outputs) == 0 {
|
||||
return fmt.Errorf("outputs must have length greater than zero")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewTrainingSession creates a new training session from paths stored on disk.
|
||||
// evalModelPath is optional and can be the empty string. In case it is not
|
||||
// provided, only the checkpoint state can be exported once training is complete
|
||||
// (and not the final inference model).
|
||||
func NewTrainingSession(checkpointStatePath string,
|
||||
trainingModelPath string,
|
||||
evalModelPath string,
|
||||
optimizerModelPath string,
|
||||
inputs, outputs []Value,
|
||||
options *SessionOptions) (*TrainingSession, error) {
|
||||
|
||||
if err := checkTraining(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := validateInputOutputs(inputs, outputs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
checkPointContent, e := os.ReadFile(checkpointStatePath)
|
||||
if e != nil {
|
||||
return nil, fmt.Errorf("reading checkpoint data failed: %s", e.Error())
|
||||
}
|
||||
|
||||
// create checkpoint C object
|
||||
ortCheckpointState, e := createCCheckpoint(checkPointContent)
|
||||
if e != nil {
|
||||
return nil, fmt.Errorf("error creating C checkpointState: %w", e)
|
||||
}
|
||||
|
||||
// create session C object
|
||||
if _, err := os.Stat(trainingModelPath); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("training model does not exist at path %s", trainingModelPath)
|
||||
}
|
||||
cTrainingPath, err := createOrtCharString(trainingModelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error converting training model path to C string: %w", err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(optimizerModelPath); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("optimizer s does not exist at path %s", optimizerModelPath)
|
||||
}
|
||||
cOptimizerPath, err := createOrtCharString(optimizerModelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error converting optimizer path to C string: %w", err)
|
||||
}
|
||||
|
||||
// eval is optional
|
||||
var cEvalPath *C.char
|
||||
if evalModelPath != "" {
|
||||
if _, err := os.Stat(evalModelPath); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("eval model does not exist at path %s", evalModelPath)
|
||||
}
|
||||
cEvalPath, err = createOrtCharString(evalModelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error converting eval path to C string: %w", err)
|
||||
}
|
||||
} else {
|
||||
cEvalPath = nil
|
||||
}
|
||||
|
||||
ortTrainingSession, e := createCtrainingSessionWithPaths(ortCheckpointState,
|
||||
cTrainingPath, cEvalPath, cOptimizerPath, options)
|
||||
if e != nil {
|
||||
return nil, fmt.Errorf("error creating C training session: %w", e)
|
||||
}
|
||||
|
||||
inputOrtTensors := make([]*C.OrtValue, len(inputs))
|
||||
outputOrtTensors := make([]*C.OrtValue, len(outputs))
|
||||
for i, v := range inputs {
|
||||
inputOrtTensors[i] = v.GetInternals().ortValue
|
||||
}
|
||||
for i, v := range outputs {
|
||||
outputOrtTensors[i] = v.GetInternals().ortValue
|
||||
}
|
||||
|
||||
return &TrainingSession{
|
||||
ortCheckpointState: ortCheckpointState,
|
||||
ortTrainingSession: ortTrainingSession,
|
||||
inputs: inputOrtTensors,
|
||||
outputs: outputOrtTensors,
|
||||
evalModelPath: cEvalPath,
|
||||
trainingModelPath: cTrainingPath,
|
||||
optimizerModelPath: cOptimizerPath,
|
||||
}, nil
|
||||
}
|
@@ -1,353 +0,0 @@
|
||||
package onnxruntime_go
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"math/rand"
|
||||
"os"
|
||||
"path"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTrainingNotSupported(t *testing.T) {
|
||||
InitializeRuntime(t)
|
||||
defer CleanupRuntime(t)
|
||||
|
||||
if IsTrainingSupported() {
|
||||
t.Skipf("onnxruntime library supports training")
|
||||
}
|
||||
|
||||
options, e := NewSessionOptions()
|
||||
if e != nil {
|
||||
t.Logf("Failed creating options: %s\n", e)
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
trainingSession, e := NewTrainingSession("test_data/onnxruntime_training_test/training_artifacts/checkpoint",
|
||||
"test_data/onnxruntime_training_test/training_artifacts/training_model.onnx",
|
||||
"test_data/onnxruntime_training_test/training_artifacts/eval_model.onnx",
|
||||
"test_data/onnxruntime_training_test/training_artifacts/optimizer_model.onnx",
|
||||
nil, nil,
|
||||
options)
|
||||
|
||||
if !errors.Is(e, trainingNotSupportedError) {
|
||||
t.Logf("Creating training session when onnxruntime lib does not support it should return training not supported error.")
|
||||
if e != nil {
|
||||
t.Logf("Received instead error: %s", e.Error())
|
||||
} else {
|
||||
t.Logf("Received no error instead")
|
||||
}
|
||||
t.FailNow()
|
||||
}
|
||||
if trainingSession != nil {
|
||||
if err := trainingSession.Destroy(); err != nil {
|
||||
t.Fatalf("cleanup of training session failed with error: %v", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetInputOutputNames(t *testing.T) {
|
||||
InitializeRuntime(t)
|
||||
defer CleanupRuntime(t)
|
||||
|
||||
if !IsTrainingSupported() {
|
||||
t.Skipf("Training is not supported on this platform/onnxruntime build.")
|
||||
}
|
||||
|
||||
artifactsPath := path.Join("test_data", "training_test")
|
||||
|
||||
names, err := GetInputOutputNames(
|
||||
path.Join(artifactsPath, "checkpoint"),
|
||||
path.Join(artifactsPath, "training_model.onnx"),
|
||||
path.Join(artifactsPath, "eval_model.onnx"),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed getting input and output names with error: %v\n", err)
|
||||
}
|
||||
|
||||
expectedTrainInputNames := []string{"input", "target"}
|
||||
expectedEvalInputNames := expectedTrainInputNames
|
||||
expectedTrainOutputNames := []string{"onnx::reducemean_output::5"}
|
||||
expectedEvalOutputNames := expectedTrainOutputNames
|
||||
|
||||
for i, v := range names.TrainingInputNames {
|
||||
if v != expectedTrainInputNames[i] {
|
||||
t.Fatalf("training input names don't match")
|
||||
}
|
||||
}
|
||||
for i, v := range names.TrainingOutputNames {
|
||||
if v != expectedTrainOutputNames[i] {
|
||||
t.Fatalf("training output names don't match")
|
||||
}
|
||||
}
|
||||
for i, v := range names.EvalInputNames {
|
||||
if v != expectedEvalInputNames[i] {
|
||||
t.Fatalf("eval input names don't match")
|
||||
}
|
||||
}
|
||||
for i, v := range names.EvalOutputNames {
|
||||
if v != expectedEvalOutputNames[i] {
|
||||
t.Fatalf("eval output names don't match")
|
||||
}
|
||||
}
|
||||
|
||||
// without eval model
|
||||
names, err = GetInputOutputNames(
|
||||
path.Join(artifactsPath, "checkpoint"),
|
||||
path.Join(artifactsPath, "training_model.onnx"),
|
||||
"",
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed getting input and output names with error: %v\n", err)
|
||||
}
|
||||
|
||||
for i, v := range names.TrainingInputNames {
|
||||
if v != expectedTrainInputNames[i] {
|
||||
t.Fatalf("training input names don't match")
|
||||
}
|
||||
}
|
||||
for i, v := range names.TrainingOutputNames {
|
||||
if v != expectedTrainOutputNames[i] {
|
||||
t.Fatalf("training output names don't match")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func generateBatchData(nBatches int, batchSize int) map[int]map[string][]float32 {
|
||||
batchData := map[int]map[string][]float32{}
|
||||
|
||||
source := rand.NewSource(1234)
|
||||
g := rand.New(source)
|
||||
|
||||
for i := 0; i < nBatches; i++ {
|
||||
inputCounter := 0
|
||||
outputCounter := 0
|
||||
inputSlice := make([]float32, batchSize*4)
|
||||
outputSlice := make([]float32, batchSize*2)
|
||||
batchData[i] = map[string][]float32{}
|
||||
|
||||
// generate random data for batch
|
||||
for n := 0; n < batchSize; n++ {
|
||||
var sum float32
|
||||
min := float32(1)
|
||||
max := float32(-1)
|
||||
for i := 0; i < 4; i++ {
|
||||
r := g.Float32()
|
||||
inputSlice[inputCounter] = r
|
||||
inputCounter++
|
||||
if r > max {
|
||||
max = r
|
||||
}
|
||||
if r < min {
|
||||
min = r
|
||||
}
|
||||
sum = sum + r
|
||||
}
|
||||
outputSlice[outputCounter] = sum
|
||||
outputSlice[outputCounter+1] = max - min
|
||||
outputCounter = outputCounter + 2
|
||||
}
|
||||
batchData[i]["input"] = inputSlice
|
||||
batchData[i]["output"] = outputSlice
|
||||
}
|
||||
return batchData
|
||||
}
|
||||
|
||||
// TestTraining tests a basic training flow using the bindings to the C api for on-device onnxruntime training
|
||||
func TestTraining(t *testing.T) {
|
||||
InitializeRuntime(t)
|
||||
defer CleanupRuntime(t)
|
||||
|
||||
if !IsTrainingSupported() {
|
||||
t.Skipf("Training is not supported on this platform/onnxruntime build.")
|
||||
}
|
||||
|
||||
trainingArtifactsFolder := path.Join("test_data", "training_test")
|
||||
|
||||
// generate training data
|
||||
batchSize := 10
|
||||
nBatches := 10
|
||||
|
||||
// holds inputs/outputs and loss for each training batch
|
||||
batchInputShape := NewShape(int64(batchSize), 1, 4)
|
||||
batchTargetShape := NewShape(int64(batchSize), 1, 2)
|
||||
batchInputTensor, err := NewEmptyTensor[float32](batchInputShape)
|
||||
if err != nil {
|
||||
t.Fatalf("training test failed with error: %v", err)
|
||||
}
|
||||
batchTargetTensor, err := NewEmptyTensor[float32](batchTargetShape)
|
||||
if err != nil {
|
||||
t.Fatalf("training test failed with error: %v", err)
|
||||
}
|
||||
lossScalar, err := NewEmptyScalar[float32]()
|
||||
if err != nil {
|
||||
t.Fatalf("training test failed with error: %v", err)
|
||||
}
|
||||
|
||||
trainingSession, errorSessionCreation := NewTrainingSession(
|
||||
path.Join(trainingArtifactsFolder, "checkpoint"),
|
||||
path.Join(trainingArtifactsFolder, "training_model.onnx"),
|
||||
path.Join(trainingArtifactsFolder, "eval_model.onnx"),
|
||||
path.Join(trainingArtifactsFolder, "optimizer_model.onnx"),
|
||||
[]Value{batchInputTensor, batchTargetTensor}, []Value{lossScalar},
|
||||
nil)
|
||||
|
||||
if errorSessionCreation != nil {
|
||||
t.Fatalf("session creation failed with error: %v", errorSessionCreation)
|
||||
}
|
||||
|
||||
// cleanup after test run
|
||||
defer func(session *TrainingSession, tensors []Value) {
|
||||
var errs []error
|
||||
errs = append(errs, session.Destroy())
|
||||
for _, t := range tensors {
|
||||
errs = append(errs, t.Destroy())
|
||||
}
|
||||
if e := errors.Join(errs...); e != nil {
|
||||
t.Fatalf("cleanup of test failed with error: %v", e)
|
||||
}
|
||||
}(trainingSession, []Value{batchInputTensor, batchTargetTensor, lossScalar})
|
||||
|
||||
losses := []float32{}
|
||||
epochs := 100
|
||||
batchData := generateBatchData(nBatches, batchSize)
|
||||
|
||||
for epoch := 0; epoch < epochs; epoch++ {
|
||||
var epochLoss float32 // total epoch loss
|
||||
|
||||
for i := 0; i < nBatches; i++ {
|
||||
inputSlice := batchInputTensor.GetData()
|
||||
outputSlice := batchTargetTensor.GetData()
|
||||
|
||||
copy(inputSlice, batchData[i]["input"])
|
||||
copy(outputSlice, batchData[i]["output"])
|
||||
|
||||
// train on batch
|
||||
err = trainingSession.TrainStep()
|
||||
if err != nil {
|
||||
t.Fatalf("train step failed with error: %v", err)
|
||||
}
|
||||
|
||||
epochLoss = epochLoss + lossScalar.GetData()
|
||||
|
||||
err = trainingSession.OptimizerStep()
|
||||
if err != nil {
|
||||
t.Fatalf("optimizer step failed with error: %v", err)
|
||||
}
|
||||
|
||||
// ort training api - reset the gradients to zero so that new gradients can be computed for next batch
|
||||
err = trainingSession.LazyResetGrad()
|
||||
if err != nil {
|
||||
t.Fatalf("lazy reset grad step failed with error: %v", err)
|
||||
}
|
||||
}
|
||||
if epoch%10 == 0 {
|
||||
t.Logf("Epoch {%d} Loss {%f}\n", epoch+1, epochLoss/float32(batchSize*nBatches))
|
||||
losses = append(losses, epochLoss/float32(batchSize*nBatches))
|
||||
}
|
||||
}
|
||||
|
||||
expectedLosses := []float32{
|
||||
0.125085,
|
||||
0.097187,
|
||||
0.062333,
|
||||
0.024307,
|
||||
0.019963,
|
||||
0.018476,
|
||||
0.017160,
|
||||
0.015982,
|
||||
0.014845,
|
||||
0.013867,
|
||||
}
|
||||
|
||||
for i, l := range losses {
|
||||
diff := math.Abs(float64(l - expectedLosses[i]))
|
||||
deviation := diff / float64(expectedLosses[i])
|
||||
if deviation > 0.6 {
|
||||
t.Fatalf("loss deviation too large: expected %f, actual %f, deviation %f", float64(expectedLosses[i]), float64(l), float64(deviation))
|
||||
}
|
||||
}
|
||||
|
||||
// test the saving of the checkpoint state
|
||||
finalCheckpointPath := path.Join("test_data", "training_test", "finalCheckpoint")
|
||||
errSaveCheckpoint := trainingSession.SaveCheckpoint(finalCheckpointPath, false)
|
||||
if errSaveCheckpoint != nil {
|
||||
t.Fatalf("Saving of checkpoint failed with error: %v", errSaveCheckpoint)
|
||||
}
|
||||
|
||||
// test the saving of the model
|
||||
finalModelPath := path.Join("test_data", "training_test", "final_inference.onnx")
|
||||
errExport := trainingSession.ExportModel(finalModelPath, []string{"output"})
|
||||
if errExport != nil {
|
||||
t.Fatalf("Exporting model failed with error: %v", errExport)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
e := os.Remove(finalCheckpointPath)
|
||||
if e != nil {
|
||||
t.Errorf("Error removing final checkpoint file %s: %s", finalCheckpointPath, e)
|
||||
}
|
||||
e = os.Remove(finalModelPath)
|
||||
if e != nil {
|
||||
t.Errorf("Error removing final model file %s: %s", finalModelPath, e)
|
||||
}
|
||||
}()
|
||||
|
||||
// load the model back in and test in-sample predictions for the first batch
|
||||
// (we care about correctness more than generalization here)
|
||||
session, err := NewAdvancedSession(path.Join("test_data", "training_test", "final_inference.onnx"),
|
||||
[]string{"input"}, []string{"output"},
|
||||
[]Value{batchInputTensor}, []Value{batchTargetTensor}, nil)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("creation of inference session failed with error: %v", err)
|
||||
}
|
||||
|
||||
defer func(s *AdvancedSession) {
|
||||
err := s.Destroy()
|
||||
if err != nil {
|
||||
t.Fatalf("cleanup of inference session failed with error: %v", err)
|
||||
}
|
||||
}(session)
|
||||
|
||||
// Calling Run() will run the network, reading the current contents of the
|
||||
// input tensors and modifying the contents of the output tensors.
|
||||
copy(batchInputTensor.GetData(), batchData[0]["input"])
|
||||
err = session.Run()
|
||||
if err != nil {
|
||||
t.Fatalf("run of inference session failed with error: %v", err)
|
||||
}
|
||||
|
||||
expectedOutput := []float32{
|
||||
2.4524384,
|
||||
0.65120333,
|
||||
2.5457804,
|
||||
0.6102175,
|
||||
1.6276635,
|
||||
0.573755,
|
||||
1.7900972,
|
||||
0.59951085,
|
||||
3.1650176,
|
||||
0.66626525,
|
||||
1.9361509,
|
||||
0.571084,
|
||||
2.0798547,
|
||||
0.6060241,
|
||||
0.9611889,
|
||||
0.52100605,
|
||||
1.4070896,
|
||||
0.5412475,
|
||||
2.1449144,
|
||||
0.5985652,
|
||||
}
|
||||
|
||||
for i, l := range batchTargetTensor.GetData() {
|
||||
diff := math.Abs(float64(l - expectedOutput[i]))
|
||||
deviation := diff / float64(expectedOutput[i])
|
||||
if deviation > 0.6 {
|
||||
t.Fatalf("deviation too large")
|
||||
}
|
||||
}
|
||||
}
|
@@ -402,146 +402,3 @@ OrtStatus *CreateOrtValue(OrtValue **in, size_t num_values,
|
||||
value_type, out);
|
||||
}
|
||||
|
||||
// TRAINING API WRAPPER
|
||||
|
||||
static const OrtTrainingApi *ort_training_api = NULL;
|
||||
|
||||
void SetTrainingApi() {
|
||||
ort_training_api = ort_api->GetTrainingApi(ORT_API_VERSION);
|
||||
}
|
||||
|
||||
int IsTrainingApiSupported() {
|
||||
return ort_training_api != NULL;
|
||||
}
|
||||
|
||||
OrtStatus *CreateCheckpoint(void *checkpoint_data, size_t checkpoint_data_length, OrtCheckpointState **out) {
|
||||
OrtStatus *status = NULL;
|
||||
status = ort_training_api->LoadCheckpointFromBuffer(checkpoint_data, checkpoint_data_length, out);
|
||||
return status;
|
||||
}
|
||||
|
||||
OrtStatus *CreateTrainingSessionFromBuffer(OrtCheckpointState *checkpoint_state,
|
||||
void *training_model_data, size_t training_model_data_length,
|
||||
void *eval_model_data, size_t eval_model_data_length,
|
||||
void *optim_model_data, size_t optim_model_data_length,
|
||||
OrtEnv *env, OrtTrainingSession **out, OrtSessionOptions *options) {
|
||||
OrtStatus *status = NULL;
|
||||
int default_options = 0;
|
||||
if (!options) {
|
||||
default_options = 1;
|
||||
status = ort_api->CreateSessionOptions(&options);
|
||||
if (status) return status;
|
||||
}
|
||||
status = ort_training_api->CreateTrainingSessionFromBuffer(env, options, checkpoint_state,
|
||||
training_model_data, training_model_data_length, eval_model_data, eval_model_data_length,
|
||||
optim_model_data, optim_model_data_length, out);
|
||||
if (default_options) {
|
||||
ort_api->ReleaseSessionOptions(options);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
OrtStatus *CreateTrainingSessionFromPaths(OrtCheckpointState *checkpoint_state,
|
||||
char *training_model_path, char *eval_model_path, char *optim_model_path,
|
||||
OrtEnv *env, OrtTrainingSession **out, OrtSessionOptions *options) {
|
||||
OrtStatus *status = NULL;
|
||||
int default_options = 0;
|
||||
if (!options) {
|
||||
default_options = 1;
|
||||
status = ort_api->CreateSessionOptions(&options);
|
||||
if (status) return status;
|
||||
}
|
||||
status = ort_training_api->CreateTrainingSession(env, options, checkpoint_state,
|
||||
training_model_path, eval_model_path, optim_model_path, out);
|
||||
if (default_options) {
|
||||
ort_api->ReleaseSessionOptions(options);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
OrtStatus *TrainingSessionGetInputCount(OrtTrainingSession *training_session, size_t *result_training, size_t *result_eval) {
|
||||
OrtStatus *status = NULL;
|
||||
status = ort_training_api->TrainingSessionGetTrainingModelInputCount(training_session, result_training);
|
||||
if (status) return status;
|
||||
status = ort_training_api->TrainingSessionGetEvalModelInputCount(training_session, result_eval);
|
||||
return status;
|
||||
}
|
||||
|
||||
OrtStatus *TrainingSessionGetOutputCount(OrtTrainingSession *training_session, size_t *result_training, size_t *result_eval) {
|
||||
OrtStatus *status = NULL;
|
||||
status = ort_training_api->TrainingSessionGetTrainingModelOutputCount(training_session, result_training);
|
||||
if (status) return status;
|
||||
status = ort_training_api->TrainingSessionGetEvalModelOutputCount(training_session, result_eval);
|
||||
return status;
|
||||
}
|
||||
|
||||
OrtStatus *TrainingSessionGetTrainingInputName(OrtTrainingSession *training_session, size_t i, char **name) {
|
||||
OrtAllocator *allocator = NULL;
|
||||
OrtStatus *status = NULL;
|
||||
status = ort_api->GetAllocatorWithDefaultOptions(&allocator);
|
||||
if (status) return status;
|
||||
return ort_training_api->TrainingSessionGetTrainingModelInputName(training_session, i, allocator, name);
|
||||
}
|
||||
|
||||
OrtStatus *TrainingSessionGetTrainingOutputName(OrtTrainingSession *training_session, size_t i, char **name) {
|
||||
OrtAllocator *allocator = NULL;
|
||||
OrtStatus *status = NULL;
|
||||
status = ort_api->GetAllocatorWithDefaultOptions(&allocator);
|
||||
if (status) return status;
|
||||
return ort_training_api->TrainingSessionGetTrainingModelOutputName(training_session, i, allocator, name);
|
||||
}
|
||||
|
||||
OrtStatus *TrainingSessionGetEvalInputName(OrtTrainingSession *training_session, size_t i, char **name) {
|
||||
OrtAllocator *allocator = NULL;
|
||||
OrtStatus *status = NULL;
|
||||
status = ort_api->GetAllocatorWithDefaultOptions(&allocator);
|
||||
if (status) return status;
|
||||
return ort_training_api->TrainingSessionGetEvalModelInputName(training_session, i, allocator, name);
|
||||
}
|
||||
|
||||
OrtStatus *TrainingSessionGetEvalOutputName(OrtTrainingSession *training_session, size_t i, char **name) {
|
||||
OrtAllocator *allocator = NULL;
|
||||
OrtStatus *status = NULL;
|
||||
status = ort_api->GetAllocatorWithDefaultOptions(&allocator);
|
||||
if (status) return status;
|
||||
return ort_training_api->TrainingSessionGetEvalModelOutputName(training_session, i, allocator, name);
|
||||
}
|
||||
|
||||
OrtStatus *TrainStep(OrtTrainingSession *training_session, size_t inputs_len, OrtValue **inputs, size_t output_len, OrtValue **outputs) {
|
||||
OrtStatus *status = NULL;
|
||||
status = ort_training_api->TrainStep(training_session, NULL, inputs_len, (const OrtValue* const*) inputs, output_len, outputs);
|
||||
return status;
|
||||
}
|
||||
|
||||
OrtStatus *OptimizerStep(OrtTrainingSession *training_session) {
|
||||
OrtStatus *status = NULL;
|
||||
status = ort_training_api->OptimizerStep(training_session, NULL);
|
||||
return status;
|
||||
}
|
||||
|
||||
OrtStatus *LazyResetGrad(OrtTrainingSession *training_session) {
|
||||
OrtStatus *status = NULL;
|
||||
status = ort_training_api->LazyResetGrad(training_session);
|
||||
return status;
|
||||
}
|
||||
|
||||
OrtStatus *SaveCheckpoint(OrtCheckpointState *checkpoint, char *path, size_t include_optimizer) {
|
||||
OrtStatus *status = NULL;
|
||||
status = ort_training_api->SaveCheckpoint(checkpoint, path, include_optimizer);
|
||||
return status;
|
||||
}
|
||||
|
||||
OrtStatus *ExportModel(OrtTrainingSession *training_session, char *path, size_t outputs_len, char **output_names) {
|
||||
OrtStatus *status = NULL;
|
||||
status = ort_training_api->ExportModelForInferencing(training_session, path, outputs_len, (const char* const*) output_names);
|
||||
return status;
|
||||
}
|
||||
|
||||
void ReleaseOrtTrainingSession(OrtTrainingSession *session) {
|
||||
ort_training_api->ReleaseTrainingSession(session);
|
||||
}
|
||||
|
||||
void ReleaseCheckpointState(OrtCheckpointState *checkpoint) {
|
||||
ort_training_api->ReleaseCheckpointState(checkpoint);
|
||||
}
|
||||
|
||||
|
@@ -15,7 +15,6 @@
|
||||
// Next, we actually include the header.
|
||||
#undef _WIN32
|
||||
#include "onnxruntime_c_api.h"
|
||||
#include "onnxruntime_training_c_api.h"
|
||||
|
||||
// ... However, mingw will complain if _WIN32 is *not* defined! So redefine it.
|
||||
#define _WIN32
|
||||
@@ -254,75 +253,6 @@ OrtStatus *GetValueCount(OrtValue *v, size_t *out);
|
||||
OrtStatus *CreateOrtValue(OrtValue **in, size_t num_values,
|
||||
enum ONNXType value_type, OrtValue **out);
|
||||
|
||||
// TRAINING API WRAPPER
|
||||
|
||||
void SetTrainingApi();
|
||||
|
||||
// Checks if training api is supported.
|
||||
int IsTrainingApiSupported();
|
||||
|
||||
// Wraps ort_training_api->CreateSessionFromBuffer.
|
||||
// Creates and ORT checkpoint from the checkpoint data.
|
||||
OrtStatus *CreateCheckpoint(void *checkpoint_data,
|
||||
size_t checkpoint_data_length, OrtCheckpointState **out);
|
||||
|
||||
// Wraps ort_training_api->CreateTrainingSessionFromBuffer. Creates an ORT
|
||||
// training session using the given models and checkpoint. The given options
|
||||
// pointer may be NULL; if it is, then we'll use default options.
|
||||
OrtStatus *CreateTrainingSessionFromBuffer(OrtCheckpointState *checkpoint_state,
|
||||
void *training_model_data, size_t training_model_data_length,
|
||||
void *eval_model_data, size_t eval_model_data_length,
|
||||
void *optim_model_data, size_t optim_model_data_length,
|
||||
OrtEnv *env, OrtTrainingSession **out, OrtSessionOptions *options);
|
||||
|
||||
// Wraps ort_training_api->CreateTrainingSession.
|
||||
// Currently this is the only way to create a training session that is able to
|
||||
// export the final trained model to disk.
|
||||
OrtStatus *CreateTrainingSessionFromPaths(OrtCheckpointState *checkpoint_state,
|
||||
char *training_model_path, char *eval_model_path, char *optim_model_path,
|
||||
OrtEnv *env, OrtTrainingSession **out, OrtSessionOptions *options);
|
||||
|
||||
// Wraps ort_training_api->TrainingSessionGetTrainingModelInputCount
|
||||
// and ort_training_api->TrainingSessionGetEvalgModelInputCount.
|
||||
OrtStatus *TrainingSessionGetInputCount(OrtTrainingSession *training_session, size_t *result_training, size_t *result_eval);
|
||||
|
||||
// Wraps ort_training_api->TrainingSessionGetTrainingModelOutputCounet
|
||||
// and ort_training_api->TrainingSessionGetEvalgModelOutputCount.
|
||||
OrtStatus *TrainingSessionGetOutputCount(OrtTrainingSession *training_session, size_t *result_training, size_t *result_eval);
|
||||
|
||||
// Wraps ort_training_api->TrainingSessionGetTrainingModelInputName.
|
||||
OrtStatus *TrainingSessionGetTrainingInputName(OrtTrainingSession *training_session, size_t i, char **name);
|
||||
|
||||
// Wraps ort_training_api->TrainingSessionGetEvalModelInputName.
|
||||
OrtStatus *TrainingSessionGetEvalInputName(OrtTrainingSession *training_session, size_t i, char **name);
|
||||
|
||||
// Wraps ort_training_api->TrainingSessionGetTrainingModelOutputName.
|
||||
OrtStatus *TrainingSessionGetTrainingOutputName(OrtTrainingSession *training_session, size_t i, char **name);
|
||||
|
||||
// Wraps ort_training_api->TrainingSessionGetEvalModelOutputName.
|
||||
OrtStatus *TrainingSessionGetEvalOutputName(OrtTrainingSession *training_session, size_t i, char **name);
|
||||
|
||||
// Wraps ort_training_api->TrainStep.
|
||||
OrtStatus *TrainStep(OrtTrainingSession *training_session, size_t inputs_len, OrtValue **inputs, size_t output_len, OrtValue **outputs);
|
||||
|
||||
// Wraps ort_training_api->OptimizerStep.
|
||||
OrtStatus *OptimizerStep(OrtTrainingSession *training_session);
|
||||
|
||||
// Wraps ort_training_api->LazyResetGrad.
|
||||
OrtStatus *LazyResetGrad(OrtTrainingSession *training_session);
|
||||
|
||||
// Wraps ort_training_api->SaveCheckpoint.
|
||||
OrtStatus *SaveCheckpoint(OrtCheckpointState *checkpoint, char *path, size_t include_optimizer);
|
||||
|
||||
// Wraps ort_training_api->ExportModel.
|
||||
OrtStatus *ExportModel(OrtTrainingSession *training_session, char *path, size_t outputs_len, char **output_names);
|
||||
|
||||
// Wraps ort_training_api->ReleaseTrainingSession.
|
||||
void ReleaseOrtTrainingSession(OrtTrainingSession *session);
|
||||
|
||||
// Wraps ort_training_api->ReleaseCheckpointState.
|
||||
void ReleaseCheckpointState(OrtCheckpointState *checkpoint);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
|
@@ -56,10 +56,6 @@ func platformInitializeEnvironment() error {
|
||||
return fmt.Errorf("Error setting ORT API base: %d", tmp)
|
||||
}
|
||||
|
||||
// we do not initialize the training API on windows (see setup_env.go)
|
||||
// because currently we cannot support the conversion from UTF-8 to wide
|
||||
// character. See https://github.com/yalue/onnxruntime_go/pull/56.
|
||||
|
||||
libraryHandle = handle
|
||||
return nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user