Files
onnxruntime_go/onnxruntime_training_test.go
yalue 02239b0937 Minor cleanup
- The training test printed some stuff to stdout rather than the
   testing log.

 - The temporary buffer for UTF16 conversion was over-allocated.
2024-09-14 23:24:31 -04:00

354 lines
9.6 KiB
Go

package onnxruntime_go
import (
"errors"
"math"
"math/rand"
"os"
"path"
"testing"
)
func TestTrainingNotSupported(t *testing.T) {
InitializeRuntime(t)
defer CleanupRuntime(t)
if IsTrainingSupported() {
t.Skipf("onnxruntime library supports training")
}
options, e := NewSessionOptions()
if e != nil {
t.Logf("Failed creating options: %s\n", e)
t.FailNow()
}
trainingSession, e := NewTrainingSession("test_data/onnxruntime_training_test/training_artifacts/checkpoint",
"test_data/onnxruntime_training_test/training_artifacts/training_model.onnx",
"test_data/onnxruntime_training_test/training_artifacts/eval_model.onnx",
"test_data/onnxruntime_training_test/training_artifacts/optimizer_model.onnx",
nil, nil,
options)
if !errors.Is(e, trainingNotSupportedError) {
t.Logf("Creating training session when onnxruntime lib does not support it should return training not supported error.")
if e != nil {
t.Logf("Received instead error: %s", e.Error())
} else {
t.Logf("Received no error instead")
}
t.FailNow()
}
if trainingSession != nil {
if err := trainingSession.Destroy(); err != nil {
t.Fatalf("cleanup of training session failed with error: %v", e)
}
}
}
func TestGetInputOutputNames(t *testing.T) {
InitializeRuntime(t)
defer CleanupRuntime(t)
if !IsTrainingSupported() {
t.Skipf("Training is not supported on this platform/onnxruntime build.")
}
artifactsPath := path.Join("test_data", "training_test")
names, err := GetInputOutputNames(
path.Join(artifactsPath, "checkpoint"),
path.Join(artifactsPath, "training_model.onnx"),
path.Join(artifactsPath, "eval_model.onnx"),
)
if err != nil {
t.Fatalf("Failed getting input and output names with error: %v\n", err)
}
expectedTrainInputNames := []string{"input", "target"}
expectedEvalInputNames := expectedTrainInputNames
expectedTrainOutputNames := []string{"onnx::reducemean_output::5"}
expectedEvalOutputNames := expectedTrainOutputNames
for i, v := range names.TrainingInputNames {
if v != expectedTrainInputNames[i] {
t.Fatalf("training input names don't match")
}
}
for i, v := range names.TrainingOutputNames {
if v != expectedTrainOutputNames[i] {
t.Fatalf("training output names don't match")
}
}
for i, v := range names.EvalInputNames {
if v != expectedEvalInputNames[i] {
t.Fatalf("eval input names don't match")
}
}
for i, v := range names.EvalOutputNames {
if v != expectedEvalOutputNames[i] {
t.Fatalf("eval output names don't match")
}
}
// without eval model
names, err = GetInputOutputNames(
path.Join(artifactsPath, "checkpoint"),
path.Join(artifactsPath, "training_model.onnx"),
"",
)
if err != nil {
t.Fatalf("Failed getting input and output names with error: %v\n", err)
}
for i, v := range names.TrainingInputNames {
if v != expectedTrainInputNames[i] {
t.Fatalf("training input names don't match")
}
}
for i, v := range names.TrainingOutputNames {
if v != expectedTrainOutputNames[i] {
t.Fatalf("training output names don't match")
}
}
}
func generateBatchData(nBatches int, batchSize int) map[int]map[string][]float32 {
batchData := map[int]map[string][]float32{}
source := rand.NewSource(1234)
g := rand.New(source)
for i := 0; i < nBatches; i++ {
inputCounter := 0
outputCounter := 0
inputSlice := make([]float32, batchSize*4)
outputSlice := make([]float32, batchSize*2)
batchData[i] = map[string][]float32{}
// generate random data for batch
for n := 0; n < batchSize; n++ {
var sum float32
min := float32(1)
max := float32(-1)
for i := 0; i < 4; i++ {
r := g.Float32()
inputSlice[inputCounter] = r
inputCounter++
if r > max {
max = r
}
if r < min {
min = r
}
sum = sum + r
}
outputSlice[outputCounter] = sum
outputSlice[outputCounter+1] = max - min
outputCounter = outputCounter + 2
}
batchData[i]["input"] = inputSlice
batchData[i]["output"] = outputSlice
}
return batchData
}
// TestTraining tests a basic training flow using the bindings to the C api for on-device onnxruntime training
func TestTraining(t *testing.T) {
InitializeRuntime(t)
defer CleanupRuntime(t)
if !IsTrainingSupported() {
t.Skipf("Training is not supported on this platform/onnxruntime build.")
}
trainingArtifactsFolder := path.Join("test_data", "training_test")
// generate training data
batchSize := 10
nBatches := 10
// holds inputs/outputs and loss for each training batch
batchInputShape := NewShape(int64(batchSize), 1, 4)
batchTargetShape := NewShape(int64(batchSize), 1, 2)
batchInputTensor, err := NewEmptyTensor[float32](batchInputShape)
if err != nil {
t.Fatalf("training test failed with error: %v", err)
}
batchTargetTensor, err := NewEmptyTensor[float32](batchTargetShape)
if err != nil {
t.Fatalf("training test failed with error: %v", err)
}
lossScalar, err := NewEmptyScalar[float32]()
if err != nil {
t.Fatalf("training test failed with error: %v", err)
}
trainingSession, errorSessionCreation := NewTrainingSession(
path.Join(trainingArtifactsFolder, "checkpoint"),
path.Join(trainingArtifactsFolder, "training_model.onnx"),
path.Join(trainingArtifactsFolder, "eval_model.onnx"),
path.Join(trainingArtifactsFolder, "optimizer_model.onnx"),
[]Value{batchInputTensor, batchTargetTensor}, []Value{lossScalar},
nil)
if errorSessionCreation != nil {
t.Fatalf("session creation failed with error: %v", errorSessionCreation)
}
// cleanup after test run
defer func(session *TrainingSession, tensors []Value) {
var errs []error
errs = append(errs, session.Destroy())
for _, t := range tensors {
errs = append(errs, t.Destroy())
}
if e := errors.Join(errs...); e != nil {
t.Fatalf("cleanup of test failed with error: %v", e)
}
}(trainingSession, []Value{batchInputTensor, batchTargetTensor, lossScalar})
losses := []float32{}
epochs := 100
batchData := generateBatchData(nBatches, batchSize)
for epoch := 0; epoch < epochs; epoch++ {
var epochLoss float32 // total epoch loss
for i := 0; i < nBatches; i++ {
inputSlice := batchInputTensor.GetData()
outputSlice := batchTargetTensor.GetData()
copy(inputSlice, batchData[i]["input"])
copy(outputSlice, batchData[i]["output"])
// train on batch
err = trainingSession.TrainStep()
if err != nil {
t.Fatalf("train step failed with error: %v", err)
}
epochLoss = epochLoss + lossScalar.GetData()
err = trainingSession.OptimizerStep()
if err != nil {
t.Fatalf("optimizer step failed with error: %v", err)
}
// ort training api - reset the gradients to zero so that new gradients can be computed for next batch
err = trainingSession.LazyResetGrad()
if err != nil {
t.Fatalf("lazy reset grad step failed with error: %v", err)
}
}
if epoch%10 == 0 {
t.Logf("Epoch {%d} Loss {%f}\n", epoch+1, epochLoss/float32(batchSize*nBatches))
losses = append(losses, epochLoss/float32(batchSize*nBatches))
}
}
expectedLosses := []float32{
0.125085,
0.097187,
0.062333,
0.024307,
0.019963,
0.018476,
0.017160,
0.015982,
0.014845,
0.013867,
}
for i, l := range losses {
diff := math.Abs(float64(l - expectedLosses[i]))
deviation := diff / float64(expectedLosses[i])
if deviation > 0.6 {
t.Fatalf("loss deviation too large: expected %f, actual %f, deviation %f", float64(expectedLosses[i]), float64(l), float64(deviation))
}
}
// test the saving of the checkpoint state
finalCheckpointPath := path.Join("test_data", "training_test", "finalCheckpoint")
errSaveCheckpoint := trainingSession.SaveCheckpoint(finalCheckpointPath, false)
if errSaveCheckpoint != nil {
t.Fatalf("Saving of checkpoint failed with error: %v", errSaveCheckpoint)
}
// test the saving of the model
finalModelPath := path.Join("test_data", "training_test", "final_inference.onnx")
errExport := trainingSession.ExportModel(finalModelPath, []string{"output"})
if errExport != nil {
t.Fatalf("Exporting model failed with error: %v", errExport)
}
defer func() {
e := os.Remove(finalCheckpointPath)
if e != nil {
t.Errorf("Error removing final checkpoint file %s: %s", finalCheckpointPath, e)
}
e = os.Remove(finalModelPath)
if e != nil {
t.Errorf("Error removing final model file %s: %s", finalModelPath, e)
}
}()
// load the model back in and test in-sample predictions for the first batch
// (we care about correctness more than generalization here)
session, err := NewAdvancedSession(path.Join("test_data", "training_test", "final_inference.onnx"),
[]string{"input"}, []string{"output"},
[]Value{batchInputTensor}, []Value{batchTargetTensor}, nil)
if err != nil {
t.Fatalf("creation of inference session failed with error: %v", err)
}
defer func(s *AdvancedSession) {
err := s.Destroy()
if err != nil {
t.Fatalf("cleanup of inference session failed with error: %v", err)
}
}(session)
// Calling Run() will run the network, reading the current contents of the
// input tensors and modifying the contents of the output tensors.
copy(batchInputTensor.GetData(), batchData[0]["input"])
err = session.Run()
if err != nil {
t.Fatalf("run of inference session failed with error: %v", err)
}
expectedOutput := []float32{
2.4524384,
0.65120333,
2.5457804,
0.6102175,
1.6276635,
0.573755,
1.7900972,
0.59951085,
3.1650176,
0.66626525,
1.9361509,
0.571084,
2.0798547,
0.6060241,
0.9611889,
0.52100605,
1.4070896,
0.5412475,
2.1449144,
0.5985652,
}
for i, l := range batchTargetTensor.GetData() {
diff := math.Abs(float64(l - expectedOutput[i]))
deviation := diff / float64(expectedOutput[i])
if deviation > 0.6 {
t.Fatalf("deviation too large")
}
}
}