xref: /aosp_15_r20/external/pytorch/torch/csrc/onnx/init.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <onnx/onnx_pb.h>
2 #include <torch/csrc/onnx/back_compat.h>
3 #include <torch/csrc/onnx/init.h>
4 #include <torch/csrc/onnx/onnx.h>
5 #include <torch/version.h>
6 
7 #include <torch/csrc/Exceptions.h>
8 #include <torch/csrc/jit/passes/onnx.h>
9 #include <torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h>
10 #include <torch/csrc/jit/passes/onnx/constant_fold.h>
11 #include <torch/csrc/jit/passes/onnx/deduplicate_initializers.h>
12 #include <torch/csrc/jit/passes/onnx/eliminate_unused_items.h>
13 #include <torch/csrc/jit/passes/onnx/eval_peephole.h>
14 #include <torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h>
15 #include <torch/csrc/jit/passes/onnx/function_extraction.h>
16 #include <torch/csrc/jit/passes/onnx/function_substitution.h>
17 #include <torch/csrc/jit/passes/onnx/list_model_parameters.h>
18 #include <torch/csrc/jit/passes/onnx/naming.h>
19 #include <torch/csrc/jit/passes/onnx/onnx_log.h>
20 #include <torch/csrc/jit/passes/onnx/pattern_conversion/autograd_function_process.h>
21 #include <torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.h>
22 #include <torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.h>
23 #include <torch/csrc/jit/passes/onnx/peephole.h>
24 #include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h>
25 #include <torch/csrc/jit/passes/onnx/preprocess_for_onnx.h>
26 #include <torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.h>
27 #include <torch/csrc/jit/passes/onnx/scalar_type_analysis.h>
28 #include <torch/csrc/jit/passes/onnx/shape_type_inference.h>
29 #include <torch/csrc/jit/passes/onnx/unpack_quantized_weights.h>
30 #include <torch/csrc/jit/serialization/export.h>
31 
32 namespace torch::onnx {
33 
34 using namespace torch::jit;
35 
initONNXBindings(PyObject * module)36 void initONNXBindings(PyObject* module) {
37   auto m = py::handle(module).cast<py::module>();
38 
39   // ONNX specific passes
40   m.def("_jit_pass_onnx_remove_print", RemovePrintOps)
41       .def("_jit_pass_onnx_preprocess_caffe2", PreprocessCaffe2Ops)
42       .def("_jit_pass_onnx", ToONNX)
43       .def(
44           "_jit_pass_onnx_assign_output_shape",
45           ::torch::wrap_pybind_function(
46               [](std::shared_ptr<Graph>& graph,
47                  const std::vector<at::Tensor>& tensors,
48                  const python::IODescriptor& desc,
49                  bool onnx_shape_inference,
50                  bool is_script,
51                  int opset_version) {
52                 ONNXAssignOutputShape(
53                     graph,
54                     tensors,
55                     desc,
56                     onnx_shape_inference,
57                     is_script,
58                     opset_version);
59               }))
60       .def(
61           "_jit_pass_onnx_function_substitution",
62           wrap_pybind_function(ONNXFunctionCallSubstitution))
63       .def(
64           "_jit_pass_onnx_autograd_function_process",
65           wrap_pybind_function(ONNXAutogradFunctionProcess))
66       .def(
67           "_jit_pass_onnx_peephole",
68           ::torch::wrap_pybind_function([](std::shared_ptr<Graph>& graph,
69                                            int opset_version,
70                                            bool fixed_batch_size) {
71             return PeepholeOptimizeONNX(graph, opset_version, fixed_batch_size);
72           }))
73       .def(
74           "_jit_pass_onnx_preprocess",
75           ::torch::wrap_pybind_function(PreprocessForONNX))
76       .def(
77           "_jit_pass_onnx_eval_peephole",
78           ::torch::wrap_pybind_function(
79               [](std::shared_ptr<Graph>& graph,
80                  std::map<std::string, IValue>& paramsDict) {
81                 EvalPeepholeONNX(graph, paramsDict);
82                 return paramsDict;
83               }),
84           pybind11::return_value_policy::move)
85       .def(
86           "_jit_pass_onnx_cast_all_constant_to_floating",
87           ::torch::wrap_pybind_function(CastAllConstantToFloating))
88       .def(
89           "_jit_pass_onnx_constant_fold",
90           ::torch::wrap_pybind_function(
91               [](std::shared_ptr<Graph>& graph,
92                  std::map<std::string, IValue>& paramsDict,
93                  int opset_version) {
94                 ConstantFoldONNX(
95                     graph,
96                     paramsDict,
97                     opset_version); // overload resolution
98                 return paramsDict;
99               }),
100           pybind11::return_value_policy::move)
101       .def(
102           "_jit_pass_onnx_eliminate_unused_items",
103           ::torch::wrap_pybind_function(
104               [](std::shared_ptr<Graph>& graph,
105                  std::map<std::string, IValue>& paramsDict) {
106                 EliminateUnusedItemsONNX(
107                     graph->block(),
108                     paramsDict); // overload resolution
109                 return paramsDict;
110               }),
111           pybind11::return_value_policy::move)
112       .def(
113           "_jit_pass_onnx_scalar_type_analysis",
114           ::torch::wrap_pybind_function([](std::shared_ptr<Graph>& graph,
115                                            bool lowprecision_cast,
116                                            int opset_version) {
117             return ScalarTypeAnalysisForONNX(
118                 graph, lowprecision_cast, opset_version);
119           }),
120           py::arg("graph"),
121           py::arg("lowprecision_cast") = true,
122           py::arg("opset_version"))
123       .def(
124           "_jit_pass_onnx_remove_inplace_ops_for_onnx",
125           ::torch::wrap_pybind_function(RemoveInplaceOpsForONNX))
126       .def(
127           "_jit_pass_onnx_node_shape_type_inference",
128           ::torch::wrap_pybind_function(
129               [](Node* n,
130                  std::map<std::string, IValue>& params_dict,
131                  int opset_version) {
132                 ONNXShapeTypeInference(n, params_dict, opset_version);
133               }))
134       .def(
135           "_jit_pass_onnx_graph_shape_type_inference",
136           ::torch::wrap_pybind_function(
137               [](std::shared_ptr<Graph>& graph,
138                  std::map<std::string, IValue>& params_dict,
139                  int opset_version) {
140                 ONNXShapeTypeInference(graph, params_dict, opset_version);
141               }),
142           py::arg("graph"),
143           py::arg("params_dict"),
144           py::arg("opset_version"))
145       .def(
146           "_jit_pass_onnx_set_dynamic_input_shape",
147           ::torch::wrap_pybind_function(ONNXSetDynamicInputShape))
148       .def("_jit_pass_onnx_lint", torch::wrap_pybind_function(ONNXLintGraph))
149       .def(
150           "_jit_pass_onnx_function_extraction",
151           ::torch::wrap_pybind_function(
152               torch::jit::onnx::ONNXFunctionExtraction))
153       .def("_jit_pass_onnx_block", torch::wrap_pybind_function(BlockToONNX))
154       .def(
155           "_jit_pass_onnx_unpack_quantized_weights",
156           ::torch::wrap_pybind_function(
157               [](std::shared_ptr<Graph>& graph,
158                  std::map<std::string, IValue>& paramsDict) {
159                 UnpackQuantizedWeights(graph, paramsDict);
160                 return paramsDict;
161               }),
162           pybind11::return_value_policy::move)
163       .def(
164           "_jit_pass_onnx_quantization_insert_permutes",
165           ::torch::wrap_pybind_function(
166               [](std::shared_ptr<Graph>& graph,
167                  std::map<std::string, IValue>& paramsDict) {
168                 insertPermutes(graph, paramsDict);
169                 return paramsDict;
170               }),
171           pybind11::return_value_policy::move)
172       .def(
173           "_jit_onnx_list_model_parameters",
174           ::torch::wrap_pybind_function(
175               [](Module& module) { return list_module_parameters(module); }))
176       .def(
177           "_jit_pass_prepare_division_for_onnx",
178           ::torch::wrap_pybind_function(PrepareDivisionForONNX))
179       .def(
180           "_jit_onnx_convert_pattern_from_subblock",
181           ::torch::wrap_pybind_function(ConvertPatternFromSubblock))
182       .def(
183           "_jit_pass_fixup_onnx_controlflow_node",
184           ::torch::wrap_pybind_function(FixupONNXControlflowNode))
185       .def(
186           "_jit_pass_onnx_deduplicate_initializers",
187           ::torch::wrap_pybind_function(
188               [](std::shared_ptr<Graph>& graph,
189                  std::map<std::string, IValue> params_dict,
190                  bool is_train) {
191                 DeduplicateInitializers(graph, params_dict, is_train);
192                 return params_dict;
193               }),
194           pybind11::return_value_policy::move)
195       .def(
196           "_jit_pass_onnx_clear_scope_records",
197           &torch::jit::onnx::ONNXClearScopeRecords)
198       .def(
199           "_jit_pass_onnx_track_scope_attributes",
200           &torch::jit::onnx::ONNXTrackScopeAttributes)
201       .def(
202           "_jit_is_onnx_log_enabled",
203           ::torch::jit::onnx::is_log_enabled,
204           "Returns whether ONNX logging is enabled or disabled.")
205       .def(
206           "_jit_set_onnx_log_enabled",
207           ::torch::jit::onnx::set_log_enabled,
208           "Enables or disables ONNX logging.")
209       .def(
210           "_jit_set_onnx_log_output_stream",
211           [](const std::string& stream_name = "stdout") -> void {
212             std::shared_ptr<std::ostream> out;
213             if (stream_name == "stdout") {
214               out = std::shared_ptr<std::ostream>(
215                   &std::cout, [](std::ostream*) {});
216             } else if (stream_name == "stderr") {
217               out = std::shared_ptr<std::ostream>(
218                   &std::cerr, [](std::ostream*) {});
219             } else {
220               std::cerr << "ERROR: only `stdout` and `stderr`"
221                         << "are supported as `stream_name`" << std::endl;
222             }
223             ::torch::jit::onnx::set_log_output_stream(out);
224           },
225           "Set specific file stream for ONNX logging.")
226       .def(
227           "_jit_onnx_log",
228           [](const py::args& args) -> void {
229             if (::torch::jit::onnx::is_log_enabled()) {
230               auto& out = ::torch::jit::onnx::_get_log_output_stream();
231               for (auto arg : args) {
232                 out << ::c10::str(arg);
233               }
234               out << std::endl;
235             }
236           },
237           "Write `args` to the previously specified ONNX log stream.")
238       .def(
239           "_jit_pass_onnx_assign_scoped_names_for_node_and_value",
240           ::torch::wrap_pybind_function(
241               ::torch::jit::onnx::AssignScopedNamesForNodeAndValue),
242           "Assign informative scoped names for nodes and values.")
243       .def(
244           "_jit_onnx_create_full_scope_name",
245           ::torch::wrap_pybind_function(
246               ::torch::jit::onnx::ONNXScopeName::createFullScopeName),
247           "Create a full scope name from class name and variable name.");
248 
249   m.def(
250       "_check_onnx_proto",
251       ::torch::wrap_pybind_function([](const std::string& proto_string) {
252         check_onnx_proto(proto_string);
253       }),
254       py::arg("proto_string"));
255 
256   auto onnx = m.def_submodule("_onnx");
257   py::enum_<::ONNX_NAMESPACE::TensorProto_DataType>(onnx, "TensorProtoDataType")
258       .value("UNDEFINED", ::ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED)
259       .value("FLOAT", ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT)
260       .value("UINT8", ::ONNX_NAMESPACE::TensorProto_DataType_UINT8)
261       .value("INT8", ::ONNX_NAMESPACE::TensorProto_DataType_INT8)
262       .value("UINT16", ::ONNX_NAMESPACE::TensorProto_DataType_UINT16)
263       .value("INT16", ::ONNX_NAMESPACE::TensorProto_DataType_INT16)
264       .value("INT32", ::ONNX_NAMESPACE::TensorProto_DataType_INT32)
265       .value("INT64", ::ONNX_NAMESPACE::TensorProto_DataType_INT64)
266       .value("STRING", ::ONNX_NAMESPACE::TensorProto_DataType_STRING)
267       .value("BOOL", ::ONNX_NAMESPACE::TensorProto_DataType_BOOL)
268       .value("FLOAT16", ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)
269       .value("DOUBLE", ::ONNX_NAMESPACE::TensorProto_DataType_DOUBLE)
270       .value("UINT32", ::ONNX_NAMESPACE::TensorProto_DataType_UINT32)
271       .value("UINT64", ::ONNX_NAMESPACE::TensorProto_DataType_UINT64)
272       .value("COMPLEX64", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64)
273       .value("COMPLEX128", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128)
274       .value("BFLOAT16", ::ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)
275       .value("FLOAT8E4M3FN", ::torch::onnx::TensorProto_DataType_FLOAT8E4M3FN)
276       .value(
277           "FLOAT8E4M3FNUZ", ::torch::onnx::TensorProto_DataType_FLOAT8E4M3FNUZ)
278       .value("FLOAT8E5M2", ::torch::onnx::TensorProto_DataType_FLOAT8E5M2)
279       .value(
280           "FLOAT8E5M2FNUZ", ::torch::onnx::TensorProto_DataType_FLOAT8E5M2FNUZ);
281 
282   py::enum_<OperatorExportTypes>(onnx, "OperatorExportTypes")
283       .value("ONNX", OperatorExportTypes::ONNX)
284       .value("ONNX_ATEN", OperatorExportTypes::ONNX_ATEN)
285       .value("ONNX_ATEN_FALLBACK", OperatorExportTypes::ONNX_ATEN_FALLBACK)
286       .value("ONNX_FALLTHROUGH", OperatorExportTypes::ONNX_FALLTHROUGH);
287 
288   py::enum_<TrainingMode>(onnx, "TrainingMode")
289       .value("EVAL", TrainingMode::EVAL)
290       .value("PRESERVE", TrainingMode::PRESERVE)
291       .value("TRAINING", TrainingMode::TRAINING);
292 
293   onnx.attr("PRODUCER_VERSION") = py::str(TORCH_VERSION);
294 }
295 } // namespace torch::onnx
296