1# mypy: allow-untyped-defs 2"""Serialization. 3 4This module contains functionality for serializing TorchScript modules, notably: 5 * torch.jit.save 6 * torch.jit.load 7 8This is not intended to be imported directly; please use the exposed 9functionalities in `torch.jit`. 10""" 11 12import os 13 14import torch 15from torch._jit_internal import _get_model_id 16from torch._utils_internal import log_torchscript_usage 17from torch.jit._recursive import wrap_cpp_module 18from torch.serialization import validate_cuda_device 19 20 21def save(m, f, _extra_files=None): 22 r""" 23 Save an offline version of this module for use in a separate process. 24 25 The saved module serializes all of the methods, submodules, parameters, and 26 attributes of this module. It can be loaded into the C++ API using 27 ``torch::jit::load(filename)`` or into the Python API with 28 :func:`torch.jit.load <torch.jit.load>`. 29 30 To be able to save a module, it must not make any calls to native Python 31 functions. This means that all submodules must be subclasses of 32 :class:`ScriptModule` as well. 33 34 .. DANGER:: 35 All modules, no matter their device, are always loaded onto the CPU 36 during loading. This is different from :func:`torch.load`'s semantics 37 and may change in the future. 38 39 Args: 40 m: A :class:`ScriptModule` to save. 41 f: A file-like object (has to implement write and flush) or a string 42 containing a file name. 43 _extra_files: Map from filename to contents which will be stored as part of `f`. 44 45 .. note:: 46 torch.jit.save attempts to preserve the behavior of some operators 47 across versions. For example, dividing two integer tensors in 48 PyTorch 1.5 performed floor division, and if the module 49 containing that code is saved in PyTorch 1.5 and loaded in PyTorch 1.6 50 its division behavior will be preserved. The same module saved in 51 PyTorch 1.6 will fail to load in PyTorch 1.5, however, since the 52 behavior of division changed in 1.6, and 1.5 does not know how to 53 replicate the 1.6 behavior. 54 55 Example: 56 .. testcode:: 57 58 import torch 59 import io 60 61 class MyModule(torch.nn.Module): 62 def forward(self, x): 63 return x + 10 64 65 m = torch.jit.script(MyModule()) 66 67 # Save to file 68 torch.jit.save(m, 'scriptmodule.pt') 69 # This line is equivalent to the previous 70 m.save("scriptmodule.pt") 71 72 # Save to io.BytesIO buffer 73 buffer = io.BytesIO() 74 torch.jit.save(m, buffer) 75 76 # Save with extra files 77 extra_files = {'foo.txt': b'bar'} 78 torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files) 79 """ 80 log_torchscript_usage("save", model_id=_get_model_id(m)) 81 if _extra_files is None: 82 _extra_files = {} 83 if isinstance(f, (str, os.PathLike)): 84 m.save(f, _extra_files=_extra_files) 85 else: 86 ret = m.save_to_buffer(_extra_files=_extra_files) 87 f.write(ret) 88 89 90def load(f, map_location=None, _extra_files=None, _restore_shapes=False): 91 r""" 92 Load a :class:`ScriptModule` or :class:`ScriptFunction` previously saved with :func:`torch.jit.save <torch.jit.save>`. 93 94 All previously saved modules, no matter their device, are first loaded onto CPU, 95 and then are moved to the devices they were saved from. If this fails (e.g. 96 because the run time system doesn't have certain devices), an exception is 97 raised. 98 99 Args: 100 f: a file-like object (has to implement read, readline, tell, and seek), 101 or a string containing a file name 102 map_location (string or torch.device): A simplified version of 103 ``map_location`` in `torch.jit.save` used to dynamically remap 104 storages to an alternative set of devices. 105 _extra_files (dictionary of filename to content): The extra 106 filenames given in the map would be loaded and their content 107 would be stored in the provided map. 108 _restore_shapes (bool): Whether or not to retrace the module on load using stored inputs 109 110 Returns: 111 A :class:`ScriptModule` object. 112 113 Example: 114 .. testcode:: 115 116 import torch 117 import io 118 119 torch.jit.load('scriptmodule.pt') 120 121 # Load ScriptModule from io.BytesIO object 122 with open('scriptmodule.pt', 'rb') as f: 123 buffer = io.BytesIO(f.read()) 124 125 # Load all tensors to the original device 126 torch.jit.load(buffer) 127 128 # Load all tensors onto CPU, using a device 129 buffer.seek(0) 130 torch.jit.load(buffer, map_location=torch.device('cpu')) 131 132 # Load all tensors onto CPU, using a string 133 buffer.seek(0) 134 torch.jit.load(buffer, map_location='cpu') 135 136 # Load with extra files. 137 extra_files = {'foo.txt': ''} # values will be replaced with data 138 torch.jit.load('scriptmodule.pt', _extra_files=extra_files) 139 print(extra_files['foo.txt']) 140 141 .. testoutput:: 142 :hide: 143 144 ... 145 146 .. testcleanup:: 147 148 import os 149 os.remove("scriptmodule.pt") 150 """ 151 if isinstance(f, (str, os.PathLike)): 152 if not os.path.exists(f): # type: ignore[type-var] 153 raise ValueError(f"The provided filename {f} does not exist") # type: ignore[str-bytes-safe] 154 if os.path.isdir(f): 155 raise ValueError(f"The provided filename {f} is a directory") # type: ignore[str-bytes-safe] 156 157 map_location = validate_map_location(map_location) 158 if _extra_files is None: 159 _extra_files = {} 160 161 cu = torch._C.CompilationUnit() 162 if isinstance(f, (str, os.PathLike)): 163 cpp_module = torch._C.import_ir_module(cu, os.fspath(f), map_location, _extra_files, _restore_shapes) # type: ignore[call-arg] 164 else: 165 cpp_module = torch._C.import_ir_module_from_buffer( 166 cu, f.read(), map_location, _extra_files, _restore_shapes 167 ) # type: ignore[call-arg] 168 169 # TODO: Pretty sure this approach loses ConstSequential status and such 170 ret = wrap_cpp_module(cpp_module) 171 log_torchscript_usage("load", model_id=_get_model_id(ret)) 172 return ret 173 174 175def validate_map_location(map_location=None): 176 if isinstance(map_location, str): 177 map_location = torch.device(map_location) 178 elif not (map_location is None or isinstance(map_location, torch.device)): 179 raise ValueError( 180 "map_location should be either None, string or torch.device, " 181 "but got type: " + str(type(map_location)) 182 ) 183 184 if str(map_location).startswith("cuda"): 185 validate_cuda_device(map_location) 186 187 return map_location 188 189 190def jit_module_from_flatbuffer(f): 191 if isinstance(f, (str, os.PathLike)): 192 f = os.fspath(f) 193 return wrap_cpp_module(torch._C._load_jit_module_from_file(f)) 194 else: 195 return wrap_cpp_module(torch._C._load_jit_module_from_bytes(f.read())) 196 197 198def save_jit_module_to_flatbuffer(m, f, _extra_files=None): 199 r""" 200 Save an offline version of this module for use in a separate process. 201 202 The saved module serializes all of the methods, submodules, parameters, and 203 attributes of this module. It can be loaded into the C++ API using 204 ``torch::jit::load_jit_module_from_file(filename)`` or into the Python API with 205 :func:`torch.jit.jit_module_from_flatbuffer<torch.jit.jit_module_from_flatbuffer>`. 206 207 To be able to save a module, it must not make any calls to native Python 208 functions. This means that all submodules must be subclasses of 209 :class:`ScriptModule` as well. 210 211 .. DANGER:: 212 All modules, no matter their device, are always loaded onto the CPU 213 during loading. This is different from :func:`torch.load`'s semantics 214 and may change in the future. 215 216 Args: 217 m: A :class:`ScriptModule` to save. 218 f: A string for file path 219 220 221 Example: 222 .. testcode:: 223 224 import torch 225 import io 226 227 class MyModule(torch.nn.Module): 228 def forward(self, x): 229 return x + 10 230 231 m = torch.jit.script(MyModule()) 232 233 # Save to file 234 torch.jit.save_jit_module_to_flatbuffer(m, 'scriptmodule.ff') 235 """ 236 extra_files = _extra_files 237 if extra_files is None: 238 extra_files = {} 239 240 if isinstance(f, (str, os.PathLike)): 241 f = os.fspath(f) 242 torch._C._save_jit_module(m._c, f, extra_files) 243 else: 244 s = torch._C._save_jit_module_to_bytes(m._c, extra_files) 245 f.write(s) 246 247 248def get_flatbuffer_module_info(path_or_file): 249 r"""Get some information regarding a model file in flatbuffer format. 250 251 Args: 252 path_or_file: Either str, Path or file like object (BytesIO OK). 253 If it's str or Path, we will read the file referenced by that 254 path as Bytes. 255 256 Returns: 257 A dict with metadata on what that file contains, currently looks like 258 this: 259 { 260 'bytecode_version': 4, # int 261 'operator_version': 4, # int 262 'function_names': { 263 '__torch__.___torch_mangle_0.Foo.forward'}, # set 264 'type_names': set(), # set 265 'opname_to_num_args': {'aten::linear': 3} # Dict[str, int] 266 } 267 """ 268 if isinstance(path_or_file, (str, os.PathLike)): 269 with open(path_or_file, "rb") as f: 270 all_bytes = f.read() 271 else: 272 all_bytes = path_or_file.read() 273 return torch._C._get_module_info_from_flatbuffer(all_bytes) 274