mirror of
https://github.com/yalue/onnxruntime_go.git
synced 2025-10-04 14:53:07 +08:00
Initial Sequence support
- This adds support for ONNX sequences, along with basic unit tests for the sequence types. - This also introduces a new test network generated via sklearn in an included script. - The test using sklearn_randomforest.onnx has not been written yet, but I wanted to commit after doing all the work so far. The existing tests all pass.
This commit is contained in:
45
test_data/generate_sklearn_network.py
Normal file
45
test_data/generate_sklearn_network.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# This script is a modified version of the example from
|
||||
# https://pypi.org/project/skl2onnx/, which we use to produce
|
||||
# sklearn_randomforest.onnx. sklearn makes heavy use of onnxruntime maps and
|
||||
# sequences in its networks, so this is used for testing those data types.
|
||||
|
||||
import numpy as np
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
|
||||
iris = load_iris()
|
||||
inputs, outputs = iris.data, iris.target
|
||||
inputs = inputs.astype(np.float32)
|
||||
inputs_train, inputs_test, outputs_train, outputs_test = train_test_split(inputs, outputs)
|
||||
classifier = RandomForestClassifier()
|
||||
classifier.fit(inputs_train, outputs_train)
|
||||
|
||||
# Convert into ONNX format.
|
||||
from skl2onnx import to_onnx
|
||||
output_filename = "sklearn_randomforest.onnx"
|
||||
onnx_content = to_onnx(classifier, inputs[:1])
|
||||
with open(output_filename, "wb") as f:
|
||||
f.write(onnx_content.SerializeToString())
|
||||
|
||||
# Compute the prediction with onnxruntime.
|
||||
import onnxruntime as ort
|
||||
|
||||
def float_formatter(f):
|
||||
return f"{float(f):.06f}"
|
||||
|
||||
np.set_printoptions(formatter = {'float_kind': float_formatter})
|
||||
session = ort.InferenceSession(output_filename)
|
||||
print(f"Input names: {[n.name for n in session.get_inputs()]!s}")
|
||||
print(f"Output names: {[o.name for o in session.get_outputs()]!s}")
|
||||
example_inputs = inputs_test.astype(np.float32)[:6]
|
||||
print(f"Inputs shape = {example_inputs.shape!s}")
|
||||
onnx_predictions = session.run(["output_label", "output_probability"],
|
||||
{"X": example_inputs})
|
||||
labels = onnx_predictions[0]
|
||||
probabilities = onnx_predictions[1]
|
||||
|
||||
print(f"Inputs to network: {example_inputs.astype(np.float32)}")
|
||||
print(f"ONNX predicted labels: {labels!s}")
|
||||
print(f"ONNX predicted probabilities: {probabilities!s}")
|
||||
|
Reference in New Issue
Block a user