xref: /aosp_15_r20/external/pytorch/test/mobile/test_bytecode.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: mobile"]
2
3import fnmatch
4import io
5import shutil
6import tempfile
7from pathlib import Path
8
9import torch
10import torch.utils.show_pickle
11
12# from torch.utils.mobile_optimizer import optimize_for_mobile
13from torch.jit.mobile import (
14    _backport_for_mobile,
15    _backport_for_mobile_to_buffer,
16    _get_mobile_model_contained_types,
17    _get_model_bytecode_version,
18    _get_model_ops_and_info,
19    _load_for_lite_interpreter,
20)
21from torch.testing._internal.common_utils import run_tests, TestCase
22
23
24pytorch_test_dir = Path(__file__).resolve().parents[1]
25
26# script_module_v4.ptl and script_module_v5.ptl source code
27# class TestModule(torch.nn.Module):
28#     def __init__(self, v):
29#         super().__init__()
30#         self.x = v
31
32#     def forward(self, y: int):
33#         increment = torch.ones([2, 4], dtype=torch.float64)
34#         return self.x + y + increment
35
36# output_model_path = Path(tmpdirname, "script_module_v5.ptl")
37# script_module = torch.jit.script(TestModule(1))
38# optimized_scripted_module = optimize_for_mobile(script_module)
39# exported_optimized_scripted_module = optimized_scripted_module._save_for_lite_interpreter(
40#   str(output_model_path))
41
42SCRIPT_MODULE_V4_BYTECODE_PKL = """
43(4,
44 ('__torch__.*.TestModule.forward',
45  (('instructions',
46    (('STOREN', 1, 2),
47     ('DROPR', 1, 0),
48     ('LOADC', 0, 0),
49     ('LOADC', 1, 0),
50     ('MOVE', 2, 0),
51     ('OP', 0, 0),
52     ('LOADC', 1, 0),
53     ('OP', 1, 0),
54     ('RET', 0, 0))),
55   ('operators', (('aten::add', 'int'), ('aten::add', 'Scalar'))),
56   ('constants',
57    (torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.DoubleStorage, '0', 'cpu', 8),),
58       0,
59       (2, 4),
60       (4, 1),
61       False,
62       collections.OrderedDict()),
63     1)),
64   ('types', ()),
65   ('register_size', 2)),
66  (('arguments',
67    ((('name', 'self'),
68      ('type', '__torch__.*.TestModule'),
69      ('default_value', None)),
70     (('name', 'y'), ('type', 'int'), ('default_value', None)))),
71   ('returns',
72    ((('name', ''), ('type', 'Tensor'), ('default_value', None)),)))))
73        """
74
75SCRIPT_MODULE_V5_BYTECODE_PKL = """
76(5,
77 ('__torch__.*.TestModule.forward',
78  (('instructions',
79    (('STOREN', 1, 2),
80     ('DROPR', 1, 0),
81     ('LOADC', 0, 0),
82     ('LOADC', 1, 0),
83     ('MOVE', 2, 0),
84     ('OP', 0, 0),
85     ('LOADC', 1, 0),
86     ('OP', 1, 0),
87     ('RET', 0, 0))),
88   ('operators', (('aten::add', 'int'), ('aten::add', 'Scalar'))),
89   ('constants',
90    (torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.DoubleStorage, 'constants/0', 'cpu', 8),),
91       0,
92       (2, 4),
93       (4, 1),
94       False,
95       collections.OrderedDict()),
96     1)),
97   ('types', ()),
98   ('register_size', 2)),
99  (('arguments',
100    ((('name', 'self'),
101      ('type', '__torch__.*.TestModule'),
102      ('default_value', None)),
103     (('name', 'y'), ('type', 'int'), ('default_value', None)))),
104   ('returns',
105    ((('name', ''), ('type', 'Tensor'), ('default_value', None)),)))))
106        """
107
108SCRIPT_MODULE_V6_BYTECODE_PKL = """
109(6,
110 ('__torch__.*.TestModule.forward',
111  (('instructions',
112    (('STOREN', 1, 2),
113     ('DROPR', 1, 0),
114     ('LOADC', 0, 0),
115     ('LOADC', 1, 0),
116     ('MOVE', 2, 0),
117     ('OP', 0, 0),
118     ('OP', 1, 0),
119     ('RET', 0, 0))),
120   ('operators', (('aten::add', 'int', 2), ('aten::add', 'Scalar', 2))),
121   ('constants',
122    (torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.DoubleStorage, '0', 'cpu', 8),),
123       0,
124       (2, 4),
125       (4, 1),
126       False,
127       collections.OrderedDict()),
128     1)),
129   ('types', ()),
130   ('register_size', 2)),
131  (('arguments',
132    ((('name', 'self'),
133      ('type', '__torch__.*.TestModule'),
134      ('default_value', None)),
135     (('name', 'y'), ('type', 'int'), ('default_value', None)))),
136   ('returns',
137    ((('name', ''), ('type', 'Tensor'), ('default_value', None)),)))))
138    """
139
140SCRIPT_MODULE_BYTECODE_PKL = {
141    4: {
142        "bytecode_pkl": SCRIPT_MODULE_V4_BYTECODE_PKL,
143        "model_name": "script_module_v4.ptl",
144    },
145}
146
147# The minimum version a model can be backported to
148# Need to be updated when a bytecode version is completely retired
149MINIMUM_TO_VERSION = 4
150
151
152class testVariousModelVersions(TestCase):
153    def test_get_model_bytecode_version(self):
154        def check_model_version(model_path, expect_version):
155            actual_version = _get_model_bytecode_version(model_path)
156            assert actual_version == expect_version
157
158        for version, model_info in SCRIPT_MODULE_BYTECODE_PKL.items():
159            model_path = pytorch_test_dir / "cpp" / "jit" / model_info["model_name"]
160            check_model_version(model_path, version)
161
162    def test_bytecode_values_for_all_backport_functions(self):
163        # Find the maximum version of the checked in models, start backporting to the minimum support version,
164        # and comparing the bytecode pkl content.
165        # It can't be merged to the test `test_all_backport_functions`, because optimization is dynamic and
166        # the content might change when optimize function changes. This test focuses
167        # on bytecode.pkl content validation. For the content validation, it is not byte to byte check, but
168        # regular expression matching. The wildcard can be used to skip some specific content comparison.
169        maximum_checked_in_model_version = max(SCRIPT_MODULE_BYTECODE_PKL.keys())
170        current_from_version = maximum_checked_in_model_version
171
172        with tempfile.TemporaryDirectory() as tmpdirname:
173            while current_from_version > MINIMUM_TO_VERSION:
174                # Load model v5 and run forward method
175                model_name = SCRIPT_MODULE_BYTECODE_PKL[current_from_version][
176                    "model_name"
177                ]
178                input_model_path = pytorch_test_dir / "cpp" / "jit" / model_name
179
180                # A temporary model file will be export to this path, and run through bytecode.pkl
181                # content check.
182                tmp_output_model_path_backport = Path(
183                    tmpdirname, "tmp_script_module_backport.ptl"
184                )
185
186                current_to_version = current_from_version - 1
187                backport_success = _backport_for_mobile(
188                    input_model_path, tmp_output_model_path_backport, current_to_version
189                )
190                assert backport_success
191
192                expect_bytecode_pkl = SCRIPT_MODULE_BYTECODE_PKL[current_to_version][
193                    "bytecode_pkl"
194                ]
195
196                buf = io.StringIO()
197                torch.utils.show_pickle.main(
198                    [
199                        "",
200                        tmpdirname
201                        + "/"
202                        + tmp_output_model_path_backport.name
203                        + "@*/bytecode.pkl",
204                    ],
205                    output_stream=buf,
206                )
207                output = buf.getvalue()
208
209                acutal_result_clean = "".join(output.split())
210                expect_result_clean = "".join(expect_bytecode_pkl.split())
211                isMatch = fnmatch.fnmatch(acutal_result_clean, expect_result_clean)
212                assert isMatch
213
214                current_from_version -= 1
215            shutil.rmtree(tmpdirname)
216
217    # Please run this test manually when working on backport.
218    # This test passes in OSS, but fails internally, likely due to missing step in build
219    # def test_all_backport_functions(self):
220    #     # Backport from the latest bytecode version to the minimum support version
221    #     # Load, run the backport model, and check version
222    #     class TestModule(torch.nn.Module):
223    #         def __init__(self, v):
224    #             super().__init__()
225    #             self.x = v
226
227    #         def forward(self, y: int):
228    #             increment = torch.ones([2, 4], dtype=torch.float64)
229    #             return self.x + y + increment
230
231    #     module_input = 1
232    #     expected_mobile_module_result = 3 * torch.ones([2, 4], dtype=torch.float64)
233
234    #     # temporary input model file and output model file will be exported in the temporary folder
235    #     with tempfile.TemporaryDirectory() as tmpdirname:
236    #         tmp_input_model_path = Path(tmpdirname, "tmp_script_module.ptl")
237    #         script_module = torch.jit.script(TestModule(1))
238    #         optimized_scripted_module = optimize_for_mobile(script_module)
239    #         exported_optimized_scripted_module = optimized_scripted_module._save_for_lite_interpreter(str(tmp_input_model_path))
240
241    #         current_from_version = _get_model_bytecode_version(tmp_input_model_path)
242    #         current_to_version = current_from_version - 1
243    #         tmp_output_model_path = Path(tmpdirname, "tmp_script_module_backport.ptl")
244
245    #         while current_to_version >= MINIMUM_TO_VERSION:
246    #             # Backport the latest model to `to_version` to a tmp file "tmp_script_module_backport"
247    #             backport_success = _backport_for_mobile(tmp_input_model_path, tmp_output_model_path, current_to_version)
248    #             assert(backport_success)
249
250    #             backport_version = _get_model_bytecode_version(tmp_output_model_path)
251    #             assert(backport_version == current_to_version)
252
253    #             # Load model and run forward method
254    #             mobile_module = _load_for_lite_interpreter(str(tmp_input_model_path))
255    #             mobile_module_result = mobile_module(module_input)
256    #             torch.testing.assert_close(mobile_module_result, expected_mobile_module_result)
257    #             current_to_version -= 1
258
259    #         # Check backport failure case
260    #         backport_success = _backport_for_mobile(tmp_input_model_path, tmp_output_model_path, MINIMUM_TO_VERSION - 1)
261    #         assert(not backport_success)
262    #         # need to clean the folder before it closes, otherwise will run into git not clean error
263    #         shutil.rmtree(tmpdirname)
264
265    # Check just the test_backport_bytecode_from_file_to_file mechanism but not the function implementations
266    def test_backport_bytecode_from_file_to_file(self):
267        maximum_checked_in_model_version = max(SCRIPT_MODULE_BYTECODE_PKL.keys())
268        script_module_v5_path = (
269            pytorch_test_dir
270            / "cpp"
271            / "jit"
272            / SCRIPT_MODULE_BYTECODE_PKL[maximum_checked_in_model_version]["model_name"]
273        )
274
275        if maximum_checked_in_model_version > MINIMUM_TO_VERSION:
276            with tempfile.TemporaryDirectory() as tmpdirname:
277                tmp_backport_model_path = Path(
278                    tmpdirname, "tmp_script_module_v5_backported_to_v4.ptl"
279                )
280                # backport from file
281                success = _backport_for_mobile(
282                    script_module_v5_path,
283                    tmp_backport_model_path,
284                    maximum_checked_in_model_version - 1,
285                )
286                assert success
287
288                buf = io.StringIO()
289                torch.utils.show_pickle.main(
290                    [
291                        "",
292                        tmpdirname
293                        + "/"
294                        + tmp_backport_model_path.name
295                        + "@*/bytecode.pkl",
296                    ],
297                    output_stream=buf,
298                )
299                output = buf.getvalue()
300
301                expected_result = SCRIPT_MODULE_V4_BYTECODE_PKL
302                acutal_result_clean = "".join(output.split())
303                expect_result_clean = "".join(expected_result.split())
304                isMatch = fnmatch.fnmatch(acutal_result_clean, expect_result_clean)
305                assert isMatch
306
307                # Load model v4 and run forward method
308                mobile_module = _load_for_lite_interpreter(str(tmp_backport_model_path))
309                module_input = 1
310                mobile_module_result = mobile_module(module_input)
311                expected_mobile_module_result = 3 * torch.ones(
312                    [2, 4], dtype=torch.float64
313                )
314                torch.testing.assert_close(
315                    mobile_module_result, expected_mobile_module_result
316                )
317                shutil.rmtree(tmpdirname)
318
319    # Check just the _backport_for_mobile_to_buffer mechanism but not the function implementations
320    def test_backport_bytecode_from_file_to_buffer(self):
321        maximum_checked_in_model_version = max(SCRIPT_MODULE_BYTECODE_PKL.keys())
322        script_module_v5_path = (
323            pytorch_test_dir
324            / "cpp"
325            / "jit"
326            / SCRIPT_MODULE_BYTECODE_PKL[maximum_checked_in_model_version]["model_name"]
327        )
328
329        if maximum_checked_in_model_version > MINIMUM_TO_VERSION:
330            # Backport model to v4
331            script_module_v4_buffer = _backport_for_mobile_to_buffer(
332                script_module_v5_path, maximum_checked_in_model_version - 1
333            )
334            buf = io.StringIO()
335
336            # Check version of the model v4 from backport
337            bytesio = io.BytesIO(script_module_v4_buffer)
338            backport_version = _get_model_bytecode_version(bytesio)
339            assert backport_version == maximum_checked_in_model_version - 1
340
341            # Load model v4 from backport and run forward method
342            bytesio = io.BytesIO(script_module_v4_buffer)
343            mobile_module = _load_for_lite_interpreter(bytesio)
344            module_input = 1
345            mobile_module_result = mobile_module(module_input)
346            expected_mobile_module_result = 3 * torch.ones([2, 4], dtype=torch.float64)
347            torch.testing.assert_close(
348                mobile_module_result, expected_mobile_module_result
349            )
350
351    def test_get_model_ops_and_info(self):
352        # TODO update this to be more in the style of the above tests after a backport from 6 -> 5 exists
353        script_module_v6 = pytorch_test_dir / "cpp" / "jit" / "script_module_v6.ptl"
354        ops_v6 = _get_model_ops_and_info(script_module_v6)
355        assert ops_v6["aten::add.int"].num_schema_args == 2
356        assert ops_v6["aten::add.Scalar"].num_schema_args == 2
357
358    def test_get_mobile_model_contained_types(self):
359        class MyTestModule(torch.nn.Module):
360            def forward(self, x):
361                return x + 10
362
363        sample_input = torch.tensor([1])
364
365        script_module = torch.jit.script(MyTestModule())
366        script_module_result = script_module(sample_input)
367
368        buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
369        buffer.seek(0)
370        type_list = _get_mobile_model_contained_types(buffer)
371        assert len(type_list) >= 0
372
373
374if __name__ == "__main__":
375    run_tests()
376