mirror of
https://github.com/yalue/onnxruntime_go.git
synced 2025-11-01 19:22:39 +08:00
- Decided to write down some contribution guidelines. - Clean up the files generated by onnxruntime_training_test.go.
355 lines
9.6 KiB
Go
355 lines
9.6 KiB
Go
package onnxruntime_go
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"math/rand"
|
|
"os"
|
|
"path"
|
|
"testing"
|
|
)
|
|
|
|
func TestTrainingNotSupported(t *testing.T) {
|
|
InitializeRuntime(t)
|
|
defer CleanupRuntime(t)
|
|
|
|
if IsTrainingSupported() {
|
|
t.Skipf("onnxruntime library supports training")
|
|
}
|
|
|
|
options, e := NewSessionOptions()
|
|
if e != nil {
|
|
t.Logf("Failed creating options: %s\n", e)
|
|
t.FailNow()
|
|
}
|
|
|
|
trainingSession, e := NewTrainingSession("test_data/onnxruntime_training_test/training_artifacts/checkpoint",
|
|
"test_data/onnxruntime_training_test/training_artifacts/training_model.onnx",
|
|
"test_data/onnxruntime_training_test/training_artifacts/eval_model.onnx",
|
|
"test_data/onnxruntime_training_test/training_artifacts/optimizer_model.onnx",
|
|
nil, nil,
|
|
options)
|
|
|
|
if !errors.Is(e, trainingNotSupportedError) {
|
|
t.Logf("Creating training session when onnxruntime lib does not support it should return training not supported error.")
|
|
if e != nil {
|
|
t.Logf("Received instead error: %s", e.Error())
|
|
} else {
|
|
t.Logf("Received no error instead")
|
|
}
|
|
t.FailNow()
|
|
}
|
|
if trainingSession != nil {
|
|
if err := trainingSession.Destroy(); err != nil {
|
|
t.Fatalf("cleanup of training session failed with error: %v", e)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestGetInputOutputNames(t *testing.T) {
|
|
InitializeRuntime(t)
|
|
defer CleanupRuntime(t)
|
|
|
|
if !IsTrainingSupported() {
|
|
t.Skipf("Training is not supported on this platform/onnxruntime build.")
|
|
}
|
|
|
|
artifactsPath := path.Join("test_data", "training_test")
|
|
|
|
names, err := GetInputOutputNames(
|
|
path.Join(artifactsPath, "checkpoint"),
|
|
path.Join(artifactsPath, "training_model.onnx"),
|
|
path.Join(artifactsPath, "eval_model.onnx"),
|
|
)
|
|
|
|
if err != nil {
|
|
t.Fatalf("Failed getting input and output names with error: %v\n", err)
|
|
}
|
|
|
|
expectedTrainInputNames := []string{"input", "target"}
|
|
expectedEvalInputNames := expectedTrainInputNames
|
|
expectedTrainOutputNames := []string{"onnx::reducemean_output::5"}
|
|
expectedEvalOutputNames := expectedTrainOutputNames
|
|
|
|
for i, v := range names.TrainingInputNames {
|
|
if v != expectedTrainInputNames[i] {
|
|
t.Fatalf("training input names don't match")
|
|
}
|
|
}
|
|
for i, v := range names.TrainingOutputNames {
|
|
if v != expectedTrainOutputNames[i] {
|
|
t.Fatalf("training output names don't match")
|
|
}
|
|
}
|
|
for i, v := range names.EvalInputNames {
|
|
if v != expectedEvalInputNames[i] {
|
|
t.Fatalf("eval input names don't match")
|
|
}
|
|
}
|
|
for i, v := range names.EvalOutputNames {
|
|
if v != expectedEvalOutputNames[i] {
|
|
t.Fatalf("eval output names don't match")
|
|
}
|
|
}
|
|
|
|
// without eval model
|
|
names, err = GetInputOutputNames(
|
|
path.Join(artifactsPath, "checkpoint"),
|
|
path.Join(artifactsPath, "training_model.onnx"),
|
|
"",
|
|
)
|
|
|
|
if err != nil {
|
|
t.Fatalf("Failed getting input and output names with error: %v\n", err)
|
|
}
|
|
|
|
for i, v := range names.TrainingInputNames {
|
|
if v != expectedTrainInputNames[i] {
|
|
t.Fatalf("training input names don't match")
|
|
}
|
|
}
|
|
for i, v := range names.TrainingOutputNames {
|
|
if v != expectedTrainOutputNames[i] {
|
|
t.Fatalf("training output names don't match")
|
|
}
|
|
}
|
|
}
|
|
|
|
func generateBatchData(nBatches int, batchSize int) map[int]map[string][]float32 {
|
|
batchData := map[int]map[string][]float32{}
|
|
|
|
source := rand.NewSource(1234)
|
|
g := rand.New(source)
|
|
|
|
for i := 0; i < nBatches; i++ {
|
|
inputCounter := 0
|
|
outputCounter := 0
|
|
inputSlice := make([]float32, batchSize*4)
|
|
outputSlice := make([]float32, batchSize*2)
|
|
batchData[i] = map[string][]float32{}
|
|
|
|
// generate random data for batch
|
|
for n := 0; n < batchSize; n++ {
|
|
var sum float32
|
|
min := float32(1)
|
|
max := float32(-1)
|
|
for i := 0; i < 4; i++ {
|
|
r := g.Float32()
|
|
inputSlice[inputCounter] = r
|
|
inputCounter++
|
|
if r > max {
|
|
max = r
|
|
}
|
|
if r < min {
|
|
min = r
|
|
}
|
|
sum = sum + r
|
|
}
|
|
outputSlice[outputCounter] = sum
|
|
outputSlice[outputCounter+1] = max - min
|
|
outputCounter = outputCounter + 2
|
|
}
|
|
batchData[i]["input"] = inputSlice
|
|
batchData[i]["output"] = outputSlice
|
|
}
|
|
return batchData
|
|
}
|
|
|
|
// TestTraining tests a basic training flow using the bindings to the C api for on-device onnxruntime training
|
|
func TestTraining(t *testing.T) {
|
|
InitializeRuntime(t)
|
|
defer CleanupRuntime(t)
|
|
|
|
if !IsTrainingSupported() {
|
|
t.Skipf("Training is not supported on this platform/onnxruntime build.")
|
|
}
|
|
|
|
trainingArtifactsFolder := path.Join("test_data", "training_test")
|
|
|
|
// generate training data
|
|
batchSize := 10
|
|
nBatches := 10
|
|
|
|
// holds inputs/outputs and loss for each training batch
|
|
batchInputShape := NewShape(int64(batchSize), 1, 4)
|
|
batchTargetShape := NewShape(int64(batchSize), 1, 2)
|
|
batchInputTensor, err := NewEmptyTensor[float32](batchInputShape)
|
|
if err != nil {
|
|
t.Fatalf("training test failed with error: %v", err)
|
|
}
|
|
batchTargetTensor, err := NewEmptyTensor[float32](batchTargetShape)
|
|
if err != nil {
|
|
t.Fatalf("training test failed with error: %v", err)
|
|
}
|
|
lossScalar, err := NewEmptyScalar[float32]()
|
|
if err != nil {
|
|
t.Fatalf("training test failed with error: %v", err)
|
|
}
|
|
|
|
trainingSession, errorSessionCreation := NewTrainingSession(
|
|
path.Join(trainingArtifactsFolder, "checkpoint"),
|
|
path.Join(trainingArtifactsFolder, "training_model.onnx"),
|
|
path.Join(trainingArtifactsFolder, "eval_model.onnx"),
|
|
path.Join(trainingArtifactsFolder, "optimizer_model.onnx"),
|
|
[]ArbitraryTensor{batchInputTensor, batchTargetTensor}, []ArbitraryTensor{lossScalar},
|
|
nil)
|
|
|
|
if errorSessionCreation != nil {
|
|
t.Fatalf("session creation failed with error: %v", errorSessionCreation)
|
|
}
|
|
|
|
// cleanup after test run
|
|
defer func(session *TrainingSession, tensors []ArbitraryTensor) {
|
|
var errs []error
|
|
errs = append(errs, session.Destroy())
|
|
for _, t := range tensors {
|
|
errs = append(errs, t.Destroy())
|
|
}
|
|
if e := errors.Join(errs...); e != nil {
|
|
t.Fatalf("cleanup of test failed with error: %v", e)
|
|
}
|
|
}(trainingSession, []ArbitraryTensor{batchInputTensor, batchTargetTensor, lossScalar})
|
|
|
|
losses := []float32{}
|
|
epochs := 100
|
|
batchData := generateBatchData(nBatches, batchSize)
|
|
|
|
for epoch := 0; epoch < epochs; epoch++ {
|
|
var epochLoss float32 // total epoch loss
|
|
|
|
for i := 0; i < nBatches; i++ {
|
|
inputSlice := batchInputTensor.GetData()
|
|
outputSlice := batchTargetTensor.GetData()
|
|
|
|
copy(inputSlice, batchData[i]["input"])
|
|
copy(outputSlice, batchData[i]["output"])
|
|
|
|
// train on batch
|
|
err = trainingSession.TrainStep()
|
|
if err != nil {
|
|
t.Fatalf("train step failed with error: %v", err)
|
|
}
|
|
|
|
epochLoss = epochLoss + lossScalar.GetData()
|
|
|
|
err = trainingSession.OptimizerStep()
|
|
if err != nil {
|
|
t.Fatalf("optimizer step failed with error: %v", err)
|
|
}
|
|
|
|
// ort training api - reset the gradients to zero so that new gradients can be computed for next batch
|
|
err = trainingSession.LazyResetGrad()
|
|
if err != nil {
|
|
t.Fatalf("lazy reset grad step failed with error: %v", err)
|
|
}
|
|
}
|
|
if epoch%10 == 0 {
|
|
fmt.Printf("Epoch {%d} Loss {%f}\n", epoch+1, epochLoss/float32(batchSize*nBatches))
|
|
losses = append(losses, epochLoss/float32(batchSize*nBatches))
|
|
}
|
|
}
|
|
|
|
expectedLosses := []float32{
|
|
0.125085,
|
|
0.097187,
|
|
0.062333,
|
|
0.024307,
|
|
0.019963,
|
|
0.018476,
|
|
0.017160,
|
|
0.015982,
|
|
0.014845,
|
|
0.013867,
|
|
}
|
|
|
|
for i, l := range losses {
|
|
diff := math.Abs(float64(l - expectedLosses[i]))
|
|
deviation := diff / float64(expectedLosses[i])
|
|
if deviation > 0.6 {
|
|
t.Fatalf("loss deviation too large: expected %f, actual %f, deviation %f", float64(expectedLosses[i]), float64(l), float64(deviation))
|
|
}
|
|
}
|
|
|
|
// test the saving of the checkpoint state
|
|
finalCheckpointPath := path.Join("test_data", "training_test", "finalCheckpoint")
|
|
errSaveCheckpoint := trainingSession.SaveCheckpoint(finalCheckpointPath, false)
|
|
if errSaveCheckpoint != nil {
|
|
t.Fatalf("Saving of checkpoint failed with error: %v", errSaveCheckpoint)
|
|
}
|
|
|
|
// test the saving of the model
|
|
finalModelPath := path.Join("test_data", "training_test", "final_inference.onnx")
|
|
errExport := trainingSession.ExportModel(finalModelPath, []string{"output"})
|
|
if errExport != nil {
|
|
t.Fatalf("Exporting model failed with error: %v", errExport)
|
|
}
|
|
|
|
defer func() {
|
|
e := os.Remove(finalCheckpointPath)
|
|
if e != nil {
|
|
t.Errorf("Error removing final checkpoint file %s: %s", finalCheckpointPath, e)
|
|
}
|
|
e = os.Remove(finalModelPath)
|
|
if e != nil {
|
|
t.Errorf("Error removing final model file %s: %s", finalModelPath, e)
|
|
}
|
|
}()
|
|
|
|
// load the model back in and test in-sample predictions for the first batch
|
|
// (we care about correctness more than generalization here)
|
|
session, err := NewAdvancedSession(path.Join("test_data", "training_test", "final_inference.onnx"),
|
|
[]string{"input"}, []string{"output"},
|
|
[]ArbitraryTensor{batchInputTensor}, []ArbitraryTensor{batchTargetTensor}, nil)
|
|
|
|
if err != nil {
|
|
t.Fatalf("creation of inference session failed with error: %v", err)
|
|
}
|
|
|
|
defer func(s *AdvancedSession) {
|
|
err := s.Destroy()
|
|
if err != nil {
|
|
t.Fatalf("cleanup of inference session failed with error: %v", err)
|
|
}
|
|
}(session)
|
|
|
|
// Calling Run() will run the network, reading the current contents of the
|
|
// input tensors and modifying the contents of the output tensors.
|
|
copy(batchInputTensor.GetData(), batchData[0]["input"])
|
|
err = session.Run()
|
|
if err != nil {
|
|
t.Fatalf("run of inference session failed with error: %v", err)
|
|
}
|
|
|
|
expectedOutput := []float32{
|
|
2.4524384,
|
|
0.65120333,
|
|
2.5457804,
|
|
0.6102175,
|
|
1.6276635,
|
|
0.573755,
|
|
1.7900972,
|
|
0.59951085,
|
|
3.1650176,
|
|
0.66626525,
|
|
1.9361509,
|
|
0.571084,
|
|
2.0798547,
|
|
0.6060241,
|
|
0.9611889,
|
|
0.52100605,
|
|
1.4070896,
|
|
0.5412475,
|
|
2.1449144,
|
|
0.5985652,
|
|
}
|
|
|
|
for i, l := range batchTargetTensor.GetData() {
|
|
diff := math.Abs(float64(l - expectedOutput[i]))
|
|
deviation := diff / float64(expectedOutput[i])
|
|
if deviation > 0.6 {
|
|
t.Fatalf("deviation too large")
|
|
}
|
|
}
|
|
}
|