1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport os 4*da0073e9SAndroid Build Coastguard Workerimport sys 5*da0073e9SAndroid Build Coastguard Workerimport unittest 6*da0073e9SAndroid Build Coastguard Workerfrom typing import Tuple 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerimport torch 9*da0073e9SAndroid Build Coastguard Workerfrom jit.test_hooks_modules import ( 10*da0073e9SAndroid Build Coastguard Worker create_forward_tuple_input, 11*da0073e9SAndroid Build Coastguard Worker create_module_forward_multiple_inputs, 12*da0073e9SAndroid Build Coastguard Worker create_module_forward_single_input, 13*da0073e9SAndroid Build Coastguard Worker create_module_hook_return_nothing, 14*da0073e9SAndroid Build Coastguard Worker create_module_multiple_hooks_multiple_inputs, 15*da0073e9SAndroid Build Coastguard Worker create_module_multiple_hooks_single_input, 16*da0073e9SAndroid Build Coastguard Worker create_module_no_forward_input, 17*da0073e9SAndroid Build Coastguard Worker create_module_same_hook_repeated, 18*da0073e9SAndroid Build Coastguard Worker create_submodule_forward_multiple_inputs, 19*da0073e9SAndroid Build Coastguard Worker create_submodule_forward_single_input, 20*da0073e9SAndroid Build Coastguard Worker create_submodule_forward_single_input_return_not_tupled, 21*da0073e9SAndroid Build Coastguard Worker create_submodule_hook_return_nothing, 22*da0073e9SAndroid Build Coastguard Worker create_submodule_multiple_hooks_multiple_inputs, 23*da0073e9SAndroid Build Coastguard Worker create_submodule_multiple_hooks_single_input, 24*da0073e9SAndroid Build Coastguard Worker create_submodule_no_forward_input, 25*da0073e9SAndroid Build Coastguard Worker create_submodule_same_hook_repeated, 26*da0073e9SAndroid Build Coastguard Worker create_submodule_to_call_directly_with_hooks, 27*da0073e9SAndroid Build Coastguard Worker ModuleDirectforwardSubmodCall, 28*da0073e9SAndroid Build Coastguard Worker ModuleForwardSingleInput, 29*da0073e9SAndroid Build Coastguard Worker ModuleForwardTupleInput, 30*da0073e9SAndroid Build Coastguard Worker) 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable 34*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 35*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir) 36*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 40*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 41*da0073e9SAndroid Build Coastguard Worker "This test file is not meant to be run directly, use:\n\n" 42*da0073e9SAndroid Build Coastguard Worker "\tpython test/test_jit.py TESTNAME\n\n" 43*da0073e9SAndroid Build Coastguard Worker "instead." 44*da0073e9SAndroid Build Coastguard Worker ) 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker# Tests for JIT forward hooks and pre-hooks 48*da0073e9SAndroid Build Coastguard Workerclass TestHooks(JitTestCase): 49*da0073e9SAndroid Build Coastguard Worker def test_module_no_forward_input(self): 50*da0073e9SAndroid Build Coastguard Worker self.checkModule(create_module_no_forward_input(), ()) 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker def test_submodule_no_forward_input(self): 53*da0073e9SAndroid Build Coastguard Worker self.checkModule(create_submodule_no_forward_input(), ()) 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker def test_module_forward_multiple_inputs(self): 56*da0073e9SAndroid Build Coastguard Worker self.checkModule( 57*da0073e9SAndroid Build Coastguard Worker create_module_forward_multiple_inputs(), (["a"], "no_pre_hook") 58*da0073e9SAndroid Build Coastguard Worker ) 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker def test_module_multiple_hooks_multiple_inputs(self): 61*da0073e9SAndroid Build Coastguard Worker self.checkModule( 62*da0073e9SAndroid Build Coastguard Worker create_module_multiple_hooks_multiple_inputs(), (["a"], "no_pre_hook") 63*da0073e9SAndroid Build Coastguard Worker ) 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker def test_module_forward_single_input(self): 66*da0073e9SAndroid Build Coastguard Worker self.checkModule(create_module_forward_single_input(), ("a",)) 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker def test_module_same_hook_repeated(self): 69*da0073e9SAndroid Build Coastguard Worker self.checkModule(create_module_same_hook_repeated(), ("a",)) 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker def test_module_hook_return_nothing(self): 72*da0073e9SAndroid Build Coastguard Worker self.checkModule(create_module_hook_return_nothing(), ("a",)) 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker def test_module_multiple_hooks_single_input(self): 75*da0073e9SAndroid Build Coastguard Worker self.checkModule(create_module_multiple_hooks_single_input(), ("a",)) 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker def test_submodule_forward_multiple_inputs(self): 78*da0073e9SAndroid Build Coastguard Worker self.checkModule( 79*da0073e9SAndroid Build Coastguard Worker create_submodule_forward_multiple_inputs(), (["a"], "no_pre_hook") 80*da0073e9SAndroid Build Coastguard Worker ) 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker def test_submodule_multiple_hooks_multiple_inputs(self): 83*da0073e9SAndroid Build Coastguard Worker self.checkModule( 84*da0073e9SAndroid Build Coastguard Worker create_submodule_multiple_hooks_multiple_inputs(), 85*da0073e9SAndroid Build Coastguard Worker (["a"], "no_pre_hook"), 86*da0073e9SAndroid Build Coastguard Worker ) 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker def test_submodule_forward_single_input(self): 89*da0073e9SAndroid Build Coastguard Worker self.checkModule(create_submodule_forward_single_input(), ("a",)) 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker def test_submodule_called_directly_with_hooks(self): 92*da0073e9SAndroid Build Coastguard Worker module = create_submodule_to_call_directly_with_hooks() 93*da0073e9SAndroid Build Coastguard Worker module_scripted = torch.jit.script(module) 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker submodule = module.submodule 96*da0073e9SAndroid Build Coastguard Worker scripted_submodule = module_scripted.submodule 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker self.assertEqual(submodule("a"), scripted_submodule("a")) 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker def test_submodule_same_hook_repeated(self): 101*da0073e9SAndroid Build Coastguard Worker self.checkModule(create_submodule_same_hook_repeated(), ("a",)) 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker def test_submodule_hook_return_nothing(self): 104*da0073e9SAndroid Build Coastguard Worker self.checkModule(create_submodule_hook_return_nothing(), ("a",)) 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker def test_submodule_multiple_hooks_single_input(self): 107*da0073e9SAndroid Build Coastguard Worker self.checkModule(create_submodule_multiple_hooks_single_input(), (["a"])) 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker def test_forward_tuple_input(self): 110*da0073e9SAndroid Build Coastguard Worker self.checkModule(create_forward_tuple_input(), ((3,),)) 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker def test_submodule_forward_single_input_return_not_tupled(self): 113*da0073e9SAndroid Build Coastguard Worker self.checkModule( 114*da0073e9SAndroid Build Coastguard Worker create_submodule_forward_single_input_return_not_tupled(), ("a",) 115*da0073e9SAndroid Build Coastguard Worker ) 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker def test_hook_method_name_collision(self): 118*da0073e9SAndroid Build Coastguard Worker # Hooks can't have the same name as methods. 119*da0073e9SAndroid Build Coastguard Worker m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker def foo(self, input: Tuple[str]) -> Tuple[str]: 122*da0073e9SAndroid Build Coastguard Worker assert self.name == "inner_mod_name" 123*da0073e9SAndroid Build Coastguard Worker assert input[0] == "a_outermod" 124*da0073e9SAndroid Build Coastguard Worker return ("pre_hook_override_name",) 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker m.submodule.register_forward_pre_hook(foo) 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 129*da0073e9SAndroid Build Coastguard Worker RuntimeError, 130*da0073e9SAndroid Build Coastguard Worker "Can't define hook: foo on class: .+ " 131*da0073e9SAndroid Build Coastguard Worker "because a method or hook with that name already exists.", 132*da0073e9SAndroid Build Coastguard Worker ): 133*da0073e9SAndroid Build Coastguard Worker torch.jit.script(m) 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker def test_hook_hook_name_collision(self): 136*da0073e9SAndroid Build Coastguard Worker # Test edge case of two hooks sharing name but not python definition 137*da0073e9SAndroid Build Coastguard Worker m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Worker def prehook(self, input: Tuple[str]) -> Tuple[str]: 140*da0073e9SAndroid Build Coastguard Worker return "This is the first hook" 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker m.submodule.register_forward_pre_hook(prehook) 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker def prehook(self, input: Tuple[str]) -> Tuple[str]: 145*da0073e9SAndroid Build Coastguard Worker return "This is the second hook" 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker m.submodule.register_forward_pre_hook(prehook) 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 150*da0073e9SAndroid Build Coastguard Worker RuntimeError, 151*da0073e9SAndroid Build Coastguard Worker "Pre-hook '.+' on .+ has at least two different python " 152*da0073e9SAndroid Build Coastguard Worker "definitions. Please use unique names for all hooks.", 153*da0073e9SAndroid Build Coastguard Worker ): 154*da0073e9SAndroid Build Coastguard Worker torch.jit.script(m) 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Worker m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker def hook(self, input: Tuple[str], output: str): 159*da0073e9SAndroid Build Coastguard Worker return "This is the first hook" 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker m.submodule.register_forward_hook(hook) 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker def hook(self, input: Tuple[str]): 164*da0073e9SAndroid Build Coastguard Worker return "This is the second hook" 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker m.submodule.register_forward_hook(hook) 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 169*da0073e9SAndroid Build Coastguard Worker RuntimeError, 170*da0073e9SAndroid Build Coastguard Worker "Hook '.+' on .+ has at least two different python " 171*da0073e9SAndroid Build Coastguard Worker "definitions. Please use unique names for all hooks.", 172*da0073e9SAndroid Build Coastguard Worker ): 173*da0073e9SAndroid Build Coastguard Worker torch.jit.script(m) 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Worker def test_module_direct_forward_invocation(self): 176*da0073e9SAndroid Build Coastguard Worker # Test that hooks are only invoked when the module is 177*da0073e9SAndroid Build Coastguard Worker # called directly and not when forward is called. 178*da0073e9SAndroid Build Coastguard Worker m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker def pre_hook(self, input: Tuple[str]) -> Tuple[str]: 181*da0073e9SAndroid Build Coastguard Worker return ("pre_hook_override_name",) 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker def forward_hook(self, input: Tuple[str], output: str): 184*da0073e9SAndroid Build Coastguard Worker assert self.name == "outer_mod_name" 185*da0073e9SAndroid Build Coastguard Worker assert input == ("pre_hook_override_name",) 186*da0073e9SAndroid Build Coastguard Worker output = output + "_fh" 187*da0073e9SAndroid Build Coastguard Worker return output 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Worker m.register_forward_pre_hook(pre_hook) 190*da0073e9SAndroid Build Coastguard Worker m.register_forward_hook(forward_hook) 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker m_scripted = torch.jit.script(m) 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.forward("a"), m_scripted.forward("a")) 195*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(m_scripted("a"), m_scripted.forward("a")) 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker def test_submodule_direct_forward_invocation(self): 198*da0073e9SAndroid Build Coastguard Worker m_submod_forward_call = ModuleDirectforwardSubmodCall( 199*da0073e9SAndroid Build Coastguard Worker "outer_mod_name", "inner_mod_name" 200*da0073e9SAndroid Build Coastguard Worker ) 201*da0073e9SAndroid Build Coastguard Worker m_submod_call = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker def pre_hook(self, input: Tuple[str]) -> Tuple[str]: 204*da0073e9SAndroid Build Coastguard Worker return ("pre_hook_override_name",) 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker def forward_hook(self, input: Tuple[str], output: str): 207*da0073e9SAndroid Build Coastguard Worker assert input == ("pre_hook_override_name",) 208*da0073e9SAndroid Build Coastguard Worker return output + "_fh" 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker m_submod_forward_call.submodule.register_forward_pre_hook(pre_hook) 211*da0073e9SAndroid Build Coastguard Worker m_submod_forward_call.submodule.register_forward_hook(forward_hook) 212*da0073e9SAndroid Build Coastguard Worker m_submod_call.submodule.register_forward_pre_hook(pre_hook) 213*da0073e9SAndroid Build Coastguard Worker m_submod_call.submodule.register_forward_hook(forward_hook) 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker m_submod_forward_call_scripted = torch.jit.script(m_submod_forward_call) 216*da0073e9SAndroid Build Coastguard Worker m_submod_call_scripted = torch.jit.script(m_submod_call) 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 219*da0073e9SAndroid Build Coastguard Worker m_submod_forward_call_scripted("a"), m_submod_forward_call("a") 220*da0073e9SAndroid Build Coastguard Worker ) 221*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual( 222*da0073e9SAndroid Build Coastguard Worker m_submod_forward_call_scripted("a"), m_submod_call_scripted("a") 223*da0073e9SAndroid Build Coastguard Worker ) 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker # TODO: add this test back once figured out how to print error msg 226*da0073e9SAndroid Build Coastguard Worker @unittest.skip 227*da0073e9SAndroid Build Coastguard Worker def test_hook_compilation_hint(self): 228*da0073e9SAndroid Build Coastguard Worker # Tests if hook error message is printed out if erroring after schema check. 229*da0073e9SAndroid Build Coastguard Worker # Useful for when user is scripting hooks while not aware of it. 230*da0073e9SAndroid Build Coastguard Worker m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker def pre_hook(self, input: Tuple[str]) -> Tuple[str]: 233*da0073e9SAndroid Build Coastguard Worker assert self.name == "outer_mod_name" 234*da0073e9SAndroid Build Coastguard Worker assert input[4] == "a" # out of bounds tuple range 235*da0073e9SAndroid Build Coastguard Worker return ("pre_hook_override_name",) 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker m.register_forward_pre_hook(pre_hook) 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 240*da0073e9SAndroid Build Coastguard Worker RuntimeError, 241*da0073e9SAndroid Build Coastguard Worker "This error occurred while scripting the forward pre-hook 'pre_hook'", 242*da0073e9SAndroid Build Coastguard Worker ): 243*da0073e9SAndroid Build Coastguard Worker torch.jit.script(m) 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker def test_wrong_pre_hook_signatures(self): 246*da0073e9SAndroid Build Coastguard Worker # correct signature: pre_hook_c(self, input: Tuple[str]) 247*da0073e9SAndroid Build Coastguard Worker def pre_hook_wrong_input1(self, input: Tuple[None]) -> Tuple[str]: 248*da0073e9SAndroid Build Coastguard Worker return ("hello",) 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 251*da0073e9SAndroid Build Coastguard Worker m.register_forward_pre_hook(pre_hook_wrong_input1) 252*da0073e9SAndroid Build Coastguard Worker 253*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 254*da0073e9SAndroid Build Coastguard Worker RuntimeError, 255*da0073e9SAndroid Build Coastguard Worker "has the wrong inner types for the input tuple argument", 256*da0073e9SAndroid Build Coastguard Worker ): 257*da0073e9SAndroid Build Coastguard Worker torch.jit.script(m) 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard Worker def pre_hook_wrong_input2(self, input: Tuple[str], input2: str) -> Tuple[str]: 260*da0073e9SAndroid Build Coastguard Worker return ("hello",) 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 263*da0073e9SAndroid Build Coastguard Worker m.register_forward_pre_hook(pre_hook_wrong_input2) 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 266*da0073e9SAndroid Build Coastguard Worker RuntimeError, 267*da0073e9SAndroid Build Coastguard Worker "was expected to only have exactly 2 inputs but it had 3 inputs", 268*da0073e9SAndroid Build Coastguard Worker ): 269*da0073e9SAndroid Build Coastguard Worker torch.jit.script(m) 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Worker def pre_hook_wrong_input3(self, input: int) -> Tuple[str]: 272*da0073e9SAndroid Build Coastguard Worker return ("hello",) 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 275*da0073e9SAndroid Build Coastguard Worker m.register_forward_pre_hook(pre_hook_wrong_input3) 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 278*da0073e9SAndroid Build Coastguard Worker RuntimeError, 279*da0073e9SAndroid Build Coastguard Worker "expected the input argument to be typed as a Tuple but" 280*da0073e9SAndroid Build Coastguard Worker " found type: 'int' instead", 281*da0073e9SAndroid Build Coastguard Worker ): 282*da0073e9SAndroid Build Coastguard Worker torch.jit.script(m) 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard Worker def pre_hook_wrong_output(self, input: Tuple[str]) -> int: 285*da0073e9SAndroid Build Coastguard Worker return 1 # expecting Tuple[str], str, or None 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Worker m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 288*da0073e9SAndroid Build Coastguard Worker m.register_forward_pre_hook(pre_hook_wrong_output) 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 291*da0073e9SAndroid Build Coastguard Worker RuntimeError, 292*da0073e9SAndroid Build Coastguard Worker "returned the wrong type of: 'int'", 293*da0073e9SAndroid Build Coastguard Worker ): 294*da0073e9SAndroid Build Coastguard Worker torch.jit.script(m) 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Worker def pre_hook_no_output_annotation(self, input: Tuple[str]): 297*da0073e9SAndroid Build Coastguard Worker return 1 # expecting Tuple[str], str, or None 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 300*da0073e9SAndroid Build Coastguard Worker m.register_forward_pre_hook(pre_hook_no_output_annotation) 301*da0073e9SAndroid Build Coastguard Worker 302*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 303*da0073e9SAndroid Build Coastguard Worker RuntimeError, 304*da0073e9SAndroid Build Coastguard Worker "is missing a return annotation. Return annotations" 305*da0073e9SAndroid Build Coastguard Worker " are required, please add one.", 306*da0073e9SAndroid Build Coastguard Worker ): 307*da0073e9SAndroid Build Coastguard Worker torch.jit.script(m) 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker def pre_hook_wrong_tuple_return(self, input: Tuple[Tuple[int]]) -> Tuple[int]: 310*da0073e9SAndroid Build Coastguard Worker return (11,) # doesn't work with eager, inner tuple lost 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Worker m = ModuleForwardTupleInput("outer_mod_name", "inner_mod_name") 313*da0073e9SAndroid Build Coastguard Worker m.register_forward_pre_hook(pre_hook_wrong_tuple_return) 314*da0073e9SAndroid Build Coastguard Worker 315*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 316*da0073e9SAndroid Build Coastguard Worker RuntimeError, 317*da0073e9SAndroid Build Coastguard Worker "When forward has a single tuple input argument, " 318*da0073e9SAndroid Build Coastguard Worker "the return needs to be 'None' or a nested tuple containing " 319*da0073e9SAndroid Build Coastguard Worker r"forward's input tuple argument as in: 'Tuple\[Tuple\[int\]\]'", 320*da0073e9SAndroid Build Coastguard Worker ): 321*da0073e9SAndroid Build Coastguard Worker torch.jit.script(m) 322*da0073e9SAndroid Build Coastguard Worker 323*da0073e9SAndroid Build Coastguard Worker def test_wrong_hook_signatures(self): 324*da0073e9SAndroid Build Coastguard Worker # correct signature: 325*da0073e9SAndroid Build Coastguard Worker # def forward_hook(self, input: Tuple[str], output: str) 326*da0073e9SAndroid Build Coastguard Worker def forward_hook_wrong_input1(self, input: Tuple[str, str], output: str): 327*da0073e9SAndroid Build Coastguard Worker return output 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 330*da0073e9SAndroid Build Coastguard Worker m.register_forward_hook(forward_hook_wrong_input1) 331*da0073e9SAndroid Build Coastguard Worker 332*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 333*da0073e9SAndroid Build Coastguard Worker RuntimeError, 334*da0073e9SAndroid Build Coastguard Worker "has the wrong number of contained types for the " 335*da0073e9SAndroid Build Coastguard Worker r"input argument's Tuple. Received type: 'Tuple\[str, str\]'", 336*da0073e9SAndroid Build Coastguard Worker ): 337*da0073e9SAndroid Build Coastguard Worker torch.jit.script(m) 338*da0073e9SAndroid Build Coastguard Worker 339*da0073e9SAndroid Build Coastguard Worker def forward_hook_wrong_input2(self, input: str, output: str): 340*da0073e9SAndroid Build Coastguard Worker return output 341*da0073e9SAndroid Build Coastguard Worker 342*da0073e9SAndroid Build Coastguard Worker m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 343*da0073e9SAndroid Build Coastguard Worker m.register_forward_hook(forward_hook_wrong_input2) 344*da0073e9SAndroid Build Coastguard Worker 345*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 346*da0073e9SAndroid Build Coastguard Worker RuntimeError, 347*da0073e9SAndroid Build Coastguard Worker "expected the input argument to be typed as a Tuple " 348*da0073e9SAndroid Build Coastguard Worker "but found type: 'str' instead.", 349*da0073e9SAndroid Build Coastguard Worker ): 350*da0073e9SAndroid Build Coastguard Worker torch.jit.script(m) 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker def forward_hook_wrong_input3(self, input: Tuple[None], output: str): 353*da0073e9SAndroid Build Coastguard Worker return output 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 356*da0073e9SAndroid Build Coastguard Worker m.register_forward_hook(forward_hook_wrong_input3) 357*da0073e9SAndroid Build Coastguard Worker 358*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 359*da0073e9SAndroid Build Coastguard Worker RuntimeError, 360*da0073e9SAndroid Build Coastguard Worker "has the wrong inner types for the input tuple" 361*da0073e9SAndroid Build Coastguard Worker r" argument. Received type: 'Tuple\[NoneType\]'", 362*da0073e9SAndroid Build Coastguard Worker ): 363*da0073e9SAndroid Build Coastguard Worker torch.jit.script(m) 364*da0073e9SAndroid Build Coastguard Worker 365*da0073e9SAndroid Build Coastguard Worker def forward_hook_wrong_output(self, input: Tuple[str], output: Tuple[str]): 366*da0073e9SAndroid Build Coastguard Worker return output 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard Worker m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 369*da0073e9SAndroid Build Coastguard Worker m.register_forward_hook(forward_hook_wrong_output) 370*da0073e9SAndroid Build Coastguard Worker 371*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 372*da0073e9SAndroid Build Coastguard Worker RuntimeError, 373*da0073e9SAndroid Build Coastguard Worker "has the wrong type for the output argument. Received" 374*da0073e9SAndroid Build Coastguard Worker r" type: 'Tuple\[str\]'. Expected type: 'str'", 375*da0073e9SAndroid Build Coastguard Worker ): 376*da0073e9SAndroid Build Coastguard Worker torch.jit.script(m) 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker def forward_hook_correct(self, input: Tuple[str], output: str): 379*da0073e9SAndroid Build Coastguard Worker return (output,) 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Worker def forward_hook_wrong_output_from_prev_hook( 382*da0073e9SAndroid Build Coastguard Worker self, input: Tuple[str], output: str 383*da0073e9SAndroid Build Coastguard Worker ): 384*da0073e9SAndroid Build Coastguard Worker return output 385*da0073e9SAndroid Build Coastguard Worker 386*da0073e9SAndroid Build Coastguard Worker m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 387*da0073e9SAndroid Build Coastguard Worker m.register_forward_hook(forward_hook_correct) 388*da0073e9SAndroid Build Coastguard Worker m.register_forward_hook(forward_hook_wrong_output_from_prev_hook) 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 391*da0073e9SAndroid Build Coastguard Worker RuntimeError, 392*da0073e9SAndroid Build Coastguard Worker "has the wrong type for the output argument. " 393*da0073e9SAndroid Build Coastguard Worker r"Received type: 'str'. Expected type: 'Tuple\[str\]'", 394*da0073e9SAndroid Build Coastguard Worker ): 395*da0073e9SAndroid Build Coastguard Worker torch.jit.script(m) 396