mirror of
https://github.com/yalue/onnxruntime_go.git
synced 2025-11-03 11:30:51 +08:00
Convert paths to UTF16 on Windows, enable path-based APIs - This change adds a function to convert UTF8 to UTF16 strings on Windows, enabling direct usage of the onnxruntime CreateSession functions in lieu of always buffering files and using CreateSessionFromArray. - Adding functionality for converting paths to UTF16 also enables the training API on Windows. - This is an in-progress commit that still may require some touchups, as well as proper test cases for the now-underutilized *WithONNXData functions. - Added a simple .onnx file to test_data with a name containing non-ascii characters. - Used the new file to test that the non-ASCII paths work correctly on Windows and Linux, in both the current and "legacy" session API. - Removed the old "example_network.onnx" and associated tests. This was an overengineered idea from when I first started the library.
25 lines
708 B
Python
25 lines
708 B
Python
# This script generates the .onnx file with a bunch of different special chars
|
|
# in the filename. It takes a 1x2 uint32 tensor and produces a 1x1-element
|
|
# uint32 output containing the sum of the 2 inputs.
|
|
import torch
|
|
|
|
class AddModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, inputs):
|
|
return inputs.sum(1).int()
|
|
|
|
def main():
|
|
model = AddModel()
|
|
model.eval()
|
|
x = torch.ones((1, 2), dtype=torch.int32)
|
|
file_name = "example ż 大 김.onnx"
|
|
torch.onnx.export(model, (x,), file_name, input_names=["in"],
|
|
output_names=["out"])
|
|
print(f"{file_name} saved OK.")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|