xref: /aosp_15_r20/external/pytorch/test/jit/test_jit_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5from textwrap import dedent
6
7import torch
8from torch.testing._internal import jit_utils
9
10
11# Make the helper files in test/ importable
12pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
13sys.path.append(pytorch_test_dir)
14from torch.testing._internal.jit_utils import JitTestCase
15
16
17if __name__ == "__main__":
18    raise RuntimeError(
19        "This test file is not meant to be run directly, use:\n\n"
20        "\tpython test/test_jit.py TESTNAME\n\n"
21        "instead."
22    )
23
24
25# Tests various JIT-related utility functions.
26class TestJitUtils(JitTestCase):
27    # Tests that POSITIONAL_OR_KEYWORD arguments are captured.
28    def test_get_callable_argument_names_positional_or_keyword(self):
29        def fn_positional_or_keyword_args_only(x, y):
30            return x + y
31
32        self.assertEqual(
33            ["x", "y"],
34            torch._jit_internal.get_callable_argument_names(
35                fn_positional_or_keyword_args_only
36            ),
37        )
38
39    # Tests that POSITIONAL_ONLY arguments are ignored.
40    def test_get_callable_argument_names_positional_only(self):
41        code = dedent(
42            """
43            def fn_positional_only_arg(x, /, y):
44                return x + y
45        """
46        )
47
48        fn_positional_only_arg = jit_utils._get_py3_code(code, "fn_positional_only_arg")
49        self.assertEqual(
50            ["y"],
51            torch._jit_internal.get_callable_argument_names(fn_positional_only_arg),
52        )
53
54    # Tests that VAR_POSITIONAL arguments are ignored.
55    def test_get_callable_argument_names_var_positional(self):
56        # Tests that VAR_POSITIONAL arguments are ignored.
57        def fn_var_positional_arg(x, *arg):
58            return x + arg[0]
59
60        self.assertEqual(
61            ["x"],
62            torch._jit_internal.get_callable_argument_names(fn_var_positional_arg),
63        )
64
65    # Tests that KEYWORD_ONLY arguments are ignored.
66    def test_get_callable_argument_names_keyword_only(self):
67        def fn_keyword_only_arg(x, *, y):
68            return x + y
69
70        self.assertEqual(
71            ["x"], torch._jit_internal.get_callable_argument_names(fn_keyword_only_arg)
72        )
73
74    # Tests that VAR_KEYWORD arguments are ignored.
75    def test_get_callable_argument_names_var_keyword(self):
76        def fn_var_keyword_arg(**args):
77            return args["x"] + args["y"]
78
79        self.assertEqual(
80            [], torch._jit_internal.get_callable_argument_names(fn_var_keyword_arg)
81        )
82
83    # Tests that a function signature containing various different types of
84    # arguments are ignored.
85    def test_get_callable_argument_names_hybrid(self):
86        code = dedent(
87            """
88            def fn_hybrid_args(x, /, y, *args, **kwargs):
89                return x + y + args[0] + kwargs['z']
90        """
91        )
92        fn_hybrid_args = jit_utils._get_py3_code(code, "fn_hybrid_args")
93        self.assertEqual(
94            ["y"], torch._jit_internal.get_callable_argument_names(fn_hybrid_args)
95        )
96
97    def test_checkscriptassertraisesregex(self):
98        def fn():
99            tup = (1, 2)
100            return tup[2]
101
102        self.checkScriptRaisesRegex(fn, (), Exception, "range", name="fn")
103
104        s = dedent(
105            """
106        def fn():
107            tup = (1, 2)
108            return tup[2]
109        """
110        )
111
112        self.checkScriptRaisesRegex(s, (), Exception, "range", name="fn")
113
114    def test_no_tracer_warn_context_manager(self):
115        torch._C._jit_set_tracer_state_warn(True)
116        with jit_utils.NoTracerWarnContextManager() as no_warn:
117            self.assertEqual(False, torch._C._jit_get_tracer_state_warn())
118        self.assertEqual(True, torch._C._jit_get_tracer_state_warn())
119