mirror of
https://github.com/yalue/onnxruntime_go.git
synced 2025-10-07 08:01:14 +08:00

- Sets up the onnxruntime environment, but doesn't load or run networks yet. - The things builds and run on Windows. - Still working on getting the Linux (arm64 for now) test to work.
70 lines
3.0 KiB
C++
70 lines
3.0 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
// Summary: The experimental Ort C++ API is a wrapper around the Ort C++ API.
|
|
//
|
|
// This C++ API further simplifies usage and provides support for modern C++ syntax/features
|
|
// at the cost of some overhead (i.e. using std::string over char *).
|
|
//
|
|
// Where applicable, default memory allocator options are used unless explicitly set.
|
|
//
|
|
// Experimental components are designed as drop-in replacements of the regular API, requiring
|
|
// minimal code modifications to use.
|
|
//
|
|
// Example: Ort::Session -> Ort::Experimental::Session
|
|
//
|
|
// NOTE: Experimental API components are subject to change based on feedback and provide no
|
|
// guarantee of backwards compatibility in future releases.
|
|
|
|
#pragma once
|
|
#include "onnxruntime_cxx_api.h"
|
|
|
|
namespace Ort {
|
|
namespace Experimental {
|
|
|
|
struct Session : Ort::Session {
|
|
Session(Env& env, std::basic_string<ORTCHAR_T>& model_path, SessionOptions& options)
|
|
: Ort::Session(env, model_path.data(), options){};
|
|
Session(Env& env, void* model_data, size_t model_data_length, SessionOptions& options)
|
|
: Ort::Session(env, model_data, model_data_length, options){};
|
|
|
|
// overloaded Run() with sensible defaults
|
|
std::vector<Ort::Value> Run(const std::vector<std::string>& input_names,
|
|
const std::vector<Ort::Value>& input_values,
|
|
const std::vector<std::string>& output_names,
|
|
const RunOptions& run_options = RunOptions());
|
|
void Run(const std::vector<std::string>& input_names,
|
|
const std::vector<Ort::Value>& input_values,
|
|
const std::vector<std::string>& output_names,
|
|
std::vector<Ort::Value>& output_values,
|
|
const RunOptions& run_options = RunOptions());
|
|
|
|
// convenience methods that simplify common lower-level API calls
|
|
std::vector<std::string> GetInputNames() const;
|
|
std::vector<std::string> GetOutputNames() const;
|
|
std::vector<std::string> GetOverridableInitializerNames() const;
|
|
|
|
// NOTE: shape dimensions may have a negative value to indicate a symbolic/unknown dimension.
|
|
std::vector<std::vector<int64_t> > GetInputShapes() const;
|
|
std::vector<std::vector<int64_t> > GetOutputShapes() const;
|
|
std::vector<std::vector<int64_t> > GetOverridableInitializerShapes() const;
|
|
};
|
|
|
|
struct Value : Ort::Value {
|
|
Value(OrtValue* p)
|
|
: Ort::Value(p){};
|
|
|
|
template <typename T>
|
|
static Ort::Value CreateTensor(T* p_data, size_t p_data_element_count, const std::vector<int64_t>& shape);
|
|
static Ort::Value CreateTensor(void* p_data, size_t p_data_byte_count, const std::vector<int64_t>& shape, ONNXTensorElementDataType type);
|
|
|
|
template <typename T>
|
|
static Ort::Value CreateTensor(const std::vector<int64_t>& shape);
|
|
static Ort::Value CreateTensor(const std::vector<int64_t>& shape, ONNXTensorElementDataType type);
|
|
};
|
|
|
|
}
|
|
}
|
|
|
|
#include "experimental_onnxruntime_cxx_inline.h"
|