Files
onnxruntime_go/setup_env_windows.go
yalue 11b449bb38 Remove training API
- The training API has been deprecated by onnxruntime itself, and it
   will be much easier to remove it rather than deprecate it.

 - The training API wrapper functions have been replaced by stubs that
   return errors in legacy_types.go.

 - The README mentions the old version required for the training API.

 - The Scalar type has been promoted to onnxruntime_go.go, and a test
   has been added for it.
2024-11-13 20:34:22 -05:00

104 lines
3.3 KiB
Go

//go:build windows
package onnxruntime_go
// This file includes the Windows-specific code for loading the onnxruntime
// library and setting up the environment.
import (
"fmt"
"syscall"
"unicode/utf16"
"unicode/utf8"
"unsafe"
)
// #include "onnxruntime_wrapper.h"
import "C"
// This will contain the handle to the onnxruntime dll if it has been loaded
// successfully.
var libraryHandle syscall.Handle
func platformCleanup() error {
e := syscall.FreeLibrary(libraryHandle)
libraryHandle = 0
return e
}
func platformInitializeEnvironment() error {
if onnxSharedLibraryPath == "" {
onnxSharedLibraryPath = "onnxruntime.dll"
}
handle, e := syscall.LoadLibrary(onnxSharedLibraryPath)
if e != nil {
return fmt.Errorf("Error loading ONNX shared library \"%s\": %w",
onnxSharedLibraryPath, e)
}
getApiBaseProc, e := syscall.GetProcAddress(handle, "OrtGetApiBase")
if e != nil {
syscall.FreeLibrary(handle)
return fmt.Errorf("Error finding OrtGetApiBase function in %s: %w",
onnxSharedLibraryPath, e)
}
ortApiBase, _, e := syscall.SyscallN(uintptr(getApiBaseProc), 0)
if ortApiBase == 0 {
syscall.FreeLibrary(handle)
if e != nil {
return fmt.Errorf("Error calling OrtGetApiBase: %w", e)
} else {
return fmt.Errorf("Error calling OrtGetApiBase")
}
}
tmp := C.SetAPIFromBase((*C.OrtApiBase)(unsafe.Pointer(ortApiBase)))
if tmp != 0 {
syscall.FreeLibrary(handle)
return fmt.Errorf("Error setting ORT API base: %d", tmp)
}
libraryHandle = handle
return nil
}
// Converts the given string to a UTF-16 string, pointed to by a raw
// *C.char. Note that we actually keep ORTCHAR_T defined to char even
// on Windows, so do _not_ index into this string from Cgo code and expect to
// get correct characters! Instead, this should only be used to obtain pointers
// that are passed to onnxruntime windows DLL functions expecting ORTCHAR_T*
// args. This is required because we undefine _WIN32 for cgo compatibility when
// including onnxruntime_c_api.h, but still interact with a DLL that was
// compiled assuming _WIN32 was defined.
//
// The pointer returned by this function must still be freed using C.free when
// no longer needed. This will return an error if the given string contains
// non-UTF8 characters.
func createOrtCharString(str string) (*C.char, error) {
src := []uint8(str)
// Assumed common case: the utf16 buffer contains one uint16 per utf8 byte
// plus one more for the required null terminator in the C buffer.
dst := make([]uint16, 0, len(src)+1)
// Convert UTF-8 to UTF-16 by reading each subsequent rune from src and
// appending it as UTF-16 to dst.
for len(src) > 0 {
r, size := utf8.DecodeRune(src)
if r == utf8.RuneError {
return nil, fmt.Errorf("Invalid UTF-8 rune found in \"%s\"", str)
}
src = src[size:]
dst = utf16.AppendRune(dst, r)
}
// Make sure dst contains the null terminator. Additionally this will cause
// us to return an empty string if the original string was empty.
dst = append(dst, 0)
// Finally, we need to copy dst into a C array for compatibility with
// C.CString.
toReturn := C.calloc(C.size_t(len(dst)), 2)
if toReturn == nil {
return nil, fmt.Errorf("Error allocating buffer for the utf16 string")
}
C.memcpy(toReturn, unsafe.Pointer(&(dst[0])), C.size_t(len(dst))*2)
return (*C.char)(toReturn), nil
}