mirror of
https://github.com/yalue/onnxruntime_go.git
synced 2025-09-26 19:31:13 +08:00
Remove more training API stuff
- Removed testing data that was used only with onnxruntime_training_test.go, which was already removed. - Renamed legacy_types.go to legacy_code.go, since it's now a parking ground for legacy code that's not just "types". - Fixed some references to the training API in CONTRIBUTING.md.
This commit is contained in:
@@ -41,8 +41,7 @@ Tests
|
||||
-----
|
||||
|
||||
- All new features and bugfixes must include a basic unit test (in
|
||||
`onnxruntime_test.go` or `onnxruntime_training_test.go`) to serve as a
|
||||
sanity check.
|
||||
`onnxruntime_test.go`) to serve as a sanity check.
|
||||
|
||||
- If a test is for a platform-dependent or execution-provider-dependent
|
||||
feature, the test must be skipped if run on an unsupported system.
|
||||
@@ -62,8 +61,8 @@ Adding New Files
|
||||
|
||||
- Apart from testing data, try not to add new source files.
|
||||
|
||||
- Do not add third-party code or headers. The only exceptions for now are
|
||||
`onnxruntime_c_api.h` and `onnxruntime_training_c_api.h`.
|
||||
- Do not add third-party code or headers. The only exception for now is
|
||||
`onnxruntime_c_api.h`.
|
||||
|
||||
- No C++ at all. Developing Go-to-C wrappers is annoying enough as it is.
|
||||
|
||||
|
@@ -1,8 +1,7 @@
|
||||
package onnxruntime_go
|
||||
|
||||
// This file contains Session types that we maintain for compatibility
|
||||
// purposes; the main onnxruntime_go.go file is dedicated to AdvancedSession
|
||||
// now.
|
||||
// This file contains code and types that we maintain for compatibility
|
||||
// purposes, but is not expected to be regularly maintained or udpated.
|
||||
|
||||
import (
|
||||
"fmt"
|
@@ -882,8 +882,7 @@ func (t *CustomDataTensor) GetData() []byte {
|
||||
|
||||
// Scalar is like a tensor but the underlying go slice is of length 1 and it
|
||||
// has no dimension. It was introduced for use with the training API, but
|
||||
// remains supported since it conceivable will have use outside of the training
|
||||
// API.
|
||||
// remains supported since it may be useful apart from the training API.
|
||||
type Scalar[T TensorData] struct {
|
||||
data []T
|
||||
dataSize uintptr
|
||||
|
@@ -1,55 +0,0 @@
|
||||
import torch
|
||||
from torch.nn.functional import relu
|
||||
from pathlib import Path
|
||||
import onnx
|
||||
import onnxruntime.training.artifacts as artifacts
|
||||
|
||||
class SumAndDiffModel(torch.nn.Module):
|
||||
""" Just a standard, fairly minimal, pytorch model for generating the NN.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# We'll do four 1x4 convolutions to make the network more interesting.
|
||||
self.conv = torch.nn.Conv1d(1, 4, 4)
|
||||
# We'll follow the conv with a FC layer to produce the outputs. The
|
||||
# input to the FC layer are the 4 conv outputs concatenated with the
|
||||
# original input.
|
||||
self.fc = torch.nn.Linear(8, 2)
|
||||
|
||||
def forward(self, data):
|
||||
batch_size = data.shape[0]
|
||||
conv_out = relu(self.conv(data))
|
||||
conv_flattened = torch.flatten(conv_out, start_dim=1)
|
||||
data_flattened = torch.flatten(data, start_dim=1)
|
||||
combined = torch.cat((conv_flattened, data_flattened), dim=1)
|
||||
output = relu(self.fc(combined))
|
||||
output = output.view(batch_size, 1, 2)
|
||||
return output
|
||||
|
||||
def main():
|
||||
model = SumAndDiffModel()
|
||||
|
||||
# Export the model to ONNX.
|
||||
training_artifacts_path = Path(".", "training_test")
|
||||
training_artifacts_path.mkdir(exist_ok=True, parents=True)
|
||||
model_name = "training_network"
|
||||
torch.onnx.export(model, torch.zeros(10, 1, 4),
|
||||
Path(".", "training_test", f"{model_name}.onnx").__str__(),
|
||||
input_names=["input"], output_names=["output"])
|
||||
|
||||
# Load the onnx model and generate artifacts
|
||||
onnx_model = onnx.load(Path("training_test", "training_network.onnx"))
|
||||
requires_grad = ["conv.weight", "conv.bias", "fc.weight", "fc.bias"]
|
||||
model.train()
|
||||
|
||||
# Generate the training artifacts.
|
||||
artifacts.generate_artifacts(
|
||||
onnx_model,
|
||||
requires_grad=requires_grad,
|
||||
frozen_params=[],
|
||||
loss=artifacts.LossType.L1Loss,
|
||||
optimizer=artifacts.OptimType.AdamW,
|
||||
artifact_directory=Path(".", "training_test"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user