1 #pragma once 2 3 #include <caffe2/serialize/inline_container.h> 4 #include <torch/csrc/jit/api/module.h> 5 #include <torch/csrc/jit/ir/ir.h> 6 #include <torch/csrc/jit/serialization/export_bytecode.h> 7 #include <torch/csrc/jit/serialization/flatbuffer_serializer.h> 8 #include <torch/csrc/jit/serialization/pickler.h> 9 #include <torch/csrc/jit/serialization/python_print.h> 10 #include <torch/csrc/jit/serialization/storage_context.h> 11 #include <torch/csrc/jit/serialization/type_name_uniquer.h> 12 #include <torch/csrc/onnx/onnx.h> 13 #include <ostream> 14 15 namespace ONNX_NAMESPACE { 16 class ModelProto; 17 } 18 19 namespace torch::jit { 20 21 // This map is used to keep track of parameters that should be exported 22 // externally. When `defer_weight_export` is true, the returned map contains 23 // kv pairs that map {external reference name} -> {at::Tensor to be exported}. 24 // It is the responsibility of the caller to export these appropriately. 25 // 26 // For example, when exporting to a zip archive, the caller may write out files 27 // for each entry in the export map, with the filename being the key and the 28 // file contents being the raw tensor data. 29 using RawDataExportMap = std::unordered_map<std::string, at::Tensor>; 30 31 using SymbolDimMap = std::map<c10::ShapeSymbol, std::string>; 32 using DimSymbolMap = std::map<std::string, c10::ShapeSymbol>; 33 34 using NodeNameMap = std::unordered_map<const Node*, std::string>; 35 36 // Used for modularized export settling function and node attributes. 37 using NodeAttrNameMap = std:: 38 unordered_map<const Node*, std::unordered_map<std::string, std::string>>; 39 40 TORCH_API std::tuple< 41 std::shared_ptr<::ONNX_NAMESPACE::ModelProto>, 42 RawDataExportMap, 43 SymbolDimMap, 44 bool, 45 NodeNameMap> 46 export_onnx( 47 const std::shared_ptr<Graph>& graph, 48 const std::map<std::string, at::Tensor>& initializers, 49 int64_t onnx_opset_version, 50 const std::unordered_map< 51 std::string, 52 std::unordered_map<int64_t, std::string>>& dynamic_axes, 53 bool defer_weight_export = false, 54 ::torch::onnx::OperatorExportTypes operator_export_type = 55 ::torch::onnx::OperatorExportTypes::ONNX, 56 bool strip_doc_string = true, 57 bool keep_initializers_as_inputs = true, 58 const std::map<std::string, int>& custom_opsets = {}, 59 bool add_node_names = true, 60 bool use_external_data_format = false, 61 const std::string& onnx_file_path = std::string(), 62 const NodeAttrNameMap& node_attr_to_name = {}); 63 64 TORCH_API std::string serialize_model_proto_to_string( 65 const std::shared_ptr<::ONNX_NAMESPACE::ModelProto>& model_proto); 66 67 TORCH_API void check_onnx_proto(const std::string& proto_string); 68 69 // Serializer for both oldsyle and unified format TorchScript serialization 70 class TORCH_API ScriptModuleSerializer { 71 public: ScriptModuleSerializer(caffe2::serialize::PyTorchStreamWriter & export_writer)72 explicit ScriptModuleSerializer( 73 caffe2::serialize::PyTorchStreamWriter& export_writer) 74 : writer_(export_writer) {} 75 76 void writeFiles(const std::string& code_dir); 77 void serialize( 78 const Module& module, 79 const ExtraFilesMap& extra_files, 80 bool bytecode_format, 81 bool save_mobile_debug_info); 82 void serialize_unified_format(Module& module, uint64_t script_module_id); 83 SerializationStorageContext& storage_context(); 84 85 ~ScriptModuleSerializer() = default; 86 87 private: 88 void convertNamedType(const c10::NamedTypePtr& class_type); 89 void convertTypes(const at::NamedTypePtr& root_type); 90 void writeExtraFiles(const Module& module, const ExtraFilesMap& extra_files); 91 void writeByteCode(const Module& module, bool save_mobile_debug_info); 92 void writeArchive( 93 const IValue& value, 94 const std::string& archive_name, 95 const std::string& archive_dir, 96 const std::string& tensor_dir, 97 bool use_storage_context = false, 98 bool skip_tensor_data = false); 99 void updateSourceRangeTags(const SourceRangeRecords& ranges); 100 101 caffe2::serialize::PyTorchStreamWriter& writer_; 102 std::vector<at::IValue> constant_table_; 103 104 std::unordered_set<c10::NamedTypePtr> converted_types_; 105 PrintDepsTable class_deps_; 106 TypeNameUniquer type_name_uniquer_; 107 // qualifier, e.g. '__torch__.Bar' -> PythonPrint for the file that will be 108 // created 109 OrderedDict<std::string, PythonPrint> file_streams_; 110 // Used to keep references of storages around during serialization to solve 111 // for ABA memory reuse problem hit when storages are created/destroyed 112 // during serialization process. Also used to coordinate sharing of storages 113 // between Script and eager modules in torch.package. 114 SerializationStorageContext storage_context_; 115 116 // Uniquely identifies a SourceRange in a model. 117 // SourceRanges are associated with Nodes of Graphs. 118 // However for mobile deployment we dont intend to ship 119 // full JIT with capabilities of reading code and constructing 120 // graphs. 121 // Instead we serialize the Code generated from graph of the methods. 122 // Code is serialized in bytecode format that contains instructions 123 // corresponding to the nodes of the graph. Since original graph is gone, the 124 // question is how do we identify where the ops, in serialized bytecode, come 125 // from in original model code. We do this in two parts. 126 // 1. Associate a unique tag to SourceRange. 127 // 2. Serialize this unique_tag. 128 // 2.1 Meaning save <byte_offset, source_range_tag, source range> instead of 129 // <byte_offset, source range> 130 // 3. During serializing model for mobile, i.e. bytecode generation, 131 // save unique tag of SourceRange corresponding to the Node. 132 // 4. During deserialization, read all the debug_pkl, to construct a map 133 // of <unique_tag, SourceRange> and use tag saved with OPs in bytecode 134 // to lookup the source range. 135 // Strictly speaking we will serialize InlinedCallStack directly, which 136 // contains SourceRange. This way we have access to entire callstack and not 137 // just source information about where the node is, since bytecode inlines the 138 // graph before saving it. 139 SourceRangeTagMap source_range_tags_; 140 int64_t current_source_range_tag_{0}; 141 }; 142 143 // For testing purposes 144 TORCH_API std::string pretty_print_onnx( 145 const std::shared_ptr<Graph>& graph, 146 const std::map<std::string, at::Tensor>& initializers, 147 int64_t onnx_opset_version, 148 bool defer_weight_export, 149 ::torch::onnx::OperatorExportTypes operator_export_type = 150 ::torch::onnx::OperatorExportTypes::ONNX, 151 bool google_printer = false, 152 bool keep_initializers_as_inputs = true, 153 const std::map<std::string, int>& custom_opsets = {}, 154 bool add_node_names = true); 155 156 TORCH_API void ExportModule( 157 const Module& module, 158 std::ostream& out, 159 const ExtraFilesMap& metadata = ExtraFilesMap(), 160 bool bytecode_format = false, 161 bool save_mobile_debug_info = false, 162 bool use_flatbuffer = false); 163 164 TORCH_API void ExportModule( 165 const Module& module, 166 const std::string& filename, 167 const ExtraFilesMap& metadata = ExtraFilesMap(), 168 bool bytecode_format = false, 169 bool save_mobile_debug_info = false, 170 bool use_flatbuffer = false); 171 172 TORCH_API void ExportModule( 173 const Module& module, 174 const std::function<size_t(const void*, size_t)>& writer_func, 175 const ExtraFilesMap& metadata = ExtraFilesMap(), 176 bool bytecode_format = false, 177 bool save_mobile_debug_info = false, 178 bool use_flatbuffer = false); 179 180 // Write the bytes of a pickle archive and the tensors referenced inside that 181 // archive 182 TORCH_API void writeArchiveAndTensors( 183 const std::string& archive_name, 184 const char* pickle_bytes, 185 size_t size, 186 const std::vector<at::Tensor>& tensors, 187 caffe2::serialize::PyTorchStreamWriter& out); 188 189 // Surrounding system can install an additional hook to produce extra files 190 // with metadata based on environment every time a module is serialized. 191 using ExportModuleExtraFilesHook = std::function<ExtraFilesMap(const Module&)>; 192 TORCH_API void SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook); 193 194 /** 195 * Generates new bytecode for a Script module and returns what the op list 196 * would be for a LiteScriptModule based off the current code base. If you 197 * have a LiteScriptModule and want to get the currently present 198 * list of ops call _export_operator_list instead. 199 */ 200 TORCH_API std::vector<std::string> export_opnames(const Module& m); 201 202 struct TORCH_API BytecodeEmitMode { 203 static bool is_default_value_for_unspecified_arg_enabled(); 204 static void set_default_value_for_unspecified_arg_enabled(bool enabled); 205 206 static bool is_default_args_before_out_args_enabled(); 207 static void set_default_args_before_out_args_enabled(bool enabled); 208 209 static bool is_emit_promoted_ops_enabled(); 210 static void set_default_emit_promoted_ops_enabled(bool enabled); 211 }; 212 213 // RAII guard to switch the way JIT emits the bytecode for inputs. 214 // default_value_for_unspecified_arg: 215 // true: instruction of default argument values (like LOADC) is emitted. 216 // false: instruction of default argument values are not emitted. Instead 217 // they are fetched from operator schema. 218 // default_args_before_out_args (to forward compatibile support 219 // operators allowing out arguments and default arguments): 220 // true: the number of specified arguments will deserialized to (#all_args - 221 // #default_args). false: the number of specified arguments will deserialized to 222 // (#all_args). 223 struct TORCH_API BytecodeEmitModeGuard { BytecodeEmitModeGuardBytecodeEmitModeGuard224 BytecodeEmitModeGuard( 225 bool enable_default_value_for_unspecified_arg, 226 bool enable_default_args_before_out_args, 227 bool enable_emit_promoted_ops) 228 : prev_default_value_for_unspecified_arg_mode( 229 BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled()), 230 prev_default_args_before_out_args( 231 BytecodeEmitMode::is_default_args_before_out_args_enabled()), 232 prev_default_emit_promoted_ops( 233 BytecodeEmitMode::is_emit_promoted_ops_enabled()) { 234 BytecodeEmitMode::set_default_value_for_unspecified_arg_enabled( 235 enable_default_value_for_unspecified_arg); 236 BytecodeEmitMode::set_default_args_before_out_args_enabled( 237 enable_default_args_before_out_args); 238 BytecodeEmitMode::set_default_emit_promoted_ops_enabled( 239 enable_emit_promoted_ops); 240 } ~BytecodeEmitModeGuardBytecodeEmitModeGuard241 ~BytecodeEmitModeGuard() { 242 BytecodeEmitMode::set_default_value_for_unspecified_arg_enabled( 243 prev_default_value_for_unspecified_arg_mode); 244 BytecodeEmitMode::set_default_args_before_out_args_enabled( 245 prev_default_args_before_out_args); 246 BytecodeEmitMode::set_default_emit_promoted_ops_enabled( 247 prev_default_emit_promoted_ops); 248 } 249 bool prev_default_value_for_unspecified_arg_mode; 250 bool prev_default_args_before_out_args; 251 bool prev_default_emit_promoted_ops; 252 }; 253 254 TORCH_API IValue to_tuple(std::vector<IValue> ivalues); 255 TORCH_API IValue 256 Table(const std::vector<std::pair<std::string, IValue>>& entries); 257 258 // TODO remove these switches once interface call is rolled out. 259 TORCH_API void enableMobileInterfaceCallExport(); 260 bool getMobileInterfaceCallExport(); 261 262 TORCH_API CompilationOptions getOptionsFromGlobal(); 263 264 TORCH_API void save_jit_module( 265 const Module& module, 266 const std::string& filename, 267 const ExtraFilesMap& extra_files = ExtraFilesMap()); 268 269 TORCH_API DetachedBuffer::UniqueDetachedBuffer save_jit_module_to_bytes( 270 const Module& module, 271 const ExtraFilesMap& extra_files = ExtraFilesMap()); 272 273 TORCH_API void save_jit_module_to_write_func( 274 const Module& module, 275 const ExtraFilesMap& extra_files, 276 bool save_mobile_debug_info, 277 const std::function<size_t(const void*, size_t)>& writer_func); 278 279 } // namespace torch::jit 280