1# Owner(s): ["module: nn"] 2import re 3import unittest 4from copy import deepcopy 5from itertools import product 6 7import torch 8import torch.nn as nn 9from torch.testing._internal.common_nn import NNTestCase 10from torch.testing._internal.common_utils import ( 11 instantiate_parametrized_tests, 12 parametrize, 13 run_tests, 14 skipIfCrossRef, 15 skipIfTorchDynamo, 16 swap, 17 TEST_NUMPY, 18 TestCase, 19) 20from torch.utils._pytree import tree_map 21 22 23if TEST_NUMPY: 24 import numpy as np 25 26 27class TestLoadStateDict(NNTestCase): 28 _do_cuda_memory_leak_check = True 29 _do_cuda_non_default_stream = True 30 31 @unittest.skipIf(not TEST_NUMPY, "numpy not found") 32 @swap([True, False]) 33 def test_load_state_dict_invalid(self): 34 m = torch.nn.Linear(2, 2, bias=False) 35 36 state_dict = {"weight": np.random.randn(2, 2)} 37 with self.assertRaisesRegex( 38 RuntimeError, 39 "expected torch.Tensor or Tensor-like object from checkpoint but received", 40 ): 41 m.load_state_dict(state_dict) 42 43 state_dict = {"weight": ((1.0, 1.0), (2.0, 2.0))} 44 with self.assertRaisesRegex( 45 RuntimeError, 46 "expected torch.Tensor or Tensor-like object from checkpoint but received", 47 ): 48 m.load_state_dict(state_dict) 49 50 @swap([True, False]) 51 def test_load_state_dict_type(self): 52 m = nn.Module() 53 54 with self.assertRaisesRegex( 55 TypeError, "Expected state_dict to be dict-like, got" 56 ): 57 m.load_state_dict("") 58 with self.assertRaisesRegex( 59 TypeError, "Expected state_dict to be dict-like, got" 60 ): 61 m.load_state_dict(2) 62 63 @swap([True, False]) 64 @skipIfTorchDynamo("dynamo installs weakrefs on some params") 65 def test_load_state_dict(self): 66 l = nn.Linear(5, 5) 67 block = nn.Module() 68 block.conv1 = nn.Conv2d(3, 3, 3, bias=True) 69 block.conv2 = nn.Conv2d(3, 3, 3, bias=False) 70 net = nn.Module() 71 net.linear1 = l 72 net.linear2 = l 73 net.bn = nn.BatchNorm2d(2) 74 net.block = block 75 net.add_module("empty", None) 76 conv1_bias_dtype = block.conv1.bias.dtype 77 78 state_dict = net.state_dict() 79 state_dict.update( 80 { 81 "linear1.weight": torch.ones(5, 5), 82 "block.conv1.bias": torch.arange(1, 4, dtype=conv1_bias_dtype), 83 "bn.running_mean": torch.randn(2), 84 } 85 ) 86 # Also test if a DDP state_dict can be loaded from a local model. 87 ddp_state_dict = net.state_dict() 88 ddp_state_dict.update( 89 { 90 "module.linear1.weight": torch.ones(5, 5), 91 "module.block.conv1.bias": torch.arange(1, 4, dtype=conv1_bias_dtype), 92 "module.bn.running_mean": torch.randn(2), 93 } 94 ) 95 torch.nn.modules.utils.consume_prefix_in_state_dict_if_present( 96 ddp_state_dict, "module." 97 ) 98 for sd in [state_dict, ddp_state_dict]: 99 incompatible_keys = net.load_state_dict(sd) 100 self.assertEqual(len(incompatible_keys.missing_keys), 0) 101 self.assertEqual(len(incompatible_keys.unexpected_keys), 0) 102 self.assertNotIn("Incompatible", str(incompatible_keys)) 103 self.assertEqual(net.linear1.weight, sd["linear1.weight"]) 104 self.assertEqual(net.block.conv1.bias, sd["block.conv1.bias"]) 105 self.assertEqual(net.bn.running_mean, sd["bn.running_mean"]) 106 107 state_dict = net.state_dict() 108 state_dict.update({"extra": torch.ones(5)}) 109 self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict)) 110 incompatible_keys = net.load_state_dict(state_dict, strict=False) 111 self.assertEqual(len(incompatible_keys.missing_keys), 0) 112 self.assertEqual(len(incompatible_keys.unexpected_keys), 1) 113 self.assertIn("extra", incompatible_keys.unexpected_keys) 114 self.assertIn("Incompatible", str(incompatible_keys)) 115 116 state_dict = net.state_dict() 117 state_dict.update({"extra.param": torch.ones(5)}) 118 self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict)) 119 incompatible_keys = net.load_state_dict(state_dict, strict=False) 120 self.assertEqual(len(incompatible_keys.missing_keys), 0) 121 self.assertEqual(len(incompatible_keys.unexpected_keys), 1) 122 self.assertIn("extra.param", incompatible_keys.unexpected_keys) 123 124 state_dict = net.state_dict() 125 del state_dict["linear1.weight"] 126 self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict)) 127 incompatible_keys = net.load_state_dict(state_dict, strict=False) 128 self.assertEqual(len(incompatible_keys.missing_keys), 1) 129 self.assertEqual(len(incompatible_keys.unexpected_keys), 0) 130 self.assertIn("linear1.weight", incompatible_keys.missing_keys) 131 state_dict.update({"extra.param": torch.ones(5)}) 132 self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict)) 133 incompatible_keys = net.load_state_dict(state_dict, strict=False) 134 self.assertEqual(len(incompatible_keys.missing_keys), 1) 135 self.assertEqual(len(incompatible_keys.unexpected_keys), 1) 136 self.assertIn("linear1.weight", incompatible_keys.missing_keys) 137 self.assertIn("extra.param", incompatible_keys.unexpected_keys) 138 139 state_dict = net.state_dict() 140 state_dict.update({"bn.running_mean": torch.rand(14, 4)}) # wrong size 141 self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict)) 142 self.assertRaises( 143 RuntimeError, lambda: net.load_state_dict(state_dict, strict=False) 144 ) 145 146 state_dict = net.state_dict() 147 old_state_dict = deepcopy(state_dict) 148 state_dict = { 149 "linear1.weight": torch.ones(5, 5), 150 "block.conv1.bias": torch.arange(1, 4, dtype=conv1_bias_dtype), 151 "bn.running_mean": torch.randn(2), 152 "nonexistent_key": torch.rand(3), 153 } 154 net.load_state_dict(state_dict, strict=False) 155 self.assertEqual(net.linear1.weight, state_dict["linear1.weight"]) 156 self.assertEqual(net.block.conv1.bias, state_dict["block.conv1.bias"]) 157 self.assertEqual(net.bn.running_mean, state_dict["bn.running_mean"]) 158 new_state_dict = net.state_dict() 159 del old_state_dict["linear1.weight"] 160 del old_state_dict["block.conv1.bias"] 161 del old_state_dict["bn.running_mean"] 162 for ( 163 k, 164 v, 165 ) in old_state_dict.items(): 166 self.assertTrue(v.equal(new_state_dict[k])) 167 168 @swap([True, False]) 169 def test_load_state_dict_BC(self): 170 # BatchNormNd 171 # Added num_batches_tracked buffer at version 2. For state dict with 172 # earlier versions or no versions, it should provide default value of 0. 173 bn = nn.BatchNorm2d(3) 174 state_dict = bn.state_dict() 175 del state_dict["num_batches_tracked"] 176 state_dict._metadata[""]["version"] = 1 # version 1 177 bn.load_state_dict(state_dict) 178 self.assertEqual(bn.num_batches_tracked.dtype, torch.long) 179 self.assertEqual(bn.num_batches_tracked.item(), 0) 180 del state_dict._metadata[""]["version"] # no version 181 bn.load_state_dict(state_dict) 182 self.assertEqual(bn.num_batches_tracked.dtype, torch.long) 183 self.assertEqual(bn.num_batches_tracked.item(), 0) 184 185 @swap([True, False]) 186 def test_load_state_dict_child(self): 187 base_module = nn.Linear(1, 1) 188 model = base_module 189 for _ in range(3): 190 model = nn.Sequential(*[deepcopy(model) for _ in range(10)]) 191 192 def hook_fn( 193 module, 194 state_dict, 195 prefix, 196 local_metadata, 197 strict, 198 missing_keys, 199 unexpected_keys, 200 error_msgs, 201 ): 202 module_state_dict = module.state_dict() 203 self.assertEqual(len(module_state_dict.keys()), len(state_dict.keys())) 204 205 model[0][0].register_load_state_dict_pre_hook(hook_fn) 206 model.load_state_dict(model.state_dict(), strict=True) 207 208 # fails swapping as LSTM installs weak references on the parameters 209 @swap([False]) 210 @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons") 211 def test_load_state_dict_ref_cycle(self): 212 # load_state_dict shouldn't cause a reference cycle involving Tensors 213 import gc 214 215 m = torch.nn.LSTM(16, 16, bidirectional=True) 216 217 gc.collect() 218 m.load_state_dict(deepcopy(m).state_dict()) 219 refcycles = gc.collect() 220 221 self.assertEqual(refcycles, 0) 222 223 @swap([True, False]) 224 def test_load_state_dict_custom(self): 225 class CustomState(nn.Module): 226 def __init__(self) -> None: 227 super().__init__() 228 self.param = torch.nn.Parameter(torch.ones(1)) 229 self.sub = torch.nn.Linear(5, 5) 230 231 def _save_to_state_dict(self, destination, prefix, keep_vars): 232 destination[prefix + "serialized"] = self.param.data + 1 233 234 def _load_from_state_dict( 235 self, 236 state_dict, 237 prefix, 238 local_metadata, 239 strict, 240 missing_keys, 241 unexpected_keys, 242 error_msgs, 243 ): 244 # skip some of the error handling 245 self.param.data.copy_(state_dict[prefix + "serialized"] - 1) 246 247 # use sequential to verify nesting 248 m = nn.Sequential(CustomState()) 249 with torch.no_grad(): 250 m[0].param[0] = 10 251 m[0].sub.weight[0, 0] = 555 252 state_dict = m.state_dict() 253 self.assertEqual(state_dict["0.serialized"].item(), 11) 254 self.assertIn("0.sub.weight", state_dict) 255 self.assertNotIn("0.param", state_dict) 256 del m 257 mm = nn.Sequential(CustomState()) 258 self.assertEqual(mm[0].param[0].item(), 1) 259 mm.load_state_dict(state_dict) 260 self.assertEqual(mm[0].param[0].item(), 10) 261 self.assertEqual(mm[0].sub.weight[0, 0].item(), 555) 262 263 @swap([True, False]) 264 @parametrize("keep_vars", [True, False]) 265 def test_load_state_dict_assign_meta(self, keep_vars): 266 class MyModule(torch.nn.Module): 267 def __init__(self) -> None: 268 super().__init__() 269 self.fc1 = nn.Linear(3, 5) 270 self.bn = nn.BatchNorm1d(5) 271 self.x = nn.Parameter(torch.rand(5), requires_grad=False) 272 273 def forward(self, input): 274 return self.x + self.bn(self.fc1(input)) 275 276 swap = torch.__future__.get_swap_module_params_on_conversion() 277 net = MyModule() 278 state_dict = net.state_dict(keep_vars=keep_vars) 279 for v in state_dict.values(): 280 v.requires_grad_(False) 281 282 with torch.device("meta"): 283 net_meta = MyModule() 284 285 net_meta_state_dict_old = net_meta.state_dict(keep_vars=True) 286 net_meta.load_state_dict(state_dict, assign=True) 287 288 # Make sure parameters and persistent buffers were assigned 289 net_meta_state_dict = net_meta.state_dict(keep_vars=True) 290 for key in state_dict.keys(): 291 if key in net_meta._parameters: 292 if keep_vars and not swap: 293 # state_dict[key] is an nn.Parameter 294 self.assertTrue(state_dict[key] is net_meta_state_dict[key]) 295 else: 296 if swap: 297 self.assertTrue( 298 net_meta_state_dict[key] is net_meta_state_dict_old[key] 299 ) 300 else: 301 # state_dict[key] is not an nn.Parameter so it will be detached when wrapping with a Parameter 302 self.assertTrue( 303 net_meta_state_dict[key] is not net_meta_state_dict_old[key] 304 ) 305 self.assertEqual( 306 net_meta_state_dict_old[key].requires_grad, 307 net_meta_state_dict[key].requires_grad, 308 ) 309 self.assertEqual( 310 net_meta_state_dict_old[key].requires_grad, 311 net_meta_state_dict[key].requires_grad, 312 ) 313 self.assertEqual(state_dict[key], net_meta_state_dict[key]) 314 elif ( 315 key in net_meta._buffers 316 and key not in net_meta._non_persistent_buffers_set 317 ): 318 self.assertTrue(state_dict[key] is net_meta_state_dict[key]) 319 self.assertEqual(state_dict[key], net_meta_state_dict[key]) 320 321 # Make sure that ordering of parameters and buffers is preserved 322 net_named_parameters = net.named_parameters() 323 net_named_buffers = net.named_buffers() 324 net_meta_named_parameters = net_meta.named_parameters() 325 net_meta_named_buffers = net_meta.named_buffers() 326 327 for (n1, _), (n2, _) in zip(net_named_parameters, net_meta_named_parameters): 328 self.assertEqual(n1, n2) 329 330 for (n1, _), (n2, _) in zip(net_named_buffers, net_meta_named_buffers): 331 self.assertEqual(n1, n2) 332 333 # Make sure outputs are the same 334 t = torch.randn(4, 3) 335 out_net = net(t) 336 out_net_meta = net_meta(t.clone()) 337 338 self.assertEqual(out_net, out_net_meta) 339 340 @swap([True, False]) 341 def test_load_state_dict_assign_with_optimizer(self): 342 class MyModule(torch.nn.Module): 343 def __init__(self) -> None: 344 super().__init__() 345 self.fc1 = nn.Linear(3, 5) 346 self.bn = nn.BatchNorm1d(5) 347 348 def forward(self, input): 349 return self.bn(self.fc1(input)) 350 351 net = MyModule() 352 opt = torch.optim.Adam(net.parameters(), lr=1000) 353 x = torch.randn(4, 3) 354 num_iters = 3 355 356 for i in range(num_iters): 357 opt.zero_grad() 358 out = net(x) 359 out.sum().backward() 360 opt.step() 361 362 opt_state_dict = deepcopy(opt.state_dict()) 363 net_state_dict = deepcopy(net.state_dict()) 364 365 with torch.device("meta"): 366 net_meta = MyModule() 367 368 net_meta.load_state_dict(net_state_dict, assign=True) 369 # must create optimizer only after loading state_dict when assign=True 370 opt2 = torch.optim.Adam(net_meta.parameters(), lr=1000) 371 opt2.load_state_dict(opt_state_dict) 372 373 y = x.clone() 374 for i in range(num_iters): 375 opt.zero_grad() 376 out = net(x) 377 out.sum().backward() 378 opt.step() 379 380 opt2.zero_grad() 381 out2 = net_meta(y) 382 out2.sum().backward() 383 opt2.step() 384 385 self.assertEqual(opt.state_dict(), opt2.state_dict()) 386 self.assertEqual(net.state_dict(), net_meta.state_dict()) 387 388 @swap([True, False]) 389 def test_load_state_dict_assign_shape_stride(self): 390 # Assigned tensor is allowed to have different properties than initial 391 # tensor except for shape 392 class MyModule(torch.nn.Module): 393 def __init__(self) -> None: 394 super().__init__() 395 self.fc1 = nn.Linear(3, 5) 396 self.bn = nn.BatchNorm1d(5) 397 398 def forward(self, input): 399 return self.bn(self.fc1(input)) 400 401 net = MyModule() 402 state_dict = net.state_dict() 403 # loading should be ok if stride is different 404 state_dict["fc1.weight"] = torch.randn(3, 5).transpose(0, 1) 405 net2 = MyModule() 406 net2.load_state_dict(state_dict, strict=False, assign=True) 407 408 state_dict["fc1.weight"] = torch.randn(2, 4) 409 with self.assertRaisesRegex( 410 RuntimeError, "size mismatch for fc1.weight: copying a param with shape" 411 ): 412 net2.load_state_dict(state_dict, strict=False, assign=True) 413 414 @swap([True, False]) 415 def test_load_state_dict_warn_assign(self): 416 with torch.device("meta"): 417 m = torch.nn.Linear(3, 5) 418 state_dict = m.state_dict() 419 state_dict["weight"] = torch.empty_like(state_dict["weight"], device="cpu") 420 with self.assertWarnsRegex( 421 UserWarning, 422 "for weight: copying from a non-meta parameter in the checkpoint to a meta", 423 ): 424 m.load_state_dict(state_dict) 425 426 @swap([True, False]) 427 def test_load_state_dict_with_unexpected_key(self): 428 class MyModule(torch.nn.Module): 429 def __init__(self) -> None: 430 super().__init__() 431 self.fc1 = torch.nn.Linear(5, 10) 432 433 m = MyModule() 434 435 # Unexpected key & strict = True 436 with self.assertRaisesRegex(RuntimeError, "Unexpected key"): 437 state_dict = m.state_dict() 438 state_dict["fc1.bad_suffix"] = torch.randn(5, 10) 439 m.load_state_dict(state_dict) 440 441 # Unexpected key & strict = False 442 state_dict = m.load_state_dict(state_dict, strict=False) 443 self.assertIn("fc1.bad_suffix", state_dict.unexpected_keys) 444 445 # Unexpected key whose prefix matches a valid key & strict = True 446 with self.assertRaisesRegex(RuntimeError, "Unexpected key"): 447 state_dict = m.state_dict() 448 state_dict["fc1.weight.bad_suffix"] = torch.randn(5, 10) 449 m.load_state_dict(state_dict) 450 451 # Unexpected key whose prefix matches a valid key & strict = False 452 state_dict = m.load_state_dict(state_dict, strict=False) 453 self.assertIn("fc1.weight.bad_suffix", state_dict.unexpected_keys) 454 455 456def load_torch_function_handler(cls, func, types, args=(), kwargs=None): 457 kwargs = {} if kwargs is None else kwargs 458 459 def module_load(dest, src, assign=False): 460 if isinstance(dest, cls): 461 if assign: 462 return src.detach() 463 else: 464 if type(src) is torch.Tensor: 465 return cls(src) 466 elif type(src) is cls: 467 return src.detach() 468 else: 469 if isinstance(src, MyWrapperLoadTensor): 470 return cls(src._data) 471 return cls(src) 472 else: 473 assert isinstance( 474 src, cls 475 ), f"Expected isinstance(src, {cls}) but got {type(src)}" 476 assert ( 477 type(dest) == torch.Tensor 478 or type(dest) == torch.nn.Parameter 479 or issubclass(cls, type(dest)) 480 ) 481 if assign: 482 return src.detach() 483 else: 484 if isinstance(src, MyWrapperLoadTensor): 485 if type(dest) not in {torch.Tensor, torch.nn.Parameter}: 486 return type(dest)(src._data) 487 else: 488 return src._data.detach() 489 else: 490 return torch.Tensor(src) 491 492 if func is torch.Tensor.module_load: 493 return module_load(*args, **kwargs) 494 else: 495 with torch._C.DisableTorchFunctionSubclass(): 496 # detach must return instance of same subclass for nn.Parameter() 497 if func == torch.Tensor.detach: 498 ret = func(*args, **kwargs) 499 if not isinstance(ret, cls): 500 return cls(ret) 501 return ret 502 return func(*args, **kwargs) 503 504 505class MyLoadTensor(torch.Tensor): 506 @classmethod 507 def __torch_function__(cls, func, types, args=(), kwargs=None): 508 return load_torch_function_handler(cls, func, types, args, kwargs) 509 510 511# We use MyLoadTensor2 to test tensor subclass, wrapper tensor subclass 512# where neither inherits from each other 513class MyLoadTensor2(torch.Tensor): 514 @classmethod 515 def __torch_function__(cls, func, types, args=(), kwargs=None): 516 return load_torch_function_handler(cls, func, types, args, kwargs) 517 518 519class MyBrokenLoadTensor(torch.Tensor): 520 @classmethod 521 def __torch_function__(cls, func, types, args=(), kwargs=None): 522 kwargs = {} if kwargs is None else kwargs 523 524 if func is torch.Tensor.module_load: 525 # wrong as this doesn't detach! 526 return args[1] 527 else: 528 with torch._C.DisableTorchFunctionSubclass(): 529 # detach must return instance of same subclass for nn.Parameter() 530 if func == torch.Tensor.detach: 531 return cls(func(*args, **kwargs)) 532 return func(*args, **kwargs) 533 534 535class MyWrapperLoadTensor(MyLoadTensor): 536 @staticmethod 537 def __new__(cls, data: torch.Tensor): 538 t = torch.Tensor._make_wrapper_subclass( 539 cls, 540 data.size(), 541 dtype=data.dtype, 542 layout=data.layout, 543 device=data.device, 544 requires_grad=data.requires_grad, 545 strides=data.stride(), 546 storage_offset=data.storage_offset(), 547 ) 548 return t 549 550 def __init__(self, data: torch.Tensor): 551 self._data = data 552 553 def __repr__(self): 554 return f"MyWrapperLoadTensor({self._data.__repr__()})" 555 556 @classmethod 557 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 558 def unwrap(t): 559 return t._data if isinstance(t, MyWrapperLoadTensor) else t 560 561 def wrap(t): 562 return MyWrapperLoadTensor(t) if isinstance(t, torch.Tensor) else t 563 564 kwargs = {} if kwargs is None else kwargs 565 out = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) 566 return tree_map(wrap, out) 567 568 569class TestLoadStateDictSwap(TestCase): 570 @skipIfCrossRef 571 @skipIfTorchDynamo("Can't swap with dynamo as dynamo installs weakrefs") 572 @swap([True]) 573 @parametrize("assign", [True, False]) 574 def test_swap_subclass(self, assign): 575 def _create_model(subclass=None): 576 m = torch.nn.Linear(2, 3, bias=False) 577 m.buf = torch.nn.Buffer(torch.randn(2, 3)) 578 if subclass is not None: 579 m.weight = torch.nn.Parameter(subclass(m.weight)) 580 m.buf = subclass(m.buf) 581 return m 582 583 def _test(m_subclass=None, sd_subclass=None): 584 m = _create_model(m_subclass) 585 sd = _create_model(sd_subclass).state_dict() 586 m.load_state_dict(sd, assign=assign) 587 self.assertEqual(m.weight, sd["weight"]) 588 self.assertEqual(m.buf, sd["buf"]) 589 self.assertTrue(isinstance(m.weight, torch.nn.Parameter)) 590 self.assertTrue(not isinstance(m.buf, torch.nn.Parameter)) 591 592 weight_type, buf_type = (torch.nn.Parameter, torch.Tensor) 593 if assign: 594 if sd_subclass is not None: 595 weight_type, buf_type = (sd_subclass, sd_subclass) 596 else: 597 if m_subclass is not None: 598 weight_type, buf_type = (m_subclass, m_subclass) 599 600 self.assertTrue(type(m.weight) is weight_type) 601 self.assertTrue(type(m.buf) is buf_type) 602 603 # (MyLoadTensor, MyWrapperLoadTensor) tests the behavior of (superclass, subclass) 604 subclasses = [None, MyLoadTensor, MyLoadTensor2, MyWrapperLoadTensor] 605 for m_s, sd_s in product(subclasses, subclasses): 606 _test(m_s, sd_s) 607 608 # MyBrokenLoadTensor should error since its module_load doesn't call .detach() 609 with self.assertRaisesRegex( 610 RuntimeError, re.escape("Error(s) in loading state_dict for Linear:") 611 ): 612 _test(None, MyBrokenLoadTensor) 613 614 615instantiate_parametrized_tests(TestLoadStateDict) 616instantiate_parametrized_tests(TestLoadStateDictSwap) 617 618if __name__ == "__main__": 619 TestCase._default_dtype_check_enabled = True 620 run_tests() 621