mirror of
https://github.com/yalue/onnxruntime_go.git
synced 2025-09-27 03:35:58 +08:00

- The training API has been deprecated by onnxruntime itself, and it will be much easier to remove it rather than deprecate it. - The training API wrapper functions have been replaced by stubs that return errors in legacy_types.go. - The README mentions the old version required for the training API. - The Scalar type has been promoted to onnxruntime_go.go, and a test has been added for it.
104 lines
3.3 KiB
Go
104 lines
3.3 KiB
Go
//go:build windows
|
|
|
|
package onnxruntime_go
|
|
|
|
// This file includes the Windows-specific code for loading the onnxruntime
|
|
// library and setting up the environment.
|
|
|
|
import (
|
|
"fmt"
|
|
"syscall"
|
|
"unicode/utf16"
|
|
"unicode/utf8"
|
|
"unsafe"
|
|
)
|
|
|
|
// #include "onnxruntime_wrapper.h"
|
|
import "C"
|
|
|
|
// This will contain the handle to the onnxruntime dll if it has been loaded
|
|
// successfully.
|
|
var libraryHandle syscall.Handle
|
|
|
|
func platformCleanup() error {
|
|
e := syscall.FreeLibrary(libraryHandle)
|
|
libraryHandle = 0
|
|
return e
|
|
}
|
|
|
|
func platformInitializeEnvironment() error {
|
|
if onnxSharedLibraryPath == "" {
|
|
onnxSharedLibraryPath = "onnxruntime.dll"
|
|
}
|
|
handle, e := syscall.LoadLibrary(onnxSharedLibraryPath)
|
|
if e != nil {
|
|
return fmt.Errorf("Error loading ONNX shared library \"%s\": %w",
|
|
onnxSharedLibraryPath, e)
|
|
}
|
|
getApiBaseProc, e := syscall.GetProcAddress(handle, "OrtGetApiBase")
|
|
if e != nil {
|
|
syscall.FreeLibrary(handle)
|
|
return fmt.Errorf("Error finding OrtGetApiBase function in %s: %w",
|
|
onnxSharedLibraryPath, e)
|
|
}
|
|
ortApiBase, _, e := syscall.SyscallN(uintptr(getApiBaseProc), 0)
|
|
if ortApiBase == 0 {
|
|
syscall.FreeLibrary(handle)
|
|
if e != nil {
|
|
return fmt.Errorf("Error calling OrtGetApiBase: %w", e)
|
|
} else {
|
|
return fmt.Errorf("Error calling OrtGetApiBase")
|
|
}
|
|
}
|
|
tmp := C.SetAPIFromBase((*C.OrtApiBase)(unsafe.Pointer(ortApiBase)))
|
|
if tmp != 0 {
|
|
syscall.FreeLibrary(handle)
|
|
return fmt.Errorf("Error setting ORT API base: %d", tmp)
|
|
}
|
|
|
|
libraryHandle = handle
|
|
return nil
|
|
}
|
|
|
|
// Converts the given string to a UTF-16 string, pointed to by a raw
|
|
// *C.char. Note that we actually keep ORTCHAR_T defined to char even
|
|
// on Windows, so do _not_ index into this string from Cgo code and expect to
|
|
// get correct characters! Instead, this should only be used to obtain pointers
|
|
// that are passed to onnxruntime windows DLL functions expecting ORTCHAR_T*
|
|
// args. This is required because we undefine _WIN32 for cgo compatibility when
|
|
// including onnxruntime_c_api.h, but still interact with a DLL that was
|
|
// compiled assuming _WIN32 was defined.
|
|
//
|
|
// The pointer returned by this function must still be freed using C.free when
|
|
// no longer needed. This will return an error if the given string contains
|
|
// non-UTF8 characters.
|
|
func createOrtCharString(str string) (*C.char, error) {
|
|
src := []uint8(str)
|
|
// Assumed common case: the utf16 buffer contains one uint16 per utf8 byte
|
|
// plus one more for the required null terminator in the C buffer.
|
|
dst := make([]uint16, 0, len(src)+1)
|
|
// Convert UTF-8 to UTF-16 by reading each subsequent rune from src and
|
|
// appending it as UTF-16 to dst.
|
|
for len(src) > 0 {
|
|
r, size := utf8.DecodeRune(src)
|
|
if r == utf8.RuneError {
|
|
return nil, fmt.Errorf("Invalid UTF-8 rune found in \"%s\"", str)
|
|
}
|
|
src = src[size:]
|
|
dst = utf16.AppendRune(dst, r)
|
|
}
|
|
// Make sure dst contains the null terminator. Additionally this will cause
|
|
// us to return an empty string if the original string was empty.
|
|
dst = append(dst, 0)
|
|
|
|
// Finally, we need to copy dst into a C array for compatibility with
|
|
// C.CString.
|
|
toReturn := C.calloc(C.size_t(len(dst)), 2)
|
|
if toReturn == nil {
|
|
return nil, fmt.Errorf("Error allocating buffer for the utf16 string")
|
|
}
|
|
C.memcpy(toReturn, unsafe.Pointer(&(dst[0])), C.size_t(len(dst))*2)
|
|
|
|
return (*C.char)(toReturn), nil
|
|
}
|