xref: /aosp_15_r20/external/pytorch/test/jit/test_upgraders.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import io
4import os
5import sys
6import zipfile
7from typing import Union
8
9import torch
10from torch.testing import FileCheck
11
12
13# Make the helper files in test/ importable
14pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
15sys.path.append(pytorch_test_dir)
16from torch.testing._internal.jit_utils import JitTestCase
17
18
19if __name__ == "__main__":
20    raise RuntimeError(
21        "This test file is not meant to be run directly, use:\n\n"
22        "\tpython test/test_jit.py TESTNAME\n\n"
23        "instead."
24    )
25
26
27class TestUpgraders(JitTestCase):
28    def _load_model_version(self, loaded_model):
29        buffer = io.BytesIO()
30        torch.jit.save(loaded_model, buffer)
31        buffer.seek(0)
32        zipped_model = zipfile.ZipFile(buffer)
33        # there was a change in how we store version number
34        # in a package between version 3 and 7.
35        # So we have to check for both.
36        try:
37            version = int(zipped_model.read("archive/version").decode("utf-8"))
38            return version
39        except KeyError:
40            version = int(zipped_model.read("archive/.data/version").decode("utf-8"))
41            return version
42
43    # TODO (tugsuu) We should ideally be generating this test cases.
44    def test_populated_upgrader_graph(self):
45        @torch.jit.script
46        def f():
47            return 0
48
49        buffer = io.BytesIO()
50        torch.jit.save(f, buffer)
51        buffer.seek(0)
52        torch.jit.load(buffer)
53        upgraders_size = torch._C._get_upgraders_map_size()
54        upgraders_dump = torch._C._dump_upgraders_map()
55        # make sure we only populate the upgrader map only once
56        # so we load it again and make sure the upgrader map has
57        # same content
58        buffer.seek(0)
59        torch.jit.load(buffer)
60        upgraders_size_second_time = torch._C._get_upgraders_map_size()
61        upgraders_dump_second_time = torch._C._dump_upgraders_map()
62        self.assertTrue(upgraders_size == upgraders_size_second_time)
63        self.assertTrue(upgraders_dump == upgraders_dump_second_time)
64
65    def test_add_value_to_version_map(self):
66        map_before_test = torch._C._get_operator_version_map()
67
68        upgrader_bumped_version = 3
69        upgrader_name = "_test_serialization_subcmul_0_2"
70        upgrader_schema = "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=2) -> Tensor"
71        dummy_entry = torch._C._UpgraderEntry(
72            upgrader_bumped_version, upgrader_name, upgrader_schema
73        )
74
75        torch._C._test_only_add_entry_to_op_version_map(
76            "aten::_test_serialization_subcmul", dummy_entry
77        )
78        map_after_test = torch._C._get_operator_version_map()
79        self.assertTrue("aten::_test_serialization_subcmul" in map_after_test)
80        self.assertTrue(len(map_after_test) - len(map_before_test) == 1)
81        torch._C._test_only_remove_entry_to_op_version_map(
82            "aten::_test_serialization_subcmul"
83        )
84        map_after_remove_test = torch._C._get_operator_version_map()
85        self.assertTrue(
86            "aten::_test_serialization_subcmul" not in map_after_remove_test
87        )
88        self.assertEqual(len(map_after_remove_test), len(map_before_test))
89
90    def test_populated_test_upgrader_graph(self):
91        @torch.jit.script
92        def f():
93            return 0
94
95        buffer = io.BytesIO()
96        torch.jit.save(f, buffer)
97        buffer.seek(0)
98        torch.jit.load(buffer)
99
100        # upgrader map should have populated now
101        upgraders_size = torch._C._get_upgraders_map_size()
102
103        test_map = {"a": str(torch._C.Graph()), "c": str(torch._C.Graph())}
104        torch._C._test_only_populate_upgraders(test_map)
105        upgraders_size_after_test = torch._C._get_upgraders_map_size()
106        self.assertEqual(upgraders_size_after_test - upgraders_size, 2)
107        upgraders_dump = torch._C._dump_upgraders_map()
108        self.assertTrue("a" in upgraders_dump)
109        self.assertTrue("c" in upgraders_dump)
110
111        torch._C._test_only_remove_upgraders(test_map)
112        upgraders_size_after_remove_test = torch._C._get_upgraders_map_size()
113        self.assertTrue(upgraders_size_after_remove_test == upgraders_size)
114        upgraders_dump_after_remove_test = torch._C._dump_upgraders_map()
115        self.assertTrue("a" not in upgraders_dump_after_remove_test)
116        self.assertTrue("c" not in upgraders_dump_after_remove_test)
117
118    def test_aten_div_tensor_at_3(self):
119        model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_v3.pt"
120        loaded_model = torch.jit.load(model_path)
121        # there are 3 aten::div in this model
122        # And the upgrader for aten::div uses two
123        # div's because of if/else branch
124        FileCheck().check("prim::If").run(loaded_model.graph)
125        FileCheck().check_count("aten::div", 6).run(loaded_model.graph)
126
127        buffer = io.BytesIO()
128        torch.jit.save(loaded_model, buffer)
129        buffer.seek(0)
130        version = self._load_model_version(loaded_model)
131        self.assertTrue(version == 4)
132        loaded_model_twice = torch.jit.load(buffer)
133        # we check by its code because graph variable names
134        # can be different every time
135        self.assertEqual(loaded_model.code, loaded_model_twice.code)
136
137    def test_aten_full_other_variants(self):
138        def test_func():
139            a = torch.full([4, 5, 6], 4, names=["a", "b", "c"], dtype=torch.int64)
140            return a
141
142        scripted_func = torch.jit.script(test_func)
143        buffer = io.BytesIO()
144        torch.jit.save(scripted_func, buffer)
145
146        current_flag_value = torch._C._get_version_calculator_flag()
147        # calculate based on old version
148        torch._C._calculate_package_version_based_on_upgraders(False)
149        buffer.seek(0)
150        loaded_func = torch.jit.load(buffer)
151        version = self._load_model_version(loaded_func)
152        self.assertTrue(version == 5)
153
154        # calculate based on new version
155        torch._C._calculate_package_version_based_on_upgraders(True)
156        buffer.seek(0)
157        loaded_func = torch.jit.load(buffer)
158        version = self._load_model_version(loaded_func)
159        self.assertTrue(version == 5)
160
161        # make sure we preserve old behaviou
162        torch._C._calculate_package_version_based_on_upgraders(current_flag_value)
163
164    def test_aten_linspace(self):
165        model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_v7.ptl"
166        loaded_model = torch.jit.load(model_path)
167        sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j))
168        for a, b in sample_inputs:
169            output_with_step, output_without_step = loaded_model(a, b)
170            # when no step is given, should have used 100
171            self.assertTrue(output_without_step.size(dim=0) == 100)
172            self.assertTrue(output_with_step.size(dim=0) == 5)
173
174        version = self._load_model_version(loaded_model)
175        self.assertTrue(version == 8)
176
177    def test_aten_linspace_out(self):
178        model_path = (
179            pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_out_v7.ptl"
180        )
181        loaded_model = torch.jit.load(model_path)
182        sample_inputs = (
183            (3, 10, torch.empty((100,), dtype=torch.int64)),
184            (-10, 10, torch.empty((100,), dtype=torch.int64)),
185            (4.0, 6.0, torch.empty((100,), dtype=torch.float64)),
186            (3 + 4j, 4 + 5j, torch.empty((100,), dtype=torch.complex64)),
187        )
188        for a, b, c in sample_inputs:
189            output = loaded_model(a, b, c)
190            # when no step is given, should have used 100
191            self.assertTrue(output.size(dim=0) == 100)
192
193        version = self._load_model_version(loaded_model)
194        self.assertTrue(version == 8)
195
196    def test_aten_logspace(self):
197        model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_v8.ptl"
198        loaded_model = torch.jit.load(model_path)
199        sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j))
200        for a, b in sample_inputs:
201            output_with_step, output_without_step = loaded_model(a, b)
202            # when no step is given, should have used 100
203            self.assertTrue(output_without_step.size(dim=0) == 100)
204            self.assertTrue(output_with_step.size(dim=0) == 5)
205
206        version = self._load_model_version(loaded_model)
207        self.assertTrue(version == 9)
208
209    def test_aten_logspace_out(self):
210        model_path = (
211            pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_out_v8.ptl"
212        )
213        loaded_model = torch.jit.load(model_path)
214        sample_inputs = (
215            (3, 10, torch.empty((100,), dtype=torch.int64)),
216            (-10, 10, torch.empty((100,), dtype=torch.int64)),
217            (4.0, 6.0, torch.empty((100,), dtype=torch.float64)),
218            (3 + 4j, 4 + 5j, torch.empty((100,), dtype=torch.complex64)),
219        )
220        for a, b, c in sample_inputs:
221            output = loaded_model(a, b, c)
222            # when no step is given, should have used 100
223            self.assertTrue(output.size(dim=0) == 100)
224
225        version = self._load_model_version(loaded_model)
226        self.assertTrue(version == 9)
227
228    def test_aten_test_serialization(self):
229        model_path = (
230            pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt"
231        )
232
233        # add test version entry to the version map
234        upgrader_bumped_version = 3
235        upgrader_name = "_test_serialization_subcmul_0_2"
236        upgrader_schema = "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=2) -> Tensor"
237        dummy_entry = torch._C._UpgraderEntry(
238            upgrader_bumped_version, upgrader_name, upgrader_schema
239        )
240
241        torch._C._test_only_add_entry_to_op_version_map(
242            "aten::_test_serialization_subcmul", dummy_entry
243        )
244
245        # add test upgrader in the upgraders map
246        @torch.jit.script
247        def _test_serialization_subcmul_0_2(
248            self: torch.Tensor, other: torch.Tensor, alpha: Union[int, float] = 2
249        ) -> torch.Tensor:
250            return other - (self * alpha)
251
252        torch._C._test_only_populate_upgraders(
253            {
254                "_test_serialization_subcmul_0_2": str(
255                    _test_serialization_subcmul_0_2.graph
256                )
257            }
258        )
259
260        # test if the server is able to find the test upgraders and apply to IR
261        loaded_model = torch.jit.load(model_path)
262        FileCheck().check_count("aten::mul", 2).run(loaded_model.graph)
263        FileCheck().check_count("aten::sub", 2).run(loaded_model.graph)
264
265        buffer = io.BytesIO()
266        torch.jit.save(loaded_model, buffer)
267        buffer.seek(0)
268        version = self._load_model_version(loaded_model)
269        self.assertTrue(version == 3)
270        loaded_model_twice = torch.jit.load(buffer)
271        # we check by its' code because graph variable names
272        # can be different every time
273        self.assertEqual(loaded_model.code, loaded_model_twice.code)
274        torch._C._test_only_remove_entry_to_op_version_map(
275            "aten::_test_serialization_subcmul"
276        )
277        torch._C._test_only_remove_upgraders(
278            {
279                "_test_serialization_subcmul_0_2": str(
280                    _test_serialization_subcmul_0_2.graph
281                )
282            }
283        )
284
285    def test_aten_div_scalar_at_3(self):
286        model_path = (
287            pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v3.pt"
288        )
289        loaded_model = torch.jit.load(model_path)
290        FileCheck().check("prim::If").run(loaded_model.graph)
291        FileCheck().check_count("aten::div", 2).run(loaded_model.graph)
292
293        buffer = io.BytesIO()
294        torch.jit.save(loaded_model, buffer)
295        buffer.seek(0)
296        version = self._load_model_version(loaded_model)
297        self.assertEqual(version, 4)
298        loaded_model_twice = torch.jit.load(buffer)
299
300        self.assertEqual(
301            loaded_model(torch.Tensor([5.0, 3.0]), 2.0),
302            loaded_model_twice(torch.Tensor([5.0, 3.0]), 2.0),
303        )
304
305    def test_aten_div_tensor_out_at_3(self):
306        model_path = (
307            pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_out_v3.pt"
308        )
309        loaded_model = torch.jit.load(model_path)
310        FileCheck().check("prim::If").run(loaded_model.graph)
311        FileCheck().check_count("aten::div", 2).run(loaded_model.graph)
312
313        buffer = io.BytesIO()
314        torch.jit.save(loaded_model, buffer)
315        buffer.seek(0)
316        version = self._load_model_version(loaded_model)
317        self.assertTrue(version == 4)
318        loaded_model_twice = torch.jit.load(buffer)
319        # we check by its' code because graph variable names
320        # can be different every time
321        self.assertEqual(loaded_model.code, loaded_model_twice.code)
322
323    def test_aten_full_at_4(self):
324        model_path = (
325            pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt"
326        )
327        loaded_model = torch.jit.load(model_path)
328        FileCheck().check_count("aten::Float", 1).run(loaded_model.graph)
329        FileCheck().check_count("aten::full", 2).run(loaded_model.graph)
330
331        buffer = io.BytesIO()
332        torch.jit.save(loaded_model, buffer)
333        buffer.seek(0)
334        version = self._load_model_version(loaded_model)
335        self.assertTrue(version == 5)
336        loaded_model_twice = torch.jit.load(buffer)
337        # we check by its' code because graph variable names
338        # can be different every time
339        self.assertEqual(loaded_model.code, loaded_model_twice.code)
340
341    def test_aten_full_out_at_4(self):
342        model_path = (
343            pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt"
344        )
345        loaded_model = torch.jit.load(model_path)
346        FileCheck().check_count("aten::full", 5).run(loaded_model.graph)
347        version = self._load_model_version(loaded_model)
348        self.assertTrue(version == 5)
349