1 #include <torch/csrc/jit/serialization/export.h>
2
3 #include <ATen/ATen.h>
4 #include <ATen/Utils.h>
5 #include <ATen/core/functional.h>
6 #include <c10/macros/Macros.h>
7 #include <c10/util/Exception.h>
8 #include <c10/util/accumulate.h>
9 #include <c10/util/irange.h>
10 #include <torch/csrc/autograd/symbolic.h>
11 #include <torch/csrc/jit/jit_log.h>
12 #include <torch/csrc/jit/passes/dead_code_elimination.h>
13 #include <torch/csrc/jit/passes/inliner.h>
14 #include <torch/csrc/jit/runtime/instruction.h>
15 #include <torch/csrc/jit/serialization/import_export_constants.h>
16 #include <torch/csrc/jit/serialization/import_export_functions.h>
17 #include <torch/csrc/jit/serialization/import_export_helpers.h>
18 #include <torch/csrc/jit/serialization/onnx.h>
19 #include <torch/csrc/onnx/back_compat.h>
20 #include <torch/csrc/onnx/onnx.h>
21 #include <torch/version.h>
22 #include <optional>
23
24 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wnewline-eof")
25 #include <onnx/checker.h>
26 C10_DIAGNOSTIC_POP()
27 #include <onnx/onnx_pb.h>
28 #include <onnx/proto_utils.h>
29 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wsuggest-override")
30 #include <onnx/shape_inference/implementation.h>
31 C10_DIAGNOSTIC_POP()
32
33 #include <memory>
34 #include <regex>
35 #include <set>
36 #include <sstream>
37 #include <string>
38 #include <utility>
39 #include <vector>
40
41 namespace torch::jit {
42
get_little_endian_data(const at::Tensor & t)43 static std::string get_little_endian_data(const at::Tensor& t) {
44 #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
45 return std::string(
46 static_cast<char*>(t.data_ptr()), t.element_size() * t.numel());
47 #elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
48 const size_t element_size = t.element_size();
49 const size_t num_elements = t.numel();
50
51 std::vector<char> data_copy{
52 static_cast<char*>(t.data_ptr()),
53 static_cast<char*>(t.data_ptr()) + element_size * num_elements};
54
55 for (size_t i = 0; i < num_elements; ++i) {
56 char* start_byte = data_copy.data() + i * element_size;
57 char* end_byte = start_byte + element_size - 1;
58 /* keep swapping */
59 for (size_t count = 0; count < element_size / 2; ++count) {
60 std::swap(*start_byte, *end_byte);
61 ++start_byte;
62 --end_byte;
63 }
64 }
65
66 return std::string(data_copy.data(), element_size * num_elements);
67 #else
68 #error Unexpected or undefined __BYTE_ORDER__
69 #endif
70 }
71
writeArchiveAndTensors(const std::string & archive_name,const char * data,size_t size,const std::vector<at::Tensor> & tensors,caffe2::serialize::PyTorchStreamWriter & out)72 void writeArchiveAndTensors(
73 const std::string& archive_name,
74 const char* data,
75 size_t size,
76 const std::vector<at::Tensor>& tensors,
77 caffe2::serialize::PyTorchStreamWriter& out) {
78 std::string prefix = archive_name + "/";
79 size_t i = 0;
80 for (const auto& td : tensors) {
81 WriteableTensorData writable_td = getWriteableTensorData(td);
82 std::string fname = prefix + std::to_string(i++);
83 out.writeRecord(fname, writable_td.data(), writable_td.sizeInBytes());
84 }
85 std::string fname = archive_name + ".pkl";
86 out.writeRecord(fname, data, size);
87 }
88
89 namespace {
90 namespace onnx_torch = ::torch::onnx;
91 namespace onnx = ::ONNX_NAMESPACE;
92
93 const static int kInvalidOpsetVersion = -1;
94 const static int kMainOpsetVersion = 20;
95 // Based on OP_SET_ID_VERSION_MAP in
96 // https://github.com/onnx/onnx/blob/master/onnx/helper.py.
97 constexpr static std::array<int64_t, kMainOpsetVersion + 1>
98 kOpsetVersionToIRVersion = {
99 kInvalidOpsetVersion,
100 3, // opset 1
101 kInvalidOpsetVersion,
102 kInvalidOpsetVersion,
103 kInvalidOpsetVersion,
104 3, // opset 5
105 3, // opset 6
106 3, // opset 7
107 3, // opset 8
108 4, // opset 9
109 5, // opset 10
110 6, // opset 11
111 7, // opset 12
112 7, // opset 13
113 7, // opset 14
114 8, // opset 15
115 8, // opset 16
116 8, // opset 17
117 8, // opset 18
118 9, // opset 19
119 9, // opset 20
120 };
121
getNodeStackTraceString(const Node * n)122 std::string getNodeStackTraceString(const Node* n) {
123 return n->sourceRange().str();
124 }
125
validateBlock(Block * b,onnx_torch::OperatorExportTypes operator_export_type)126 void validateBlock(
127 Block* b,
128 onnx_torch::OperatorExportTypes operator_export_type) {
129 for (auto node : b->nodes()) {
130 for (Block* sub_block : node->blocks()) {
131 validateBlock(sub_block, operator_export_type);
132 }
133 // Macro'ed so we get a marginally better line number on failed export
134 #define FAIL_EXPORT(name) \
135 throw std::runtime_error( \
136 std::string("ONNX export failed: ") + name + \
137 "\n\nGraph we tried to export:\n" + b->owningGraph()->toString());
138 // Special error messages for certain types of operators
139 if (node->kind() == prim::PythonOp) {
140 if (operator_export_type !=
141 onnx_torch::OperatorExportTypes::ONNX_FALLTHROUGH) {
142 auto py_node = static_cast<PythonOp*>(node);
143 FAIL_EXPORT(
144 "Couldn't export Python operator " + py_node->name() +
145 "\n\nDefined at:\n" + getNodeStackTraceString(node))
146 }
147 } else {
148 if (node->kind() == prim::PackPadded || node->kind() == prim::PadPacked) {
149 if (operator_export_type !=
150 onnx_torch::OperatorExportTypes::ONNX_FALLTHROUGH) {
151 FAIL_EXPORT(
152 "Cannot export individual pack_padded_sequence or pad_packed_sequence; these operations must occur in pairs.\n\nUsage of this operation occurred at:\n" +
153 getNodeStackTraceString(node));
154 }
155 }
156 bool is_aten_enabled = operator_export_type ==
157 onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK ||
158 operator_export_type == onnx_torch::OperatorExportTypes::ONNX_ATEN ||
159 operator_export_type ==
160 onnx_torch::OperatorExportTypes::ONNX_FALLTHROUGH;
161 if (node->kind().is_aten() && !is_aten_enabled && !node->mustBeNone()) {
162 FAIL_EXPORT(
163 "Couldn't export operator " + node->kind().toDisplayString() +
164 "\n\nDefined at:\n" + getNodeStackTraceString(node));
165 }
166 }
167 #undef FAIL_EXPORT
168 }
169 }
170
validateGraph(const std::shared_ptr<Graph> & graph,onnx_torch::OperatorExportTypes operator_export_type)171 void validateGraph(
172 const std::shared_ptr<Graph>& graph,
173 onnx_torch::OperatorExportTypes operator_export_type) {
174 validateBlock(graph->block(), operator_export_type);
175 }
176
GetFileRootPath(const std::string & rootPath)177 std::string GetFileRootPath(const std::string& rootPath) {
178 std::string rootPath_ = rootPath;
179 // First, making slash consistent.
180 std::replace(rootPath_.begin(), rootPath_.end(), '\\', '/');
181 // Second, remove trailing slashes, if any
182 std::regex trailer("/+$");
183 std::string root = std::regex_replace(rootPath_, trailer, std::string());
184 std::string folder = root.substr(0, root.find_last_of('/'));
185 if (folder == rootPath_) { // If no root folder specified, select cwd.
186 return std::string(".");
187 }
188 return folder;
189 }
190
GetExternalFileName(const std::optional<std::string> & external_ref)191 std::string GetExternalFileName(
192 const std::optional<std::string>& external_ref) {
193 auto tensorName = external_ref.value();
194 const std::string illegalChars = "\\/:?\"<>|";
195 for (char& i : tensorName) {
196 if (illegalChars.find(i) != std::string::npos) {
197 i = '_';
198 }
199 }
200 return tensorName;
201 }
202
CloseFile(FILE * fp)203 void CloseFile(FILE* fp) {
204 fclose(fp);
205 }
206
CreateExternalFile(const at::Tensor & tensor,const std::string & tensorName,const std::string & onnx_file_path)207 void CreateExternalFile(
208 const at::Tensor& tensor,
209 const std::string& tensorName,
210 const std::string& onnx_file_path) {
211 auto folder = GetFileRootPath(onnx_file_path);
212 std::string fullFilePath = folder + "/" + tensorName;
213 std::unique_ptr<FILE, decltype(&CloseFile)> fp(
214 fopen(fullFilePath.c_str(), "wb"), &CloseFile);
215 if (fp == nullptr) {
216 throw std::runtime_error(
217 std::string("ONNX export failed. Could not open file or directory: ") +
218 fullFilePath);
219 }
220 std::string s = get_little_endian_data(tensor);
221 fwrite(s.c_str(), tensor.element_size(), tensor.numel(), fp.get());
222 } // fclose() called here through CloseFile(), if FILE* is not a null pointer.
223
224 class GraphEncoder {
225 public:
226 GraphEncoder(
227 const std::shared_ptr<Graph>& graph,
228 int64_t onnx_opset_version,
229 onnx_torch::OperatorExportTypes operator_export_type,
230 const std::map<std::string, at::Tensor>& initializers,
231 const std::unordered_map<
232 std::string,
233 std::unordered_map<int64_t, std::string>>& dynamic_axes,
234 bool defer_weight_export,
235 bool strip_doc,
236 bool keep_initializers_as_inputs,
237 const std::map<std::string, int>& custom_opsets,
238 bool add_node_names,
239 bool use_external_data_format,
240 const std::string& onnx_file_path,
241 NodeAttrNameMap node_attr_to_name = {});
242
get_model_proto()243 std::shared_ptr<onnx::ModelProto> get_model_proto() {
244 return model_proto_;
245 }
246
get_symbol_dim_param_map()247 SymbolDimMap get_symbol_dim_param_map() {
248 return symbol_dim_map_;
249 }
250
get_raw_data_export_map()251 RawDataExportMap get_raw_data_export_map() {
252 return raw_data_export_map_;
253 }
254
get_use_external_data_format()255 bool get_use_external_data_format() {
256 return use_external_data_format_;
257 }
258
get_onnx_node_names()259 NodeNameMap get_onnx_node_names() {
260 return onnx_node_name_map_;
261 }
262
263 private:
264 // Using std::map instead of std::unordered_map for initializers
265 // in EncodeGraph constructor so that the order in which initializers
266 // get written to the ONNX graph is always the deterministic and
267 // predictable. While this is not a ONNX requirement, it is needed
268 // for testing purposes in tests that use _export_to_pretty_string()
269 // for validating ONNX graphs.
270 void EncodeGraph(
271 onnx::GraphProto* graph_proto,
272 const std::shared_ptr<Graph>& graph,
273 const std::map<std::string, at::Tensor>& initializers =
274 std::map<std::string, at::Tensor>(),
275 const std::
276 unordered_map<std::string, std::unordered_map<int64_t, std::string>>&
277 dynamic_axes = std::unordered_map<
278 std::string,
279 std::unordered_map<int64_t, std::string>>(),
280 bool keep_initializers_as_inputs = true,
281 bool add_node_names = true,
282 bool use_external_data_format = false,
283 const std::string& onnx_file_path = std::string());
284
285 void EncodeBlock(
286 onnx::GraphProto* graph_proto,
287 const Block* block,
288 const std::map<std::string, at::Tensor>& initializers =
289 std::map<std::string, at::Tensor>(),
290 const std::
291 unordered_map<std::string, std::unordered_map<int64_t, std::string>>&
292 dynamic_axes = std::unordered_map<
293 std::string,
294 std::unordered_map<int64_t, std::string>>(),
295 bool keep_initializers_as_inputs = true,
296 bool add_node_names = true,
297 bool use_external_data_format = false,
298 const std::string& onnx_file_path = std::string());
299
300 void AddInitializersIntoGraphProto(
301 onnx::GraphProto* graph_proto,
302 const Block* block,
303 const std::map<std::string, at::Tensor>& initializers =
304 std::map<std::string, at::Tensor>(),
305 bool use_external_data_format = false,
306 const std::string& onnx_file_path = std::string());
307
308 unsigned long long int GetGraphProtoSize(
309 onnx::GraphProto* graph_proto,
310 const std::shared_ptr<Graph>& graph,
311 bool add_node_names,
312 bool use_external_data_format,
313 const std::string& onnx_file_path,
314 const std::map<std::string, at::Tensor>& initializers =
315 std::map<std::string, at::Tensor>());
316
317 void EncodeNode(
318 onnx::GraphProto* graph_proto,
319 onnx::NodeProto* node_proto,
320 const Node* node,
321 bool add_node_names = true,
322 bool use_external_data_format = false,
323 const std::string& onnx_file_path = std::string());
324
325 void EncodeTypeProto(
326 onnx::TypeProto* type_proto,
327 const TypePtr& node_type,
328 const std::string& name);
329
330 void EncodeLocalFunctionOpsetImport(
331 onnx::FunctionProto* func_proto,
332 const Node* n,
333 std::unordered_set<std::string>& custom_domains);
334
335 void EncodeLocalFunction(
336 onnx::GraphProto* graph_proto,
337 onnx::FunctionProto* func_proto,
338 const Node* n,
339 bool add_node_names = true,
340 bool use_external_data_format = false,
341 const std::string& onnx_file_path = std::string());
342
343 void EncodeTensor(
344 onnx::TensorProto* tensor_proto,
345 const at::Tensor& tensor,
346 const std::optional<std::string>& external_ref = {},
347 const bool use_external_data_format = false,
348 const std::string& onnx_file_path = std::string());
349
350 void EncodeIntermediateValueInfo(
351 onnx::GraphProto* graph_proto,
352 const Value* n);
353
354 void EncodeValueInfo(
355 onnx::GraphProto* graph_proto,
356 onnx::ValueInfoProto* v,
357 const Value* n,
358 const std::
359 unordered_map<std::string, std::unordered_map<int64_t, std::string>>&
360 dynamic_axes = std::unordered_map<
361 std::string,
362 std::unordered_map<int64_t, std::string>>());
363
364 void EncodeValueInfoType(
365 onnx::TypeProto* onnx_type,
366 const TypePtr& node_type,
367 const Value* n,
368 const std::unordered_map<
369 std::string,
370 std::unordered_map<int64_t, std::string>>& dynamic_axes);
371
372 void AddAttribute(
373 onnx::NodeProto* node_proto,
374 const jit::Symbol name,
375 const std::string& ref_attr_name,
376 const AttributeKind attr_kind);
377
378 void AddAttribute(
379 onnx::NodeProto* node_proto,
380 const jit::Node* node,
381 const jit::Symbol name,
382 const bool use_external_data_format = false,
383 const std::string& onnx_file_path = std::string());
384
385 void AddAttribute(onnx::FunctionProto* func_proto, const std::string& name);
386
387 void TensorTypeToONNXType(
388 const TensorTypePtr& tensor_type,
389 const std::string& dim_name_prefix,
390 const std::string& name,
391 const std::unordered_map<
392 std::string,
393 std::unordered_map<int64_t, std::string>>& dynamic_axes,
394 onnx::TypeProto_Tensor* onnx_tensor_type,
395 bool assign_dim_param = true);
396
397 SymbolDimMap symbol_dim_map_;
398 std::shared_ptr<onnx::ModelProto> model_proto_;
399 size_t num_blocks_{0};
400 size_t num_op_nodes_{0};
401 size_t num_external_data_{0};
402 onnx_torch::OperatorExportTypes operator_export_type_;
403 bool strip_doc_;
404 std::set<std::string> domains_;
405 RawDataExportMap raw_data_export_map_;
406 bool defer_weight_export_;
407 bool use_external_data_format_;
408 int64_t onnx_opset_version_;
409 std::map<std::string, int> custom_opsets_;
410 std::shared_ptr<Graph> graph_;
411 NodeAttrNameMap node_attr_to_name_;
412 NodeNameMap onnx_node_name_map_;
413 // For large models, the parameters can be stored in separate binary files.
414 // This parameter sets a threshold on the number of elements in the parameter
415 // tensor, beyond which the parameter is stored in a separate file (if
416 // use_external_data_format_ is True). This threshold is in place
417 // so as not to create too many external files.
418 static constexpr size_t ParamSizeThresholdForExternalStorage = 1024;
419 };
420
ATenTypeToOnnxType(at::ScalarType at_type)421 onnx::TensorProto_DataType ATenTypeToOnnxType(at::ScalarType at_type) {
422 switch (at_type) {
423 case at::kDouble:
424 return onnx::TensorProto_DataType_DOUBLE;
425 case at::kFloat:
426 return onnx::TensorProto_DataType_FLOAT;
427 case at::kHalf:
428 return onnx::TensorProto_DataType_FLOAT16;
429 case at::kByte:
430 return onnx::TensorProto_DataType_UINT8;
431 case at::kChar:
432 return onnx::TensorProto_DataType_INT8;
433 case at::kShort:
434 return onnx::TensorProto_DataType_INT16;
435 case at::kInt:
436 return onnx::TensorProto_DataType_INT32;
437 case at::kLong:
438 return onnx::TensorProto_DataType_INT64;
439 case at::kBool:
440 return onnx::TensorProto_DataType_BOOL;
441 case at::kQInt8:
442 return onnx::TensorProto_DataType_INT8;
443 case at::kQUInt8:
444 return onnx::TensorProto_DataType_UINT8;
445 case at::kQInt32:
446 return onnx::TensorProto_DataType_INT32;
447 case at::kBFloat16:
448 return onnx::TensorProto_DataType_BFLOAT16;
449 case at::kFloat8_e4m3fn:
450 return onnx_torch::TensorProto_DataType_FLOAT8E4M3FN;
451 case at::kFloat8_e5m2:
452 return onnx_torch::TensorProto_DataType_FLOAT8E5M2;
453 case at::kFloat8_e4m3fnuz:
454 return onnx_torch::TensorProto_DataType_FLOAT8E4M3FNUZ;
455 case at::kFloat8_e5m2fnuz:
456 return onnx_torch::TensorProto_DataType_FLOAT8E5M2FNUZ;
457 default:
458 TORCH_CHECK(
459 false,
460 "ScalarType ",
461 toString(at_type),
462 " is an unexpected tensor scalar type");
463 }
464 }
465
ATenAttributeKindToOnnxAttributeType(AttributeKind at_kind,const jit::Symbol name)466 onnx::AttributeProto_AttributeType ATenAttributeKindToOnnxAttributeType(
467 AttributeKind at_kind,
468 const jit::Symbol name) {
469 switch (at_kind) {
470 case AttributeKind::f:
471 return onnx::AttributeProto_AttributeType_FLOAT;
472 case AttributeKind::fs:
473 return onnx::AttributeProto_AttributeType_FLOATS;
474 case AttributeKind::i:
475 return onnx::AttributeProto_AttributeType_INT;
476 case AttributeKind::is:
477 return onnx::AttributeProto_AttributeType_INTS;
478 case AttributeKind::s:
479 return onnx::AttributeProto_AttributeType_STRING;
480 case AttributeKind::ss:
481 return onnx::AttributeProto_AttributeType_STRINGS;
482 case AttributeKind::t:
483 return onnx::AttributeProto_AttributeType_TENSOR;
484 case AttributeKind::ts:
485 return onnx::AttributeProto_AttributeType_TENSORS;
486 case AttributeKind::ty:
487 return onnx::AttributeProto_AttributeType_TYPE_PROTO;
488 case AttributeKind::tys:
489 return onnx::AttributeProto_AttributeType_TYPE_PROTOS;
490 case AttributeKind::g:
491 return onnx::AttributeProto_AttributeType_GRAPH;
492 case AttributeKind::gs:
493 return onnx::AttributeProto_AttributeType_GRAPHS;
494 default:
495 std::ostringstream err_msg;
496 err_msg << "attribute \"" << name.toDisplayString()
497 << "\" has unexpected kind: " << toString(at_kind);
498 throw std::runtime_error(err_msg.str());
499 }
500 }
501
GraphEncoder(const std::shared_ptr<Graph> & graph,int64_t onnx_opset_version,onnx_torch::OperatorExportTypes operator_export_type,const std::map<std::string,at::Tensor> & initializers,const std::unordered_map<std::string,std::unordered_map<int64_t,std::string>> & dynamic_axes,bool defer_weight_export,bool strip_doc,bool keep_initializers_as_inputs,const std::map<std::string,int> & custom_opsets,bool add_node_names,bool use_external_data_format,const std::string & onnx_file_path,NodeAttrNameMap node_attr_to_name)502 GraphEncoder::GraphEncoder(
503 const std::shared_ptr<Graph>& graph,
504 int64_t onnx_opset_version,
505 onnx_torch::OperatorExportTypes operator_export_type,
506 const std::map<std::string, at::Tensor>& initializers,
507 const std::unordered_map<
508 std::string,
509 std::unordered_map<int64_t, std::string>>& dynamic_axes,
510 bool defer_weight_export,
511 bool strip_doc,
512 bool keep_initializers_as_inputs,
513 const std::map<std::string, int>& custom_opsets,
514 bool add_node_names,
515 bool use_external_data_format,
516 const std::string& onnx_file_path,
517 NodeAttrNameMap node_attr_to_name)
518 : model_proto_(std::make_shared<onnx::ModelProto>()),
519
520 operator_export_type_(operator_export_type),
521 strip_doc_(strip_doc),
522 defer_weight_export_(defer_weight_export),
523 use_external_data_format_(use_external_data_format),
524 onnx_opset_version_(onnx_opset_version),
525 custom_opsets_(custom_opsets),
526 graph_(graph),
527 node_attr_to_name_(std::move(node_attr_to_name)) {
528 model_proto_->set_producer_name("pytorch");
529 TORCH_CHECK(
530 onnx_opset_version > 0 &&
531 static_cast<size_t>(onnx_opset_version) <
532 kOpsetVersionToIRVersion.size() &&
533 kOpsetVersionToIRVersion[onnx_opset_version] != kInvalidOpsetVersion,
534 "Unsupported onnx_opset_version: ",
535 onnx_opset_version);
536
537 model_proto_->set_ir_version(kOpsetVersionToIRVersion[onnx_opset_version]);
538 model_proto_->set_producer_version(TORCH_VERSION);
539 validateGraph(graph, operator_export_type);
540
541 // If graph proto size exceed maximum protobuf size of 2GB, set
542 // use_external_data_format to true.
543 if (!use_external_data_format &&
544 GetGraphProtoSize(
545 model_proto_->mutable_graph(),
546 graph,
547 add_node_names,
548 use_external_data_format,
549 onnx_file_path,
550 initializers) > INT_MAX) {
551 GRAPH_DEBUG(
552 "Exporting model exceed maximum protobuf size of 2GB. Storing model parameters in external data files");
553 use_external_data_format = true;
554 // use_external_data_format_ is one of graph_encoder private variable set
555 // for return `use_external_data_format` value.
556 use_external_data_format_ = use_external_data_format;
557 }
558
559 if (use_external_data_format) {
560 TORCH_CHECK(
561 !onnx_file_path.empty(),
562 "The serialized model is larger than the 2GiB limit imposed by the protobuf library. ",
563 "Therefore the output file must be a file path, so that the ONNX external data can ",
564 "be written to the same directory. Please specify the output file name.");
565 }
566
567 auto* imp = model_proto_->add_opset_import();
568 // This is the version of ONNX operator set we are targeting
569 imp->set_version(onnx_opset_version);
570
571 EncodeGraph(
572 model_proto_->mutable_graph(),
573 graph,
574 initializers,
575 dynamic_axes,
576 keep_initializers_as_inputs,
577 add_node_names,
578 use_external_data_format,
579 onnx_file_path);
580
581 for (const std::string& domain : domains_) {
582 auto* opset = model_proto_->add_opset_import();
583 opset->set_domain(domain);
584 // Check if domain version is registered. If not, set to version 1
585 auto it = custom_opsets.find(domain);
586 if (it == custom_opsets.end())
587 opset->set_version(1);
588 else {
589 opset->set_version(it->second);
590 }
591 }
592
593 for (auto const& custom_opset : custom_opsets) {
594 if (!std::count(domains_.begin(), domains_.end(), custom_opset.first)) {
595 TORCH_WARN(
596 "Custom opset domain: '",
597 custom_opset.first,
598 "' provided is not used in the model. ",
599 "Please verify custom opset domain names.");
600 }
601 }
602 }
603
604 // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
TensorTypeToONNXType(const TensorTypePtr & tensor_type,const std::string & dim_name_prefix,const std::string & name,const std::unordered_map<std::string,std::unordered_map<int64_t,std::string>> & dynamic_axes,onnx::TypeProto_Tensor * onnx_tensor_type,bool assign_dim_param)605 void GraphEncoder::TensorTypeToONNXType(
606 const TensorTypePtr& tensor_type,
607 const std::string& dim_name_prefix,
608 const std::string& name,
609 const std::unordered_map<
610 std::string,
611 std::unordered_map<int64_t, std::string>>& dynamic_axes,
612 onnx::TypeProto_Tensor* onnx_tensor_type,
613 bool assign_dim_param) {
614 if (tensor_type->dim()) {
615 onnx::TensorShapeProto* shape = onnx_tensor_type->mutable_shape();
616 auto sizes = tensor_type->symbolic_sizes().sizes().value();
617 for (const auto i : c10::irange(sizes.size())) {
618 shape->add_dim();
619 if ((dynamic_axes.find(name) != dynamic_axes.end()) &&
620 (dynamic_axes.at(name).find(i) != dynamic_axes.at(name).end())) {
621 shape->mutable_dim(i)->set_dim_param(dynamic_axes.at(name).at(i));
622 if (!sizes[i].is_static()) {
623 symbol_dim_map_[sizes[i]] = dynamic_axes.at(name).at(i);
624 }
625 } else if (sizes[i].is_static()) {
626 shape->mutable_dim(i)->set_dim_value(sizes[i].static_size());
627 } else if (assign_dim_param) {
628 if (symbol_dim_map_.find(sizes[i]) == symbol_dim_map_.end()) {
629 symbol_dim_map_[sizes[i]] =
630 dim_name_prefix + name + "_dim_" + std::to_string(i);
631 }
632 shape->mutable_dim(i)->set_dim_param(symbol_dim_map_[sizes[i]]);
633 }
634 }
635 }
636 if (tensor_type->scalarType()) {
637 onnx_tensor_type->set_elem_type(
638 ATenTypeToOnnxType(tensor_type->scalarType().value()));
639 }
640 }
641 // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
642
EncodeValueInfoType(onnx::TypeProto * onnx_type,const TypePtr & node_type,const Value * n,const std::unordered_map<std::string,std::unordered_map<int64_t,std::string>> & dynamic_axes)643 void GraphEncoder::EncodeValueInfoType(
644 onnx::TypeProto* onnx_type,
645 const TypePtr& node_type,
646 const Value* n,
647 const std::unordered_map<
648 std::string,
649 std::unordered_map<int64_t, std::string>>& dynamic_axes) {
650 std::string dim_name_prefix;
651 if (n->node()->kind() != prim::Param) {
652 dim_name_prefix = n->node()->kind().toUnqualString();
653 }
654 if (TensorTypePtr tensor_type = node_type->cast<TensorType>()) {
655 if (tensor_type->dim() || tensor_type->scalarType()) {
656 // Encode type if either shape or dtype exists.
657 onnx::TypeProto_Tensor* onnx_tensor_type =
658 onnx_type->mutable_tensor_type();
659 // Do not assign dim_param for sequence tensor type.
660 // Sequence of tensors could differ in dimension size.
661 // Use a dimension with neither dim_value nor dim_param set
662 // to denote an unknown dimension.
663 // Create and assign dim_param for normal tensor type.
664 auto is_sequence_tensor = static_cast<bool>(n->type()->cast<ListType>());
665 TensorTypeToONNXType(
666 tensor_type,
667 dim_name_prefix,
668 n->debugName(),
669 dynamic_axes,
670 onnx_tensor_type,
671 !is_sequence_tensor);
672 }
673 } else if (BoolTypePtr bool_type = node_type->cast<BoolType>()) {
674 onnx::TypeProto_Tensor* onnx_tensor_type = onnx_type->mutable_tensor_type();
675 onnx_tensor_type->set_elem_type(ATenTypeToOnnxType(at::kBool));
676 } else if (IntTypePtr int_type = node_type->cast<IntType>()) {
677 onnx::TypeProto_Tensor* onnx_tensor_type = onnx_type->mutable_tensor_type();
678 onnx_tensor_type->set_elem_type(ATenTypeToOnnxType(at::kLong));
679 } else if (FloatTypePtr float_type = node_type->cast<FloatType>()) {
680 onnx::TypeProto_Tensor* onnx_tensor_type = onnx_type->mutable_tensor_type();
681 onnx_tensor_type->set_elem_type(ATenTypeToOnnxType(at::kFloat));
682 } else if (ListTypePtr list_type = node_type->cast<ListType>()) {
683 auto list_elem_type = list_type->getElementType();
684 onnx::TypeProto_Sequence* sequence_type =
685 onnx_type->mutable_sequence_type();
686 onnx::TypeProto* onnx_tensor_type = sequence_type->mutable_elem_type();
687 EncodeValueInfoType(onnx_tensor_type, list_elem_type, n, dynamic_axes);
688 } else if (OptionalTypePtr optional_type = node_type->cast<OptionalType>()) {
689 auto elem_type = optional_type->getElementType();
690 if (TensorTypePtr tensor_type = elem_type->cast<TensorType>()) {
691 onnx::TypeProto_Optional* onnx_optional_type =
692 onnx_type->mutable_optional_type();
693 onnx::TypeProto_Tensor* onnx_tensor_type =
694 onnx_optional_type->mutable_elem_type()->mutable_tensor_type();
695 TensorTypeToONNXType(
696 tensor_type,
697 dim_name_prefix,
698 n->debugName(),
699 dynamic_axes,
700 onnx_tensor_type);
701 } else if (ListTypePtr inner_node_type = elem_type->cast<ListType>()) {
702 auto list_elem_type = inner_node_type->getElementType();
703 if (TensorTypePtr tensor_type = list_elem_type->cast<TensorType>()) {
704 onnx::TypeProto_Optional* onnx_optional_type =
705 onnx_type->mutable_optional_type();
706 onnx::TypeProto_Sequence* onnx_optional_sequence_type =
707 onnx_optional_type->mutable_elem_type()->mutable_sequence_type();
708 onnx::TypeProto_Tensor* onnx_tensor_type =
709 onnx_optional_sequence_type->mutable_elem_type()
710 ->mutable_tensor_type();
711 TensorTypeToONNXType(
712 tensor_type,
713 dim_name_prefix,
714 n->debugName(),
715 dynamic_axes,
716 onnx_tensor_type);
717 }
718 }
719 }
720 }
721
EncodeValueInfo(onnx::GraphProto * graph_proto,onnx::ValueInfoProto * v,const Value * n,const std::unordered_map<std::string,std::unordered_map<int64_t,std::string>> & dynamic_axes)722 void GraphEncoder::EncodeValueInfo(
723 onnx::GraphProto* graph_proto,
724 onnx::ValueInfoProto* v,
725 const Value* n,
726 const std::unordered_map<
727 std::string,
728 std::unordered_map<int64_t, std::string>>& dynamic_axes) {
729 std::string name = n->debugName();
730 v->set_name(name);
731 EncodeValueInfoType(v->mutable_type(), n->type(), n, dynamic_axes);
732 }
733
EncodeGraph(onnx::GraphProto * graph_proto,const std::shared_ptr<Graph> & graph,const std::map<std::string,at::Tensor> & initializers,const std::unordered_map<std::string,std::unordered_map<int64_t,std::string>> & dynamic_axes,bool keep_initializers_as_inputs,bool add_node_names,bool use_external_data_format,const std::string & onnx_file_path)734 void GraphEncoder::EncodeGraph(
735 onnx::GraphProto* graph_proto,
736 const std::shared_ptr<Graph>& graph,
737 const std::map<std::string, at::Tensor>& initializers,
738 const std::unordered_map<
739 std::string,
740 std::unordered_map<int64_t, std::string>>& dynamic_axes,
741 bool keep_initializers_as_inputs,
742 bool add_node_names,
743 bool use_external_data_format,
744 const std::string& onnx_file_path) {
745 EncodeBlock(
746 graph_proto,
747 graph->block(),
748 initializers,
749 dynamic_axes,
750 keep_initializers_as_inputs,
751 add_node_names,
752 use_external_data_format,
753 onnx_file_path);
754 }
755
EncodeBlock(onnx::GraphProto * graph_proto,const Block * block,const std::map<std::string,at::Tensor> & initializers,const std::unordered_map<std::string,std::unordered_map<int64_t,std::string>> & dynamic_axes,bool keep_initializers_as_inputs,bool add_node_names,bool use_external_data_format,const std::string & onnx_file_path)756 void GraphEncoder::EncodeBlock(
757 onnx::GraphProto* graph_proto,
758 const Block* block,
759 const std::map<std::string, at::Tensor>& initializers,
760 const std::unordered_map<
761 std::string,
762 std::unordered_map<int64_t, std::string>>& dynamic_axes,
763 bool keep_initializers_as_inputs,
764 bool add_node_names,
765 bool use_external_data_format,
766 const std::string& onnx_file_path) {
767 TORCH_INTERNAL_ASSERT(graph_proto != nullptr);
768 if (nullptr == block->owningNode()) {
769 // Top level main graph.
770 graph_proto->set_name("main_graph");
771 } else {
772 // TODO: Set more meaningful name for sub-graphs.
773 std::string block_name = "sub_graph";
774 if (num_blocks_) {
775 block_name += std::to_string(num_blocks_);
776 }
777 num_blocks_++;
778 graph_proto->set_name(block_name);
779 }
780
781 // Since ONNX IR VERSION 4, initializers do not have to
782 // be a subset of graph inputs. We use keep_initializers_as_inputs
783 // argument to determine whether to add initializers
784 // as inputs or not. If keep_initializers_as_inputs=false,
785 // we only add non-parameter inputs as inputs to ONNX graph, and
786 // not the initializers (parameters). If keep_initializers_as_inputs
787 // =true, we add initializers as inputs too. Setting
788 // keep_initializers_as_inputs=false allows better
789 // optimizations, such as constant-folding, on ONNX graphs
790 // by backends/optimizers.
791 if (keep_initializers_as_inputs) {
792 for (auto input : block->inputs()) {
793 onnx::ValueInfoProto* v = graph_proto->add_input();
794 EncodeValueInfo(graph_proto, v, input, dynamic_axes);
795 }
796 } else {
797 for (auto input : block->inputs()) {
798 auto it = initializers.find(input->debugName());
799 if (it == initializers.end()) {
800 onnx::ValueInfoProto* v = graph_proto->add_input();
801 EncodeValueInfo(graph_proto, v, input, dynamic_axes);
802 }
803 }
804 }
805 for (auto output : block->outputs()) {
806 onnx::ValueInfoProto* v = graph_proto->add_output();
807 EncodeValueInfo(graph_proto, v, output, dynamic_axes);
808 }
809 for (auto node : block->nodes()) {
810 if (node->mustBeNone()) {
811 // None nodes are used to implement optional inputs. One
812 // way to "not provide" an optional input is to create an
813 // Undefined node, and pass its output as that input.
814 continue;
815 }
816 if (node->kind() == ::c10::Symbol::onnx("LocalFunctionDef")) {
817 auto* func_proto = model_proto_->add_functions();
818 EncodeLocalFunction(
819 graph_proto,
820 func_proto,
821 node,
822 add_node_names,
823 use_external_data_format,
824 onnx_file_path);
825 continue;
826 }
827 auto* n_proto = graph_proto->add_node();
828 EncodeNode(
829 graph_proto,
830 n_proto,
831 node,
832 add_node_names,
833 use_external_data_format,
834 onnx_file_path);
835 }
836 AddInitializersIntoGraphProto(
837 graph_proto,
838 block,
839 initializers,
840 use_external_data_format,
841 onnx_file_path);
842 }
843
AddInitializersIntoGraphProto(onnx::GraphProto * graph_proto,const Block * block,const std::map<std::string,at::Tensor> & initializers,bool use_external_data_format,const std::string & onnx_file_path)844 void GraphEncoder::AddInitializersIntoGraphProto(
845 onnx::GraphProto* graph_proto,
846 const Block* block,
847 const std::map<std::string, at::Tensor>& initializers,
848 bool use_external_data_format,
849 const std::string& onnx_file_path) {
850 TORCH_INTERNAL_ASSERT(block->inputs().size() >= initializers.size());
851 for (auto input : block->inputs()) {
852 auto name_tensor_pair = initializers.find(input->debugName());
853 if (name_tensor_pair == initializers.end()) {
854 continue;
855 }
856 auto p = graph_proto->add_initializer();
857 p->set_name(name_tensor_pair->first);
858 EncodeTensor(
859 p,
860 name_tensor_pair->second,
861 name_tensor_pair->first,
862 use_external_data_format,
863 onnx_file_path);
864 }
865 }
866
GetGraphProtoSize(onnx::GraphProto * graph_proto,const std::shared_ptr<Graph> & graph,bool add_node_names,bool use_external_data_format,const std::string & onnx_file_path,const std::map<std::string,at::Tensor> & initializers)867 unsigned long long int GraphEncoder::GetGraphProtoSize(
868 onnx::GraphProto* graph_proto,
869 const std::shared_ptr<Graph>& graph,
870 bool add_node_names,
871 bool use_external_data_format,
872 const std::string& onnx_file_path,
873 const std::map<std::string, at::Tensor>& initializers) {
874 // Model size = sum(size(initializers)) + sum(size(onnx_constant_nodes))
875
876 // Add up all Initializers
877 onnx::GraphProto graph_proto_copy = onnx::GraphProto(*graph_proto);
878 unsigned long long int size = graph_proto_copy.ByteSizeLong();
879 for (auto input : graph->inputs()) {
880 auto name_tensor_pair = initializers.find(input->debugName());
881 if (name_tensor_pair == initializers.end()) {
882 continue;
883 }
884 auto tensor_proto = graph_proto_copy.add_initializer();
885 const at::Tensor& tensor = name_tensor_pair->second;
886 for (auto d : tensor.sizes()) {
887 tensor_proto->add_dims(d);
888 }
889 tensor_proto->set_data_type(ATenTypeToOnnxType(tensor.scalar_type()));
890
891 // Don't actually copy the buffer into tensor_proto since that is expensive.
892 // All we actually need is its size.
893 size += tensor_proto->ByteSizeLong();
894 size += tensor.element_size() * tensor.numel();
895 }
896
897 // Add up all onnx::Constant nodes that are Tensors
898 for (const auto& node : graph->nodes()) {
899 if (node->kind() == ::c10::onnx::Constant &&
900 node->hasAttribute(attr::value) &&
901 node->kindOf(attr::value) == AttributeKind::t) {
902 at::Tensor tensor = node->t(attr::value);
903
904 // Don't actually copy the buffer into n_proto since that is expensive.
905 // All we actually need is its size.
906 auto* n_proto = graph_proto_copy.add_node();
907 EncodeNode(
908 &graph_proto_copy,
909 n_proto,
910 node,
911 add_node_names,
912 use_external_data_format,
913 onnx_file_path);
914
915 // Calculate the size of the tensor in bytes
916 size += n_proto->ByteSizeLong();
917 size += tensor.element_size() * tensor.numel();
918 }
919 }
920 return size;
921 }
922
EncodeNode(onnx::GraphProto * graph_proto,onnx::NodeProto * node_proto,const Node * node,bool add_node_names,bool use_external_data_format,const std::string & onnx_file_path)923 void GraphEncoder::EncodeNode(
924 onnx::GraphProto* graph_proto,
925 onnx::NodeProto* node_proto,
926 const Node* node,
927 bool add_node_names,
928 bool use_external_data_format,
929 const std::string& onnx_file_path) {
930 if (!strip_doc_) {
931 node_proto->set_doc_string(node->sourceRange().str());
932 }
933 for (auto input : node->inputs()) {
934 if (input->node()->mustBeNone()) {
935 node_proto->add_input("");
936 } else {
937 node_proto->add_input(input->debugName());
938 }
939 }
940 for (auto output : node->outputs()) {
941 node_proto->add_output(output->debugName());
942 EncodeIntermediateValueInfo(graph_proto, output);
943 }
944 if (!node->kind().is_onnx()) {
945 std::string domain;
946 if (node->kind().is_aten() || node->kind().is_caffe2()) {
947 domain = node->kind().domainString();
948 } else { // Custom namespace and domain
949 domain = node->kind().ns().toUnqualString();
950 }
951 // TODO: set correct domain for function proto.
952 domains_.insert(domain);
953 node_proto->set_domain(domain);
954 }
955 if (operator_export_type_ == onnx_torch::OperatorExportTypes::ONNX) {
956 TORCH_INTERNAL_ASSERT(
957 !node->kind().is_aten() && !node->kind().is_prim() &&
958 !node->kind().is_attr());
959 }
960 node_proto->set_op_type(node->kind().toUnqualString());
961 const auto node_name_attribute_symbol =
962 Symbol::attr(::torch::onnx::kOnnxNodeNameAttribute);
963 if (add_node_names) {
964 std::string node_name =
965 node_proto->op_type() + "_" + std::to_string(num_op_nodes_);
966 if (node->hasAttribute(node_name_attribute_symbol)) {
967 node_name = node->s(node_name_attribute_symbol);
968 }
969 node_proto->set_name(node_name);
970 onnx_node_name_map_[node] = node_name;
971 num_op_nodes_++;
972 }
973 auto attrs_it = node_attr_to_name_.find(node);
974 for (auto attr_name : node->attributeNames()) {
975 if (attr_name == node_name_attribute_symbol) {
976 // Skip the node name attribute.
977 continue;
978 }
979 if (attrs_it != node_attr_to_name_.end()) {
980 auto attr_it = attrs_it->second.find(attr_name.toUnqualString());
981 if (attr_it != attrs_it->second.end()) {
982 AddAttribute(
983 node_proto, attr_name, attr_it->second, node->kindOf(attr_name));
984 continue;
985 }
986 }
987 AddAttribute(
988 node_proto, node, attr_name, use_external_data_format, onnx_file_path);
989 }
990 if (node->kind() == ::c10::onnx::Loop) {
991 TORCH_INTERNAL_ASSERT(node->blocks().size() == 1);
992
993 auto body = node_proto->add_attribute();
994 body->set_name("body");
995 body->set_type(onnx::AttributeProto_AttributeType_GRAPH);
996 auto g = body->mutable_g();
997 EncodeBlock(
998 g,
999 node->blocks()[0],
1000 {},
1001 {},
1002 true,
1003 true,
1004 use_external_data_format,
1005 onnx_file_path);
1006 }
1007 if (node->kind() == ::c10::onnx::If) {
1008 TORCH_INTERNAL_ASSERT(node->blocks().size() == 2);
1009
1010 auto then_branch = node_proto->add_attribute();
1011 then_branch->set_name("then_branch");
1012 then_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH);
1013 auto true_g = then_branch->mutable_g();
1014 EncodeBlock(
1015 true_g,
1016 node->blocks()[0],
1017 {},
1018 {},
1019 true,
1020 true,
1021 use_external_data_format,
1022 onnx_file_path);
1023
1024 auto else_branch = node_proto->add_attribute();
1025 else_branch->set_name("else_branch");
1026 else_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH);
1027 auto false_g = else_branch->mutable_g();
1028 EncodeBlock(
1029 false_g,
1030 node->blocks()[1],
1031 {},
1032 {},
1033 true,
1034 true,
1035 use_external_data_format,
1036 onnx_file_path);
1037 }
1038 }
1039
AddAttribute(onnx::NodeProto * node_proto,const jit::Symbol name,const std::string & ref_attr_name,const AttributeKind attr_kind)1040 void GraphEncoder::AddAttribute(
1041 onnx::NodeProto* node_proto,
1042 const jit::Symbol name,
1043 const std::string& ref_attr_name,
1044 const AttributeKind attr_kind) {
1045 auto attr = node_proto->add_attribute();
1046 TORCH_INTERNAL_ASSERT(name.is_attr());
1047 attr->set_name(name.toUnqualString());
1048 attr->set_ref_attr_name(ref_attr_name);
1049 attr->set_type(ATenAttributeKindToOnnxAttributeType(attr_kind, name));
1050 }
1051
AddAttribute(onnx::NodeProto * node_proto,const jit::Node * node,const jit::Symbol name,const bool use_external_data_format,const std::string & onnx_file_path)1052 void GraphEncoder::AddAttribute(
1053 onnx::NodeProto* node_proto,
1054 const jit::Node* node,
1055 const jit::Symbol name,
1056 const bool use_external_data_format,
1057 const std::string& onnx_file_path) {
1058 auto createAttributeTensorName =
1059 [](const onnx::NodeProto* node_proto,
1060 onnx::TensorProto* tensor_proto,
1061 const jit::Symbol attr_name,
1062 size_t& num_external_data) -> std::string {
1063 if (tensor_proto->has_name()) {
1064 return tensor_proto->name();
1065 }
1066 if (!node_proto->has_name()) {
1067 auto name = node_proto->op_type() + "_" + attr_name.toDisplayString() +
1068 "_" + std::to_string(num_external_data);
1069 num_external_data++;
1070 return name;
1071 } else {
1072 return node_proto->name() + "_" + attr_name.toDisplayString();
1073 }
1074 };
1075
1076 auto attr = node_proto->add_attribute();
1077 TORCH_INTERNAL_ASSERT(name.is_attr());
1078 attr->set_name(name.toUnqualString());
1079 attr->set_type(
1080 ATenAttributeKindToOnnxAttributeType(node->kindOf(name), name));
1081 switch (node->kindOf(name)) {
1082 case AttributeKind::f:
1083 attr->set_f(static_cast<float>(node->f(name)));
1084 break;
1085 case AttributeKind::fs:
1086 for (auto& v : node->fs(name))
1087 attr->add_floats(static_cast<float>(v));
1088 break;
1089 case AttributeKind::i:
1090 attr->set_i(node->i(name));
1091 break;
1092 case AttributeKind::is:
1093 for (auto& v : node->is(name))
1094 attr->add_ints(v);
1095 break;
1096 case AttributeKind::s:
1097 attr->set_s(node->s(name));
1098 break;
1099 case AttributeKind::ss:
1100 for (auto& v : node->ss(name))
1101 attr->add_strings(v);
1102 break;
1103 case AttributeKind::t: {
1104 auto t = attr->mutable_t();
1105 if (use_external_data_format && !t->has_name()) {
1106 t->set_name(
1107 createAttributeTensorName(node_proto, t, name, num_external_data_));
1108 }
1109 EncodeTensor(
1110 t, node->t(name), {}, use_external_data_format, onnx_file_path);
1111 } break;
1112 case AttributeKind::ts:
1113 for (auto& v : node->ts(name)) {
1114 auto t = attr->add_tensors();
1115 if (use_external_data_format && !t->has_name()) {
1116 t->set_name(createAttributeTensorName(
1117 node_proto, t, name, num_external_data_));
1118 }
1119 EncodeTensor(t, v, {}, use_external_data_format, onnx_file_path);
1120 }
1121 break;
1122 case AttributeKind::ty: {
1123 attr->set_type(onnx::AttributeProto_AttributeType_TYPE_PROTO);
1124 auto tp = attr->mutable_tp();
1125 const TypePtr& node_type = node->ty(name);
1126 EncodeTypeProto(
1127 tp, node_type, node_proto->op_type() + "_" + name.toDisplayString());
1128 } break;
1129 case AttributeKind::tys: {
1130 attr->set_type(onnx::AttributeProto_AttributeType_TYPE_PROTOS);
1131 size_t index = 0;
1132 for (auto& v : node->tys(name)) {
1133 auto tp = attr->add_type_protos();
1134 EncodeTypeProto(
1135 tp,
1136 v,
1137 node_proto->op_type() + "_" + name.toDisplayString() + "_" +
1138 std::to_string(index));
1139 index++;
1140 }
1141 } break;
1142 case AttributeKind::g: {
1143 auto g = attr->mutable_g();
1144 EncodeGraph(
1145 g,
1146 node->g(name),
1147 {},
1148 {},
1149 true,
1150 true,
1151 use_external_data_format,
1152 onnx_file_path);
1153 } break;
1154 case AttributeKind::gs:
1155 for (auto& v : node->gs(name)) {
1156 auto g = attr->add_graphs();
1157 EncodeGraph(
1158 g, v, {}, {}, true, true, use_external_data_format, onnx_file_path);
1159 }
1160 break;
1161 default:
1162 std::ostringstream err_msg;
1163 err_msg << "attribute \"" << name.toDisplayString()
1164 << "\" has unexpected kind: " << toString(node->kindOf(name));
1165 throw std::runtime_error(err_msg.str());
1166 }
1167 }
1168
AddAttribute(onnx::FunctionProto * func_proto,const std::string & name)1169 void GraphEncoder::AddAttribute(
1170 onnx::FunctionProto* func_proto,
1171 const std::string& name) {
1172 TORCH_INTERNAL_ASSERT(nullptr != func_proto);
1173 func_proto->add_attribute(name);
1174 }
1175
EncodeLocalFunctionOpsetImport(onnx::FunctionProto * func_proto,const Node * n,std::unordered_set<std::string> & custom_domains)1176 void GraphEncoder::EncodeLocalFunctionOpsetImport(
1177 onnx::FunctionProto* func_proto,
1178 const Node* n,
1179 std::unordered_set<std::string>& custom_domains) {
1180 if (!n->kind().is_onnx()) {
1181 std::string domain;
1182 if (n->kind().is_aten() || n->kind().is_caffe2()) {
1183 domain = n->kind().domainString();
1184 } else { // Custom namespace and domain
1185 domain = n->kind().ns().toUnqualString();
1186 }
1187 domains_.insert(domain);
1188
1189 if (custom_domains.find(domain) == custom_domains.end()) {
1190 custom_domains.insert(domain);
1191
1192 auto* custom_imp = func_proto->add_opset_import();
1193 custom_imp->set_domain(domain);
1194 // Check if domain version is registered. If not, set to version 1
1195 auto it = custom_opsets_.find(domain);
1196 if (it == custom_opsets_.end())
1197 custom_imp->set_version(1);
1198 else {
1199 custom_imp->set_version(it->second);
1200 }
1201 }
1202 }
1203
1204 for (auto* b : n->blocks()) {
1205 for (auto* sub_n : b->nodes()) {
1206 EncodeLocalFunctionOpsetImport(func_proto, sub_n, custom_domains);
1207 }
1208 }
1209 }
1210
EncodeLocalFunction(onnx::GraphProto * graph_proto,onnx::FunctionProto * func_proto,const Node * n,bool add_node_names,bool use_external_data_format,const std::string & onnx_file_path)1211 void GraphEncoder::EncodeLocalFunction(
1212 onnx::GraphProto* graph_proto,
1213 onnx::FunctionProto* func_proto,
1214 const Node* n,
1215 bool add_node_names,
1216 bool use_external_data_format,
1217 const std::string& onnx_file_path) {
1218 const auto fsub_g = n->g(Symbol::attr("graph"));
1219 func_proto->set_name(n->s(::c10::attr::name));
1220
1221 for (auto input : fsub_g->inputs()) {
1222 func_proto->add_input(input->debugName());
1223 }
1224 for (auto output : fsub_g->outputs()) {
1225 func_proto->add_output(output->debugName());
1226 }
1227
1228 // encode attributes names
1229 if (n->hasAttribute(Symbol::attr("attributes"))) {
1230 for (const auto& attr_name : n->ss(Symbol::attr("attributes"))) {
1231 AddAttribute(func_proto, attr_name);
1232 }
1233 }
1234
1235 auto* imp = func_proto->add_opset_import();
1236 // This is the version of ONNX operator set we are targeting
1237 imp->set_version(onnx_opset_version_);
1238
1239 // add for custom domain as well.
1240 const auto& domain = n->s(Symbol::attr("domain"));
1241 func_proto->set_domain(domain);
1242 domains_.insert(domain);
1243 std::unordered_set<std::string> custom_domains;
1244
1245 for (auto* fsub_n : fsub_g->nodes()) {
1246 if (fsub_n->mustBeNone()) {
1247 // None nodes are used to implement optional inputs. One
1248 // way to "not provide" an optional input is to create an
1249 // Undefined node, and pass its output as that input.
1250 continue;
1251 }
1252 auto* n_proto = func_proto->add_node();
1253 EncodeNode(
1254 graph_proto,
1255 n_proto,
1256 fsub_n,
1257 add_node_names,
1258 use_external_data_format,
1259 onnx_file_path);
1260 EncodeLocalFunctionOpsetImport(func_proto, fsub_n, custom_domains);
1261 }
1262 }
1263
EncodeTypeProto(onnx::TypeProto * type_proto,const TypePtr & node_type,const std::string & name)1264 void GraphEncoder::EncodeTypeProto(
1265 onnx::TypeProto* type_proto,
1266 const TypePtr& node_type,
1267 const std::string& name) {
1268 if (TensorTypePtr tensor_type = node_type->cast<TensorType>()) {
1269 onnx::TypeProto_Tensor* onnx_tensor_type =
1270 type_proto->mutable_tensor_type();
1271 TensorTypeToONNXType(tensor_type, "", name, {}, onnx_tensor_type);
1272 } else if (ListTypePtr list_type = node_type->cast<ListType>()) {
1273 onnx::TypeProto_Sequence* seq_type = type_proto->mutable_sequence_type();
1274 auto elem_type = list_type->getElementType();
1275 EncodeTypeProto(seq_type->mutable_elem_type(), elem_type, name);
1276 }
1277 }
1278
EncodeTensor(onnx::TensorProto * tensor_proto,const at::Tensor & tensor,const std::optional<std::string> & external_ref,const bool use_external_data_format,const std::string & onnx_file_path)1279 void GraphEncoder::EncodeTensor(
1280 onnx::TensorProto* tensor_proto,
1281 const at::Tensor& tensor,
1282 const std::optional<std::string>& external_ref,
1283 const bool use_external_data_format,
1284 const std::string& onnx_file_path) {
1285 for (auto d : tensor.sizes()) {
1286 tensor_proto->add_dims(d);
1287 }
1288 tensor_proto->set_data_type(ATenTypeToOnnxType(tensor.scalar_type()));
1289 at::Tensor t;
1290 // CPU's HalfTensor doesn't have contiguous(), so first calling contiguous()
1291 // TODO We don't call .cpu() on quantized tensors as it fails when calling
1292 // aten::empty() on quantized tensors beyond certain size. Issue #29435.
1293 if (tensor.is_quantized()) {
1294 t = tensor.contiguous();
1295 } else {
1296 t = tensor.contiguous().cpu();
1297 }
1298
1299 // Either defer_weight_export should be true and external_ref must be present,
1300 // or use_external_data_format should be true, not both at the same time. They
1301 // can both be false at the same time (for ONNX export for regular model
1302 // size).
1303 TORCH_INTERNAL_ASSERT(
1304 !((defer_weight_export_ && external_ref) && use_external_data_format));
1305 // Add a buffer to the raw_data_export_map for the caller to dump into an
1306 // external data store. If external_ref is not specified, we instead dump
1307 // the contiguous data into the protobuf itself
1308 if (defer_weight_export_ && external_ref) {
1309 // For now, we use the name of the tensor as the external lookup name to
1310 // avoid ONNX protobuf changes.
1311 TORCH_INTERNAL_ASSERT(external_ref.value() == tensor_proto->name());
1312 TORCH_INTERNAL_ASSERT(
1313 raw_data_export_map_.count(external_ref.value()) == 0);
1314 raw_data_export_map_[external_ref.value()] = t;
1315 tensor_proto->set_raw_data("__EXTERNAL");
1316 } else {
1317 TORCH_INTERNAL_ASSERT(t.is_contiguous());
1318 size_t tensorSize = static_cast<size_t>(c10::multiply_integers(
1319 std::begin(tensor.sizes()), std::end(tensor.sizes())));
1320 if (use_external_data_format &&
1321 tensorSize > ParamSizeThresholdForExternalStorage) {
1322 TORCH_INTERNAL_ASSERT(!onnx_file_path.empty());
1323 TORCH_INTERNAL_ASSERT(tensor_proto->has_name());
1324 auto tensorName = GetExternalFileName(tensor_proto->name());
1325 CreateExternalFile(t, tensorName, onnx_file_path);
1326 onnx::StringStringEntryProto* location =
1327 tensor_proto->mutable_external_data()->Add();
1328 location->set_key("location");
1329 location->set_value(tensorName);
1330 tensor_proto->set_data_location(onnx::TensorProto_DataLocation_EXTERNAL);
1331 } else {
1332 // According to ParseData function's comments in onnx, tensor data is
1333 // always little endian.
1334 tensor_proto->set_raw_data(get_little_endian_data(t));
1335 }
1336 }
1337 }
1338
EncodeIntermediateValueInfo(onnx::GraphProto * graph_proto,const Value * v)1339 void GraphEncoder::EncodeIntermediateValueInfo(
1340 onnx::GraphProto* graph_proto,
1341 const Value* v) {
1342 // Motivation is to encode ValueInfo for onnx local function nodes.
1343 auto n = v->node();
1344 if (n->kind().is_onnx() || n->kind().is_aten()) {
1345 // Encode value info only for non-onnx or non-ATen nodes.
1346 return;
1347 }
1348 if (n->owningGraph() != graph_.get()) {
1349 // Encode value info only for node in main graph.
1350 return;
1351 }
1352 for (const auto* o : graph_->outputs()) {
1353 // Do not encode value info for graph outputs.
1354 if (o == v) {
1355 return;
1356 }
1357 }
1358 auto v_info_p = graph_proto->add_value_info();
1359 EncodeValueInfo(graph_proto, v_info_p, v);
1360 }
1361
1362 } // namespace
1363
pretty_print_onnx(const std::shared_ptr<Graph> & graph,const std::map<std::string,at::Tensor> & initializers,int64_t onnx_opset_version,bool defer_weight_export,::torch::onnx::OperatorExportTypes operator_export_type,bool google_printer,bool keep_initializers_as_inputs,const std::map<std::string,int> & custom_opsets,bool add_node_names)1364 std::string pretty_print_onnx(
1365 const std::shared_ptr<Graph>& graph,
1366 const std::map<std::string, at::Tensor>& initializers,
1367 int64_t onnx_opset_version,
1368 bool defer_weight_export,
1369 ::torch::onnx::OperatorExportTypes operator_export_type,
1370 bool google_printer,
1371 bool keep_initializers_as_inputs,
1372 const std::map<std::string, int>& custom_opsets,
1373 bool add_node_names) {
1374 auto graph_encoder = GraphEncoder(
1375 graph,
1376 onnx_opset_version,
1377 operator_export_type,
1378 initializers,
1379 std::unordered_map<
1380 std::string,
1381 std::unordered_map<int64_t, std::string>>{},
1382 defer_weight_export,
1383 true,
1384 keep_initializers_as_inputs,
1385 custom_opsets,
1386 add_node_names,
1387 false,
1388 std::string());
1389 if (google_printer) {
1390 return graph_encoder.get_model_proto()->DebugString();
1391 }
1392 return prettyPrint(*graph_encoder.get_model_proto());
1393 }
1394
1395 std::tuple<
1396 std::shared_ptr<::ONNX_NAMESPACE::ModelProto>,
1397 RawDataExportMap,
1398 SymbolDimMap,
1399 bool,
1400 NodeNameMap>
export_onnx(const std::shared_ptr<Graph> & graph,const std::map<std::string,at::Tensor> & initializers,int64_t onnx_opset_version,const std::unordered_map<std::string,std::unordered_map<std::int64_t,std::string>> & dynamic_axes,bool defer_weight_export,::torch::onnx::OperatorExportTypes operator_export_type,bool strip_doc_string,bool keep_initializers_as_inputs,const std::map<std::string,int> & custom_opsets,bool add_node_names,bool use_external_data_format,const std::string & onnx_file_path,const NodeAttrNameMap & node_attr_to_name)1401 export_onnx(
1402 const std::shared_ptr<Graph>& graph,
1403 const std::map<std::string, at::Tensor>& initializers,
1404 int64_t onnx_opset_version,
1405 const std::unordered_map<
1406 std::string,
1407 std::unordered_map<std::int64_t, std::string>>& dynamic_axes,
1408 bool defer_weight_export,
1409 ::torch::onnx::OperatorExportTypes operator_export_type,
1410 bool strip_doc_string,
1411 bool keep_initializers_as_inputs,
1412 const std::map<std::string, int>& custom_opsets,
1413 bool add_node_names,
1414 bool use_external_data_format,
1415 const std::string& onnx_file_path,
1416 const NodeAttrNameMap& node_attr_to_name) {
1417 auto graph_encoder = GraphEncoder(
1418 graph,
1419 onnx_opset_version,
1420 operator_export_type,
1421 initializers,
1422 dynamic_axes,
1423 defer_weight_export,
1424 strip_doc_string,
1425 keep_initializers_as_inputs,
1426 custom_opsets,
1427 add_node_names,
1428 use_external_data_format,
1429 onnx_file_path,
1430 node_attr_to_name);
1431 GRAPH_DEBUG("onnx proto:", prettyPrint(*graph_encoder.get_model_proto()));
1432 return std::make_tuple(
1433 graph_encoder.get_model_proto(),
1434 graph_encoder.get_raw_data_export_map(),
1435 graph_encoder.get_symbol_dim_param_map(),
1436 graph_encoder.get_use_external_data_format(),
1437 graph_encoder.get_onnx_node_names());
1438 }
1439
serialize_model_proto_to_string(const std::shared_ptr<::ONNX_NAMESPACE::ModelProto> & model_proto)1440 std::string serialize_model_proto_to_string(
1441 const std::shared_ptr<::ONNX_NAMESPACE::ModelProto>& model_proto) {
1442 return model_proto->SerializeAsString();
1443 }
1444
check_onnx_proto(const std::string & proto_string)1445 void check_onnx_proto(const std::string& proto_string) {
1446 onnx::ModelProto model;
1447 if (!ParseProtoFromBytes(&model, proto_string.c_str(), proto_string.size())) {
1448 throw std::runtime_error("Invalid ONNX proto string.");
1449 return;
1450 }
1451 // 1. baseline check
1452 // These two checks prevent broken graph being generated
1453 // And errors out exporting if that happens.
1454 onnx::checker::check_model(model);
1455 onnx::shape_inference::InferShapes(model);
1456 // 2. full check
1457 // apply strict mode shape type inference check which examines
1458 // whether it's a valid ONNX graph or not. As for some users, they
1459 // don't need a fully valid ONNX graph to run their model, we simply
1460 // add this information as warning message if it fails.
1461 try {
1462 auto* schema_registry = onnx::OpSchemaRegistry::Instance();
1463 onnx::ShapeInferenceOptions options{
1464 /*check_type_val=*/true,
1465 /*strict_mode_val=*/true};
1466 onnx::shape_inference::InferShapes(model, schema_registry, options);
1467 } catch (const onnx::InferenceError& ex) {
1468 TORCH_WARN(
1469 "The exported ONNX model failed ONNX shape inference. "
1470 "The model will not be executable by the ONNX Runtime. "
1471 "If this is unintended and you believe there is a bug, "
1472 "please report an issue at https://github.com/pytorch/pytorch/issues. "
1473 "Error reported by strict ONNX shape inference: ",
1474 ex.what());
1475 }
1476 }
1477
1478 } // namespace torch::jit
1479