xref: /aosp_15_r20/external/pytorch/test/jit/test_hooks.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport os
4*da0073e9SAndroid Build Coastguard Workerimport sys
5*da0073e9SAndroid Build Coastguard Workerimport unittest
6*da0073e9SAndroid Build Coastguard Workerfrom typing import Tuple
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport torch
9*da0073e9SAndroid Build Coastguard Workerfrom jit.test_hooks_modules import (
10*da0073e9SAndroid Build Coastguard Worker    create_forward_tuple_input,
11*da0073e9SAndroid Build Coastguard Worker    create_module_forward_multiple_inputs,
12*da0073e9SAndroid Build Coastguard Worker    create_module_forward_single_input,
13*da0073e9SAndroid Build Coastguard Worker    create_module_hook_return_nothing,
14*da0073e9SAndroid Build Coastguard Worker    create_module_multiple_hooks_multiple_inputs,
15*da0073e9SAndroid Build Coastguard Worker    create_module_multiple_hooks_single_input,
16*da0073e9SAndroid Build Coastguard Worker    create_module_no_forward_input,
17*da0073e9SAndroid Build Coastguard Worker    create_module_same_hook_repeated,
18*da0073e9SAndroid Build Coastguard Worker    create_submodule_forward_multiple_inputs,
19*da0073e9SAndroid Build Coastguard Worker    create_submodule_forward_single_input,
20*da0073e9SAndroid Build Coastguard Worker    create_submodule_forward_single_input_return_not_tupled,
21*da0073e9SAndroid Build Coastguard Worker    create_submodule_hook_return_nothing,
22*da0073e9SAndroid Build Coastguard Worker    create_submodule_multiple_hooks_multiple_inputs,
23*da0073e9SAndroid Build Coastguard Worker    create_submodule_multiple_hooks_single_input,
24*da0073e9SAndroid Build Coastguard Worker    create_submodule_no_forward_input,
25*da0073e9SAndroid Build Coastguard Worker    create_submodule_same_hook_repeated,
26*da0073e9SAndroid Build Coastguard Worker    create_submodule_to_call_directly_with_hooks,
27*da0073e9SAndroid Build Coastguard Worker    ModuleDirectforwardSubmodCall,
28*da0073e9SAndroid Build Coastguard Worker    ModuleForwardSingleInput,
29*da0073e9SAndroid Build Coastguard Worker    ModuleForwardTupleInput,
30*da0073e9SAndroid Build Coastguard Worker)
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable
34*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
35*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir)
36*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
40*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(
41*da0073e9SAndroid Build Coastguard Worker        "This test file is not meant to be run directly, use:\n\n"
42*da0073e9SAndroid Build Coastguard Worker        "\tpython test/test_jit.py TESTNAME\n\n"
43*da0073e9SAndroid Build Coastguard Worker        "instead."
44*da0073e9SAndroid Build Coastguard Worker    )
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker# Tests for JIT forward hooks and pre-hooks
48*da0073e9SAndroid Build Coastguard Workerclass TestHooks(JitTestCase):
49*da0073e9SAndroid Build Coastguard Worker    def test_module_no_forward_input(self):
50*da0073e9SAndroid Build Coastguard Worker        self.checkModule(create_module_no_forward_input(), ())
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker    def test_submodule_no_forward_input(self):
53*da0073e9SAndroid Build Coastguard Worker        self.checkModule(create_submodule_no_forward_input(), ())
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    def test_module_forward_multiple_inputs(self):
56*da0073e9SAndroid Build Coastguard Worker        self.checkModule(
57*da0073e9SAndroid Build Coastguard Worker            create_module_forward_multiple_inputs(), (["a"], "no_pre_hook")
58*da0073e9SAndroid Build Coastguard Worker        )
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker    def test_module_multiple_hooks_multiple_inputs(self):
61*da0073e9SAndroid Build Coastguard Worker        self.checkModule(
62*da0073e9SAndroid Build Coastguard Worker            create_module_multiple_hooks_multiple_inputs(), (["a"], "no_pre_hook")
63*da0073e9SAndroid Build Coastguard Worker        )
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker    def test_module_forward_single_input(self):
66*da0073e9SAndroid Build Coastguard Worker        self.checkModule(create_module_forward_single_input(), ("a",))
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker    def test_module_same_hook_repeated(self):
69*da0073e9SAndroid Build Coastguard Worker        self.checkModule(create_module_same_hook_repeated(), ("a",))
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker    def test_module_hook_return_nothing(self):
72*da0073e9SAndroid Build Coastguard Worker        self.checkModule(create_module_hook_return_nothing(), ("a",))
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker    def test_module_multiple_hooks_single_input(self):
75*da0073e9SAndroid Build Coastguard Worker        self.checkModule(create_module_multiple_hooks_single_input(), ("a",))
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker    def test_submodule_forward_multiple_inputs(self):
78*da0073e9SAndroid Build Coastguard Worker        self.checkModule(
79*da0073e9SAndroid Build Coastguard Worker            create_submodule_forward_multiple_inputs(), (["a"], "no_pre_hook")
80*da0073e9SAndroid Build Coastguard Worker        )
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker    def test_submodule_multiple_hooks_multiple_inputs(self):
83*da0073e9SAndroid Build Coastguard Worker        self.checkModule(
84*da0073e9SAndroid Build Coastguard Worker            create_submodule_multiple_hooks_multiple_inputs(),
85*da0073e9SAndroid Build Coastguard Worker            (["a"], "no_pre_hook"),
86*da0073e9SAndroid Build Coastguard Worker        )
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker    def test_submodule_forward_single_input(self):
89*da0073e9SAndroid Build Coastguard Worker        self.checkModule(create_submodule_forward_single_input(), ("a",))
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker    def test_submodule_called_directly_with_hooks(self):
92*da0073e9SAndroid Build Coastguard Worker        module = create_submodule_to_call_directly_with_hooks()
93*da0073e9SAndroid Build Coastguard Worker        module_scripted = torch.jit.script(module)
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker        submodule = module.submodule
96*da0073e9SAndroid Build Coastguard Worker        scripted_submodule = module_scripted.submodule
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(submodule("a"), scripted_submodule("a"))
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker    def test_submodule_same_hook_repeated(self):
101*da0073e9SAndroid Build Coastguard Worker        self.checkModule(create_submodule_same_hook_repeated(), ("a",))
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker    def test_submodule_hook_return_nothing(self):
104*da0073e9SAndroid Build Coastguard Worker        self.checkModule(create_submodule_hook_return_nothing(), ("a",))
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker    def test_submodule_multiple_hooks_single_input(self):
107*da0073e9SAndroid Build Coastguard Worker        self.checkModule(create_submodule_multiple_hooks_single_input(), (["a"]))
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker    def test_forward_tuple_input(self):
110*da0073e9SAndroid Build Coastguard Worker        self.checkModule(create_forward_tuple_input(), ((3,),))
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker    def test_submodule_forward_single_input_return_not_tupled(self):
113*da0073e9SAndroid Build Coastguard Worker        self.checkModule(
114*da0073e9SAndroid Build Coastguard Worker            create_submodule_forward_single_input_return_not_tupled(), ("a",)
115*da0073e9SAndroid Build Coastguard Worker        )
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker    def test_hook_method_name_collision(self):
118*da0073e9SAndroid Build Coastguard Worker        # Hooks can't have the same name as methods.
119*da0073e9SAndroid Build Coastguard Worker        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker        def foo(self, input: Tuple[str]) -> Tuple[str]:
122*da0073e9SAndroid Build Coastguard Worker            assert self.name == "inner_mod_name"
123*da0073e9SAndroid Build Coastguard Worker            assert input[0] == "a_outermod"
124*da0073e9SAndroid Build Coastguard Worker            return ("pre_hook_override_name",)
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker        m.submodule.register_forward_pre_hook(foo)
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
129*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
130*da0073e9SAndroid Build Coastguard Worker            "Can't define hook: foo on class: .+ "
131*da0073e9SAndroid Build Coastguard Worker            "because a method or hook with that name already exists.",
132*da0073e9SAndroid Build Coastguard Worker        ):
133*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(m)
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker    def test_hook_hook_name_collision(self):
136*da0073e9SAndroid Build Coastguard Worker        # Test edge case of two hooks sharing name but not python definition
137*da0073e9SAndroid Build Coastguard Worker        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Worker        def prehook(self, input: Tuple[str]) -> Tuple[str]:
140*da0073e9SAndroid Build Coastguard Worker            return "This is the first hook"
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker        m.submodule.register_forward_pre_hook(prehook)
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker        def prehook(self, input: Tuple[str]) -> Tuple[str]:
145*da0073e9SAndroid Build Coastguard Worker            return "This is the second hook"
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker        m.submodule.register_forward_pre_hook(prehook)
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
150*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
151*da0073e9SAndroid Build Coastguard Worker            "Pre-hook '.+' on .+ has at least two different python "
152*da0073e9SAndroid Build Coastguard Worker            "definitions. Please use unique names for all hooks.",
153*da0073e9SAndroid Build Coastguard Worker        ):
154*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(m)
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker        def hook(self, input: Tuple[str], output: str):
159*da0073e9SAndroid Build Coastguard Worker            return "This is the first hook"
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker        m.submodule.register_forward_hook(hook)
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker        def hook(self, input: Tuple[str]):
164*da0073e9SAndroid Build Coastguard Worker            return "This is the second hook"
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker        m.submodule.register_forward_hook(hook)
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
169*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
170*da0073e9SAndroid Build Coastguard Worker            "Hook '.+' on .+ has at least two different python "
171*da0073e9SAndroid Build Coastguard Worker            "definitions. Please use unique names for all hooks.",
172*da0073e9SAndroid Build Coastguard Worker        ):
173*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(m)
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker    def test_module_direct_forward_invocation(self):
176*da0073e9SAndroid Build Coastguard Worker        # Test that hooks are only invoked when the module is
177*da0073e9SAndroid Build Coastguard Worker        # called directly and not when forward is called.
178*da0073e9SAndroid Build Coastguard Worker        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker        def pre_hook(self, input: Tuple[str]) -> Tuple[str]:
181*da0073e9SAndroid Build Coastguard Worker            return ("pre_hook_override_name",)
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker        def forward_hook(self, input: Tuple[str], output: str):
184*da0073e9SAndroid Build Coastguard Worker            assert self.name == "outer_mod_name"
185*da0073e9SAndroid Build Coastguard Worker            assert input == ("pre_hook_override_name",)
186*da0073e9SAndroid Build Coastguard Worker            output = output + "_fh"
187*da0073e9SAndroid Build Coastguard Worker            return output
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker        m.register_forward_pre_hook(pre_hook)
190*da0073e9SAndroid Build Coastguard Worker        m.register_forward_hook(forward_hook)
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker        m_scripted = torch.jit.script(m)
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m.forward("a"), m_scripted.forward("a"))
195*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(m_scripted("a"), m_scripted.forward("a"))
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker    def test_submodule_direct_forward_invocation(self):
198*da0073e9SAndroid Build Coastguard Worker        m_submod_forward_call = ModuleDirectforwardSubmodCall(
199*da0073e9SAndroid Build Coastguard Worker            "outer_mod_name", "inner_mod_name"
200*da0073e9SAndroid Build Coastguard Worker        )
201*da0073e9SAndroid Build Coastguard Worker        m_submod_call = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker        def pre_hook(self, input: Tuple[str]) -> Tuple[str]:
204*da0073e9SAndroid Build Coastguard Worker            return ("pre_hook_override_name",)
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker        def forward_hook(self, input: Tuple[str], output: str):
207*da0073e9SAndroid Build Coastguard Worker            assert input == ("pre_hook_override_name",)
208*da0073e9SAndroid Build Coastguard Worker            return output + "_fh"
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker        m_submod_forward_call.submodule.register_forward_pre_hook(pre_hook)
211*da0073e9SAndroid Build Coastguard Worker        m_submod_forward_call.submodule.register_forward_hook(forward_hook)
212*da0073e9SAndroid Build Coastguard Worker        m_submod_call.submodule.register_forward_pre_hook(pre_hook)
213*da0073e9SAndroid Build Coastguard Worker        m_submod_call.submodule.register_forward_hook(forward_hook)
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker        m_submod_forward_call_scripted = torch.jit.script(m_submod_forward_call)
216*da0073e9SAndroid Build Coastguard Worker        m_submod_call_scripted = torch.jit.script(m_submod_call)
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
219*da0073e9SAndroid Build Coastguard Worker            m_submod_forward_call_scripted("a"), m_submod_forward_call("a")
220*da0073e9SAndroid Build Coastguard Worker        )
221*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(
222*da0073e9SAndroid Build Coastguard Worker            m_submod_forward_call_scripted("a"), m_submod_call_scripted("a")
223*da0073e9SAndroid Build Coastguard Worker        )
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker    # TODO: add this test back once figured out how to print error msg
226*da0073e9SAndroid Build Coastguard Worker    @unittest.skip
227*da0073e9SAndroid Build Coastguard Worker    def test_hook_compilation_hint(self):
228*da0073e9SAndroid Build Coastguard Worker        # Tests if hook error message is printed out if erroring after schema check.
229*da0073e9SAndroid Build Coastguard Worker        # Useful for when user is scripting hooks while not aware of it.
230*da0073e9SAndroid Build Coastguard Worker        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker        def pre_hook(self, input: Tuple[str]) -> Tuple[str]:
233*da0073e9SAndroid Build Coastguard Worker            assert self.name == "outer_mod_name"
234*da0073e9SAndroid Build Coastguard Worker            assert input[4] == "a"  # out of bounds tuple range
235*da0073e9SAndroid Build Coastguard Worker            return ("pre_hook_override_name",)
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker        m.register_forward_pre_hook(pre_hook)
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
240*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
241*da0073e9SAndroid Build Coastguard Worker            "This error occurred while scripting the forward pre-hook 'pre_hook'",
242*da0073e9SAndroid Build Coastguard Worker        ):
243*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(m)
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker    def test_wrong_pre_hook_signatures(self):
246*da0073e9SAndroid Build Coastguard Worker        # correct signature: pre_hook_c(self, input: Tuple[str])
247*da0073e9SAndroid Build Coastguard Worker        def pre_hook_wrong_input1(self, input: Tuple[None]) -> Tuple[str]:
248*da0073e9SAndroid Build Coastguard Worker            return ("hello",)
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
251*da0073e9SAndroid Build Coastguard Worker        m.register_forward_pre_hook(pre_hook_wrong_input1)
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
254*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
255*da0073e9SAndroid Build Coastguard Worker            "has the wrong inner types for the input tuple argument",
256*da0073e9SAndroid Build Coastguard Worker        ):
257*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(m)
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Worker        def pre_hook_wrong_input2(self, input: Tuple[str], input2: str) -> Tuple[str]:
260*da0073e9SAndroid Build Coastguard Worker            return ("hello",)
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
263*da0073e9SAndroid Build Coastguard Worker        m.register_forward_pre_hook(pre_hook_wrong_input2)
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
266*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
267*da0073e9SAndroid Build Coastguard Worker            "was expected to only have exactly 2 inputs but it had 3 inputs",
268*da0073e9SAndroid Build Coastguard Worker        ):
269*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(m)
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Worker        def pre_hook_wrong_input3(self, input: int) -> Tuple[str]:
272*da0073e9SAndroid Build Coastguard Worker            return ("hello",)
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
275*da0073e9SAndroid Build Coastguard Worker        m.register_forward_pre_hook(pre_hook_wrong_input3)
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
278*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
279*da0073e9SAndroid Build Coastguard Worker            "expected the input argument to be typed as a Tuple but"
280*da0073e9SAndroid Build Coastguard Worker            " found type: 'int' instead",
281*da0073e9SAndroid Build Coastguard Worker        ):
282*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(m)
283*da0073e9SAndroid Build Coastguard Worker
284*da0073e9SAndroid Build Coastguard Worker        def pre_hook_wrong_output(self, input: Tuple[str]) -> int:
285*da0073e9SAndroid Build Coastguard Worker            return 1  # expecting Tuple[str], str, or None
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
288*da0073e9SAndroid Build Coastguard Worker        m.register_forward_pre_hook(pre_hook_wrong_output)
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
291*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
292*da0073e9SAndroid Build Coastguard Worker            "returned the wrong type of: 'int'",
293*da0073e9SAndroid Build Coastguard Worker        ):
294*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(m)
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker        def pre_hook_no_output_annotation(self, input: Tuple[str]):
297*da0073e9SAndroid Build Coastguard Worker            return 1  # expecting Tuple[str], str, or None
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
300*da0073e9SAndroid Build Coastguard Worker        m.register_forward_pre_hook(pre_hook_no_output_annotation)
301*da0073e9SAndroid Build Coastguard Worker
302*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
303*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
304*da0073e9SAndroid Build Coastguard Worker            "is missing a return annotation. Return annotations"
305*da0073e9SAndroid Build Coastguard Worker            " are required, please add one.",
306*da0073e9SAndroid Build Coastguard Worker        ):
307*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(m)
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker        def pre_hook_wrong_tuple_return(self, input: Tuple[Tuple[int]]) -> Tuple[int]:
310*da0073e9SAndroid Build Coastguard Worker            return (11,)  # doesn't work with eager, inner tuple lost
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Worker        m = ModuleForwardTupleInput("outer_mod_name", "inner_mod_name")
313*da0073e9SAndroid Build Coastguard Worker        m.register_forward_pre_hook(pre_hook_wrong_tuple_return)
314*da0073e9SAndroid Build Coastguard Worker
315*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
316*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
317*da0073e9SAndroid Build Coastguard Worker            "When forward has a single tuple input argument, "
318*da0073e9SAndroid Build Coastguard Worker            "the return needs to be 'None' or a nested tuple containing "
319*da0073e9SAndroid Build Coastguard Worker            r"forward's input tuple argument as in: 'Tuple\[Tuple\[int\]\]'",
320*da0073e9SAndroid Build Coastguard Worker        ):
321*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(m)
322*da0073e9SAndroid Build Coastguard Worker
323*da0073e9SAndroid Build Coastguard Worker    def test_wrong_hook_signatures(self):
324*da0073e9SAndroid Build Coastguard Worker        # correct signature:
325*da0073e9SAndroid Build Coastguard Worker        #   def forward_hook(self, input: Tuple[str], output: str)
326*da0073e9SAndroid Build Coastguard Worker        def forward_hook_wrong_input1(self, input: Tuple[str, str], output: str):
327*da0073e9SAndroid Build Coastguard Worker            return output
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
330*da0073e9SAndroid Build Coastguard Worker        m.register_forward_hook(forward_hook_wrong_input1)
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
333*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
334*da0073e9SAndroid Build Coastguard Worker            "has the wrong number of contained types for the "
335*da0073e9SAndroid Build Coastguard Worker            r"input argument's Tuple. Received type: 'Tuple\[str, str\]'",
336*da0073e9SAndroid Build Coastguard Worker        ):
337*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(m)
338*da0073e9SAndroid Build Coastguard Worker
339*da0073e9SAndroid Build Coastguard Worker        def forward_hook_wrong_input2(self, input: str, output: str):
340*da0073e9SAndroid Build Coastguard Worker            return output
341*da0073e9SAndroid Build Coastguard Worker
342*da0073e9SAndroid Build Coastguard Worker        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
343*da0073e9SAndroid Build Coastguard Worker        m.register_forward_hook(forward_hook_wrong_input2)
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
346*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
347*da0073e9SAndroid Build Coastguard Worker            "expected the input argument to be typed as a Tuple "
348*da0073e9SAndroid Build Coastguard Worker            "but found type: 'str' instead.",
349*da0073e9SAndroid Build Coastguard Worker        ):
350*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(m)
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker        def forward_hook_wrong_input3(self, input: Tuple[None], output: str):
353*da0073e9SAndroid Build Coastguard Worker            return output
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
356*da0073e9SAndroid Build Coastguard Worker        m.register_forward_hook(forward_hook_wrong_input3)
357*da0073e9SAndroid Build Coastguard Worker
358*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
359*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
360*da0073e9SAndroid Build Coastguard Worker            "has the wrong inner types for the input tuple"
361*da0073e9SAndroid Build Coastguard Worker            r" argument. Received type: 'Tuple\[NoneType\]'",
362*da0073e9SAndroid Build Coastguard Worker        ):
363*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(m)
364*da0073e9SAndroid Build Coastguard Worker
365*da0073e9SAndroid Build Coastguard Worker        def forward_hook_wrong_output(self, input: Tuple[str], output: Tuple[str]):
366*da0073e9SAndroid Build Coastguard Worker            return output
367*da0073e9SAndroid Build Coastguard Worker
368*da0073e9SAndroid Build Coastguard Worker        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
369*da0073e9SAndroid Build Coastguard Worker        m.register_forward_hook(forward_hook_wrong_output)
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
372*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
373*da0073e9SAndroid Build Coastguard Worker            "has the wrong type for the output argument. Received"
374*da0073e9SAndroid Build Coastguard Worker            r" type: 'Tuple\[str\]'. Expected type: 'str'",
375*da0073e9SAndroid Build Coastguard Worker        ):
376*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(m)
377*da0073e9SAndroid Build Coastguard Worker
378*da0073e9SAndroid Build Coastguard Worker        def forward_hook_correct(self, input: Tuple[str], output: str):
379*da0073e9SAndroid Build Coastguard Worker            return (output,)
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Worker        def forward_hook_wrong_output_from_prev_hook(
382*da0073e9SAndroid Build Coastguard Worker            self, input: Tuple[str], output: str
383*da0073e9SAndroid Build Coastguard Worker        ):
384*da0073e9SAndroid Build Coastguard Worker            return output
385*da0073e9SAndroid Build Coastguard Worker
386*da0073e9SAndroid Build Coastguard Worker        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
387*da0073e9SAndroid Build Coastguard Worker        m.register_forward_hook(forward_hook_correct)
388*da0073e9SAndroid Build Coastguard Worker        m.register_forward_hook(forward_hook_wrong_output_from_prev_hook)
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
391*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
392*da0073e9SAndroid Build Coastguard Worker            "has the wrong type for the output argument. "
393*da0073e9SAndroid Build Coastguard Worker            r"Received type: 'str'. Expected type: 'Tuple\[str\]'",
394*da0073e9SAndroid Build Coastguard Worker        ):
395*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(m)
396