mirror of
https://github.com/yalue/onnxruntime_go.git
synced 2025-11-02 21:34:01 +08:00
Minor cleanup
- The training test printed some stuff to stdout rather than the testing log. - The temporary buffer for UTF16 conversion was over-allocated.
This commit is contained in:
@@ -2,7 +2,6 @@ package onnxruntime_go
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand"
|
||||
"os"
|
||||
@@ -245,7 +244,7 @@ func TestTraining(t *testing.T) {
|
||||
}
|
||||
}
|
||||
if epoch%10 == 0 {
|
||||
fmt.Printf("Epoch {%d} Loss {%f}\n", epoch+1, epochLoss/float32(batchSize*nBatches))
|
||||
t.Logf("Epoch {%d} Loss {%f}\n", epoch+1, epochLoss/float32(batchSize*nBatches))
|
||||
losses = append(losses, epochLoss/float32(batchSize*nBatches))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,7 +78,9 @@ func platformInitializeEnvironment() error {
|
||||
// non-UTF8 characters.
|
||||
func createOrtCharString(str string) (*C.char, error) {
|
||||
src := []uint8(str)
|
||||
dst := make([]uint16, 0, (len(src)+1)*2)
|
||||
// Assumed common case: the utf16 buffer contains one uint16 per utf8 byte
|
||||
// plus one more for the required null terminator in the C buffer.
|
||||
dst := make([]uint16, 0, len(src)+1)
|
||||
// Convert UTF-8 to UTF-16 by reading each subsequent rune from src and
|
||||
// appending it as UTF-16 to dst.
|
||||
for len(src) > 0 {
|
||||
@@ -97,7 +99,7 @@ func createOrtCharString(str string) (*C.char, error) {
|
||||
// C.CString.
|
||||
toReturn := C.calloc(C.size_t(len(dst)), 2)
|
||||
if toReturn == nil {
|
||||
return nil, fmt.Errorf("Error allocating C buffer to hold utf16 str")
|
||||
return nil, fmt.Errorf("Error allocating buffer for the utf16 string")
|
||||
}
|
||||
C.memcpy(toReturn, unsafe.Pointer(&(dst[0])), C.size_t(len(dst))*2)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user