xref: /aosp_15_r20/external/pytorch/test/jit/test_save_load_for_op_version.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import io
4import os
5import sys
6from itertools import product as product
7from typing import Union
8
9import hypothesis.strategies as st
10from hypothesis import example, given, settings
11
12import torch
13
14
15# Make the helper files in test/ importable
16pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
17sys.path.append(pytorch_test_dir)
18from torch.jit.mobile import _load_for_lite_interpreter
19from torch.testing._internal.jit_utils import JitTestCase
20
21
22if __name__ == "__main__":
23    raise RuntimeError(
24        "This test file is not meant to be run directly, use:\n\n"
25        "\tpython test/test_jit.py TESTNAME\n\n"
26        "instead."
27    )
28
29
30class TestSaveLoadForOpVersion(JitTestCase):
31    # Helper that returns the module after saving and loading
32    def _save_load_module(self, m):
33        scripted_module = torch.jit.script(m())
34        buffer = io.BytesIO()
35        torch.jit.save(scripted_module, buffer)
36        buffer.seek(0)
37        return torch.jit.load(buffer)
38
39    def _save_load_mobile_module(self, m):
40        scripted_module = torch.jit.script(m())
41        buffer = io.BytesIO(scripted_module._save_to_buffer_for_lite_interpreter())
42        buffer.seek(0)
43        return _load_for_lite_interpreter(buffer)
44
45    # Helper which returns the result of a function or the exception the
46    #   function threw.
47    def _try_fn(self, fn, *args, **kwargs):
48        try:
49            return fn(*args, **kwargs)
50        except Exception as e:
51            return e
52
53    def _verify_no(self, kind, m):
54        self._verify_count(kind, m, 0)
55
56    def _verify_count(self, kind, m, count):
57        node_count = sum(str(n).count(kind) for n in m.graph.nodes())
58        self.assertEqual(node_count, count)
59
60    """
61    Tests that verify Torchscript remaps aten::div(_) from versions 0-3
62    to call either aten::true_divide(_), if an input is a float type,
63    or truncated aten::divide(_) otherwise.
64    NOTE: currently compares against current div behavior, too, since
65      div behavior has not yet been updated.
66    """
67
68    @settings(
69        max_examples=10, deadline=200000
70    )  # A total of 10 examples will be generated
71    @given(
72        sample_input=st.tuples(
73            st.integers(min_value=5, max_value=199),
74            st.floats(min_value=5.0, max_value=199.0),
75        )
76    )  # Generate a pair (integer, float)
77    @example((2, 3, 2.0, 3.0))  # Ensure this example will be covered
78    def test_versioned_div_tensor(self, sample_input):
79        def historic_div(self, other):
80            if self.is_floating_point() or other.is_floating_point():
81                return self.true_divide(other)
82            return self.divide(other, rounding_mode="trunc")
83
84        # Tensor x Tensor
85        class MyModule(torch.nn.Module):
86            def forward(self, a, b):
87                result_0 = a / b
88                result_1 = torch.div(a, b)
89                result_2 = a.div(b)
90
91                return result_0, result_1, result_2
92
93        # Loads historic module
94        try:
95            v3_mobile_module = _load_for_lite_interpreter(
96                pytorch_test_dir
97                + "/cpp/jit/upgrader_models/test_versioned_div_tensor_v2.ptl"
98            )
99        except Exception as e:
100            self.skipTest("Failed to load fixture!")
101
102        current_mobile_module = self._save_load_mobile_module(MyModule)
103
104        for val_a, val_b in product(sample_input, sample_input):
105            a = torch.tensor((val_a,))
106            b = torch.tensor((val_b,))
107
108            def _helper(m, fn):
109                m_results = self._try_fn(m, a, b)
110                fn_result = self._try_fn(fn, a, b)
111
112                if isinstance(m_results, Exception):
113                    self.assertTrue(isinstance(fn_result, Exception))
114                else:
115                    for result in m_results:
116                        self.assertEqual(result, fn_result)
117
118            _helper(v3_mobile_module, historic_div)
119            _helper(current_mobile_module, torch.div)
120
121    @settings(
122        max_examples=10, deadline=200000
123    )  # A total of 10 examples will be generated
124    @given(
125        sample_input=st.tuples(
126            st.integers(min_value=5, max_value=199),
127            st.floats(min_value=5.0, max_value=199.0),
128        )
129    )  # Generate a pair (integer, float)
130    @example((2, 3, 2.0, 3.0))  # Ensure this example will be covered
131    def test_versioned_div_tensor_inplace(self, sample_input):
132        def historic_div_(self, other):
133            if self.is_floating_point() or other.is_floating_point():
134                return self.true_divide_(other)
135            return self.divide_(other, rounding_mode="trunc")
136
137        class MyModule(torch.nn.Module):
138            def forward(self, a, b):
139                a /= b
140                return a
141
142        try:
143            v3_mobile_module = _load_for_lite_interpreter(
144                pytorch_test_dir
145                + "/cpp/jit/upgrader_models/test_versioned_div_tensor_inplace_v2.ptl"
146            )
147        except Exception as e:
148            self.skipTest("Failed to load fixture!")
149
150        current_mobile_module = self._save_load_mobile_module(MyModule)
151
152        for val_a, val_b in product(sample_input, sample_input):
153            a = torch.tensor((val_a,))
154            b = torch.tensor((val_b,))
155
156            def _helper(m, fn):
157                fn_result = self._try_fn(fn, a.clone(), b)
158                m_result = self._try_fn(m, a, b)
159                if isinstance(m_result, Exception):
160                    self.assertTrue(fn_result, Exception)
161                else:
162                    self.assertEqual(m_result, fn_result)
163                    self.assertEqual(m_result, a)
164
165            _helper(v3_mobile_module, historic_div_)
166
167            # Recreates a since it was modified in place
168            a = torch.tensor((val_a,))
169            _helper(current_mobile_module, torch.Tensor.div_)
170
171    @settings(
172        max_examples=10, deadline=200000
173    )  # A total of 10 examples will be generated
174    @given(
175        sample_input=st.tuples(
176            st.integers(min_value=5, max_value=199),
177            st.floats(min_value=5.0, max_value=199.0),
178        )
179    )  # Generate a pair (integer, float)
180    @example((2, 3, 2.0, 3.0))  # Ensure this example will be covered
181    def test_versioned_div_tensor_out(self, sample_input):
182        def historic_div_out(self, other, out):
183            if (
184                self.is_floating_point()
185                or other.is_floating_point()
186                or out.is_floating_point()
187            ):
188                return torch.true_divide(self, other, out=out)
189            return torch.divide(self, other, out=out, rounding_mode="trunc")
190
191        class MyModule(torch.nn.Module):
192            def forward(self, a, b, out):
193                return a.div(b, out=out)
194
195        try:
196            v3_mobile_module = _load_for_lite_interpreter(
197                pytorch_test_dir
198                + "/cpp/jit/upgrader_models/test_versioned_div_tensor_out_v2.ptl"
199            )
200        except Exception as e:
201            self.skipTest("Failed to load fixture!")
202
203        current_mobile_module = self._save_load_mobile_module(MyModule)
204
205        for val_a, val_b in product(sample_input, sample_input):
206            a = torch.tensor((val_a,))
207            b = torch.tensor((val_b,))
208
209            for out in (torch.empty((1,)), torch.empty((1,), dtype=torch.long)):
210
211                def _helper(m, fn):
212                    fn_result = None
213                    if fn is torch.div:
214                        fn_result = self._try_fn(fn, a, b, out=out.clone())
215                    else:
216                        fn_result = self._try_fn(fn, a, b, out.clone())
217                    m_result = self._try_fn(m, a, b, out)
218
219                    if isinstance(m_result, Exception):
220                        self.assertTrue(fn_result, Exception)
221                    else:
222                        self.assertEqual(m_result, fn_result)
223                        self.assertEqual(m_result, out)
224
225                _helper(v3_mobile_module, historic_div_out)
226                _helper(current_mobile_module, torch.div)
227
228    @settings(
229        max_examples=10, deadline=200000
230    )  # A total of 10 examples will be generated
231    @given(
232        sample_input=st.tuples(
233            st.integers(min_value=5, max_value=199),
234            st.floats(min_value=5.0, max_value=199.0),
235        )
236    )  # Generate a pair (integer, float)
237    @example((2, 3, 2.0, 3.0))  # Ensure this example will be covered
238    def test_versioned_div_scalar(self, sample_input):
239        def historic_div_scalar_float(self, other: float):
240            return torch.true_divide(self, other)
241
242        def historic_div_scalar_int(self, other: int):
243            if self.is_floating_point():
244                return torch.true_divide(self, other)
245            return torch.divide(self, other, rounding_mode="trunc")
246
247        class MyModuleFloat(torch.nn.Module):
248            def forward(self, a, b: float):
249                return a / b
250
251        class MyModuleInt(torch.nn.Module):
252            def forward(self, a, b: int):
253                return a / b
254
255        try:
256            v3_mobile_module_float = _load_for_lite_interpreter(
257                pytorch_test_dir
258                + "/jit/fixtures/test_versioned_div_scalar_float_v2.ptl"
259            )
260            v3_mobile_module_int = _load_for_lite_interpreter(
261                pytorch_test_dir
262                + "/cpp/jit/upgrader_models/test_versioned_div_scalar_int_v2.ptl"
263            )
264        except Exception as e:
265            self.skipTest("Failed to load fixture!")
266
267        current_mobile_module_float = self._save_load_mobile_module(MyModuleFloat)
268        current_mobile_module_int = self._save_load_mobile_module(MyModuleInt)
269
270        for val_a, val_b in product(sample_input, sample_input):
271            a = torch.tensor((val_a,))
272            b = val_b
273
274            def _helper(m, fn):
275                m_result = self._try_fn(m, a, b)
276                fn_result = self._try_fn(fn, a, b)
277
278                if isinstance(m_result, Exception):
279                    self.assertTrue(fn_result, Exception)
280                else:
281                    self.assertEqual(m_result, fn_result)
282
283            if isinstance(b, float):
284                _helper(v3_mobile_module_float, current_mobile_module_float)
285                _helper(current_mobile_module_float, torch.div)
286            else:
287                _helper(v3_mobile_module_int, historic_div_scalar_int)
288                _helper(current_mobile_module_int, torch.div)
289
290    @settings(
291        max_examples=10, deadline=200000
292    )  # A total of 10 examples will be generated
293    @given(
294        sample_input=st.tuples(
295            st.integers(min_value=5, max_value=199),
296            st.floats(min_value=5.0, max_value=199.0),
297        )
298    )  # Generate a pair (integer, float)
299    @example((2, 3, 2.0, 3.0))  # Ensure this example will be covered
300    def test_versioned_div_scalar_reciprocal(self, sample_input):
301        def historic_div_scalar_float_reciprocal(self, other: float):
302            return other / self
303
304        def historic_div_scalar_int_reciprocal(self, other: int):
305            if self.is_floating_point():
306                return other / self
307            return torch.divide(other, self, rounding_mode="trunc")
308
309        class MyModuleFloat(torch.nn.Module):
310            def forward(self, a, b: float):
311                return b / a
312
313        class MyModuleInt(torch.nn.Module):
314            def forward(self, a, b: int):
315                return b / a
316
317        try:
318            v3_mobile_module_float = _load_for_lite_interpreter(
319                pytorch_test_dir
320                + "/cpp/jit/upgrader_models/test_versioned_div_scalar_reciprocal_float_v2.ptl"
321            )
322            v3_mobile_module_int = _load_for_lite_interpreter(
323                pytorch_test_dir
324                + "/cpp/jit/upgrader_models/test_versioned_div_scalar_reciprocal_int_v2.ptl"
325            )
326        except Exception as e:
327            self.skipTest("Failed to load fixture!")
328
329        current_mobile_module_float = self._save_load_mobile_module(MyModuleFloat)
330        current_mobile_module_int = self._save_load_mobile_module(MyModuleInt)
331
332        for val_a, val_b in product(sample_input, sample_input):
333            a = torch.tensor((val_a,))
334            b = val_b
335
336            def _helper(m, fn):
337                m_result = self._try_fn(m, a, b)
338                fn_result = None
339                # Reverses argument order for torch.div
340                if fn is torch.div:
341                    fn_result = self._try_fn(torch.div, b, a)
342                else:
343                    fn_result = self._try_fn(fn, a, b)
344
345                if isinstance(m_result, Exception):
346                    self.assertTrue(isinstance(fn_result, Exception))
347                elif fn is torch.div or a.is_floating_point():
348                    self.assertEqual(m_result, fn_result)
349                else:
350                    # Skip when fn is not torch.div and a is integral because
351                    # historic_div_scalar_int performs floored division
352                    pass
353
354            if isinstance(b, float):
355                _helper(v3_mobile_module_float, current_mobile_module_float)
356                _helper(current_mobile_module_float, torch.div)
357            else:
358                _helper(v3_mobile_module_int, current_mobile_module_int)
359                _helper(current_mobile_module_int, torch.div)
360
361    @settings(
362        max_examples=10, deadline=200000
363    )  # A total of 10 examples will be generated
364    @given(
365        sample_input=st.tuples(
366            st.integers(min_value=5, max_value=199),
367            st.floats(min_value=5.0, max_value=199.0),
368        )
369    )  # Generate a pair (integer, float)
370    @example((2, 3, 2.0, 3.0))  # Ensure this example will be covered
371    def test_versioned_div_scalar_inplace(self, sample_input):
372        def historic_div_scalar_float_inplace(self, other: float):
373            return self.true_divide_(other)
374
375        def historic_div_scalar_int_inplace(self, other: int):
376            if self.is_floating_point():
377                return self.true_divide_(other)
378
379            return self.divide_(other, rounding_mode="trunc")
380
381        class MyModuleFloat(torch.nn.Module):
382            def forward(self, a, b: float):
383                a /= b
384                return a
385
386        class MyModuleInt(torch.nn.Module):
387            def forward(self, a, b: int):
388                a /= b
389                return a
390
391        try:
392            v3_mobile_module_float = _load_for_lite_interpreter(
393                pytorch_test_dir
394                + "/cpp/jit/upgrader_models/test_versioned_div_scalar_inplace_float_v2.ptl"
395            )
396            v3_mobile_module_int = _load_for_lite_interpreter(
397                pytorch_test_dir
398                + "/cpp/jit/upgrader_models/test_versioned_div_scalar_inplace_int_v2.ptl"
399            )
400        except Exception as e:
401            self.skipTest("Failed to load fixture!")
402
403        current_mobile_module_float = self._save_load_module(MyModuleFloat)
404        current_mobile_module_int = self._save_load_module(MyModuleInt)
405
406        for val_a, val_b in product(sample_input, sample_input):
407            a = torch.tensor((val_a,))
408            b = val_b
409
410            def _helper(m, fn):
411                m_result = self._try_fn(m, a, b)
412                fn_result = self._try_fn(fn, a, b)
413
414                if isinstance(m_result, Exception):
415                    self.assertTrue(fn_result, Exception)
416                else:
417                    self.assertEqual(m_result, fn_result)
418
419            if isinstance(b, float):
420                _helper(current_mobile_module_float, torch.Tensor.div_)
421            else:
422                _helper(current_mobile_module_int, torch.Tensor.div_)
423
424    # NOTE: Scalar division was already true division in op version 3,
425    #   so this test verifies the behavior is unchanged.
426    def test_versioned_div_scalar_scalar(self):
427        class MyModule(torch.nn.Module):
428            def forward(self, a: float, b: int, c: float, d: int):
429                result_0 = a / b
430                result_1 = a / c
431                result_2 = b / c
432                result_3 = b / d
433                return (result_0, result_1, result_2, result_3)
434
435        try:
436            v3_mobile_module = _load_for_lite_interpreter(
437                pytorch_test_dir
438                + "/cpp/jit/upgrader_models/test_versioned_div_scalar_scalar_v2.ptl"
439            )
440        except Exception as e:
441            self.skipTest("Failed to load fixture!")
442
443        current_mobile_module = self._save_load_mobile_module(MyModule)
444
445        def _helper(m, fn):
446            vals = (5.0, 3, 2.0, 7)
447            m_result = m(*vals)
448            fn_result = fn(*vals)
449            for mr, hr in zip(m_result, fn_result):
450                self.assertEqual(mr, hr)
451
452        _helper(v3_mobile_module, current_mobile_module)
453
454    def test_versioned_linspace(self):
455        class Module(torch.nn.Module):
456            def forward(
457                self, a: Union[int, float, complex], b: Union[int, float, complex]
458            ):
459                c = torch.linspace(a, b, steps=5)
460                d = torch.linspace(a, b, steps=100)
461                return c, d
462
463        scripted_module = torch.jit.load(
464            pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_v7.ptl"
465        )
466
467        buffer = io.BytesIO(scripted_module._save_to_buffer_for_lite_interpreter())
468        buffer.seek(0)
469        v7_mobile_module = _load_for_lite_interpreter(buffer)
470
471        current_mobile_module = self._save_load_mobile_module(Module)
472
473        sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j))
474        for a, b in sample_inputs:
475            (output_with_step, output_without_step) = v7_mobile_module(a, b)
476            (current_with_step, current_without_step) = current_mobile_module(a, b)
477            # when no step is given, should have used 100
478            self.assertTrue(output_without_step.size(dim=0) == 100)
479            self.assertTrue(output_with_step.size(dim=0) == 5)
480            # outputs should be equal to the newest version
481            self.assertEqual(output_with_step, current_with_step)
482            self.assertEqual(output_without_step, current_without_step)
483
484    def test_versioned_linspace_out(self):
485        class Module(torch.nn.Module):
486            def forward(
487                self,
488                a: Union[int, float, complex],
489                b: Union[int, float, complex],
490                out: torch.Tensor,
491            ):
492                return torch.linspace(a, b, steps=100, out=out)
493
494        model_path = (
495            pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_out_v7.ptl"
496        )
497        loaded_model = torch.jit.load(model_path)
498        buffer = io.BytesIO(loaded_model._save_to_buffer_for_lite_interpreter())
499        buffer.seek(0)
500        v7_mobile_module = _load_for_lite_interpreter(buffer)
501        current_mobile_module = self._save_load_mobile_module(Module)
502
503        sample_inputs = (
504            (
505                3,
506                10,
507                torch.empty((100,), dtype=torch.int64),
508                torch.empty((100,), dtype=torch.int64),
509            ),
510            (
511                -10,
512                10,
513                torch.empty((100,), dtype=torch.int64),
514                torch.empty((100,), dtype=torch.int64),
515            ),
516            (
517                4.0,
518                6.0,
519                torch.empty((100,), dtype=torch.float64),
520                torch.empty((100,), dtype=torch.float64),
521            ),
522            (
523                3 + 4j,
524                4 + 5j,
525                torch.empty((100,), dtype=torch.complex64),
526                torch.empty((100,), dtype=torch.complex64),
527            ),
528        )
529        for start, end, out_for_old, out_for_new in sample_inputs:
530            output = v7_mobile_module(start, end, out_for_old)
531            output_current = current_mobile_module(start, end, out_for_new)
532            # when no step is given, should have used 100
533            self.assertTrue(output.size(dim=0) == 100)
534            # "Upgraded" model should match the new version output
535            self.assertEqual(output, output_current)
536
537    def test_versioned_logspace(self):
538        class Module(torch.nn.Module):
539            def forward(
540                self, a: Union[int, float, complex], b: Union[int, float, complex]
541            ):
542                c = torch.logspace(a, b, steps=5)
543                d = torch.logspace(a, b, steps=100)
544                return c, d
545
546        scripted_module = torch.jit.load(
547            pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_v8.ptl"
548        )
549
550        buffer = io.BytesIO(scripted_module._save_to_buffer_for_lite_interpreter())
551        buffer.seek(0)
552        v8_mobile_module = _load_for_lite_interpreter(buffer)
553
554        current_mobile_module = self._save_load_mobile_module(Module)
555
556        sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j))
557        for a, b in sample_inputs:
558            (output_with_step, output_without_step) = v8_mobile_module(a, b)
559            (current_with_step, current_without_step) = current_mobile_module(a, b)
560            # when no step is given, should have used 100
561            self.assertTrue(output_without_step.size(dim=0) == 100)
562            self.assertTrue(output_with_step.size(dim=0) == 5)
563            # outputs should be equal to the newest version
564            self.assertEqual(output_with_step, current_with_step)
565            self.assertEqual(output_without_step, current_without_step)
566
567    def test_versioned_logspace_out(self):
568        class Module(torch.nn.Module):
569            def forward(
570                self,
571                a: Union[int, float, complex],
572                b: Union[int, float, complex],
573                out: torch.Tensor,
574            ):
575                return torch.logspace(a, b, steps=100, out=out)
576
577        model_path = (
578            pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_out_v8.ptl"
579        )
580        loaded_model = torch.jit.load(model_path)
581        buffer = io.BytesIO(loaded_model._save_to_buffer_for_lite_interpreter())
582        buffer.seek(0)
583        v8_mobile_module = _load_for_lite_interpreter(buffer)
584        current_mobile_module = self._save_load_mobile_module(Module)
585
586        sample_inputs = (
587            (
588                3,
589                10,
590                torch.empty((100,), dtype=torch.int64),
591                torch.empty((100,), dtype=torch.int64),
592            ),
593            (
594                -10,
595                10,
596                torch.empty((100,), dtype=torch.int64),
597                torch.empty((100,), dtype=torch.int64),
598            ),
599            (
600                4.0,
601                6.0,
602                torch.empty((100,), dtype=torch.float64),
603                torch.empty((100,), dtype=torch.float64),
604            ),
605            (
606                3 + 4j,
607                4 + 5j,
608                torch.empty((100,), dtype=torch.complex64),
609                torch.empty((100,), dtype=torch.complex64),
610            ),
611        )
612        for start, end, out_for_old, out_for_new in sample_inputs:
613            output = v8_mobile_module(start, end, out_for_old)
614            output_current = current_mobile_module(start, end, out_for_new)
615            # when no step is given, should have used 100
616            self.assertTrue(output.size(dim=0) == 100)
617            # "Upgraded" model should match the new version output
618            self.assertEqual(output, output_current)
619