xref: /aosp_15_r20/external/pytorch/test/nn/test_module_hooks.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: nn"]
2import gc
3import math
4import pickle
5import unittest
6import warnings
7import weakref
8from collections import namedtuple, OrderedDict
9from copy import deepcopy
10from functools import partial
11from tempfile import NamedTemporaryFile
12from typing import Any, Dict, List, Tuple
13
14import torch
15import torch.nn as nn
16from torch.testing._internal.common_nn import _create_basic_net, NNTestCase
17from torch.testing._internal.common_utils import (
18    instantiate_parametrized_tests,
19    IS_WINDOWS,
20    parametrize as parametrize_test,
21    run_tests,
22    skipIfTorchDynamo,
23    swap,
24    TestCase,
25)
26
27
28class Net(nn.Module):
29    def __init__(self) -> None:
30        super().__init__()
31        self.seq1 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(2)])
32        self.seq2 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(2)])
33
34    def forward(self, x: torch.Tensor) -> torch.Tensor:
35        return self.seq2(self.seq1(x))
36
37
38ToyNamedTuple = namedtuple("ToyNamedTuple", "content")
39
40
41class ToyModel(nn.Module):
42    def __init__(self, with_named_tuple=False) -> None:
43        super().__init__()
44        self.net1 = Net()
45        self.net2 = Net()
46        self.with_named_tuple = with_named_tuple
47
48    def forward(self, x: torch.Tensor) -> torch.Tensor:
49        res = self.net2(self.net1(x))
50        if self.with_named_tuple:
51            return ToyNamedTuple(res)
52        else:
53            return (res,)
54
55
56def forward_hook(
57    self: TestCase,
58    fired_hooks: List[int],
59    expected_module: nn.Module,
60    hook_id: int,
61    module: nn.Module,
62    inp: Tuple[torch.Tensor],
63    out: torch.Tensor,
64) -> None:
65    fired_hooks.append(hook_id)
66    self.assertEqual(id(module), id(expected_module))
67    self.assertEqual(len(inp), 1)
68
69
70def forward_pre_hook(
71    self: TestCase,
72    fired_hooks: List[int],
73    expected_module: nn.Module,
74    hook_id: int,
75    module: nn.Module,
76    inp: Tuple[torch.Tensor],
77) -> None:
78    fired_hooks.append(hook_id)
79    self.assertEqual(id(module), id(expected_module))
80    self.assertEqual(len(inp), 1)
81
82
83def full_backward_hook(
84    self: TestCase,
85    fired_hooks: List[int],
86    expected_module: nn.Module,
87    hook_id: int,
88    module: nn.Module,
89    grad_input: Tuple[torch.Tensor],
90    grad_output: Tuple[torch.Tensor],
91) -> None:
92    fired_hooks.append(hook_id)
93    self.assertEqual(id(module), id(expected_module))
94    self.assertEqual(len(grad_input), 1)
95    self.assertEqual(len(grad_output), 1)
96
97
98def full_backward_pre_hook(
99    self: TestCase,
100    fired_hooks: List[int],
101    expected_module: nn.Module,
102    hook_id: int,
103    module: nn.Module,
104    grad_input: Tuple[torch.Tensor],
105) -> None:
106    fired_hooks.append(hook_id)
107    self.assertEqual(id(module), id(expected_module))
108    self.assertEqual(len(grad_input), 1)
109
110
111class KwargModel(nn.Module):
112    def __init__(self) -> None:
113        super().__init__()
114        self.net1 = Net()
115        self.net2 = Net()
116
117    def forward(self, x: torch.Tensor, bias: torch.Tensor = None) -> torch.Tensor:
118        if bias is not None:
119            x = x + bias
120        return x
121
122    def internal_forward_hook(
123        self,
124        module: nn.Module,
125        args: Tuple[torch.Tensor],
126        kwargs: Dict[str, Any],
127        out: torch.Tensor,
128    ):
129        return out + kwargs["bias"]
130
131
132class FailsInForwardModel(nn.Module):
133    def __init__(self) -> None:
134        super().__init__()
135        self.net1 = Net()
136
137    def forward(self, x: torch.Tensor, fail: bool = True) -> torch.Tensor:
138        if fail:
139            raise RuntimeError("failing in forward")
140        return self.net1(x)
141
142
143def kwarg_forward_pre_hook(
144    self: TestCase,
145    fired_hooks: List[int],
146    expected_module: nn.Module,
147    hook_id: int,
148    module: nn.Module,
149    args: Tuple[torch.Tensor],
150    kwargs: Dict[str, Any],
151) -> Tuple[Any, Any]:
152    fired_hooks.append(hook_id)
153    self.assertEqual(id(module), id(expected_module))
154    self.assertEqual(len(args), 1)
155    kwargs["bias"] = 2 * kwargs["bias"]
156    return args, kwargs
157
158
159def kwarg_forward_hook(
160    self: TestCase,
161    fired_hooks: List[int],
162    expected_module: nn.Module,
163    hook_id: int,
164    module: nn.Module,
165    args: Tuple[torch.Tensor],
166    kwargs: Dict[str, Any],
167    out: torch.Tensor,
168) -> Any:
169    fired_hooks.append(hook_id)
170    self.assertEqual(id(module), id(expected_module))
171    self.assertEqual(len(args), 1)
172
173    out = out + kwargs["bias"]
174    return out
175
176
177class DummyContextManager:
178    def __init__(self, inp):
179        self.input = inp
180
181    def __enter__(self, *args, **kwargs):
182        self.input.append(2)
183
184    def __exit__(self, *args, **kwargs):
185        self.input.append(-1)
186
187
188class TestModuleHooks(TestCase):
189    @parametrize_test("named_tuple", (True, False))
190    def test_forward_hooks(self, named_tuple):
191        fired_hooks: List[int] = []
192        model = ToyModel(named_tuple)
193        x = torch.randn(10, 10)
194        hook = partial(forward_hook, self, fired_hooks, model.net1.seq2)
195        model.net1.seq2.register_forward_hook(partial(hook, 0))
196        model.net1.seq2.register_forward_hook(partial(hook, 1), prepend=True)
197        model.net1.seq2.register_forward_hook(partial(hook, 2))
198        model.net1.seq2.register_forward_hook(partial(hook, 3))
199        model.net1.seq2.register_forward_hook(partial(hook, 4), prepend=True)
200        expected = [4, 1, 0, 2, 3]
201
202        self.assertEqual(fired_hooks, [])
203        out = model(x)
204        self.assertEqual(fired_hooks, expected)
205        self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple)
206        out[0].sum().backward()
207        self.assertEqual(fired_hooks, expected)
208        model(x)[0].sum().backward()
209        self.assertEqual(fired_hooks, expected + expected)
210
211    @parametrize_test("named_tuple", (True, False))
212    def test_forward_pre_hooks(self, named_tuple):
213        fired_hooks: List[int] = []
214        model = ToyModel(named_tuple)
215        x = torch.randn(10, 10)
216        hook = partial(forward_pre_hook, self, fired_hooks, model.net2.seq1)
217        model.net2.seq1.register_forward_pre_hook(partial(hook, 0), prepend=True)
218        model.net2.seq1.register_forward_pre_hook(partial(hook, 1))
219        model.net2.seq1.register_forward_pre_hook(partial(hook, 2))
220        model.net2.seq1.register_forward_pre_hook(partial(hook, 3))
221        model.net2.seq1.register_forward_pre_hook(partial(hook, 4), prepend=True)
222        expected = [4, 0, 1, 2, 3]
223
224        self.assertEqual(fired_hooks, [])
225        out = model(x)
226        self.assertEqual(fired_hooks, expected)
227        self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple)
228        out[0].sum().backward()
229        self.assertEqual(fired_hooks, expected)
230        model(x)[0].sum().backward()
231        self.assertEqual(fired_hooks, expected + expected)
232
233    @parametrize_test("named_tuple", (True, False))
234    def test_full_backward_hooks(self, named_tuple):
235        fired_hooks: List[int] = []
236        model = ToyModel(named_tuple)
237        x = torch.randn(10, 10)
238        hook = partial(full_backward_hook, self, fired_hooks, model.net1)
239        model.net1.register_full_backward_hook(partial(hook, 0))
240        model.net1.register_full_backward_hook(partial(hook, 1))
241        model.net1.register_full_backward_hook(partial(hook, 2))
242        model.net1.register_full_backward_hook(partial(hook, 3), prepend=True)
243        model.net1.register_full_backward_hook(partial(hook, 4), prepend=True)
244        expected = [4, 3, 0, 1, 2]
245
246        self.assertEqual(fired_hooks, [])
247        out = model(x)
248        self.assertEqual(fired_hooks, [])
249        self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple)
250        out[0].sum().backward()
251        self.assertEqual(fired_hooks, expected)
252        model(x)[0].sum().backward()
253        self.assertEqual(fired_hooks, expected + expected)
254
255    @parametrize_test("named_tuple", (True, False))
256    def test_full_backward_pre_hooks(self, named_tuple):
257        fired_hooks: List[int] = []
258        model = ToyModel(named_tuple)
259        x = torch.randn(10, 10)
260        hook = partial(full_backward_pre_hook, self, fired_hooks, model.net1)
261        model.net1.register_full_backward_pre_hook(partial(hook, 0), prepend=True)
262        model.net1.register_full_backward_pre_hook(partial(hook, 1), prepend=True)
263        model.net1.register_full_backward_pre_hook(partial(hook, 2))
264        model.net1.register_full_backward_pre_hook(partial(hook, 3))
265        model.net1.register_full_backward_pre_hook(partial(hook, 4))
266        expected = [1, 0, 2, 3, 4]
267
268        self.assertEqual(fired_hooks, [])
269        out = model(x)
270        self.assertEqual(fired_hooks, [])
271        self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple)
272        out[0].sum().backward()
273        self.assertEqual(fired_hooks, expected)
274        model(x)[0].sum().backward()
275        self.assertEqual(fired_hooks, expected + expected)
276
277        # Backward pre hook can affect subsequent gradient computation
278        for rg in [True, False]:
279            a = torch.ones(2, requires_grad=rg)
280            model = nn.Linear(2, 2)
281
282            def fn(_unused_module, grad_output):
283                return (grad_output[0] * 0,)
284
285            model.register_full_backward_pre_hook(fn)
286
287            out = model(a)
288            out.sum().backward()
289            self.assertEqual(model.weight.grad, torch.zeros(2, 2))
290            if rg:
291                self.assertEqual(a.grad, torch.zeros_like(a))
292            else:
293                self.assertIsNone(a.grad)
294
295    @parametrize_test("named_tuple", (True, False))
296    def test_mixed_hooks(self, named_tuple):
297        fired_hooks: List[int] = []
298        model = ToyModel(named_tuple)
299        x = torch.randn(10, 10)
300        model.register_forward_pre_hook(
301            partial(forward_pre_hook, self, fired_hooks, model, 0)
302        )
303        model.register_forward_hook(partial(forward_hook, self, fired_hooks, model, 1))
304        model.register_full_backward_pre_hook(
305            partial(full_backward_pre_hook, self, fired_hooks, model, 2)
306        )
307        model.register_full_backward_hook(
308            partial(full_backward_hook, self, fired_hooks, model, 3)
309        )
310
311        self.assertEqual(fired_hooks, [])
312        out = model(x)
313        self.assertEqual(fired_hooks, [0, 1])
314        self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple)
315        out[0].sum().backward()
316        self.assertEqual(fired_hooks, [0, 1, 2, 3])
317        model(x)[0].sum().backward()
318        self.assertEqual(fired_hooks, [0, 1, 2, 3, 0, 1, 2, 3])
319
320    def test_kwarg_hooks(self):
321        # 1. test forward pre hook
322        fired_hooks: List[int] = []
323        x: torch.Tensor = torch.ones(10, 10)
324        bias: torch.Tensor = torch.ones(10, 10)
325        model = KwargModel()
326        model.register_forward_pre_hook(
327            partial(kwarg_forward_pre_hook, self, fired_hooks, model, 0),
328            with_kwargs=True,
329        )
330
331        # forward-pre: bias' = bias * 2
332        # So, out = x + bias * 2
333        self.assertEqual(fired_hooks, [])
334        out = model(x, bias=bias)
335        self.assertEqual(fired_hooks, [0])
336        self.assertEqual(out, x + 2 * bias, rtol=0, atol=1e-5)
337
338        # 2. test forward pre and forward hooks
339        fired_hooks: List[int] = []
340        x: torch.Tensor = torch.ones(10, 10)
341        bias: torch.Tensor = torch.ones(10, 10)
342        model = KwargModel()
343        model.register_forward_hook(
344            partial(kwarg_forward_hook, self, fired_hooks, model, 1),
345            with_kwargs=True,
346        )
347        model.register_forward_pre_hook(
348            partial(kwarg_forward_pre_hook, self, fired_hooks, model, 0),
349            with_kwargs=True,
350        )
351
352        # forward-pre: bias' = bias * 2
353        # forward: out = x + bias'
354        # forward-post: out = out + bias'
355        # So, out = x + bias * 4
356        self.assertEqual(fired_hooks, [])
357        out = model(x, bias=bias)
358        self.assertEqual(fired_hooks, [0, 1])
359        self.assertEqual(out, x + 4 * bias, rtol=0, atol=1e-5)
360
361        # 3. test nn.Module member method as forward-post hook
362        x: torch.Tensor = torch.ones(10, 10)
363        bias: torch.Tensor = torch.ones(10, 10)
364        model = KwargModel()
365        model.register_forward_hook(model.internal_forward_hook, with_kwargs=True)
366
367        # forward: out = x + bias
368        # forward-post: out = out + bias
369        # So, out = x + bias * 2
370        out = model(x, bias=bias)
371        self.assertEqual(out, x + 2 * bias, rtol=0, atol=1e-5)
372
373    def test_remove_kwarg_hooks(self):
374        # test forward pre and forward hooks
375        fired_hooks: List[int] = []
376        x: torch.Tensor = torch.ones(10, 10)
377        bias: torch.Tensor = torch.ones(10, 10)
378        model = KwargModel()
379        forward_hook_handle = model.register_forward_hook(
380            partial(kwarg_forward_hook, self, fired_hooks, model, 1),
381            with_kwargs=True,
382        )
383        forward_pre_hook_handle = model.register_forward_pre_hook(
384            partial(kwarg_forward_pre_hook, self, fired_hooks, model, 0),
385            with_kwargs=True,
386        )
387
388        # forward-pre: bias' = bias * 2
389        # forward: out = x + bias'
390        # forward-post: out = out + bias'
391        # So, out = x + bias * 4
392        self.assertEqual(fired_hooks, [])
393        out = model(x, bias=bias)
394        self.assertEqual(fired_hooks, [0, 1])
395        self.assertEqual(out, x + 4 * bias, rtol=0, atol=1e-5)
396
397        # forward-pre: bias' = bias * 2
398        # forward: out = x + bias'
399        # So, out = x + bias * 2
400        forward_hook_handle.remove()
401        out = model(x, bias=bias)
402        self.assertEqual(fired_hooks, [0, 1, 0])
403        self.assertEqual(out, x + 2 * bias, rtol=0, atol=1e-5)
404        self.assertFalse(forward_hook_handle.id in model._forward_hooks_with_kwargs)
405
406        # forward: out = x + bias
407        # So, out = x + bias
408        forward_pre_hook_handle.remove()
409        out = model(x, bias=bias)
410        self.assertEqual(fired_hooks, [0, 1, 0])
411        self.assertEqual(out, x + bias, rtol=0, atol=1e-5)
412        self.assertFalse(
413            forward_pre_hook_handle.id in model._forward_pre_hooks_with_kwargs
414        )
415
416    def test_always_called_forward_hooks(self):
417        x: torch.Tensor = torch.ones(10, 10)
418        model = FailsInForwardModel()
419        stack = []
420        ctx = None
421
422        def setup_context():
423            nonlocal ctx
424            ctx = DummyContextManager(stack)
425
426        def ctx_setup_hook(m, i):
427            setup_context()
428            ctx.__enter__()
429
430        def ctx_setup_failure_hook(m, i):
431            setup_context()
432            ctx.__enter__()
433            raise RuntimeError("failing in ctx setup")
434
435        def ctx_shutdown_hook(m, i, o):
436            ctx.__exit__()
437
438        def ctx_shutdown_failure_hook(m, i, o):
439            ctx.__exit__()
440            raise RuntimeError("failing in ctx shutdown")
441
442        def throw_hook(m, i, o):
443            raise RuntimeError("failing in throw")
444
445        forward_pre_hook_handle = model.register_forward_pre_hook(ctx_setup_hook)
446        forward_hook_handle = model.register_forward_hook(
447            ctx_shutdown_hook, always_call=True
448        )
449        self.assertTrue(len(model._forward_hooks_always_called) == 1)
450
451        # make sure always_called forward hook runs when model.forward raises RuntimeError
452        with self.assertRaisesRegex(RuntimeError, "failing in forward"):
453            model(x)
454        self.assertEqual(stack, [2, -1])
455
456        # make sure that always_called forward hook does not run twice if there is no error
457        model(x, fail=False)
458        self.assertEqual(stack, [2, -1, 2, -1])
459
460        # make sure always_called forward hook runs when forward pre hook raises RuntimeError
461        forward_pre_hook_handle.remove()
462        model.register_forward_pre_hook(ctx_setup_failure_hook)
463
464        with self.assertRaisesRegex(RuntimeError, "failing in ctx setup"):
465            model(x, fail=False)
466        self.assertEqual(stack, [2, -1, 2, -1, 2, -1])
467
468        # make sure always_called hook runs when another always_called forward hook raises an error
469        forward_hook_handle2 = model.register_forward_hook(
470            throw_hook, prepend=True, always_call=True
471        )
472
473        # error raised should not be error of the forced hook
474        with self.assertRaisesRegex(RuntimeError, "failing in ctx setup"):
475            model(x, fail=False)
476        self.assertEqual(stack, [2, -1, 2, -1, 2, -1, 2, -1])
477
478        # make sure that always called forward hooks are properly removed
479        forward_hook_handle.remove()
480        forward_hook_handle2.remove()
481        self.assertTrue(len(model._forward_hooks_always_called) == 0)
482
483        # make sure that always called forward hook is not run twice if it fails while running
484        forward_hook_handle3 = model.register_forward_hook(
485            ctx_shutdown_failure_hook, always_call=True
486        )
487        with self.assertRaisesRegex(RuntimeError, "failing in ctx setup"):
488            model(x, fail=False)
489        self.assertEqual(stack, [2, -1, 2, -1, 2, -1, 2, -1, 2, -1])
490
491        forward_hook_handle3.remove()
492
493        global_forward_hook_handle = nn.modules.module.register_module_forward_hook(
494            ctx_shutdown_hook, always_call=True
495        )
496        self.assertTrue(len(nn.modules.module._global_forward_hooks_always_called) == 1)
497        # make sure global forward hook runs when forward pre hook raises RuntimeError
498        with self.assertRaisesRegex(RuntimeError, "failing in ctx setup"):
499            model(x, fail=False)
500        self.assertEqual(stack, [2, -1, 2, -1, 2, -1, 2, -1, 2, -1, 2, -1])
501
502        # make sure forced global forward hook is properly removed
503        global_forward_hook_handle.remove()
504        self.assertTrue(len(nn.modules.module._global_forward_hooks_always_called) == 0)
505        with self.assertRaisesRegex(RuntimeError, "failing in ctx setup"):
506            model(x)
507        self.assertEqual(stack, [2, -1, 2, -1, 2, -1, 2, -1, 2, -1, 2, -1, 2])
508
509    def test_bw_hook_warning_for_non_tensor_or_tuple(self):
510        # Test to verify that backward hook raises warning
511        # if result is not a Tensor or tuple of Tensors.
512        counter = {"forward": 0, "backward": 0}
513
514        def fw_pre_hook(module: nn.Module, _inputs):
515            counter["forward"] += 1
516
517        def fw_hook(module: nn.Module, _inputs, _outputs):
518            counter["forward"] += 1
519
520        def bw_hook(module: nn.Module, _inputs, _outputs):
521            counter["backward"] += 1
522
523        class TestModule(nn.Module):
524            def forward(self, dict):
525                inp = dict["x"]
526                x = torch.nn.functional.softmax(inp, dim=0)
527                return {"x": x}
528
529        x = torch.ones(2, requires_grad=True)
530        model = TestModule()
531        model.register_forward_pre_hook(fw_pre_hook)
532        model.register_forward_hook(fw_hook)
533        model.register_full_backward_pre_hook(bw_hook)
534        model.register_full_backward_hook(bw_hook)
535
536        with warnings.catch_warnings(record=True) as w:
537            y = model({"x": x})["x"]
538            loss = y.sum()
539            loss.backward()
540
541        self.assertEqual(counter["forward"], 2)
542        self.assertEqual(counter["backward"], 0)
543        self.assertEqual(len(w), 1)
544        self.assertTrue("should be a Tensor or a tuple of Tensors" in str(w[0].message))
545
546
547def _hook_to_pickle(*args, **kwargs):
548    pass
549
550
551class TestStateDictHooks(TestCase):
552    @swap([True, False])
553    def test_load_state_dict_pre_hook(self):
554        m = nn.Linear(10, 10)
555        m_state_dict = m.state_dict()
556
557        m_load = nn.Linear(10, 10)
558
559        hook_called = 0
560
561        def hook_without_module(
562            state_dict,
563            prefix,
564            local_metadata,
565            strict,
566            missing_keys,
567            unexpected_keys,
568            error_msgs,
569        ):
570            self.assertEqual(m_state_dict, state_dict)
571            nonlocal hook_called
572            hook_called += 1
573
574        def hook_with_module(
575            module,
576            state_dict,
577            prefix,
578            local_metadata,
579            strict,
580            missing_keys,
581            unexpected_keys,
582            error_msgs,
583        ):
584            self.assertEqual(m_state_dict, state_dict)
585            self.assertTrue(m_load is module)
586            nonlocal hook_called
587            hook_called += 1
588
589        hook_called = 0
590        # Test private API since this sets with_module=False which diverges from public API
591        m_load._register_load_state_dict_pre_hook(hook_without_module)
592        m_load.load_state_dict(m_state_dict)
593        self.assertEqual(1, hook_called)
594
595        hook_called = 0
596        m_load.register_load_state_dict_pre_hook(hook_with_module)
597        m_load.load_state_dict(m_state_dict)
598        self.assertEqual(2, hook_called)
599
600        # Test private API with with_module=True
601        hook_called = 0
602        m_load._register_load_state_dict_pre_hook(hook_with_module, True)
603        m_load.load_state_dict(m_state_dict)
604        self.assertEqual(3, hook_called)
605
606    def test_no_extra_ref_to_module(self):
607        try:
608            gc.disable()
609            m = nn.Linear(10, 10)
610
611            m.register_load_state_dict_pre_hook(_hook_to_pickle)
612            weak_m = weakref.ref(m)
613            del m
614
615            self.assertEqual(weak_m(), None)
616        finally:
617            gc.enable()
618
619    def test_pickled_hook(self):
620        m = nn.Linear(10, 10)
621        m.register_load_state_dict_pre_hook(_hook_to_pickle)
622        pickle.loads(pickle.dumps(m))
623
624    @swap([True, False])
625    def test_load_state_dict_module_pre_hook(self):
626        hook_called = 0
627
628        # Test with module instance method as hook
629        class MyModule(nn.Module):
630            def __init__(self) -> None:
631                super().__init__()
632                self.foo = torch.nn.Parameter(torch.rand(10))
633
634            def my_pre_load_hook(
635                self,
636                state_dict,
637                prefix,
638                local_metadata,
639                strict,
640                missing_keys,
641                unexpected_keys,
642                error_msgs,
643            ):
644                assert [] == error_msgs
645                assert [] == unexpected_keys
646                assert [] == missing_keys
647                assert strict
648                nonlocal hook_called
649                hook_called += 1
650
651            def my_pre_load_hook_with_module(
652                self,
653                module,
654                state_dict,
655                prefix,
656                local_metadata,
657                strict,
658                missing_keys,
659                unexpected_keys,
660                error_msgs,
661            ):
662                assert [] == error_msgs
663                assert [] == unexpected_keys
664                assert [] == missing_keys
665                assert strict
666                assert self is module
667                nonlocal hook_called
668                hook_called += 1
669
670        # Test that hooks registered on a submodule are also called
671        # appropriately, i.e. with the submodule as module argument in
672        # my_pre_load_hook_with_module.
673        class MyModuleContainer(nn.Module):
674            def __init__(self, mod):
675                super().__init__()
676                self.mod = mod
677
678        for ctor in [MyModuleContainer, lambda x: x]:
679            m = ctor(MyModule())
680            state_dict = m.state_dict()
681            if isinstance(m, MyModuleContainer):
682                mod = m.mod
683            else:
684                mod = m
685
686            hook_called = 0
687            # Test private API since this sets with_module=False which diverges from public API
688            mod._register_load_state_dict_pre_hook(mod.my_pre_load_hook)
689            m.load_state_dict(state_dict)
690            self.assertEqual(1, hook_called)
691
692            hook_called = 0
693            mod.register_load_state_dict_pre_hook(mod.my_pre_load_hook_with_module)
694            m.load_state_dict(state_dict)
695            self.assertEqual(2, hook_called)
696
697    @swap([True, False])
698    def test_load_state_dict_post_hook(self):
699        hook_called = 0
700
701        class MyModule(nn.Module):
702            def __init__(self) -> None:
703                super().__init__()
704                self.foo = torch.nn.Parameter(torch.rand(10))
705
706            def my_post_load_hook(self, module, incompatible_keys):
707                assert module is self
708                nonlocal hook_called
709                incompatible_keys.missing_keys.append("foo")
710                incompatible_keys.unexpected_keys.append("bar")
711                hook_called += 1
712
713        nested = MyModule()
714        wrapped = nn.ModuleList([nested])
715        handle = nested.register_load_state_dict_post_hook(
716            nested.my_post_load_hook,
717        )
718        # Hook must be called even if it is wrapped
719        ret = wrapped.load_state_dict(wrapped.state_dict(), strict=False)
720        self.assertEqual(hook_called, 1)
721        # Ensure that the hook modified missing_keys and unexpected_keys
722        missing = ret.missing_keys
723        unexpected = ret.unexpected_keys
724        self.assertEqual(missing, ["foo"])
725        self.assertEqual(unexpected, ["bar"])
726        # When called with strict=True, the error raised should mention the
727        # missing and unexpected keys the hook added.
728        with self.assertRaisesRegex(RuntimeError, "foo.*\n.*bar"):
729            wrapped.load_state_dict(wrapped.state_dict(), strict=True)
730        self.assertEqual(hook_called, 2)
731        # Removing the hook via handle.remove() should cause it not to
732        # fire anymore.
733        handle.remove()
734        # Hook did not run so it should not have added any keys
735        ret = wrapped.load_state_dict(wrapped.state_dict(), strict=False)
736        self.assertEqual(ret.missing_keys, [])
737        self.assertEqual(ret.unexpected_keys, [])
738        # hook_called should not have been incremented
739        self.assertEqual(hook_called, 2)
740
741        def load_hook_clear_incompatible(module, incompatible_keys):
742            incompatible_keys.missing_keys.clear()
743            incompatible_keys.unexpected_keys.clear()
744
745        nested.register_load_state_dict_post_hook(load_hook_clear_incompatible)
746        state_dict = wrapped.state_dict()
747        state_dict["extra"] = torch.ones(1)
748        # load state_dict with strict=True should not throw.
749        ret = wrapped.load_state_dict(state_dict, strict=True)
750        # explicitly ensure that the post hook clearned out incompatible_keys
751        self.assertEqual([], ret.missing_keys)
752        self.assertEqual([], ret.unexpected_keys)
753
754    @unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows")
755    @swap([True, False])
756    def test_load_state_dict_post_hook_backward_compatibility(self):
757        def my_post_load_hook(mod, _):
758            nonlocal called
759            called = True
760
761        for m in [nn.Softmin(10), nn.Softmax(10), nn.LogSoftmax(10)]:
762            called = False
763            sd = deepcopy(m.state_dict())
764            self.assertTrue(hasattr(m, "_load_state_dict_post_hooks"))
765            # Simulate an older model that did not have this attr
766            delattr(m, "_load_state_dict_post_hooks")
767            # Save and load, and ensure that load_state_dict works (without proper
768            # BC we would run into errors because this attribute would be expected).
769            # In particular, Softmax runs into the issue described here:
770            # https://github.com/pytorch/pytorch/issues/77280
771            with NamedTemporaryFile() as f:
772                # Note that torch.save / torch.load is not recommended to save/load
773                # modules.
774                torch.save(m, f.name)
775                # weights_only=False as this is legacy code that saves the model
776                m = torch.load(f.name, weights_only=False)
777                m.load_state_dict(sd)
778                self.assertFalse(called)
779
780            # Ensure hooks can be registered and called.
781            m.register_load_state_dict_post_hook(my_post_load_hook)
782            m.load_state_dict(sd)
783            self.assertTrue(called)
784
785    def _test_register_state_dict_pre_hook(self, model, submodule):
786        _state_dict_prefix = "foo."
787        state_dict_pre_hook_count = 0
788        keep_var_setting = False
789
790        def my_state_dict_pre_hook(module, prefix, keep_vars):
791            self.assertEqual(keep_vars, keep_var_setting)
792            nonlocal state_dict_pre_hook_count
793            state_dict_pre_hook_count += 1
794            self.assertTrue(prefix.startswith(_state_dict_prefix))
795
796        model.register_state_dict_pre_hook(my_state_dict_pre_hook)
797        # Test to ensure submodules run the hook as well.
798        submodule.register_state_dict_pre_hook(my_state_dict_pre_hook)
799
800        def check_results(model):
801            nonlocal state_dict_pre_hook_count, keep_var_setting
802            for keep_var_setting in [True, False]:
803                _ = model.state_dict(
804                    prefix=_state_dict_prefix, keep_vars=keep_var_setting
805                )
806                self.assertEqual(2, state_dict_pre_hook_count)
807                state_dict_pre_hook_count = 0
808
809        # Test state dict works as expected after model construction
810        check_results(model)
811        # Test state dict works as expected after forward
812        model(torch.ones(10, 3))
813        check_results(model)
814
815    def test_register_state_dict_pre_hook(self):
816        class MyModule(torch.nn.Module):
817            def __init__(self) -> None:
818                super().__init__()
819                self.a = nn.Sequential(
820                    nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3)
821                )
822
823            def forward(self, x):
824                return self.a(x)
825
826        mod = MyModule()
827        self._test_register_state_dict_pre_hook(mod, mod.a)
828
829    def test_register_state_dict_pre_hook_lazy_module(self):
830        class MyLazyModule(torch.nn.Module):
831            def __init__(self) -> None:
832                super().__init__()
833                self.layer1 = nn.LazyLinear(8)
834                self.layer2 = nn.LazyLinear(5)
835
836            def forward(self, x):
837                return self.layer2(self.layer1(x))
838
839        mod = MyLazyModule()
840        self._test_register_state_dict_pre_hook(mod, mod.layer1)
841
842    @unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows")
843    def test_register_state_dict_pre_hook_backward_compat(self):
844        called = False
845
846        def my_state_dict_pre_hook(*args, **kwargs):
847            nonlocal called
848            called = True
849
850        m = nn.Linear(1, 1)
851        self.assertTrue(hasattr(m, "_state_dict_pre_hooks"))
852        delattr(m, "_state_dict_pre_hooks")
853        # Save and load, ensure we can still call state_dict
854        # without running into issues.
855        with NamedTemporaryFile() as f:
856            # Note that torch.save / torch.load is not recommended
857            # to save / load modules.
858            torch.save(m, f.name)
859            # weights_only=False as this is legacy code that saves the model
860            m = torch.load(f.name, weights_only=False)
861
862        # Ensure we can run state_dict without issues
863        _ = m.state_dict()
864        self.assertFalse(called)
865        m.register_state_dict_pre_hook(my_state_dict_pre_hook)
866        _ = m.state_dict()
867        self.assertTrue(called)
868
869    @parametrize_test("private", [True, False])
870    def test_register_state_dict_post_hook(self, private):
871        m = nn.Transformer(
872            d_model=4, nhead=2, num_encoder_layers=2, num_decoder_layers=2
873        )
874
875        def linear_state_dict_post_hook(module, state_dict, prefix, local_metadata):
876            for name, param in module.named_parameters(recurse=False):
877                state_dict[prefix + name] = torch.nn.Parameter(
878                    state_dict[prefix + name]
879                )
880
881        def register_linear_hook(module):
882            if isinstance(module, nn.Linear):
883                hook_registration_fn = (
884                    module._register_state_dict_hook
885                    if private
886                    else module.register_state_dict_post_hook
887                )
888                hook_registration_fn(linear_state_dict_post_hook)
889
890        def _check_sd(state_dict):
891            for k, v in m.state_dict().items():
892                if "linear" in k or "out_proj" in k:
893                    self.assertTrue(isinstance(v, torch.nn.Parameter))
894                else:
895                    self.assertFalse(isinstance(v, torch.nn.Parameter))
896
897        # verify that return type of hook registered on child submodules has no effect
898        # regardless of whether using public or private API
899        m.apply(register_linear_hook)
900        _check_sd(m.state_dict())
901
902        # verify that return type of hook registered root module has no effect
903        # for public API but has effect for private API
904        hook_registration_fn = (
905            m._register_state_dict_hook if private else m.register_state_dict_post_hook
906        )
907
908        def fn(m, s, p, l):
909            return OrderedDict()
910
911        handle = hook_registration_fn(fn)
912        if private:
913            self.assertFalse(hasattr(fn, "_from_public_api"))
914            self.assertTrue(len(m.state_dict()) == 0)
915        else:
916            self.assertTrue(hasattr(fn, "_from_public_api"))
917            with self.assertRaisesRegex(
918                RuntimeError, "state_dict post-hook must return None"
919            ):
920                sd = m.state_dict()
921            with self.assertRaisesRegex(
922                RuntimeError, "previously registered via register_state_dict_post_hook"
923            ):
924                m._register_state_dict_hook(fn)
925
926
927class TestModuleGlobalHooks(TestCase):
928    def tearDown(self):
929        nn.modules.module._global_backward_hooks = OrderedDict()
930        nn.modules.module._global_forward_hooks = OrderedDict()
931        nn.modules.module._global_forward_pre_hooks = OrderedDict()
932
933    @skipIfTorchDynamo("TorchDynamo does not work well with hooks")
934    def test_module_global_hooks(self):
935        module = nn.Sigmoid
936
937        module_1 = module()
938        module_2 = module()
939        module_3 = module()
940
941        input = torch.ones(5, 5, requires_grad=True)
942
943        counter = {"forwards": 0, "backwards": 0}
944
945        def fw_hook(inc, h_module, input, output):
946            self.assertIsInstance(input, tuple)
947            self.assertTrue(isinstance(output, torch.Tensor))
948            self.assertTrue(isinstance(h_module, module))
949            self.assertEqual(input[0], torch.ones(5, 5))
950            self.assertEqual(output, torch.empty(5, 5).fill_(1 / (1 + 1 / math.e)))
951            counter["forwards"] += inc
952
953        def bw_hook(inc, h_module, grad_input, grad_output):
954            self.assertIsInstance(grad_input, tuple)
955            self.assertIsInstance(grad_output, tuple)
956            self.assertTrue(isinstance(h_module, module))
957            self.assertEqual(grad_output[0], torch.ones(5, 5) * 2)
958            counter["backwards"] += inc
959
960        test_fwd = nn.modules.module.register_module_forward_hook(
961            lambda *args: fw_hook(1, *args)
962        )
963
964        module_1(input)
965        module_2(input)
966        module_3(input)
967        self.assertEqual(counter["forwards"], 3)
968        self.assertEqual(counter["backwards"], 0)
969
970        test_bwd = nn.modules.module.register_module_backward_hook(
971            lambda *args: bw_hook(1, *args)
972        )
973
974        output_1 = module_1(input)
975        output_2 = module_2(input)
976        output_3 = module_3(input)
977        self.assertEqual(counter["forwards"], 6)
978        self.assertEqual(counter["backwards"], 0)
979
980        output_1.backward(torch.ones(5, 5) * 2, retain_graph=True)
981        output_2.backward(torch.ones(5, 5) * 2, retain_graph=False)
982        output_3.backward(torch.ones(5, 5) * 2, retain_graph=False)
983        self.assertEqual(counter["forwards"], 6)
984        self.assertEqual(counter["backwards"], 3)
985
986        output_1.backward(torch.ones(5, 5) * 2, retain_graph=True)
987        self.assertEqual(counter["forwards"], 6)
988        self.assertEqual(counter["backwards"], 4)
989
990        test2_fwd = nn.modules.module.register_module_forward_hook(
991            lambda *args: fw_hook(2, *args)
992        )
993
994        output = module_1(input)
995        output = module_2(input)
996        output = module_3(input)
997        self.assertEqual(counter["forwards"], 15)
998        self.assertEqual(counter["backwards"], 4)
999
1000        test2_bwd = nn.modules.module.register_module_backward_hook(
1001            lambda *args: bw_hook(2, *args)
1002        )
1003
1004        module_1(input).backward(torch.ones(5, 5) * 2)
1005        self.assertEqual(counter["forwards"], 18)
1006        self.assertEqual(counter["backwards"], 7)
1007
1008        test2_bwd.remove()
1009
1010        module_2(input).backward(torch.ones(5, 5) * 2)
1011        self.assertEqual(counter["forwards"], 21)
1012        self.assertEqual(counter["backwards"], 8)
1013
1014        test2_fwd.remove()
1015
1016        module_3(input).backward(torch.ones(5, 5) * 2)
1017        self.assertEqual(counter["forwards"], 22)
1018        self.assertEqual(counter["backwards"], 9)
1019
1020        test_fwd.remove()
1021        test_bwd.remove()
1022
1023    def test_module_global_hook_invalid_outputs(self):
1024        module = nn.Sigmoid()
1025        input = torch.randn(5, 5, requires_grad=True)
1026
1027        def bw_fail1(self, grad_input, grad_output):
1028            return grad_input[:-1]
1029
1030        def bw_fail2(self, grad_input, grad_output):
1031            return grad_input + (torch.randn(2, 2),)
1032
1033        with nn.modules.module.register_module_backward_hook(bw_fail1):
1034            with self.assertRaisesRegex(RuntimeError, "got 0, but expected 1"):
1035                module(input).sum().backward()
1036
1037        with nn.modules.module.register_module_backward_hook(bw_fail2):
1038            with self.assertRaisesRegex(RuntimeError, "got 2, but expected 1"):
1039                module(input).sum().backward()
1040
1041    def test_module_backward_global_hook_writeable(self):
1042        module = nn.Sigmoid()
1043        input = torch.randn(5, 5, requires_grad=True)
1044        sig_x = torch.sigmoid(input)
1045
1046        def bw_hook(module, grad_input, grad_output):
1047            for grad in grad_input:
1048                self.assertTrue(isinstance(grad, torch.Tensor))
1049            for grad in grad_output:
1050                self.assertTrue(isinstance(grad, torch.Tensor))
1051            return tuple(gi * 2 for gi in grad_input)
1052
1053        nn.modules.module.register_module_backward_hook(bw_hook)
1054        module(input).backward(torch.ones(5, 5))
1055        expected_grad = sig_x * (1 - sig_x) * 2
1056        self.assertEqual(input.grad, expected_grad)
1057
1058    @skipIfTorchDynamo("TorchDynamo does not work well with hooks")
1059    def test_module_global_forward_preforward_hook_writeable(self):
1060        module = nn.Sigmoid()
1061        input = torch.randn(5, 5, requires_grad=True)
1062        sig_x = torch.sigmoid(input)
1063
1064        def forward_pre_hook(m, input):
1065            return torch.nn.functional.relu(input[0])
1066
1067        def forward_hook(m, input, output):
1068            return -output
1069
1070        nn.modules.module.register_module_forward_pre_hook(forward_pre_hook)
1071        nn.modules.module.register_module_forward_hook(forward_hook)
1072        output = module(input)
1073        expected_res = -torch.sigmoid(torch.nn.functional.relu(input))
1074        self.assertEqual(output, expected_res)
1075        output.backward(torch.ones(5, 5) * 2, retain_graph=True)
1076        mask = input > 0
1077        expected_grad = -sig_x * (1 - sig_x) * 2 * mask
1078        self.assertEqual(input.grad, expected_grad)
1079
1080    def test_module_forward_preforward_hook_removable(self):
1081        """
1082        This test is to test when multiple pre-forward hook functions can be
1083        registered successfully and used correctly, if the handle can be removable
1084        during the pre-forward hook function call.
1085        """
1086        module = nn.Sigmoid()
1087
1088        def removable_hook(m, input):
1089            nonlocal handle
1090            handle.remove()
1091            return input
1092
1093        def removable_hook_2(m, input):
1094            nonlocal handle_2
1095            handle_2.remove()
1096            return input
1097
1098        handle = module.register_forward_pre_hook(removable_hook)
1099        handle_2 = module.register_forward_pre_hook(removable_hook_2)
1100
1101        # make sure hook register is successful
1102        self.assertEqual(len(handle.hooks_dict_ref()), 2)
1103        self.assertEqual(len(handle_2.hooks_dict_ref()), 2)
1104
1105        input = torch.randn(2, 2)
1106        output = module(input)
1107        self.assertEqual(torch.sigmoid(input), output)
1108
1109        # make sure hook removal is successful
1110        self.assertFalse(handle.id in handle.hooks_dict_ref())
1111        self.assertFalse(handle_2.id in handle.hooks_dict_ref())
1112        self.assertEqual(len(handle.hooks_dict_ref()), 0)
1113        self.assertEqual(len(handle_2.hooks_dict_ref()), 0)
1114
1115    def test_module_forward_forward_hook_removable(self):
1116        """
1117        This test is to test when multiple forward hook functions can be registered
1118        successfully and used correctly, if the handle can be removable during the
1119        forward hook function call.
1120        """
1121        module = nn.Sigmoid()
1122
1123        def removable_hook(m, input, output):
1124            nonlocal handle
1125            handle.remove()
1126            return output
1127
1128        def removable_hook_2(m, input, output):
1129            nonlocal handle_2
1130            handle_2.remove()
1131            return output
1132
1133        handle = module.register_forward_hook(removable_hook)
1134        handle_2 = module.register_forward_hook(removable_hook_2)
1135
1136        # make sure hook register is successful
1137        self.assertEqual(len(handle.hooks_dict_ref()), 2)
1138        self.assertEqual(len(handle_2.hooks_dict_ref()), 2)
1139
1140        input = torch.randn(2, 2)
1141        output = module(input)
1142        self.assertEqual(torch.sigmoid(input), output)
1143
1144        # make sure hook removal is successful
1145        self.assertFalse(handle.id in handle.hooks_dict_ref())
1146        self.assertFalse(handle_2.id in handle.hooks_dict_ref())
1147        self.assertEqual(len(handle.hooks_dict_ref()), 0)
1148        self.assertEqual(len(handle_2.hooks_dict_ref()), 0)
1149
1150    @skipIfTorchDynamo("TorchDynamo does not work well with hooks")
1151    def test_global_and_local_hooks_order(self):
1152        module = nn.Sigmoid()
1153
1154        global_forward_pre_called = False
1155        local_forward_pre_called = False
1156        global_forward_called = False
1157        local_forward_called = False
1158        global_backward_called = False
1159        local_backward_called = False
1160
1161        def global_forward_pre_hook(m, input):
1162            nonlocal global_forward_pre_called
1163            self.assertTrue(not local_forward_pre_called)
1164            global_forward_pre_called = True
1165            return input
1166
1167        def local_forward_pre_hook(m, input):
1168            nonlocal local_forward_pre_called
1169            self.assertTrue(global_forward_pre_called)
1170            local_forward_pre_called = True
1171            return input
1172
1173        def global_forward_hook(m, input, output):
1174            nonlocal global_forward_called
1175            self.assertTrue(not local_forward_called)
1176            global_forward_called = True
1177            return output
1178
1179        def local_forward_hook(m, input, output):
1180            nonlocal local_forward_called
1181            self.assertTrue(global_forward_called)
1182            local_forward_called = True
1183            return output
1184
1185        def global_backward_hook(m, input, output):
1186            nonlocal global_backward_called
1187            self.assertTrue(not local_backward_called)
1188            global_backward_called = True
1189            return input
1190
1191        def local_backward_hook(m, input, output):
1192            nonlocal local_backward_called
1193            self.assertTrue(global_backward_called)
1194            local_backward_called = True
1195            return input
1196
1197        input = torch.randn(5, 5, requires_grad=True)
1198        nn.modules.module.register_module_forward_pre_hook(global_forward_pre_hook)
1199        module.register_forward_pre_hook(local_forward_pre_hook)
1200        nn.modules.module.register_module_forward_hook(global_forward_hook)
1201        module.register_forward_hook(local_forward_hook)
1202        nn.modules.module.register_module_backward_hook(global_backward_hook)
1203        module.register_backward_hook(local_backward_hook)
1204
1205        output = module(input)
1206        self.assertTrue(
1207            local_forward_called
1208            and local_forward_pre_called
1209            and global_forward_called
1210            and global_forward_pre_called
1211        )
1212
1213        output.backward(torch.ones(5, 5), retain_graph=True)
1214        self.assertTrue(local_backward_called and global_backward_called)
1215
1216
1217class TestModuleHookNN(NNTestCase):
1218    _do_cuda_memory_leak_check = True
1219    _do_cuda_non_default_stream = True
1220
1221    def _test_hooks(self, backward_register_fn):
1222        module = nn.Sigmoid()
1223        input = torch.ones(5, 5, requires_grad=True)
1224
1225        counter = {"forwards": 0, "backwards": 0}
1226
1227        def fw_hook(inc, h_module, input, output):
1228            self.assertIsInstance(input, tuple)
1229            self.assertTrue(isinstance(output, torch.Tensor))
1230            self.assertTrue(h_module is module)
1231            self.assertEqual(input[0], torch.ones(5, 5))
1232            self.assertEqual(output, torch.empty(5, 5).fill_(1 / (1 + 1 / math.e)))
1233            counter["forwards"] += inc
1234
1235        def bw_hook(inc, h_module, grad_input, grad_output):
1236            self.assertIsInstance(grad_input, tuple)
1237            self.assertIsInstance(grad_output, tuple)
1238            self.assertTrue(h_module is module)
1239            self.assertEqual(grad_output[0], torch.ones(5, 5) * 2)
1240            counter["backwards"] += inc
1241
1242        # backward_pre_hook expects callback with only `module` and `grad_output`
1243        # as arguments.
1244        def bw_pre_hook(inc, h_module, grad_output):
1245            self.assertIsInstance(grad_output, tuple)
1246            self.assertTrue(h_module is module)
1247            self.assertEqual(grad_output[0], torch.ones(5, 5) * 2)
1248            counter["backwards"] += inc
1249
1250        test_fwd = module.register_forward_hook(lambda *args: fw_hook(1, *args))
1251
1252        module(input)
1253        module(input)
1254        self.assertEqual(counter["forwards"], 2)
1255        self.assertEqual(counter["backwards"], 0)
1256
1257        bw_hook_fn = (
1258            bw_pre_hook
1259            if backward_register_fn == "register_full_backward_pre_hook"
1260            else bw_hook
1261        )
1262        test_bwd = getattr(module, backward_register_fn)(
1263            lambda *args: bw_hook_fn(1, *args)
1264        )
1265
1266        output = module(input)
1267        self.assertEqual(counter["forwards"], 3)
1268        self.assertEqual(counter["backwards"], 0)
1269
1270        output.backward(torch.ones(5, 5) * 2, retain_graph=True)
1271        self.assertEqual(counter["forwards"], 3)
1272        self.assertEqual(counter["backwards"], 1)
1273
1274        output.backward(torch.ones(5, 5) * 2, retain_graph=True)
1275        self.assertEqual(counter["forwards"], 3)
1276        self.assertEqual(counter["backwards"], 2)
1277
1278        test2_fwd = module.register_forward_hook(lambda *args: fw_hook(2, *args))
1279
1280        output = module(input)
1281        self.assertEqual(counter["forwards"], 6)
1282        self.assertEqual(counter["backwards"], 2)
1283
1284        test2_bwd = getattr(module, backward_register_fn)(
1285            lambda *args: bw_hook_fn(2, *args)
1286        )
1287
1288        module(input).backward(torch.ones(5, 5) * 2)
1289        self.assertEqual(counter["forwards"], 9)
1290        self.assertEqual(counter["backwards"], 5)
1291
1292        test2_bwd.remove()
1293
1294        module(input).backward(torch.ones(5, 5) * 2)
1295        self.assertEqual(counter["forwards"], 12)
1296        self.assertEqual(counter["backwards"], 6)
1297
1298        test2_fwd.remove()
1299
1300        module(input).backward(torch.ones(5, 5) * 2)
1301        self.assertEqual(counter["forwards"], 13)
1302        self.assertEqual(counter["backwards"], 7)
1303
1304        test_fwd.remove()
1305        test_bwd.remove()
1306
1307    def test_hooks(self):
1308        self._test_hooks("register_backward_hook")
1309        self._test_hooks("register_full_backward_hook")
1310        self._test_hooks("register_full_backward_pre_hook")
1311
1312    def test_hook_cpp(self):
1313        bn = nn.BatchNorm1d(5)
1314
1315        def hook(module, grad_inputs, grad_outputs):
1316            self.assertEqual(len(grad_inputs), 1)
1317            self.assertEqual(len(grad_outputs), 1)
1318            self.assertEqual(module, bn)
1319
1320        bn.register_full_backward_hook(hook)
1321        output = bn(torch.randn(5, 5, requires_grad=True))
1322        output.sum().backward()
1323
1324    def test_backward_hooks_interaction(self):
1325        # Test to make sure that the grad_outputs
1326        # updated by full_backward_pre_hook are received by
1327        # the full_backward_hook
1328        module = torch.nn.Sigmoid()
1329
1330        cnt = {"backward_cnt": 0}
1331
1332        def bw_pre_hook(m, grad_output):
1333            cnt["backward_cnt"] += 1
1334            return (grad_output[0] * 0.5,)
1335
1336        def bw_hook(m, grad_in, grad_output):
1337            self.assertEqual(torch.full_like(grad_output[0], 0.5), grad_output[0])
1338            cnt["backward_cnt"] += 1
1339            return grad_output
1340
1341        module.register_full_backward_pre_hook(bw_pre_hook)
1342        module.register_full_backward_hook(bw_hook)
1343
1344        t = torch.ones(1, 2, requires_grad=True)
1345        module(t).sum().backward()
1346        self.assertEqual(cnt["backward_cnt"], 2)
1347
1348    def test_hook_invalid_outputs(self):
1349        module = nn.Sigmoid()
1350        input = torch.randn(5, 5, requires_grad=True)
1351
1352        def bw_fail1(self, grad_input, grad_output):
1353            return grad_input[:-1]
1354
1355        def bw_fail2(self, grad_input, grad_output):
1356            return grad_input + (torch.randn(2, 2),)
1357
1358        with module.register_backward_hook(bw_fail1):
1359            with self.assertRaisesRegex(RuntimeError, "got 0, but expected 1"):
1360                module(input).sum().backward()
1361
1362        with module.register_backward_hook(bw_fail2):
1363            with self.assertRaisesRegex(RuntimeError, "got 2, but expected 1"):
1364                module(input).sum().backward()
1365
1366        def bw_pre_fail1(self, grad_output):
1367            return ()
1368
1369        def bw_pre_fail2(self, grad_output):
1370            return grad_output + (torch.randn(2, 2),)
1371
1372        with module.register_full_backward_pre_hook(bw_pre_fail1):
1373            with self.assertRaisesRegex(RuntimeError, "got 0, but expected 1"):
1374                module(input).sum().backward()
1375
1376        with module.register_full_backward_pre_hook(bw_pre_fail2):
1377            with self.assertRaisesRegex(RuntimeError, "got 2, but expected 1"):
1378                module(input).sum().backward()
1379
1380    def test_hook_requires_grad(self):
1381        test_self = self
1382
1383        class MyModule(nn.Module):
1384            def forward(self, arg1, arg2, arg3):
1385                test_self.assertTrue(arg1.requires_grad)
1386                test_self.assertFalse(arg2.requires_grad)
1387                test_self.assertTrue(arg3.requires_grad)
1388                return arg1.sum() + arg2.sum() + arg3.sum()
1389
1390        inp = torch.rand(2, requires_grad=True)
1391        mod = MyModule()
1392
1393        mod(inp, inp.detach(), inp)
1394        # Ensure that requires grad is properly propagated
1395        mod.register_full_backward_hook(lambda mod, gI, gO: None)
1396        mod(inp, inp.detach(), inp)
1397
1398    def test_hook_no_requires_grad(self):
1399        mod = nn.Linear(2, 3)
1400
1401        inp = torch.rand(1, 2)
1402
1403        return_val = "None"
1404        hook_called = [0]
1405
1406        def hook(mod, grad_input, grad_output):
1407            hook_called[0] += 1
1408            for gI in grad_input:
1409                self.assertIsNone(gI)
1410            for gO in grad_output:
1411                self.assertEqual(gO.size(), (1, 3))
1412
1413            if return_val == "grad_input":
1414                return grad_input
1415            elif return_val == "invalid":
1416                # If the inputs were requiring gradients, this would be
1417                # a valid return
1418                return inp
1419            elif return_val == "None":
1420                return None
1421            else:
1422                raise RuntimeError("Invalid return_val string")
1423
1424        mod.register_full_backward_hook(hook)
1425
1426        # This should run and trigger the hook properly
1427        mod(inp).sum().backward()
1428        self.assertEqual(hook_called[0], 1)
1429
1430        return_val = "grad_input"
1431
1432        mod(inp).sum().backward()
1433        self.assertEqual(hook_called[0], 2)
1434
1435        return_val = "invalid"
1436        with self.assertRaisesRegex(RuntimeError, "where no input requires gradient"):
1437            mod(inp).sum().backward()
1438
1439    def test_hook_last_arg_requires_grad(self):
1440        mod = nn.L1Loss()
1441        inp = torch.rand(1, requires_grad=True)
1442        mod.register_full_backward_hook(lambda m, gI, gO: None)
1443
1444        try:
1445            mod(inp.detach(), inp)
1446        except Exception as ex:
1447            self.fail(f"Unexpected exception: {ex}")
1448
1449    def test_hook_extra_input(self):
1450        class MyModule(nn.Module):
1451            def forward(self, non_tensor, tensor):
1452                return tensor.clone(), non_tensor
1453
1454        inp = torch.rand(2, requires_grad=True)
1455        mod = MyModule()
1456
1457        def hook(mod, grad_input, grad_output):
1458            self.assertIsNone(grad_input[0])
1459            self.assertIsInstance(grad_input[1], torch.Tensor)
1460
1461            self.assertIsInstance(grad_output[0], torch.Tensor)
1462            self.assertIsNone(grad_output[1])
1463
1464        mod.register_full_backward_hook(hook)
1465        out, _ = mod(True, inp)
1466        out.sum().backward()
1467
1468    def test_hook_inplace(self):
1469        class MyModule(nn.Module):
1470            def forward(self, inp, do_inplace):
1471                self.inp = inp
1472                if do_inplace:
1473                    inp += 1
1474                return inp.clone()
1475
1476        hook_called = [0]
1477
1478        def hook(mod, grad_input, grad_output):
1479            hook_called[0] += 1
1480
1481        def hook_pre(mod, grad_output):
1482            hook_called[0] += 1
1483
1484        inp = torch.rand(10, requires_grad=True)
1485        mod = MyModule()
1486        for hook_fn, register_fn in [
1487            (hook, mod.register_full_backward_hook),
1488            (hook_pre, mod.register_full_backward_pre_hook),
1489        ]:
1490            hook_called[0] = 0
1491            with register_fn(hook_fn):
1492                # No inplace should work
1493                mod(inp, False).sum().backward()
1494                self.assertEqual(hook_called[0], 1)
1495
1496                # Input inplace error should throw an error
1497                with self.assertRaisesRegex(
1498                    RuntimeError,
1499                    "Output 0 of BackwardHookFunctionBackward is "
1500                    "a view and is being modified inplace.",
1501                ):
1502                    mod(inp.clone(), True)
1503
1504                # Input inplace error should throw an error if we try to re-use the view after they have
1505                # been modified
1506                local_inp = inp.clone()
1507                out = mod(local_inp, False)
1508                local_inp[0] *= 1
1509                with self.assertRaisesRegex(
1510                    RuntimeError,
1511                    "Output 0 of BackwardHookFunctionBackward is "
1512                    "a view and its base or another view",
1513                ):
1514                    # Any operation involving the view will fail here
1515                    mod.inp + 2
1516
1517                # Output inplace error should throw an error
1518                out = mod(inp, False)
1519                with self.assertRaisesRegex(
1520                    RuntimeError,
1521                    "BackwardHookFunctionBackward is a view "
1522                    "and is being modified inplace.",
1523                ):
1524                    out += 1
1525
1526    def test_hook_non_full_warning(self):
1527        def noop(*args):
1528            pass
1529
1530        a = torch.rand(2, requires_grad=True)
1531        b = torch.rand(2, requires_grad=True)
1532
1533        # Check invalid input container
1534        class MyModule(nn.Module):
1535            def forward(self, l):
1536                return l[0].clone(), l[1].clone()
1537
1538        m = MyModule()
1539        m.register_backward_hook(noop)
1540
1541        with self.assertWarnsRegex(
1542            FutureWarning,
1543            "does not take as input a single Tensor or a tuple of Tensors",
1544        ):
1545            m([a, b])
1546
1547        # Check invalid output container
1548        class MyModule(nn.Module):
1549            def forward(self, a, b):
1550                return [a.clone(), b.clone()]
1551
1552        m = MyModule()
1553        m.register_backward_hook(noop)
1554
1555        with self.assertWarnsRegex(
1556            FutureWarning, "does not return a single Tensor or a tuple of Tensors"
1557        ):
1558            m(a, b)
1559
1560        # Check invalid output from different Nodes
1561        class MyModule(nn.Module):
1562            def forward(self, a, b):
1563                return a.clone(), b.clone()
1564
1565        m = MyModule()
1566        m.register_backward_hook(noop)
1567
1568        with self.assertWarnsRegex(
1569            FutureWarning, "outputs are generated by different autograd Nodes"
1570        ):
1571            m(a, b)
1572
1573        # Check invalid forward with multiple Nodes
1574        class MyModule(nn.Module):
1575            def forward(self, a):
1576                return a.clone().clone()
1577
1578        m = MyModule()
1579        m.register_backward_hook(noop)
1580
1581        with self.assertWarnsRegex(
1582            FutureWarning, "the forward contains multiple autograd Nodes"
1583        ):
1584            m(a)
1585
1586    def test_hook_backward_size(self):
1587        # Make module with multiple operations in forward
1588        # And different size for input and outputs
1589        class MyModule(nn.Module):
1590            def forward(self, arg1, arg2):
1591                tmp = arg1.sum() * arg2
1592                tmp = tmp + arg2.sum() * arg1.sum()
1593                tmp = tmp.sum().view(1)
1594                tmp = tmp.expand(8).contiguous()
1595                return tmp
1596
1597        module = MyModule()
1598        inp1 = torch.randn(5, 5, requires_grad=True)
1599        inp2 = torch.randn(10, 10, requires_grad=True)
1600
1601        def bw_hook(module, grad_input, grad_output):
1602            self.assertEqual(len(grad_input), 2)
1603            self.assertEqual(grad_input[0].size(), torch.Size([5, 5]))
1604            self.assertEqual(grad_input[1].size(), torch.Size([10, 10]))
1605            self.assertEqual(len(grad_output), 1)
1606            self.assertEqual(grad_output[0].size(), torch.Size([8]))
1607
1608        with module.register_full_backward_hook(bw_hook):
1609            module(inp1, inp2).sum().backward()
1610
1611    def test_hook_backward_writeable(self):
1612        module = nn.Sigmoid()
1613        input = torch.randn(5, 5, requires_grad=True)
1614        sig_x = torch.nn.functional.sigmoid(input)
1615
1616        def bw_hook(module, grad_input, grad_output):
1617            for grad in grad_input:
1618                self.assertTrue(isinstance(grad, torch.Tensor))
1619            for grad in grad_output:
1620                self.assertTrue(isinstance(grad, torch.Tensor))
1621            return tuple(gi * 2 for gi in grad_input)
1622
1623        module.register_backward_hook(bw_hook)
1624        module(input).backward(torch.ones(5, 5))
1625        expected_grad = sig_x * (1 - sig_x) * 2
1626        self.assertEqual(input.grad, expected_grad)
1627
1628    def test_hook_forward_preforward_writable(self):
1629        module = nn.Sigmoid()
1630        input = torch.randn(5, 5, requires_grad=True)
1631        sig_x = torch.nn.functional.sigmoid(input)
1632
1633        def forward_pre_hook(m, input):
1634            return torch.nn.functional.relu(input[0])
1635
1636        def forward_hook(m, input, output):
1637            return -output
1638
1639        module.register_forward_pre_hook(forward_pre_hook)
1640        module.register_forward_hook(forward_hook)
1641        output = module(input)
1642        expected_res = -torch.nn.functional.sigmoid(torch.nn.functional.relu(input))
1643        self.assertEqual(output, expected_res)
1644        output.backward(torch.ones(5, 5) * 2, retain_graph=True)
1645        mask = input > 0
1646        expected_grad = -sig_x * (1 - sig_x) * 2 * mask
1647        self.assertEqual(input.grad, expected_grad)
1648
1649    def test_hook_buffer_registration(self):
1650        for return_buffer in (True, False):
1651
1652            def buffer_registration_hook(module, name, buffer):
1653                buffer.registered = True
1654                if return_buffer:
1655                    return buffer
1656
1657            handle = torch.nn.modules.module.register_module_buffer_registration_hook(
1658                buffer_registration_hook
1659            )
1660            try:
1661                l, n, s = _create_basic_net()
1662                for b in s.buffers():
1663                    self.assertTrue(getattr(b, "registered", False))
1664            finally:
1665                handle.remove()
1666
1667    def test_hook_submodule_registration(self):
1668        for return_submodule in (True, False):
1669
1670            def module_registration_hook(module, name, submodule):
1671                module.registered = True
1672                submodule.registered = True
1673                if return_submodule:
1674                    return submodule
1675
1676            handle = torch.nn.modules.module.register_module_module_registration_hook(
1677                module_registration_hook
1678            )
1679            try:
1680                l, n, s = _create_basic_net()
1681                for m in s.modules():
1682                    self.assertTrue(getattr(m, "registered", False))
1683            finally:
1684                handle.remove()
1685
1686    def test_hook_parameter_registration(self):
1687        for return_parameter in (True, False):
1688
1689            def parameter_registration_hook(module, name, parameter):
1690                parameter.registered = True
1691                if return_parameter:
1692                    return parameter
1693
1694            handle = (
1695                torch.nn.modules.module.register_module_parameter_registration_hook(
1696                    parameter_registration_hook
1697                )
1698            )
1699            try:
1700                l, n, s = _create_basic_net()
1701                for p in s.parameters():
1702                    self.assertTrue(getattr(p, "registered", False))
1703            finally:
1704                handle.remove()
1705
1706
1707instantiate_parametrized_tests(TestModuleHooks)
1708instantiate_parametrized_tests(TestStateDictHooks)
1709
1710if __name__ == "__main__":
1711    run_tests()
1712