1# Owner(s): ["module: nn"] 2import gc 3import math 4import pickle 5import unittest 6import warnings 7import weakref 8from collections import namedtuple, OrderedDict 9from copy import deepcopy 10from functools import partial 11from tempfile import NamedTemporaryFile 12from typing import Any, Dict, List, Tuple 13 14import torch 15import torch.nn as nn 16from torch.testing._internal.common_nn import _create_basic_net, NNTestCase 17from torch.testing._internal.common_utils import ( 18 instantiate_parametrized_tests, 19 IS_WINDOWS, 20 parametrize as parametrize_test, 21 run_tests, 22 skipIfTorchDynamo, 23 swap, 24 TestCase, 25) 26 27 28class Net(nn.Module): 29 def __init__(self) -> None: 30 super().__init__() 31 self.seq1 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(2)]) 32 self.seq2 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(2)]) 33 34 def forward(self, x: torch.Tensor) -> torch.Tensor: 35 return self.seq2(self.seq1(x)) 36 37 38ToyNamedTuple = namedtuple("ToyNamedTuple", "content") 39 40 41class ToyModel(nn.Module): 42 def __init__(self, with_named_tuple=False) -> None: 43 super().__init__() 44 self.net1 = Net() 45 self.net2 = Net() 46 self.with_named_tuple = with_named_tuple 47 48 def forward(self, x: torch.Tensor) -> torch.Tensor: 49 res = self.net2(self.net1(x)) 50 if self.with_named_tuple: 51 return ToyNamedTuple(res) 52 else: 53 return (res,) 54 55 56def forward_hook( 57 self: TestCase, 58 fired_hooks: List[int], 59 expected_module: nn.Module, 60 hook_id: int, 61 module: nn.Module, 62 inp: Tuple[torch.Tensor], 63 out: torch.Tensor, 64) -> None: 65 fired_hooks.append(hook_id) 66 self.assertEqual(id(module), id(expected_module)) 67 self.assertEqual(len(inp), 1) 68 69 70def forward_pre_hook( 71 self: TestCase, 72 fired_hooks: List[int], 73 expected_module: nn.Module, 74 hook_id: int, 75 module: nn.Module, 76 inp: Tuple[torch.Tensor], 77) -> None: 78 fired_hooks.append(hook_id) 79 self.assertEqual(id(module), id(expected_module)) 80 self.assertEqual(len(inp), 1) 81 82 83def full_backward_hook( 84 self: TestCase, 85 fired_hooks: List[int], 86 expected_module: nn.Module, 87 hook_id: int, 88 module: nn.Module, 89 grad_input: Tuple[torch.Tensor], 90 grad_output: Tuple[torch.Tensor], 91) -> None: 92 fired_hooks.append(hook_id) 93 self.assertEqual(id(module), id(expected_module)) 94 self.assertEqual(len(grad_input), 1) 95 self.assertEqual(len(grad_output), 1) 96 97 98def full_backward_pre_hook( 99 self: TestCase, 100 fired_hooks: List[int], 101 expected_module: nn.Module, 102 hook_id: int, 103 module: nn.Module, 104 grad_input: Tuple[torch.Tensor], 105) -> None: 106 fired_hooks.append(hook_id) 107 self.assertEqual(id(module), id(expected_module)) 108 self.assertEqual(len(grad_input), 1) 109 110 111class KwargModel(nn.Module): 112 def __init__(self) -> None: 113 super().__init__() 114 self.net1 = Net() 115 self.net2 = Net() 116 117 def forward(self, x: torch.Tensor, bias: torch.Tensor = None) -> torch.Tensor: 118 if bias is not None: 119 x = x + bias 120 return x 121 122 def internal_forward_hook( 123 self, 124 module: nn.Module, 125 args: Tuple[torch.Tensor], 126 kwargs: Dict[str, Any], 127 out: torch.Tensor, 128 ): 129 return out + kwargs["bias"] 130 131 132class FailsInForwardModel(nn.Module): 133 def __init__(self) -> None: 134 super().__init__() 135 self.net1 = Net() 136 137 def forward(self, x: torch.Tensor, fail: bool = True) -> torch.Tensor: 138 if fail: 139 raise RuntimeError("failing in forward") 140 return self.net1(x) 141 142 143def kwarg_forward_pre_hook( 144 self: TestCase, 145 fired_hooks: List[int], 146 expected_module: nn.Module, 147 hook_id: int, 148 module: nn.Module, 149 args: Tuple[torch.Tensor], 150 kwargs: Dict[str, Any], 151) -> Tuple[Any, Any]: 152 fired_hooks.append(hook_id) 153 self.assertEqual(id(module), id(expected_module)) 154 self.assertEqual(len(args), 1) 155 kwargs["bias"] = 2 * kwargs["bias"] 156 return args, kwargs 157 158 159def kwarg_forward_hook( 160 self: TestCase, 161 fired_hooks: List[int], 162 expected_module: nn.Module, 163 hook_id: int, 164 module: nn.Module, 165 args: Tuple[torch.Tensor], 166 kwargs: Dict[str, Any], 167 out: torch.Tensor, 168) -> Any: 169 fired_hooks.append(hook_id) 170 self.assertEqual(id(module), id(expected_module)) 171 self.assertEqual(len(args), 1) 172 173 out = out + kwargs["bias"] 174 return out 175 176 177class DummyContextManager: 178 def __init__(self, inp): 179 self.input = inp 180 181 def __enter__(self, *args, **kwargs): 182 self.input.append(2) 183 184 def __exit__(self, *args, **kwargs): 185 self.input.append(-1) 186 187 188class TestModuleHooks(TestCase): 189 @parametrize_test("named_tuple", (True, False)) 190 def test_forward_hooks(self, named_tuple): 191 fired_hooks: List[int] = [] 192 model = ToyModel(named_tuple) 193 x = torch.randn(10, 10) 194 hook = partial(forward_hook, self, fired_hooks, model.net1.seq2) 195 model.net1.seq2.register_forward_hook(partial(hook, 0)) 196 model.net1.seq2.register_forward_hook(partial(hook, 1), prepend=True) 197 model.net1.seq2.register_forward_hook(partial(hook, 2)) 198 model.net1.seq2.register_forward_hook(partial(hook, 3)) 199 model.net1.seq2.register_forward_hook(partial(hook, 4), prepend=True) 200 expected = [4, 1, 0, 2, 3] 201 202 self.assertEqual(fired_hooks, []) 203 out = model(x) 204 self.assertEqual(fired_hooks, expected) 205 self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple) 206 out[0].sum().backward() 207 self.assertEqual(fired_hooks, expected) 208 model(x)[0].sum().backward() 209 self.assertEqual(fired_hooks, expected + expected) 210 211 @parametrize_test("named_tuple", (True, False)) 212 def test_forward_pre_hooks(self, named_tuple): 213 fired_hooks: List[int] = [] 214 model = ToyModel(named_tuple) 215 x = torch.randn(10, 10) 216 hook = partial(forward_pre_hook, self, fired_hooks, model.net2.seq1) 217 model.net2.seq1.register_forward_pre_hook(partial(hook, 0), prepend=True) 218 model.net2.seq1.register_forward_pre_hook(partial(hook, 1)) 219 model.net2.seq1.register_forward_pre_hook(partial(hook, 2)) 220 model.net2.seq1.register_forward_pre_hook(partial(hook, 3)) 221 model.net2.seq1.register_forward_pre_hook(partial(hook, 4), prepend=True) 222 expected = [4, 0, 1, 2, 3] 223 224 self.assertEqual(fired_hooks, []) 225 out = model(x) 226 self.assertEqual(fired_hooks, expected) 227 self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple) 228 out[0].sum().backward() 229 self.assertEqual(fired_hooks, expected) 230 model(x)[0].sum().backward() 231 self.assertEqual(fired_hooks, expected + expected) 232 233 @parametrize_test("named_tuple", (True, False)) 234 def test_full_backward_hooks(self, named_tuple): 235 fired_hooks: List[int] = [] 236 model = ToyModel(named_tuple) 237 x = torch.randn(10, 10) 238 hook = partial(full_backward_hook, self, fired_hooks, model.net1) 239 model.net1.register_full_backward_hook(partial(hook, 0)) 240 model.net1.register_full_backward_hook(partial(hook, 1)) 241 model.net1.register_full_backward_hook(partial(hook, 2)) 242 model.net1.register_full_backward_hook(partial(hook, 3), prepend=True) 243 model.net1.register_full_backward_hook(partial(hook, 4), prepend=True) 244 expected = [4, 3, 0, 1, 2] 245 246 self.assertEqual(fired_hooks, []) 247 out = model(x) 248 self.assertEqual(fired_hooks, []) 249 self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple) 250 out[0].sum().backward() 251 self.assertEqual(fired_hooks, expected) 252 model(x)[0].sum().backward() 253 self.assertEqual(fired_hooks, expected + expected) 254 255 @parametrize_test("named_tuple", (True, False)) 256 def test_full_backward_pre_hooks(self, named_tuple): 257 fired_hooks: List[int] = [] 258 model = ToyModel(named_tuple) 259 x = torch.randn(10, 10) 260 hook = partial(full_backward_pre_hook, self, fired_hooks, model.net1) 261 model.net1.register_full_backward_pre_hook(partial(hook, 0), prepend=True) 262 model.net1.register_full_backward_pre_hook(partial(hook, 1), prepend=True) 263 model.net1.register_full_backward_pre_hook(partial(hook, 2)) 264 model.net1.register_full_backward_pre_hook(partial(hook, 3)) 265 model.net1.register_full_backward_pre_hook(partial(hook, 4)) 266 expected = [1, 0, 2, 3, 4] 267 268 self.assertEqual(fired_hooks, []) 269 out = model(x) 270 self.assertEqual(fired_hooks, []) 271 self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple) 272 out[0].sum().backward() 273 self.assertEqual(fired_hooks, expected) 274 model(x)[0].sum().backward() 275 self.assertEqual(fired_hooks, expected + expected) 276 277 # Backward pre hook can affect subsequent gradient computation 278 for rg in [True, False]: 279 a = torch.ones(2, requires_grad=rg) 280 model = nn.Linear(2, 2) 281 282 def fn(_unused_module, grad_output): 283 return (grad_output[0] * 0,) 284 285 model.register_full_backward_pre_hook(fn) 286 287 out = model(a) 288 out.sum().backward() 289 self.assertEqual(model.weight.grad, torch.zeros(2, 2)) 290 if rg: 291 self.assertEqual(a.grad, torch.zeros_like(a)) 292 else: 293 self.assertIsNone(a.grad) 294 295 @parametrize_test("named_tuple", (True, False)) 296 def test_mixed_hooks(self, named_tuple): 297 fired_hooks: List[int] = [] 298 model = ToyModel(named_tuple) 299 x = torch.randn(10, 10) 300 model.register_forward_pre_hook( 301 partial(forward_pre_hook, self, fired_hooks, model, 0) 302 ) 303 model.register_forward_hook(partial(forward_hook, self, fired_hooks, model, 1)) 304 model.register_full_backward_pre_hook( 305 partial(full_backward_pre_hook, self, fired_hooks, model, 2) 306 ) 307 model.register_full_backward_hook( 308 partial(full_backward_hook, self, fired_hooks, model, 3) 309 ) 310 311 self.assertEqual(fired_hooks, []) 312 out = model(x) 313 self.assertEqual(fired_hooks, [0, 1]) 314 self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple) 315 out[0].sum().backward() 316 self.assertEqual(fired_hooks, [0, 1, 2, 3]) 317 model(x)[0].sum().backward() 318 self.assertEqual(fired_hooks, [0, 1, 2, 3, 0, 1, 2, 3]) 319 320 def test_kwarg_hooks(self): 321 # 1. test forward pre hook 322 fired_hooks: List[int] = [] 323 x: torch.Tensor = torch.ones(10, 10) 324 bias: torch.Tensor = torch.ones(10, 10) 325 model = KwargModel() 326 model.register_forward_pre_hook( 327 partial(kwarg_forward_pre_hook, self, fired_hooks, model, 0), 328 with_kwargs=True, 329 ) 330 331 # forward-pre: bias' = bias * 2 332 # So, out = x + bias * 2 333 self.assertEqual(fired_hooks, []) 334 out = model(x, bias=bias) 335 self.assertEqual(fired_hooks, [0]) 336 self.assertEqual(out, x + 2 * bias, rtol=0, atol=1e-5) 337 338 # 2. test forward pre and forward hooks 339 fired_hooks: List[int] = [] 340 x: torch.Tensor = torch.ones(10, 10) 341 bias: torch.Tensor = torch.ones(10, 10) 342 model = KwargModel() 343 model.register_forward_hook( 344 partial(kwarg_forward_hook, self, fired_hooks, model, 1), 345 with_kwargs=True, 346 ) 347 model.register_forward_pre_hook( 348 partial(kwarg_forward_pre_hook, self, fired_hooks, model, 0), 349 with_kwargs=True, 350 ) 351 352 # forward-pre: bias' = bias * 2 353 # forward: out = x + bias' 354 # forward-post: out = out + bias' 355 # So, out = x + bias * 4 356 self.assertEqual(fired_hooks, []) 357 out = model(x, bias=bias) 358 self.assertEqual(fired_hooks, [0, 1]) 359 self.assertEqual(out, x + 4 * bias, rtol=0, atol=1e-5) 360 361 # 3. test nn.Module member method as forward-post hook 362 x: torch.Tensor = torch.ones(10, 10) 363 bias: torch.Tensor = torch.ones(10, 10) 364 model = KwargModel() 365 model.register_forward_hook(model.internal_forward_hook, with_kwargs=True) 366 367 # forward: out = x + bias 368 # forward-post: out = out + bias 369 # So, out = x + bias * 2 370 out = model(x, bias=bias) 371 self.assertEqual(out, x + 2 * bias, rtol=0, atol=1e-5) 372 373 def test_remove_kwarg_hooks(self): 374 # test forward pre and forward hooks 375 fired_hooks: List[int] = [] 376 x: torch.Tensor = torch.ones(10, 10) 377 bias: torch.Tensor = torch.ones(10, 10) 378 model = KwargModel() 379 forward_hook_handle = model.register_forward_hook( 380 partial(kwarg_forward_hook, self, fired_hooks, model, 1), 381 with_kwargs=True, 382 ) 383 forward_pre_hook_handle = model.register_forward_pre_hook( 384 partial(kwarg_forward_pre_hook, self, fired_hooks, model, 0), 385 with_kwargs=True, 386 ) 387 388 # forward-pre: bias' = bias * 2 389 # forward: out = x + bias' 390 # forward-post: out = out + bias' 391 # So, out = x + bias * 4 392 self.assertEqual(fired_hooks, []) 393 out = model(x, bias=bias) 394 self.assertEqual(fired_hooks, [0, 1]) 395 self.assertEqual(out, x + 4 * bias, rtol=0, atol=1e-5) 396 397 # forward-pre: bias' = bias * 2 398 # forward: out = x + bias' 399 # So, out = x + bias * 2 400 forward_hook_handle.remove() 401 out = model(x, bias=bias) 402 self.assertEqual(fired_hooks, [0, 1, 0]) 403 self.assertEqual(out, x + 2 * bias, rtol=0, atol=1e-5) 404 self.assertFalse(forward_hook_handle.id in model._forward_hooks_with_kwargs) 405 406 # forward: out = x + bias 407 # So, out = x + bias 408 forward_pre_hook_handle.remove() 409 out = model(x, bias=bias) 410 self.assertEqual(fired_hooks, [0, 1, 0]) 411 self.assertEqual(out, x + bias, rtol=0, atol=1e-5) 412 self.assertFalse( 413 forward_pre_hook_handle.id in model._forward_pre_hooks_with_kwargs 414 ) 415 416 def test_always_called_forward_hooks(self): 417 x: torch.Tensor = torch.ones(10, 10) 418 model = FailsInForwardModel() 419 stack = [] 420 ctx = None 421 422 def setup_context(): 423 nonlocal ctx 424 ctx = DummyContextManager(stack) 425 426 def ctx_setup_hook(m, i): 427 setup_context() 428 ctx.__enter__() 429 430 def ctx_setup_failure_hook(m, i): 431 setup_context() 432 ctx.__enter__() 433 raise RuntimeError("failing in ctx setup") 434 435 def ctx_shutdown_hook(m, i, o): 436 ctx.__exit__() 437 438 def ctx_shutdown_failure_hook(m, i, o): 439 ctx.__exit__() 440 raise RuntimeError("failing in ctx shutdown") 441 442 def throw_hook(m, i, o): 443 raise RuntimeError("failing in throw") 444 445 forward_pre_hook_handle = model.register_forward_pre_hook(ctx_setup_hook) 446 forward_hook_handle = model.register_forward_hook( 447 ctx_shutdown_hook, always_call=True 448 ) 449 self.assertTrue(len(model._forward_hooks_always_called) == 1) 450 451 # make sure always_called forward hook runs when model.forward raises RuntimeError 452 with self.assertRaisesRegex(RuntimeError, "failing in forward"): 453 model(x) 454 self.assertEqual(stack, [2, -1]) 455 456 # make sure that always_called forward hook does not run twice if there is no error 457 model(x, fail=False) 458 self.assertEqual(stack, [2, -1, 2, -1]) 459 460 # make sure always_called forward hook runs when forward pre hook raises RuntimeError 461 forward_pre_hook_handle.remove() 462 model.register_forward_pre_hook(ctx_setup_failure_hook) 463 464 with self.assertRaisesRegex(RuntimeError, "failing in ctx setup"): 465 model(x, fail=False) 466 self.assertEqual(stack, [2, -1, 2, -1, 2, -1]) 467 468 # make sure always_called hook runs when another always_called forward hook raises an error 469 forward_hook_handle2 = model.register_forward_hook( 470 throw_hook, prepend=True, always_call=True 471 ) 472 473 # error raised should not be error of the forced hook 474 with self.assertRaisesRegex(RuntimeError, "failing in ctx setup"): 475 model(x, fail=False) 476 self.assertEqual(stack, [2, -1, 2, -1, 2, -1, 2, -1]) 477 478 # make sure that always called forward hooks are properly removed 479 forward_hook_handle.remove() 480 forward_hook_handle2.remove() 481 self.assertTrue(len(model._forward_hooks_always_called) == 0) 482 483 # make sure that always called forward hook is not run twice if it fails while running 484 forward_hook_handle3 = model.register_forward_hook( 485 ctx_shutdown_failure_hook, always_call=True 486 ) 487 with self.assertRaisesRegex(RuntimeError, "failing in ctx setup"): 488 model(x, fail=False) 489 self.assertEqual(stack, [2, -1, 2, -1, 2, -1, 2, -1, 2, -1]) 490 491 forward_hook_handle3.remove() 492 493 global_forward_hook_handle = nn.modules.module.register_module_forward_hook( 494 ctx_shutdown_hook, always_call=True 495 ) 496 self.assertTrue(len(nn.modules.module._global_forward_hooks_always_called) == 1) 497 # make sure global forward hook runs when forward pre hook raises RuntimeError 498 with self.assertRaisesRegex(RuntimeError, "failing in ctx setup"): 499 model(x, fail=False) 500 self.assertEqual(stack, [2, -1, 2, -1, 2, -1, 2, -1, 2, -1, 2, -1]) 501 502 # make sure forced global forward hook is properly removed 503 global_forward_hook_handle.remove() 504 self.assertTrue(len(nn.modules.module._global_forward_hooks_always_called) == 0) 505 with self.assertRaisesRegex(RuntimeError, "failing in ctx setup"): 506 model(x) 507 self.assertEqual(stack, [2, -1, 2, -1, 2, -1, 2, -1, 2, -1, 2, -1, 2]) 508 509 def test_bw_hook_warning_for_non_tensor_or_tuple(self): 510 # Test to verify that backward hook raises warning 511 # if result is not a Tensor or tuple of Tensors. 512 counter = {"forward": 0, "backward": 0} 513 514 def fw_pre_hook(module: nn.Module, _inputs): 515 counter["forward"] += 1 516 517 def fw_hook(module: nn.Module, _inputs, _outputs): 518 counter["forward"] += 1 519 520 def bw_hook(module: nn.Module, _inputs, _outputs): 521 counter["backward"] += 1 522 523 class TestModule(nn.Module): 524 def forward(self, dict): 525 inp = dict["x"] 526 x = torch.nn.functional.softmax(inp, dim=0) 527 return {"x": x} 528 529 x = torch.ones(2, requires_grad=True) 530 model = TestModule() 531 model.register_forward_pre_hook(fw_pre_hook) 532 model.register_forward_hook(fw_hook) 533 model.register_full_backward_pre_hook(bw_hook) 534 model.register_full_backward_hook(bw_hook) 535 536 with warnings.catch_warnings(record=True) as w: 537 y = model({"x": x})["x"] 538 loss = y.sum() 539 loss.backward() 540 541 self.assertEqual(counter["forward"], 2) 542 self.assertEqual(counter["backward"], 0) 543 self.assertEqual(len(w), 1) 544 self.assertTrue("should be a Tensor or a tuple of Tensors" in str(w[0].message)) 545 546 547def _hook_to_pickle(*args, **kwargs): 548 pass 549 550 551class TestStateDictHooks(TestCase): 552 @swap([True, False]) 553 def test_load_state_dict_pre_hook(self): 554 m = nn.Linear(10, 10) 555 m_state_dict = m.state_dict() 556 557 m_load = nn.Linear(10, 10) 558 559 hook_called = 0 560 561 def hook_without_module( 562 state_dict, 563 prefix, 564 local_metadata, 565 strict, 566 missing_keys, 567 unexpected_keys, 568 error_msgs, 569 ): 570 self.assertEqual(m_state_dict, state_dict) 571 nonlocal hook_called 572 hook_called += 1 573 574 def hook_with_module( 575 module, 576 state_dict, 577 prefix, 578 local_metadata, 579 strict, 580 missing_keys, 581 unexpected_keys, 582 error_msgs, 583 ): 584 self.assertEqual(m_state_dict, state_dict) 585 self.assertTrue(m_load is module) 586 nonlocal hook_called 587 hook_called += 1 588 589 hook_called = 0 590 # Test private API since this sets with_module=False which diverges from public API 591 m_load._register_load_state_dict_pre_hook(hook_without_module) 592 m_load.load_state_dict(m_state_dict) 593 self.assertEqual(1, hook_called) 594 595 hook_called = 0 596 m_load.register_load_state_dict_pre_hook(hook_with_module) 597 m_load.load_state_dict(m_state_dict) 598 self.assertEqual(2, hook_called) 599 600 # Test private API with with_module=True 601 hook_called = 0 602 m_load._register_load_state_dict_pre_hook(hook_with_module, True) 603 m_load.load_state_dict(m_state_dict) 604 self.assertEqual(3, hook_called) 605 606 def test_no_extra_ref_to_module(self): 607 try: 608 gc.disable() 609 m = nn.Linear(10, 10) 610 611 m.register_load_state_dict_pre_hook(_hook_to_pickle) 612 weak_m = weakref.ref(m) 613 del m 614 615 self.assertEqual(weak_m(), None) 616 finally: 617 gc.enable() 618 619 def test_pickled_hook(self): 620 m = nn.Linear(10, 10) 621 m.register_load_state_dict_pre_hook(_hook_to_pickle) 622 pickle.loads(pickle.dumps(m)) 623 624 @swap([True, False]) 625 def test_load_state_dict_module_pre_hook(self): 626 hook_called = 0 627 628 # Test with module instance method as hook 629 class MyModule(nn.Module): 630 def __init__(self) -> None: 631 super().__init__() 632 self.foo = torch.nn.Parameter(torch.rand(10)) 633 634 def my_pre_load_hook( 635 self, 636 state_dict, 637 prefix, 638 local_metadata, 639 strict, 640 missing_keys, 641 unexpected_keys, 642 error_msgs, 643 ): 644 assert [] == error_msgs 645 assert [] == unexpected_keys 646 assert [] == missing_keys 647 assert strict 648 nonlocal hook_called 649 hook_called += 1 650 651 def my_pre_load_hook_with_module( 652 self, 653 module, 654 state_dict, 655 prefix, 656 local_metadata, 657 strict, 658 missing_keys, 659 unexpected_keys, 660 error_msgs, 661 ): 662 assert [] == error_msgs 663 assert [] == unexpected_keys 664 assert [] == missing_keys 665 assert strict 666 assert self is module 667 nonlocal hook_called 668 hook_called += 1 669 670 # Test that hooks registered on a submodule are also called 671 # appropriately, i.e. with the submodule as module argument in 672 # my_pre_load_hook_with_module. 673 class MyModuleContainer(nn.Module): 674 def __init__(self, mod): 675 super().__init__() 676 self.mod = mod 677 678 for ctor in [MyModuleContainer, lambda x: x]: 679 m = ctor(MyModule()) 680 state_dict = m.state_dict() 681 if isinstance(m, MyModuleContainer): 682 mod = m.mod 683 else: 684 mod = m 685 686 hook_called = 0 687 # Test private API since this sets with_module=False which diverges from public API 688 mod._register_load_state_dict_pre_hook(mod.my_pre_load_hook) 689 m.load_state_dict(state_dict) 690 self.assertEqual(1, hook_called) 691 692 hook_called = 0 693 mod.register_load_state_dict_pre_hook(mod.my_pre_load_hook_with_module) 694 m.load_state_dict(state_dict) 695 self.assertEqual(2, hook_called) 696 697 @swap([True, False]) 698 def test_load_state_dict_post_hook(self): 699 hook_called = 0 700 701 class MyModule(nn.Module): 702 def __init__(self) -> None: 703 super().__init__() 704 self.foo = torch.nn.Parameter(torch.rand(10)) 705 706 def my_post_load_hook(self, module, incompatible_keys): 707 assert module is self 708 nonlocal hook_called 709 incompatible_keys.missing_keys.append("foo") 710 incompatible_keys.unexpected_keys.append("bar") 711 hook_called += 1 712 713 nested = MyModule() 714 wrapped = nn.ModuleList([nested]) 715 handle = nested.register_load_state_dict_post_hook( 716 nested.my_post_load_hook, 717 ) 718 # Hook must be called even if it is wrapped 719 ret = wrapped.load_state_dict(wrapped.state_dict(), strict=False) 720 self.assertEqual(hook_called, 1) 721 # Ensure that the hook modified missing_keys and unexpected_keys 722 missing = ret.missing_keys 723 unexpected = ret.unexpected_keys 724 self.assertEqual(missing, ["foo"]) 725 self.assertEqual(unexpected, ["bar"]) 726 # When called with strict=True, the error raised should mention the 727 # missing and unexpected keys the hook added. 728 with self.assertRaisesRegex(RuntimeError, "foo.*\n.*bar"): 729 wrapped.load_state_dict(wrapped.state_dict(), strict=True) 730 self.assertEqual(hook_called, 2) 731 # Removing the hook via handle.remove() should cause it not to 732 # fire anymore. 733 handle.remove() 734 # Hook did not run so it should not have added any keys 735 ret = wrapped.load_state_dict(wrapped.state_dict(), strict=False) 736 self.assertEqual(ret.missing_keys, []) 737 self.assertEqual(ret.unexpected_keys, []) 738 # hook_called should not have been incremented 739 self.assertEqual(hook_called, 2) 740 741 def load_hook_clear_incompatible(module, incompatible_keys): 742 incompatible_keys.missing_keys.clear() 743 incompatible_keys.unexpected_keys.clear() 744 745 nested.register_load_state_dict_post_hook(load_hook_clear_incompatible) 746 state_dict = wrapped.state_dict() 747 state_dict["extra"] = torch.ones(1) 748 # load state_dict with strict=True should not throw. 749 ret = wrapped.load_state_dict(state_dict, strict=True) 750 # explicitly ensure that the post hook clearned out incompatible_keys 751 self.assertEqual([], ret.missing_keys) 752 self.assertEqual([], ret.unexpected_keys) 753 754 @unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows") 755 @swap([True, False]) 756 def test_load_state_dict_post_hook_backward_compatibility(self): 757 def my_post_load_hook(mod, _): 758 nonlocal called 759 called = True 760 761 for m in [nn.Softmin(10), nn.Softmax(10), nn.LogSoftmax(10)]: 762 called = False 763 sd = deepcopy(m.state_dict()) 764 self.assertTrue(hasattr(m, "_load_state_dict_post_hooks")) 765 # Simulate an older model that did not have this attr 766 delattr(m, "_load_state_dict_post_hooks") 767 # Save and load, and ensure that load_state_dict works (without proper 768 # BC we would run into errors because this attribute would be expected). 769 # In particular, Softmax runs into the issue described here: 770 # https://github.com/pytorch/pytorch/issues/77280 771 with NamedTemporaryFile() as f: 772 # Note that torch.save / torch.load is not recommended to save/load 773 # modules. 774 torch.save(m, f.name) 775 # weights_only=False as this is legacy code that saves the model 776 m = torch.load(f.name, weights_only=False) 777 m.load_state_dict(sd) 778 self.assertFalse(called) 779 780 # Ensure hooks can be registered and called. 781 m.register_load_state_dict_post_hook(my_post_load_hook) 782 m.load_state_dict(sd) 783 self.assertTrue(called) 784 785 def _test_register_state_dict_pre_hook(self, model, submodule): 786 _state_dict_prefix = "foo." 787 state_dict_pre_hook_count = 0 788 keep_var_setting = False 789 790 def my_state_dict_pre_hook(module, prefix, keep_vars): 791 self.assertEqual(keep_vars, keep_var_setting) 792 nonlocal state_dict_pre_hook_count 793 state_dict_pre_hook_count += 1 794 self.assertTrue(prefix.startswith(_state_dict_prefix)) 795 796 model.register_state_dict_pre_hook(my_state_dict_pre_hook) 797 # Test to ensure submodules run the hook as well. 798 submodule.register_state_dict_pre_hook(my_state_dict_pre_hook) 799 800 def check_results(model): 801 nonlocal state_dict_pre_hook_count, keep_var_setting 802 for keep_var_setting in [True, False]: 803 _ = model.state_dict( 804 prefix=_state_dict_prefix, keep_vars=keep_var_setting 805 ) 806 self.assertEqual(2, state_dict_pre_hook_count) 807 state_dict_pre_hook_count = 0 808 809 # Test state dict works as expected after model construction 810 check_results(model) 811 # Test state dict works as expected after forward 812 model(torch.ones(10, 3)) 813 check_results(model) 814 815 def test_register_state_dict_pre_hook(self): 816 class MyModule(torch.nn.Module): 817 def __init__(self) -> None: 818 super().__init__() 819 self.a = nn.Sequential( 820 nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3) 821 ) 822 823 def forward(self, x): 824 return self.a(x) 825 826 mod = MyModule() 827 self._test_register_state_dict_pre_hook(mod, mod.a) 828 829 def test_register_state_dict_pre_hook_lazy_module(self): 830 class MyLazyModule(torch.nn.Module): 831 def __init__(self) -> None: 832 super().__init__() 833 self.layer1 = nn.LazyLinear(8) 834 self.layer2 = nn.LazyLinear(5) 835 836 def forward(self, x): 837 return self.layer2(self.layer1(x)) 838 839 mod = MyLazyModule() 840 self._test_register_state_dict_pre_hook(mod, mod.layer1) 841 842 @unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows") 843 def test_register_state_dict_pre_hook_backward_compat(self): 844 called = False 845 846 def my_state_dict_pre_hook(*args, **kwargs): 847 nonlocal called 848 called = True 849 850 m = nn.Linear(1, 1) 851 self.assertTrue(hasattr(m, "_state_dict_pre_hooks")) 852 delattr(m, "_state_dict_pre_hooks") 853 # Save and load, ensure we can still call state_dict 854 # without running into issues. 855 with NamedTemporaryFile() as f: 856 # Note that torch.save / torch.load is not recommended 857 # to save / load modules. 858 torch.save(m, f.name) 859 # weights_only=False as this is legacy code that saves the model 860 m = torch.load(f.name, weights_only=False) 861 862 # Ensure we can run state_dict without issues 863 _ = m.state_dict() 864 self.assertFalse(called) 865 m.register_state_dict_pre_hook(my_state_dict_pre_hook) 866 _ = m.state_dict() 867 self.assertTrue(called) 868 869 @parametrize_test("private", [True, False]) 870 def test_register_state_dict_post_hook(self, private): 871 m = nn.Transformer( 872 d_model=4, nhead=2, num_encoder_layers=2, num_decoder_layers=2 873 ) 874 875 def linear_state_dict_post_hook(module, state_dict, prefix, local_metadata): 876 for name, param in module.named_parameters(recurse=False): 877 state_dict[prefix + name] = torch.nn.Parameter( 878 state_dict[prefix + name] 879 ) 880 881 def register_linear_hook(module): 882 if isinstance(module, nn.Linear): 883 hook_registration_fn = ( 884 module._register_state_dict_hook 885 if private 886 else module.register_state_dict_post_hook 887 ) 888 hook_registration_fn(linear_state_dict_post_hook) 889 890 def _check_sd(state_dict): 891 for k, v in m.state_dict().items(): 892 if "linear" in k or "out_proj" in k: 893 self.assertTrue(isinstance(v, torch.nn.Parameter)) 894 else: 895 self.assertFalse(isinstance(v, torch.nn.Parameter)) 896 897 # verify that return type of hook registered on child submodules has no effect 898 # regardless of whether using public or private API 899 m.apply(register_linear_hook) 900 _check_sd(m.state_dict()) 901 902 # verify that return type of hook registered root module has no effect 903 # for public API but has effect for private API 904 hook_registration_fn = ( 905 m._register_state_dict_hook if private else m.register_state_dict_post_hook 906 ) 907 908 def fn(m, s, p, l): 909 return OrderedDict() 910 911 handle = hook_registration_fn(fn) 912 if private: 913 self.assertFalse(hasattr(fn, "_from_public_api")) 914 self.assertTrue(len(m.state_dict()) == 0) 915 else: 916 self.assertTrue(hasattr(fn, "_from_public_api")) 917 with self.assertRaisesRegex( 918 RuntimeError, "state_dict post-hook must return None" 919 ): 920 sd = m.state_dict() 921 with self.assertRaisesRegex( 922 RuntimeError, "previously registered via register_state_dict_post_hook" 923 ): 924 m._register_state_dict_hook(fn) 925 926 927class TestModuleGlobalHooks(TestCase): 928 def tearDown(self): 929 nn.modules.module._global_backward_hooks = OrderedDict() 930 nn.modules.module._global_forward_hooks = OrderedDict() 931 nn.modules.module._global_forward_pre_hooks = OrderedDict() 932 933 @skipIfTorchDynamo("TorchDynamo does not work well with hooks") 934 def test_module_global_hooks(self): 935 module = nn.Sigmoid 936 937 module_1 = module() 938 module_2 = module() 939 module_3 = module() 940 941 input = torch.ones(5, 5, requires_grad=True) 942 943 counter = {"forwards": 0, "backwards": 0} 944 945 def fw_hook(inc, h_module, input, output): 946 self.assertIsInstance(input, tuple) 947 self.assertTrue(isinstance(output, torch.Tensor)) 948 self.assertTrue(isinstance(h_module, module)) 949 self.assertEqual(input[0], torch.ones(5, 5)) 950 self.assertEqual(output, torch.empty(5, 5).fill_(1 / (1 + 1 / math.e))) 951 counter["forwards"] += inc 952 953 def bw_hook(inc, h_module, grad_input, grad_output): 954 self.assertIsInstance(grad_input, tuple) 955 self.assertIsInstance(grad_output, tuple) 956 self.assertTrue(isinstance(h_module, module)) 957 self.assertEqual(grad_output[0], torch.ones(5, 5) * 2) 958 counter["backwards"] += inc 959 960 test_fwd = nn.modules.module.register_module_forward_hook( 961 lambda *args: fw_hook(1, *args) 962 ) 963 964 module_1(input) 965 module_2(input) 966 module_3(input) 967 self.assertEqual(counter["forwards"], 3) 968 self.assertEqual(counter["backwards"], 0) 969 970 test_bwd = nn.modules.module.register_module_backward_hook( 971 lambda *args: bw_hook(1, *args) 972 ) 973 974 output_1 = module_1(input) 975 output_2 = module_2(input) 976 output_3 = module_3(input) 977 self.assertEqual(counter["forwards"], 6) 978 self.assertEqual(counter["backwards"], 0) 979 980 output_1.backward(torch.ones(5, 5) * 2, retain_graph=True) 981 output_2.backward(torch.ones(5, 5) * 2, retain_graph=False) 982 output_3.backward(torch.ones(5, 5) * 2, retain_graph=False) 983 self.assertEqual(counter["forwards"], 6) 984 self.assertEqual(counter["backwards"], 3) 985 986 output_1.backward(torch.ones(5, 5) * 2, retain_graph=True) 987 self.assertEqual(counter["forwards"], 6) 988 self.assertEqual(counter["backwards"], 4) 989 990 test2_fwd = nn.modules.module.register_module_forward_hook( 991 lambda *args: fw_hook(2, *args) 992 ) 993 994 output = module_1(input) 995 output = module_2(input) 996 output = module_3(input) 997 self.assertEqual(counter["forwards"], 15) 998 self.assertEqual(counter["backwards"], 4) 999 1000 test2_bwd = nn.modules.module.register_module_backward_hook( 1001 lambda *args: bw_hook(2, *args) 1002 ) 1003 1004 module_1(input).backward(torch.ones(5, 5) * 2) 1005 self.assertEqual(counter["forwards"], 18) 1006 self.assertEqual(counter["backwards"], 7) 1007 1008 test2_bwd.remove() 1009 1010 module_2(input).backward(torch.ones(5, 5) * 2) 1011 self.assertEqual(counter["forwards"], 21) 1012 self.assertEqual(counter["backwards"], 8) 1013 1014 test2_fwd.remove() 1015 1016 module_3(input).backward(torch.ones(5, 5) * 2) 1017 self.assertEqual(counter["forwards"], 22) 1018 self.assertEqual(counter["backwards"], 9) 1019 1020 test_fwd.remove() 1021 test_bwd.remove() 1022 1023 def test_module_global_hook_invalid_outputs(self): 1024 module = nn.Sigmoid() 1025 input = torch.randn(5, 5, requires_grad=True) 1026 1027 def bw_fail1(self, grad_input, grad_output): 1028 return grad_input[:-1] 1029 1030 def bw_fail2(self, grad_input, grad_output): 1031 return grad_input + (torch.randn(2, 2),) 1032 1033 with nn.modules.module.register_module_backward_hook(bw_fail1): 1034 with self.assertRaisesRegex(RuntimeError, "got 0, but expected 1"): 1035 module(input).sum().backward() 1036 1037 with nn.modules.module.register_module_backward_hook(bw_fail2): 1038 with self.assertRaisesRegex(RuntimeError, "got 2, but expected 1"): 1039 module(input).sum().backward() 1040 1041 def test_module_backward_global_hook_writeable(self): 1042 module = nn.Sigmoid() 1043 input = torch.randn(5, 5, requires_grad=True) 1044 sig_x = torch.sigmoid(input) 1045 1046 def bw_hook(module, grad_input, grad_output): 1047 for grad in grad_input: 1048 self.assertTrue(isinstance(grad, torch.Tensor)) 1049 for grad in grad_output: 1050 self.assertTrue(isinstance(grad, torch.Tensor)) 1051 return tuple(gi * 2 for gi in grad_input) 1052 1053 nn.modules.module.register_module_backward_hook(bw_hook) 1054 module(input).backward(torch.ones(5, 5)) 1055 expected_grad = sig_x * (1 - sig_x) * 2 1056 self.assertEqual(input.grad, expected_grad) 1057 1058 @skipIfTorchDynamo("TorchDynamo does not work well with hooks") 1059 def test_module_global_forward_preforward_hook_writeable(self): 1060 module = nn.Sigmoid() 1061 input = torch.randn(5, 5, requires_grad=True) 1062 sig_x = torch.sigmoid(input) 1063 1064 def forward_pre_hook(m, input): 1065 return torch.nn.functional.relu(input[0]) 1066 1067 def forward_hook(m, input, output): 1068 return -output 1069 1070 nn.modules.module.register_module_forward_pre_hook(forward_pre_hook) 1071 nn.modules.module.register_module_forward_hook(forward_hook) 1072 output = module(input) 1073 expected_res = -torch.sigmoid(torch.nn.functional.relu(input)) 1074 self.assertEqual(output, expected_res) 1075 output.backward(torch.ones(5, 5) * 2, retain_graph=True) 1076 mask = input > 0 1077 expected_grad = -sig_x * (1 - sig_x) * 2 * mask 1078 self.assertEqual(input.grad, expected_grad) 1079 1080 def test_module_forward_preforward_hook_removable(self): 1081 """ 1082 This test is to test when multiple pre-forward hook functions can be 1083 registered successfully and used correctly, if the handle can be removable 1084 during the pre-forward hook function call. 1085 """ 1086 module = nn.Sigmoid() 1087 1088 def removable_hook(m, input): 1089 nonlocal handle 1090 handle.remove() 1091 return input 1092 1093 def removable_hook_2(m, input): 1094 nonlocal handle_2 1095 handle_2.remove() 1096 return input 1097 1098 handle = module.register_forward_pre_hook(removable_hook) 1099 handle_2 = module.register_forward_pre_hook(removable_hook_2) 1100 1101 # make sure hook register is successful 1102 self.assertEqual(len(handle.hooks_dict_ref()), 2) 1103 self.assertEqual(len(handle_2.hooks_dict_ref()), 2) 1104 1105 input = torch.randn(2, 2) 1106 output = module(input) 1107 self.assertEqual(torch.sigmoid(input), output) 1108 1109 # make sure hook removal is successful 1110 self.assertFalse(handle.id in handle.hooks_dict_ref()) 1111 self.assertFalse(handle_2.id in handle.hooks_dict_ref()) 1112 self.assertEqual(len(handle.hooks_dict_ref()), 0) 1113 self.assertEqual(len(handle_2.hooks_dict_ref()), 0) 1114 1115 def test_module_forward_forward_hook_removable(self): 1116 """ 1117 This test is to test when multiple forward hook functions can be registered 1118 successfully and used correctly, if the handle can be removable during the 1119 forward hook function call. 1120 """ 1121 module = nn.Sigmoid() 1122 1123 def removable_hook(m, input, output): 1124 nonlocal handle 1125 handle.remove() 1126 return output 1127 1128 def removable_hook_2(m, input, output): 1129 nonlocal handle_2 1130 handle_2.remove() 1131 return output 1132 1133 handle = module.register_forward_hook(removable_hook) 1134 handle_2 = module.register_forward_hook(removable_hook_2) 1135 1136 # make sure hook register is successful 1137 self.assertEqual(len(handle.hooks_dict_ref()), 2) 1138 self.assertEqual(len(handle_2.hooks_dict_ref()), 2) 1139 1140 input = torch.randn(2, 2) 1141 output = module(input) 1142 self.assertEqual(torch.sigmoid(input), output) 1143 1144 # make sure hook removal is successful 1145 self.assertFalse(handle.id in handle.hooks_dict_ref()) 1146 self.assertFalse(handle_2.id in handle.hooks_dict_ref()) 1147 self.assertEqual(len(handle.hooks_dict_ref()), 0) 1148 self.assertEqual(len(handle_2.hooks_dict_ref()), 0) 1149 1150 @skipIfTorchDynamo("TorchDynamo does not work well with hooks") 1151 def test_global_and_local_hooks_order(self): 1152 module = nn.Sigmoid() 1153 1154 global_forward_pre_called = False 1155 local_forward_pre_called = False 1156 global_forward_called = False 1157 local_forward_called = False 1158 global_backward_called = False 1159 local_backward_called = False 1160 1161 def global_forward_pre_hook(m, input): 1162 nonlocal global_forward_pre_called 1163 self.assertTrue(not local_forward_pre_called) 1164 global_forward_pre_called = True 1165 return input 1166 1167 def local_forward_pre_hook(m, input): 1168 nonlocal local_forward_pre_called 1169 self.assertTrue(global_forward_pre_called) 1170 local_forward_pre_called = True 1171 return input 1172 1173 def global_forward_hook(m, input, output): 1174 nonlocal global_forward_called 1175 self.assertTrue(not local_forward_called) 1176 global_forward_called = True 1177 return output 1178 1179 def local_forward_hook(m, input, output): 1180 nonlocal local_forward_called 1181 self.assertTrue(global_forward_called) 1182 local_forward_called = True 1183 return output 1184 1185 def global_backward_hook(m, input, output): 1186 nonlocal global_backward_called 1187 self.assertTrue(not local_backward_called) 1188 global_backward_called = True 1189 return input 1190 1191 def local_backward_hook(m, input, output): 1192 nonlocal local_backward_called 1193 self.assertTrue(global_backward_called) 1194 local_backward_called = True 1195 return input 1196 1197 input = torch.randn(5, 5, requires_grad=True) 1198 nn.modules.module.register_module_forward_pre_hook(global_forward_pre_hook) 1199 module.register_forward_pre_hook(local_forward_pre_hook) 1200 nn.modules.module.register_module_forward_hook(global_forward_hook) 1201 module.register_forward_hook(local_forward_hook) 1202 nn.modules.module.register_module_backward_hook(global_backward_hook) 1203 module.register_backward_hook(local_backward_hook) 1204 1205 output = module(input) 1206 self.assertTrue( 1207 local_forward_called 1208 and local_forward_pre_called 1209 and global_forward_called 1210 and global_forward_pre_called 1211 ) 1212 1213 output.backward(torch.ones(5, 5), retain_graph=True) 1214 self.assertTrue(local_backward_called and global_backward_called) 1215 1216 1217class TestModuleHookNN(NNTestCase): 1218 _do_cuda_memory_leak_check = True 1219 _do_cuda_non_default_stream = True 1220 1221 def _test_hooks(self, backward_register_fn): 1222 module = nn.Sigmoid() 1223 input = torch.ones(5, 5, requires_grad=True) 1224 1225 counter = {"forwards": 0, "backwards": 0} 1226 1227 def fw_hook(inc, h_module, input, output): 1228 self.assertIsInstance(input, tuple) 1229 self.assertTrue(isinstance(output, torch.Tensor)) 1230 self.assertTrue(h_module is module) 1231 self.assertEqual(input[0], torch.ones(5, 5)) 1232 self.assertEqual(output, torch.empty(5, 5).fill_(1 / (1 + 1 / math.e))) 1233 counter["forwards"] += inc 1234 1235 def bw_hook(inc, h_module, grad_input, grad_output): 1236 self.assertIsInstance(grad_input, tuple) 1237 self.assertIsInstance(grad_output, tuple) 1238 self.assertTrue(h_module is module) 1239 self.assertEqual(grad_output[0], torch.ones(5, 5) * 2) 1240 counter["backwards"] += inc 1241 1242 # backward_pre_hook expects callback with only `module` and `grad_output` 1243 # as arguments. 1244 def bw_pre_hook(inc, h_module, grad_output): 1245 self.assertIsInstance(grad_output, tuple) 1246 self.assertTrue(h_module is module) 1247 self.assertEqual(grad_output[0], torch.ones(5, 5) * 2) 1248 counter["backwards"] += inc 1249 1250 test_fwd = module.register_forward_hook(lambda *args: fw_hook(1, *args)) 1251 1252 module(input) 1253 module(input) 1254 self.assertEqual(counter["forwards"], 2) 1255 self.assertEqual(counter["backwards"], 0) 1256 1257 bw_hook_fn = ( 1258 bw_pre_hook 1259 if backward_register_fn == "register_full_backward_pre_hook" 1260 else bw_hook 1261 ) 1262 test_bwd = getattr(module, backward_register_fn)( 1263 lambda *args: bw_hook_fn(1, *args) 1264 ) 1265 1266 output = module(input) 1267 self.assertEqual(counter["forwards"], 3) 1268 self.assertEqual(counter["backwards"], 0) 1269 1270 output.backward(torch.ones(5, 5) * 2, retain_graph=True) 1271 self.assertEqual(counter["forwards"], 3) 1272 self.assertEqual(counter["backwards"], 1) 1273 1274 output.backward(torch.ones(5, 5) * 2, retain_graph=True) 1275 self.assertEqual(counter["forwards"], 3) 1276 self.assertEqual(counter["backwards"], 2) 1277 1278 test2_fwd = module.register_forward_hook(lambda *args: fw_hook(2, *args)) 1279 1280 output = module(input) 1281 self.assertEqual(counter["forwards"], 6) 1282 self.assertEqual(counter["backwards"], 2) 1283 1284 test2_bwd = getattr(module, backward_register_fn)( 1285 lambda *args: bw_hook_fn(2, *args) 1286 ) 1287 1288 module(input).backward(torch.ones(5, 5) * 2) 1289 self.assertEqual(counter["forwards"], 9) 1290 self.assertEqual(counter["backwards"], 5) 1291 1292 test2_bwd.remove() 1293 1294 module(input).backward(torch.ones(5, 5) * 2) 1295 self.assertEqual(counter["forwards"], 12) 1296 self.assertEqual(counter["backwards"], 6) 1297 1298 test2_fwd.remove() 1299 1300 module(input).backward(torch.ones(5, 5) * 2) 1301 self.assertEqual(counter["forwards"], 13) 1302 self.assertEqual(counter["backwards"], 7) 1303 1304 test_fwd.remove() 1305 test_bwd.remove() 1306 1307 def test_hooks(self): 1308 self._test_hooks("register_backward_hook") 1309 self._test_hooks("register_full_backward_hook") 1310 self._test_hooks("register_full_backward_pre_hook") 1311 1312 def test_hook_cpp(self): 1313 bn = nn.BatchNorm1d(5) 1314 1315 def hook(module, grad_inputs, grad_outputs): 1316 self.assertEqual(len(grad_inputs), 1) 1317 self.assertEqual(len(grad_outputs), 1) 1318 self.assertEqual(module, bn) 1319 1320 bn.register_full_backward_hook(hook) 1321 output = bn(torch.randn(5, 5, requires_grad=True)) 1322 output.sum().backward() 1323 1324 def test_backward_hooks_interaction(self): 1325 # Test to make sure that the grad_outputs 1326 # updated by full_backward_pre_hook are received by 1327 # the full_backward_hook 1328 module = torch.nn.Sigmoid() 1329 1330 cnt = {"backward_cnt": 0} 1331 1332 def bw_pre_hook(m, grad_output): 1333 cnt["backward_cnt"] += 1 1334 return (grad_output[0] * 0.5,) 1335 1336 def bw_hook(m, grad_in, grad_output): 1337 self.assertEqual(torch.full_like(grad_output[0], 0.5), grad_output[0]) 1338 cnt["backward_cnt"] += 1 1339 return grad_output 1340 1341 module.register_full_backward_pre_hook(bw_pre_hook) 1342 module.register_full_backward_hook(bw_hook) 1343 1344 t = torch.ones(1, 2, requires_grad=True) 1345 module(t).sum().backward() 1346 self.assertEqual(cnt["backward_cnt"], 2) 1347 1348 def test_hook_invalid_outputs(self): 1349 module = nn.Sigmoid() 1350 input = torch.randn(5, 5, requires_grad=True) 1351 1352 def bw_fail1(self, grad_input, grad_output): 1353 return grad_input[:-1] 1354 1355 def bw_fail2(self, grad_input, grad_output): 1356 return grad_input + (torch.randn(2, 2),) 1357 1358 with module.register_backward_hook(bw_fail1): 1359 with self.assertRaisesRegex(RuntimeError, "got 0, but expected 1"): 1360 module(input).sum().backward() 1361 1362 with module.register_backward_hook(bw_fail2): 1363 with self.assertRaisesRegex(RuntimeError, "got 2, but expected 1"): 1364 module(input).sum().backward() 1365 1366 def bw_pre_fail1(self, grad_output): 1367 return () 1368 1369 def bw_pre_fail2(self, grad_output): 1370 return grad_output + (torch.randn(2, 2),) 1371 1372 with module.register_full_backward_pre_hook(bw_pre_fail1): 1373 with self.assertRaisesRegex(RuntimeError, "got 0, but expected 1"): 1374 module(input).sum().backward() 1375 1376 with module.register_full_backward_pre_hook(bw_pre_fail2): 1377 with self.assertRaisesRegex(RuntimeError, "got 2, but expected 1"): 1378 module(input).sum().backward() 1379 1380 def test_hook_requires_grad(self): 1381 test_self = self 1382 1383 class MyModule(nn.Module): 1384 def forward(self, arg1, arg2, arg3): 1385 test_self.assertTrue(arg1.requires_grad) 1386 test_self.assertFalse(arg2.requires_grad) 1387 test_self.assertTrue(arg3.requires_grad) 1388 return arg1.sum() + arg2.sum() + arg3.sum() 1389 1390 inp = torch.rand(2, requires_grad=True) 1391 mod = MyModule() 1392 1393 mod(inp, inp.detach(), inp) 1394 # Ensure that requires grad is properly propagated 1395 mod.register_full_backward_hook(lambda mod, gI, gO: None) 1396 mod(inp, inp.detach(), inp) 1397 1398 def test_hook_no_requires_grad(self): 1399 mod = nn.Linear(2, 3) 1400 1401 inp = torch.rand(1, 2) 1402 1403 return_val = "None" 1404 hook_called = [0] 1405 1406 def hook(mod, grad_input, grad_output): 1407 hook_called[0] += 1 1408 for gI in grad_input: 1409 self.assertIsNone(gI) 1410 for gO in grad_output: 1411 self.assertEqual(gO.size(), (1, 3)) 1412 1413 if return_val == "grad_input": 1414 return grad_input 1415 elif return_val == "invalid": 1416 # If the inputs were requiring gradients, this would be 1417 # a valid return 1418 return inp 1419 elif return_val == "None": 1420 return None 1421 else: 1422 raise RuntimeError("Invalid return_val string") 1423 1424 mod.register_full_backward_hook(hook) 1425 1426 # This should run and trigger the hook properly 1427 mod(inp).sum().backward() 1428 self.assertEqual(hook_called[0], 1) 1429 1430 return_val = "grad_input" 1431 1432 mod(inp).sum().backward() 1433 self.assertEqual(hook_called[0], 2) 1434 1435 return_val = "invalid" 1436 with self.assertRaisesRegex(RuntimeError, "where no input requires gradient"): 1437 mod(inp).sum().backward() 1438 1439 def test_hook_last_arg_requires_grad(self): 1440 mod = nn.L1Loss() 1441 inp = torch.rand(1, requires_grad=True) 1442 mod.register_full_backward_hook(lambda m, gI, gO: None) 1443 1444 try: 1445 mod(inp.detach(), inp) 1446 except Exception as ex: 1447 self.fail(f"Unexpected exception: {ex}") 1448 1449 def test_hook_extra_input(self): 1450 class MyModule(nn.Module): 1451 def forward(self, non_tensor, tensor): 1452 return tensor.clone(), non_tensor 1453 1454 inp = torch.rand(2, requires_grad=True) 1455 mod = MyModule() 1456 1457 def hook(mod, grad_input, grad_output): 1458 self.assertIsNone(grad_input[0]) 1459 self.assertIsInstance(grad_input[1], torch.Tensor) 1460 1461 self.assertIsInstance(grad_output[0], torch.Tensor) 1462 self.assertIsNone(grad_output[1]) 1463 1464 mod.register_full_backward_hook(hook) 1465 out, _ = mod(True, inp) 1466 out.sum().backward() 1467 1468 def test_hook_inplace(self): 1469 class MyModule(nn.Module): 1470 def forward(self, inp, do_inplace): 1471 self.inp = inp 1472 if do_inplace: 1473 inp += 1 1474 return inp.clone() 1475 1476 hook_called = [0] 1477 1478 def hook(mod, grad_input, grad_output): 1479 hook_called[0] += 1 1480 1481 def hook_pre(mod, grad_output): 1482 hook_called[0] += 1 1483 1484 inp = torch.rand(10, requires_grad=True) 1485 mod = MyModule() 1486 for hook_fn, register_fn in [ 1487 (hook, mod.register_full_backward_hook), 1488 (hook_pre, mod.register_full_backward_pre_hook), 1489 ]: 1490 hook_called[0] = 0 1491 with register_fn(hook_fn): 1492 # No inplace should work 1493 mod(inp, False).sum().backward() 1494 self.assertEqual(hook_called[0], 1) 1495 1496 # Input inplace error should throw an error 1497 with self.assertRaisesRegex( 1498 RuntimeError, 1499 "Output 0 of BackwardHookFunctionBackward is " 1500 "a view and is being modified inplace.", 1501 ): 1502 mod(inp.clone(), True) 1503 1504 # Input inplace error should throw an error if we try to re-use the view after they have 1505 # been modified 1506 local_inp = inp.clone() 1507 out = mod(local_inp, False) 1508 local_inp[0] *= 1 1509 with self.assertRaisesRegex( 1510 RuntimeError, 1511 "Output 0 of BackwardHookFunctionBackward is " 1512 "a view and its base or another view", 1513 ): 1514 # Any operation involving the view will fail here 1515 mod.inp + 2 1516 1517 # Output inplace error should throw an error 1518 out = mod(inp, False) 1519 with self.assertRaisesRegex( 1520 RuntimeError, 1521 "BackwardHookFunctionBackward is a view " 1522 "and is being modified inplace.", 1523 ): 1524 out += 1 1525 1526 def test_hook_non_full_warning(self): 1527 def noop(*args): 1528 pass 1529 1530 a = torch.rand(2, requires_grad=True) 1531 b = torch.rand(2, requires_grad=True) 1532 1533 # Check invalid input container 1534 class MyModule(nn.Module): 1535 def forward(self, l): 1536 return l[0].clone(), l[1].clone() 1537 1538 m = MyModule() 1539 m.register_backward_hook(noop) 1540 1541 with self.assertWarnsRegex( 1542 FutureWarning, 1543 "does not take as input a single Tensor or a tuple of Tensors", 1544 ): 1545 m([a, b]) 1546 1547 # Check invalid output container 1548 class MyModule(nn.Module): 1549 def forward(self, a, b): 1550 return [a.clone(), b.clone()] 1551 1552 m = MyModule() 1553 m.register_backward_hook(noop) 1554 1555 with self.assertWarnsRegex( 1556 FutureWarning, "does not return a single Tensor or a tuple of Tensors" 1557 ): 1558 m(a, b) 1559 1560 # Check invalid output from different Nodes 1561 class MyModule(nn.Module): 1562 def forward(self, a, b): 1563 return a.clone(), b.clone() 1564 1565 m = MyModule() 1566 m.register_backward_hook(noop) 1567 1568 with self.assertWarnsRegex( 1569 FutureWarning, "outputs are generated by different autograd Nodes" 1570 ): 1571 m(a, b) 1572 1573 # Check invalid forward with multiple Nodes 1574 class MyModule(nn.Module): 1575 def forward(self, a): 1576 return a.clone().clone() 1577 1578 m = MyModule() 1579 m.register_backward_hook(noop) 1580 1581 with self.assertWarnsRegex( 1582 FutureWarning, "the forward contains multiple autograd Nodes" 1583 ): 1584 m(a) 1585 1586 def test_hook_backward_size(self): 1587 # Make module with multiple operations in forward 1588 # And different size for input and outputs 1589 class MyModule(nn.Module): 1590 def forward(self, arg1, arg2): 1591 tmp = arg1.sum() * arg2 1592 tmp = tmp + arg2.sum() * arg1.sum() 1593 tmp = tmp.sum().view(1) 1594 tmp = tmp.expand(8).contiguous() 1595 return tmp 1596 1597 module = MyModule() 1598 inp1 = torch.randn(5, 5, requires_grad=True) 1599 inp2 = torch.randn(10, 10, requires_grad=True) 1600 1601 def bw_hook(module, grad_input, grad_output): 1602 self.assertEqual(len(grad_input), 2) 1603 self.assertEqual(grad_input[0].size(), torch.Size([5, 5])) 1604 self.assertEqual(grad_input[1].size(), torch.Size([10, 10])) 1605 self.assertEqual(len(grad_output), 1) 1606 self.assertEqual(grad_output[0].size(), torch.Size([8])) 1607 1608 with module.register_full_backward_hook(bw_hook): 1609 module(inp1, inp2).sum().backward() 1610 1611 def test_hook_backward_writeable(self): 1612 module = nn.Sigmoid() 1613 input = torch.randn(5, 5, requires_grad=True) 1614 sig_x = torch.nn.functional.sigmoid(input) 1615 1616 def bw_hook(module, grad_input, grad_output): 1617 for grad in grad_input: 1618 self.assertTrue(isinstance(grad, torch.Tensor)) 1619 for grad in grad_output: 1620 self.assertTrue(isinstance(grad, torch.Tensor)) 1621 return tuple(gi * 2 for gi in grad_input) 1622 1623 module.register_backward_hook(bw_hook) 1624 module(input).backward(torch.ones(5, 5)) 1625 expected_grad = sig_x * (1 - sig_x) * 2 1626 self.assertEqual(input.grad, expected_grad) 1627 1628 def test_hook_forward_preforward_writable(self): 1629 module = nn.Sigmoid() 1630 input = torch.randn(5, 5, requires_grad=True) 1631 sig_x = torch.nn.functional.sigmoid(input) 1632 1633 def forward_pre_hook(m, input): 1634 return torch.nn.functional.relu(input[0]) 1635 1636 def forward_hook(m, input, output): 1637 return -output 1638 1639 module.register_forward_pre_hook(forward_pre_hook) 1640 module.register_forward_hook(forward_hook) 1641 output = module(input) 1642 expected_res = -torch.nn.functional.sigmoid(torch.nn.functional.relu(input)) 1643 self.assertEqual(output, expected_res) 1644 output.backward(torch.ones(5, 5) * 2, retain_graph=True) 1645 mask = input > 0 1646 expected_grad = -sig_x * (1 - sig_x) * 2 * mask 1647 self.assertEqual(input.grad, expected_grad) 1648 1649 def test_hook_buffer_registration(self): 1650 for return_buffer in (True, False): 1651 1652 def buffer_registration_hook(module, name, buffer): 1653 buffer.registered = True 1654 if return_buffer: 1655 return buffer 1656 1657 handle = torch.nn.modules.module.register_module_buffer_registration_hook( 1658 buffer_registration_hook 1659 ) 1660 try: 1661 l, n, s = _create_basic_net() 1662 for b in s.buffers(): 1663 self.assertTrue(getattr(b, "registered", False)) 1664 finally: 1665 handle.remove() 1666 1667 def test_hook_submodule_registration(self): 1668 for return_submodule in (True, False): 1669 1670 def module_registration_hook(module, name, submodule): 1671 module.registered = True 1672 submodule.registered = True 1673 if return_submodule: 1674 return submodule 1675 1676 handle = torch.nn.modules.module.register_module_module_registration_hook( 1677 module_registration_hook 1678 ) 1679 try: 1680 l, n, s = _create_basic_net() 1681 for m in s.modules(): 1682 self.assertTrue(getattr(m, "registered", False)) 1683 finally: 1684 handle.remove() 1685 1686 def test_hook_parameter_registration(self): 1687 for return_parameter in (True, False): 1688 1689 def parameter_registration_hook(module, name, parameter): 1690 parameter.registered = True 1691 if return_parameter: 1692 return parameter 1693 1694 handle = ( 1695 torch.nn.modules.module.register_module_parameter_registration_hook( 1696 parameter_registration_hook 1697 ) 1698 ) 1699 try: 1700 l, n, s = _create_basic_net() 1701 for p in s.parameters(): 1702 self.assertTrue(getattr(p, "registered", False)) 1703 finally: 1704 handle.remove() 1705 1706 1707instantiate_parametrized_tests(TestModuleHooks) 1708instantiate_parametrized_tests(TestStateDictHooks) 1709 1710if __name__ == "__main__": 1711 run_tests() 1712