Minor cleanup

- The training test printed some stuff to stdout rather than the
   testing log.

 - The temporary buffer for UTF16 conversion was over-allocated.
This commit is contained in:
yalue
2024-09-14 23:24:31 -04:00
parent b5a29a1390
commit 02239b0937
2 changed files with 5 additions and 4 deletions

View File

@@ -2,7 +2,6 @@ package onnxruntime_go
import (
"errors"
"fmt"
"math"
"math/rand"
"os"
@@ -245,7 +244,7 @@ func TestTraining(t *testing.T) {
}
}
if epoch%10 == 0 {
fmt.Printf("Epoch {%d} Loss {%f}\n", epoch+1, epochLoss/float32(batchSize*nBatches))
t.Logf("Epoch {%d} Loss {%f}\n", epoch+1, epochLoss/float32(batchSize*nBatches))
losses = append(losses, epochLoss/float32(batchSize*nBatches))
}
}

View File

@@ -78,7 +78,9 @@ func platformInitializeEnvironment() error {
// non-UTF8 characters.
func createOrtCharString(str string) (*C.char, error) {
src := []uint8(str)
dst := make([]uint16, 0, (len(src)+1)*2)
// Assumed common case: the utf16 buffer contains one uint16 per utf8 byte
// plus one more for the required null terminator in the C buffer.
dst := make([]uint16, 0, len(src)+1)
// Convert UTF-8 to UTF-16 by reading each subsequent rune from src and
// appending it as UTF-16 to dst.
for len(src) > 0 {
@@ -97,7 +99,7 @@ func createOrtCharString(str string) (*C.char, error) {
// C.CString.
toReturn := C.calloc(C.size_t(len(dst)), 2)
if toReturn == nil {
return nil, fmt.Errorf("Error allocating C buffer to hold utf16 str")
return nil, fmt.Errorf("Error allocating buffer for the utf16 string")
}
C.memcpy(toReturn, unsafe.Pointer(&(dst[0])), C.size_t(len(dst))*2)