SourceXtractorPlusPlus  0.19
SourceXtractor++, the next generation SExtractor
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
OnnxModel.h
Go to the documentation of this file.
1 /*
2  * OnnxModel.h
3  *
4  * Created on: Feb 16, 2021
5  * Author: mschefer
6  */
7 
8 #ifndef _SEIMPLEMENTATION_COMMON_ONNXMODEL_H_
9 #define _SEIMPLEMENTATION_COMMON_ONNXMODEL_H_
10 
11 #include <memory>
12 #include <vector>
13 #include <list>
14 #include <iostream>
15 #include <numeric>
16 
17 #include <onnxruntime_cxx_api.h>
18 
19 namespace SourceXtractor {
20 
21 class OnnxModel {
22 public:
23 
24  explicit OnnxModel(const std::string& model_path);
25 
26  template<typename T, typename U>
27  void run(std::vector<T>& input_data, std::vector<U>& output_data) const {
28  Ort::RunOptions run_options;
29  auto mem_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
30 
31  // Allocate memory
33  input_shape[0] = 1;
34  size_t input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1u, std::multiplies<size_t>());
35 
37  output_shape[0] = 1;
38  size_t output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1u, std::multiplies<size_t>());
39 
40  // Check input and output size are OK
41  if (input_data.size() < input_size || output_data.size() < output_size) {
42  throw Elements::Exception() << "OnnxModel: Insufficient buffer size ";
43  }
44 
45  // Setup input/output tensors
46  auto input_tensor = Ort::Value::CreateTensor<T>(
47  mem_info, input_data.data(), input_data.size(), input_shape.data(), input_shape.size());
48  auto output_tensor = Ort::Value::CreateTensor<U>(
49  mem_info, output_data.data(), output_data.size(), output_shape.data(), output_shape.size());
50 
51  // Run the model
52  const char *input_name = m_input_names[0].c_str();
53  const char *output_name = m_output_name.c_str();
54 
55  m_session->Run(run_options, &input_name, &input_tensor, 1, &output_name, &output_tensor, 1);
56  }
57 
58  template<typename T, typename U>
59  void runMultiInput(std::map<std::string, std::vector<T>>& input_data, std::vector<U>& output_data) const {
60  Ort::RunOptions run_options;
61  auto mem_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
62 
63  std::vector<const char *> input_names;
64  std::vector<Ort::Value> input_tensors;
65 
66  int inputs_nb = m_input_names.size();
67  for (int i=0; i<inputs_nb; i++) {
68  input_names.emplace_back(m_input_names[i].c_str());
69 
70  // Allocate memory
72  input_shape[0] = 1;
73  size_t input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1u, std::multiplies<size_t>());
74 
75  // Check input size is OK
76  if (input_data[m_input_names[i]].size() < input_size) {
77  throw Elements::Exception() << "OnnxModel: Insufficient buffer size ";
78  }
79 
80  input_tensors.emplace_back(Ort::Value::CreateTensor<T>(
81  mem_info, input_data[m_input_names[i]].data(), input_data[m_input_names[i]].size(),
82  input_shape.data(), input_shape.size()));
83  }
84 
85  // Output name and shape
86  const char *output_name = m_output_name.c_str();
88  output_shape[0] = 1;
89 
90  // Setup output tensor
91  size_t output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1u, std::multiplies<size_t>());
92 
93  // Check output and output size are OK
94  if (output_data.size() < output_size) {
95  throw Elements::Exception() << "OnnxModel: Insufficient buffer size ";
96  }
97 
98  auto output_tensor = Ort::Value::CreateTensor<U>(
99  mem_info, output_data.data(), output_data.size(), output_shape.data(), output_shape.size());
100 
101  // Run the model
102  m_session->Run(run_options, &input_names[0], &input_tensors[0], inputs_nb, &output_name, &output_tensor, 1);
103  }
104 
105 
106  ONNXTensorElementDataType getInputType() const {
107  return m_input_types[0];
108  }
109 
110  ONNXTensorElementDataType getOutputType() const {
111  return m_output_type;
112  }
113 
115  return m_input_shapes[0];
116  }
117 
119  return m_output_shape;
120  }
121 
123  return m_domain_name;
124  }
125 
127  return m_graph_name;
128  }
129 
131  return m_input_names[0];
132  }
133 
135  return m_output_name;
136  }
137 
139  return m_model_path;
140  }
141 
142  size_t getInputNb() const {
143  return m_input_names.size();
144  }
145 
146  size_t getOutputNb() const {
147  return 1U;
148  }
149 
150 private:
156  ONNXTensorElementDataType m_output_type;
161 };
162 
163 }
164 
165 
166 #endif /* _SEIMPLEMENTATION_COMMON_ONNXMODEL_H_ */
std::string m_output_name
Output tensor name.
Definition: OnnxModel.h:154
std::string getOutputName() const
Definition: OnnxModel.h:134
std::string getInputName() const
Definition: OnnxModel.h:130
const std::vector< std::int64_t > & getOutputShape() const
Definition: OnnxModel.h:118
void run(std::vector< T > &input_data, std::vector< U > &output_data) const
Definition: OnnxModel.h:27
T end(T...args)
const std::vector< std::int64_t > & getInputShape() const
Definition: OnnxModel.h:114
STL class.
std::vector< std::vector< std::int64_t > > m_input_shapes
Input tensor shape.
Definition: OnnxModel.h:157
STL class.
ONNXTensorElementDataType getInputType() const
Definition: OnnxModel.h:106
std::string getDomain() const
Definition: OnnxModel.h:122
std::string getGraphName() const
Definition: OnnxModel.h:126
std::string m_graph_name
graph name
Definition: OnnxModel.h:152
T data(T...args)
void runMultiInput(std::map< std::string, std::vector< T >> &input_data, std::vector< U > &output_data) const
Definition: OnnxModel.h:59
ONNXTensorElementDataType getOutputType() const
Definition: OnnxModel.h:110
std::string getModelPath() const
Definition: OnnxModel.h:138
size_t getOutputNb() const
Definition: OnnxModel.h:146
std::unique_ptr< Ort::Session > m_session
Session, one per model. In theory, it is thread-safe.
Definition: OnnxModel.h:160
std::vector< std::int64_t > m_output_shape
Output tensor shape.
Definition: OnnxModel.h:158
std::string m_model_path
Path to the ONNX model.
Definition: OnnxModel.h:159
T size(T...args)
ONNXTensorElementDataType m_output_type
Output type.
Definition: OnnxModel.h:156
STL class.
T begin(T...args)
T c_str(T...args)
OnnxModel(const std::string &model_path)
Definition: OnnxModel.cpp:17
std::string m_domain_name
domain name
Definition: OnnxModel.h:151
std::vector< ONNXTensorElementDataType > m_input_types
Input type.
Definition: OnnxModel.h:155
T accumulate(T...args)
size_t getInputNb() const
Definition: OnnxModel.h:142
T emplace_back(T...args)
std::vector< std::string > m_input_names
Input tensor name.
Definition: OnnxModel.h:153