xref: /aosp_15_r20/external/pytorch/test/jit/test_upgraders.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport io
4*da0073e9SAndroid Build Coastguard Workerimport os
5*da0073e9SAndroid Build Coastguard Workerimport sys
6*da0073e9SAndroid Build Coastguard Workerimport zipfile
7*da0073e9SAndroid Build Coastguard Workerfrom typing import Union
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport torch
10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable
14*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
15*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir)
16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
20*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(
21*da0073e9SAndroid Build Coastguard Worker        "This test file is not meant to be run directly, use:\n\n"
22*da0073e9SAndroid Build Coastguard Worker        "\tpython test/test_jit.py TESTNAME\n\n"
23*da0073e9SAndroid Build Coastguard Worker        "instead."
24*da0073e9SAndroid Build Coastguard Worker    )
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Workerclass TestUpgraders(JitTestCase):
28*da0073e9SAndroid Build Coastguard Worker    def _load_model_version(self, loaded_model):
29*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
30*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(loaded_model, buffer)
31*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
32*da0073e9SAndroid Build Coastguard Worker        zipped_model = zipfile.ZipFile(buffer)
33*da0073e9SAndroid Build Coastguard Worker        # there was a change in how we store version number
34*da0073e9SAndroid Build Coastguard Worker        # in a package between version 3 and 7.
35*da0073e9SAndroid Build Coastguard Worker        # So we have to check for both.
36*da0073e9SAndroid Build Coastguard Worker        try:
37*da0073e9SAndroid Build Coastguard Worker            version = int(zipped_model.read("archive/version").decode("utf-8"))
38*da0073e9SAndroid Build Coastguard Worker            return version
39*da0073e9SAndroid Build Coastguard Worker        except KeyError:
40*da0073e9SAndroid Build Coastguard Worker            version = int(zipped_model.read("archive/.data/version").decode("utf-8"))
41*da0073e9SAndroid Build Coastguard Worker            return version
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker    # TODO (tugsuu) We should ideally be generating this test cases.
44*da0073e9SAndroid Build Coastguard Worker    def test_populated_upgrader_graph(self):
45*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
46*da0073e9SAndroid Build Coastguard Worker        def f():
47*da0073e9SAndroid Build Coastguard Worker            return 0
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
50*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(f, buffer)
51*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
52*da0073e9SAndroid Build Coastguard Worker        torch.jit.load(buffer)
53*da0073e9SAndroid Build Coastguard Worker        upgraders_size = torch._C._get_upgraders_map_size()
54*da0073e9SAndroid Build Coastguard Worker        upgraders_dump = torch._C._dump_upgraders_map()
55*da0073e9SAndroid Build Coastguard Worker        # make sure we only populate the upgrader map only once
56*da0073e9SAndroid Build Coastguard Worker        # so we load it again and make sure the upgrader map has
57*da0073e9SAndroid Build Coastguard Worker        # same content
58*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
59*da0073e9SAndroid Build Coastguard Worker        torch.jit.load(buffer)
60*da0073e9SAndroid Build Coastguard Worker        upgraders_size_second_time = torch._C._get_upgraders_map_size()
61*da0073e9SAndroid Build Coastguard Worker        upgraders_dump_second_time = torch._C._dump_upgraders_map()
62*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(upgraders_size == upgraders_size_second_time)
63*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(upgraders_dump == upgraders_dump_second_time)
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker    def test_add_value_to_version_map(self):
66*da0073e9SAndroid Build Coastguard Worker        map_before_test = torch._C._get_operator_version_map()
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker        upgrader_bumped_version = 3
69*da0073e9SAndroid Build Coastguard Worker        upgrader_name = "_test_serialization_subcmul_0_2"
70*da0073e9SAndroid Build Coastguard Worker        upgrader_schema = "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=2) -> Tensor"
71*da0073e9SAndroid Build Coastguard Worker        dummy_entry = torch._C._UpgraderEntry(
72*da0073e9SAndroid Build Coastguard Worker            upgrader_bumped_version, upgrader_name, upgrader_schema
73*da0073e9SAndroid Build Coastguard Worker        )
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker        torch._C._test_only_add_entry_to_op_version_map(
76*da0073e9SAndroid Build Coastguard Worker            "aten::_test_serialization_subcmul", dummy_entry
77*da0073e9SAndroid Build Coastguard Worker        )
78*da0073e9SAndroid Build Coastguard Worker        map_after_test = torch._C._get_operator_version_map()
79*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("aten::_test_serialization_subcmul" in map_after_test)
80*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(len(map_after_test) - len(map_before_test) == 1)
81*da0073e9SAndroid Build Coastguard Worker        torch._C._test_only_remove_entry_to_op_version_map(
82*da0073e9SAndroid Build Coastguard Worker            "aten::_test_serialization_subcmul"
83*da0073e9SAndroid Build Coastguard Worker        )
84*da0073e9SAndroid Build Coastguard Worker        map_after_remove_test = torch._C._get_operator_version_map()
85*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
86*da0073e9SAndroid Build Coastguard Worker            "aten::_test_serialization_subcmul" not in map_after_remove_test
87*da0073e9SAndroid Build Coastguard Worker        )
88*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(map_after_remove_test), len(map_before_test))
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker    def test_populated_test_upgrader_graph(self):
91*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
92*da0073e9SAndroid Build Coastguard Worker        def f():
93*da0073e9SAndroid Build Coastguard Worker            return 0
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
96*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(f, buffer)
97*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
98*da0073e9SAndroid Build Coastguard Worker        torch.jit.load(buffer)
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker        # upgrader map should have populated now
101*da0073e9SAndroid Build Coastguard Worker        upgraders_size = torch._C._get_upgraders_map_size()
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker        test_map = {"a": str(torch._C.Graph()), "c": str(torch._C.Graph())}
104*da0073e9SAndroid Build Coastguard Worker        torch._C._test_only_populate_upgraders(test_map)
105*da0073e9SAndroid Build Coastguard Worker        upgraders_size_after_test = torch._C._get_upgraders_map_size()
106*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(upgraders_size_after_test - upgraders_size, 2)
107*da0073e9SAndroid Build Coastguard Worker        upgraders_dump = torch._C._dump_upgraders_map()
108*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("a" in upgraders_dump)
109*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("c" in upgraders_dump)
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker        torch._C._test_only_remove_upgraders(test_map)
112*da0073e9SAndroid Build Coastguard Worker        upgraders_size_after_remove_test = torch._C._get_upgraders_map_size()
113*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(upgraders_size_after_remove_test == upgraders_size)
114*da0073e9SAndroid Build Coastguard Worker        upgraders_dump_after_remove_test = torch._C._dump_upgraders_map()
115*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("a" not in upgraders_dump_after_remove_test)
116*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("c" not in upgraders_dump_after_remove_test)
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker    def test_aten_div_tensor_at_3(self):
119*da0073e9SAndroid Build Coastguard Worker        model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_v3.pt"
120*da0073e9SAndroid Build Coastguard Worker        loaded_model = torch.jit.load(model_path)
121*da0073e9SAndroid Build Coastguard Worker        # there are 3 aten::div in this model
122*da0073e9SAndroid Build Coastguard Worker        # And the upgrader for aten::div uses two
123*da0073e9SAndroid Build Coastguard Worker        # div's because of if/else branch
124*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("prim::If").run(loaded_model.graph)
125*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::div", 6).run(loaded_model.graph)
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
128*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(loaded_model, buffer)
129*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
130*da0073e9SAndroid Build Coastguard Worker        version = self._load_model_version(loaded_model)
131*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(version == 4)
132*da0073e9SAndroid Build Coastguard Worker        loaded_model_twice = torch.jit.load(buffer)
133*da0073e9SAndroid Build Coastguard Worker        # we check by its code because graph variable names
134*da0073e9SAndroid Build Coastguard Worker        # can be different every time
135*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(loaded_model.code, loaded_model_twice.code)
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker    def test_aten_full_other_variants(self):
138*da0073e9SAndroid Build Coastguard Worker        def test_func():
139*da0073e9SAndroid Build Coastguard Worker            a = torch.full([4, 5, 6], 4, names=["a", "b", "c"], dtype=torch.int64)
140*da0073e9SAndroid Build Coastguard Worker            return a
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker        scripted_func = torch.jit.script(test_func)
143*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
144*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(scripted_func, buffer)
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker        current_flag_value = torch._C._get_version_calculator_flag()
147*da0073e9SAndroid Build Coastguard Worker        # calculate based on old version
148*da0073e9SAndroid Build Coastguard Worker        torch._C._calculate_package_version_based_on_upgraders(False)
149*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
150*da0073e9SAndroid Build Coastguard Worker        loaded_func = torch.jit.load(buffer)
151*da0073e9SAndroid Build Coastguard Worker        version = self._load_model_version(loaded_func)
152*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(version == 5)
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker        # calculate based on new version
155*da0073e9SAndroid Build Coastguard Worker        torch._C._calculate_package_version_based_on_upgraders(True)
156*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
157*da0073e9SAndroid Build Coastguard Worker        loaded_func = torch.jit.load(buffer)
158*da0073e9SAndroid Build Coastguard Worker        version = self._load_model_version(loaded_func)
159*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(version == 5)
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker        # make sure we preserve old behaviou
162*da0073e9SAndroid Build Coastguard Worker        torch._C._calculate_package_version_based_on_upgraders(current_flag_value)
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker    def test_aten_linspace(self):
165*da0073e9SAndroid Build Coastguard Worker        model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_v7.ptl"
166*da0073e9SAndroid Build Coastguard Worker        loaded_model = torch.jit.load(model_path)
167*da0073e9SAndroid Build Coastguard Worker        sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j))
168*da0073e9SAndroid Build Coastguard Worker        for a, b in sample_inputs:
169*da0073e9SAndroid Build Coastguard Worker            output_with_step, output_without_step = loaded_model(a, b)
170*da0073e9SAndroid Build Coastguard Worker            # when no step is given, should have used 100
171*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(output_without_step.size(dim=0) == 100)
172*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(output_with_step.size(dim=0) == 5)
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker        version = self._load_model_version(loaded_model)
175*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(version == 8)
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker    def test_aten_linspace_out(self):
178*da0073e9SAndroid Build Coastguard Worker        model_path = (
179*da0073e9SAndroid Build Coastguard Worker            pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_out_v7.ptl"
180*da0073e9SAndroid Build Coastguard Worker        )
181*da0073e9SAndroid Build Coastguard Worker        loaded_model = torch.jit.load(model_path)
182*da0073e9SAndroid Build Coastguard Worker        sample_inputs = (
183*da0073e9SAndroid Build Coastguard Worker            (3, 10, torch.empty((100,), dtype=torch.int64)),
184*da0073e9SAndroid Build Coastguard Worker            (-10, 10, torch.empty((100,), dtype=torch.int64)),
185*da0073e9SAndroid Build Coastguard Worker            (4.0, 6.0, torch.empty((100,), dtype=torch.float64)),
186*da0073e9SAndroid Build Coastguard Worker            (3 + 4j, 4 + 5j, torch.empty((100,), dtype=torch.complex64)),
187*da0073e9SAndroid Build Coastguard Worker        )
188*da0073e9SAndroid Build Coastguard Worker        for a, b, c in sample_inputs:
189*da0073e9SAndroid Build Coastguard Worker            output = loaded_model(a, b, c)
190*da0073e9SAndroid Build Coastguard Worker            # when no step is given, should have used 100
191*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(output.size(dim=0) == 100)
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker        version = self._load_model_version(loaded_model)
194*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(version == 8)
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Worker    def test_aten_logspace(self):
197*da0073e9SAndroid Build Coastguard Worker        model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_v8.ptl"
198*da0073e9SAndroid Build Coastguard Worker        loaded_model = torch.jit.load(model_path)
199*da0073e9SAndroid Build Coastguard Worker        sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j))
200*da0073e9SAndroid Build Coastguard Worker        for a, b in sample_inputs:
201*da0073e9SAndroid Build Coastguard Worker            output_with_step, output_without_step = loaded_model(a, b)
202*da0073e9SAndroid Build Coastguard Worker            # when no step is given, should have used 100
203*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(output_without_step.size(dim=0) == 100)
204*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(output_with_step.size(dim=0) == 5)
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker        version = self._load_model_version(loaded_model)
207*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(version == 9)
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker    def test_aten_logspace_out(self):
210*da0073e9SAndroid Build Coastguard Worker        model_path = (
211*da0073e9SAndroid Build Coastguard Worker            pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_out_v8.ptl"
212*da0073e9SAndroid Build Coastguard Worker        )
213*da0073e9SAndroid Build Coastguard Worker        loaded_model = torch.jit.load(model_path)
214*da0073e9SAndroid Build Coastguard Worker        sample_inputs = (
215*da0073e9SAndroid Build Coastguard Worker            (3, 10, torch.empty((100,), dtype=torch.int64)),
216*da0073e9SAndroid Build Coastguard Worker            (-10, 10, torch.empty((100,), dtype=torch.int64)),
217*da0073e9SAndroid Build Coastguard Worker            (4.0, 6.0, torch.empty((100,), dtype=torch.float64)),
218*da0073e9SAndroid Build Coastguard Worker            (3 + 4j, 4 + 5j, torch.empty((100,), dtype=torch.complex64)),
219*da0073e9SAndroid Build Coastguard Worker        )
220*da0073e9SAndroid Build Coastguard Worker        for a, b, c in sample_inputs:
221*da0073e9SAndroid Build Coastguard Worker            output = loaded_model(a, b, c)
222*da0073e9SAndroid Build Coastguard Worker            # when no step is given, should have used 100
223*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(output.size(dim=0) == 100)
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker        version = self._load_model_version(loaded_model)
226*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(version == 9)
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Worker    def test_aten_test_serialization(self):
229*da0073e9SAndroid Build Coastguard Worker        model_path = (
230*da0073e9SAndroid Build Coastguard Worker            pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt"
231*da0073e9SAndroid Build Coastguard Worker        )
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker        # add test version entry to the version map
234*da0073e9SAndroid Build Coastguard Worker        upgrader_bumped_version = 3
235*da0073e9SAndroid Build Coastguard Worker        upgrader_name = "_test_serialization_subcmul_0_2"
236*da0073e9SAndroid Build Coastguard Worker        upgrader_schema = "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=2) -> Tensor"
237*da0073e9SAndroid Build Coastguard Worker        dummy_entry = torch._C._UpgraderEntry(
238*da0073e9SAndroid Build Coastguard Worker            upgrader_bumped_version, upgrader_name, upgrader_schema
239*da0073e9SAndroid Build Coastguard Worker        )
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker        torch._C._test_only_add_entry_to_op_version_map(
242*da0073e9SAndroid Build Coastguard Worker            "aten::_test_serialization_subcmul", dummy_entry
243*da0073e9SAndroid Build Coastguard Worker        )
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker        # add test upgrader in the upgraders map
246*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
247*da0073e9SAndroid Build Coastguard Worker        def _test_serialization_subcmul_0_2(
248*da0073e9SAndroid Build Coastguard Worker            self: torch.Tensor, other: torch.Tensor, alpha: Union[int, float] = 2
249*da0073e9SAndroid Build Coastguard Worker        ) -> torch.Tensor:
250*da0073e9SAndroid Build Coastguard Worker            return other - (self * alpha)
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker        torch._C._test_only_populate_upgraders(
253*da0073e9SAndroid Build Coastguard Worker            {
254*da0073e9SAndroid Build Coastguard Worker                "_test_serialization_subcmul_0_2": str(
255*da0073e9SAndroid Build Coastguard Worker                    _test_serialization_subcmul_0_2.graph
256*da0073e9SAndroid Build Coastguard Worker                )
257*da0073e9SAndroid Build Coastguard Worker            }
258*da0073e9SAndroid Build Coastguard Worker        )
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker        # test if the server is able to find the test upgraders and apply to IR
261*da0073e9SAndroid Build Coastguard Worker        loaded_model = torch.jit.load(model_path)
262*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::mul", 2).run(loaded_model.graph)
263*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::sub", 2).run(loaded_model.graph)
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
266*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(loaded_model, buffer)
267*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
268*da0073e9SAndroid Build Coastguard Worker        version = self._load_model_version(loaded_model)
269*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(version == 3)
270*da0073e9SAndroid Build Coastguard Worker        loaded_model_twice = torch.jit.load(buffer)
271*da0073e9SAndroid Build Coastguard Worker        # we check by its' code because graph variable names
272*da0073e9SAndroid Build Coastguard Worker        # can be different every time
273*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(loaded_model.code, loaded_model_twice.code)
274*da0073e9SAndroid Build Coastguard Worker        torch._C._test_only_remove_entry_to_op_version_map(
275*da0073e9SAndroid Build Coastguard Worker            "aten::_test_serialization_subcmul"
276*da0073e9SAndroid Build Coastguard Worker        )
277*da0073e9SAndroid Build Coastguard Worker        torch._C._test_only_remove_upgraders(
278*da0073e9SAndroid Build Coastguard Worker            {
279*da0073e9SAndroid Build Coastguard Worker                "_test_serialization_subcmul_0_2": str(
280*da0073e9SAndroid Build Coastguard Worker                    _test_serialization_subcmul_0_2.graph
281*da0073e9SAndroid Build Coastguard Worker                )
282*da0073e9SAndroid Build Coastguard Worker            }
283*da0073e9SAndroid Build Coastguard Worker        )
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker    def test_aten_div_scalar_at_3(self):
286*da0073e9SAndroid Build Coastguard Worker        model_path = (
287*da0073e9SAndroid Build Coastguard Worker            pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v3.pt"
288*da0073e9SAndroid Build Coastguard Worker        )
289*da0073e9SAndroid Build Coastguard Worker        loaded_model = torch.jit.load(model_path)
290*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("prim::If").run(loaded_model.graph)
291*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::div", 2).run(loaded_model.graph)
292*da0073e9SAndroid Build Coastguard Worker
293*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
294*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(loaded_model, buffer)
295*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
296*da0073e9SAndroid Build Coastguard Worker        version = self._load_model_version(loaded_model)
297*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(version, 4)
298*da0073e9SAndroid Build Coastguard Worker        loaded_model_twice = torch.jit.load(buffer)
299*da0073e9SAndroid Build Coastguard Worker
300*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
301*da0073e9SAndroid Build Coastguard Worker            loaded_model(torch.Tensor([5.0, 3.0]), 2.0),
302*da0073e9SAndroid Build Coastguard Worker            loaded_model_twice(torch.Tensor([5.0, 3.0]), 2.0),
303*da0073e9SAndroid Build Coastguard Worker        )
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker    def test_aten_div_tensor_out_at_3(self):
306*da0073e9SAndroid Build Coastguard Worker        model_path = (
307*da0073e9SAndroid Build Coastguard Worker            pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_out_v3.pt"
308*da0073e9SAndroid Build Coastguard Worker        )
309*da0073e9SAndroid Build Coastguard Worker        loaded_model = torch.jit.load(model_path)
310*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("prim::If").run(loaded_model.graph)
311*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::div", 2).run(loaded_model.graph)
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
314*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(loaded_model, buffer)
315*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
316*da0073e9SAndroid Build Coastguard Worker        version = self._load_model_version(loaded_model)
317*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(version == 4)
318*da0073e9SAndroid Build Coastguard Worker        loaded_model_twice = torch.jit.load(buffer)
319*da0073e9SAndroid Build Coastguard Worker        # we check by its' code because graph variable names
320*da0073e9SAndroid Build Coastguard Worker        # can be different every time
321*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(loaded_model.code, loaded_model_twice.code)
322*da0073e9SAndroid Build Coastguard Worker
323*da0073e9SAndroid Build Coastguard Worker    def test_aten_full_at_4(self):
324*da0073e9SAndroid Build Coastguard Worker        model_path = (
325*da0073e9SAndroid Build Coastguard Worker            pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt"
326*da0073e9SAndroid Build Coastguard Worker        )
327*da0073e9SAndroid Build Coastguard Worker        loaded_model = torch.jit.load(model_path)
328*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::Float", 1).run(loaded_model.graph)
329*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::full", 2).run(loaded_model.graph)
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
332*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(loaded_model, buffer)
333*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
334*da0073e9SAndroid Build Coastguard Worker        version = self._load_model_version(loaded_model)
335*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(version == 5)
336*da0073e9SAndroid Build Coastguard Worker        loaded_model_twice = torch.jit.load(buffer)
337*da0073e9SAndroid Build Coastguard Worker        # we check by its' code because graph variable names
338*da0073e9SAndroid Build Coastguard Worker        # can be different every time
339*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(loaded_model.code, loaded_model_twice.code)
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard Worker    def test_aten_full_out_at_4(self):
342*da0073e9SAndroid Build Coastguard Worker        model_path = (
343*da0073e9SAndroid Build Coastguard Worker            pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt"
344*da0073e9SAndroid Build Coastguard Worker        )
345*da0073e9SAndroid Build Coastguard Worker        loaded_model = torch.jit.load(model_path)
346*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::full", 5).run(loaded_model.graph)
347*da0073e9SAndroid Build Coastguard Worker        version = self._load_model_version(loaded_model)
348*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(version == 5)
349