1# Owner(s): ["oncall: mobile"] 2 3import fnmatch 4import io 5import shutil 6import tempfile 7from pathlib import Path 8 9import torch 10import torch.utils.show_pickle 11 12# from torch.utils.mobile_optimizer import optimize_for_mobile 13from torch.jit.mobile import ( 14 _backport_for_mobile, 15 _backport_for_mobile_to_buffer, 16 _get_mobile_model_contained_types, 17 _get_model_bytecode_version, 18 _get_model_ops_and_info, 19 _load_for_lite_interpreter, 20) 21from torch.testing._internal.common_utils import run_tests, TestCase 22 23 24pytorch_test_dir = Path(__file__).resolve().parents[1] 25 26# script_module_v4.ptl and script_module_v5.ptl source code 27# class TestModule(torch.nn.Module): 28# def __init__(self, v): 29# super().__init__() 30# self.x = v 31 32# def forward(self, y: int): 33# increment = torch.ones([2, 4], dtype=torch.float64) 34# return self.x + y + increment 35 36# output_model_path = Path(tmpdirname, "script_module_v5.ptl") 37# script_module = torch.jit.script(TestModule(1)) 38# optimized_scripted_module = optimize_for_mobile(script_module) 39# exported_optimized_scripted_module = optimized_scripted_module._save_for_lite_interpreter( 40# str(output_model_path)) 41 42SCRIPT_MODULE_V4_BYTECODE_PKL = """ 43(4, 44 ('__torch__.*.TestModule.forward', 45 (('instructions', 46 (('STOREN', 1, 2), 47 ('DROPR', 1, 0), 48 ('LOADC', 0, 0), 49 ('LOADC', 1, 0), 50 ('MOVE', 2, 0), 51 ('OP', 0, 0), 52 ('LOADC', 1, 0), 53 ('OP', 1, 0), 54 ('RET', 0, 0))), 55 ('operators', (('aten::add', 'int'), ('aten::add', 'Scalar'))), 56 ('constants', 57 (torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.DoubleStorage, '0', 'cpu', 8),), 58 0, 59 (2, 4), 60 (4, 1), 61 False, 62 collections.OrderedDict()), 63 1)), 64 ('types', ()), 65 ('register_size', 2)), 66 (('arguments', 67 ((('name', 'self'), 68 ('type', '__torch__.*.TestModule'), 69 ('default_value', None)), 70 (('name', 'y'), ('type', 'int'), ('default_value', None)))), 71 ('returns', 72 ((('name', ''), ('type', 'Tensor'), ('default_value', None)),))))) 73 """ 74 75SCRIPT_MODULE_V5_BYTECODE_PKL = """ 76(5, 77 ('__torch__.*.TestModule.forward', 78 (('instructions', 79 (('STOREN', 1, 2), 80 ('DROPR', 1, 0), 81 ('LOADC', 0, 0), 82 ('LOADC', 1, 0), 83 ('MOVE', 2, 0), 84 ('OP', 0, 0), 85 ('LOADC', 1, 0), 86 ('OP', 1, 0), 87 ('RET', 0, 0))), 88 ('operators', (('aten::add', 'int'), ('aten::add', 'Scalar'))), 89 ('constants', 90 (torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.DoubleStorage, 'constants/0', 'cpu', 8),), 91 0, 92 (2, 4), 93 (4, 1), 94 False, 95 collections.OrderedDict()), 96 1)), 97 ('types', ()), 98 ('register_size', 2)), 99 (('arguments', 100 ((('name', 'self'), 101 ('type', '__torch__.*.TestModule'), 102 ('default_value', None)), 103 (('name', 'y'), ('type', 'int'), ('default_value', None)))), 104 ('returns', 105 ((('name', ''), ('type', 'Tensor'), ('default_value', None)),))))) 106 """ 107 108SCRIPT_MODULE_V6_BYTECODE_PKL = """ 109(6, 110 ('__torch__.*.TestModule.forward', 111 (('instructions', 112 (('STOREN', 1, 2), 113 ('DROPR', 1, 0), 114 ('LOADC', 0, 0), 115 ('LOADC', 1, 0), 116 ('MOVE', 2, 0), 117 ('OP', 0, 0), 118 ('OP', 1, 0), 119 ('RET', 0, 0))), 120 ('operators', (('aten::add', 'int', 2), ('aten::add', 'Scalar', 2))), 121 ('constants', 122 (torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.DoubleStorage, '0', 'cpu', 8),), 123 0, 124 (2, 4), 125 (4, 1), 126 False, 127 collections.OrderedDict()), 128 1)), 129 ('types', ()), 130 ('register_size', 2)), 131 (('arguments', 132 ((('name', 'self'), 133 ('type', '__torch__.*.TestModule'), 134 ('default_value', None)), 135 (('name', 'y'), ('type', 'int'), ('default_value', None)))), 136 ('returns', 137 ((('name', ''), ('type', 'Tensor'), ('default_value', None)),))))) 138 """ 139 140SCRIPT_MODULE_BYTECODE_PKL = { 141 4: { 142 "bytecode_pkl": SCRIPT_MODULE_V4_BYTECODE_PKL, 143 "model_name": "script_module_v4.ptl", 144 }, 145} 146 147# The minimum version a model can be backported to 148# Need to be updated when a bytecode version is completely retired 149MINIMUM_TO_VERSION = 4 150 151 152class testVariousModelVersions(TestCase): 153 def test_get_model_bytecode_version(self): 154 def check_model_version(model_path, expect_version): 155 actual_version = _get_model_bytecode_version(model_path) 156 assert actual_version == expect_version 157 158 for version, model_info in SCRIPT_MODULE_BYTECODE_PKL.items(): 159 model_path = pytorch_test_dir / "cpp" / "jit" / model_info["model_name"] 160 check_model_version(model_path, version) 161 162 def test_bytecode_values_for_all_backport_functions(self): 163 # Find the maximum version of the checked in models, start backporting to the minimum support version, 164 # and comparing the bytecode pkl content. 165 # It can't be merged to the test `test_all_backport_functions`, because optimization is dynamic and 166 # the content might change when optimize function changes. This test focuses 167 # on bytecode.pkl content validation. For the content validation, it is not byte to byte check, but 168 # regular expression matching. The wildcard can be used to skip some specific content comparison. 169 maximum_checked_in_model_version = max(SCRIPT_MODULE_BYTECODE_PKL.keys()) 170 current_from_version = maximum_checked_in_model_version 171 172 with tempfile.TemporaryDirectory() as tmpdirname: 173 while current_from_version > MINIMUM_TO_VERSION: 174 # Load model v5 and run forward method 175 model_name = SCRIPT_MODULE_BYTECODE_PKL[current_from_version][ 176 "model_name" 177 ] 178 input_model_path = pytorch_test_dir / "cpp" / "jit" / model_name 179 180 # A temporary model file will be export to this path, and run through bytecode.pkl 181 # content check. 182 tmp_output_model_path_backport = Path( 183 tmpdirname, "tmp_script_module_backport.ptl" 184 ) 185 186 current_to_version = current_from_version - 1 187 backport_success = _backport_for_mobile( 188 input_model_path, tmp_output_model_path_backport, current_to_version 189 ) 190 assert backport_success 191 192 expect_bytecode_pkl = SCRIPT_MODULE_BYTECODE_PKL[current_to_version][ 193 "bytecode_pkl" 194 ] 195 196 buf = io.StringIO() 197 torch.utils.show_pickle.main( 198 [ 199 "", 200 tmpdirname 201 + "/" 202 + tmp_output_model_path_backport.name 203 + "@*/bytecode.pkl", 204 ], 205 output_stream=buf, 206 ) 207 output = buf.getvalue() 208 209 acutal_result_clean = "".join(output.split()) 210 expect_result_clean = "".join(expect_bytecode_pkl.split()) 211 isMatch = fnmatch.fnmatch(acutal_result_clean, expect_result_clean) 212 assert isMatch 213 214 current_from_version -= 1 215 shutil.rmtree(tmpdirname) 216 217 # Please run this test manually when working on backport. 218 # This test passes in OSS, but fails internally, likely due to missing step in build 219 # def test_all_backport_functions(self): 220 # # Backport from the latest bytecode version to the minimum support version 221 # # Load, run the backport model, and check version 222 # class TestModule(torch.nn.Module): 223 # def __init__(self, v): 224 # super().__init__() 225 # self.x = v 226 227 # def forward(self, y: int): 228 # increment = torch.ones([2, 4], dtype=torch.float64) 229 # return self.x + y + increment 230 231 # module_input = 1 232 # expected_mobile_module_result = 3 * torch.ones([2, 4], dtype=torch.float64) 233 234 # # temporary input model file and output model file will be exported in the temporary folder 235 # with tempfile.TemporaryDirectory() as tmpdirname: 236 # tmp_input_model_path = Path(tmpdirname, "tmp_script_module.ptl") 237 # script_module = torch.jit.script(TestModule(1)) 238 # optimized_scripted_module = optimize_for_mobile(script_module) 239 # exported_optimized_scripted_module = optimized_scripted_module._save_for_lite_interpreter(str(tmp_input_model_path)) 240 241 # current_from_version = _get_model_bytecode_version(tmp_input_model_path) 242 # current_to_version = current_from_version - 1 243 # tmp_output_model_path = Path(tmpdirname, "tmp_script_module_backport.ptl") 244 245 # while current_to_version >= MINIMUM_TO_VERSION: 246 # # Backport the latest model to `to_version` to a tmp file "tmp_script_module_backport" 247 # backport_success = _backport_for_mobile(tmp_input_model_path, tmp_output_model_path, current_to_version) 248 # assert(backport_success) 249 250 # backport_version = _get_model_bytecode_version(tmp_output_model_path) 251 # assert(backport_version == current_to_version) 252 253 # # Load model and run forward method 254 # mobile_module = _load_for_lite_interpreter(str(tmp_input_model_path)) 255 # mobile_module_result = mobile_module(module_input) 256 # torch.testing.assert_close(mobile_module_result, expected_mobile_module_result) 257 # current_to_version -= 1 258 259 # # Check backport failure case 260 # backport_success = _backport_for_mobile(tmp_input_model_path, tmp_output_model_path, MINIMUM_TO_VERSION - 1) 261 # assert(not backport_success) 262 # # need to clean the folder before it closes, otherwise will run into git not clean error 263 # shutil.rmtree(tmpdirname) 264 265 # Check just the test_backport_bytecode_from_file_to_file mechanism but not the function implementations 266 def test_backport_bytecode_from_file_to_file(self): 267 maximum_checked_in_model_version = max(SCRIPT_MODULE_BYTECODE_PKL.keys()) 268 script_module_v5_path = ( 269 pytorch_test_dir 270 / "cpp" 271 / "jit" 272 / SCRIPT_MODULE_BYTECODE_PKL[maximum_checked_in_model_version]["model_name"] 273 ) 274 275 if maximum_checked_in_model_version > MINIMUM_TO_VERSION: 276 with tempfile.TemporaryDirectory() as tmpdirname: 277 tmp_backport_model_path = Path( 278 tmpdirname, "tmp_script_module_v5_backported_to_v4.ptl" 279 ) 280 # backport from file 281 success = _backport_for_mobile( 282 script_module_v5_path, 283 tmp_backport_model_path, 284 maximum_checked_in_model_version - 1, 285 ) 286 assert success 287 288 buf = io.StringIO() 289 torch.utils.show_pickle.main( 290 [ 291 "", 292 tmpdirname 293 + "/" 294 + tmp_backport_model_path.name 295 + "@*/bytecode.pkl", 296 ], 297 output_stream=buf, 298 ) 299 output = buf.getvalue() 300 301 expected_result = SCRIPT_MODULE_V4_BYTECODE_PKL 302 acutal_result_clean = "".join(output.split()) 303 expect_result_clean = "".join(expected_result.split()) 304 isMatch = fnmatch.fnmatch(acutal_result_clean, expect_result_clean) 305 assert isMatch 306 307 # Load model v4 and run forward method 308 mobile_module = _load_for_lite_interpreter(str(tmp_backport_model_path)) 309 module_input = 1 310 mobile_module_result = mobile_module(module_input) 311 expected_mobile_module_result = 3 * torch.ones( 312 [2, 4], dtype=torch.float64 313 ) 314 torch.testing.assert_close( 315 mobile_module_result, expected_mobile_module_result 316 ) 317 shutil.rmtree(tmpdirname) 318 319 # Check just the _backport_for_mobile_to_buffer mechanism but not the function implementations 320 def test_backport_bytecode_from_file_to_buffer(self): 321 maximum_checked_in_model_version = max(SCRIPT_MODULE_BYTECODE_PKL.keys()) 322 script_module_v5_path = ( 323 pytorch_test_dir 324 / "cpp" 325 / "jit" 326 / SCRIPT_MODULE_BYTECODE_PKL[maximum_checked_in_model_version]["model_name"] 327 ) 328 329 if maximum_checked_in_model_version > MINIMUM_TO_VERSION: 330 # Backport model to v4 331 script_module_v4_buffer = _backport_for_mobile_to_buffer( 332 script_module_v5_path, maximum_checked_in_model_version - 1 333 ) 334 buf = io.StringIO() 335 336 # Check version of the model v4 from backport 337 bytesio = io.BytesIO(script_module_v4_buffer) 338 backport_version = _get_model_bytecode_version(bytesio) 339 assert backport_version == maximum_checked_in_model_version - 1 340 341 # Load model v4 from backport and run forward method 342 bytesio = io.BytesIO(script_module_v4_buffer) 343 mobile_module = _load_for_lite_interpreter(bytesio) 344 module_input = 1 345 mobile_module_result = mobile_module(module_input) 346 expected_mobile_module_result = 3 * torch.ones([2, 4], dtype=torch.float64) 347 torch.testing.assert_close( 348 mobile_module_result, expected_mobile_module_result 349 ) 350 351 def test_get_model_ops_and_info(self): 352 # TODO update this to be more in the style of the above tests after a backport from 6 -> 5 exists 353 script_module_v6 = pytorch_test_dir / "cpp" / "jit" / "script_module_v6.ptl" 354 ops_v6 = _get_model_ops_and_info(script_module_v6) 355 assert ops_v6["aten::add.int"].num_schema_args == 2 356 assert ops_v6["aten::add.Scalar"].num_schema_args == 2 357 358 def test_get_mobile_model_contained_types(self): 359 class MyTestModule(torch.nn.Module): 360 def forward(self, x): 361 return x + 10 362 363 sample_input = torch.tensor([1]) 364 365 script_module = torch.jit.script(MyTestModule()) 366 script_module_result = script_module(sample_input) 367 368 buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter()) 369 buffer.seek(0) 370 type_list = _get_mobile_model_contained_types(buffer) 371 assert len(type_list) >= 0 372 373 374if __name__ == "__main__": 375 run_tests() 376