diff --git a/README.md b/README.md index c7eab42..abe2db2 100644 --- a/README.md +++ b/README.md @@ -206,10 +206,13 @@ run and pass. Training API Support -------------------- -This wrapper supports the onnxruntime training API on limited platforms. See -the `NewTrainingSession` and associated data types or functions to use it. So -far, the training API has only been tested on Linux, on `x86_64` architectures. +The training API has been deprecated as of onnxruntime version 1.20. Rather +than continuing to maintain wrappers for a deprecated API, `onnxruntime_go` has +replaced the wrapper functions for the training API with stubs that return an +error. Users who need to continue to use the training API will need to use an +older version. For example the following versions should be compatible with +training: + + - Version `v1.12.1` of `onnxruntime_go`, and + - Version 1.19.2 of `onnxruntime`. -If you are not sure whether your platform or build of onnxruntime supports -training, you can call `onnxruntime_go.IsTrainingSupported()`, which will -return `true` if training is supported on your system. diff --git a/legacy_types.go b/legacy_types.go index 69bfb47..5408647 100644 --- a/legacy_types.go +++ b/legacy_types.go @@ -171,3 +171,77 @@ type ArbitraryTensor = Value // As with the ArbitraryTensor type, this type alias only exists to facilitate // renaming an old type without breaking existing code. type TensorInternalData = ValueInternalData + +var TrainingAPIRemovedError error = fmt.Errorf("Support for the training " + + "API has been removed from onnxruntime_go following its deprecation in " + + "onnxruntime versions 1.19.2 and later. The last revision of " + + "onnxruntime_go supporting the training API is version v1.12.1") + +// Support for TrainingSessions has been removed from onnxruntime_go following +// the deprecation of the training API in onnxruntime 1.20.0. +type TrainingSession struct{} + +// Always returns TrainingAPIRemovedError. +func (s *TrainingSession) ExportModel(path string, outputNames []string) error { + return TrainingAPIRemovedError +} + +// Always returns TrainingAPIRemovedError. +func (s *TrainingSession) SaveCheckpoint(path string, + saveOptimizerState bool) error { + return TrainingAPIRemovedError +} + +// Always returns TrainingAPIRemovedError. +func (s *TrainingSession) Destroy() error { + return TrainingAPIRemovedError +} + +// Always returns TrainingAPIRemovedError. +func (s *TrainingSession) TrainStep() error { + return TrainingAPIRemovedError +} + +// Always returns TrainingAPIRemovedError. +func (s *TrainingSession) OptimizerStep() error { + return TrainingAPIRemovedError +} + +// Always returns TrainingAPIRemovedError. +func (s *TrainingSession) LazyResetGrad() error { + return TrainingAPIRemovedError +} + +// Support for TrainingInputOutputNames has been removed from onnxruntime_go +// following the deprecation of the training API in onnxruntime 1.20.0. +type TrainingInputOutputNames struct { + TrainingInputNames []string + EvalInputNames []string + TrainingOutputNames []string + EvalOutputNames []string +} + +// Always returns (nil, TrainingAPIRemovedError). +func GetInputOutputNames(checkpointStatePath string, trainingModelPath string, + evalModelPath string) (*TrainingInputOutputNames, error) { + return nil, TrainingAPIRemovedError +} + +// Always returns false. +func IsTrainingSupported() bool { + return false +} + +// Always returns (nil, TrainingAPIRemovedError). +func NewTrainingSessionWithOnnxData(checkpointData, trainingData, evalData, + optimizerData []byte, inputs, outputs []Value, + options *SessionOptions) (*TrainingSession, error) { + return nil, TrainingAPIRemovedError +} + +// Always returns (nil, TrainingAPIRemovedError). +func NewTrainingSession(checkpointStatePath, trainingModelPath, evalModelPath, + optimizerModelPath string, inputs, outputs []Value, + options *SessionOptions) (*TrainingSession, error) { + return nil, TrainingAPIRemovedError +} diff --git a/onnxruntime_go.go b/onnxruntime_go.go index 40cb1a9..b5ed546 100644 --- a/onnxruntime_go.go +++ b/onnxruntime_go.go @@ -90,9 +90,6 @@ func InitializeEnvironment() error { return fmt.Errorf("Platform-specific initialization failed: %w", e) } - // Get the training API pointer if it is supported. - C.SetTrainingApi() - name := C.CString("Golang onnxruntime environment") defer C.free(unsafe.Pointer(name)) status := C.CreateOrtEnv(name, &ortEnv) @@ -883,6 +880,90 @@ func (t *CustomDataTensor) GetData() []byte { return t.data } +// Scalar is like a tensor but the underlying go slice is of length 1 and it +// has no dimension. It was introduced for use with the training API, but +// remains supported since it conceivable will have use outside of the training +// API. +type Scalar[T TensorData] struct { + data []T + dataSize uintptr + ortValue *C.OrtValue +} + +// Always returns nil for Scalars. +func (s *Scalar[T]) GetShape() Shape { + return nil +} + +func (s *Scalar[T]) ZeroContents() { + C.memset(unsafe.Pointer(&s.data[0]), 0, C.size_t(s.dataSize)) +} + +func (s *Scalar[T]) Destroy() error { + C.ReleaseOrtValue(s.ortValue) + s.ortValue = nil + s.data = nil + s.dataSize = 0 + return nil +} + +// GetData returns the undelying data for the scalar. If you want to set the +// scalar's data, use Set. +func (t *Scalar[T]) GetData() T { + return t.data[0] +} + +// Changes the underlying value of the scalar to the new value. +func (t *Scalar[T]) Set(value T) { + t.data = []T{value} +} + +func (t *Scalar[T]) DataType() C.ONNXTensorElementDataType { + return GetTensorElementDataType[T]() +} + +func (t *Scalar[_]) GetInternals() *ValueInternalData { + return &ValueInternalData{ + ortValue: t.ortValue, + } +} + +func (t *Scalar[_]) GetONNXType() ONNXType { + return ONNXTypeTensor +} + +// NewEmptyScalar creates a new scalar of type T. +func NewEmptyScalar[T TensorData]() (*Scalar[T], error) { + var data T + return NewScalar(data) +} + +// NewScalar creates a new scalar of type T backed by a value of type T. +// Note that, differently from tensors, this is not a []T but just a value T. +func NewScalar[T TensorData](data T) (*Scalar[T], error) { + if !IsInitialized() { + return nil, NotInitializedError + } + + dataSlice := []T{data} + var ortValue *C.OrtValue + dataType := GetTensorElementDataType[T]() + dataSize := unsafe.Sizeof(dataSlice[0]) * uintptr(1) + + status := C.CreateOrtTensorWithShape(unsafe.Pointer(&dataSlice[0]), + C.size_t(dataSize), nil, C.int64_t(0), ortMemoryInfo, dataType, &ortValue) + if status != nil { + return nil, statusToError(status) + } + toReturn := Scalar[T]{ + data: dataSlice, + dataSize: dataSize, + ortValue: ortValue, + } + return &toReturn, nil +} + + // Holds options required when enabling the CUDA backend for a session. This // struct wraps C onnxruntime types; users must create instances of this using // the NewCUDAProviderOptions() function. So, to enable CUDA for a session, diff --git a/onnxruntime_test.go b/onnxruntime_test.go index 4fcf18b..fcdb836 100644 --- a/onnxruntime_test.go +++ b/onnxruntime_test.go @@ -1461,6 +1461,41 @@ func TestSessionFromDataBuffer(t *testing.T) { } } +func TestScalar(t *testing.T) { + InitializeRuntime(t) + defer CleanupRuntime(t) + s, e := NewEmptyScalar[float32]() + if e != nil { + t.Fatalf("Error creating empty scalar: %s\n", e) + } + if s.GetData() != 0.0 { + t.Fatalf("Empty scalar not initialized to 0: %s\n", e) + } + e = s.Destroy() + if e != nil { + t.Fatalf("Failed destroying scalar: %s\n", e) + } + s2, e := NewScalar(int64(1337)) + if e != nil { + t.Fatalf("Failed creating int64 scalar: %s\n", e) + } + defer s2.Destroy() + contents := s2.GetData() + if contents != 1337 { + t.Fatalf("Incorrect initial contents of s2: %d\n", contents) + } + s2.ZeroContents() + contents = s2.GetData() + if contents != 0 { + t.Fatalf("Incorrect value of s2 after zeroing: %d\n", contents) + } + s2.Set(1234) + contents = s2.GetData() + if contents != 1234 { + t.Fatalf("Incorrect value of s2: %d (expected 1234)\n", contents) + } +} + // See the comment in generate_network_big_compute.py for information about // the inputs and outputs used for testing or benchmarking session options. func prepareBenchmarkTensors(t testing.TB, seed int64) (*Tensor[float32], diff --git a/onnxruntime_training_c_api.h b/onnxruntime_training_c_api.h deleted file mode 100644 index ed6d151..0000000 --- a/onnxruntime_training_c_api.h +++ /dev/null @@ -1,731 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// This file contains the training c apis. - -#pragma once -#include -#include "onnxruntime_c_api.h" - -/** \page training_c_cpp_api Training C & C++ APIs - * - * Training C and C++ APIs are an extension of the \ref c_cpp_api "onnxruntime core C and C++ APIs" and should be used in conjunction with them. - * - * In order to train a model with onnxruntime, the following training artifacts must be generated: - * - The training onnx model - * - The checkpoint file - * - The optimizer onnx model - * - The eval onnx model model (optional) - * - * These training artifacts can be generated as part of an offline step using the python [utilities](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md) made available in the `onnxruntime-training` python package. - * - * After these artifacts have been generated, the C and C++ utilities listed in this documentation can be leveraged to perform training. - * - * If any problem is encountered, please create an [issue](https://github.com/microsoft/onnxruntime/issues/new) with your scenario and requirements, and we will be sure to respond and follow up on the request. - * - *

Training C API

- * - * ::OrtTrainingApi - Training C API functions. - * - * This C structure contains functions that enable users to perform training with onnxruntime. - * - * _Sample Code_: - * - * ```c - * #include - * - * OrtApi* g_ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION); - * OrtTrainingApi* g_ort_training_api = g_ort_api->GetTrainingApi(ORT_API_VERSION); - * - * OrtEnv* env = NULL; - * g_ort_api->CreateEnv(logging_level, logid, &env); - * OrtSessionOptions* session_options = NULL; - * g_ort_api->CreateSessionOptions(&session_options); - * - * OrtCheckpointState* state = NULL; - * g_ort_training_api->LoadCheckpoint(path_to_checkpoint, &state); - * - * OrtTrainingSession* training_session = NULL; - * g_ort_training_api->CreateTrainingSession(env, session_options, training_model_path, - * state, eval_model_path, optimizer_model_path, - * &training_session); - * // Training loop - * { - * g_ort_training_api->TrainStep(...); - * g_ort_training_api->OptimizerStep(...); - * g_ort_training_api->LazyResetGrad(...); - * } - * - * g_ort_training_api->ExportModelForInferencing(training_session, inference_model_path, ...); - * g_ort_training_api->SaveCheckpoint(state, path_to_checkpoint, false); - * - * g_ort_training_api->ReleaseTrainingSession(training_session); - * g_ort_training_api->ReleaseCheckpointState(state); - * ``` - * - * > **Note** - * > The ::OrtCheckpointState contains the entire training state that the ::OrtTrainingSession uses. As a result, the training session must always have access to the state. That is to say, the ::OrtCheckpointState instance must outlive the lifetime of the ::OrtTrainingSession instance. - * - *

Training C++ API

- * - * @ref TrainingCpp - Training C++ API classes and functions. - * - * These C++ classes and functions enable users to perform training with onnxruntime. - * - * _Sample Code_: - * - * ```cc - * #include - * - * Ort::Env env; - * Ort::SessionOptions session_options; - * - * auto state = Ort::CheckpointState::LoadCheckpoint(path_to_checkpoint); - * auto training_session = Ort::TrainingSession(env, session_options, state, training_model_path, - * eval_model_path, optimizer_model_path); - * - * // Training Loop - * { - * training_session.TrainStep(...); - * training_session.OptimizerStep(...); - * training_session.LazyResetGrad(...); - * } - * - * training_session->ExportModelForInferencing(inference_model_path, ...); - * Ort::CheckpointState::SaveCheckpoint(state, path_to_checkpoint, false); - * ``` - * > **Note** - * > The ::Ort::CheckpointState contains the entire training state that the ::Ort::TrainingSession uses. As a result, the training session must always have access to the state. That is to say, the ::Ort::CheckpointState instance must outlive the lifetime of the ::Ort::TrainingSession instance. - */ - -/** @defgroup TrainingC Ort Training C API - * @{ - */ -ORT_RUNTIME_CLASS(TrainingSession); // Type that enables performing training for the given user models. -ORT_RUNTIME_CLASS(CheckpointState); // Type that holds the training states for the training session. - -/** \brief Type of property to be added to or returned from the ::OrtCheckpointState. - */ -typedef enum OrtPropertyType { - OrtIntProperty = 0, - OrtFloatProperty = 1, - OrtStringProperty = 2, -} OrtPropertyType; - -/** \brief The Training C API that holds onnxruntime training function pointers - * - * All the Training C API functions are defined inside this structure as pointers to functions. - * Call OrtApi::GetTrainingApi to get a pointer to this struct. - * - * \nosubgrouping - */ -struct OrtTrainingApi { - /// \name Accessing The Training Session State - /// @{ - - /** \brief Load a checkpoint state from a file on disk into checkpoint_state. - * - * This function will parse a checkpoint file, pull relevant data and load the training - * state into the checkpoint_state. This checkpoint state can then be used to create the - * training session by invoking OrtTrainingApi::CreateTrainingSession. By doing so, the training - * session will resume training from the given checkpoint state. - * \note Note that the training session created with a checkpoint state uses this state to store the entire - * training state (including model parameters, its gradients, the optimizer states and the properties). - * As a result, it is required that the checkpoint state outlive the lifetime of the training session. - * \note Note that the checkpoint file can be either the complete checkpoint or the nominal checkpoint. - * - * \param[in] checkpoint_path Path to the checkpoint file - * \param[out] checkpoint_state Checkpoint state that contains the states of the training session. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(LoadCheckpoint, _In_ const ORTCHAR_T* checkpoint_path, - _Outptr_ OrtCheckpointState** checkpoint_state); - - /** \brief Save the given state to a checkpoint file on disk. - * - * This function serializes the provided checkpoint state to a file on disk. - * This checkpoint can later be loaded by invoking OrtTrainingApi::LoadCheckpoint to resume - * training from this snapshot of the state. - * - * \param[in] checkpoint_state The checkpoint state to save. - * \param[in] checkpoint_path Path to the checkpoint file. - * \param[in] include_optimizer_state Flag to indicate whether to save the optimizer state or not. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(SaveCheckpoint, _In_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* checkpoint_path, - const bool include_optimizer_state); - - /// @} - - /// \name Implementing The Training Loop - /// @{ - /** \brief Create a training session that can be used to begin or resume training. - * - * This function creates a training session based on the env and session options provided that can - * begin or resume training from a given checkpoint state for the given onnx models. - * The checkpoint state represents the parameters of the training session which will be moved - * to the device specified by the user through the session options (if necessary). - * The training session requires four training artifacts - * - The training onnx model - * - The evaluation onnx model (optional) - * - The optimizer onnx model - * - The checkpoint file - * - * These artifacts can be generated using the `onnxruntime-training` python [utility](https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/training/onnxblock/README.md). - * - * \param[in] env Environment to be used for the training session. - * \param[in] options Session options that the user can customize for this training session. - * \param[in] checkpoint_state Training states that the training session uses as a starting point for training. - * \param[in] train_model_path Model to be used to perform training. - * \param[in] eval_model_path Model to be used to perform evaluation. - * \param[in] optimizer_model_path Model to be used to perform gradient descent. - * \param[out] out Created training session. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options, - _Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path, - _In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path, - _Outptr_result_maybenull_ OrtTrainingSession** out); - - /** \brief Create a training session that can be used to begin or resume training. - * This api provides a way to load all the training artifacts from buffers instead of files. - * - * \param[in] env Environment to be used for the training session. - * \param[in] options Session options that the user can customize for this training session. - * \param[in] checkpoint_state Training states that the training session uses as a starting point for training. - * \param[in] train_model_data Buffer containing the model data to be used to perform training - * \param[in] train_data_length Length of the buffer containing train_model_data - * \param[in] eval_model_data Buffer containing the model data to be used to perform evaluation - * \param[in] eval_data_length Length of the buffer containing eval_model_data - * \param[in] optim_model_data Buffer containing the model data to be used to perform weight update - * \param[in] optim_data_length Length of the buffer containing optim_model_data - * \param[out] out Created training session. - * - */ - ORT_API2_STATUS(CreateTrainingSessionFromBuffer, _In_ const OrtEnv* env, - _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, - _In_ const void* train_model_data, size_t train_data_length, - _In_ const void* eval_model_data, size_t eval_data_length, - _In_ const void* optim_model_data, size_t optim_data_length, - _Outptr_result_maybenull_ OrtTrainingSession** out); - - /// @} - - /// \name Model IO Information - /// @{ - - /** \brief Retrieves the number of user outputs in the training model. - * - * This function returns the number of outputs of the training model so that the user can - * allocate space for the number of outputs when OrtTrainingApi::TrainStep is invoked. - * - * \param[in] sess The `this` pointer to the training session. - * \param[out] out Number of user outputs in the training model. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(TrainingSessionGetTrainingModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out); - - /** \brief Retrieves the number of user outputs in the eval model. - * - * This function returns the number of outputs of the eval model so that the user can - * allocate space for the number of outputs when OrtTrainingApi::EvalStep is invoked. - * - * \param[in] sess The `this` pointer to the training session. - * \param[out] out Number of user outputs in the eval model. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(TrainingSessionGetEvalModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out); - - /** \brief Retrieves the names of user outputs in the training model. - * - * This function returns the names of outputs of the training model that can be associated with the OrtValue(s) - * returned by the OrtTrainingApi::TrainStep function. - * - * \param[in] sess The `this` pointer to the training session. - * \param[in] index Index of the output name requested. - * \param[in] allocator Allocator to use to allocate the memory for the name. - * \param[out] output Name of the training model output at the given index. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(TrainingSessionGetTrainingModelOutputName, _In_ const OrtTrainingSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** output); - - /** \brief Retrieves the names of user outputs in the eval model. - * - * This function returns the names of outputs of the eval model that can be associated with the OrtValue(s) returned - * by the OrtTrainingApi::EvalStep function. - * - * \param[in] sess The `this` pointer to the training session. - * \param[in] index Index of the output name requested. - * \param[in] allocator Allocator to use to allocate the memory for the name. - * \param[out] output Name of the eval model output at the given index. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(TrainingSessionGetEvalModelOutputName, _In_ const OrtTrainingSession* sess, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** output); - - /// @} - - /// \name Implementing The Training Loop - /// @{ - - /** \brief Reset the gradients of all trainable parameters to zero lazily. - * - * This function sets the internal state of the training session such that the gradients of the trainable - * parameters in the OrtCheckpointState will be scheduled to be reset just before the new gradients are - * computed on the next invocation of the next OrtTrainingApi::TrainStep. - * - * \param[in] session The `this` pointer to the training session. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(LazyResetGrad, _Inout_ OrtTrainingSession* session); - - /** \brief Computes the outputs of the training model and the gradients of the trainable parameters for the given inputs - * - * This function performs a training step that computes the outputs of the training model and the gradients - * of the trainable parameters for the given inputs. The train step is performed based on the training model - * that was provided to the training session. - * The OrtTrainingApi::TrainStep is equivalent of running forward propagation and backward propagation in a single - * step. - * The gradients computed are stored inside the training session state so they can be later consumed - * by the OrtTrainingApi::OptimizerStep function. - * The gradients can be lazily reset by invoking the OrtTrainingApi::LazyResetGrad function. - * - * \param[in] sess The `this` pointer to the training session. - * \param[in] run_options Run options for this training step. - * \param[in] inputs_len Number of user inputs to the training model. - * \param[in] inputs The user inputs to the training model. - * \param[in] outputs_len Number of user outputs expected from this training step. - * \param[out] outputs User outputs computed by train step. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(TrainStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options, - _In_ size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs, - _In_ size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs); - - /** \brief Computes the outputs for the eval model for the given inputs - * - * This function performs an eval step that computes the outputs of the eval model for the given inputs. - * The eval step is performed based on the eval model that was provided to the training session. - * - * \param[in] sess The `this` pointer to the training session. - * \param[in] run_options Run options for this eval step. - * \param[in] inputs_len Number of user inputs to the eval model. - * \param[in] inputs The user inputs to the eval model. - * \param[in] outputs_len Number of user outputs expected from this eval step. - * \param[out] outputs User outputs computed by eval step. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(EvalStep, _In_ const OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options, - _In_ size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs, - _In_ size_t outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs); - - /** \brief Sets the learning rate for this training session. - * - * This function allows users to set the learning rate for the training session. The current - * learning rate is maintained by the training session and can be overwritten by invoking - * this function with the desired learning rate. This function should not be used when a valid - * learning rate scheduler is registered. It should be used either to set the learning rate - * derived from a custom learning rate scheduler or to set a constant learning rate to be used - * throughout the training session. - * \note Please note that this function does not set the initial learning rate that may be needed - * by the predefined learning rate schedulers. To set the initial learning rate for learning - * rate schedulers, please look at the function OrtTrainingApi::RegisterLinearLRScheduler. - * - * \param[in] sess The `this` pointer to the training session. - * \param[in] learning_rate Desired learning rate to be set. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(SetLearningRate, _Inout_ OrtTrainingSession* sess, _In_ float learning_rate); - - /** \brief Gets the current learning rate for this training session. - * - * This function allows users to get the learning rate for the training session. The current - * learning rate is maintained by the training session, and users can query it for the purpose - * of implementing their own learning rate schedulers. - * - * \param[in] sess The `this` pointer to the training session. - * \param[out] learning_rate Learning rate currently in use by the training session. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(GetLearningRate, _Inout_ OrtTrainingSession* sess, _Out_ float* learning_rate); - - /** \brief Performs the weight updates for the trainable parameters using the optimizer model. - * - * This function performs the weight update step that updates the trainable parameters such that they - * take a step in the direction of their gradients (gradient descent). The optimizer step is performed - * based on the optimizer model that was provided to the training session. - * The updated parameters are stored inside the training state so that they can be used by the next - * OrtTrainingApi::TrainStep function call. - * - * \param[in] sess The `this` pointer to the training session. - * \param[in] run_options Run options for this optimizer step. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(OptimizerStep, _Inout_ OrtTrainingSession* sess, - _In_opt_ const OrtRunOptions* run_options); - - /** \brief Registers a linear learning rate scheduler for the training session. - * - * Register a linear learning rate scheduler that decays the learning rate by linearly updated - * multiplicative factor from the initial learning rate set on the training session to 0. The decay - * is performed after the initial warm up phase where the learning rate is linearly incremented - * from 0 to the initial learning rate provided. - * - * \param[in] sess The `this` pointer to the training session. - * \param[in] warmup_step_count Warmup steps for LR warmup. - * \param[in] total_step_count Total step count. - * \param[in] initial_lr The initial learning rate to be used by the training session. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(RegisterLinearLRScheduler, _Inout_ OrtTrainingSession* sess, _In_ const int64_t warmup_step_count, - _In_ const int64_t total_step_count, _In_ const float initial_lr); - - /** \brief Update the learning rate based on the registered learing rate scheduler. - * - * Takes a scheduler step that updates the learning rate that is being used by the training session. - * This function should typically be called before invoking the optimizer step for each round, - * or as determined necessary to update the learning rate being used by the training session. - * \note Please note that a valid predefined learning rate scheduler must be first registered to invoke this - * function. - * - * \param[in] sess The `this` pointer to the training session. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(SchedulerStep, _Inout_ OrtTrainingSession* sess); - - /// @} - - /// \name Accessing The Training Session State - /// @{ - /** \brief Retrieves the size of all the parameters. - * - * Calculates the total number of primitive (datatype of the parameters) elements of all the parameters in the - * training state. - * When trainable_only argument is true, the size is calculated for trainable params only. - * - * \param[in] sess The `this` pointer to the training session. - * \param[out] out Size of all parameter elements. - * \param[in] trainable_only Whether to skip non-trainable parameters - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(GetParametersSize, _Inout_ OrtTrainingSession* sess, _Out_ size_t* out, bool trainable_only); - - /** \brief Copy all parameters to a contiguous buffer held by the argument parameters_buffer - * - * The parameters_buffer has to be of the size given by GetParametersSize api call, - * with matching setting for the argument trainable_only. All the target parameters must be of the same - * datatype. The OrtValue must be pre-allocated onto - * the desired device. This is a complementary function to OrtTrainingApi::CopyBufferToParameters. - * Parameter ordering is preserved. - * User is responsible for allocating and freeing the resources used by the parameters_buffer. - * - * \param[in] sess The `this` pointer to the training session. - * \param[in] trainable_only Whether to skip non-trainable parameters - * \param[out] parameters_buffer The pre-allocated OrtValue buffer to copy onto. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(CopyParametersToBuffer, _Inout_ OrtTrainingSession* sess, - _Inout_ OrtValue* parameters_buffer, bool trainable_only); - - /** \brief Copy parameter values from the given contiguous buffer held by parameters_buffer to the training state - * - * The parameters_buffer argument has to be of the size given by OrtTrainingApi::GetParametersSize api call, - * with matching setting for trainable_only argument. All the target parameters must be of the same - * datatype. This is a complementary function to OrtTrainingApi::CopyParametersToBuffer - * and can be used to load updated buffer values onto the training state. - * Parameter ordering is preserved. - * User is responsible for allocating and freeing the resources used by the parameters_buffer. - * In case the training session was created with a nominal checkpoint, invoking this function is required - * to load the updated parameters onto the checkpoint to complete it. - * - * \param[in] sess The `this` pointer to the training session. - * \param[in] trainable_only Whether to skip non-trainable parameters - * \param[out] parameters_buffer The pre-allocated OrtValue buffer to copy from. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(CopyBufferToParameters, _Inout_ OrtTrainingSession* sess, - _Inout_ OrtValue* parameters_buffer, bool trainable_only); - - /// @} - - /// \name Release Training Resources - /// @{ - - /** \brief Frees up the memory used up by the training session. - * - * This function frees up any memory that was allocated in the training session. The training - * session can no longer be used after this call. - * - */ - ORT_CLASS_RELEASE(TrainingSession); - - /** \brief Frees up the memory used up by the checkpoint state. - * - * This function frees up any memory that was allocated in the checkpoint state. The checkpoint - * state can no longer be used after this call. - * \note Note that the checkpoint state must be released only after the training session has been released. - * - */ - ORT_CLASS_RELEASE(CheckpointState); - - /// @} - - /// \name Prepare For Inferencing - /// @{ - /** \brief Export a model that can be used for inferencing. - * - * If the training session was provided with an eval model, the training session can generate - * an inference model if it knows the inference graph outputs. The input inference graph outputs - * are used to prune the eval model so that the inference model's outputs align with the provided outputs. - * The exported model is saved at the path provided and can be used for inferencing with InferenceSession. - * \note Note that the function re-loads the eval model from the path provided to OrtTrainingApi::CreateTrainingSession - * and expects that this path still be valid. - * - * \param[in] sess The `this` pointer to the training session. - * \param[in] inference_model_path Path where the inference model should be serialized to. - * \param[in] graph_outputs_len Size of the graph output names array. - * \param[in] graph_output_names Names of the outputs that are needed in the inference model. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(ExportModelForInferencing, _Inout_ OrtTrainingSession* sess, - _In_ const ORTCHAR_T* inference_model_path, size_t graph_outputs_len, - _In_reads_(graph_outputs_len) const char* const* graph_output_names); - - /// @} - - /// \name Training Utilities - /// @{ - /** \brief Sets the seed used for random number generation in Onnxruntime. - * - * Use this function to generate reproducible results. It should be noted that completely reproducible - * results are not guaranteed. - * - * \param[in] seed The seed to be set. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(SetSeed, _In_ const int64_t seed); - - /// @} - - /// \name Model IO Information - /// @{ - /** \brief Retrieves the number of user inputs in the training model. - * - * This function returns the number of inputs of the training model so that the user can accordingly - * allocate the OrtValue(s) provided to the OrtTrainingApi::TrainStep function. - * - * \param[in] sess The `this` pointer to the training session. - * \param[out] out Number of user inputs in the training model. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(TrainingSessionGetTrainingModelInputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out); - - /** \brief Retrieves the number of user inputs in the eval model. - * - * This function returns the number of inputs of the eval model so that the user can accordingly - * allocate the OrtValue(s) provided to the OrtTrainingApi::EvalStep function. - * - * \param[in] sess The `this` pointer to the training session. - * \param[out] out Number of user inputs in the eval model. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(TrainingSessionGetEvalModelInputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out); - - /** \brief Retrieves the name of the user input at given index in the training model. - * - * This function returns the names of inputs of the training model that can be associated with the - * OrtValue(s) provided to the OrtTrainingApi::TrainStep function. - * - * \param[in] sess The `this` pointer to the training session. - * \param[in] index The index of the training model input name requested. - * \param[in] allocator The allocator to use to allocate the memory for the requested name. - * \param[out] output Name of the user input for the training model at the given index. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(TrainingSessionGetTrainingModelInputName, _In_ const OrtTrainingSession* sess, size_t index, - _In_ OrtAllocator* allocator, _Outptr_ char** output); - - /** \brief Retrieves the name of the user input at given index in the eval model. - * - * This function returns the names of inputs of the eval model that can be associated with the OrtValue(s) provided - * to the OrtTrainingApi::EvalStep function. - * - * \param[in] sess The `this` pointer to the training session. - * \param[in] index The index of the eval model input name requested. - * \param[in] allocator The allocator to use to allocate the memory for the requested name. - * \param[out] output Name of the user input for the eval model at the given index. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(TrainingSessionGetEvalModelInputName, _In_ const OrtTrainingSession* sess, size_t index, - _In_ OrtAllocator* allocator, _Outptr_ char** output); - - /// @} - - /// \name Accessing The Training Session State - /// @{ - - /** \brief Adds or updates the given property to/in the checkpoint state. - * - * Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint - * state by the user by calling this function with the corresponding property name and value. - * The given property name must be unique to be able to successfully add the property. - * - * \param[in] checkpoint_state The checkpoint state which should hold the property. - * \param[in] property_name Name of the property being added or updated. - * \param[in] property_type Type of the property associated with the given name. - * \param[in] property_value Property value associated with the given name. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(AddProperty, _Inout_ OrtCheckpointState* checkpoint_state, - _In_ const char* property_name, _In_ enum OrtPropertyType property_type, - _In_ void* property_value); - - /** \brief Gets the property value associated with the given name from the checkpoint state. - * - * Gets the property value from an existing entry in the checkpoint state. The property must - * exist in the checkpoint state to be able to retrieve it successfully. - * - * \param[in] checkpoint_state The checkpoint state that is currently holding the property. - * \param[in] property_name Name of the property being retrieved. - * \param[in] allocator Allocator used to allocate the memory for the property_value. - * \param[out] property_type Type of the property associated with the given name. - * \param[out] property_value Property value associated with the given name. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(GetProperty, _In_ const OrtCheckpointState* checkpoint_state, - _In_ const char* property_name, _Inout_ OrtAllocator* allocator, - _Out_ enum OrtPropertyType* property_type, _Outptr_ void** property_value); - - /// @} - - /// \name Accessing The Training Session State - /// @{ - - /** \brief Load a checkpoint state from a buffer into checkpoint_state. - * - * This function will parse a checkpoint bytes buffer, pull relevant data and load the training - * state into the checkpoint_state. This checkpoint state can then be used to create the - * training session by invoking OrtTrainingApi::CreateTrainingSession. By doing so, the training - * session will resume training from the given checkpoint state. - * \note Note that the training session created with a checkpoint state uses this state to store the entire - * training state (including model parameters, its gradients, the optimizer states and the properties). - * As a result, it is required that the checkpoint state outlive the lifetime of the training session. - * - * \param[in] checkpoint_buffer Path to the checkpoint bytes buffer. - * \param[in] num_bytes Number of bytes in the checkpoint buffer. - * \param[out] checkpoint_state Checkpoint state that contains the states of the training session. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer, - _In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state); - - /** \brief Retrieves the type and shape information of the parameter associated with the given parameter name. - * - * This function retrieves the type and shape of the parameter associated with the given parameter name. - * The parameter must exist in the checkpoint state to be able to retrieve its type and shape information successfully. - * - * \param[in] checkpoint_state The checkpoint state. - * \param[in] parameter_name Name of the parameter being retrieved. - * \param[out] parameter_type_and_shape The type and shape of the parameter being retrieved. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state, - _In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape); - - /** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name. - * - * This function updates a model parameter in the checkpoint state with the given parameter data. - * The training session must be already created with the checkpoint state that contains the parameter - * being updated. The given parameter is copied over to the registered device for the training session. - * The parameter must exist in the checkpoint state to be able to update it successfully. - * - * \param[in] checkpoint_state The checkpoint state. - * \param[in] parameter_name Name of the parameter being updated. - * \param[in] parameter The parameter data that should replace the existing parameter data. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state, - _In_ const char* parameter_name, _In_ OrtValue* parameter); - - /** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name. - * - * This function retrieves the model parameter data from the checkpoint state for the given parameter name. - * The parameter is copied over and returned as an OrtValue. The training session must be already created - * with the checkpoint state that contains the parameter being retrieved. - * The parameter must exist in the checkpoint state to be able to retrieve it successfully. - * - * \param[in] checkpoint_state The checkpoint state. - * \param[in] parameter_name Name of the parameter being retrieved. - * \param[in] allocator Allocator used to allocate the memory for the parameter. - * \param[out] parameter The parameter data that is retrieved from the checkpoint state. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - */ - ORT_API2_STATUS(GetParameter, _In_ const OrtCheckpointState* checkpoint_state, - _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator, - _Outptr_ OrtValue** parameter); - - /// @} -}; - -typedef struct OrtTrainingApi OrtTrainingApi; - -/// @} diff --git a/onnxruntime_training_go.go b/onnxruntime_training_go.go deleted file mode 100644 index 05be9e5..0000000 --- a/onnxruntime_training_go.go +++ /dev/null @@ -1,639 +0,0 @@ -package onnxruntime_go - -// #cgo CFLAGS: -O2 -g -// -// #include "onnxruntime_wrapper.h" -import "C" -import ( - "fmt" - "os" - "path/filepath" - "unsafe" -) - -var trainingNotSupportedError error = fmt.Errorf("training not supported by onnx library") - -// Scalar is like a tensor but the underlying go slice is of length 1 and it has no dimension. -// It can be used to store e.g. the loss from a training cycle. -type Scalar[T TensorData] struct { - data []T - dataSize uintptr - ortValue *C.OrtValue -} - -func (s *Scalar[T]) GetShape() Shape { - return nil -} - -func (s *Scalar[T]) ZeroContents() { - C.memset(unsafe.Pointer(&s.data[0]), 0, C.size_t(s.dataSize)) -} - -func (s *Scalar[T]) Destroy() error { - C.ReleaseOrtValue(s.ortValue) - s.ortValue = nil - s.data = nil - s.dataSize = 0 - return nil -} - -// GetData returns the undelying data for the scalar. -// If you want to explicitly set the scalar's data, use Set. -func (t *Scalar[T]) GetData() T { - return t.data[0] -} - -// Set allows to explicitly set the underlying value for the scalar. -func (t *Scalar[T]) Set(value T) { - t.data = []T{value} -} - -func (t *Scalar[T]) DataType() C.ONNXTensorElementDataType { - return GetTensorElementDataType[T]() -} - -func (t *Scalar[_]) GetInternals() *ValueInternalData { - return &ValueInternalData{ - ortValue: t.ortValue, - } -} - -func (t *Scalar[_]) GetONNXType() ONNXType { - return ONNXTypeTensor -} - -// NewEmptyScalar creates a new scalar of type T. -func NewEmptyScalar[T TensorData]() (*Scalar[T], error) { - var data T - return NewScalar(data) -} - -// NewScalar creates a new scalar of type T backed by a value of type T. -// Note that, differently from tensors, this is not a []T but just a value T. -func NewScalar[T TensorData](data T) (*Scalar[T], error) { - if !IsInitialized() { - return nil, NotInitializedError - } - - dataSlice := []T{data} - var ortValue *C.OrtValue - dataType := GetTensorElementDataType[T]() - dataSize := unsafe.Sizeof(dataSlice[0]) * uintptr(1) - - status := C.CreateOrtTensorWithShape(unsafe.Pointer(&dataSlice[0]), - C.size_t(dataSize), nil, C.int64_t(0), ortMemoryInfo, dataType, &ortValue) - if status != nil { - return nil, statusToError(status) - } - toReturn := Scalar[T]{ - data: dataSlice, - dataSize: dataSize, - ortValue: ortValue, - } - return &toReturn, nil -} - -// TraininSession is the type that wraps the C training session object. -type TrainingSession struct { - ortTrainingSession *C.OrtTrainingSession - ortCheckpointState *C.OrtCheckpointState - inputs []*C.OrtValue - outputs []*C.OrtValue - trainingModelPath *C.char - optimizerModelPath *C.char - evalModelPath *C.char -} - -// ExportModel is used to export the final trained model to disk. It requires the path for -// the exported model as well as the names of the graph nodes to export. -// Note that currently the final model can only be exported if the session has been -// initialized with NewTrainingSession and the path to the eval model has been provided. -func (s *TrainingSession) ExportModel(path string, outputNames []string) error { - if s.evalModelPath == nil { - return fmt.Errorf("final model can only be exported if the eval model path is " + - "provided at session creation time (see NewTrainingSession)") - } - if path == "" { - return fmt.Errorf("path cannot be empty") - } - dir, _ := filepath.Split(path) - if _, err := os.Stat(dir); dir != "" && os.IsNotExist(err) { - return fmt.Errorf("directory %s does not exist", dir) - } - - cOutputNames := make([]*C.char, len(outputNames)) - for i, name := range outputNames { - cOutputNames[i] = C.CString(name) - } - cPath, err := createOrtCharString(path) - if err != nil { - return fmt.Errorf("Error converting export path to C string: %w", err) - } - outputLength := C.size_t(len(outputNames)) - defer func() { - for i := range cOutputNames { - C.free(unsafe.Pointer(cOutputNames[i])) - } - C.free(unsafe.Pointer(cPath)) - }() - status := C.ExportModel(s.ortTrainingSession, cPath, outputLength, &cOutputNames[0]) - if status != nil { - return statusToError(status) - } - return nil -} - -// SaveCheckpoint can be used to save the current checkpoint state at the specified path. -// This is useful to snapshot the training parameters to continue training later or on -// a different machine. -func (s *TrainingSession) SaveCheckpoint(path string, saveOptimizerState bool) error { - if path == "" { - return fmt.Errorf("path cannot be empty") - } - dir, _ := filepath.Split(path) - if _, err := os.Stat(dir); dir != "" && os.IsNotExist(err) { - return fmt.Errorf("directory %s does not exist", dir) - } - - cPath, err := createOrtCharString(path) - if err != nil { - return fmt.Errorf("Error converting path to C string: %w", err) - } - var saveOptimizer int - if saveOptimizerState { - saveOptimizer = 1 - } - - defer func() { - C.free(unsafe.Pointer(cPath)) - }() - - status := C.SaveCheckpoint(s.ortCheckpointState, cPath, C.size_t(saveOptimizer)) - if status != nil { - return statusToError(status) - } - return nil -} - -// Destroy frees all the C memory associated to a training session. -func (s *TrainingSession) Destroy() error { - if s.ortTrainingSession != nil { - C.ReleaseOrtTrainingSession(s.ortTrainingSession) - s.ortTrainingSession = nil - } - // note: checkpoint MUST be released after session - if s.ortCheckpointState != nil { - C.ReleaseCheckpointState(s.ortCheckpointState) - s.ortCheckpointState = nil - } - C.free(unsafe.Pointer(s.trainingModelPath)) - s.trainingModelPath = nil - C.free(unsafe.Pointer(s.evalModelPath)) - s.evalModelPath = nil - C.free(unsafe.Pointer(s.optimizerModelPath)) - s.optimizerModelPath = nil - s.inputs = nil - s.outputs = nil - return nil -} - -// TrainStep performs the training step. -func (s *TrainingSession) TrainStep() error { - inputLength := C.size_t(len(s.inputs)) - outputLength := C.size_t(len(s.outputs)) - status := C.TrainStep(s.ortTrainingSession, inputLength, &s.inputs[0], outputLength, &s.outputs[0]) - if status != nil { - return fmt.Errorf("error performing training step: %w", statusToError(status)) - } - return nil -} - -// TrainStep performs the optimizer step. -func (s *TrainingSession) OptimizerStep() error { - status := C.OptimizerStep(s.ortTrainingSession) - if status != nil { - return fmt.Errorf("error performing optimizer step: %w", statusToError(status)) - } - return nil -} - -// TrainStep performs the LazyResetGrad step. -func (s *TrainingSession) LazyResetGrad() error { - status := C.LazyResetGrad(s.ortTrainingSession) - if status != nil { - return fmt.Errorf("error performing lazyResetGrad step: %w", statusToError(status)) - } - return nil -} - -func getInputName(s *C.OrtTrainingSession, i int, model string) (string, error) { - var cName *C.char - var status *C.OrtStatus - switch model { - case "train": - status = C.TrainingSessionGetTrainingInputName(s, C.size_t(i), &cName) - case "eval": - status = C.TrainingSessionGetEvalInputName(s, C.size_t(i), &cName) - default: - return "", fmt.Errorf("%s model not recognized", model) - } - if status != nil { - return "", fmt.Errorf("error getting name: %w", statusToError(status)) - } - - name, e := convertORTString(cName) - if e != nil { - return "", fmt.Errorf("error converting C name to Go string: %w", e) - } - return name, nil -} - -func getOutputName(s *C.OrtTrainingSession, i int, model string) (string, error) { - var cName *C.char - var status *C.OrtStatus - switch model { - case "train": - status = C.TrainingSessionGetTrainingOutputName(s, C.size_t(i), &cName) - case "eval": - status = C.TrainingSessionGetEvalOutputName(s, C.size_t(i), &cName) - default: - return "", fmt.Errorf("%s model not recognized", model) - } - if status != nil { - return "", fmt.Errorf("error getting name: %w", statusToError(status)) - } - name, e := convertORTString(cName) - if e != nil { - return "", fmt.Errorf("error converting C name to Go string: %w", e) - } - return name, nil -} - -type TrainingInputOutputNames struct { - TrainingInputNames []string - EvalInputNames []string - TrainingOutputNames []string - EvalOutputNames []string -} - -// GetInputOutputNames returns the names of the training inputs and outputs -// for the training and validation models. Eval model is optional and can be empty -// string. -func GetInputOutputNames(checkpointStatePath string, - trainingModelPath string, - evalModelPath string) (*TrainingInputOutputNames, error) { - - options, e := NewSessionOptions() - if e != nil { - return nil, fmt.Errorf("failed creating options with error: %v\n", e) - } - defer options.Destroy() - - checkpointData, e := os.ReadFile(checkpointStatePath) - if e != nil { - return nil, fmt.Errorf("error reading %s: %w", checkpointStatePath, e) - } - - trainingData, e := os.ReadFile(trainingModelPath) - if e != nil { - return nil, fmt.Errorf("error reading %s: %w", checkpointStatePath, e) - } - - var evalData []byte - if evalModelPath != "" { - evalData, e = os.ReadFile(evalModelPath) - if e != nil { - return nil, fmt.Errorf("error reading %s: %w", evalModelPath, e) - } - } - - // create checkpoint C object - ortCheckpointState, e := createCCheckpoint(checkpointData) - if e != nil { - return nil, fmt.Errorf("error creating C checkpointState: %w", e) - } - - // create session C object - ortTrainingSession, e := createCTrainingSessionWithOnnxData(ortCheckpointState, - trainingData, evalData, nil, options) - if e != nil { - C.ReleaseCheckpointState(ortCheckpointState) - return nil, fmt.Errorf("error creating C training session: %w", e) - } - defer func() { - C.ReleaseOrtTrainingSession(ortTrainingSession) - C.ReleaseCheckpointState(ortCheckpointState) - }() - - var inputCountTraining, inputCountEval C.size_t - status := C.TrainingSessionGetInputCount(ortTrainingSession, &inputCountTraining, &inputCountEval) - if status != nil { - return nil, statusToError(status) - } - - var outputCountTraining, outputCountEval C.size_t - status = C.TrainingSessionGetOutputCount(ortTrainingSession, &outputCountTraining, &outputCountEval) - if status != nil { - return nil, statusToError(status) - } - - trainInputNames := make([]string, inputCountTraining) - trainOutputNames := make([]string, outputCountTraining) - - for i := 0; i < int(inputCountTraining); i++ { - name, err := getInputName(ortTrainingSession, i, "train") - if err != nil { - return nil, fmt.Errorf("error retrieving train input name: %w", err) - } - trainInputNames[i] = name - } - - for i := 0; i < int(outputCountTraining); i++ { - name, err := getOutputName(ortTrainingSession, i, "train") - if err != nil { - return nil, fmt.Errorf("error retrieving train output name: %w", err) - } - trainOutputNames[i] = name - } - - var evalInputNames []string - var evalOutputNames []string - - if len(evalData) > 0 { - evalInputNames = make([]string, inputCountEval) - evalOutputNames = make([]string, outputCountEval) - - for i := 0; i < int(inputCountEval); i++ { - name, err := getInputName(ortTrainingSession, i, "eval") - if err != nil { - return nil, fmt.Errorf("error retrieving eval input name: %w", err) - } - evalInputNames[i] = name - } - - for i := 0; i < int(outputCountTraining); i++ { - name, err := getOutputName(ortTrainingSession, i, "eval") - if err != nil { - return nil, fmt.Errorf("error retrieving eval output name: %w", err) - } - evalOutputNames[i] = name - } - } - - return &TrainingInputOutputNames{ - TrainingInputNames: trainInputNames, - EvalInputNames: evalInputNames, - TrainingOutputNames: trainOutputNames, - EvalOutputNames: evalOutputNames, - }, nil -} - -// IsTrainingSupported returns true if the training api is supported -// by the onnxruntime library. -func IsTrainingSupported() bool { - return C.IsTrainingApiSupported() != 0 -} - -func checkTraining() error { - if !IsInitialized() { - return NotInitializedError - } - if !IsTrainingSupported() { - return trainingNotSupportedError - } - return nil -} - -func createCCheckpoint(onnxData []byte) (*C.OrtCheckpointState, error) { - if e := checkTraining(); e != nil { - return nil, e - } - if len(onnxData) == 0 { - return nil, fmt.Errorf("Missing checkpoint data") - } - var ortCheckpointState *C.OrtCheckpointState - status := C.CreateCheckpoint(unsafe.Pointer(&(onnxData[0])), C.size_t(len(onnxData)), &ortCheckpointState) - if status != nil { - return nil, statusToError(status) - } - return ortCheckpointState, nil -} - -// createCTrainingSessionWithOnnxData creates a C session from byte data using buffers -func createCTrainingSessionWithOnnxData(checkpointState *C.OrtCheckpointState, - trainingData, evalData, optimizerData []byte, - options *SessionOptions) (*C.OrtTrainingSession, error) { - if e := checkTraining(); e != nil { - return nil, e - } - - var ortTrainingSession *C.OrtTrainingSession - var ortSessionOptions *C.OrtSessionOptions - if options != nil { - ortSessionOptions = options.o - } - - // eval model is optional - var evalDataPtr unsafe.Pointer - var evalDataSize C.size_t - if len(evalData) > 0 { - evalDataPtr = unsafe.Pointer(&(evalData[0])) - evalDataSize = C.size_t(len(evalData)) - } - - // optimizer model is also optional when e.g. getting input and output names - var optimizerDataPtr unsafe.Pointer - var optimizerDataSize C.size_t - if len(optimizerData) > 0 { - optimizerDataPtr = unsafe.Pointer(&(optimizerData[0])) - optimizerDataSize = C.size_t(len(optimizerData)) - } - - status := C.CreateTrainingSessionFromBuffer( - checkpointState, - unsafe.Pointer(&(trainingData[0])), C.size_t(len(trainingData)), - evalDataPtr, evalDataSize, - optimizerDataPtr, optimizerDataSize, - ortEnv, &ortTrainingSession, ortSessionOptions) - if status != nil { - return nil, statusToError(status) - } - return ortTrainingSession, nil -} - -// createCTrainingSessionWithPaths creates a C session from paths -func createCtrainingSessionWithPaths(checkpointState *C.OrtCheckpointState, - trainingPath, evalPath, optimizerPath *C.char, - options *SessionOptions) (*C.OrtTrainingSession, error) { - if e := checkTraining(); e != nil { - return nil, e - } - - var ortTrainingSession *C.OrtTrainingSession - var ortSessionOptions *C.OrtSessionOptions - if options != nil { - ortSessionOptions = options.o - } - - status := C.CreateTrainingSessionFromPaths(checkpointState, - trainingPath, evalPath, optimizerPath, ortEnv, &ortTrainingSession, ortSessionOptions) - - if status != nil { - return nil, statusToError(status) - } - return ortTrainingSession, nil -} - -// NewTrainingSessionWithOnnxData is like NewTrainingSession, but it accepts -// bytes rather than paths to the training assets. Note that there does not -// seem to currently be a way to export the trained model from a session -// instantiated from bytes. If you wish to export the trained model, you should -// use NewTrainingSession instead. -func NewTrainingSessionWithOnnxData(checkpointData []byte, - trainingData []byte, - evalData []byte, - optimizerData []byte, - inputs, outputs []Value, - options *SessionOptions) (*TrainingSession, error) { - - if err := checkTraining(); err != nil { - return nil, err - } - - if err := validateInputOutputs(inputs, outputs); err != nil { - return nil, err - } - - if len(trainingData) == 0 { - return nil, fmt.Errorf("training data has length zero.") - } - if len(optimizerData) == 0 { - return nil, fmt.Errorf("optimizer data has length zero.") - } - - // create checkpoint C object - ortCheckpointState, e := createCCheckpoint(checkpointData) - if e != nil { - return nil, fmt.Errorf("error creating C checkpointState: %w", e) - } - - // create session C object - ortTrainingSession, e := createCTrainingSessionWithOnnxData(ortCheckpointState, - trainingData, evalData, optimizerData, options) - if e != nil { - return nil, fmt.Errorf("error creating C training session: %w", e) - } - - inputOrtTensors := make([]*C.OrtValue, len(inputs)) - outputOrtTensors := make([]*C.OrtValue, len(outputs)) - for i, v := range inputs { - inputOrtTensors[i] = v.GetInternals().ortValue - } - for i, v := range outputs { - outputOrtTensors[i] = v.GetInternals().ortValue - } - - return &TrainingSession{ - ortCheckpointState: ortCheckpointState, - ortTrainingSession: ortTrainingSession, - inputs: inputOrtTensors, - outputs: outputOrtTensors, - }, nil -} - -func validateInputOutputs(inputs, outputs []Value) error { - if len(inputs) == 0 { - return fmt.Errorf("inputs must have length greater than zero") - } - if len(outputs) == 0 { - return fmt.Errorf("outputs must have length greater than zero") - } - return nil -} - -// NewTrainingSession creates a new training session from paths stored on disk. -// evalModelPath is optional and can be the empty string. In case it is not -// provided, only the checkpoint state can be exported once training is complete -// (and not the final inference model). -func NewTrainingSession(checkpointStatePath string, - trainingModelPath string, - evalModelPath string, - optimizerModelPath string, - inputs, outputs []Value, - options *SessionOptions) (*TrainingSession, error) { - - if err := checkTraining(); err != nil { - return nil, err - } - - if err := validateInputOutputs(inputs, outputs); err != nil { - return nil, err - } - - checkPointContent, e := os.ReadFile(checkpointStatePath) - if e != nil { - return nil, fmt.Errorf("reading checkpoint data failed: %s", e.Error()) - } - - // create checkpoint C object - ortCheckpointState, e := createCCheckpoint(checkPointContent) - if e != nil { - return nil, fmt.Errorf("error creating C checkpointState: %w", e) - } - - // create session C object - if _, err := os.Stat(trainingModelPath); os.IsNotExist(err) { - return nil, fmt.Errorf("training model does not exist at path %s", trainingModelPath) - } - cTrainingPath, err := createOrtCharString(trainingModelPath) - if err != nil { - return nil, fmt.Errorf("Error converting training model path to C string: %w", err) - } - - if _, err := os.Stat(optimizerModelPath); os.IsNotExist(err) { - return nil, fmt.Errorf("optimizer s does not exist at path %s", optimizerModelPath) - } - cOptimizerPath, err := createOrtCharString(optimizerModelPath) - if err != nil { - return nil, fmt.Errorf("Error converting optimizer path to C string: %w", err) - } - - // eval is optional - var cEvalPath *C.char - if evalModelPath != "" { - if _, err := os.Stat(evalModelPath); os.IsNotExist(err) { - return nil, fmt.Errorf("eval model does not exist at path %s", evalModelPath) - } - cEvalPath, err = createOrtCharString(evalModelPath) - if err != nil { - return nil, fmt.Errorf("Error converting eval path to C string: %w", err) - } - } else { - cEvalPath = nil - } - - ortTrainingSession, e := createCtrainingSessionWithPaths(ortCheckpointState, - cTrainingPath, cEvalPath, cOptimizerPath, options) - if e != nil { - return nil, fmt.Errorf("error creating C training session: %w", e) - } - - inputOrtTensors := make([]*C.OrtValue, len(inputs)) - outputOrtTensors := make([]*C.OrtValue, len(outputs)) - for i, v := range inputs { - inputOrtTensors[i] = v.GetInternals().ortValue - } - for i, v := range outputs { - outputOrtTensors[i] = v.GetInternals().ortValue - } - - return &TrainingSession{ - ortCheckpointState: ortCheckpointState, - ortTrainingSession: ortTrainingSession, - inputs: inputOrtTensors, - outputs: outputOrtTensors, - evalModelPath: cEvalPath, - trainingModelPath: cTrainingPath, - optimizerModelPath: cOptimizerPath, - }, nil -} diff --git a/onnxruntime_training_test.go b/onnxruntime_training_test.go deleted file mode 100644 index 6fa5bb5..0000000 --- a/onnxruntime_training_test.go +++ /dev/null @@ -1,353 +0,0 @@ -package onnxruntime_go - -import ( - "errors" - "math" - "math/rand" - "os" - "path" - "testing" -) - -func TestTrainingNotSupported(t *testing.T) { - InitializeRuntime(t) - defer CleanupRuntime(t) - - if IsTrainingSupported() { - t.Skipf("onnxruntime library supports training") - } - - options, e := NewSessionOptions() - if e != nil { - t.Logf("Failed creating options: %s\n", e) - t.FailNow() - } - - trainingSession, e := NewTrainingSession("test_data/onnxruntime_training_test/training_artifacts/checkpoint", - "test_data/onnxruntime_training_test/training_artifacts/training_model.onnx", - "test_data/onnxruntime_training_test/training_artifacts/eval_model.onnx", - "test_data/onnxruntime_training_test/training_artifacts/optimizer_model.onnx", - nil, nil, - options) - - if !errors.Is(e, trainingNotSupportedError) { - t.Logf("Creating training session when onnxruntime lib does not support it should return training not supported error.") - if e != nil { - t.Logf("Received instead error: %s", e.Error()) - } else { - t.Logf("Received no error instead") - } - t.FailNow() - } - if trainingSession != nil { - if err := trainingSession.Destroy(); err != nil { - t.Fatalf("cleanup of training session failed with error: %v", e) - } - } -} - -func TestGetInputOutputNames(t *testing.T) { - InitializeRuntime(t) - defer CleanupRuntime(t) - - if !IsTrainingSupported() { - t.Skipf("Training is not supported on this platform/onnxruntime build.") - } - - artifactsPath := path.Join("test_data", "training_test") - - names, err := GetInputOutputNames( - path.Join(artifactsPath, "checkpoint"), - path.Join(artifactsPath, "training_model.onnx"), - path.Join(artifactsPath, "eval_model.onnx"), - ) - - if err != nil { - t.Fatalf("Failed getting input and output names with error: %v\n", err) - } - - expectedTrainInputNames := []string{"input", "target"} - expectedEvalInputNames := expectedTrainInputNames - expectedTrainOutputNames := []string{"onnx::reducemean_output::5"} - expectedEvalOutputNames := expectedTrainOutputNames - - for i, v := range names.TrainingInputNames { - if v != expectedTrainInputNames[i] { - t.Fatalf("training input names don't match") - } - } - for i, v := range names.TrainingOutputNames { - if v != expectedTrainOutputNames[i] { - t.Fatalf("training output names don't match") - } - } - for i, v := range names.EvalInputNames { - if v != expectedEvalInputNames[i] { - t.Fatalf("eval input names don't match") - } - } - for i, v := range names.EvalOutputNames { - if v != expectedEvalOutputNames[i] { - t.Fatalf("eval output names don't match") - } - } - - // without eval model - names, err = GetInputOutputNames( - path.Join(artifactsPath, "checkpoint"), - path.Join(artifactsPath, "training_model.onnx"), - "", - ) - - if err != nil { - t.Fatalf("Failed getting input and output names with error: %v\n", err) - } - - for i, v := range names.TrainingInputNames { - if v != expectedTrainInputNames[i] { - t.Fatalf("training input names don't match") - } - } - for i, v := range names.TrainingOutputNames { - if v != expectedTrainOutputNames[i] { - t.Fatalf("training output names don't match") - } - } -} - -func generateBatchData(nBatches int, batchSize int) map[int]map[string][]float32 { - batchData := map[int]map[string][]float32{} - - source := rand.NewSource(1234) - g := rand.New(source) - - for i := 0; i < nBatches; i++ { - inputCounter := 0 - outputCounter := 0 - inputSlice := make([]float32, batchSize*4) - outputSlice := make([]float32, batchSize*2) - batchData[i] = map[string][]float32{} - - // generate random data for batch - for n := 0; n < batchSize; n++ { - var sum float32 - min := float32(1) - max := float32(-1) - for i := 0; i < 4; i++ { - r := g.Float32() - inputSlice[inputCounter] = r - inputCounter++ - if r > max { - max = r - } - if r < min { - min = r - } - sum = sum + r - } - outputSlice[outputCounter] = sum - outputSlice[outputCounter+1] = max - min - outputCounter = outputCounter + 2 - } - batchData[i]["input"] = inputSlice - batchData[i]["output"] = outputSlice - } - return batchData -} - -// TestTraining tests a basic training flow using the bindings to the C api for on-device onnxruntime training -func TestTraining(t *testing.T) { - InitializeRuntime(t) - defer CleanupRuntime(t) - - if !IsTrainingSupported() { - t.Skipf("Training is not supported on this platform/onnxruntime build.") - } - - trainingArtifactsFolder := path.Join("test_data", "training_test") - - // generate training data - batchSize := 10 - nBatches := 10 - - // holds inputs/outputs and loss for each training batch - batchInputShape := NewShape(int64(batchSize), 1, 4) - batchTargetShape := NewShape(int64(batchSize), 1, 2) - batchInputTensor, err := NewEmptyTensor[float32](batchInputShape) - if err != nil { - t.Fatalf("training test failed with error: %v", err) - } - batchTargetTensor, err := NewEmptyTensor[float32](batchTargetShape) - if err != nil { - t.Fatalf("training test failed with error: %v", err) - } - lossScalar, err := NewEmptyScalar[float32]() - if err != nil { - t.Fatalf("training test failed with error: %v", err) - } - - trainingSession, errorSessionCreation := NewTrainingSession( - path.Join(trainingArtifactsFolder, "checkpoint"), - path.Join(trainingArtifactsFolder, "training_model.onnx"), - path.Join(trainingArtifactsFolder, "eval_model.onnx"), - path.Join(trainingArtifactsFolder, "optimizer_model.onnx"), - []Value{batchInputTensor, batchTargetTensor}, []Value{lossScalar}, - nil) - - if errorSessionCreation != nil { - t.Fatalf("session creation failed with error: %v", errorSessionCreation) - } - - // cleanup after test run - defer func(session *TrainingSession, tensors []Value) { - var errs []error - errs = append(errs, session.Destroy()) - for _, t := range tensors { - errs = append(errs, t.Destroy()) - } - if e := errors.Join(errs...); e != nil { - t.Fatalf("cleanup of test failed with error: %v", e) - } - }(trainingSession, []Value{batchInputTensor, batchTargetTensor, lossScalar}) - - losses := []float32{} - epochs := 100 - batchData := generateBatchData(nBatches, batchSize) - - for epoch := 0; epoch < epochs; epoch++ { - var epochLoss float32 // total epoch loss - - for i := 0; i < nBatches; i++ { - inputSlice := batchInputTensor.GetData() - outputSlice := batchTargetTensor.GetData() - - copy(inputSlice, batchData[i]["input"]) - copy(outputSlice, batchData[i]["output"]) - - // train on batch - err = trainingSession.TrainStep() - if err != nil { - t.Fatalf("train step failed with error: %v", err) - } - - epochLoss = epochLoss + lossScalar.GetData() - - err = trainingSession.OptimizerStep() - if err != nil { - t.Fatalf("optimizer step failed with error: %v", err) - } - - // ort training api - reset the gradients to zero so that new gradients can be computed for next batch - err = trainingSession.LazyResetGrad() - if err != nil { - t.Fatalf("lazy reset grad step failed with error: %v", err) - } - } - if epoch%10 == 0 { - t.Logf("Epoch {%d} Loss {%f}\n", epoch+1, epochLoss/float32(batchSize*nBatches)) - losses = append(losses, epochLoss/float32(batchSize*nBatches)) - } - } - - expectedLosses := []float32{ - 0.125085, - 0.097187, - 0.062333, - 0.024307, - 0.019963, - 0.018476, - 0.017160, - 0.015982, - 0.014845, - 0.013867, - } - - for i, l := range losses { - diff := math.Abs(float64(l - expectedLosses[i])) - deviation := diff / float64(expectedLosses[i]) - if deviation > 0.6 { - t.Fatalf("loss deviation too large: expected %f, actual %f, deviation %f", float64(expectedLosses[i]), float64(l), float64(deviation)) - } - } - - // test the saving of the checkpoint state - finalCheckpointPath := path.Join("test_data", "training_test", "finalCheckpoint") - errSaveCheckpoint := trainingSession.SaveCheckpoint(finalCheckpointPath, false) - if errSaveCheckpoint != nil { - t.Fatalf("Saving of checkpoint failed with error: %v", errSaveCheckpoint) - } - - // test the saving of the model - finalModelPath := path.Join("test_data", "training_test", "final_inference.onnx") - errExport := trainingSession.ExportModel(finalModelPath, []string{"output"}) - if errExport != nil { - t.Fatalf("Exporting model failed with error: %v", errExport) - } - - defer func() { - e := os.Remove(finalCheckpointPath) - if e != nil { - t.Errorf("Error removing final checkpoint file %s: %s", finalCheckpointPath, e) - } - e = os.Remove(finalModelPath) - if e != nil { - t.Errorf("Error removing final model file %s: %s", finalModelPath, e) - } - }() - - // load the model back in and test in-sample predictions for the first batch - // (we care about correctness more than generalization here) - session, err := NewAdvancedSession(path.Join("test_data", "training_test", "final_inference.onnx"), - []string{"input"}, []string{"output"}, - []Value{batchInputTensor}, []Value{batchTargetTensor}, nil) - - if err != nil { - t.Fatalf("creation of inference session failed with error: %v", err) - } - - defer func(s *AdvancedSession) { - err := s.Destroy() - if err != nil { - t.Fatalf("cleanup of inference session failed with error: %v", err) - } - }(session) - - // Calling Run() will run the network, reading the current contents of the - // input tensors and modifying the contents of the output tensors. - copy(batchInputTensor.GetData(), batchData[0]["input"]) - err = session.Run() - if err != nil { - t.Fatalf("run of inference session failed with error: %v", err) - } - - expectedOutput := []float32{ - 2.4524384, - 0.65120333, - 2.5457804, - 0.6102175, - 1.6276635, - 0.573755, - 1.7900972, - 0.59951085, - 3.1650176, - 0.66626525, - 1.9361509, - 0.571084, - 2.0798547, - 0.6060241, - 0.9611889, - 0.52100605, - 1.4070896, - 0.5412475, - 2.1449144, - 0.5985652, - } - - for i, l := range batchTargetTensor.GetData() { - diff := math.Abs(float64(l - expectedOutput[i])) - deviation := diff / float64(expectedOutput[i]) - if deviation > 0.6 { - t.Fatalf("deviation too large") - } - } -} diff --git a/onnxruntime_wrapper.c b/onnxruntime_wrapper.c index 663813a..f67b192 100644 --- a/onnxruntime_wrapper.c +++ b/onnxruntime_wrapper.c @@ -402,146 +402,3 @@ OrtStatus *CreateOrtValue(OrtValue **in, size_t num_values, value_type, out); } -// TRAINING API WRAPPER - -static const OrtTrainingApi *ort_training_api = NULL; - -void SetTrainingApi() { - ort_training_api = ort_api->GetTrainingApi(ORT_API_VERSION); -} - -int IsTrainingApiSupported() { - return ort_training_api != NULL; -} - -OrtStatus *CreateCheckpoint(void *checkpoint_data, size_t checkpoint_data_length, OrtCheckpointState **out) { - OrtStatus *status = NULL; - status = ort_training_api->LoadCheckpointFromBuffer(checkpoint_data, checkpoint_data_length, out); - return status; -} - -OrtStatus *CreateTrainingSessionFromBuffer(OrtCheckpointState *checkpoint_state, - void *training_model_data, size_t training_model_data_length, - void *eval_model_data, size_t eval_model_data_length, - void *optim_model_data, size_t optim_model_data_length, - OrtEnv *env, OrtTrainingSession **out, OrtSessionOptions *options) { - OrtStatus *status = NULL; - int default_options = 0; - if (!options) { - default_options = 1; - status = ort_api->CreateSessionOptions(&options); - if (status) return status; - } - status = ort_training_api->CreateTrainingSessionFromBuffer(env, options, checkpoint_state, - training_model_data, training_model_data_length, eval_model_data, eval_model_data_length, - optim_model_data, optim_model_data_length, out); - if (default_options) { - ort_api->ReleaseSessionOptions(options); - } - return status; -} - -OrtStatus *CreateTrainingSessionFromPaths(OrtCheckpointState *checkpoint_state, - char *training_model_path, char *eval_model_path, char *optim_model_path, - OrtEnv *env, OrtTrainingSession **out, OrtSessionOptions *options) { - OrtStatus *status = NULL; - int default_options = 0; - if (!options) { - default_options = 1; - status = ort_api->CreateSessionOptions(&options); - if (status) return status; - } - status = ort_training_api->CreateTrainingSession(env, options, checkpoint_state, - training_model_path, eval_model_path, optim_model_path, out); - if (default_options) { - ort_api->ReleaseSessionOptions(options); - } - return status; -} - -OrtStatus *TrainingSessionGetInputCount(OrtTrainingSession *training_session, size_t *result_training, size_t *result_eval) { - OrtStatus *status = NULL; - status = ort_training_api->TrainingSessionGetTrainingModelInputCount(training_session, result_training); - if (status) return status; - status = ort_training_api->TrainingSessionGetEvalModelInputCount(training_session, result_eval); - return status; -} - -OrtStatus *TrainingSessionGetOutputCount(OrtTrainingSession *training_session, size_t *result_training, size_t *result_eval) { - OrtStatus *status = NULL; - status = ort_training_api->TrainingSessionGetTrainingModelOutputCount(training_session, result_training); - if (status) return status; - status = ort_training_api->TrainingSessionGetEvalModelOutputCount(training_session, result_eval); - return status; -} - -OrtStatus *TrainingSessionGetTrainingInputName(OrtTrainingSession *training_session, size_t i, char **name) { - OrtAllocator *allocator = NULL; - OrtStatus *status = NULL; - status = ort_api->GetAllocatorWithDefaultOptions(&allocator); - if (status) return status; - return ort_training_api->TrainingSessionGetTrainingModelInputName(training_session, i, allocator, name); -} - -OrtStatus *TrainingSessionGetTrainingOutputName(OrtTrainingSession *training_session, size_t i, char **name) { - OrtAllocator *allocator = NULL; - OrtStatus *status = NULL; - status = ort_api->GetAllocatorWithDefaultOptions(&allocator); - if (status) return status; - return ort_training_api->TrainingSessionGetTrainingModelOutputName(training_session, i, allocator, name); -} - -OrtStatus *TrainingSessionGetEvalInputName(OrtTrainingSession *training_session, size_t i, char **name) { - OrtAllocator *allocator = NULL; - OrtStatus *status = NULL; - status = ort_api->GetAllocatorWithDefaultOptions(&allocator); - if (status) return status; - return ort_training_api->TrainingSessionGetEvalModelInputName(training_session, i, allocator, name); -} - -OrtStatus *TrainingSessionGetEvalOutputName(OrtTrainingSession *training_session, size_t i, char **name) { - OrtAllocator *allocator = NULL; - OrtStatus *status = NULL; - status = ort_api->GetAllocatorWithDefaultOptions(&allocator); - if (status) return status; - return ort_training_api->TrainingSessionGetEvalModelOutputName(training_session, i, allocator, name); -} - -OrtStatus *TrainStep(OrtTrainingSession *training_session, size_t inputs_len, OrtValue **inputs, size_t output_len, OrtValue **outputs) { - OrtStatus *status = NULL; - status = ort_training_api->TrainStep(training_session, NULL, inputs_len, (const OrtValue* const*) inputs, output_len, outputs); - return status; -} - -OrtStatus *OptimizerStep(OrtTrainingSession *training_session) { - OrtStatus *status = NULL; - status = ort_training_api->OptimizerStep(training_session, NULL); - return status; -} - -OrtStatus *LazyResetGrad(OrtTrainingSession *training_session) { - OrtStatus *status = NULL; - status = ort_training_api->LazyResetGrad(training_session); - return status; -} - -OrtStatus *SaveCheckpoint(OrtCheckpointState *checkpoint, char *path, size_t include_optimizer) { - OrtStatus *status = NULL; - status = ort_training_api->SaveCheckpoint(checkpoint, path, include_optimizer); - return status; -} - -OrtStatus *ExportModel(OrtTrainingSession *training_session, char *path, size_t outputs_len, char **output_names) { - OrtStatus *status = NULL; - status = ort_training_api->ExportModelForInferencing(training_session, path, outputs_len, (const char* const*) output_names); - return status; -} - -void ReleaseOrtTrainingSession(OrtTrainingSession *session) { - ort_training_api->ReleaseTrainingSession(session); -} - -void ReleaseCheckpointState(OrtCheckpointState *checkpoint) { - ort_training_api->ReleaseCheckpointState(checkpoint); -} - diff --git a/onnxruntime_wrapper.h b/onnxruntime_wrapper.h index 7865660..dc4b749 100644 --- a/onnxruntime_wrapper.h +++ b/onnxruntime_wrapper.h @@ -15,7 +15,6 @@ // Next, we actually include the header. #undef _WIN32 #include "onnxruntime_c_api.h" -#include "onnxruntime_training_c_api.h" // ... However, mingw will complain if _WIN32 is *not* defined! So redefine it. #define _WIN32 @@ -254,75 +253,6 @@ OrtStatus *GetValueCount(OrtValue *v, size_t *out); OrtStatus *CreateOrtValue(OrtValue **in, size_t num_values, enum ONNXType value_type, OrtValue **out); -// TRAINING API WRAPPER - -void SetTrainingApi(); - -// Checks if training api is supported. -int IsTrainingApiSupported(); - -// Wraps ort_training_api->CreateSessionFromBuffer. -// Creates and ORT checkpoint from the checkpoint data. -OrtStatus *CreateCheckpoint(void *checkpoint_data, - size_t checkpoint_data_length, OrtCheckpointState **out); - -// Wraps ort_training_api->CreateTrainingSessionFromBuffer. Creates an ORT -// training session using the given models and checkpoint. The given options -// pointer may be NULL; if it is, then we'll use default options. -OrtStatus *CreateTrainingSessionFromBuffer(OrtCheckpointState *checkpoint_state, - void *training_model_data, size_t training_model_data_length, - void *eval_model_data, size_t eval_model_data_length, - void *optim_model_data, size_t optim_model_data_length, - OrtEnv *env, OrtTrainingSession **out, OrtSessionOptions *options); - -// Wraps ort_training_api->CreateTrainingSession. -// Currently this is the only way to create a training session that is able to -// export the final trained model to disk. -OrtStatus *CreateTrainingSessionFromPaths(OrtCheckpointState *checkpoint_state, - char *training_model_path, char *eval_model_path, char *optim_model_path, - OrtEnv *env, OrtTrainingSession **out, OrtSessionOptions *options); - -// Wraps ort_training_api->TrainingSessionGetTrainingModelInputCount -// and ort_training_api->TrainingSessionGetEvalgModelInputCount. -OrtStatus *TrainingSessionGetInputCount(OrtTrainingSession *training_session, size_t *result_training, size_t *result_eval); - -// Wraps ort_training_api->TrainingSessionGetTrainingModelOutputCounet -// and ort_training_api->TrainingSessionGetEvalgModelOutputCount. -OrtStatus *TrainingSessionGetOutputCount(OrtTrainingSession *training_session, size_t *result_training, size_t *result_eval); - -// Wraps ort_training_api->TrainingSessionGetTrainingModelInputName. -OrtStatus *TrainingSessionGetTrainingInputName(OrtTrainingSession *training_session, size_t i, char **name); - -// Wraps ort_training_api->TrainingSessionGetEvalModelInputName. -OrtStatus *TrainingSessionGetEvalInputName(OrtTrainingSession *training_session, size_t i, char **name); - -// Wraps ort_training_api->TrainingSessionGetTrainingModelOutputName. -OrtStatus *TrainingSessionGetTrainingOutputName(OrtTrainingSession *training_session, size_t i, char **name); - -// Wraps ort_training_api->TrainingSessionGetEvalModelOutputName. -OrtStatus *TrainingSessionGetEvalOutputName(OrtTrainingSession *training_session, size_t i, char **name); - -// Wraps ort_training_api->TrainStep. -OrtStatus *TrainStep(OrtTrainingSession *training_session, size_t inputs_len, OrtValue **inputs, size_t output_len, OrtValue **outputs); - -// Wraps ort_training_api->OptimizerStep. -OrtStatus *OptimizerStep(OrtTrainingSession *training_session); - -// Wraps ort_training_api->LazyResetGrad. -OrtStatus *LazyResetGrad(OrtTrainingSession *training_session); - -// Wraps ort_training_api->SaveCheckpoint. -OrtStatus *SaveCheckpoint(OrtCheckpointState *checkpoint, char *path, size_t include_optimizer); - -// Wraps ort_training_api->ExportModel. -OrtStatus *ExportModel(OrtTrainingSession *training_session, char *path, size_t outputs_len, char **output_names); - -// Wraps ort_training_api->ReleaseTrainingSession. -void ReleaseOrtTrainingSession(OrtTrainingSession *session); - -// Wraps ort_training_api->ReleaseCheckpointState. -void ReleaseCheckpointState(OrtCheckpointState *checkpoint); - #ifdef __cplusplus } // extern "C" #endif diff --git a/setup_env_windows.go b/setup_env_windows.go index f97af36..6b636e3 100644 --- a/setup_env_windows.go +++ b/setup_env_windows.go @@ -56,10 +56,6 @@ func platformInitializeEnvironment() error { return fmt.Errorf("Error setting ORT API base: %d", tmp) } - // we do not initialize the training API on windows (see setup_env.go) - // because currently we cannot support the conversion from UTF-8 to wide - // character. See https://github.com/yalue/onnxruntime_go/pull/56. - libraryHandle = handle return nil }