1# Owner(s): ["oncall: mobile"] 2 3import inspect 4import io 5from tempfile import TemporaryFileName 6from typing import Dict, List 7 8import torch 9import torch.utils.bundled_inputs 10from torch.jit.mobile import _export_operator_list, _load_for_lite_interpreter 11from torch.testing import FileCheck 12from torch.testing._internal.common_quantization import ( 13 AnnotatedNestedModel, 14 AnnotatedSingleLayerLinearModel, 15 QuantizationLiteTestCase, 16 TwoLayerLinearModel, 17) 18from torch.testing._internal.common_utils import run_tests, TestCase 19 20 21class TestLiteScriptModule(TestCase): 22 def getScriptExportImportCopy( 23 self, m, save_mobile_debug_info=True, also_test_file=False 24 ): 25 m_scripted = torch.jit.script(m) 26 27 if not also_test_file: 28 buffer = io.BytesIO( 29 m_scripted._save_to_buffer_for_lite_interpreter( 30 _save_mobile_debug_info=save_mobile_debug_info 31 ) 32 ) 33 buffer.seek(0) 34 mobile_module = _load_for_lite_interpreter(buffer) 35 return mobile_module 36 37 with TemporaryFileName() as fname: 38 m_scripted._save_for_lite_interpreter( 39 fname, _save_mobile_debug_info=save_mobile_debug_info 40 ) 41 mobile_module = _load_for_lite_interpreter(fname) 42 return mobile_module 43 44 def test_load_mobile_module(self): 45 class MyTestModule(torch.nn.Module): 46 def forward(self, x): 47 return x + 10 48 49 input = torch.tensor([1]) 50 51 script_module = torch.jit.script(MyTestModule()) 52 script_module_result = script_module(input) 53 54 buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter()) 55 buffer.seek(0) 56 mobile_module = _load_for_lite_interpreter(buffer) 57 58 mobile_module_result = mobile_module(input) 59 torch.testing.assert_close(script_module_result, mobile_module_result) 60 61 mobile_module_forward_result = mobile_module.forward(input) 62 torch.testing.assert_close(script_module_result, mobile_module_forward_result) 63 64 mobile_module_run_method_result = mobile_module.run_method("forward", input) 65 torch.testing.assert_close( 66 script_module_result, mobile_module_run_method_result 67 ) 68 69 def test_save_mobile_module_with_debug_info_with_trace(self): 70 class A(torch.nn.Module): 71 def forward(self, x, y): 72 return x * y 73 74 class B(torch.nn.Module): 75 def __init__(self) -> None: 76 super().__init__() 77 self.A0 = A() 78 self.A1 = A() 79 80 def forward(self, x, y, z): 81 return self.A0(x, y) + self.A1(y, z) 82 83 for export_method in ["trace", "script"]: 84 x = torch.rand((2, 3)) 85 y = torch.rand((2, 3)) 86 z = torch.rand((2, 3)) 87 if export_method == "trace": 88 trace_module = torch.jit.trace(B(), [x, y, z]) 89 else: 90 trace_module = torch.jit.script(B()) 91 exported_module = trace_module._save_to_buffer_for_lite_interpreter( 92 _save_mobile_debug_info=True 93 ) 94 buffer = io.BytesIO(exported_module) 95 buffer.seek(0) 96 97 assert b"callstack_debug_map.pkl" in exported_module 98 99 mobile_module = _load_for_lite_interpreter(buffer) 100 with self.assertRaisesRegex( 101 RuntimeError, 102 r"Module hierarchy:top\(B\)::<unknown>.A0\(A\)::forward.aten::mul", 103 ): 104 x = torch.rand((2, 3)) 105 y = torch.rand((8, 10)) 106 z = torch.rand((8, 10)) 107 mobile_module(x, y, z) 108 with self.assertRaisesRegex( 109 RuntimeError, 110 r"Module hierarchy:top\(B\)::<unknown>.A1\(A\)::forward.aten::mul", 111 ): 112 x = torch.rand((2, 3)) 113 y = torch.rand((2, 3)) 114 z = torch.rand((8, 10)) 115 mobile_module(x, y, z) 116 117 def test_load_mobile_module_with_debug_info(self): 118 class MyTestModule(torch.nn.Module): 119 def forward(self, x): 120 return x + 5 121 122 input = torch.tensor([3]) 123 124 script_module = torch.jit.script(MyTestModule()) 125 script_module_result = script_module(input) 126 127 buffer = io.BytesIO( 128 script_module._save_to_buffer_for_lite_interpreter( 129 _save_mobile_debug_info=True 130 ) 131 ) 132 buffer.seek(0) 133 mobile_module = _load_for_lite_interpreter(buffer) 134 135 mobile_module_result = mobile_module(input) 136 torch.testing.assert_close(script_module_result, mobile_module_result) 137 138 mobile_module_forward_result = mobile_module.forward(input) 139 torch.testing.assert_close(script_module_result, mobile_module_forward_result) 140 141 mobile_module_run_method_result = mobile_module.run_method("forward", input) 142 torch.testing.assert_close( 143 script_module_result, mobile_module_run_method_result 144 ) 145 146 def test_find_and_run_method(self): 147 class MyTestModule(torch.nn.Module): 148 def forward(self, arg): 149 return arg 150 151 input = (torch.tensor([1]),) 152 153 script_module = torch.jit.script(MyTestModule()) 154 script_module_result = script_module(*input) 155 156 buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter()) 157 buffer.seek(0) 158 mobile_module = _load_for_lite_interpreter(buffer) 159 160 has_bundled_inputs = mobile_module.find_method("get_all_bundled_inputs") 161 self.assertFalse(has_bundled_inputs) 162 163 torch.utils.bundled_inputs.augment_model_with_bundled_inputs( 164 script_module, [input], [] 165 ) 166 167 buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter()) 168 buffer.seek(0) 169 mobile_module = _load_for_lite_interpreter(buffer) 170 171 has_bundled_inputs = mobile_module.find_method("get_all_bundled_inputs") 172 self.assertTrue(has_bundled_inputs) 173 174 bundled_inputs = mobile_module.run_method("get_all_bundled_inputs") 175 mobile_module_result = mobile_module.forward(*bundled_inputs[0]) 176 torch.testing.assert_close(script_module_result, mobile_module_result) 177 178 def test_method_calls_with_optional_arg(self): 179 class A(torch.nn.Module): 180 def __init__(self) -> None: 181 super().__init__() 182 183 # opt arg in script-to-script invocation 184 def forward(self, x, two: int = 2): 185 return x + two 186 187 class B(torch.nn.Module): 188 def __init__(self) -> None: 189 super().__init__() 190 self.A0 = A() 191 192 # opt arg in Python-to-script invocation 193 def forward(self, x, one: int = 1): 194 return self.A0(x) + one 195 196 script_module = torch.jit.script(B()) 197 buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter()) 198 mobile_module = _load_for_lite_interpreter(buffer) 199 200 input = torch.tensor([5]) 201 script_module_forward_result = script_module.forward(input) 202 mobile_module_forward_result = mobile_module.forward(input) 203 torch.testing.assert_close( 204 script_module_forward_result, mobile_module_forward_result 205 ) 206 207 # change ref only 208 script_module_forward_result = script_module.forward(input, 2) 209 self.assertFalse( 210 (script_module_forward_result == mobile_module_forward_result).all().item() 211 ) 212 213 # now both match again 214 mobile_module_forward_result = mobile_module.forward(input, 2) 215 torch.testing.assert_close( 216 script_module_forward_result, mobile_module_forward_result 217 ) 218 219 def test_unsupported_classtype(self): 220 class Foo: 221 def __init__(self) -> None: 222 return 223 224 def func(self, x: int, y: int): 225 return x + y 226 227 class MyTestModule(torch.nn.Module): 228 def forward(self, arg): 229 f = Foo() 230 return f.func(1, 2) 231 232 script_module = torch.jit.script(MyTestModule()) 233 with self.assertRaisesRegex( 234 RuntimeError, 235 r"Workaround: instead of using arbitrary class type \(class Foo\(\)\), " 236 r"define a pytorch class \(class Foo\(torch\.nn\.Module\)\)\. " 237 r"The problematic type is: ", 238 ): 239 script_module._save_to_buffer_for_lite_interpreter() 240 241 def test_unsupported_return_list_with_module_class(self): 242 class Foo(torch.nn.Module): 243 pass 244 245 class MyTestModuleForListWithModuleClass(torch.nn.Module): 246 def __init__(self) -> None: 247 super().__init__() 248 self.foo = Foo() 249 250 def forward(self): 251 my_list: List[Foo] = [self.foo] 252 return my_list 253 254 script_module = torch.jit.script(MyTestModuleForListWithModuleClass()) 255 with self.assertRaisesRegex( 256 RuntimeError, 257 r"^Returning a list or dictionary with pytorch class type " 258 r"is not supported in mobile module " 259 r"\(List\[Foo\] or Dict\[int\, Foo\] for class Foo\(torch\.nn\.Module\)\)\. " 260 r"Workaround\: instead of using pytorch class as their element type\, " 261 r"use a combination of list\, dictionary\, and single types\.$", 262 ): 263 script_module._save_to_buffer_for_lite_interpreter() 264 265 def test_unsupported_return_dict_with_module_class(self): 266 class Foo(torch.nn.Module): 267 pass 268 269 class MyTestModuleForDictWithModuleClass(torch.nn.Module): 270 def __init__(self) -> None: 271 super().__init__() 272 self.foo = Foo() 273 274 def forward(self): 275 my_dict: Dict[int, Foo] = {1: self.foo} 276 return my_dict 277 278 script_module = torch.jit.script(MyTestModuleForDictWithModuleClass()) 279 with self.assertRaisesRegex( 280 RuntimeError, 281 r"^Returning a list or dictionary with pytorch class type " 282 r"is not supported in mobile module " 283 r"\(List\[Foo\] or Dict\[int\, Foo\] for class Foo\(torch\.nn\.Module\)\)\. " 284 r"Workaround\: instead of using pytorch class as their element type\, " 285 r"use a combination of list\, dictionary\, and single types\.$", 286 ): 287 script_module._save_to_buffer_for_lite_interpreter() 288 289 def test_module_export_operator_list(self): 290 class Foo(torch.nn.Module): 291 def __init__(self) -> None: 292 super().__init__() 293 self.weight = torch.ones((20, 1, 5, 5)) 294 self.bias = torch.ones(20) 295 296 def forward(self, input): 297 x1 = torch.zeros(2, 2) 298 x2 = torch.empty_like(torch.empty(2, 2)) 299 x3 = torch._convolution( 300 input, 301 self.weight, 302 self.bias, 303 [1, 1], 304 [0, 0], 305 [1, 1], 306 False, 307 [0, 0], 308 1, 309 False, 310 False, 311 True, 312 True, 313 ) 314 return (x1, x2, x3) 315 316 m = torch.jit.script(Foo()) 317 318 buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter()) 319 buffer.seek(0) 320 mobile_module = _load_for_lite_interpreter(buffer) 321 322 expected_ops = { 323 "aten::_convolution", 324 "aten::empty.memory_format", 325 "aten::empty_like", 326 "aten::zeros", 327 } 328 actual_ops = _export_operator_list(mobile_module) 329 self.assertEqual(actual_ops, expected_ops) 330 331 def test_source_range_simple(self): 332 class FooTest(torch.jit.ScriptModule): 333 @torch.jit.script_method 334 def forward(self, x, w): 335 return torch.mm(x, w.t()) 336 337 ft = FooTest() 338 loaded = self.getScriptExportImportCopy(ft) 339 _, lineno = inspect.getsourcelines(FooTest) 340 341 with self.assertRaisesRegex( 342 RuntimeError, f'test_lite_script_module.py", line {lineno + 3}' 343 ): 344 loaded(torch.rand(3, 4), torch.rand(30, 40)) 345 346 def test_source_range_raise_exception(self): 347 class FooTest2(torch.jit.ScriptModule): 348 @torch.jit.script_method 349 def forward(self): 350 raise RuntimeError("foo") 351 352 _, lineno = inspect.getsourcelines(FooTest2) 353 354 # In C++ code, the type of exception thrown is torch::jit::JITException 355 # which does not extend c10::Error, and hence it isn't possible to add 356 # additional context to the exception message and preserve the correct 357 # C++ stack trace for symbolication. i.e. it isn't possible to add 358 # the debug handle string to show where in the Python code the exception 359 # occured w/o first changing 360 # torch::jit::JITException to extend c10::Error. 361 with self.assertRaisesRegex(torch.jit.Error, "foo"): 362 ft = FooTest2() 363 loaded = self.getScriptExportImportCopy(ft) 364 loaded() 365 366 def test_source_range_function_call(self): 367 class FooTest3(torch.jit.ScriptModule): 368 @torch.jit.script_method 369 def add_method(self, x, w): 370 return x + w 371 372 @torch.jit.script_method 373 def forward(self, x, y, w): 374 x = x * y 375 x = x + 2 376 return self.add_method(x, w) 377 378 ft = FooTest3() 379 loaded = self.getScriptExportImportCopy(ft) 380 _, lineno = inspect.getsourcelines(FooTest3) 381 382 try: 383 loaded(torch.rand(3, 4), torch.rand(3, 4), torch.rand(30, 40)) 384 except RuntimeError as e: 385 error_message = f"{e}" 386 self.assertTrue( 387 f'test_lite_script_module.py", line {lineno + 3}' in error_message 388 ) 389 self.assertTrue( 390 f'test_lite_script_module.py", line {lineno + 9}' in error_message 391 ) 392 self.assertTrue("top(FooTest3)" in error_message) 393 394 def test_source_range_no_debug_info(self): 395 class FooTest4(torch.jit.ScriptModule): 396 @torch.jit.script_method 397 def forward(self, x, w): 398 return torch.mm(x, w.t()) 399 400 ft = FooTest4() 401 loaded = self.getScriptExportImportCopy(ft, save_mobile_debug_info=False) 402 403 try: 404 loaded(torch.rand(3, 4), torch.rand(30, 40)) 405 except RuntimeError as e: 406 error_message = f"{e}" 407 self.assertTrue("test_lite_script_module.py" not in error_message) 408 409 def test_source_range_raise_exc(self): 410 class FooTest5(torch.jit.ScriptModule): 411 def __init__(self, val: int): 412 super().__init__() 413 self.val = val 414 415 @torch.jit.script_method 416 def add_method(self, val: int, x, w): 417 if val == self.val: 418 raise RuntimeError("self.val and val are same") 419 return x + w 420 421 @torch.jit.script_method 422 def forward(self, val: int, x, y, w): 423 x = x * y 424 x = x + 2 425 return self.add_method(val, x, w) 426 427 ft = FooTest5(42) 428 loaded = self.getScriptExportImportCopy(ft) 429 _, lineno = inspect.getsourcelines(FooTest5) 430 431 try: 432 loaded(42, torch.rand(3, 4), torch.rand(3, 4), torch.rand(30, 40)) 433 except torch.jit.Error as e: 434 error_message = f"{e}" 435 436 # In C++ code, the type of exception thrown is torch::jit::JITException 437 # which does not extend c10::Error, and hence it isn't possible to add 438 # additional context to the exception message and preserve the correct 439 # C++ stack trace for symbolication. i.e. it isn't possible to add 440 # the debug handle string to show where in the Python code the exception 441 # occured w/o first changing 442 # torch::jit::JITException to extend c10::Error. 443 self.assertTrue("self.val and val are same" in error_message) 444 445 def test_stacktrace_interface_call(self): 446 @torch.jit.interface 447 class Forward(torch.nn.Module): 448 def forward(self, x) -> torch.Tensor: 449 pass 450 451 def forwardError(self, x) -> torch.Tensor: 452 pass 453 454 class B(torch.nn.Module): 455 def forward(self, x): 456 return x 457 458 def forwardError(self, x): 459 return self.call() + x 460 461 def call(self): 462 return torch.ones(-1) 463 464 class A(torch.nn.Module): 465 b: Forward 466 467 def __init__(self) -> None: 468 super().__init__() 469 self.b = B() 470 471 def forward(self): 472 self.b.forward(torch.ones(1)) 473 self.b.forwardError(torch.ones(1)) 474 475 a = torch.jit.script(A()) 476 torch._C._enable_mobile_interface_call_export() 477 buffer = io.BytesIO( 478 a._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True) 479 ) 480 buffer.seek(0) 481 mobile_module = _load_for_lite_interpreter(buffer) 482 try: 483 mobile_module() 484 self.assertTrue(False) 485 except RuntimeError as exp: 486 FileCheck().check("Trying to create tensor with negative dimension").check( 487 "Traceback of TorchScript" 488 ).check("self.b.forwardError").check_next( 489 "~~~~~~~~~~~~~~~~~~~ <--- HERE" 490 ).check( 491 "return self.call" 492 ).check_next( 493 "~~~~~~~~~ <--- HERE" 494 ).check( 495 "return torch.ones" 496 ).check_next( 497 "~~~~~~~~~~ <--- HERE" 498 ).run( 499 str(exp) 500 ) 501 502 503class TestLiteScriptQuantizedModule(QuantizationLiteTestCase): 504 def test_single_layer(self): 505 input = torch.rand(2, 5, dtype=torch.float) 506 quantized_model = self._create_quantized_model( 507 model_class=AnnotatedSingleLayerLinearModel, qengine="qnnpack" 508 ) 509 self._compare_script_and_mobile(model=quantized_model, input=input) 510 511 def test_two_layer(self): 512 input = torch.rand(2, 5, dtype=torch.float) 513 quantized_model = self._create_quantized_model(model_class=TwoLayerLinearModel) 514 self._compare_script_and_mobile(model=quantized_model, input=input) 515 516 def test_annotated_nested(self): 517 input = torch.rand(2, 5, dtype=torch.float) 518 quantized_model = self._create_quantized_model( 519 model_class=AnnotatedNestedModel, qengine="qnnpack" 520 ) 521 self._compare_script_and_mobile(model=quantized_model, input=input) 522 523 def test_quantization_example(self): 524 # From the example in Static Quantization section of https://pytorch.org/docs/stable/quantization.html 525 class M(torch.nn.Module): 526 def __init__(self) -> None: 527 super().__init__() 528 self.quant = torch.ao.quantization.QuantStub() 529 self.conv = torch.nn.Conv2d(1, 1, 1) 530 self.relu = torch.nn.ReLU() 531 self.dequant = torch.ao.quantization.DeQuantStub() 532 533 def forward(self, x): 534 x = self.quant(x) 535 x = self.conv(x) 536 x = self.relu(x) 537 x = self.dequant(x) 538 return x 539 540 model_fp32 = M() 541 542 model_fp32.eval() 543 model_fp32.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack") 544 model_fp32_fused = torch.ao.quantization.fuse_modules( 545 model_fp32, [["conv", "relu"]] 546 ) 547 model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused) 548 input_fp32 = torch.randn(4, 1, 4, 4) 549 model_fp32_prepared(input_fp32) 550 model_int8 = torch.ao.quantization.convert(model_fp32_prepared) 551 552 input = torch.randn(4, 1, 4, 4) 553 self._compare_script_and_mobile(model=model_int8, input=input) 554 555 def test_bundled_input_with_dynamic_type(self): 556 class Model(torch.nn.Module): 557 def forward( 558 self, 559 x: Dict[int, torch.Tensor], 560 y: Dict[int, torch.Tensor], 561 z: Dict[int, torch.Tensor], 562 ): 563 return x 564 565 model = Model() 566 script_module = torch.jit.script(model) 567 568 sample_input = { 569 script_module.forward: [ 570 ( 571 {0: torch.ones(1)}, 572 {1: torch.ones(1)}, 573 {2: torch.ones(1)}, 574 ) 575 ] 576 } 577 578 bundled_model = torch.utils.bundled_inputs.bundle_inputs( 579 script_module, sample_input 580 ) 581 582 buf = bundled_model._save_to_buffer_for_lite_interpreter() 583 mobile_module = _load_for_lite_interpreter(io.BytesIO(buf)) 584 585 i = mobile_module.run_method("get_all_bundled_inputs") 586 587 self.assertEqual( 588 i[0], 589 ( 590 {0: torch.ones(1)}, 591 {1: torch.ones(1)}, 592 {2: torch.ones(1)}, 593 ), 594 ) 595 596 597if __name__ == "__main__": 598 run_tests() 599