xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/serialization/export.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <caffe2/serialize/inline_container.h>
4 #include <torch/csrc/jit/api/module.h>
5 #include <torch/csrc/jit/ir/ir.h>
6 #include <torch/csrc/jit/serialization/export_bytecode.h>
7 #include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
8 #include <torch/csrc/jit/serialization/pickler.h>
9 #include <torch/csrc/jit/serialization/python_print.h>
10 #include <torch/csrc/jit/serialization/storage_context.h>
11 #include <torch/csrc/jit/serialization/type_name_uniquer.h>
12 #include <torch/csrc/onnx/onnx.h>
13 #include <ostream>
14 
15 namespace ONNX_NAMESPACE {
16 class ModelProto;
17 }
18 
19 namespace torch::jit {
20 
21 // This map is used to keep track of parameters that should be exported
22 // externally. When `defer_weight_export` is true, the returned map contains
23 // kv pairs that map {external reference name} -> {at::Tensor to be exported}.
24 // It is the responsibility of the caller to export these appropriately.
25 //
26 // For example, when exporting to a zip archive, the caller may write out files
27 // for each entry in the export map, with the filename being the key and the
28 // file contents being the raw tensor data.
29 using RawDataExportMap = std::unordered_map<std::string, at::Tensor>;
30 
31 using SymbolDimMap = std::map<c10::ShapeSymbol, std::string>;
32 using DimSymbolMap = std::map<std::string, c10::ShapeSymbol>;
33 
34 using NodeNameMap = std::unordered_map<const Node*, std::string>;
35 
36 // Used for modularized export settling function and node attributes.
37 using NodeAttrNameMap = std::
38     unordered_map<const Node*, std::unordered_map<std::string, std::string>>;
39 
40 TORCH_API std::tuple<
41     std::shared_ptr<::ONNX_NAMESPACE::ModelProto>,
42     RawDataExportMap,
43     SymbolDimMap,
44     bool,
45     NodeNameMap>
46 export_onnx(
47     const std::shared_ptr<Graph>& graph,
48     const std::map<std::string, at::Tensor>& initializers,
49     int64_t onnx_opset_version,
50     const std::unordered_map<
51         std::string,
52         std::unordered_map<int64_t, std::string>>& dynamic_axes,
53     bool defer_weight_export = false,
54     ::torch::onnx::OperatorExportTypes operator_export_type =
55         ::torch::onnx::OperatorExportTypes::ONNX,
56     bool strip_doc_string = true,
57     bool keep_initializers_as_inputs = true,
58     const std::map<std::string, int>& custom_opsets = {},
59     bool add_node_names = true,
60     bool use_external_data_format = false,
61     const std::string& onnx_file_path = std::string(),
62     const NodeAttrNameMap& node_attr_to_name = {});
63 
64 TORCH_API std::string serialize_model_proto_to_string(
65     const std::shared_ptr<::ONNX_NAMESPACE::ModelProto>& model_proto);
66 
67 TORCH_API void check_onnx_proto(const std::string& proto_string);
68 
69 // Serializer for both oldsyle and unified format TorchScript serialization
70 class TORCH_API ScriptModuleSerializer {
71  public:
ScriptModuleSerializer(caffe2::serialize::PyTorchStreamWriter & export_writer)72   explicit ScriptModuleSerializer(
73       caffe2::serialize::PyTorchStreamWriter& export_writer)
74       : writer_(export_writer) {}
75 
76   void writeFiles(const std::string& code_dir);
77   void serialize(
78       const Module& module,
79       const ExtraFilesMap& extra_files,
80       bool bytecode_format,
81       bool save_mobile_debug_info);
82   void serialize_unified_format(Module& module, uint64_t script_module_id);
83   SerializationStorageContext& storage_context();
84 
85   ~ScriptModuleSerializer() = default;
86 
87  private:
88   void convertNamedType(const c10::NamedTypePtr& class_type);
89   void convertTypes(const at::NamedTypePtr& root_type);
90   void writeExtraFiles(const Module& module, const ExtraFilesMap& extra_files);
91   void writeByteCode(const Module& module, bool save_mobile_debug_info);
92   void writeArchive(
93       const IValue& value,
94       const std::string& archive_name,
95       const std::string& archive_dir,
96       const std::string& tensor_dir,
97       bool use_storage_context = false,
98       bool skip_tensor_data = false);
99   void updateSourceRangeTags(const SourceRangeRecords& ranges);
100 
101   caffe2::serialize::PyTorchStreamWriter& writer_;
102   std::vector<at::IValue> constant_table_;
103 
104   std::unordered_set<c10::NamedTypePtr> converted_types_;
105   PrintDepsTable class_deps_;
106   TypeNameUniquer type_name_uniquer_;
107   // qualifier, e.g. '__torch__.Bar' -> PythonPrint for the file that will be
108   // created
109   OrderedDict<std::string, PythonPrint> file_streams_;
110   // Used to keep references of storages around during serialization to solve
111   // for ABA memory reuse problem hit when storages are created/destroyed
112   // during serialization process. Also used to coordinate sharing of storages
113   // between Script and eager modules in torch.package.
114   SerializationStorageContext storage_context_;
115 
116   // Uniquely identifies a SourceRange in a model.
117   // SourceRanges are associated with Nodes of Graphs.
118   // However for mobile deployment we dont intend to ship
119   // full JIT with capabilities of reading code and constructing
120   // graphs.
121   // Instead we serialize the Code generated from graph of the methods.
122   // Code is serialized in bytecode format that contains instructions
123   // corresponding to the nodes of the graph. Since original graph is gone, the
124   // question is how do we identify where the ops, in serialized bytecode, come
125   // from in original model code. We do this in two parts.
126   // 1. Associate a unique tag to SourceRange.
127   // 2. Serialize this unique_tag.
128   //  2.1 Meaning save <byte_offset, source_range_tag, source range> instead of
129   //      <byte_offset, source range>
130   // 3. During serializing model for mobile, i.e. bytecode generation,
131   //    save unique tag of SourceRange corresponding to the Node.
132   // 4. During deserialization, read all the debug_pkl, to construct a map
133   //    of <unique_tag, SourceRange> and use tag saved with OPs in bytecode
134   //    to lookup the source range.
135   // Strictly speaking we will serialize InlinedCallStack directly, which
136   // contains SourceRange. This way we have access to entire callstack and not
137   // just source information about where the node is, since bytecode inlines the
138   // graph before saving it.
139   SourceRangeTagMap source_range_tags_;
140   int64_t current_source_range_tag_{0};
141 };
142 
143 // For testing purposes
144 TORCH_API std::string pretty_print_onnx(
145     const std::shared_ptr<Graph>& graph,
146     const std::map<std::string, at::Tensor>& initializers,
147     int64_t onnx_opset_version,
148     bool defer_weight_export,
149     ::torch::onnx::OperatorExportTypes operator_export_type =
150         ::torch::onnx::OperatorExportTypes::ONNX,
151     bool google_printer = false,
152     bool keep_initializers_as_inputs = true,
153     const std::map<std::string, int>& custom_opsets = {},
154     bool add_node_names = true);
155 
156 TORCH_API void ExportModule(
157     const Module& module,
158     std::ostream& out,
159     const ExtraFilesMap& metadata = ExtraFilesMap(),
160     bool bytecode_format = false,
161     bool save_mobile_debug_info = false,
162     bool use_flatbuffer = false);
163 
164 TORCH_API void ExportModule(
165     const Module& module,
166     const std::string& filename,
167     const ExtraFilesMap& metadata = ExtraFilesMap(),
168     bool bytecode_format = false,
169     bool save_mobile_debug_info = false,
170     bool use_flatbuffer = false);
171 
172 TORCH_API void ExportModule(
173     const Module& module,
174     const std::function<size_t(const void*, size_t)>& writer_func,
175     const ExtraFilesMap& metadata = ExtraFilesMap(),
176     bool bytecode_format = false,
177     bool save_mobile_debug_info = false,
178     bool use_flatbuffer = false);
179 
180 // Write the bytes of a pickle archive and the tensors referenced inside that
181 // archive
182 TORCH_API void writeArchiveAndTensors(
183     const std::string& archive_name,
184     const char* pickle_bytes,
185     size_t size,
186     const std::vector<at::Tensor>& tensors,
187     caffe2::serialize::PyTorchStreamWriter& out);
188 
189 // Surrounding system can install an additional hook to produce extra files
190 // with metadata based on environment every time a module is serialized.
191 using ExportModuleExtraFilesHook = std::function<ExtraFilesMap(const Module&)>;
192 TORCH_API void SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook);
193 
194 /**
195  * Generates new bytecode for a Script module and returns what the op list
196  * would be for a LiteScriptModule based off the current code base. If you
197  * have a LiteScriptModule and want to get the currently present
198  * list of ops call _export_operator_list instead.
199  */
200 TORCH_API std::vector<std::string> export_opnames(const Module& m);
201 
202 struct TORCH_API BytecodeEmitMode {
203   static bool is_default_value_for_unspecified_arg_enabled();
204   static void set_default_value_for_unspecified_arg_enabled(bool enabled);
205 
206   static bool is_default_args_before_out_args_enabled();
207   static void set_default_args_before_out_args_enabled(bool enabled);
208 
209   static bool is_emit_promoted_ops_enabled();
210   static void set_default_emit_promoted_ops_enabled(bool enabled);
211 };
212 
213 // RAII guard to switch the way JIT emits the bytecode for inputs.
214 // default_value_for_unspecified_arg:
215 // true: instruction of default argument values (like LOADC) is emitted.
216 // false: instruction of default argument values are not emitted. Instead
217 // they are fetched from operator schema.
218 // default_args_before_out_args (to forward compatibile support
219 // operators allowing out arguments and default arguments):
220 // true: the number of specified arguments will deserialized to (#all_args -
221 // #default_args). false: the number of specified arguments will deserialized to
222 // (#all_args).
223 struct TORCH_API BytecodeEmitModeGuard {
BytecodeEmitModeGuardBytecodeEmitModeGuard224   BytecodeEmitModeGuard(
225       bool enable_default_value_for_unspecified_arg,
226       bool enable_default_args_before_out_args,
227       bool enable_emit_promoted_ops)
228       : prev_default_value_for_unspecified_arg_mode(
229             BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled()),
230         prev_default_args_before_out_args(
231             BytecodeEmitMode::is_default_args_before_out_args_enabled()),
232         prev_default_emit_promoted_ops(
233             BytecodeEmitMode::is_emit_promoted_ops_enabled()) {
234     BytecodeEmitMode::set_default_value_for_unspecified_arg_enabled(
235         enable_default_value_for_unspecified_arg);
236     BytecodeEmitMode::set_default_args_before_out_args_enabled(
237         enable_default_args_before_out_args);
238     BytecodeEmitMode::set_default_emit_promoted_ops_enabled(
239         enable_emit_promoted_ops);
240   }
~BytecodeEmitModeGuardBytecodeEmitModeGuard241   ~BytecodeEmitModeGuard() {
242     BytecodeEmitMode::set_default_value_for_unspecified_arg_enabled(
243         prev_default_value_for_unspecified_arg_mode);
244     BytecodeEmitMode::set_default_args_before_out_args_enabled(
245         prev_default_args_before_out_args);
246     BytecodeEmitMode::set_default_emit_promoted_ops_enabled(
247         prev_default_emit_promoted_ops);
248   }
249   bool prev_default_value_for_unspecified_arg_mode;
250   bool prev_default_args_before_out_args;
251   bool prev_default_emit_promoted_ops;
252 };
253 
254 TORCH_API IValue to_tuple(std::vector<IValue> ivalues);
255 TORCH_API IValue
256 Table(const std::vector<std::pair<std::string, IValue>>& entries);
257 
258 // TODO remove these switches once interface call is rolled out.
259 TORCH_API void enableMobileInterfaceCallExport();
260 bool getMobileInterfaceCallExport();
261 
262 TORCH_API CompilationOptions getOptionsFromGlobal();
263 
264 TORCH_API void save_jit_module(
265     const Module& module,
266     const std::string& filename,
267     const ExtraFilesMap& extra_files = ExtraFilesMap());
268 
269 TORCH_API DetachedBuffer::UniqueDetachedBuffer save_jit_module_to_bytes(
270     const Module& module,
271     const ExtraFilesMap& extra_files = ExtraFilesMap());
272 
273 TORCH_API void save_jit_module_to_write_func(
274     const Module& module,
275     const ExtraFilesMap& extra_files,
276     bool save_mobile_debug_info,
277     const std::function<size_t(const void*, size_t)>& writer_func);
278 
279 } // namespace torch::jit
280