xref: /aosp_15_r20/external/pytorch/test/mobile/test_lite_script_module.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: mobile"]
2
3import inspect
4import io
5from tempfile import TemporaryFileName
6from typing import Dict, List
7
8import torch
9import torch.utils.bundled_inputs
10from torch.jit.mobile import _export_operator_list, _load_for_lite_interpreter
11from torch.testing import FileCheck
12from torch.testing._internal.common_quantization import (
13    AnnotatedNestedModel,
14    AnnotatedSingleLayerLinearModel,
15    QuantizationLiteTestCase,
16    TwoLayerLinearModel,
17)
18from torch.testing._internal.common_utils import run_tests, TestCase
19
20
21class TestLiteScriptModule(TestCase):
22    def getScriptExportImportCopy(
23        self, m, save_mobile_debug_info=True, also_test_file=False
24    ):
25        m_scripted = torch.jit.script(m)
26
27        if not also_test_file:
28            buffer = io.BytesIO(
29                m_scripted._save_to_buffer_for_lite_interpreter(
30                    _save_mobile_debug_info=save_mobile_debug_info
31                )
32            )
33            buffer.seek(0)
34            mobile_module = _load_for_lite_interpreter(buffer)
35            return mobile_module
36
37        with TemporaryFileName() as fname:
38            m_scripted._save_for_lite_interpreter(
39                fname, _save_mobile_debug_info=save_mobile_debug_info
40            )
41            mobile_module = _load_for_lite_interpreter(fname)
42            return mobile_module
43
44    def test_load_mobile_module(self):
45        class MyTestModule(torch.nn.Module):
46            def forward(self, x):
47                return x + 10
48
49        input = torch.tensor([1])
50
51        script_module = torch.jit.script(MyTestModule())
52        script_module_result = script_module(input)
53
54        buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
55        buffer.seek(0)
56        mobile_module = _load_for_lite_interpreter(buffer)
57
58        mobile_module_result = mobile_module(input)
59        torch.testing.assert_close(script_module_result, mobile_module_result)
60
61        mobile_module_forward_result = mobile_module.forward(input)
62        torch.testing.assert_close(script_module_result, mobile_module_forward_result)
63
64        mobile_module_run_method_result = mobile_module.run_method("forward", input)
65        torch.testing.assert_close(
66            script_module_result, mobile_module_run_method_result
67        )
68
69    def test_save_mobile_module_with_debug_info_with_trace(self):
70        class A(torch.nn.Module):
71            def forward(self, x, y):
72                return x * y
73
74        class B(torch.nn.Module):
75            def __init__(self) -> None:
76                super().__init__()
77                self.A0 = A()
78                self.A1 = A()
79
80            def forward(self, x, y, z):
81                return self.A0(x, y) + self.A1(y, z)
82
83        for export_method in ["trace", "script"]:
84            x = torch.rand((2, 3))
85            y = torch.rand((2, 3))
86            z = torch.rand((2, 3))
87            if export_method == "trace":
88                trace_module = torch.jit.trace(B(), [x, y, z])
89            else:
90                trace_module = torch.jit.script(B())
91            exported_module = trace_module._save_to_buffer_for_lite_interpreter(
92                _save_mobile_debug_info=True
93            )
94            buffer = io.BytesIO(exported_module)
95            buffer.seek(0)
96
97            assert b"callstack_debug_map.pkl" in exported_module
98
99            mobile_module = _load_for_lite_interpreter(buffer)
100            with self.assertRaisesRegex(
101                RuntimeError,
102                r"Module hierarchy:top\(B\)::<unknown>.A0\(A\)::forward.aten::mul",
103            ):
104                x = torch.rand((2, 3))
105                y = torch.rand((8, 10))
106                z = torch.rand((8, 10))
107                mobile_module(x, y, z)
108            with self.assertRaisesRegex(
109                RuntimeError,
110                r"Module hierarchy:top\(B\)::<unknown>.A1\(A\)::forward.aten::mul",
111            ):
112                x = torch.rand((2, 3))
113                y = torch.rand((2, 3))
114                z = torch.rand((8, 10))
115                mobile_module(x, y, z)
116
117    def test_load_mobile_module_with_debug_info(self):
118        class MyTestModule(torch.nn.Module):
119            def forward(self, x):
120                return x + 5
121
122        input = torch.tensor([3])
123
124        script_module = torch.jit.script(MyTestModule())
125        script_module_result = script_module(input)
126
127        buffer = io.BytesIO(
128            script_module._save_to_buffer_for_lite_interpreter(
129                _save_mobile_debug_info=True
130            )
131        )
132        buffer.seek(0)
133        mobile_module = _load_for_lite_interpreter(buffer)
134
135        mobile_module_result = mobile_module(input)
136        torch.testing.assert_close(script_module_result, mobile_module_result)
137
138        mobile_module_forward_result = mobile_module.forward(input)
139        torch.testing.assert_close(script_module_result, mobile_module_forward_result)
140
141        mobile_module_run_method_result = mobile_module.run_method("forward", input)
142        torch.testing.assert_close(
143            script_module_result, mobile_module_run_method_result
144        )
145
146    def test_find_and_run_method(self):
147        class MyTestModule(torch.nn.Module):
148            def forward(self, arg):
149                return arg
150
151        input = (torch.tensor([1]),)
152
153        script_module = torch.jit.script(MyTestModule())
154        script_module_result = script_module(*input)
155
156        buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
157        buffer.seek(0)
158        mobile_module = _load_for_lite_interpreter(buffer)
159
160        has_bundled_inputs = mobile_module.find_method("get_all_bundled_inputs")
161        self.assertFalse(has_bundled_inputs)
162
163        torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
164            script_module, [input], []
165        )
166
167        buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
168        buffer.seek(0)
169        mobile_module = _load_for_lite_interpreter(buffer)
170
171        has_bundled_inputs = mobile_module.find_method("get_all_bundled_inputs")
172        self.assertTrue(has_bundled_inputs)
173
174        bundled_inputs = mobile_module.run_method("get_all_bundled_inputs")
175        mobile_module_result = mobile_module.forward(*bundled_inputs[0])
176        torch.testing.assert_close(script_module_result, mobile_module_result)
177
178    def test_method_calls_with_optional_arg(self):
179        class A(torch.nn.Module):
180            def __init__(self) -> None:
181                super().__init__()
182
183            # opt arg in script-to-script invocation
184            def forward(self, x, two: int = 2):
185                return x + two
186
187        class B(torch.nn.Module):
188            def __init__(self) -> None:
189                super().__init__()
190                self.A0 = A()
191
192            # opt arg in Python-to-script invocation
193            def forward(self, x, one: int = 1):
194                return self.A0(x) + one
195
196        script_module = torch.jit.script(B())
197        buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
198        mobile_module = _load_for_lite_interpreter(buffer)
199
200        input = torch.tensor([5])
201        script_module_forward_result = script_module.forward(input)
202        mobile_module_forward_result = mobile_module.forward(input)
203        torch.testing.assert_close(
204            script_module_forward_result, mobile_module_forward_result
205        )
206
207        # change ref only
208        script_module_forward_result = script_module.forward(input, 2)
209        self.assertFalse(
210            (script_module_forward_result == mobile_module_forward_result).all().item()
211        )
212
213        # now both match again
214        mobile_module_forward_result = mobile_module.forward(input, 2)
215        torch.testing.assert_close(
216            script_module_forward_result, mobile_module_forward_result
217        )
218
219    def test_unsupported_classtype(self):
220        class Foo:
221            def __init__(self) -> None:
222                return
223
224            def func(self, x: int, y: int):
225                return x + y
226
227        class MyTestModule(torch.nn.Module):
228            def forward(self, arg):
229                f = Foo()
230                return f.func(1, 2)
231
232        script_module = torch.jit.script(MyTestModule())
233        with self.assertRaisesRegex(
234            RuntimeError,
235            r"Workaround: instead of using arbitrary class type \(class Foo\(\)\), "
236            r"define a pytorch class \(class Foo\(torch\.nn\.Module\)\)\. "
237            r"The problematic type is: ",
238        ):
239            script_module._save_to_buffer_for_lite_interpreter()
240
241    def test_unsupported_return_list_with_module_class(self):
242        class Foo(torch.nn.Module):
243            pass
244
245        class MyTestModuleForListWithModuleClass(torch.nn.Module):
246            def __init__(self) -> None:
247                super().__init__()
248                self.foo = Foo()
249
250            def forward(self):
251                my_list: List[Foo] = [self.foo]
252                return my_list
253
254        script_module = torch.jit.script(MyTestModuleForListWithModuleClass())
255        with self.assertRaisesRegex(
256            RuntimeError,
257            r"^Returning a list or dictionary with pytorch class type "
258            r"is not supported in mobile module "
259            r"\(List\[Foo\] or Dict\[int\, Foo\] for class Foo\(torch\.nn\.Module\)\)\. "
260            r"Workaround\: instead of using pytorch class as their element type\, "
261            r"use a combination of list\, dictionary\, and single types\.$",
262        ):
263            script_module._save_to_buffer_for_lite_interpreter()
264
265    def test_unsupported_return_dict_with_module_class(self):
266        class Foo(torch.nn.Module):
267            pass
268
269        class MyTestModuleForDictWithModuleClass(torch.nn.Module):
270            def __init__(self) -> None:
271                super().__init__()
272                self.foo = Foo()
273
274            def forward(self):
275                my_dict: Dict[int, Foo] = {1: self.foo}
276                return my_dict
277
278        script_module = torch.jit.script(MyTestModuleForDictWithModuleClass())
279        with self.assertRaisesRegex(
280            RuntimeError,
281            r"^Returning a list or dictionary with pytorch class type "
282            r"is not supported in mobile module "
283            r"\(List\[Foo\] or Dict\[int\, Foo\] for class Foo\(torch\.nn\.Module\)\)\. "
284            r"Workaround\: instead of using pytorch class as their element type\, "
285            r"use a combination of list\, dictionary\, and single types\.$",
286        ):
287            script_module._save_to_buffer_for_lite_interpreter()
288
289    def test_module_export_operator_list(self):
290        class Foo(torch.nn.Module):
291            def __init__(self) -> None:
292                super().__init__()
293                self.weight = torch.ones((20, 1, 5, 5))
294                self.bias = torch.ones(20)
295
296            def forward(self, input):
297                x1 = torch.zeros(2, 2)
298                x2 = torch.empty_like(torch.empty(2, 2))
299                x3 = torch._convolution(
300                    input,
301                    self.weight,
302                    self.bias,
303                    [1, 1],
304                    [0, 0],
305                    [1, 1],
306                    False,
307                    [0, 0],
308                    1,
309                    False,
310                    False,
311                    True,
312                    True,
313                )
314                return (x1, x2, x3)
315
316        m = torch.jit.script(Foo())
317
318        buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
319        buffer.seek(0)
320        mobile_module = _load_for_lite_interpreter(buffer)
321
322        expected_ops = {
323            "aten::_convolution",
324            "aten::empty.memory_format",
325            "aten::empty_like",
326            "aten::zeros",
327        }
328        actual_ops = _export_operator_list(mobile_module)
329        self.assertEqual(actual_ops, expected_ops)
330
331    def test_source_range_simple(self):
332        class FooTest(torch.jit.ScriptModule):
333            @torch.jit.script_method
334            def forward(self, x, w):
335                return torch.mm(x, w.t())
336
337        ft = FooTest()
338        loaded = self.getScriptExportImportCopy(ft)
339        _, lineno = inspect.getsourcelines(FooTest)
340
341        with self.assertRaisesRegex(
342            RuntimeError, f'test_lite_script_module.py", line {lineno + 3}'
343        ):
344            loaded(torch.rand(3, 4), torch.rand(30, 40))
345
346    def test_source_range_raise_exception(self):
347        class FooTest2(torch.jit.ScriptModule):
348            @torch.jit.script_method
349            def forward(self):
350                raise RuntimeError("foo")
351
352        _, lineno = inspect.getsourcelines(FooTest2)
353
354        # In C++ code, the type of exception thrown is torch::jit::JITException
355        # which does not extend c10::Error, and hence it isn't possible to add
356        # additional context to the exception message and preserve the correct
357        #  C++ stack trace for symbolication. i.e. it isn't possible to add
358        # the debug handle string to show where in the Python code the exception
359        # occured w/o first changing
360        # torch::jit::JITException to extend c10::Error.
361        with self.assertRaisesRegex(torch.jit.Error, "foo"):
362            ft = FooTest2()
363            loaded = self.getScriptExportImportCopy(ft)
364            loaded()
365
366    def test_source_range_function_call(self):
367        class FooTest3(torch.jit.ScriptModule):
368            @torch.jit.script_method
369            def add_method(self, x, w):
370                return x + w
371
372            @torch.jit.script_method
373            def forward(self, x, y, w):
374                x = x * y
375                x = x + 2
376                return self.add_method(x, w)
377
378        ft = FooTest3()
379        loaded = self.getScriptExportImportCopy(ft)
380        _, lineno = inspect.getsourcelines(FooTest3)
381
382        try:
383            loaded(torch.rand(3, 4), torch.rand(3, 4), torch.rand(30, 40))
384        except RuntimeError as e:
385            error_message = f"{e}"
386        self.assertTrue(
387            f'test_lite_script_module.py", line {lineno + 3}' in error_message
388        )
389        self.assertTrue(
390            f'test_lite_script_module.py", line {lineno + 9}' in error_message
391        )
392        self.assertTrue("top(FooTest3)" in error_message)
393
394    def test_source_range_no_debug_info(self):
395        class FooTest4(torch.jit.ScriptModule):
396            @torch.jit.script_method
397            def forward(self, x, w):
398                return torch.mm(x, w.t())
399
400        ft = FooTest4()
401        loaded = self.getScriptExportImportCopy(ft, save_mobile_debug_info=False)
402
403        try:
404            loaded(torch.rand(3, 4), torch.rand(30, 40))
405        except RuntimeError as e:
406            error_message = f"{e}"
407        self.assertTrue("test_lite_script_module.py" not in error_message)
408
409    def test_source_range_raise_exc(self):
410        class FooTest5(torch.jit.ScriptModule):
411            def __init__(self, val: int):
412                super().__init__()
413                self.val = val
414
415            @torch.jit.script_method
416            def add_method(self, val: int, x, w):
417                if val == self.val:
418                    raise RuntimeError("self.val and val are same")
419                return x + w
420
421            @torch.jit.script_method
422            def forward(self, val: int, x, y, w):
423                x = x * y
424                x = x + 2
425                return self.add_method(val, x, w)
426
427        ft = FooTest5(42)
428        loaded = self.getScriptExportImportCopy(ft)
429        _, lineno = inspect.getsourcelines(FooTest5)
430
431        try:
432            loaded(42, torch.rand(3, 4), torch.rand(3, 4), torch.rand(30, 40))
433        except torch.jit.Error as e:
434            error_message = f"{e}"
435
436        # In C++ code, the type of exception thrown is torch::jit::JITException
437        # which does not extend c10::Error, and hence it isn't possible to add
438        # additional context to the exception message and preserve the correct
439        #  C++ stack trace for symbolication. i.e. it isn't possible to add
440        # the debug handle string to show where in the Python code the exception
441        # occured w/o first changing
442        # torch::jit::JITException to extend c10::Error.
443        self.assertTrue("self.val and val are same" in error_message)
444
445    def test_stacktrace_interface_call(self):
446        @torch.jit.interface
447        class Forward(torch.nn.Module):
448            def forward(self, x) -> torch.Tensor:
449                pass
450
451            def forwardError(self, x) -> torch.Tensor:
452                pass
453
454        class B(torch.nn.Module):
455            def forward(self, x):
456                return x
457
458            def forwardError(self, x):
459                return self.call() + x
460
461            def call(self):
462                return torch.ones(-1)
463
464        class A(torch.nn.Module):
465            b: Forward
466
467            def __init__(self) -> None:
468                super().__init__()
469                self.b = B()
470
471            def forward(self):
472                self.b.forward(torch.ones(1))
473                self.b.forwardError(torch.ones(1))
474
475        a = torch.jit.script(A())
476        torch._C._enable_mobile_interface_call_export()
477        buffer = io.BytesIO(
478            a._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True)
479        )
480        buffer.seek(0)
481        mobile_module = _load_for_lite_interpreter(buffer)
482        try:
483            mobile_module()
484            self.assertTrue(False)
485        except RuntimeError as exp:
486            FileCheck().check("Trying to create tensor with negative dimension").check(
487                "Traceback of TorchScript"
488            ).check("self.b.forwardError").check_next(
489                "~~~~~~~~~~~~~~~~~~~ <--- HERE"
490            ).check(
491                "return self.call"
492            ).check_next(
493                "~~~~~~~~~ <--- HERE"
494            ).check(
495                "return torch.ones"
496            ).check_next(
497                "~~~~~~~~~~ <--- HERE"
498            ).run(
499                str(exp)
500            )
501
502
503class TestLiteScriptQuantizedModule(QuantizationLiteTestCase):
504    def test_single_layer(self):
505        input = torch.rand(2, 5, dtype=torch.float)
506        quantized_model = self._create_quantized_model(
507            model_class=AnnotatedSingleLayerLinearModel, qengine="qnnpack"
508        )
509        self._compare_script_and_mobile(model=quantized_model, input=input)
510
511    def test_two_layer(self):
512        input = torch.rand(2, 5, dtype=torch.float)
513        quantized_model = self._create_quantized_model(model_class=TwoLayerLinearModel)
514        self._compare_script_and_mobile(model=quantized_model, input=input)
515
516    def test_annotated_nested(self):
517        input = torch.rand(2, 5, dtype=torch.float)
518        quantized_model = self._create_quantized_model(
519            model_class=AnnotatedNestedModel, qengine="qnnpack"
520        )
521        self._compare_script_and_mobile(model=quantized_model, input=input)
522
523    def test_quantization_example(self):
524        # From the example in Static Quantization section of https://pytorch.org/docs/stable/quantization.html
525        class M(torch.nn.Module):
526            def __init__(self) -> None:
527                super().__init__()
528                self.quant = torch.ao.quantization.QuantStub()
529                self.conv = torch.nn.Conv2d(1, 1, 1)
530                self.relu = torch.nn.ReLU()
531                self.dequant = torch.ao.quantization.DeQuantStub()
532
533            def forward(self, x):
534                x = self.quant(x)
535                x = self.conv(x)
536                x = self.relu(x)
537                x = self.dequant(x)
538                return x
539
540        model_fp32 = M()
541
542        model_fp32.eval()
543        model_fp32.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
544        model_fp32_fused = torch.ao.quantization.fuse_modules(
545            model_fp32, [["conv", "relu"]]
546        )
547        model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused)
548        input_fp32 = torch.randn(4, 1, 4, 4)
549        model_fp32_prepared(input_fp32)
550        model_int8 = torch.ao.quantization.convert(model_fp32_prepared)
551
552        input = torch.randn(4, 1, 4, 4)
553        self._compare_script_and_mobile(model=model_int8, input=input)
554
555    def test_bundled_input_with_dynamic_type(self):
556        class Model(torch.nn.Module):
557            def forward(
558                self,
559                x: Dict[int, torch.Tensor],
560                y: Dict[int, torch.Tensor],
561                z: Dict[int, torch.Tensor],
562            ):
563                return x
564
565        model = Model()
566        script_module = torch.jit.script(model)
567
568        sample_input = {
569            script_module.forward: [
570                (
571                    {0: torch.ones(1)},
572                    {1: torch.ones(1)},
573                    {2: torch.ones(1)},
574                )
575            ]
576        }
577
578        bundled_model = torch.utils.bundled_inputs.bundle_inputs(
579            script_module, sample_input
580        )
581
582        buf = bundled_model._save_to_buffer_for_lite_interpreter()
583        mobile_module = _load_for_lite_interpreter(io.BytesIO(buf))
584
585        i = mobile_module.run_method("get_all_bundled_inputs")
586
587        self.assertEqual(
588            i[0],
589            (
590                {0: torch.ones(1)},
591                {1: torch.ones(1)},
592                {2: torch.ones(1)},
593            ),
594        )
595
596
597if __name__ == "__main__":
598    run_tests()
599