xref: /aosp_15_r20/external/pytorch/test/jit/test_pdt.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5from typing import Any, Dict, List, NamedTuple, Optional, Tuple  # noqa: F401
6
7import torch
8from torch.jit._monkeytype_config import _IS_MONKEYTYPE_INSTALLED
9from torch.testing._internal.common_utils import NoTest
10from torch.testing._internal.jit_utils import JitTestCase, make_global
11
12
13# Make the helper files in test/ importable
14pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
15sys.path.append(pytorch_test_dir)
16
17if not _IS_MONKEYTYPE_INSTALLED:
18    print(
19        "monkeytype is not installed. Skipping tests for Profile-Directed Typing",
20        file=sys.stderr,
21    )
22    JitTestCase = NoTest  # type: ignore[misc, assignment] # noqa: F811
23
24if __name__ == "__main__":
25    raise RuntimeError(
26        "This test file is not meant to be run directly, use:\n\n"
27        "\tpython test/test_jit.py TESTNAME\n\n"
28        "instead."
29    )
30
31
32class TestPDT(JitTestCase):
33    """
34    A suite of tests for profile directed typing in TorchScript.
35    """
36
37    def test_nn_module(self):
38        class TestPDTModel(torch.nn.Module):
39            def forward(self, x) -> Any:
40                if isinstance(x, int):
41                    return x + 1
42                elif isinstance(x, float):
43                    return x - 1
44                else:
45                    return x
46
47        make_global(TestPDTModel)
48        pdt_model = TestPDTModel()
49        inp: List[Tuple[Any, ...]] = [
50            (20,),
51            (2.7,),
52            (False,),
53        ]
54        scripted_pdt_model = torch.jit.script(
55            pdt_model, example_inputs={pdt_model: inp}
56        )
57        self.assertEqual(scripted_pdt_model(50), pdt_model(50))
58        self.assertEqual(scripted_pdt_model(1.8), pdt_model(1.8))
59        self.assertTrue(scripted_pdt_model(True), pdt_model(True))
60
61    def test_nested_nn_module_class(self):
62        class NestedPDTInner(torch.nn.Module):
63            def forward(self, x):
64                if isinstance(x, int):
65                    return x * 10
66                return x
67
68        class NestedModulePDTWrapper(torch.nn.Module):
69            def __init__(self, inner):
70                super().__init__()
71                self.inner = inner
72
73            def forward(self, x):
74                return self.inner(x)
75
76        make_global(NestedPDTInner, NestedModulePDTWrapper)
77        inner_pdt_model = NestedPDTInner()
78        wrapped_pdt_model = NestedModulePDTWrapper(inner_pdt_model)
79        inp: List[Tuple[Any, ...]] = [(20,), (False,)]
80        scripted_pdt_model = torch.jit.script(
81            wrapped_pdt_model, example_inputs={wrapped_pdt_model: inp}
82        )
83        self.assertEqual(scripted_pdt_model(30), wrapped_pdt_model(30))
84        self.assertEqual(scripted_pdt_model(1.9), wrapped_pdt_model(1.9))
85        self.assertTrue(scripted_pdt_model(True), wrapped_pdt_model(True))
86
87    def test_nested_nn_module_class_with_args(self):
88        class NestedModulePDTInner(torch.nn.Module):
89            def forward(self, x, y):
90                if isinstance(x, int):
91                    return x * 10 + y
92                return x
93
94        class NestedModulePDTOuter(torch.nn.Module):
95            def __init__(self, inner):
96                super().__init__()
97                self.inner = inner
98
99            def forward(self, x):
100                return self.inner(x, 20)
101
102        make_global(NestedModulePDTInner, NestedModulePDTOuter)
103        inner_pdt_model = NestedModulePDTInner()
104        outer_pdt_model = NestedModulePDTOuter(inner_pdt_model)
105        inner_input: List[Tuple[Any, ...]] = [
106            (10, 10),
107            (1.9, 20),
108        ]
109        outer_input: List[Tuple[Any, ...]] = [(20,), (False,)]
110        scripted_pdt_model = torch.jit.script(
111            outer_pdt_model,
112            example_inputs={
113                inner_pdt_model: inner_input,
114                outer_pdt_model: outer_input,
115            },
116        )
117        self.assertEqual(scripted_pdt_model(30), outer_pdt_model(30))
118        self.assertEqual(scripted_pdt_model(1.9), outer_pdt_model(1.9))
119        self.assertTrue(scripted_pdt_model(True), outer_pdt_model(True))
120
121    def test_nested_function_in_forward(self):
122        class NestedFunctionInForward(torch.nn.Module):
123            def forward(self, x):
124                return self.fun(x) + 10
125
126            def fun(self, x):
127                if isinstance(x, bool):
128                    return 0
129                elif isinstance(x, int):
130                    return x + 1
131                return 0
132
133        make_global(NestedFunctionInForward)
134        pdt_model = NestedFunctionInForward()
135        inp: List[Tuple[Any, ...]] = [(-1,), (False,)]
136        scripted_pdt_model = torch.jit.script(
137            pdt_model, example_inputs={pdt_model: inp}
138        )
139        self.assertEqual(scripted_pdt_model(30), pdt_model(30))
140        self.assertEqual(scripted_pdt_model(True), pdt_model(True))
141
142    def test_nn_module_with_export_function(self):
143        class TestModelWithExport(torch.nn.Module):
144            @torch.jit.export
145            def fn(self, x, y) -> Any:
146                assert not (isinstance(x, bool) and isinstance(y, bool))
147                if isinstance(x, int) and isinstance(y, int):
148                    return x + y
149                elif isinstance(x, float) and isinstance(y, float):
150                    return x - y
151                else:
152                    return -1
153
154        make_global(TestModelWithExport)
155        pdt_model = TestModelWithExport()
156        inp: List[Tuple[Any, ...]] = [
157            (
158                20,
159                10,
160            ),
161            (
162                2.7,
163                8.9,
164            ),
165        ]
166        scripted_pdt_model = torch.jit.script(
167            pdt_model, example_inputs={pdt_model.fn: inp}
168        )
169        self.assertEqual(scripted_pdt_model.fn(10, 90), pdt_model.fn(10, 90))
170        self.assertEqual(scripted_pdt_model.fn(1.8, 2.2), pdt_model.fn(1.8, 2.2))
171        self.assertTrue(
172            scripted_pdt_model.fn(torch.ones(1), 2), pdt_model.fn(torch.ones(1), 2)
173        )
174
175    def test_class_methods(self):
176        class PDTModel:
177            def test_sum(self, a):
178                return sum(a)
179
180        make_global(PDTModel)
181        pdt_model = PDTModel()
182        inp: List[Tuple[Any, ...]] = [
183            (
184                [
185                    10,
186                    20,
187                ],
188            ),
189        ]
190        scripted_pdt_model = torch.jit.script(
191            PDTModel, example_inputs={pdt_model.test_sum: inp}
192        )
193        script_model = scripted_pdt_model()
194        self.assertEqual(
195            script_model.test_sum(
196                [
197                    10,
198                    20,
199                    30,
200                ],
201            ),
202            pdt_model.test_sum(
203                [
204                    10,
205                    20,
206                    30,
207                ],
208            ),
209        )
210
211    def test_class_with_multiple_methods(self):
212        class PDTModelWithManyMethods:
213            def test_list_to_dict(self, a):
214                new_dictionary: Dict[float, bool] = {}
215                for element in a:
216                    new_dictionary[element] = True
217                return new_dictionary
218
219            def test_substring(self, a, b):
220                return b in a
221
222        make_global(PDTModelWithManyMethods)
223        pdt_model = PDTModelWithManyMethods()
224        list_inp: List[Tuple[Any, ...]] = [
225            (
226                [
227                    1.2,
228                    2.3,
229                ],
230            ),
231        ]
232        str_inp: List[Tuple[Any, ...]] = [
233            (
234                "abc",
235                "b",
236            ),
237        ]
238        scripted_pdt_model = torch.jit.script(
239            PDTModelWithManyMethods,
240            example_inputs={
241                pdt_model.test_list_to_dict: list_inp,
242                pdt_model.test_substring: str_inp,
243            },
244        )
245        script_model = scripted_pdt_model()
246        self.assertEqual(
247            script_model.test_list_to_dict(
248                [
249                    1.1,
250                    2.2,
251                    3.3,
252                ],
253            ),
254            pdt_model.test_list_to_dict(
255                [
256                    1.1,
257                    2.2,
258                    3.3,
259                ],
260            ),
261        )
262        self.assertEqual(
263            script_model.test_substring(
264                "helloworld",
265                "world",
266            ),
267            pdt_model.test_substring(
268                "helloworld",
269                "world",
270            ),
271        )
272        self.assertEqual(
273            script_model.test_substring(
274                "helloworld",
275                "def",
276            ),
277            pdt_model.test_substring(
278                "helloworld",
279                "def",
280            ),
281        )
282
283    def test_multiple_class_with_same_method(self):
284        class PDTModelOne:
285            def test_find(self, a, b):
286                return b in a.keys()
287
288        class PDTModelTwo:
289            def test_find(self, a, b):
290                return b in a
291
292        make_global(PDTModelOne, PDTModelTwo)
293        pdt_model_one = PDTModelOne()
294        pdt_model_two = PDTModelTwo()
295        dict_inp: List[Tuple[Any, ...]] = [
296            (
297                {
298                    1.2: True,
299                    2.3: False,
300                },
301                1.2,
302            ),
303        ]
304        list_inp: List[Tuple[Any, ...]] = [
305            (
306                [
307                    "abc",
308                    "b",
309                ],
310                "c",
311            ),
312        ]
313        scripted_pdt_model_one = torch.jit.script(
314            PDTModelOne, example_inputs={pdt_model_one.test_find: dict_inp}
315        )
316        scripted_pdt_model_two = torch.jit.script(
317            PDTModelTwo, example_inputs={pdt_model_two.test_find: list_inp}
318        )
319
320        script_model_one, script_model_two = (
321            scripted_pdt_model_one(),
322            scripted_pdt_model_two(),
323        )
324        self.assertEqual(
325            script_model_one.test_find(
326                {
327                    1.1: True,
328                    2.2: True,
329                    3.3: False,
330                },
331                4.4,
332            ),
333            pdt_model_one.test_find(
334                {
335                    1.1: True,
336                    2.2: True,
337                    3.3: False,
338                },
339                4.4,
340            ),
341        )
342        self.assertEqual(
343            script_model_two.test_find(
344                [
345                    "hello",
346                    "world",
347                ],
348                "world",
349            ),
350            pdt_model_two.test_find(
351                [
352                    "hello",
353                    "world",
354                ],
355                "world",
356            ),
357        )
358
359    def test_pdt(self):
360        def test_sum(a, b):
361            return a + b
362
363        make_global(test_sum)
364        scripted_fn_add = torch.jit.script(test_sum, example_inputs=[(3, 4)])
365        self.assertEqual(scripted_fn_add(10, 2), test_sum(10, 2))
366
367        def test_sub(a, b):
368            return a - b
369
370        make_global(test_sub)
371        scripted_fn_sub = torch.jit.script(test_sub, example_inputs=[(3.9, 4.10)])
372        self.assertEqual(scripted_fn_sub(6.5, 2.9), test_sub(6.5, 2.9))
373
374        def test_mul(a, b):
375            return a * b
376
377        make_global(test_mul)
378        scripted_fn_mul = torch.jit.script(test_mul, example_inputs=[(-10, 9)])
379        self.assertEqual(scripted_fn_mul(-1, 3), test_mul(-1, 3))
380
381        def test_args_complex(real, img):
382            return torch.complex(real, img)
383
384        make_global(test_args_complex)
385        scripted_fn_complex = torch.jit.script(
386            test_args_complex, example_inputs=[(torch.rand(3, 4), torch.rand(3, 4))]
387        )
388        arg1, arg2 = torch.rand(3, 4), torch.rand(3, 4)
389        self.assertEqual(scripted_fn_complex(arg1, arg2), test_args_complex(arg1, arg2))
390
391        def test_bool(a):
392            if a:
393                return -1
394            else:
395                return 0
396
397        make_global(test_bool)
398        scripted_fn_bool = torch.jit.script(test_bool, example_inputs=[(True,)])
399        self.assertEqual(scripted_fn_bool(True), test_bool(True))
400
401        def test_str(a):
402            if a == "":
403                return False
404            else:
405                return True
406
407        make_global(test_str)
408        scripted_fn_str = torch.jit.script(test_str, example_inputs=[("",)])
409        self.assertEqual(scripted_fn_str("abc"), test_str("abc"))
410
411    def test_pdt_list_and_tuple(self):
412        def test_list_and_tuple(a):
413            return sum(a)
414
415        make_global(test_list_and_tuple)
416
417        scripted_fn_float_list_input = torch.jit.script(
418            test_list_and_tuple, example_inputs=[([4.9, 8.9],)]
419        )
420        self.assertEqual(
421            scripted_fn_float_list_input([11.9, 7.6]), test_list_and_tuple([11.9, 7.6])
422        )
423
424        scripted_fn_bool_list_input = torch.jit.script(
425            test_list_and_tuple, example_inputs=[([True, False, True],)]
426        )
427        self.assertEqual(
428            scripted_fn_bool_list_input([True, True, True]),
429            test_list_and_tuple([True, True, True]),
430        )
431
432        scripted_fn_int_list_input = torch.jit.script(
433            test_list_and_tuple, example_inputs=[([3, 4, 5],)]
434        )
435        self.assertEqual(
436            scripted_fn_int_list_input([1, 2, 3]), test_list_and_tuple([1, 2, 3])
437        )
438
439        scripted_fn_float_tuple_input = torch.jit.script(
440            test_list_and_tuple, example_inputs=[((4.9, 8.9),)]
441        )
442        self.assertEqual(
443            scripted_fn_float_tuple_input((11.9, 7.6)), test_list_and_tuple((11.9, 7.6))
444        )
445
446        scripted_fn_bool_tuple_input = torch.jit.script(
447            test_list_and_tuple, example_inputs=[((True, False, True),)]
448        )
449        self.assertEqual(
450            scripted_fn_bool_tuple_input((True, True, True)),
451            test_list_and_tuple((True, True, True)),
452        )
453
454        scripted_fn_int_tuple_input = torch.jit.script(
455            test_list_and_tuple, example_inputs=[((3, 4, 5),)]
456        )
457        self.assertEqual(
458            scripted_fn_int_tuple_input((1, 2, 3)), test_list_and_tuple((1, 2, 3))
459        )
460
461    def test_nested_list_and_tuple(self):
462        def test_nested_list(inp):
463            return [sum(v) for v in inp]
464
465        def test_nested_tuple(inp):
466            ans = 0.0
467            for tup in inp:
468                for val in tup:
469                    if val > 0:
470                        ans *= val
471            return ans
472
473        make_global(test_nested_list, test_nested_tuple)
474
475        list_inp = [
476            [
477                1,
478                2,
479                3,
480            ],
481            [
482                5,
483                6,
484                7,
485            ],
486        ]
487        scripted_fn = torch.jit.script(
488            test_nested_list,
489            example_inputs=[
490                (list_inp,),
491            ],
492        )
493        inp = [
494            [
495                0,
496                4,
497                7,
498            ],
499            [
500                8,
501                11,
502            ],
503            [
504                6,
505                -1,
506                -20,
507            ],
508        ]
509        self.assertEqual(
510            scripted_fn(
511                inp,
512            ),
513            test_nested_list(
514                inp,
515            ),
516        )
517
518        list_inp = (
519            [
520                1,
521                2,
522                3,
523            ],
524            [
525                5,
526                6,
527                7,
528            ],
529        )
530        scripted_fn = torch.jit.script(
531            test_nested_list,
532            example_inputs=[
533                (list_inp,),
534            ],
535        )
536        inp = (
537            [
538                0,
539                4,
540                7,
541            ],
542            [
543                8,
544                11,
545            ],
546            [
547                6,
548                -1,
549                -20,
550            ],
551        )
552        self.assertEqual(
553            scripted_fn(
554                inp,
555            ),
556            test_nested_list(
557                inp,
558            ),
559        )
560
561        tup_inp = [
562            (
563                1.0,
564                2.6,
565                3.7,
566            ),
567            (
568                5.7,
569                6.1,
570                1.7,
571            ),
572        ]
573        scripted_fn = torch.jit.script(
574            test_nested_tuple,
575            example_inputs=[
576                (tup_inp,),
577            ],
578        )
579        inp = [
580            (
581                1.0,
582                4.1,
583                7.4,
584            ),
585            (
586                4.8,
587                1.1,
588                -1.2,
589            ),
590            (
591                6.3,
592                -1.3,
593                -2.0,
594            ),
595        ]
596        self.assertEqual(
597            scripted_fn(
598                inp,
599            ),
600            test_nested_tuple(
601                inp,
602            ),
603        )
604
605        tup_inp = (
606            (
607                True,
608                False,
609                True,
610            ),
611            (
612                False,
613                False,
614                False,
615            ),
616        )
617        scripted_fn = torch.jit.script(
618            test_nested_tuple,
619            example_inputs=[
620                (tup_inp,),
621            ],
622        )
623        inp = (
624            (
625                True,
626                True,
627                True,
628            ),
629            (
630                False,
631                False,
632                True,
633            ),
634        )
635        self.assertEqual(
636            scripted_fn(
637                inp,
638            ),
639            test_nested_tuple(
640                inp,
641            ),
642        )
643
644    def test_pdt_dict(self):
645        def test_dict(a):
646            return a["foo"]
647
648        def test_dict_int_list(a):
649            return a[1]
650
651        make_global(test_dict, test_dict_int_list)
652
653        str_bool_inp = {"foo": True, "bar": False}
654        scripted_fn = torch.jit.script(test_dict, example_inputs=[(str_bool_inp,)])
655        self.assertEqual(
656            scripted_fn(
657                {"foo": False, "bar": True},
658            ),
659            test_dict(
660                {"foo": False, "bar": True},
661            ),
662        )
663
664        str_list_inp = {0: [True, False], 1: [False, True]}
665        scripted_fn = torch.jit.script(
666            test_dict_int_list, example_inputs=[(str_list_inp,)]
667        )
668        self.assertEqual(
669            scripted_fn(
670                {0: [False, False], 1: [True, True]},
671            ),
672            test_dict_int_list(
673                {0: [False, False], 1: [True, True]},
674            ),
675        )
676
677    def test_any(self):
678        def test_multiple_types(a):
679            assert not isinstance(a, bool)
680            return a
681
682        def test_multiple_type_refinement(a):
683            if isinstance(a, bool):
684                return 1
685            elif isinstance(a, int):
686                return 1 + a
687            elif isinstance(a, float):
688                return 1 + int(a)
689            else:
690                return -1
691
692        make_global(test_multiple_types, test_multiple_type_refinement)
693
694        scripted_fn = torch.jit.script(
695            test_multiple_types, example_inputs=[(1,), ("abc",), (8.9,), ([3, 4, 5],)]
696        )
697        self.assertEqual(scripted_fn(10), test_multiple_types(10))
698        self.assertEqual(scripted_fn("def"), test_multiple_types("def"))
699        self.assertEqual(scripted_fn(7.89999), test_multiple_types(7.89999))
700        self.assertEqual(scripted_fn([10, 11, 14]), test_multiple_types([10, 11, 14]))
701
702        scripted_fn = torch.jit.script(
703            test_multiple_type_refinement,
704            example_inputs=[
705                (1,),
706                ("abc",),
707                (8.9,),
708                ([3, 4, 5],),
709                (True,),
710                ({"a": True},),
711            ],
712        )
713        self.assertEqual(scripted_fn(10), test_multiple_type_refinement(10))
714        self.assertEqual(scripted_fn("def"), test_multiple_type_refinement("def"))
715        self.assertEqual(scripted_fn(7.89999), test_multiple_type_refinement(7.89999))
716        self.assertEqual(
717            scripted_fn([10, 11, 14]), test_multiple_type_refinement([10, 11, 14])
718        )
719        self.assertEqual(scripted_fn(False), test_multiple_type_refinement(False))
720        self.assertEqual(
721            scripted_fn({"abc": True, "def": False}),
722            test_multiple_type_refinement({"abc": True, "def": False}),
723        )
724
725    def test_class_as_profiled_types(self):
726        class UserDefinedClass:
727            def fn(self, b) -> Any:
728                assert b is not None
729                if isinstance(b, int):
730                    return b if b > 0 else -1
731                elif isinstance(b, float):
732                    return b if b > 0.0 else -1.0
733                return 0
734
735        def test_model(a, m):
736            assert not isinstance(a, bool)
737            return m.fn(a)
738
739        make_global(UserDefinedClass, test_model)
740
741        user_class = UserDefinedClass()
742        scripted_fn = torch.jit.script(
743            test_model,
744            example_inputs=[
745                (
746                    10,
747                    user_class,
748                ),
749                (
750                    10.9,
751                    user_class,
752                ),
753            ],
754        )
755        self.assertEqual(
756            scripted_fn(
757                100,
758                user_class,
759            ),
760            test_model(100, user_class),
761        )
762        self.assertEqual(
763            scripted_fn(
764                1.9,
765                user_class,
766            ),
767            test_model(1.9, user_class),
768        )
769
770    def test_class_with_args_as_profiled_types(self):
771        class ClassWithArgs:
772            def __init__(self, a: bool):
773                self.a = a
774
775            def fn(self, b):
776                if self.a:
777                    return b
778                else:
779                    return -1
780
781        def test_model_with_args(a, m):
782            assert not isinstance(a, bool)
783            return m.fn(a)
784
785        make_global(ClassWithArgs, test_model_with_args)
786
787        user_class = ClassWithArgs(False)
788        scripted_fn = torch.jit.script(
789            test_model_with_args,
790            example_inputs=[
791                (
792                    10,
793                    user_class,
794                ),
795                (
796                    10.9,
797                    user_class,
798                ),
799            ],
800        )
801        self.assertEqual(
802            scripted_fn(
803                100,
804                ClassWithArgs(True),
805            ),
806            test_model_with_args(100, ClassWithArgs(True)),
807        )
808
809    def test_nn_parameter_as_arg(self):
810        class TestNNParameter(torch.nn.Module):
811            def __init__(self) -> None:
812                super().__init__()
813                self.inp = torch.nn.Parameter(torch.ones(2, 3))
814
815            def add_nn_parameter_with_int(self, x, y):
816                return torch.add(x, y)
817
818            def forward(self, y):
819                return self.add_nn_parameter_with_int(self.inp, y)
820
821        make_global(TestNNParameter)
822        pdt_model = TestNNParameter()
823        scripted_fn = torch.jit.script(
824            pdt_model,
825            example_inputs={
826                pdt_model: [
827                    (10,),
828                ],
829            },
830        )
831        self.assertEqual(scripted_fn(20), pdt_model(20))
832
833    def test_fx_tracing_with_typing(self):
834        class FXModelOutput(NamedTuple):
835            result: List[int]
836
837        class FXModel(torch.nn.Module):
838            def forward(self, a) -> FXModelOutput:
839                result = FXModelOutput(result=a)
840                return result
841
842        make_global(FXModel, FXModelOutput)
843        pdt_model = FXModel()
844        scripted_fn = torch.jit.script(
845            pdt_model,
846            example_inputs={
847                pdt_model: [
848                    (
849                        [
850                            10,
851                            20,
852                        ],
853                    ),
854                ],
855            },
856        )
857        self.assertEqual(scripted_fn([20]), pdt_model([20]))
858
859    def test_nonetype_as_optional_of_type(self):
860        def test_none(a) -> Any:
861            if a is None:
862                return 0
863            else:
864                return a + torch.ones(1)
865
866        make_global(test_none)
867
868        scripted_fn = torch.jit.script(test_none, example_inputs=[(None,), (10.6,)])
869        self.assertEqual(
870            scripted_fn(
871                30.9,
872            ),
873            test_none(
874                30.9,
875            ),
876        )
877
878        scripted_fn = torch.jit.script(test_none, example_inputs=[(None,), (10,)])
879        self.assertEqual(
880            scripted_fn(
881                2,
882            ),
883            test_none(
884                2,
885            ),
886        )
887
888        scripted_fn = torch.jit.script(
889            test_none, example_inputs=[(None,), (torch.Tensor(1),)]
890        )
891        self.assertEqual(
892            scripted_fn(
893                torch.ones(1),
894            ),
895            test_none(
896                torch.ones(1),
897            ),
898        )
899