mirror of
				https://github.com/yalue/onnxruntime_go.git
				synced 2025-10-31 18:52:43 +08:00 
			
		
		
		
	 02239b0937
			
		
	
	02239b0937
	
	
	
		
			
			- The training test printed some stuff to stdout rather than the testing log. - The temporary buffer for UTF16 conversion was over-allocated.
		
			
				
	
	
		
			354 lines
		
	
	
		
			9.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			354 lines
		
	
	
		
			9.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package onnxruntime_go
 | |
| 
 | |
| import (
 | |
| 	"errors"
 | |
| 	"math"
 | |
| 	"math/rand"
 | |
| 	"os"
 | |
| 	"path"
 | |
| 	"testing"
 | |
| )
 | |
| 
 | |
| func TestTrainingNotSupported(t *testing.T) {
 | |
| 	InitializeRuntime(t)
 | |
| 	defer CleanupRuntime(t)
 | |
| 
 | |
| 	if IsTrainingSupported() {
 | |
| 		t.Skipf("onnxruntime library supports training")
 | |
| 	}
 | |
| 
 | |
| 	options, e := NewSessionOptions()
 | |
| 	if e != nil {
 | |
| 		t.Logf("Failed creating options: %s\n", e)
 | |
| 		t.FailNow()
 | |
| 	}
 | |
| 
 | |
| 	trainingSession, e := NewTrainingSession("test_data/onnxruntime_training_test/training_artifacts/checkpoint",
 | |
| 		"test_data/onnxruntime_training_test/training_artifacts/training_model.onnx",
 | |
| 		"test_data/onnxruntime_training_test/training_artifacts/eval_model.onnx",
 | |
| 		"test_data/onnxruntime_training_test/training_artifacts/optimizer_model.onnx",
 | |
| 		nil, nil,
 | |
| 		options)
 | |
| 
 | |
| 	if !errors.Is(e, trainingNotSupportedError) {
 | |
| 		t.Logf("Creating training session when onnxruntime lib does not support it should return training not supported error.")
 | |
| 		if e != nil {
 | |
| 			t.Logf("Received instead error: %s", e.Error())
 | |
| 		} else {
 | |
| 			t.Logf("Received no error instead")
 | |
| 		}
 | |
| 		t.FailNow()
 | |
| 	}
 | |
| 	if trainingSession != nil {
 | |
| 		if err := trainingSession.Destroy(); err != nil {
 | |
| 			t.Fatalf("cleanup of training session failed with error: %v", e)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestGetInputOutputNames(t *testing.T) {
 | |
| 	InitializeRuntime(t)
 | |
| 	defer CleanupRuntime(t)
 | |
| 
 | |
| 	if !IsTrainingSupported() {
 | |
| 		t.Skipf("Training is not supported on this platform/onnxruntime build.")
 | |
| 	}
 | |
| 
 | |
| 	artifactsPath := path.Join("test_data", "training_test")
 | |
| 
 | |
| 	names, err := GetInputOutputNames(
 | |
| 		path.Join(artifactsPath, "checkpoint"),
 | |
| 		path.Join(artifactsPath, "training_model.onnx"),
 | |
| 		path.Join(artifactsPath, "eval_model.onnx"),
 | |
| 	)
 | |
| 
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Failed getting input and output names with error: %v\n", err)
 | |
| 	}
 | |
| 
 | |
| 	expectedTrainInputNames := []string{"input", "target"}
 | |
| 	expectedEvalInputNames := expectedTrainInputNames
 | |
| 	expectedTrainOutputNames := []string{"onnx::reducemean_output::5"}
 | |
| 	expectedEvalOutputNames := expectedTrainOutputNames
 | |
| 
 | |
| 	for i, v := range names.TrainingInputNames {
 | |
| 		if v != expectedTrainInputNames[i] {
 | |
| 			t.Fatalf("training input names don't match")
 | |
| 		}
 | |
| 	}
 | |
| 	for i, v := range names.TrainingOutputNames {
 | |
| 		if v != expectedTrainOutputNames[i] {
 | |
| 			t.Fatalf("training output names don't match")
 | |
| 		}
 | |
| 	}
 | |
| 	for i, v := range names.EvalInputNames {
 | |
| 		if v != expectedEvalInputNames[i] {
 | |
| 			t.Fatalf("eval input names don't match")
 | |
| 		}
 | |
| 	}
 | |
| 	for i, v := range names.EvalOutputNames {
 | |
| 		if v != expectedEvalOutputNames[i] {
 | |
| 			t.Fatalf("eval output names don't match")
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// without eval model
 | |
| 	names, err = GetInputOutputNames(
 | |
| 		path.Join(artifactsPath, "checkpoint"),
 | |
| 		path.Join(artifactsPath, "training_model.onnx"),
 | |
| 		"",
 | |
| 	)
 | |
| 
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("Failed getting input and output names with error: %v\n", err)
 | |
| 	}
 | |
| 
 | |
| 	for i, v := range names.TrainingInputNames {
 | |
| 		if v != expectedTrainInputNames[i] {
 | |
| 			t.Fatalf("training input names don't match")
 | |
| 		}
 | |
| 	}
 | |
| 	for i, v := range names.TrainingOutputNames {
 | |
| 		if v != expectedTrainOutputNames[i] {
 | |
| 			t.Fatalf("training output names don't match")
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func generateBatchData(nBatches int, batchSize int) map[int]map[string][]float32 {
 | |
| 	batchData := map[int]map[string][]float32{}
 | |
| 
 | |
| 	source := rand.NewSource(1234)
 | |
| 	g := rand.New(source)
 | |
| 
 | |
| 	for i := 0; i < nBatches; i++ {
 | |
| 		inputCounter := 0
 | |
| 		outputCounter := 0
 | |
| 		inputSlice := make([]float32, batchSize*4)
 | |
| 		outputSlice := make([]float32, batchSize*2)
 | |
| 		batchData[i] = map[string][]float32{}
 | |
| 
 | |
| 		// generate random data for batch
 | |
| 		for n := 0; n < batchSize; n++ {
 | |
| 			var sum float32
 | |
| 			min := float32(1)
 | |
| 			max := float32(-1)
 | |
| 			for i := 0; i < 4; i++ {
 | |
| 				r := g.Float32()
 | |
| 				inputSlice[inputCounter] = r
 | |
| 				inputCounter++
 | |
| 				if r > max {
 | |
| 					max = r
 | |
| 				}
 | |
| 				if r < min {
 | |
| 					min = r
 | |
| 				}
 | |
| 				sum = sum + r
 | |
| 			}
 | |
| 			outputSlice[outputCounter] = sum
 | |
| 			outputSlice[outputCounter+1] = max - min
 | |
| 			outputCounter = outputCounter + 2
 | |
| 		}
 | |
| 		batchData[i]["input"] = inputSlice
 | |
| 		batchData[i]["output"] = outputSlice
 | |
| 	}
 | |
| 	return batchData
 | |
| }
 | |
| 
 | |
| // TestTraining tests a basic training flow using the bindings to the C api for on-device onnxruntime training
 | |
| func TestTraining(t *testing.T) {
 | |
| 	InitializeRuntime(t)
 | |
| 	defer CleanupRuntime(t)
 | |
| 
 | |
| 	if !IsTrainingSupported() {
 | |
| 		t.Skipf("Training is not supported on this platform/onnxruntime build.")
 | |
| 	}
 | |
| 
 | |
| 	trainingArtifactsFolder := path.Join("test_data", "training_test")
 | |
| 
 | |
| 	// generate training data
 | |
| 	batchSize := 10
 | |
| 	nBatches := 10
 | |
| 
 | |
| 	// holds inputs/outputs and loss for each training batch
 | |
| 	batchInputShape := NewShape(int64(batchSize), 1, 4)
 | |
| 	batchTargetShape := NewShape(int64(batchSize), 1, 2)
 | |
| 	batchInputTensor, err := NewEmptyTensor[float32](batchInputShape)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("training test failed with error: %v", err)
 | |
| 	}
 | |
| 	batchTargetTensor, err := NewEmptyTensor[float32](batchTargetShape)
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("training test failed with error: %v", err)
 | |
| 	}
 | |
| 	lossScalar, err := NewEmptyScalar[float32]()
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("training test failed with error: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	trainingSession, errorSessionCreation := NewTrainingSession(
 | |
| 		path.Join(trainingArtifactsFolder, "checkpoint"),
 | |
| 		path.Join(trainingArtifactsFolder, "training_model.onnx"),
 | |
| 		path.Join(trainingArtifactsFolder, "eval_model.onnx"),
 | |
| 		path.Join(trainingArtifactsFolder, "optimizer_model.onnx"),
 | |
| 		[]Value{batchInputTensor, batchTargetTensor}, []Value{lossScalar},
 | |
| 		nil)
 | |
| 
 | |
| 	if errorSessionCreation != nil {
 | |
| 		t.Fatalf("session creation failed with error: %v", errorSessionCreation)
 | |
| 	}
 | |
| 
 | |
| 	// cleanup after test run
 | |
| 	defer func(session *TrainingSession, tensors []Value) {
 | |
| 		var errs []error
 | |
| 		errs = append(errs, session.Destroy())
 | |
| 		for _, t := range tensors {
 | |
| 			errs = append(errs, t.Destroy())
 | |
| 		}
 | |
| 		if e := errors.Join(errs...); e != nil {
 | |
| 			t.Fatalf("cleanup of test failed with error: %v", e)
 | |
| 		}
 | |
| 	}(trainingSession, []Value{batchInputTensor, batchTargetTensor, lossScalar})
 | |
| 
 | |
| 	losses := []float32{}
 | |
| 	epochs := 100
 | |
| 	batchData := generateBatchData(nBatches, batchSize)
 | |
| 
 | |
| 	for epoch := 0; epoch < epochs; epoch++ {
 | |
| 		var epochLoss float32 // total epoch loss
 | |
| 
 | |
| 		for i := 0; i < nBatches; i++ {
 | |
| 			inputSlice := batchInputTensor.GetData()
 | |
| 			outputSlice := batchTargetTensor.GetData()
 | |
| 
 | |
| 			copy(inputSlice, batchData[i]["input"])
 | |
| 			copy(outputSlice, batchData[i]["output"])
 | |
| 
 | |
| 			// train on batch
 | |
| 			err = trainingSession.TrainStep()
 | |
| 			if err != nil {
 | |
| 				t.Fatalf("train step failed with error: %v", err)
 | |
| 			}
 | |
| 
 | |
| 			epochLoss = epochLoss + lossScalar.GetData()
 | |
| 
 | |
| 			err = trainingSession.OptimizerStep()
 | |
| 			if err != nil {
 | |
| 				t.Fatalf("optimizer step failed with error: %v", err)
 | |
| 			}
 | |
| 
 | |
| 			// ort training api - reset the gradients to zero so that new gradients can be computed for next batch
 | |
| 			err = trainingSession.LazyResetGrad()
 | |
| 			if err != nil {
 | |
| 				t.Fatalf("lazy reset grad step failed with error: %v", err)
 | |
| 			}
 | |
| 		}
 | |
| 		if epoch%10 == 0 {
 | |
| 			t.Logf("Epoch {%d} Loss {%f}\n", epoch+1, epochLoss/float32(batchSize*nBatches))
 | |
| 			losses = append(losses, epochLoss/float32(batchSize*nBatches))
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	expectedLosses := []float32{
 | |
| 		0.125085,
 | |
| 		0.097187,
 | |
| 		0.062333,
 | |
| 		0.024307,
 | |
| 		0.019963,
 | |
| 		0.018476,
 | |
| 		0.017160,
 | |
| 		0.015982,
 | |
| 		0.014845,
 | |
| 		0.013867,
 | |
| 	}
 | |
| 
 | |
| 	for i, l := range losses {
 | |
| 		diff := math.Abs(float64(l - expectedLosses[i]))
 | |
| 		deviation := diff / float64(expectedLosses[i])
 | |
| 		if deviation > 0.6 {
 | |
| 			t.Fatalf("loss deviation too large: expected %f, actual %f, deviation %f", float64(expectedLosses[i]), float64(l), float64(deviation))
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// test the saving of the checkpoint state
 | |
| 	finalCheckpointPath := path.Join("test_data", "training_test", "finalCheckpoint")
 | |
| 	errSaveCheckpoint := trainingSession.SaveCheckpoint(finalCheckpointPath, false)
 | |
| 	if errSaveCheckpoint != nil {
 | |
| 		t.Fatalf("Saving of checkpoint failed with error: %v", errSaveCheckpoint)
 | |
| 	}
 | |
| 
 | |
| 	// test the saving of the model
 | |
| 	finalModelPath := path.Join("test_data", "training_test", "final_inference.onnx")
 | |
| 	errExport := trainingSession.ExportModel(finalModelPath, []string{"output"})
 | |
| 	if errExport != nil {
 | |
| 		t.Fatalf("Exporting model failed with error: %v", errExport)
 | |
| 	}
 | |
| 
 | |
| 	defer func() {
 | |
| 		e := os.Remove(finalCheckpointPath)
 | |
| 		if e != nil {
 | |
| 			t.Errorf("Error removing final checkpoint file %s: %s", finalCheckpointPath, e)
 | |
| 		}
 | |
| 		e = os.Remove(finalModelPath)
 | |
| 		if e != nil {
 | |
| 			t.Errorf("Error removing final model file %s: %s", finalModelPath, e)
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	// load the model back in and test in-sample predictions for the first batch
 | |
| 	// (we care about correctness more than generalization here)
 | |
| 	session, err := NewAdvancedSession(path.Join("test_data", "training_test", "final_inference.onnx"),
 | |
| 		[]string{"input"}, []string{"output"},
 | |
| 		[]Value{batchInputTensor}, []Value{batchTargetTensor}, nil)
 | |
| 
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("creation of inference session failed with error: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	defer func(s *AdvancedSession) {
 | |
| 		err := s.Destroy()
 | |
| 		if err != nil {
 | |
| 			t.Fatalf("cleanup of inference session failed with error: %v", err)
 | |
| 		}
 | |
| 	}(session)
 | |
| 
 | |
| 	// Calling Run() will run the network, reading the current contents of the
 | |
| 	// input tensors and modifying the contents of the output tensors.
 | |
| 	copy(batchInputTensor.GetData(), batchData[0]["input"])
 | |
| 	err = session.Run()
 | |
| 	if err != nil {
 | |
| 		t.Fatalf("run of inference session failed with error: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	expectedOutput := []float32{
 | |
| 		2.4524384,
 | |
| 		0.65120333,
 | |
| 		2.5457804,
 | |
| 		0.6102175,
 | |
| 		1.6276635,
 | |
| 		0.573755,
 | |
| 		1.7900972,
 | |
| 		0.59951085,
 | |
| 		3.1650176,
 | |
| 		0.66626525,
 | |
| 		1.9361509,
 | |
| 		0.571084,
 | |
| 		2.0798547,
 | |
| 		0.6060241,
 | |
| 		0.9611889,
 | |
| 		0.52100605,
 | |
| 		1.4070896,
 | |
| 		0.5412475,
 | |
| 		2.1449144,
 | |
| 		0.5985652,
 | |
| 	}
 | |
| 
 | |
| 	for i, l := range batchTargetTensor.GetData() {
 | |
| 		diff := math.Abs(float64(l - expectedOutput[i]))
 | |
| 		deviation := diff / float64(expectedOutput[i])
 | |
| 		if deviation > 0.6 {
 | |
| 			t.Fatalf("deviation too large")
 | |
| 		}
 | |
| 	}
 | |
| }
 |