1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5from textwrap import dedent 6 7import torch 8from torch.testing._internal import jit_utils 9 10 11# Make the helper files in test/ importable 12pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 13sys.path.append(pytorch_test_dir) 14from torch.testing._internal.jit_utils import JitTestCase 15 16 17if __name__ == "__main__": 18 raise RuntimeError( 19 "This test file is not meant to be run directly, use:\n\n" 20 "\tpython test/test_jit.py TESTNAME\n\n" 21 "instead." 22 ) 23 24 25# Tests various JIT-related utility functions. 26class TestJitUtils(JitTestCase): 27 # Tests that POSITIONAL_OR_KEYWORD arguments are captured. 28 def test_get_callable_argument_names_positional_or_keyword(self): 29 def fn_positional_or_keyword_args_only(x, y): 30 return x + y 31 32 self.assertEqual( 33 ["x", "y"], 34 torch._jit_internal.get_callable_argument_names( 35 fn_positional_or_keyword_args_only 36 ), 37 ) 38 39 # Tests that POSITIONAL_ONLY arguments are ignored. 40 def test_get_callable_argument_names_positional_only(self): 41 code = dedent( 42 """ 43 def fn_positional_only_arg(x, /, y): 44 return x + y 45 """ 46 ) 47 48 fn_positional_only_arg = jit_utils._get_py3_code(code, "fn_positional_only_arg") 49 self.assertEqual( 50 ["y"], 51 torch._jit_internal.get_callable_argument_names(fn_positional_only_arg), 52 ) 53 54 # Tests that VAR_POSITIONAL arguments are ignored. 55 def test_get_callable_argument_names_var_positional(self): 56 # Tests that VAR_POSITIONAL arguments are ignored. 57 def fn_var_positional_arg(x, *arg): 58 return x + arg[0] 59 60 self.assertEqual( 61 ["x"], 62 torch._jit_internal.get_callable_argument_names(fn_var_positional_arg), 63 ) 64 65 # Tests that KEYWORD_ONLY arguments are ignored. 66 def test_get_callable_argument_names_keyword_only(self): 67 def fn_keyword_only_arg(x, *, y): 68 return x + y 69 70 self.assertEqual( 71 ["x"], torch._jit_internal.get_callable_argument_names(fn_keyword_only_arg) 72 ) 73 74 # Tests that VAR_KEYWORD arguments are ignored. 75 def test_get_callable_argument_names_var_keyword(self): 76 def fn_var_keyword_arg(**args): 77 return args["x"] + args["y"] 78 79 self.assertEqual( 80 [], torch._jit_internal.get_callable_argument_names(fn_var_keyword_arg) 81 ) 82 83 # Tests that a function signature containing various different types of 84 # arguments are ignored. 85 def test_get_callable_argument_names_hybrid(self): 86 code = dedent( 87 """ 88 def fn_hybrid_args(x, /, y, *args, **kwargs): 89 return x + y + args[0] + kwargs['z'] 90 """ 91 ) 92 fn_hybrid_args = jit_utils._get_py3_code(code, "fn_hybrid_args") 93 self.assertEqual( 94 ["y"], torch._jit_internal.get_callable_argument_names(fn_hybrid_args) 95 ) 96 97 def test_checkscriptassertraisesregex(self): 98 def fn(): 99 tup = (1, 2) 100 return tup[2] 101 102 self.checkScriptRaisesRegex(fn, (), Exception, "range", name="fn") 103 104 s = dedent( 105 """ 106 def fn(): 107 tup = (1, 2) 108 return tup[2] 109 """ 110 ) 111 112 self.checkScriptRaisesRegex(s, (), Exception, "range", name="fn") 113 114 def test_no_tracer_warn_context_manager(self): 115 torch._C._jit_set_tracer_state_warn(True) 116 with jit_utils.NoTracerWarnContextManager() as no_warn: 117 self.assertEqual(False, torch._C._jit_get_tracer_state_warn()) 118 self.assertEqual(True, torch._C._jit_get_tracer_state_warn()) 119