SourceXtractorPlusPlus  0.19
SourceXtractor++, the next generation SExtractor
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
OnnxModel.cpp
Go to the documentation of this file.
1 /*
2  * OnnxModel.cpp
3  *
4  * Created on: Feb 16, 2021
5  * Author: mschefer
6  */
7 
11 
14 
15 namespace SourceXtractor {
16 
17 OnnxModel::OnnxModel(const std::string& model_path) {
18  m_model_path = model_path;
19 
21  auto allocator = Ort::AllocatorWithDefaultOptions();
22 
23  onnx_logger.info() << "Loading ONNX model " << model_path;
24  m_session = Euclid::make_unique<Ort::Session>(ORT_ENV, model_path.c_str(), Ort::SessionOptions{nullptr});
25 
26  if (m_session->GetOutputCount() != 1) {
27  throw Elements::Exception() << "Only ONNX models with a single output tensor are supported";
28  }
29 
30  for (size_t i=0; i<m_session->GetInputCount(); i++) {
31  auto input_type = m_session->GetInputTypeInfo(i);
32 
33  m_input_names.emplace_back(m_session->GetInputName(i, allocator));
34  m_input_shapes.emplace_back(input_type.GetTensorTypeAndShapeInfo().GetShape());
35  m_input_types.emplace_back(input_type.GetTensorTypeAndShapeInfo().GetElementType());
36  }
37 
38  m_output_name = m_session->GetOutputName(0, allocator);
39  m_domain_name = m_session->GetModelMetadata().GetDomain(allocator);
40  m_graph_name = m_session->GetModelMetadata().GetGraphName(allocator);
41 
42  auto output_type = m_session->GetOutputTypeInfo(0);
43 
44  m_output_shape = output_type.GetTensorTypeAndShapeInfo().GetShape();
45  m_output_type = output_type.GetTensorTypeAndShapeInfo().GetElementType();
46 
47 // onnx_logger.info() << "ONNX model with input of " << formatShape(m_input_shapes[0]);
48 // onnx_logger.info() << "ONNX model with output of " << formatShape(m_output_shape);
49 }
50 
51 }
std::string m_output_name
Output tensor name.
Definition: OnnxModel.h:154
Elements::Logging onnx_logger
Logger for the ONNX plugin.
Definition: OnnxPlugin.cpp:26
void info(const std::string &logMessage)
STL class.
std::vector< std::vector< std::int64_t > > m_input_shapes
Input tensor shape.
Definition: OnnxModel.h:157
STL class.
std::string m_graph_name
graph name
Definition: OnnxModel.h:152
std::unique_ptr< Ort::Session > m_session
Session, one per model. In theory, it is thread-safe.
Definition: OnnxModel.h:160
std::vector< std::int64_t > m_output_shape
Output tensor shape.
Definition: OnnxModel.h:158
std::string m_model_path
Path to the ONNX model.
Definition: OnnxModel.h:159
ONNXTensorElementDataType m_output_type
Output type.
Definition: OnnxModel.h:156
T c_str(T...args)
Ort::Env ORT_ENV
Definition: OnnxCommon.cpp:25
OnnxModel(const std::string &model_path)
Definition: OnnxModel.cpp:17
std::string m_domain_name
domain name
Definition: OnnxModel.h:151
std::vector< ONNXTensorElementDataType > m_input_types
Input type.
Definition: OnnxModel.h:155
static Logging getLogger(const std::string &name="")
T emplace_back(T...args)
std::vector< std::string > m_input_names
Input tensor name.
Definition: OnnxModel.h:153