mirror of
https://github.com/yalue/onnxruntime_go.git
synced 2025-12-24 13:38:00 +08:00
- This change introduces the CustomDataTensor type, which implements the ArbitraryTensor interface but, unlike the typed Tensor[T], is backed by an arbitrary slice of user-provided bytes. The user is responsible for providing the type of data the tensor is supposed to contain, as well as responsible for ensuring the data slice is in the correct format for the specified shape. - Added some test cases for the new CustomDataTensor type, which most notably will enable users to use float16 tensors (provided they converted the float16 array into bytes on their own).
36 lines
991 B
Python
36 lines
991 B
Python
# This script creates example_float16.onnx to use in testing.
|
|
# It takes one input:
|
|
# - "InputA": A 2x2x2 16-bit float16 tensor
|
|
# It produces one output:
|
|
# - "OutputA": A 2x2x2 16-bit bfloat16 tensor
|
|
#
|
|
# The "network" just multiplies each element in the input by 3.0
|
|
import torch
|
|
|
|
class Float16Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, input_a):
|
|
output_a = input_a * 3.0
|
|
output_a = output_a.type(torch.bfloat16)
|
|
return output_a
|
|
|
|
def fake_inputs():
|
|
return torch.rand((1, 2, 2, 2), dtype=torch.float16)
|
|
|
|
def main():
|
|
model = Float16Model()
|
|
model.eval()
|
|
input_a = torch.rand((1, 2, 2, 2), dtype=torch.float16)
|
|
output_a = model(input_a)
|
|
|
|
out_name = "example_float16.onnx"
|
|
torch.onnx.export(model, (input_a), out_name, input_names=["InputA"],
|
|
output_names=["OutputA"])
|
|
print(f"{out_name} saved OK.")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|