Files
onnxruntime_go/test_data/generate_float16_network.py
yalue 42f85c9f58 Add CustomDataTensor type
- This change introduces the CustomDataTensor type, which implements
   the ArbitraryTensor interface but, unlike the typed Tensor[T], is
   backed by an arbitrary slice of user-provided bytes. The user is
   responsible for providing the type of data the tensor is supposed to
   contain, as well as responsible for ensuring the data slice is in the
   correct format for the specified shape.

 - Added some test cases for the new CustomDataTensor type, which most
   notably will enable users to use float16 tensors (provided they
   converted the float16 array into bytes on their own).
2023-09-27 19:49:09 -04:00

36 lines
991 B
Python

# This script creates example_float16.onnx to use in testing.
# It takes one input:
# - "InputA": A 2x2x2 16-bit float16 tensor
# It produces one output:
# - "OutputA": A 2x2x2 16-bit bfloat16 tensor
#
# The "network" just multiplies each element in the input by 3.0
import torch
class Float16Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, input_a):
output_a = input_a * 3.0
output_a = output_a.type(torch.bfloat16)
return output_a
def fake_inputs():
return torch.rand((1, 2, 2, 2), dtype=torch.float16)
def main():
model = Float16Model()
model.eval()
input_a = torch.rand((1, 2, 2, 2), dtype=torch.float16)
output_a = model(input_a)
out_name = "example_float16.onnx"
torch.onnx.export(model, (input_a), out_name, input_names=["InputA"],
output_names=["OutputA"])
print(f"{out_name} saved OK.")
if __name__ == "__main__":
main()