xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/serialization/export.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/serialization/export.h>
2 
3 #include <ATen/ATen.h>
4 #include <ATen/Utils.h>
5 #include <ATen/core/functional.h>
6 #include <c10/macros/Macros.h>
7 #include <c10/util/Exception.h>
8 #include <c10/util/accumulate.h>
9 #include <c10/util/irange.h>
10 #include <torch/csrc/autograd/symbolic.h>
11 #include <torch/csrc/jit/jit_log.h>
12 #include <torch/csrc/jit/passes/dead_code_elimination.h>
13 #include <torch/csrc/jit/passes/inliner.h>
14 #include <torch/csrc/jit/runtime/instruction.h>
15 #include <torch/csrc/jit/serialization/import_export_constants.h>
16 #include <torch/csrc/jit/serialization/import_export_functions.h>
17 #include <torch/csrc/jit/serialization/import_export_helpers.h>
18 #include <torch/csrc/jit/serialization/onnx.h>
19 #include <torch/csrc/onnx/back_compat.h>
20 #include <torch/csrc/onnx/onnx.h>
21 #include <torch/version.h>
22 #include <optional>
23 
24 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wnewline-eof")
25 #include <onnx/checker.h>
26 C10_DIAGNOSTIC_POP()
27 #include <onnx/onnx_pb.h>
28 #include <onnx/proto_utils.h>
29 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wsuggest-override")
30 #include <onnx/shape_inference/implementation.h>
31 C10_DIAGNOSTIC_POP()
32 
33 #include <memory>
34 #include <regex>
35 #include <set>
36 #include <sstream>
37 #include <string>
38 #include <utility>
39 #include <vector>
40 
41 namespace torch::jit {
42 
get_little_endian_data(const at::Tensor & t)43 static std::string get_little_endian_data(const at::Tensor& t) {
44 #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
45   return std::string(
46       static_cast<char*>(t.data_ptr()), t.element_size() * t.numel());
47 #elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
48   const size_t element_size = t.element_size();
49   const size_t num_elements = t.numel();
50 
51   std::vector<char> data_copy{
52       static_cast<char*>(t.data_ptr()),
53       static_cast<char*>(t.data_ptr()) + element_size * num_elements};
54 
55   for (size_t i = 0; i < num_elements; ++i) {
56     char* start_byte = data_copy.data() + i * element_size;
57     char* end_byte = start_byte + element_size - 1;
58     /* keep swapping */
59     for (size_t count = 0; count < element_size / 2; ++count) {
60       std::swap(*start_byte, *end_byte);
61       ++start_byte;
62       --end_byte;
63     }
64   }
65 
66   return std::string(data_copy.data(), element_size * num_elements);
67 #else
68 #error Unexpected or undefined __BYTE_ORDER__
69 #endif
70 }
71 
writeArchiveAndTensors(const std::string & archive_name,const char * data,size_t size,const std::vector<at::Tensor> & tensors,caffe2::serialize::PyTorchStreamWriter & out)72 void writeArchiveAndTensors(
73     const std::string& archive_name,
74     const char* data,
75     size_t size,
76     const std::vector<at::Tensor>& tensors,
77     caffe2::serialize::PyTorchStreamWriter& out) {
78   std::string prefix = archive_name + "/";
79   size_t i = 0;
80   for (const auto& td : tensors) {
81     WriteableTensorData writable_td = getWriteableTensorData(td);
82     std::string fname = prefix + std::to_string(i++);
83     out.writeRecord(fname, writable_td.data(), writable_td.sizeInBytes());
84   }
85   std::string fname = archive_name + ".pkl";
86   out.writeRecord(fname, data, size);
87 }
88 
89 namespace {
90 namespace onnx_torch = ::torch::onnx;
91 namespace onnx = ::ONNX_NAMESPACE;
92 
93 const static int kInvalidOpsetVersion = -1;
94 const static int kMainOpsetVersion = 20;
95 // Based on OP_SET_ID_VERSION_MAP in
96 // https://github.com/onnx/onnx/blob/master/onnx/helper.py.
97 constexpr static std::array<int64_t, kMainOpsetVersion + 1>
98     kOpsetVersionToIRVersion = {
99         kInvalidOpsetVersion,
100         3, // opset 1
101         kInvalidOpsetVersion,
102         kInvalidOpsetVersion,
103         kInvalidOpsetVersion,
104         3, // opset 5
105         3, // opset 6
106         3, // opset 7
107         3, // opset 8
108         4, // opset 9
109         5, // opset 10
110         6, // opset 11
111         7, // opset 12
112         7, // opset 13
113         7, // opset 14
114         8, // opset 15
115         8, // opset 16
116         8, // opset 17
117         8, // opset 18
118         9, // opset 19
119         9, // opset 20
120 };
121 
getNodeStackTraceString(const Node * n)122 std::string getNodeStackTraceString(const Node* n) {
123   return n->sourceRange().str();
124 }
125 
validateBlock(Block * b,onnx_torch::OperatorExportTypes operator_export_type)126 void validateBlock(
127     Block* b,
128     onnx_torch::OperatorExportTypes operator_export_type) {
129   for (auto node : b->nodes()) {
130     for (Block* sub_block : node->blocks()) {
131       validateBlock(sub_block, operator_export_type);
132     }
133     // Macro'ed so we get a marginally better line number on failed export
134 #define FAIL_EXPORT(name)                          \
135   throw std::runtime_error(                        \
136       std::string("ONNX export failed: ") + name + \
137       "\n\nGraph we tried to export:\n" + b->owningGraph()->toString());
138     // Special error messages for certain types of operators
139     if (node->kind() == prim::PythonOp) {
140       if (operator_export_type !=
141           onnx_torch::OperatorExportTypes::ONNX_FALLTHROUGH) {
142         auto py_node = static_cast<PythonOp*>(node);
143         FAIL_EXPORT(
144             "Couldn't export Python operator " + py_node->name() +
145             "\n\nDefined at:\n" + getNodeStackTraceString(node))
146       }
147     } else {
148       if (node->kind() == prim::PackPadded || node->kind() == prim::PadPacked) {
149         if (operator_export_type !=
150             onnx_torch::OperatorExportTypes::ONNX_FALLTHROUGH) {
151           FAIL_EXPORT(
152               "Cannot export individual pack_padded_sequence or pad_packed_sequence; these operations must occur in pairs.\n\nUsage of this operation occurred at:\n" +
153               getNodeStackTraceString(node));
154         }
155       }
156       bool is_aten_enabled = operator_export_type ==
157               onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK ||
158           operator_export_type == onnx_torch::OperatorExportTypes::ONNX_ATEN ||
159           operator_export_type ==
160               onnx_torch::OperatorExportTypes::ONNX_FALLTHROUGH;
161       if (node->kind().is_aten() && !is_aten_enabled && !node->mustBeNone()) {
162         FAIL_EXPORT(
163             "Couldn't export operator " + node->kind().toDisplayString() +
164             "\n\nDefined at:\n" + getNodeStackTraceString(node));
165       }
166     }
167 #undef FAIL_EXPORT
168   }
169 }
170 
validateGraph(const std::shared_ptr<Graph> & graph,onnx_torch::OperatorExportTypes operator_export_type)171 void validateGraph(
172     const std::shared_ptr<Graph>& graph,
173     onnx_torch::OperatorExportTypes operator_export_type) {
174   validateBlock(graph->block(), operator_export_type);
175 }
176 
GetFileRootPath(const std::string & rootPath)177 std::string GetFileRootPath(const std::string& rootPath) {
178   std::string rootPath_ = rootPath;
179   // First, making slash consistent.
180   std::replace(rootPath_.begin(), rootPath_.end(), '\\', '/');
181   // Second, remove trailing slashes, if any
182   std::regex trailer("/+$");
183   std::string root = std::regex_replace(rootPath_, trailer, std::string());
184   std::string folder = root.substr(0, root.find_last_of('/'));
185   if (folder == rootPath_) { // If no root folder specified, select cwd.
186     return std::string(".");
187   }
188   return folder;
189 }
190 
GetExternalFileName(const std::optional<std::string> & external_ref)191 std::string GetExternalFileName(
192     const std::optional<std::string>& external_ref) {
193   auto tensorName = external_ref.value();
194   const std::string illegalChars = "\\/:?\"<>|";
195   for (char& i : tensorName) {
196     if (illegalChars.find(i) != std::string::npos) {
197       i = '_';
198     }
199   }
200   return tensorName;
201 }
202 
CloseFile(FILE * fp)203 void CloseFile(FILE* fp) {
204   fclose(fp);
205 }
206 
CreateExternalFile(const at::Tensor & tensor,const std::string & tensorName,const std::string & onnx_file_path)207 void CreateExternalFile(
208     const at::Tensor& tensor,
209     const std::string& tensorName,
210     const std::string& onnx_file_path) {
211   auto folder = GetFileRootPath(onnx_file_path);
212   std::string fullFilePath = folder + "/" + tensorName;
213   std::unique_ptr<FILE, decltype(&CloseFile)> fp(
214       fopen(fullFilePath.c_str(), "wb"), &CloseFile);
215   if (fp == nullptr) {
216     throw std::runtime_error(
217         std::string("ONNX export failed. Could not open file or directory: ") +
218         fullFilePath);
219   }
220   std::string s = get_little_endian_data(tensor);
221   fwrite(s.c_str(), tensor.element_size(), tensor.numel(), fp.get());
222 } // fclose() called here through CloseFile(), if FILE* is not a null pointer.
223 
224 class GraphEncoder {
225  public:
226   GraphEncoder(
227       const std::shared_ptr<Graph>& graph,
228       int64_t onnx_opset_version,
229       onnx_torch::OperatorExportTypes operator_export_type,
230       const std::map<std::string, at::Tensor>& initializers,
231       const std::unordered_map<
232           std::string,
233           std::unordered_map<int64_t, std::string>>& dynamic_axes,
234       bool defer_weight_export,
235       bool strip_doc,
236       bool keep_initializers_as_inputs,
237       const std::map<std::string, int>& custom_opsets,
238       bool add_node_names,
239       bool use_external_data_format,
240       const std::string& onnx_file_path,
241       NodeAttrNameMap node_attr_to_name = {});
242 
get_model_proto()243   std::shared_ptr<onnx::ModelProto> get_model_proto() {
244     return model_proto_;
245   }
246 
get_symbol_dim_param_map()247   SymbolDimMap get_symbol_dim_param_map() {
248     return symbol_dim_map_;
249   }
250 
get_raw_data_export_map()251   RawDataExportMap get_raw_data_export_map() {
252     return raw_data_export_map_;
253   }
254 
get_use_external_data_format()255   bool get_use_external_data_format() {
256     return use_external_data_format_;
257   }
258 
get_onnx_node_names()259   NodeNameMap get_onnx_node_names() {
260     return onnx_node_name_map_;
261   }
262 
263  private:
264   // Using std::map instead of std::unordered_map for initializers
265   // in EncodeGraph constructor so that the order in which initializers
266   // get written to the ONNX graph is always the deterministic and
267   // predictable. While this is not a ONNX requirement, it is needed
268   // for testing purposes in tests that use _export_to_pretty_string()
269   // for validating ONNX graphs.
270   void EncodeGraph(
271       onnx::GraphProto* graph_proto,
272       const std::shared_ptr<Graph>& graph,
273       const std::map<std::string, at::Tensor>& initializers =
274           std::map<std::string, at::Tensor>(),
275       const std::
276           unordered_map<std::string, std::unordered_map<int64_t, std::string>>&
277               dynamic_axes = std::unordered_map<
278                   std::string,
279                   std::unordered_map<int64_t, std::string>>(),
280       bool keep_initializers_as_inputs = true,
281       bool add_node_names = true,
282       bool use_external_data_format = false,
283       const std::string& onnx_file_path = std::string());
284 
285   void EncodeBlock(
286       onnx::GraphProto* graph_proto,
287       const Block* block,
288       const std::map<std::string, at::Tensor>& initializers =
289           std::map<std::string, at::Tensor>(),
290       const std::
291           unordered_map<std::string, std::unordered_map<int64_t, std::string>>&
292               dynamic_axes = std::unordered_map<
293                   std::string,
294                   std::unordered_map<int64_t, std::string>>(),
295       bool keep_initializers_as_inputs = true,
296       bool add_node_names = true,
297       bool use_external_data_format = false,
298       const std::string& onnx_file_path = std::string());
299 
300   void AddInitializersIntoGraphProto(
301       onnx::GraphProto* graph_proto,
302       const Block* block,
303       const std::map<std::string, at::Tensor>& initializers =
304           std::map<std::string, at::Tensor>(),
305       bool use_external_data_format = false,
306       const std::string& onnx_file_path = std::string());
307 
308   unsigned long long int GetGraphProtoSize(
309       onnx::GraphProto* graph_proto,
310       const std::shared_ptr<Graph>& graph,
311       bool add_node_names,
312       bool use_external_data_format,
313       const std::string& onnx_file_path,
314       const std::map<std::string, at::Tensor>& initializers =
315           std::map<std::string, at::Tensor>());
316 
317   void EncodeNode(
318       onnx::GraphProto* graph_proto,
319       onnx::NodeProto* node_proto,
320       const Node* node,
321       bool add_node_names = true,
322       bool use_external_data_format = false,
323       const std::string& onnx_file_path = std::string());
324 
325   void EncodeTypeProto(
326       onnx::TypeProto* type_proto,
327       const TypePtr& node_type,
328       const std::string& name);
329 
330   void EncodeLocalFunctionOpsetImport(
331       onnx::FunctionProto* func_proto,
332       const Node* n,
333       std::unordered_set<std::string>& custom_domains);
334 
335   void EncodeLocalFunction(
336       onnx::GraphProto* graph_proto,
337       onnx::FunctionProto* func_proto,
338       const Node* n,
339       bool add_node_names = true,
340       bool use_external_data_format = false,
341       const std::string& onnx_file_path = std::string());
342 
343   void EncodeTensor(
344       onnx::TensorProto* tensor_proto,
345       const at::Tensor& tensor,
346       const std::optional<std::string>& external_ref = {},
347       const bool use_external_data_format = false,
348       const std::string& onnx_file_path = std::string());
349 
350   void EncodeIntermediateValueInfo(
351       onnx::GraphProto* graph_proto,
352       const Value* n);
353 
354   void EncodeValueInfo(
355       onnx::GraphProto* graph_proto,
356       onnx::ValueInfoProto* v,
357       const Value* n,
358       const std::
359           unordered_map<std::string, std::unordered_map<int64_t, std::string>>&
360               dynamic_axes = std::unordered_map<
361                   std::string,
362                   std::unordered_map<int64_t, std::string>>());
363 
364   void EncodeValueInfoType(
365       onnx::TypeProto* onnx_type,
366       const TypePtr& node_type,
367       const Value* n,
368       const std::unordered_map<
369           std::string,
370           std::unordered_map<int64_t, std::string>>& dynamic_axes);
371 
372   void AddAttribute(
373       onnx::NodeProto* node_proto,
374       const jit::Symbol name,
375       const std::string& ref_attr_name,
376       const AttributeKind attr_kind);
377 
378   void AddAttribute(
379       onnx::NodeProto* node_proto,
380       const jit::Node* node,
381       const jit::Symbol name,
382       const bool use_external_data_format = false,
383       const std::string& onnx_file_path = std::string());
384 
385   void AddAttribute(onnx::FunctionProto* func_proto, const std::string& name);
386 
387   void TensorTypeToONNXType(
388       const TensorTypePtr& tensor_type,
389       const std::string& dim_name_prefix,
390       const std::string& name,
391       const std::unordered_map<
392           std::string,
393           std::unordered_map<int64_t, std::string>>& dynamic_axes,
394       onnx::TypeProto_Tensor* onnx_tensor_type,
395       bool assign_dim_param = true);
396 
397   SymbolDimMap symbol_dim_map_;
398   std::shared_ptr<onnx::ModelProto> model_proto_;
399   size_t num_blocks_{0};
400   size_t num_op_nodes_{0};
401   size_t num_external_data_{0};
402   onnx_torch::OperatorExportTypes operator_export_type_;
403   bool strip_doc_;
404   std::set<std::string> domains_;
405   RawDataExportMap raw_data_export_map_;
406   bool defer_weight_export_;
407   bool use_external_data_format_;
408   int64_t onnx_opset_version_;
409   std::map<std::string, int> custom_opsets_;
410   std::shared_ptr<Graph> graph_;
411   NodeAttrNameMap node_attr_to_name_;
412   NodeNameMap onnx_node_name_map_;
413   // For large models, the parameters can be stored in separate binary files.
414   // This parameter sets a threshold on the number of elements in the parameter
415   // tensor, beyond which the parameter is stored in a separate file (if
416   // use_external_data_format_ is True). This threshold is in place
417   // so as not to create too many external files.
418   static constexpr size_t ParamSizeThresholdForExternalStorage = 1024;
419 };
420 
ATenTypeToOnnxType(at::ScalarType at_type)421 onnx::TensorProto_DataType ATenTypeToOnnxType(at::ScalarType at_type) {
422   switch (at_type) {
423     case at::kDouble:
424       return onnx::TensorProto_DataType_DOUBLE;
425     case at::kFloat:
426       return onnx::TensorProto_DataType_FLOAT;
427     case at::kHalf:
428       return onnx::TensorProto_DataType_FLOAT16;
429     case at::kByte:
430       return onnx::TensorProto_DataType_UINT8;
431     case at::kChar:
432       return onnx::TensorProto_DataType_INT8;
433     case at::kShort:
434       return onnx::TensorProto_DataType_INT16;
435     case at::kInt:
436       return onnx::TensorProto_DataType_INT32;
437     case at::kLong:
438       return onnx::TensorProto_DataType_INT64;
439     case at::kBool:
440       return onnx::TensorProto_DataType_BOOL;
441     case at::kQInt8:
442       return onnx::TensorProto_DataType_INT8;
443     case at::kQUInt8:
444       return onnx::TensorProto_DataType_UINT8;
445     case at::kQInt32:
446       return onnx::TensorProto_DataType_INT32;
447     case at::kBFloat16:
448       return onnx::TensorProto_DataType_BFLOAT16;
449     case at::kFloat8_e4m3fn:
450       return onnx_torch::TensorProto_DataType_FLOAT8E4M3FN;
451     case at::kFloat8_e5m2:
452       return onnx_torch::TensorProto_DataType_FLOAT8E5M2;
453     case at::kFloat8_e4m3fnuz:
454       return onnx_torch::TensorProto_DataType_FLOAT8E4M3FNUZ;
455     case at::kFloat8_e5m2fnuz:
456       return onnx_torch::TensorProto_DataType_FLOAT8E5M2FNUZ;
457     default:
458       TORCH_CHECK(
459           false,
460           "ScalarType ",
461           toString(at_type),
462           " is an unexpected tensor scalar type");
463   }
464 }
465 
ATenAttributeKindToOnnxAttributeType(AttributeKind at_kind,const jit::Symbol name)466 onnx::AttributeProto_AttributeType ATenAttributeKindToOnnxAttributeType(
467     AttributeKind at_kind,
468     const jit::Symbol name) {
469   switch (at_kind) {
470     case AttributeKind::f:
471       return onnx::AttributeProto_AttributeType_FLOAT;
472     case AttributeKind::fs:
473       return onnx::AttributeProto_AttributeType_FLOATS;
474     case AttributeKind::i:
475       return onnx::AttributeProto_AttributeType_INT;
476     case AttributeKind::is:
477       return onnx::AttributeProto_AttributeType_INTS;
478     case AttributeKind::s:
479       return onnx::AttributeProto_AttributeType_STRING;
480     case AttributeKind::ss:
481       return onnx::AttributeProto_AttributeType_STRINGS;
482     case AttributeKind::t:
483       return onnx::AttributeProto_AttributeType_TENSOR;
484     case AttributeKind::ts:
485       return onnx::AttributeProto_AttributeType_TENSORS;
486     case AttributeKind::ty:
487       return onnx::AttributeProto_AttributeType_TYPE_PROTO;
488     case AttributeKind::tys:
489       return onnx::AttributeProto_AttributeType_TYPE_PROTOS;
490     case AttributeKind::g:
491       return onnx::AttributeProto_AttributeType_GRAPH;
492     case AttributeKind::gs:
493       return onnx::AttributeProto_AttributeType_GRAPHS;
494     default:
495       std::ostringstream err_msg;
496       err_msg << "attribute \"" << name.toDisplayString()
497               << "\" has unexpected kind: " << toString(at_kind);
498       throw std::runtime_error(err_msg.str());
499   }
500 }
501 
GraphEncoder(const std::shared_ptr<Graph> & graph,int64_t onnx_opset_version,onnx_torch::OperatorExportTypes operator_export_type,const std::map<std::string,at::Tensor> & initializers,const std::unordered_map<std::string,std::unordered_map<int64_t,std::string>> & dynamic_axes,bool defer_weight_export,bool strip_doc,bool keep_initializers_as_inputs,const std::map<std::string,int> & custom_opsets,bool add_node_names,bool use_external_data_format,const std::string & onnx_file_path,NodeAttrNameMap node_attr_to_name)502 GraphEncoder::GraphEncoder(
503     const std::shared_ptr<Graph>& graph,
504     int64_t onnx_opset_version,
505     onnx_torch::OperatorExportTypes operator_export_type,
506     const std::map<std::string, at::Tensor>& initializers,
507     const std::unordered_map<
508         std::string,
509         std::unordered_map<int64_t, std::string>>& dynamic_axes,
510     bool defer_weight_export,
511     bool strip_doc,
512     bool keep_initializers_as_inputs,
513     const std::map<std::string, int>& custom_opsets,
514     bool add_node_names,
515     bool use_external_data_format,
516     const std::string& onnx_file_path,
517     NodeAttrNameMap node_attr_to_name)
518     : model_proto_(std::make_shared<onnx::ModelProto>()),
519 
520       operator_export_type_(operator_export_type),
521       strip_doc_(strip_doc),
522       defer_weight_export_(defer_weight_export),
523       use_external_data_format_(use_external_data_format),
524       onnx_opset_version_(onnx_opset_version),
525       custom_opsets_(custom_opsets),
526       graph_(graph),
527       node_attr_to_name_(std::move(node_attr_to_name)) {
528   model_proto_->set_producer_name("pytorch");
529   TORCH_CHECK(
530       onnx_opset_version > 0 &&
531           static_cast<size_t>(onnx_opset_version) <
532               kOpsetVersionToIRVersion.size() &&
533           kOpsetVersionToIRVersion[onnx_opset_version] != kInvalidOpsetVersion,
534       "Unsupported onnx_opset_version: ",
535       onnx_opset_version);
536 
537   model_proto_->set_ir_version(kOpsetVersionToIRVersion[onnx_opset_version]);
538   model_proto_->set_producer_version(TORCH_VERSION);
539   validateGraph(graph, operator_export_type);
540 
541   // If graph proto size exceed maximum protobuf size of 2GB, set
542   // use_external_data_format to true.
543   if (!use_external_data_format &&
544       GetGraphProtoSize(
545           model_proto_->mutable_graph(),
546           graph,
547           add_node_names,
548           use_external_data_format,
549           onnx_file_path,
550           initializers) > INT_MAX) {
551     GRAPH_DEBUG(
552         "Exporting model exceed maximum protobuf size of 2GB. Storing model parameters in external data files");
553     use_external_data_format = true;
554     // use_external_data_format_ is one of graph_encoder private variable set
555     // for return `use_external_data_format` value.
556     use_external_data_format_ = use_external_data_format;
557   }
558 
559   if (use_external_data_format) {
560     TORCH_CHECK(
561         !onnx_file_path.empty(),
562         "The serialized model is larger than the 2GiB limit imposed by the protobuf library. ",
563         "Therefore the output file must be a file path, so that the ONNX external data can ",
564         "be written to the same directory. Please specify the output file name.");
565   }
566 
567   auto* imp = model_proto_->add_opset_import();
568   // This is the version of ONNX operator set we are targeting
569   imp->set_version(onnx_opset_version);
570 
571   EncodeGraph(
572       model_proto_->mutable_graph(),
573       graph,
574       initializers,
575       dynamic_axes,
576       keep_initializers_as_inputs,
577       add_node_names,
578       use_external_data_format,
579       onnx_file_path);
580 
581   for (const std::string& domain : domains_) {
582     auto* opset = model_proto_->add_opset_import();
583     opset->set_domain(domain);
584     //  Check if domain version is registered. If not, set to version 1
585     auto it = custom_opsets.find(domain);
586     if (it == custom_opsets.end())
587       opset->set_version(1);
588     else {
589       opset->set_version(it->second);
590     }
591   }
592 
593   for (auto const& custom_opset : custom_opsets) {
594     if (!std::count(domains_.begin(), domains_.end(), custom_opset.first)) {
595       TORCH_WARN(
596           "Custom opset domain: '",
597           custom_opset.first,
598           "' provided is not used in the model. ",
599           "Please verify custom opset domain names.");
600     }
601   }
602 }
603 
604 // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
TensorTypeToONNXType(const TensorTypePtr & tensor_type,const std::string & dim_name_prefix,const std::string & name,const std::unordered_map<std::string,std::unordered_map<int64_t,std::string>> & dynamic_axes,onnx::TypeProto_Tensor * onnx_tensor_type,bool assign_dim_param)605 void GraphEncoder::TensorTypeToONNXType(
606     const TensorTypePtr& tensor_type,
607     const std::string& dim_name_prefix,
608     const std::string& name,
609     const std::unordered_map<
610         std::string,
611         std::unordered_map<int64_t, std::string>>& dynamic_axes,
612     onnx::TypeProto_Tensor* onnx_tensor_type,
613     bool assign_dim_param) {
614   if (tensor_type->dim()) {
615     onnx::TensorShapeProto* shape = onnx_tensor_type->mutable_shape();
616     auto sizes = tensor_type->symbolic_sizes().sizes().value();
617     for (const auto i : c10::irange(sizes.size())) {
618       shape->add_dim();
619       if ((dynamic_axes.find(name) != dynamic_axes.end()) &&
620           (dynamic_axes.at(name).find(i) != dynamic_axes.at(name).end())) {
621         shape->mutable_dim(i)->set_dim_param(dynamic_axes.at(name).at(i));
622         if (!sizes[i].is_static()) {
623           symbol_dim_map_[sizes[i]] = dynamic_axes.at(name).at(i);
624         }
625       } else if (sizes[i].is_static()) {
626         shape->mutable_dim(i)->set_dim_value(sizes[i].static_size());
627       } else if (assign_dim_param) {
628         if (symbol_dim_map_.find(sizes[i]) == symbol_dim_map_.end()) {
629           symbol_dim_map_[sizes[i]] =
630               dim_name_prefix + name + "_dim_" + std::to_string(i);
631         }
632         shape->mutable_dim(i)->set_dim_param(symbol_dim_map_[sizes[i]]);
633       }
634     }
635   }
636   if (tensor_type->scalarType()) {
637     onnx_tensor_type->set_elem_type(
638         ATenTypeToOnnxType(tensor_type->scalarType().value()));
639   }
640 }
641 // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
642 
EncodeValueInfoType(onnx::TypeProto * onnx_type,const TypePtr & node_type,const Value * n,const std::unordered_map<std::string,std::unordered_map<int64_t,std::string>> & dynamic_axes)643 void GraphEncoder::EncodeValueInfoType(
644     onnx::TypeProto* onnx_type,
645     const TypePtr& node_type,
646     const Value* n,
647     const std::unordered_map<
648         std::string,
649         std::unordered_map<int64_t, std::string>>& dynamic_axes) {
650   std::string dim_name_prefix;
651   if (n->node()->kind() != prim::Param) {
652     dim_name_prefix = n->node()->kind().toUnqualString();
653   }
654   if (TensorTypePtr tensor_type = node_type->cast<TensorType>()) {
655     if (tensor_type->dim() || tensor_type->scalarType()) {
656       // Encode type if either shape or dtype exists.
657       onnx::TypeProto_Tensor* onnx_tensor_type =
658           onnx_type->mutable_tensor_type();
659       // Do not assign dim_param for sequence tensor type.
660       // Sequence of tensors could differ in dimension size.
661       // Use a dimension with neither dim_value nor dim_param set
662       // to denote an unknown dimension.
663       // Create and assign dim_param for normal tensor type.
664       auto is_sequence_tensor = static_cast<bool>(n->type()->cast<ListType>());
665       TensorTypeToONNXType(
666           tensor_type,
667           dim_name_prefix,
668           n->debugName(),
669           dynamic_axes,
670           onnx_tensor_type,
671           !is_sequence_tensor);
672     }
673   } else if (BoolTypePtr bool_type = node_type->cast<BoolType>()) {
674     onnx::TypeProto_Tensor* onnx_tensor_type = onnx_type->mutable_tensor_type();
675     onnx_tensor_type->set_elem_type(ATenTypeToOnnxType(at::kBool));
676   } else if (IntTypePtr int_type = node_type->cast<IntType>()) {
677     onnx::TypeProto_Tensor* onnx_tensor_type = onnx_type->mutable_tensor_type();
678     onnx_tensor_type->set_elem_type(ATenTypeToOnnxType(at::kLong));
679   } else if (FloatTypePtr float_type = node_type->cast<FloatType>()) {
680     onnx::TypeProto_Tensor* onnx_tensor_type = onnx_type->mutable_tensor_type();
681     onnx_tensor_type->set_elem_type(ATenTypeToOnnxType(at::kFloat));
682   } else if (ListTypePtr list_type = node_type->cast<ListType>()) {
683     auto list_elem_type = list_type->getElementType();
684     onnx::TypeProto_Sequence* sequence_type =
685         onnx_type->mutable_sequence_type();
686     onnx::TypeProto* onnx_tensor_type = sequence_type->mutable_elem_type();
687     EncodeValueInfoType(onnx_tensor_type, list_elem_type, n, dynamic_axes);
688   } else if (OptionalTypePtr optional_type = node_type->cast<OptionalType>()) {
689     auto elem_type = optional_type->getElementType();
690     if (TensorTypePtr tensor_type = elem_type->cast<TensorType>()) {
691       onnx::TypeProto_Optional* onnx_optional_type =
692           onnx_type->mutable_optional_type();
693       onnx::TypeProto_Tensor* onnx_tensor_type =
694           onnx_optional_type->mutable_elem_type()->mutable_tensor_type();
695       TensorTypeToONNXType(
696           tensor_type,
697           dim_name_prefix,
698           n->debugName(),
699           dynamic_axes,
700           onnx_tensor_type);
701     } else if (ListTypePtr inner_node_type = elem_type->cast<ListType>()) {
702       auto list_elem_type = inner_node_type->getElementType();
703       if (TensorTypePtr tensor_type = list_elem_type->cast<TensorType>()) {
704         onnx::TypeProto_Optional* onnx_optional_type =
705             onnx_type->mutable_optional_type();
706         onnx::TypeProto_Sequence* onnx_optional_sequence_type =
707             onnx_optional_type->mutable_elem_type()->mutable_sequence_type();
708         onnx::TypeProto_Tensor* onnx_tensor_type =
709             onnx_optional_sequence_type->mutable_elem_type()
710                 ->mutable_tensor_type();
711         TensorTypeToONNXType(
712             tensor_type,
713             dim_name_prefix,
714             n->debugName(),
715             dynamic_axes,
716             onnx_tensor_type);
717       }
718     }
719   }
720 }
721 
EncodeValueInfo(onnx::GraphProto * graph_proto,onnx::ValueInfoProto * v,const Value * n,const std::unordered_map<std::string,std::unordered_map<int64_t,std::string>> & dynamic_axes)722 void GraphEncoder::EncodeValueInfo(
723     onnx::GraphProto* graph_proto,
724     onnx::ValueInfoProto* v,
725     const Value* n,
726     const std::unordered_map<
727         std::string,
728         std::unordered_map<int64_t, std::string>>& dynamic_axes) {
729   std::string name = n->debugName();
730   v->set_name(name);
731   EncodeValueInfoType(v->mutable_type(), n->type(), n, dynamic_axes);
732 }
733 
EncodeGraph(onnx::GraphProto * graph_proto,const std::shared_ptr<Graph> & graph,const std::map<std::string,at::Tensor> & initializers,const std::unordered_map<std::string,std::unordered_map<int64_t,std::string>> & dynamic_axes,bool keep_initializers_as_inputs,bool add_node_names,bool use_external_data_format,const std::string & onnx_file_path)734 void GraphEncoder::EncodeGraph(
735     onnx::GraphProto* graph_proto,
736     const std::shared_ptr<Graph>& graph,
737     const std::map<std::string, at::Tensor>& initializers,
738     const std::unordered_map<
739         std::string,
740         std::unordered_map<int64_t, std::string>>& dynamic_axes,
741     bool keep_initializers_as_inputs,
742     bool add_node_names,
743     bool use_external_data_format,
744     const std::string& onnx_file_path) {
745   EncodeBlock(
746       graph_proto,
747       graph->block(),
748       initializers,
749       dynamic_axes,
750       keep_initializers_as_inputs,
751       add_node_names,
752       use_external_data_format,
753       onnx_file_path);
754 }
755 
EncodeBlock(onnx::GraphProto * graph_proto,const Block * block,const std::map<std::string,at::Tensor> & initializers,const std::unordered_map<std::string,std::unordered_map<int64_t,std::string>> & dynamic_axes,bool keep_initializers_as_inputs,bool add_node_names,bool use_external_data_format,const std::string & onnx_file_path)756 void GraphEncoder::EncodeBlock(
757     onnx::GraphProto* graph_proto,
758     const Block* block,
759     const std::map<std::string, at::Tensor>& initializers,
760     const std::unordered_map<
761         std::string,
762         std::unordered_map<int64_t, std::string>>& dynamic_axes,
763     bool keep_initializers_as_inputs,
764     bool add_node_names,
765     bool use_external_data_format,
766     const std::string& onnx_file_path) {
767   TORCH_INTERNAL_ASSERT(graph_proto != nullptr);
768   if (nullptr == block->owningNode()) {
769     // Top level main graph.
770     graph_proto->set_name("main_graph");
771   } else {
772     // TODO: Set more meaningful name for sub-graphs.
773     std::string block_name = "sub_graph";
774     if (num_blocks_) {
775       block_name += std::to_string(num_blocks_);
776     }
777     num_blocks_++;
778     graph_proto->set_name(block_name);
779   }
780 
781   // Since ONNX IR VERSION 4, initializers do not have to
782   // be a subset of graph inputs. We use keep_initializers_as_inputs
783   // argument to determine whether to add initializers
784   // as inputs or not. If keep_initializers_as_inputs=false,
785   // we only add non-parameter inputs as inputs to ONNX graph, and
786   // not the initializers (parameters). If keep_initializers_as_inputs
787   // =true, we add initializers as inputs too. Setting
788   // keep_initializers_as_inputs=false allows better
789   // optimizations, such as constant-folding, on ONNX graphs
790   // by backends/optimizers.
791   if (keep_initializers_as_inputs) {
792     for (auto input : block->inputs()) {
793       onnx::ValueInfoProto* v = graph_proto->add_input();
794       EncodeValueInfo(graph_proto, v, input, dynamic_axes);
795     }
796   } else {
797     for (auto input : block->inputs()) {
798       auto it = initializers.find(input->debugName());
799       if (it == initializers.end()) {
800         onnx::ValueInfoProto* v = graph_proto->add_input();
801         EncodeValueInfo(graph_proto, v, input, dynamic_axes);
802       }
803     }
804   }
805   for (auto output : block->outputs()) {
806     onnx::ValueInfoProto* v = graph_proto->add_output();
807     EncodeValueInfo(graph_proto, v, output, dynamic_axes);
808   }
809   for (auto node : block->nodes()) {
810     if (node->mustBeNone()) {
811       // None nodes are used to implement optional inputs. One
812       // way to "not provide" an optional input is to create an
813       // Undefined node, and pass its output as that input.
814       continue;
815     }
816     if (node->kind() == ::c10::Symbol::onnx("LocalFunctionDef")) {
817       auto* func_proto = model_proto_->add_functions();
818       EncodeLocalFunction(
819           graph_proto,
820           func_proto,
821           node,
822           add_node_names,
823           use_external_data_format,
824           onnx_file_path);
825       continue;
826     }
827     auto* n_proto = graph_proto->add_node();
828     EncodeNode(
829         graph_proto,
830         n_proto,
831         node,
832         add_node_names,
833         use_external_data_format,
834         onnx_file_path);
835   }
836   AddInitializersIntoGraphProto(
837       graph_proto,
838       block,
839       initializers,
840       use_external_data_format,
841       onnx_file_path);
842 }
843 
AddInitializersIntoGraphProto(onnx::GraphProto * graph_proto,const Block * block,const std::map<std::string,at::Tensor> & initializers,bool use_external_data_format,const std::string & onnx_file_path)844 void GraphEncoder::AddInitializersIntoGraphProto(
845     onnx::GraphProto* graph_proto,
846     const Block* block,
847     const std::map<std::string, at::Tensor>& initializers,
848     bool use_external_data_format,
849     const std::string& onnx_file_path) {
850   TORCH_INTERNAL_ASSERT(block->inputs().size() >= initializers.size());
851   for (auto input : block->inputs()) {
852     auto name_tensor_pair = initializers.find(input->debugName());
853     if (name_tensor_pair == initializers.end()) {
854       continue;
855     }
856     auto p = graph_proto->add_initializer();
857     p->set_name(name_tensor_pair->first);
858     EncodeTensor(
859         p,
860         name_tensor_pair->second,
861         name_tensor_pair->first,
862         use_external_data_format,
863         onnx_file_path);
864   }
865 }
866 
GetGraphProtoSize(onnx::GraphProto * graph_proto,const std::shared_ptr<Graph> & graph,bool add_node_names,bool use_external_data_format,const std::string & onnx_file_path,const std::map<std::string,at::Tensor> & initializers)867 unsigned long long int GraphEncoder::GetGraphProtoSize(
868     onnx::GraphProto* graph_proto,
869     const std::shared_ptr<Graph>& graph,
870     bool add_node_names,
871     bool use_external_data_format,
872     const std::string& onnx_file_path,
873     const std::map<std::string, at::Tensor>& initializers) {
874   // Model size = sum(size(initializers)) + sum(size(onnx_constant_nodes))
875 
876   // Add up all Initializers
877   onnx::GraphProto graph_proto_copy = onnx::GraphProto(*graph_proto);
878   unsigned long long int size = graph_proto_copy.ByteSizeLong();
879   for (auto input : graph->inputs()) {
880     auto name_tensor_pair = initializers.find(input->debugName());
881     if (name_tensor_pair == initializers.end()) {
882       continue;
883     }
884     auto tensor_proto = graph_proto_copy.add_initializer();
885     const at::Tensor& tensor = name_tensor_pair->second;
886     for (auto d : tensor.sizes()) {
887       tensor_proto->add_dims(d);
888     }
889     tensor_proto->set_data_type(ATenTypeToOnnxType(tensor.scalar_type()));
890 
891     // Don't actually copy the buffer into tensor_proto since that is expensive.
892     // All we actually need is its size.
893     size += tensor_proto->ByteSizeLong();
894     size += tensor.element_size() * tensor.numel();
895   }
896 
897   // Add up all onnx::Constant nodes that are Tensors
898   for (const auto& node : graph->nodes()) {
899     if (node->kind() == ::c10::onnx::Constant &&
900         node->hasAttribute(attr::value) &&
901         node->kindOf(attr::value) == AttributeKind::t) {
902       at::Tensor tensor = node->t(attr::value);
903 
904       // Don't actually copy the buffer into n_proto since that is expensive.
905       // All we actually need is its size.
906       auto* n_proto = graph_proto_copy.add_node();
907       EncodeNode(
908           &graph_proto_copy,
909           n_proto,
910           node,
911           add_node_names,
912           use_external_data_format,
913           onnx_file_path);
914 
915       // Calculate the size of the tensor in bytes
916       size += n_proto->ByteSizeLong();
917       size += tensor.element_size() * tensor.numel();
918     }
919   }
920   return size;
921 }
922 
EncodeNode(onnx::GraphProto * graph_proto,onnx::NodeProto * node_proto,const Node * node,bool add_node_names,bool use_external_data_format,const std::string & onnx_file_path)923 void GraphEncoder::EncodeNode(
924     onnx::GraphProto* graph_proto,
925     onnx::NodeProto* node_proto,
926     const Node* node,
927     bool add_node_names,
928     bool use_external_data_format,
929     const std::string& onnx_file_path) {
930   if (!strip_doc_) {
931     node_proto->set_doc_string(node->sourceRange().str());
932   }
933   for (auto input : node->inputs()) {
934     if (input->node()->mustBeNone()) {
935       node_proto->add_input("");
936     } else {
937       node_proto->add_input(input->debugName());
938     }
939   }
940   for (auto output : node->outputs()) {
941     node_proto->add_output(output->debugName());
942     EncodeIntermediateValueInfo(graph_proto, output);
943   }
944   if (!node->kind().is_onnx()) {
945     std::string domain;
946     if (node->kind().is_aten() || node->kind().is_caffe2()) {
947       domain = node->kind().domainString();
948     } else { //  Custom namespace and domain
949       domain = node->kind().ns().toUnqualString();
950     }
951     // TODO: set correct domain for function proto.
952     domains_.insert(domain);
953     node_proto->set_domain(domain);
954   }
955   if (operator_export_type_ == onnx_torch::OperatorExportTypes::ONNX) {
956     TORCH_INTERNAL_ASSERT(
957         !node->kind().is_aten() && !node->kind().is_prim() &&
958         !node->kind().is_attr());
959   }
960   node_proto->set_op_type(node->kind().toUnqualString());
961   const auto node_name_attribute_symbol =
962       Symbol::attr(::torch::onnx::kOnnxNodeNameAttribute);
963   if (add_node_names) {
964     std::string node_name =
965         node_proto->op_type() + "_" + std::to_string(num_op_nodes_);
966     if (node->hasAttribute(node_name_attribute_symbol)) {
967       node_name = node->s(node_name_attribute_symbol);
968     }
969     node_proto->set_name(node_name);
970     onnx_node_name_map_[node] = node_name;
971     num_op_nodes_++;
972   }
973   auto attrs_it = node_attr_to_name_.find(node);
974   for (auto attr_name : node->attributeNames()) {
975     if (attr_name == node_name_attribute_symbol) {
976       // Skip the node name attribute.
977       continue;
978     }
979     if (attrs_it != node_attr_to_name_.end()) {
980       auto attr_it = attrs_it->second.find(attr_name.toUnqualString());
981       if (attr_it != attrs_it->second.end()) {
982         AddAttribute(
983             node_proto, attr_name, attr_it->second, node->kindOf(attr_name));
984         continue;
985       }
986     }
987     AddAttribute(
988         node_proto, node, attr_name, use_external_data_format, onnx_file_path);
989   }
990   if (node->kind() == ::c10::onnx::Loop) {
991     TORCH_INTERNAL_ASSERT(node->blocks().size() == 1);
992 
993     auto body = node_proto->add_attribute();
994     body->set_name("body");
995     body->set_type(onnx::AttributeProto_AttributeType_GRAPH);
996     auto g = body->mutable_g();
997     EncodeBlock(
998         g,
999         node->blocks()[0],
1000         {},
1001         {},
1002         true,
1003         true,
1004         use_external_data_format,
1005         onnx_file_path);
1006   }
1007   if (node->kind() == ::c10::onnx::If) {
1008     TORCH_INTERNAL_ASSERT(node->blocks().size() == 2);
1009 
1010     auto then_branch = node_proto->add_attribute();
1011     then_branch->set_name("then_branch");
1012     then_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH);
1013     auto true_g = then_branch->mutable_g();
1014     EncodeBlock(
1015         true_g,
1016         node->blocks()[0],
1017         {},
1018         {},
1019         true,
1020         true,
1021         use_external_data_format,
1022         onnx_file_path);
1023 
1024     auto else_branch = node_proto->add_attribute();
1025     else_branch->set_name("else_branch");
1026     else_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH);
1027     auto false_g = else_branch->mutable_g();
1028     EncodeBlock(
1029         false_g,
1030         node->blocks()[1],
1031         {},
1032         {},
1033         true,
1034         true,
1035         use_external_data_format,
1036         onnx_file_path);
1037   }
1038 }
1039 
AddAttribute(onnx::NodeProto * node_proto,const jit::Symbol name,const std::string & ref_attr_name,const AttributeKind attr_kind)1040 void GraphEncoder::AddAttribute(
1041     onnx::NodeProto* node_proto,
1042     const jit::Symbol name,
1043     const std::string& ref_attr_name,
1044     const AttributeKind attr_kind) {
1045   auto attr = node_proto->add_attribute();
1046   TORCH_INTERNAL_ASSERT(name.is_attr());
1047   attr->set_name(name.toUnqualString());
1048   attr->set_ref_attr_name(ref_attr_name);
1049   attr->set_type(ATenAttributeKindToOnnxAttributeType(attr_kind, name));
1050 }
1051 
AddAttribute(onnx::NodeProto * node_proto,const jit::Node * node,const jit::Symbol name,const bool use_external_data_format,const std::string & onnx_file_path)1052 void GraphEncoder::AddAttribute(
1053     onnx::NodeProto* node_proto,
1054     const jit::Node* node,
1055     const jit::Symbol name,
1056     const bool use_external_data_format,
1057     const std::string& onnx_file_path) {
1058   auto createAttributeTensorName =
1059       [](const onnx::NodeProto* node_proto,
1060          onnx::TensorProto* tensor_proto,
1061          const jit::Symbol attr_name,
1062          size_t& num_external_data) -> std::string {
1063     if (tensor_proto->has_name()) {
1064       return tensor_proto->name();
1065     }
1066     if (!node_proto->has_name()) {
1067       auto name = node_proto->op_type() + "_" + attr_name.toDisplayString() +
1068           "_" + std::to_string(num_external_data);
1069       num_external_data++;
1070       return name;
1071     } else {
1072       return node_proto->name() + "_" + attr_name.toDisplayString();
1073     }
1074   };
1075 
1076   auto attr = node_proto->add_attribute();
1077   TORCH_INTERNAL_ASSERT(name.is_attr());
1078   attr->set_name(name.toUnqualString());
1079   attr->set_type(
1080       ATenAttributeKindToOnnxAttributeType(node->kindOf(name), name));
1081   switch (node->kindOf(name)) {
1082     case AttributeKind::f:
1083       attr->set_f(static_cast<float>(node->f(name)));
1084       break;
1085     case AttributeKind::fs:
1086       for (auto& v : node->fs(name))
1087         attr->add_floats(static_cast<float>(v));
1088       break;
1089     case AttributeKind::i:
1090       attr->set_i(node->i(name));
1091       break;
1092     case AttributeKind::is:
1093       for (auto& v : node->is(name))
1094         attr->add_ints(v);
1095       break;
1096     case AttributeKind::s:
1097       attr->set_s(node->s(name));
1098       break;
1099     case AttributeKind::ss:
1100       for (auto& v : node->ss(name))
1101         attr->add_strings(v);
1102       break;
1103     case AttributeKind::t: {
1104       auto t = attr->mutable_t();
1105       if (use_external_data_format && !t->has_name()) {
1106         t->set_name(
1107             createAttributeTensorName(node_proto, t, name, num_external_data_));
1108       }
1109       EncodeTensor(
1110           t, node->t(name), {}, use_external_data_format, onnx_file_path);
1111     } break;
1112     case AttributeKind::ts:
1113       for (auto& v : node->ts(name)) {
1114         auto t = attr->add_tensors();
1115         if (use_external_data_format && !t->has_name()) {
1116           t->set_name(createAttributeTensorName(
1117               node_proto, t, name, num_external_data_));
1118         }
1119         EncodeTensor(t, v, {}, use_external_data_format, onnx_file_path);
1120       }
1121       break;
1122     case AttributeKind::ty: {
1123       attr->set_type(onnx::AttributeProto_AttributeType_TYPE_PROTO);
1124       auto tp = attr->mutable_tp();
1125       const TypePtr& node_type = node->ty(name);
1126       EncodeTypeProto(
1127           tp, node_type, node_proto->op_type() + "_" + name.toDisplayString());
1128     } break;
1129     case AttributeKind::tys: {
1130       attr->set_type(onnx::AttributeProto_AttributeType_TYPE_PROTOS);
1131       size_t index = 0;
1132       for (auto& v : node->tys(name)) {
1133         auto tp = attr->add_type_protos();
1134         EncodeTypeProto(
1135             tp,
1136             v,
1137             node_proto->op_type() + "_" + name.toDisplayString() + "_" +
1138                 std::to_string(index));
1139         index++;
1140       }
1141     } break;
1142     case AttributeKind::g: {
1143       auto g = attr->mutable_g();
1144       EncodeGraph(
1145           g,
1146           node->g(name),
1147           {},
1148           {},
1149           true,
1150           true,
1151           use_external_data_format,
1152           onnx_file_path);
1153     } break;
1154     case AttributeKind::gs:
1155       for (auto& v : node->gs(name)) {
1156         auto g = attr->add_graphs();
1157         EncodeGraph(
1158             g, v, {}, {}, true, true, use_external_data_format, onnx_file_path);
1159       }
1160       break;
1161     default:
1162       std::ostringstream err_msg;
1163       err_msg << "attribute \"" << name.toDisplayString()
1164               << "\" has unexpected kind: " << toString(node->kindOf(name));
1165       throw std::runtime_error(err_msg.str());
1166   }
1167 }
1168 
AddAttribute(onnx::FunctionProto * func_proto,const std::string & name)1169 void GraphEncoder::AddAttribute(
1170     onnx::FunctionProto* func_proto,
1171     const std::string& name) {
1172   TORCH_INTERNAL_ASSERT(nullptr != func_proto);
1173   func_proto->add_attribute(name);
1174 }
1175 
EncodeLocalFunctionOpsetImport(onnx::FunctionProto * func_proto,const Node * n,std::unordered_set<std::string> & custom_domains)1176 void GraphEncoder::EncodeLocalFunctionOpsetImport(
1177     onnx::FunctionProto* func_proto,
1178     const Node* n,
1179     std::unordered_set<std::string>& custom_domains) {
1180   if (!n->kind().is_onnx()) {
1181     std::string domain;
1182     if (n->kind().is_aten() || n->kind().is_caffe2()) {
1183       domain = n->kind().domainString();
1184     } else { //  Custom namespace and domain
1185       domain = n->kind().ns().toUnqualString();
1186     }
1187     domains_.insert(domain);
1188 
1189     if (custom_domains.find(domain) == custom_domains.end()) {
1190       custom_domains.insert(domain);
1191 
1192       auto* custom_imp = func_proto->add_opset_import();
1193       custom_imp->set_domain(domain);
1194       //  Check if domain version is registered. If not, set to version 1
1195       auto it = custom_opsets_.find(domain);
1196       if (it == custom_opsets_.end())
1197         custom_imp->set_version(1);
1198       else {
1199         custom_imp->set_version(it->second);
1200       }
1201     }
1202   }
1203 
1204   for (auto* b : n->blocks()) {
1205     for (auto* sub_n : b->nodes()) {
1206       EncodeLocalFunctionOpsetImport(func_proto, sub_n, custom_domains);
1207     }
1208   }
1209 }
1210 
EncodeLocalFunction(onnx::GraphProto * graph_proto,onnx::FunctionProto * func_proto,const Node * n,bool add_node_names,bool use_external_data_format,const std::string & onnx_file_path)1211 void GraphEncoder::EncodeLocalFunction(
1212     onnx::GraphProto* graph_proto,
1213     onnx::FunctionProto* func_proto,
1214     const Node* n,
1215     bool add_node_names,
1216     bool use_external_data_format,
1217     const std::string& onnx_file_path) {
1218   const auto fsub_g = n->g(Symbol::attr("graph"));
1219   func_proto->set_name(n->s(::c10::attr::name));
1220 
1221   for (auto input : fsub_g->inputs()) {
1222     func_proto->add_input(input->debugName());
1223   }
1224   for (auto output : fsub_g->outputs()) {
1225     func_proto->add_output(output->debugName());
1226   }
1227 
1228   // encode attributes names
1229   if (n->hasAttribute(Symbol::attr("attributes"))) {
1230     for (const auto& attr_name : n->ss(Symbol::attr("attributes"))) {
1231       AddAttribute(func_proto, attr_name);
1232     }
1233   }
1234 
1235   auto* imp = func_proto->add_opset_import();
1236   // This is the version of ONNX operator set we are targeting
1237   imp->set_version(onnx_opset_version_);
1238 
1239   // add for custom domain as well.
1240   const auto& domain = n->s(Symbol::attr("domain"));
1241   func_proto->set_domain(domain);
1242   domains_.insert(domain);
1243   std::unordered_set<std::string> custom_domains;
1244 
1245   for (auto* fsub_n : fsub_g->nodes()) {
1246     if (fsub_n->mustBeNone()) {
1247       // None nodes are used to implement optional inputs. One
1248       // way to "not provide" an optional input is to create an
1249       // Undefined node, and pass its output as that input.
1250       continue;
1251     }
1252     auto* n_proto = func_proto->add_node();
1253     EncodeNode(
1254         graph_proto,
1255         n_proto,
1256         fsub_n,
1257         add_node_names,
1258         use_external_data_format,
1259         onnx_file_path);
1260     EncodeLocalFunctionOpsetImport(func_proto, fsub_n, custom_domains);
1261   }
1262 }
1263 
EncodeTypeProto(onnx::TypeProto * type_proto,const TypePtr & node_type,const std::string & name)1264 void GraphEncoder::EncodeTypeProto(
1265     onnx::TypeProto* type_proto,
1266     const TypePtr& node_type,
1267     const std::string& name) {
1268   if (TensorTypePtr tensor_type = node_type->cast<TensorType>()) {
1269     onnx::TypeProto_Tensor* onnx_tensor_type =
1270         type_proto->mutable_tensor_type();
1271     TensorTypeToONNXType(tensor_type, "", name, {}, onnx_tensor_type);
1272   } else if (ListTypePtr list_type = node_type->cast<ListType>()) {
1273     onnx::TypeProto_Sequence* seq_type = type_proto->mutable_sequence_type();
1274     auto elem_type = list_type->getElementType();
1275     EncodeTypeProto(seq_type->mutable_elem_type(), elem_type, name);
1276   }
1277 }
1278 
EncodeTensor(onnx::TensorProto * tensor_proto,const at::Tensor & tensor,const std::optional<std::string> & external_ref,const bool use_external_data_format,const std::string & onnx_file_path)1279 void GraphEncoder::EncodeTensor(
1280     onnx::TensorProto* tensor_proto,
1281     const at::Tensor& tensor,
1282     const std::optional<std::string>& external_ref,
1283     const bool use_external_data_format,
1284     const std::string& onnx_file_path) {
1285   for (auto d : tensor.sizes()) {
1286     tensor_proto->add_dims(d);
1287   }
1288   tensor_proto->set_data_type(ATenTypeToOnnxType(tensor.scalar_type()));
1289   at::Tensor t;
1290   // CPU's HalfTensor doesn't have contiguous(), so first calling contiguous()
1291   // TODO We don't call .cpu() on quantized tensors as it fails when calling
1292   // aten::empty() on quantized tensors beyond certain size. Issue #29435.
1293   if (tensor.is_quantized()) {
1294     t = tensor.contiguous();
1295   } else {
1296     t = tensor.contiguous().cpu();
1297   }
1298 
1299   // Either defer_weight_export should be true and external_ref must be present,
1300   // or use_external_data_format should be true, not both at the same time. They
1301   // can both be false at the same time (for ONNX export for regular model
1302   // size).
1303   TORCH_INTERNAL_ASSERT(
1304       !((defer_weight_export_ && external_ref) && use_external_data_format));
1305   // Add a buffer to the raw_data_export_map for the caller to dump into an
1306   // external data store. If external_ref is not specified, we instead dump
1307   // the contiguous data into the protobuf itself
1308   if (defer_weight_export_ && external_ref) {
1309     // For now, we use the name of the tensor as the external lookup name to
1310     // avoid ONNX protobuf changes.
1311     TORCH_INTERNAL_ASSERT(external_ref.value() == tensor_proto->name());
1312     TORCH_INTERNAL_ASSERT(
1313         raw_data_export_map_.count(external_ref.value()) == 0);
1314     raw_data_export_map_[external_ref.value()] = t;
1315     tensor_proto->set_raw_data("__EXTERNAL");
1316   } else {
1317     TORCH_INTERNAL_ASSERT(t.is_contiguous());
1318     size_t tensorSize = static_cast<size_t>(c10::multiply_integers(
1319         std::begin(tensor.sizes()), std::end(tensor.sizes())));
1320     if (use_external_data_format &&
1321         tensorSize > ParamSizeThresholdForExternalStorage) {
1322       TORCH_INTERNAL_ASSERT(!onnx_file_path.empty());
1323       TORCH_INTERNAL_ASSERT(tensor_proto->has_name());
1324       auto tensorName = GetExternalFileName(tensor_proto->name());
1325       CreateExternalFile(t, tensorName, onnx_file_path);
1326       onnx::StringStringEntryProto* location =
1327           tensor_proto->mutable_external_data()->Add();
1328       location->set_key("location");
1329       location->set_value(tensorName);
1330       tensor_proto->set_data_location(onnx::TensorProto_DataLocation_EXTERNAL);
1331     } else {
1332       // According to ParseData function's comments in onnx, tensor data is
1333       // always little endian.
1334       tensor_proto->set_raw_data(get_little_endian_data(t));
1335     }
1336   }
1337 }
1338 
EncodeIntermediateValueInfo(onnx::GraphProto * graph_proto,const Value * v)1339 void GraphEncoder::EncodeIntermediateValueInfo(
1340     onnx::GraphProto* graph_proto,
1341     const Value* v) {
1342   // Motivation is to encode ValueInfo for onnx local function nodes.
1343   auto n = v->node();
1344   if (n->kind().is_onnx() || n->kind().is_aten()) {
1345     // Encode value info only for non-onnx or non-ATen nodes.
1346     return;
1347   }
1348   if (n->owningGraph() != graph_.get()) {
1349     // Encode value info only for node in main graph.
1350     return;
1351   }
1352   for (const auto* o : graph_->outputs()) {
1353     // Do not encode value info for graph outputs.
1354     if (o == v) {
1355       return;
1356     }
1357   }
1358   auto v_info_p = graph_proto->add_value_info();
1359   EncodeValueInfo(graph_proto, v_info_p, v);
1360 }
1361 
1362 } // namespace
1363 
pretty_print_onnx(const std::shared_ptr<Graph> & graph,const std::map<std::string,at::Tensor> & initializers,int64_t onnx_opset_version,bool defer_weight_export,::torch::onnx::OperatorExportTypes operator_export_type,bool google_printer,bool keep_initializers_as_inputs,const std::map<std::string,int> & custom_opsets,bool add_node_names)1364 std::string pretty_print_onnx(
1365     const std::shared_ptr<Graph>& graph,
1366     const std::map<std::string, at::Tensor>& initializers,
1367     int64_t onnx_opset_version,
1368     bool defer_weight_export,
1369     ::torch::onnx::OperatorExportTypes operator_export_type,
1370     bool google_printer,
1371     bool keep_initializers_as_inputs,
1372     const std::map<std::string, int>& custom_opsets,
1373     bool add_node_names) {
1374   auto graph_encoder = GraphEncoder(
1375       graph,
1376       onnx_opset_version,
1377       operator_export_type,
1378       initializers,
1379       std::unordered_map<
1380           std::string,
1381           std::unordered_map<int64_t, std::string>>{},
1382       defer_weight_export,
1383       true,
1384       keep_initializers_as_inputs,
1385       custom_opsets,
1386       add_node_names,
1387       false,
1388       std::string());
1389   if (google_printer) {
1390     return graph_encoder.get_model_proto()->DebugString();
1391   }
1392   return prettyPrint(*graph_encoder.get_model_proto());
1393 }
1394 
1395 std::tuple<
1396     std::shared_ptr<::ONNX_NAMESPACE::ModelProto>,
1397     RawDataExportMap,
1398     SymbolDimMap,
1399     bool,
1400     NodeNameMap>
export_onnx(const std::shared_ptr<Graph> & graph,const std::map<std::string,at::Tensor> & initializers,int64_t onnx_opset_version,const std::unordered_map<std::string,std::unordered_map<std::int64_t,std::string>> & dynamic_axes,bool defer_weight_export,::torch::onnx::OperatorExportTypes operator_export_type,bool strip_doc_string,bool keep_initializers_as_inputs,const std::map<std::string,int> & custom_opsets,bool add_node_names,bool use_external_data_format,const std::string & onnx_file_path,const NodeAttrNameMap & node_attr_to_name)1401 export_onnx(
1402     const std::shared_ptr<Graph>& graph,
1403     const std::map<std::string, at::Tensor>& initializers,
1404     int64_t onnx_opset_version,
1405     const std::unordered_map<
1406         std::string,
1407         std::unordered_map<std::int64_t, std::string>>& dynamic_axes,
1408     bool defer_weight_export,
1409     ::torch::onnx::OperatorExportTypes operator_export_type,
1410     bool strip_doc_string,
1411     bool keep_initializers_as_inputs,
1412     const std::map<std::string, int>& custom_opsets,
1413     bool add_node_names,
1414     bool use_external_data_format,
1415     const std::string& onnx_file_path,
1416     const NodeAttrNameMap& node_attr_to_name) {
1417   auto graph_encoder = GraphEncoder(
1418       graph,
1419       onnx_opset_version,
1420       operator_export_type,
1421       initializers,
1422       dynamic_axes,
1423       defer_weight_export,
1424       strip_doc_string,
1425       keep_initializers_as_inputs,
1426       custom_opsets,
1427       add_node_names,
1428       use_external_data_format,
1429       onnx_file_path,
1430       node_attr_to_name);
1431   GRAPH_DEBUG("onnx proto:", prettyPrint(*graph_encoder.get_model_proto()));
1432   return std::make_tuple(
1433       graph_encoder.get_model_proto(),
1434       graph_encoder.get_raw_data_export_map(),
1435       graph_encoder.get_symbol_dim_param_map(),
1436       graph_encoder.get_use_external_data_format(),
1437       graph_encoder.get_onnx_node_names());
1438 }
1439 
serialize_model_proto_to_string(const std::shared_ptr<::ONNX_NAMESPACE::ModelProto> & model_proto)1440 std::string serialize_model_proto_to_string(
1441     const std::shared_ptr<::ONNX_NAMESPACE::ModelProto>& model_proto) {
1442   return model_proto->SerializeAsString();
1443 }
1444 
check_onnx_proto(const std::string & proto_string)1445 void check_onnx_proto(const std::string& proto_string) {
1446   onnx::ModelProto model;
1447   if (!ParseProtoFromBytes(&model, proto_string.c_str(), proto_string.size())) {
1448     throw std::runtime_error("Invalid ONNX proto string.");
1449     return;
1450   }
1451   // 1. baseline check
1452   // These two checks prevent broken graph being generated
1453   // And errors out exporting if that happens.
1454   onnx::checker::check_model(model);
1455   onnx::shape_inference::InferShapes(model);
1456   // 2. full check
1457   // apply strict mode shape type inference check which examines
1458   // whether it's a valid ONNX graph or not. As for some users, they
1459   // don't need a fully valid ONNX graph to run their model, we simply
1460   // add this information as warning message if it fails.
1461   try {
1462     auto* schema_registry = onnx::OpSchemaRegistry::Instance();
1463     onnx::ShapeInferenceOptions options{
1464         /*check_type_val=*/true,
1465         /*strict_mode_val=*/true};
1466     onnx::shape_inference::InferShapes(model, schema_registry, options);
1467   } catch (const onnx::InferenceError& ex) {
1468     TORCH_WARN(
1469         "The exported ONNX model failed ONNX shape inference. "
1470         "The model will not be executable by the ONNX Runtime. "
1471         "If this is unintended and you believe there is a bug, "
1472         "please report an issue at https://github.com/pytorch/pytorch/issues. "
1473         "Error reported by strict ONNX shape inference: ",
1474         ex.what());
1475   }
1476 }
1477 
1478 } // namespace torch::jit
1479