1# mypy: allow-untyped-defs 2import os 3 4import torch 5from torch.jit._serialization import validate_map_location 6 7 8def _load_for_lite_interpreter(f, map_location=None): 9 r""" 10 Load a :class:`LiteScriptModule` saved with :func:`torch.jit._save_for_lite_interpreter`. 11 12 Args: 13 f: a file-like object (has to implement read, readline, tell, and seek), 14 or a string containing a file name 15 map_location: a string or torch.device used to dynamically remap 16 storages to an alternative set of devices. 17 18 Returns: 19 A :class:`LiteScriptModule` object. 20 21 Example: 22 23 .. testcode:: 24 25 import torch 26 import io 27 28 # Load LiteScriptModule from saved file path 29 torch.jit._load_for_lite_interpreter('lite_script_module.pt') 30 31 # Load LiteScriptModule from io.BytesIO object 32 with open('lite_script_module.pt', 'rb') as f: 33 buffer = io.BytesIO(f.read()) 34 35 # Load all tensors to the original device 36 torch.jit.mobile._load_for_lite_interpreter(buffer) 37 """ 38 if isinstance(f, (str, os.PathLike)): 39 if not os.path.exists(f): 40 raise ValueError(f"The provided filename {f} does not exist") 41 if os.path.isdir(f): 42 raise ValueError(f"The provided filename {f} is a directory") 43 44 map_location = validate_map_location(map_location) 45 46 if isinstance(f, (str, os.PathLike)): 47 cpp_module = torch._C._load_for_lite_interpreter(os.fspath(f), map_location) 48 else: 49 cpp_module = torch._C._load_for_lite_interpreter_from_buffer( 50 f.read(), map_location 51 ) 52 53 return LiteScriptModule(cpp_module) 54 55 56class LiteScriptModule: 57 def __init__(self, cpp_module): 58 self._c = cpp_module 59 super().__init__() 60 61 def __call__(self, *input): 62 return self._c.forward(input) 63 64 def find_method(self, method_name): 65 return self._c.find_method(method_name) 66 67 def forward(self, *input): 68 return self._c.forward(input) 69 70 def run_method(self, method_name, *input): 71 return self._c.run_method(method_name, input) 72 73 74def _export_operator_list(module: LiteScriptModule): 75 r"""Return a set of root operator names (with overload name) that are used by any method in this mobile module.""" 76 return torch._C._export_operator_list(module._c) 77 78 79def _get_model_bytecode_version(f_input) -> int: 80 r"""Take a file-like object to return an integer. 81 82 Args: 83 f_input: a file-like object (has to implement read, readline, tell, and seek), 84 or a string containing a file name 85 86 Returns: 87 version: An integer. If the integer is -1, the version is invalid. A warning 88 will show in the log. 89 90 Example: 91 .. testcode:: 92 93 from torch.jit.mobile import _get_model_bytecode_version 94 95 # Get bytecode version from a saved file path 96 version = _get_model_bytecode_version("path/to/model.ptl") 97 98 """ 99 if isinstance(f_input, (str, os.PathLike)): 100 if not os.path.exists(f_input): 101 raise ValueError(f"The provided filename {f_input} does not exist") 102 if os.path.isdir(f_input): 103 raise ValueError(f"The provided filename {f_input} is a directory") 104 105 if isinstance(f_input, (str, os.PathLike)): 106 return torch._C._get_model_bytecode_version(os.fspath(f_input)) 107 else: 108 return torch._C._get_model_bytecode_version_from_buffer(f_input.read()) 109 110 111def _get_mobile_model_contained_types(f_input) -> int: 112 r"""Take a file-like object and return a set of string, like ("int", "Optional"). 113 114 Args: 115 f_input: a file-like object (has to implement read, readline, tell, and seek), 116 or a string containing a file name 117 118 Returns: 119 type_list: A set of string, like ("int", "Optional"). These are types used in bytecode. 120 121 Example: 122 123 .. testcode:: 124 125 from torch.jit.mobile import _get_mobile_model_contained_types 126 127 # Get type list from a saved file path 128 type_list = _get_mobile_model_contained_types("path/to/model.ptl") 129 130 """ 131 if isinstance(f_input, (str, os.PathLike)): 132 if not os.path.exists(f_input): 133 raise ValueError(f"The provided filename {f_input} does not exist") 134 if os.path.isdir(f_input): 135 raise ValueError(f"The provided filename {f_input} is a directory") 136 137 if isinstance(f_input, (str, os.PathLike)): 138 return torch._C._get_mobile_model_contained_types(os.fspath(f_input)) 139 else: 140 return torch._C._get_mobile_model_contained_types_from_buffer(f_input.read()) 141 142 143def _backport_for_mobile(f_input, f_output, to_version): 144 r"""Take a input string containing a file name (file-like object) and a new destination to return a boolean. 145 146 Args: 147 f_input: a file-like object (has to implement read, readline, tell, and seek), 148 or a string containing a file name 149 f_output: path to new model destination 150 to_version: the expected output model bytecode version 151 Returns: 152 success: A boolean. If backport success, return true, otherwise false 153 """ 154 if isinstance(f_input, (str, os.PathLike)): 155 if not os.path.exists(f_input): 156 raise ValueError(f"The provided filename {f_input} does not exist") 157 if os.path.isdir(f_input): 158 raise ValueError(f"The provided filename {f_input} is a directory") 159 160 if (isinstance(f_input, (str, os.PathLike))) and ( 161 isinstance(f_output, (str, os.PathLike)) 162 ): 163 return torch._C._backport_for_mobile( 164 os.fspath(f_input), os.fspath(f_output), to_version 165 ) 166 else: 167 return torch._C._backport_for_mobile_from_buffer( 168 f_input.read(), str(f_output), to_version 169 ) 170 171 172def _backport_for_mobile_to_buffer(f_input, to_version): 173 r"""Take a string containing a file name (file-like object). 174 175 Args: 176 f_input: a file-like object (has to implement read, readline, tell, and seek), 177 or a string containing a file name 178 179 """ 180 if isinstance(f_input, (str, os.PathLike)): 181 if not os.path.exists(f_input): 182 raise ValueError(f"The provided filename {f_input} does not exist") 183 if os.path.isdir(f_input): 184 raise ValueError(f"The provided filename {f_input} is a directory") 185 186 if isinstance(f_input, (str, os.PathLike)): 187 return torch._C._backport_for_mobile_to_buffer(os.fspath(f_input), to_version) 188 else: 189 return torch._C._backport_for_mobile_from_buffer_to_buffer( 190 f_input.read(), to_version 191 ) 192 193 194def _get_model_ops_and_info(f_input): 195 r"""Retrieve the root (top level) operators of a model and their corresponding compatibility info. 196 197 These root operators can call other operators within them (traced ops), and 198 a root op can call many different traced ops depending on internal code paths in the root op. 199 These traced ops are not returned by this function. Those operators are abstracted into the 200 runtime as an implementation detail (and the traced ops themselves can also call other operators) 201 making retrieving them difficult and their value from this api negligible since they will differ 202 between which runtime version the model is run on. Because of this, there is a false positive this 203 api can't prevent in a compatibility usecase. All the root ops of a model are present in a 204 target runtime, but not all the traced ops are which prevents a model from being able to run. 205 Args: 206 f_input: a file-like object (has to implement read, readline, tell, and seek), 207 or a string containing a file name 208 209 Returns: 210 Operators and info: A Dictionary mapping strings (the qualified names of the root operators) 211 of the model to their OperatorInfo structs. 212 213 Example: 214 215 .. testcode:: 216 217 from torch.jit.mobile import _get_model_ops_and_info 218 219 # Get bytecode version from a saved file path 220 ops_and_info = _get_model_ops_and_info("path/to/model.ptl") 221 222 """ 223 if isinstance(f_input, (str, os.PathLike)): 224 if not os.path.exists(f_input): 225 raise ValueError(f"The provided filename {f_input} does not exist") 226 if os.path.isdir(f_input): 227 raise ValueError(f"The provided filename {f_input} is a directory") 228 229 if isinstance(f_input, (str, os.PathLike)): 230 return torch._C._get_model_ops_and_info(os.fspath(f_input)) 231 else: 232 return torch._C._get_model_ops_and_info(f_input.read()) 233