1# Owner(s): ["module: nn"] 2import pickle 3import unittest 4 5import torch 6import torch.nn as nn 7from torch.nn import Buffer, Parameter 8from torch.nn.parameter import UninitializedBuffer, UninitializedParameter 9from torch.testing._internal.common_cuda import TEST_CUDA 10from torch.testing._internal.common_utils import ( 11 run_tests, 12 suppress_warnings, 13 TEST_PRIVATEUSE1, 14 TestCase, 15) 16 17 18class LazyModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module): 19 pass 20 21 22class TestLazyModules(TestCase): 23 @suppress_warnings 24 def test_lazy_module_parameter(self): 25 module = LazyModule() 26 module.register_parameter("test_param", UninitializedParameter()) 27 self.assertTrue(module.has_uninitialized_params()) 28 state_dict = module.state_dict() 29 self.assertIsInstance(state_dict["test_param"], UninitializedParameter) 30 new_module = LazyModule() 31 # An error is raised when there is an attempt to replace an existing parameter 32 # with an uninitialized one 33 new_module.register_parameter("test_param", nn.Parameter(torch.ones(5, 5))) 34 with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"): 35 new_module.load_state_dict(state_dict) 36 # Uninitialized parameters are overriden when the state dict to be loaded contains a valid one 37 new_module = LazyModule() 38 new_module.register_parameter("test_param", nn.Parameter(torch.ones(5, 5))) 39 module.load_state_dict(new_module.state_dict()) 40 self.assertEqual(module.test_param, torch.ones((5, 5))) 41 42 # Uninitialized parameters are left unchanged 43 module = LazyModule() 44 module.register_parameter("test_param", UninitializedParameter()) 45 self.assertTrue(module.has_uninitialized_params()) 46 47 new_module = LazyModule() 48 new_module.register_parameter("test_param", UninitializedParameter()) 49 module.load_state_dict(new_module.state_dict()) 50 self.assertTrue(module.has_uninitialized_params()) 51 52 @suppress_warnings 53 def test_lazy_module_buffer(self): 54 module = LazyModule() 55 module.test_buffer = UninitializedBuffer() 56 self.assertTrue(module.has_uninitialized_params()) 57 state_dict = module.state_dict() 58 self.assertIsInstance(state_dict["test_buffer"], UninitializedBuffer) 59 new_module = LazyModule() 60 # An error is raised when there is an attempt to replace an existing parameter 61 # with an uninitialized one 62 new_module.test_buffer = Buffer(torch.ones(5, 5)) 63 with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"): 64 new_module.load_state_dict(state_dict) 65 # Uninitialized parameters are overriden when the state dict to be loaded contains a valid one 66 new_module = LazyModule() 67 new_module.test_buffer = Buffer(torch.ones(5, 5)) 68 module.load_state_dict(new_module.state_dict()) 69 self.assertEqual(module.test_buffer, torch.ones((5, 5))) 70 71 # Uninitialized parameters are left unchanged 72 module = LazyModule() 73 module.test_buffer = UninitializedBuffer() 74 self.assertTrue(module.has_uninitialized_params()) 75 76 new_module = LazyModule() 77 new_module.test_buffer = UninitializedBuffer() 78 module.load_state_dict(new_module.state_dict()) 79 module.load_state_dict(new_module.state_dict()) 80 self.assertTrue(module.has_uninitialized_params()) 81 82 @suppress_warnings 83 def test_lazy_module_jit_param(self): 84 module = LazyModule() 85 module.register_parameter("test_param", UninitializedParameter()) 86 self.assertTrue(module.has_uninitialized_params()) 87 with self.assertRaisesRegex(RuntimeError, "run a forward pass"): 88 torch.jit.script(module) 89 90 @suppress_warnings 91 def test_lazy_module_jit_buffer(self): 92 module = LazyModule() 93 module.test_buffer = UninitializedBuffer() 94 self.assertTrue(module.has_uninitialized_params()) 95 with self.assertRaisesRegex(RuntimeError, "run a forward pass"): 96 torch.jit.script(module) 97 98 @suppress_warnings 99 def test_lazy_share_memory_param(self): 100 module = LazyModule() 101 module.register_parameter("test_param", UninitializedParameter()) 102 self.assertTrue(module.has_uninitialized_params()) 103 with self.assertRaisesRegex(RuntimeError, "share memory on an uninitialized"): 104 module.share_memory() 105 106 @suppress_warnings 107 def test_lazy_share_memory_buffer(self): 108 module = LazyModule() 109 module.test_buffer = UninitializedBuffer() 110 self.assertTrue(module.has_uninitialized_params()) 111 with self.assertRaisesRegex(RuntimeError, "share memory on an uninitialized"): 112 module.share_memory() 113 114 @suppress_warnings 115 def test_linear(self): 116 module = nn.LazyLinear(10) 117 self.assertIsInstance(module.weight, UninitializedParameter) 118 self.assertIsInstance(module.bias, UninitializedParameter) 119 input = torch.ones(5, 5) 120 module(input) 121 self.assertIsInstance(module, nn.Linear) 122 self.assertNotIsInstance(module, nn.LazyLinear) 123 self.assertTrue(module.weight.shape == (10, 5)) 124 self.assertTrue(module.bias.shape == (10,)) 125 y = module(input) 126 self.assertTrue( 127 torch.equal( 128 torch.nn.functional.linear(input, module.weight, module.bias), y 129 ) 130 ) 131 132 @suppress_warnings 133 def test_lazy_linear_pickle(self): 134 module = nn.LazyLinear(10) 135 self.assertIsInstance(module.weight, UninitializedParameter) 136 self.assertIsInstance(module.bias, UninitializedParameter) 137 module = pickle.loads(pickle.dumps(module)) 138 self.assertIsInstance(module, nn.LazyLinear) 139 self.assertIsInstance(module.weight, UninitializedParameter) 140 self.assertIsInstance(module.bias, UninitializedParameter) 141 input = torch.ones(5, 5) 142 module(input) # fully materialized 143 new_module = pickle.loads(pickle.dumps(module)) 144 self.assertIsInstance(new_module, nn.Linear) 145 self.assertNotIsInstance(new_module, nn.LazyLinear) 146 self.assertTrue(new_module.weight.shape == (10, 5)) 147 self.assertNotIsInstance(new_module.weight, UninitializedParameter) 148 self.assertTrue(new_module.bias.shape == (10,)) 149 self.assertNotIsInstance(new_module.bias, UninitializedParameter) 150 151 @suppress_warnings 152 def test_linear_state(self): 153 module = nn.Linear(5, 10) 154 lazy_module = nn.LazyLinear(10) 155 lazy_module.load_state_dict(module.state_dict()) 156 # Parameters have been initialized but the module won't become a full 157 # Linear one until the first iteration. This is due to 158 # limitations on the state_dict loading logic 159 self.assertFalse(lazy_module.has_uninitialized_params()) 160 self.assertTrue(lazy_module.weight.shape == (10, 5)) 161 self.assertTrue(lazy_module.bias.shape == (10,)) 162 163 module = nn.Linear(5, 10) 164 lazy_module = nn.LazyLinear(10) 165 with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"): 166 module.load_state_dict(lazy_module.state_dict()) 167 168 def _check_lazy_conv( 169 self, 170 cls, 171 lazy_cls, 172 func, 173 init_args, 174 input_shape, 175 expected_weight_shape, 176 expected_bias_shape, 177 *forward_args, 178 **forward_kwargs, 179 ): 180 module = lazy_cls(*init_args) 181 self.assertIsInstance(module.weight, UninitializedParameter) 182 if module.bias is not None: 183 self.assertIsInstance(module.bias, UninitializedParameter) 184 input = torch.ones(*input_shape) 185 module(input, *forward_args, **forward_kwargs) 186 self.assertIsInstance(module, cls) 187 self.assertNotIsInstance(module, lazy_cls) 188 self.assertEqual(module.weight.shape, expected_weight_shape) 189 if module.bias is not None: 190 self.assertEqual(module.bias.shape, expected_bias_shape) 191 y = module(input) 192 self.assertTrue(torch.equal(func(input, module.weight, module.bias), y)) 193 194 def _check_lazy_conv_pickle( 195 self, 196 cls, 197 lazy_cls, 198 init_args, 199 input_shape, 200 expected_weight_shape, 201 expected_bias_shape, 202 ): 203 module = lazy_cls(*init_args) 204 self.assertIsInstance(module.weight, UninitializedParameter) 205 if module.bias is not None: 206 self.assertIsInstance(module.bias, UninitializedParameter) 207 module = pickle.loads(pickle.dumps(module)) 208 self.assertIsInstance(module, lazy_cls) 209 self.assertIsInstance(module.weight, UninitializedParameter) 210 if module.bias is not None: 211 self.assertIsInstance(module.bias, UninitializedParameter) 212 input = torch.ones(*input_shape) 213 module(input) # fully materialized 214 new_module = pickle.loads(pickle.dumps(module)) 215 self.assertIsInstance(new_module, cls) 216 self.assertNotIsInstance(new_module, lazy_cls) 217 self.assertEqual(new_module.weight.shape, expected_weight_shape) 218 self.assertNotIsInstance(new_module.weight, UninitializedParameter) 219 if new_module.bias is not None: 220 self.assertEqual(new_module.bias.shape, expected_bias_shape) 221 self.assertNotIsInstance(new_module.bias, UninitializedParameter) 222 223 def _check_lazy_conv_state( 224 self, gen_module, gen_lazy_module, expected_weight_shape, expected_bias_shape 225 ): 226 module = gen_module() 227 lazy_module = gen_lazy_module() 228 lazy_module.load_state_dict(module.state_dict()) 229 # Parameters have been initialized but the module won't become a full 230 # Conv one until the first iteration. This is due to 231 # limitations on the state_dict loading logic 232 self.assertFalse(lazy_module.has_uninitialized_params()) 233 self.assertEqual(lazy_module.weight.shape, expected_weight_shape) 234 if lazy_module.bias is not None: 235 self.assertEqual(lazy_module.bias.shape, expected_bias_shape) 236 237 module = gen_module() 238 lazy_module = gen_lazy_module() 239 with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"): 240 module.load_state_dict(lazy_module.state_dict()) 241 242 def test_lazy_pre_forward_hook(self): 243 """ 244 This test is to test whether lazymodule can register other pre-forward hook 245 functions successfully. 246 """ 247 248 class TestModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module): 249 def initialize_parameters(self, input): 250 return None 251 252 def forward(self, input): 253 return input 254 255 def hook_function(module, input): 256 return input[0] + 1 257 258 module = TestModule() 259 module.register_forward_pre_hook(hook_function) 260 output = module(torch.zeros(2, 2)) 261 self.assertEqual(output, torch.ones(2, 2)) 262 263 def test_lazy_forward_hook(self): 264 """ 265 This test is to test whether lazymodule can register other forward hook 266 functions successfully. 267 """ 268 269 class TestModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module): 270 def initialize_parameters(self, input): 271 return None 272 273 def forward(self, input): 274 return input 275 276 def hook_function(module, input, output): 277 return input[0] + 1 278 279 module = TestModule() 280 module.register_forward_hook(hook_function) 281 output = module(torch.zeros(2, 2)) 282 self.assertEqual(output, torch.ones(2, 2)) 283 284 @suppress_warnings 285 def test_lazy_conv1d(self): 286 self._check_lazy_conv( 287 nn.Conv1d, 288 nn.LazyConv1d, 289 torch.nn.functional.conv1d, 290 (32, 2), 291 (192, 16, 50), 292 (32, 16, 2), 293 (32,), 294 ) 295 296 @suppress_warnings 297 def test_lazy_conv1d_pickle(self): 298 self._check_lazy_conv_pickle( 299 nn.Conv1d, nn.LazyConv1d, (32, 2), (192, 16, 50), (32, 16, 2), (32,) 300 ) 301 302 @suppress_warnings 303 def test_lazy_conv1d_state(self): 304 self._check_lazy_conv_state( 305 lambda: nn.Conv1d(16, 32, 2), 306 lambda: nn.LazyConv1d(32, 2), 307 (32, 16, 2), 308 (32,), 309 ) 310 311 @suppress_warnings 312 def test_lazy_conv2d(self): 313 self._check_lazy_conv( 314 nn.Conv2d, 315 nn.LazyConv2d, 316 torch.nn.functional.conv2d, 317 (32, 2), 318 (192, 16, 8, 6), 319 (32, 16, 2, 2), 320 (32,), 321 ) 322 323 @suppress_warnings 324 def test_lazy_conv2d_pickle(self): 325 self._check_lazy_conv_pickle( 326 nn.Conv2d, nn.LazyConv2d, (32, 2), (192, 16, 8, 6), (32, 16, 2, 2), (32,) 327 ) 328 329 @suppress_warnings 330 def test_lazy_conv2d_state(self): 331 self._check_lazy_conv_state( 332 lambda: nn.Conv2d(16, 32, 2), 333 lambda: nn.LazyConv2d(32, 2), 334 (32, 16, 2, 2), 335 (32,), 336 ) 337 338 @suppress_warnings 339 def test_lazy_conv3d(self): 340 self._check_lazy_conv( 341 nn.Conv3d, 342 nn.LazyConv3d, 343 torch.nn.functional.conv3d, 344 (32, 2), 345 (192, 16, 8, 7, 6), 346 (32, 16, 2, 2, 2), 347 (32,), 348 ) 349 350 @suppress_warnings 351 def test_lazy_conv3d_pickle(self): 352 self._check_lazy_conv_pickle( 353 nn.Conv3d, 354 nn.LazyConv3d, 355 (32, 2), 356 (192, 16, 8, 7, 6), 357 (32, 16, 2, 2, 2), 358 (32,), 359 ) 360 361 @suppress_warnings 362 def test_lazy_conv3d_state(self): 363 self._check_lazy_conv_state( 364 lambda: nn.Conv3d(16, 32, 2), 365 lambda: nn.LazyConv3d(32, 2), 366 (32, 16, 2, 2, 2), 367 (32,), 368 ) 369 370 @suppress_warnings 371 def test_lazy_conv_transposed1d(self): 372 self._check_lazy_conv( 373 nn.ConvTranspose1d, 374 nn.LazyConvTranspose1d, 375 torch.nn.functional.conv_transpose1d, 376 (32, 2), 377 (192, 16, 50), 378 (16, 32, 2), 379 (32,), 380 ) 381 382 @suppress_warnings 383 def test_lazy_conv_transpose1d_kwargs(self): 384 self._check_lazy_conv( 385 nn.ConvTranspose1d, 386 nn.LazyConvTranspose1d, 387 torch.nn.functional.conv_transpose1d, 388 (32, 2), 389 (192, 16, 50), 390 (16, 32, 2), 391 (32,), 392 output_size=(51,), 393 ) 394 395 @suppress_warnings 396 def test_lazy_conv_transpose1d_pickle(self): 397 self._check_lazy_conv_pickle( 398 nn.ConvTranspose1d, 399 nn.LazyConvTranspose1d, 400 (32, 2), 401 (192, 16, 50), 402 (16, 32, 2), 403 (32,), 404 ) 405 406 @suppress_warnings 407 def test_lazy_conv_transpose1d_state(self): 408 self._check_lazy_conv_state( 409 lambda: nn.ConvTranspose1d(16, 32, 2), 410 lambda: nn.LazyConvTranspose1d(32, 2), 411 (16, 32, 2), 412 (32,), 413 ) 414 415 @suppress_warnings 416 def test_lazy_conv_transpose2d(self): 417 self._check_lazy_conv( 418 nn.ConvTranspose2d, 419 nn.LazyConvTranspose2d, 420 torch.nn.functional.conv_transpose2d, 421 (32, 2), 422 (192, 16, 8, 6), 423 (16, 32, 2, 2), 424 (32,), 425 ) 426 427 @suppress_warnings 428 def test_lazy_conv_transpose2d_kwargs(self): 429 self._check_lazy_conv( 430 nn.ConvTranspose2d, 431 nn.LazyConvTranspose2d, 432 torch.nn.functional.conv_transpose2d, 433 (32, 2), 434 (192, 16, 8, 6), 435 (16, 32, 2, 2), 436 (32,), 437 output_size=(9, 7), 438 ) 439 440 @suppress_warnings 441 def test_lazy_conv_transpose2d_pickle(self): 442 self._check_lazy_conv_pickle( 443 nn.ConvTranspose2d, 444 nn.LazyConvTranspose2d, 445 (32, 2), 446 (192, 16, 8, 6), 447 (16, 32, 2, 2), 448 (32,), 449 ) 450 451 @suppress_warnings 452 def test_lazy_conv_transpose2d_state(self): 453 self._check_lazy_conv_state( 454 lambda: nn.ConvTranspose2d(16, 32, 2), 455 lambda: nn.LazyConvTranspose2d(32, 2), 456 (16, 32, 2, 2), 457 (32,), 458 ) 459 460 @suppress_warnings 461 def test_lazy_conv_transpose3d(self): 462 self._check_lazy_conv( 463 nn.ConvTranspose3d, 464 nn.LazyConvTranspose3d, 465 torch.nn.functional.conv_transpose3d, 466 (32, 2), 467 (192, 16, 8, 7, 6), 468 (16, 32, 2, 2, 2), 469 (32,), 470 ) 471 472 @suppress_warnings 473 def test_lazy_conv_transpose3d_kwargs(self): 474 self._check_lazy_conv( 475 nn.ConvTranspose3d, 476 nn.LazyConvTranspose3d, 477 torch.nn.functional.conv_transpose3d, 478 (32, 2), 479 (192, 16, 8, 7, 6), 480 (16, 32, 2, 2, 2), 481 (32,), 482 output_size=(9, 8, 7), 483 ) 484 485 @suppress_warnings 486 def test_lazy_conv_transpose3d_pickle(self): 487 self._check_lazy_conv_pickle( 488 nn.ConvTranspose3d, 489 nn.LazyConvTranspose3d, 490 (32, 2), 491 (192, 16, 8, 7, 6), 492 (16, 32, 2, 2, 2), 493 (32,), 494 ) 495 496 @suppress_warnings 497 def test_lazy_conv_transpose3d_state(self): 498 self._check_lazy_conv_state( 499 lambda: nn.ConvTranspose3d(16, 32, 2), 500 lambda: nn.LazyConvTranspose3d(32, 2), 501 (16, 32, 2, 2, 2), 502 (32,), 503 ) 504 505 def _check_lazy_norm(self, cls, lazy_cls, input_shape): 506 for affine in [False, True]: 507 for track_running_stats in [False, True]: 508 lazy_module = lazy_cls( 509 affine=affine, track_running_stats=track_running_stats 510 ) 511 512 if affine: 513 self.assertIsInstance(lazy_module.weight, UninitializedParameter) 514 self.assertIsInstance(lazy_module.bias, UninitializedParameter) 515 if track_running_stats: 516 self.assertIsInstance(lazy_module.running_mean, UninitializedBuffer) 517 self.assertIsInstance(lazy_module.running_var, UninitializedBuffer) 518 519 input = torch.ones(*input_shape) 520 lazy_output = lazy_module(input) 521 self.assertIsInstance(lazy_module, cls) 522 self.assertNotIsInstance(lazy_module, lazy_cls) 523 524 num_features = input_shape[1] 525 module = cls( 526 num_features, affine=affine, track_running_stats=track_running_stats 527 ) 528 expected_output = module(input) 529 530 self.assertEqual(lazy_output, expected_output) 531 if module.weight is not None: 532 self.assertEqual(lazy_module.weight.shape, module.weight.shape) 533 self.assertEqual(lazy_module.weight, module.weight) 534 if module.bias is not None: 535 self.assertEqual(lazy_module.bias.shape, module.bias.shape) 536 self.assertEqual(lazy_module.bias, module.bias) 537 if module.running_mean is not None: 538 self.assertEqual( 539 lazy_module.running_mean.shape, module.running_mean.shape 540 ) 541 self.assertEqual(lazy_module.running_mean, module.running_mean) 542 if module.running_var is not None: 543 self.assertEqual( 544 lazy_module.running_var.shape, module.running_var.shape 545 ) 546 self.assertEqual(lazy_module.running_var, module.running_var) 547 if module.num_batches_tracked is not None: 548 self.assertEqual( 549 lazy_module.num_batches_tracked.shape, 550 module.num_batches_tracked.shape, 551 ) 552 self.assertEqual( 553 lazy_module.num_batches_tracked, module.num_batches_tracked 554 ) 555 556 def _check_lazy_norm_pickle(self, cls, lazy_cls, input_shape): 557 for affine in [False, True]: 558 for track_running_stats in [False, True]: 559 module = lazy_cls( 560 affine=affine, track_running_stats=track_running_stats 561 ) 562 module = pickle.loads(pickle.dumps(module)) 563 564 self.assertIsInstance(module, lazy_cls) 565 if affine: 566 self.assertIsInstance(module.weight, UninitializedParameter) 567 self.assertIsInstance(module.bias, UninitializedParameter) 568 if track_running_stats: 569 self.assertIsInstance(module.running_mean, UninitializedBuffer) 570 self.assertIsInstance(module.running_var, UninitializedBuffer) 571 572 input = torch.ones(*input_shape) 573 module(input) # fully materialized 574 module = pickle.loads(pickle.dumps(module)) 575 576 self.assertNotIsInstance(module, lazy_cls) 577 self.assertIsInstance(module, cls) 578 if affine: 579 self.assertNotIsInstance(module.weight, UninitializedParameter) 580 self.assertNotIsInstance(module.bias, UninitializedParameter) 581 if track_running_stats: 582 self.assertNotIsInstance(module.running_mean, UninitializedBuffer) 583 self.assertNotIsInstance(module.running_var, UninitializedBuffer) 584 585 def _check_lazy_batchnorm_state(self, cls, lazy_cls): 586 module = cls(10) 587 lazy_module = lazy_cls(affine=True, track_running_stats=True) 588 lazy_module.load_state_dict(module.state_dict()) 589 # Parameters have been initialized but the module won't become a full 590 # Conv one until the first iteration. This is due to 591 # limitations on the state_dict loading logic 592 self.assertFalse(lazy_module.has_uninitialized_params()) 593 self.assertEqual(lazy_module.weight.shape, (10,)) 594 self.assertEqual(lazy_module.bias.shape, (10,)) 595 self.assertEqual(lazy_module.running_mean.shape, (10,)) 596 self.assertEqual(lazy_module.running_var.shape, (10,)) 597 598 module = cls(10) 599 lazy_module = lazy_cls() 600 with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"): 601 module.load_state_dict(lazy_module.state_dict()) 602 603 def _check_lazy_instancenorm_state(self, cls, lazy_cls): 604 for affine in [False, True]: 605 for track_running_stats in [False, True]: 606 module = cls(10, affine=affine, track_running_stats=track_running_stats) 607 lazy_module = lazy_cls( 608 affine=affine, track_running_stats=track_running_stats 609 ) 610 lazy_module.load_state_dict(module.state_dict()) 611 # Parameters have been initialized but the module won't become a full 612 # InstanceNorm one until the first iteration. This is due to 613 # limitations on the state_dict loading logic 614 self.assertFalse(lazy_module.has_uninitialized_params()) 615 if affine: 616 self.assertEqual(lazy_module.weight.shape, (10,)) 617 self.assertEqual(lazy_module.bias.shape, (10,)) 618 if track_running_stats: 619 self.assertEqual(lazy_module.running_mean.shape, (10,)) 620 self.assertEqual(lazy_module.running_var.shape, (10,)) 621 622 module = cls(10, affine=True, track_running_stats=True) 623 lazy_module = lazy_cls(affine=True, track_running_stats=True) 624 with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"): 625 module.load_state_dict(lazy_module.state_dict()) 626 627 def _check_lazy_norm_with_dict_input(self, cls, lazy_cls, input_shape): 628 input = {"input": torch.ones(*input_shape)} 629 630 lazy_module = lazy_cls() 631 lazy_output = lazy_module(**input) 632 633 num_features = input_shape[1] 634 module = cls(num_features) 635 expected_output = module(**input) 636 637 self.assertEqual(lazy_output, expected_output) 638 639 def test_lazy_batchnorm1d(self): 640 self._check_lazy_norm(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6)) 641 self._check_lazy_norm(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 6)) 642 643 def test_lazy_batchnorm1d_pickle(self): 644 self._check_lazy_norm_pickle(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6)) 645 self._check_lazy_norm_pickle(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 6)) 646 647 def test_lazy_batchnorm1d_state(self): 648 self._check_lazy_batchnorm_state(nn.BatchNorm1d, nn.LazyBatchNorm1d) 649 self._check_lazy_batchnorm_state(nn.BatchNorm1d, nn.LazyBatchNorm1d) 650 651 def test_lazy_batchnorm2d(self): 652 self._check_lazy_norm(nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7)) 653 654 def test_lazy_batchnorm2d_pickle(self): 655 self._check_lazy_norm_pickle(nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7)) 656 657 def test_lazy_batchnorm2d_state(self): 658 self._check_lazy_batchnorm_state(nn.BatchNorm2d, nn.LazyBatchNorm2d) 659 self._check_lazy_batchnorm_state(nn.BatchNorm2d, nn.LazyBatchNorm2d) 660 661 def test_lazy_batchnorm3d(self): 662 self._check_lazy_norm(nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8)) 663 664 def test_lazy_batchnorm3d_pickle(self): 665 self._check_lazy_norm_pickle( 666 nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8) 667 ) 668 669 def test_lazy_batchnorm3d_state(self): 670 self._check_lazy_batchnorm_state(nn.BatchNorm3d, nn.LazyBatchNorm3d) 671 self._check_lazy_batchnorm_state(nn.BatchNorm3d, nn.LazyBatchNorm3d) 672 673 def test_lazy_instancenorm1d(self): 674 self._check_lazy_norm(nn.InstanceNorm1d, nn.LazyInstanceNorm1d, (16, 3, 6)) 675 676 def test_lazy_instancenorm1d_pickle(self): 677 self._check_lazy_norm_pickle( 678 nn.InstanceNorm1d, nn.LazyInstanceNorm1d, (16, 3, 6) 679 ) 680 681 def test_lazy_instancenorm1d_state(self): 682 self._check_lazy_instancenorm_state(nn.InstanceNorm1d, nn.LazyInstanceNorm1d) 683 self._check_lazy_instancenorm_state(nn.InstanceNorm1d, nn.LazyInstanceNorm1d) 684 685 def test_lazy_instancenorm2d(self): 686 self._check_lazy_norm(nn.InstanceNorm2d, nn.LazyInstanceNorm2d, (16, 3, 6, 7)) 687 688 def test_lazy_instancenorm2d_pickle(self): 689 self._check_lazy_norm_pickle( 690 nn.InstanceNorm2d, nn.LazyInstanceNorm2d, (16, 3, 6, 7) 691 ) 692 693 def test_lazy_instancenorm2d_state(self): 694 self._check_lazy_instancenorm_state(nn.InstanceNorm2d, nn.LazyInstanceNorm2d) 695 self._check_lazy_instancenorm_state(nn.InstanceNorm2d, nn.LazyInstanceNorm2d) 696 697 def test_lazy_instancenorm3d(self): 698 self._check_lazy_norm( 699 nn.InstanceNorm3d, nn.LazyInstanceNorm3d, (16, 3, 6, 7, 8) 700 ) 701 702 def test_lazy_instancenorm3d_pickle(self): 703 self._check_lazy_norm_pickle( 704 nn.InstanceNorm3d, nn.LazyInstanceNorm3d, (16, 3, 6, 7, 8) 705 ) 706 707 def test_lazy_instancenorm3d_state(self): 708 self._check_lazy_instancenorm_state(nn.InstanceNorm3d, nn.LazyInstanceNorm3d) 709 self._check_lazy_instancenorm_state(nn.InstanceNorm3d, nn.LazyInstanceNorm3d) 710 711 def test_lazy_batchnorm_with_dict_input(self): 712 self._check_lazy_norm_with_dict_input( 713 nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6) 714 ) 715 self._check_lazy_norm_with_dict_input( 716 nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7) 717 ) 718 self._check_lazy_norm_with_dict_input( 719 nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8) 720 ) 721 722 @suppress_warnings 723 def test_materialize_dtype(self): 724 module = LazyModule() 725 module.register_parameter("test_param", UninitializedParameter()) 726 module.test_param.materialize(10) 727 self.assertTrue(module.test_param.dtype == torch.get_default_dtype()) 728 module = LazyModule() 729 module.register_parameter("test_param", UninitializedParameter()) 730 module.half() 731 module.test_param.materialize(10) 732 self.assertTrue(module.test_param.dtype == torch.float16) 733 734 @unittest.skipIf( 735 not (TEST_CUDA or TEST_PRIVATEUSE1), "CUDA and PRIVATEUSE1 not available" 736 ) 737 @suppress_warnings 738 def test_materialize_device(self): 739 module = LazyModule() 740 module.register_parameter("test_param", UninitializedParameter()) 741 module.test_param.materialize(10) 742 self.assertTrue(module.test_param.device.type == "cpu") 743 if TEST_CUDA: 744 device = "cuda" 745 elif TEST_PRIVATEUSE1: 746 device = torch._C._get_privateuse1_backend_name() 747 module = LazyModule() 748 module.register_parameter("test_param", UninitializedParameter()) 749 module.to(device) 750 module.test_param.materialize(10) 751 self.assertTrue(module.test_param.device.type == device) 752 753 @suppress_warnings 754 def test_chained_initialization(self): 755 class MyNetwork(torch.nn.Module): 756 def __init__(self) -> None: 757 super().__init__() 758 self.linear_1 = torch.nn.LazyLinear(15) 759 self.linear_2 = torch.nn.LazyLinear(10) 760 761 def forward(self, x): 762 y = self.linear_1(x) 763 return self.linear_2(y) 764 765 net = MyNetwork() 766 net(torch.ones(5, 10)) 767 self.assertTrue(net.linear_1.weight.shape == (15, 10)) 768 self.assertTrue(net.linear_1.bias.shape == (15,)) 769 self.assertTrue(net.linear_2.weight.shape == (10, 15)) 770 self.assertTrue(net.linear_2.bias.shape == (10,)) 771 772 @suppress_warnings 773 def test_optimizer_pass(self): 774 optimizers = [ 775 torch.optim.Adadelta, 776 torch.optim.Adagrad, 777 torch.optim.Adamax, 778 torch.optim.Adam, 779 torch.optim.AdamW, 780 torch.optim.ASGD, 781 torch.optim.SGD, 782 torch.optim.Rprop, 783 torch.optim.RMSprop, 784 torch.optim.LBFGS, 785 torch.optim.NAdam, 786 torch.optim.RAdam, 787 ] 788 789 def run_step(module, optim): 790 self.assertIsInstance( 791 optim.param_groups[0]["params"][0], UninitializedParameter 792 ) 793 module.test_param.materialize(10) 794 self.assertIsInstance(optim.param_groups[0]["params"][0], Parameter) 795 self.assertNotIsInstance( 796 optim.param_groups[0]["params"][0], UninitializedParameter 797 ) 798 for p in module.parameters(): 799 p.grad = torch.rand_like(p) 800 if isinstance(optim, torch.optim.LBFGS): 801 optim.step(lambda: 1.0) 802 else: 803 optim.step() 804 805 for optim_cls in optimizers: 806 module = LazyModule() 807 module.register_parameter("test_param", UninitializedParameter()) 808 if optim_cls is torch.optim.SGD: 809 optim = optim_cls(module.parameters(), lr=0.0) 810 elif optim_cls is torch.optim.Adagrad: 811 with self.assertRaisesRegex(ValueError, "uninitialized parameter"): 812 optim = optim_cls(module.parameters()) 813 continue 814 else: 815 optim = optim_cls(module.parameters()) 816 run_step(module, optim) 817 818 @suppress_warnings 819 def test_weight_norm(self): 820 m = nn.LazyLinear(7) 821 with self.assertRaisesRegex(ValueError, "have uninitialized parameters."): 822 m = torch.nn.utils.weight_norm(m) 823 824 @suppress_warnings 825 def test_spectral_norm(self): 826 m = nn.LazyLinear(7) 827 with self.assertRaisesRegex(ValueError, "have uninitialized parameters."): 828 m = torch.nn.utils.spectral_norm(m) 829 830 @suppress_warnings 831 def test_invalid_functions(self): 832 param = torch.nn.parameter.UninitializedParameter() 833 with self.assertRaisesRegex(ValueError, "uninitialized parameter"): 834 torch.empty_like(param) 835 836 with self.assertRaisesRegex(ValueError, "uninitialized parameter"): 837 torch.add(param, param) 838 839 with self.assertRaisesRegex(ValueError, "uninitialized parameter"): 840 param + param 841 842 843if __name__ == "__main__": 844 run_tests() 845