1# Owner(s): ["module: unknown"] 2import copy 3import logging 4import random 5 6import torch 7from torch import nn 8from torch.ao.pruning._experimental.pruner import ( 9 BaseStructuredSparsifier, 10 FakeStructuredSparsity, 11 FPGMPruner, 12 LSTMSaliencyPruner, 13 SaliencyPruner, 14) 15from torch.nn.utils import parametrize 16from torch.testing._internal.common_pruning import ( 17 Conv2dActivation, 18 Conv2dBias, 19 Conv2dPadBias, 20 Conv2dPool, 21 Conv2dPoolFlatten, 22 Conv2dPoolFlattenFunctional, 23 LinearActivation, 24 LinearActivationFunctional, 25 LinearBias, 26 LSTMLayerNormLinearModel, 27 LSTMLinearModel, 28 rows_are_subset, 29 SimpleConv2d, 30 SimpleLinear, 31) 32from torch.testing._internal.common_utils import skipIfTorchDynamo, TestCase 33 34 35logging.basicConfig( 36 format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO 37) 38 39DEVICES = { 40 torch.device("cpu"), 41 torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), 42} 43 44 45class SimplePruner(BaseStructuredSparsifier): 46 def update_mask(self, module, tensor_name, **kwargs): 47 getattr(module.parametrizations, tensor_name)[0].mask[1] = False 48 49 50class ImplementedPruner(BaseStructuredSparsifier): 51 def update_mask(self, module, tensor_name, **kwargs): 52 """Prunes 1/3 of the weight output channels, so resulting module has 33.3% pruning""" 53 num_rows = len(module.parametrizations[tensor_name][0].mask) 54 prune = random.sample(list(range(num_rows)), num_rows // 3) 55 module.parametrizations[tensor_name][0].mask[prune] = False 56 57 58class BottomHalfLSTMPruner(BaseStructuredSparsifier): 59 """ 60 Pruner that will remove the bottom half of the rows. 61 This is primarily meant for testing purposes 62 """ 63 64 def update_mask(self, module, tensor_name, **kwargs): 65 for p in getattr(module.parametrizations, tensor_name): 66 if isinstance(p, FakeStructuredSparsity): 67 mask = p.mask 68 masks = torch.split(mask, len(mask) // 4) 69 for small in masks: 70 num = len(small) 71 small[num // 2 :] = False 72 new_mask = torch.cat(masks) 73 mask.data = new_mask.data 74 75 76class TestSaliencyPruner(TestCase): 77 def test_saliency_pruner_update_mask(self): 78 """Test that we prune out the row with the lowest saliency (first row)""" 79 model = SimpleLinear() 80 with torch.no_grad(): 81 model.linear1.weight = nn.Parameter( 82 torch.Tensor([[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]]) 83 ) 84 pruning_config = [{"tensor_fqn": "linear1.weight", "sparsity_level": 0.5}] 85 pruner = SaliencyPruner({}) 86 87 pruner.prepare(model, pruning_config) 88 pruner.enable_mask_update = True 89 pruner.step() 90 pruned_model = pruner.prune() 91 92 expected = torch.Tensor([[3, 3, 3, 3], [4, 4, 4, 4]]) 93 pruned = pruned_model.linear1.weight 94 95 assert expected.shape == pruned.shape 96 assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all() 97 98 def test_lstm_saliency_pruner_update_mask(self): 99 model = LSTMLinearModel( 100 input_dim=2, 101 hidden_dim=2, 102 output_dim=2, 103 num_layers=1, 104 ) 105 106 manual_weights = torch.Tensor( 107 [[1, 1], [2, 2], [2, 2], [1, 1], [-1, -1], [-2, -2], [-2, -2], [-1, -1]] 108 ) 109 110 with torch.no_grad(): 111 model.lstm.weight_ih_l0 = nn.Parameter(manual_weights) 112 model.lstm.weight_hh_l0 = nn.Parameter(torch.Tensor(manual_weights)) 113 model.lstm.bias_ih_l0 = nn.Parameter(manual_weights[:, 0]) 114 model.lstm.bias_hh_l0 = nn.Parameter(manual_weights[:, 0]) 115 116 config = [ 117 {"tensor_fqn": "lstm.weight_ih_l0"}, 118 {"tensor_fqn": "lstm.weight_hh_l0"}, 119 ] 120 lstm_input = torch.ones((1, 2)) 121 fx_pruner = LSTMSaliencyPruner({"sparsity_level": 0.5}) 122 fx_pruner.prepare(model, config) 123 fx_pruner.enable_mask_update = True 124 fx_pruner.step() 125 126 model.eval() 127 pruned_model = fx_pruner.prune() 128 pruned_model.eval() 129 130 # make sure both models run 131 model(lstm_input) 132 pruned_model(lstm_input) 133 134 # make sure lowest saliency rows are pruned 135 expected = torch.Tensor([[2, 2], [2, 2], [-2, -2], [-2, -2]]) 136 pruned = model.lstm.weight_ih_l0 137 assert expected.shape == pruned.shape 138 assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all() 139 140 expected = torch.Tensor([[2], [2], [-2], [-2]]) 141 pruned = model.lstm.weight_hh_l0 142 assert expected.shape == pruned.shape 143 assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all() 144 145 expected = torch.Tensor([2, 2, -2, -2]) 146 for pruned in [model.lstm.bias_ih_l0, model.lstm.bias_hh_l0]: 147 assert expected.shape == pruned.shape 148 assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all() 149 150 151class TestBaseStructuredSparsifier(TestCase): 152 def _check_pruner_prepared(self, model, pruner, device): 153 for config in pruner.groups: 154 module = config["module"] 155 assert module.weight.device.type == device.type 156 # Check mask exists 157 assert config["tensor_fqn"] in pruner.state 158 # Check parametrization exists and is correct 159 assert parametrize.is_parametrized(module) 160 assert hasattr(module, "parametrizations") 161 # Assume that this is the 1st/only parametrization 162 assert type(module.parametrizations.weight[0]) == FakeStructuredSparsity 163 164 def _check_pruner_valid_before_step(self, model, pruner, device): 165 for config in pruner.groups: 166 modules = [] 167 if type(config["module"]) is tuple: 168 modules.extend(config["module"]) 169 else: 170 module = config["module"] 171 modules.append(module) 172 for module in modules: 173 assert module.weight.device.type == device.type 174 assert module.parametrizations.weight[0].mask.dtype == torch.bool 175 176 def _check_pruner_valid_after_step(self, model, pruner, mask, device): 177 for config in pruner.groups: 178 modules = [] 179 if type(config["module"]) is tuple: 180 modules.extend(config["module"]) 181 else: 182 module = config["module"] 183 modules.append(module) 184 for module in modules: 185 assert module.weight.device.type == device.type 186 total = module.parametrizations.weight[0].mask.numel() 187 assert ( 188 module.parametrizations.weight[0].mask.count_nonzero() 189 == total - mask 190 ) 191 192 def _test_constructor_on_device(self, model, device): 193 self.assertRaisesRegex( 194 TypeError, 195 "BaseStructuredSparsifier.*update_mask", 196 BaseStructuredSparsifier, 197 ) 198 model1 = copy.deepcopy(model).to(device) 199 pruner = SimplePruner(None) 200 pruner.prepare(model1, None) 201 pruner.enable_mask_update = True 202 for g in pruner.groups: 203 module = g["module"] 204 assert module.weight.device.type == device.type 205 assert len(pruner.groups) == 5 206 pruner.step() 207 # Can instantiate the model with configs 208 model2 = copy.deepcopy(model).to(device) 209 pruner = SimplePruner({"test": 3}) 210 pruner.prepare(model2, [{"tensor_fqn": "seq.0.weight"}]) 211 assert len(pruner.groups) == 1 212 assert pruner.groups[0]["module_fqn"] == "seq.0" 213 assert "test" in pruner.groups[0] 214 assert pruner.groups[0]["test"] == 3 215 216 def test_constructor(self): 217 model = SimpleLinear() 218 for device in DEVICES: 219 self._test_constructor_on_device(model, torch.device(device)) 220 221 def _test_prepare_linear_on_device(self, model, device): 222 model = copy.deepcopy(model).to(device) 223 x = torch.ones(128, 7, device=device) 224 pruner = SimplePruner(None) 225 pruner.prepare(model, None) 226 self._check_pruner_prepared(model, pruner, device) 227 assert model(x).shape == (128, 10) 228 229 def test_prepare_linear(self): 230 models = [ 231 SimpleLinear(), 232 LinearBias(), 233 LinearActivation(), 234 LinearActivationFunctional(), 235 ] # without and with bias 236 for device in DEVICES: 237 for model in models: 238 self._test_prepare_linear_on_device(model, torch.device(device)) 239 240 def _test_prepare_conv2d_on_device(self, model, expected_shape, config, device): 241 x = torch.ones((1, 1, 28, 28), device=device) 242 pruner = SimplePruner(None) 243 pruner.prepare(model, config) 244 self._check_pruner_prepared(model, pruner, device) 245 assert model(x).shape == expected_shape 246 247 def test_prepare_conv2d(self): 248 models = [ 249 SimpleConv2d(), 250 Conv2dBias(), 251 Conv2dActivation(), 252 Conv2dPadBias(), 253 Conv2dPool(), 254 ] 255 shapes = [ 256 (1, 52, 20, 20), 257 (1, 52, 18, 18), 258 (1, 52, 18, 18), 259 (1, 52, 24, 24), 260 (1, 52, 3, 3), 261 ] 262 configs = [None, None, None, None, None] 263 for device in DEVICES: 264 for model, shape, config in zip(models, shapes, configs): 265 model = model.to(device) 266 self._test_prepare_conv2d_on_device( 267 model, shape, config, torch.device(device) 268 ) 269 270 def _test_step_linear_on_device(self, model, device): 271 model = model.to(device) 272 x = torch.ones(7, 7, device=device) 273 pruner = SimplePruner(None) 274 pruner.prepare(model, None) 275 pruner.enable_mask_update = True 276 self._check_pruner_valid_before_step(model, pruner, device) 277 pruner.step() 278 self._check_pruner_valid_after_step(model, pruner, 1, device) 279 280 def test_step_linear(self): 281 models = [ 282 SimpleLinear(), 283 LinearBias(), 284 LinearActivation(), 285 LinearActivationFunctional(), 286 ] 287 for device in DEVICES: 288 for model in models: 289 self._test_step_linear_on_device(model, torch.device(device)) 290 291 def _test_step_conv2d_on_device(self, model, expected_shape, config, device): 292 model = model.to(device) 293 x = torch.ones((1, 1, 28, 28), device=device) 294 pruner = SimplePruner(None) 295 pruner.prepare(model, config) 296 pruner.enable_mask_update = True 297 self._check_pruner_valid_before_step(model, pruner, device) 298 pruner.step() 299 self._check_pruner_valid_after_step(model, pruner, 1, device) 300 assert model(x).shape == expected_shape 301 302 @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 303 def test_step_conv2d(self): 304 models = [ 305 SimpleConv2d(), 306 Conv2dBias(), 307 Conv2dActivation(), 308 Conv2dPadBias(), 309 Conv2dPool(), 310 ] 311 shapes = [ 312 (1, 52, 20, 20), 313 (1, 52, 18, 18), 314 (1, 52, 18, 18), 315 (1, 52, 24, 24), 316 (1, 52, 3, 3), 317 ] 318 configs = [None, None, None, None, None] 319 for device in DEVICES: 320 for model, shape, config in zip(models, shapes, configs): 321 self._test_step_conv2d_on_device( 322 model, shape, config, torch.device(device) 323 ) 324 325 def _check_pruner_pruned(self, model, pruner, device): 326 for config in pruner.groups: 327 module = config["module"] 328 assert not hasattr(module, "parametrizations") 329 assert not hasattr(module, "mask") 330 331 def _test_linear_on_device( 332 self, model, config, expected_shape, device, also_prune_bias 333 ): 334 model = model.to(device) 335 model.eval() 336 num_original_params = sum(p.numel() for p in model.parameters()) 337 x = torch.ones(128, 7, device=device) 338 339 pruner = ImplementedPruner({"prune_bias": also_prune_bias}) 340 pruner.prepare(model, config) 341 pruner.enable_mask_update = True 342 pruner.step() 343 344 y_expected = model(x) 345 346 assert y_expected.shape == (128, 10) 347 self._check_pruner_prepared(model, pruner, device) 348 349 # Pruning step 350 pruned = pruner.prune() 351 y_pruned = pruned(x) 352 num_pruned_params = sum(p.numel() for p in pruned.parameters()) 353 354 assert y_pruned.shape == expected_shape 355 self._check_pruner_pruned(model, pruner, device) 356 if y_pruned.shape == y_expected.shape: 357 assert torch.isclose(y_expected, y_pruned, rtol=1e-05, atol=1e-07).all() 358 assert num_pruned_params < num_original_params 359 360 def test_prune_linear_linear(self): 361 r"""test pruning linear-> linear modules""" 362 configs, shapes = [], [] 363 configs.append( 364 [ 365 {"tensor_fqn": "seq.0.weight"}, 366 {"tensor_fqn": "seq.1.weight"}, 367 {"tensor_fqn": "seq.2.weight"}, 368 ] 369 ) 370 shapes.append((128, 10)) 371 372 configs.append( 373 [ 374 {"tensor_fqn": "seq.0.weight"}, 375 {"tensor_fqn": "seq.1.weight"}, 376 {"tensor_fqn": "seq.2.weight"}, 377 {"tensor_fqn": "linear1.weight"}, 378 ] 379 ) 380 shapes.append((128, 10)) 381 382 configs.append( 383 [ 384 {"tensor_fqn": "seq.0.weight"}, 385 {"tensor_fqn": "seq.2.weight"}, 386 ] 387 ) 388 shapes.append((128, 10)) 389 for device in DEVICES: 390 for also_prune_bias in [True, False]: 391 for config, shape in zip(configs, shapes): 392 self._test_linear_on_device( 393 SimpleLinear(), 394 config, 395 shape, 396 torch.device(device), 397 also_prune_bias, 398 ) 399 400 def test_prune_linear_bias_linear(self): 401 # linear(bias) -> linear(no bias) 402 configs, shapes = [], [] 403 configs.append( 404 [ 405 {"tensor_fqn": "seq.0.weight"}, 406 {"tensor_fqn": "seq.1.weight"}, 407 ] 408 ) 409 shapes.append((128, 10)) 410 411 # linear(bias) -> linear(bias) 412 configs.append( 413 [ 414 {"tensor_fqn": "seq.2.weight"}, 415 {"tensor_fqn": "seq.3.weight"}, 416 ] 417 ) 418 shapes.append((128, 10)) 419 420 # linear(no bias) -> linear(bias) 421 configs.append( 422 [ 423 {"tensor_fqn": "seq.0.weight"}, 424 {"tensor_fqn": "seq.1.weight"}, 425 {"tensor_fqn": "seq.2.weight"}, 426 ] 427 ) 428 shapes.append((128, 10)) 429 430 for device in DEVICES: 431 for also_prune_bias in [True, False]: 432 for config, shape in zip(configs, shapes): 433 self._test_linear_on_device( 434 LinearBias(), 435 config, 436 shape, 437 torch.device(device), 438 also_prune_bias, 439 ) 440 441 def test_prune_linear_activation_linear(self): 442 config = [ 443 {"tensor_fqn": "seq.0.weight"}, 444 {"tensor_fqn": "seq.2.weight"}, 445 {"tensor_fqn": "seq.4.weight"}, 446 {"tensor_fqn": "linear1.weight"}, 447 ] 448 shape = (128, 10) 449 450 for device in DEVICES: 451 for also_prune_bias in [True, False]: 452 # test version with nn.Modules 453 self._test_linear_on_device( 454 LinearActivation(), 455 config, 456 shape, 457 torch.device(device), 458 also_prune_bias, 459 ) 460 # test functional version 461 self._test_linear_on_device( 462 LinearActivationFunctional(), 463 config, 464 shape, 465 torch.device(device), 466 also_prune_bias, 467 ) 468 469 def _test_conv2d_on_device( 470 self, model, config, x, expected_shape, device, also_prune_bias 471 ): 472 model = model.to(device) 473 num_original_params = sum(p.numel() for p in model.parameters()) 474 model.eval() 475 476 pruner = ImplementedPruner({"prune_bias": also_prune_bias}) 477 pruner.prepare(model, config) 478 pruner.enable_mask_update = True 479 pruner.step() 480 481 y_expected = model(x) 482 assert y_expected.shape == expected_shape 483 484 self._check_pruner_prepared(model, pruner, device) 485 486 # Fusion step 487 pruned = pruner.prune() 488 y_pruned = pruned(x) 489 num_pruned_params = sum(p.numel() for p in pruned.parameters()) 490 491 assert y_pruned.shape == expected_shape 492 self._check_pruner_pruned(model, pruner, device) 493 if y_pruned.shape == y_expected.shape: 494 # TODO This rtol is a little high, need to double check if something specific is causing this to fail 495 assert torch.isclose( 496 y_expected, 497 y_pruned, 498 rtol=1e-3, 499 atol=1e-3, 500 ).all(), f"fail for {type(model)}" 501 # only time this should be equal is when all layers have padding and we can't prune 502 assert num_pruned_params <= num_original_params 503 504 def test_prune_conv2d_conv2d(self): 505 configs, shapes = [], [] 506 # all within sequential blocks 507 configs.append( 508 [ 509 {"tensor_fqn": "seq.0.weight"}, 510 ] 511 ) 512 shapes.append((1, 52, 20, 20)) 513 # prune across sequential blocks 514 configs.append( 515 [ 516 {"tensor_fqn": "seq.0.weight"}, 517 {"tensor_fqn": "seq.1.weight"}, 518 {"tensor_fqn": "conv2d1.weight"}, 519 ] 520 ) 521 shapes.append((1, 52, 20, 20)) 522 523 for device in DEVICES: 524 x = torch.ones((1, 1, 28, 28), device=device) 525 for also_prune_bias in [True, False]: 526 for config, shape in zip(configs, shapes): 527 self._test_conv2d_on_device( 528 SimpleConv2d(), 529 config, 530 x, 531 shape, 532 torch.device(device), 533 also_prune_bias, 534 ) 535 536 def test_prune_conv2d_bias_conv2d(self): 537 # Conv2d with Bias and no Activation 538 configs, shapes = [], [] 539 # conv2d(bias) -> conv2d(bias) 540 configs.append( 541 [ 542 {"tensor_fqn": "seq.0.weight"}, 543 {"tensor_fqn": "seq.1.weight"}, 544 ] 545 ) 546 shapes.append((1, 52, 18, 18)) 547 548 # conv2d(no bias) -> conv2d(bias) 549 configs.append( 550 [ 551 {"tensor_fqn": "seq.0.weight"}, 552 {"tensor_fqn": "seq.1.weight"}, 553 {"tensor_fqn": "conv2d1.weight"}, 554 ] 555 ) 556 shapes.append((1, 52, 18, 18)) 557 558 # conv2d(bias) -> conv2d(no bias) 559 configs.append( 560 [ 561 {"tensor_fqn": "seq.0.weight"}, 562 {"tensor_fqn": "seq.1.weight"}, 563 {"tensor_fqn": "seq.2.weight"}, 564 ] 565 ) 566 shapes.append((1, 52, 18, 18)) 567 568 for device in DEVICES: 569 x = torch.ones((1, 1, 28, 28), device=device) 570 for also_prune_bias in [True, False]: 571 for config, shape in zip(configs, shapes): 572 self._test_conv2d_on_device( 573 Conv2dBias(), 574 config, 575 x, 576 shape, 577 torch.device(device), 578 also_prune_bias, 579 ) 580 581 def test_prune_conv2d_activation_conv2d(self): 582 # Conv2d with Activation and no Bias 583 configs, shapes = [], [] 584 585 # conv2d(no bias) -> activation -> conv2d(no bias) 586 configs.append( 587 [ 588 {"tensor_fqn": "seq.4.weight"}, 589 ] 590 ) 591 shapes.append((1, 52, 18, 18)) 592 593 # conv2d(bias) -> activation -> conv2d(bias) 594 configs.append( 595 [ 596 {"tensor_fqn": "seq.0.weight"}, 597 {"tensor_fqn": "seq.2.weight"}, 598 ] 599 ) 600 shapes.append((1, 52, 18, 18)) 601 602 # conv2d(bias) -> activation -> conv2d(no bias) 603 configs.append( 604 [ 605 {"tensor_fqn": "seq.2.weight"}, 606 {"tensor_fqn": "seq.4.weight"}, 607 ] 608 ) 609 shapes.append((1, 52, 18, 18)) 610 611 # conv2d(no bias) -> activation -> conv2d(bias) 612 configs.append( 613 [ 614 {"tensor_fqn": "conv2d1.weight"}, 615 ] 616 ) 617 shapes.append((1, 52, 18, 18)) 618 619 for device in DEVICES: 620 x = torch.ones((1, 1, 28, 28), device=device) 621 for also_prune_bias in [True, False]: 622 for config, shape in zip(configs, shapes): 623 self._test_conv2d_on_device( 624 Conv2dActivation(), 625 config, 626 x, 627 shape, 628 torch.device(device), 629 also_prune_bias, 630 ) 631 632 def test_prune_conv2d_padding_conv2d(self): 633 # Conv2d with Padded layers after Bias layers 634 configs, shapes = [], [] 635 636 # conv(padded, bias) -> conv(padded, bias) 637 configs.append( 638 [ 639 {"tensor_fqn": "seq.4.weight"}, 640 ] 641 ) 642 shapes.append((1, 52, 24, 24)) 643 644 # conv(no bias, no pad) -> conv(padded, bias) 645 configs.append( 646 [ 647 {"tensor_fqn": "seq.2.weight"}, 648 ] 649 ) 650 shapes.append((1, 52, 24, 24)) 651 652 # conv(padded, bias) -> conv ( no bias ,no pad) 653 configs.append( 654 [ 655 {"tensor_fqn": "seq.0.weight"}, 656 ] 657 ) 658 shapes.append((1, 52, 24, 24)) 659 # conv(pad, bias) -> conv(no pad, bias) 660 configs.append( 661 [ 662 {"tensor_fqn": "seq.6.weight"}, 663 ] 664 ) 665 shapes.append((1, 52, 24, 24)) 666 # conv(no pad, bias) -> conv(pad, bias) 667 configs.append( 668 [ 669 {"tensor_fqn": "seq.8.weight"}, 670 ] 671 ) 672 shapes.append((1, 52, 24, 24)) 673 674 for device in DEVICES: 675 x = torch.ones((1, 1, 28, 28), device=device) 676 for also_prune_bias in [True, False]: 677 for config, shape in zip(configs, shapes): 678 self._test_conv2d_on_device( 679 Conv2dPadBias(), 680 config, 681 x, 682 shape, 683 torch.device(device), 684 also_prune_bias, 685 ) 686 687 def test_prune_conv2d_pool_conv2d(self): 688 # Conv2d with Pooling layers 689 config = [ 690 {"tensor_fqn": "seq.0.weight"}, 691 {"tensor_fqn": "seq.3.weight"}, 692 {"tensor_fqn": "conv2d1.weight"}, 693 {"tensor_fqn": "conv2d2.weight"}, 694 ] 695 shape = (1, 52, 3, 3) 696 697 for device in DEVICES: 698 x = torch.ones((1, 1, 28, 28), device=device) 699 for also_prune_bias in [True, False]: 700 self._test_conv2d_on_device( 701 Conv2dPool(), 702 config, 703 x, 704 shape, 705 torch.device(device), 706 also_prune_bias, 707 ) 708 709 @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 710 def test_complex_conv2d(self): 711 """Test fusion for models that contain Conv2d & Linear modules. 712 Currently supports: Conv2d-Pool2d-Flatten-Linear, Skip-add""" 713 config = [ 714 {"tensor_fqn": "seq.0.weight"}, 715 {"tensor_fqn": "seq.3.weight"}, 716 {"tensor_fqn": "conv2d1.weight"}, 717 {"tensor_fqn": "conv2d2.weight"}, 718 ] 719 shape = (1, 13) 720 721 for device in DEVICES: 722 x = torch.ones((1, 1, 28, 28), device=device) 723 for also_prune_bias in [True, False]: 724 self._test_conv2d_on_device( 725 Conv2dPoolFlattenFunctional(), 726 config, 727 x, 728 shape, 729 torch.device(device), 730 also_prune_bias, 731 ) 732 self._test_conv2d_on_device( 733 Conv2dPoolFlatten(), 734 config, 735 x, 736 shape, 737 torch.device(device), 738 also_prune_bias, 739 ) 740 741 def test_prune_lstm_linear_multiple_layer(self): 742 """ 743 Test fusion support for LSTM(multi-layer) -> Linear 744 """ 745 model = LSTMLinearModel( 746 input_dim=8, 747 hidden_dim=8, 748 output_dim=8, 749 num_layers=2, 750 ) 751 752 config = [ 753 {"tensor_fqn": "lstm.weight_ih_l0"}, 754 {"tensor_fqn": "lstm.weight_hh_l0"}, 755 {"tensor_fqn": "lstm.weight_ih_l1"}, 756 {"tensor_fqn": "lstm.weight_hh_l1"}, 757 ] 758 759 lstm_input = torch.ones((1, 8)) 760 fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5}) 761 fx_pruner.prepare(model, config) 762 763 fx_pruner.enable_mask_update = True 764 fx_pruner.step() 765 766 model.eval() 767 _, _ = model(lstm_input) 768 pruned_model = fx_pruner.prune() 769 pruned_model.eval() 770 _, _ = pruned_model(lstm_input) 771 772 expected_params = dict(model.named_parameters()) 773 for name, param in model.named_parameters(): 774 assert name in expected_params 775 # We cannot compare y_expected == y_pruned, as the 0 elements mess up the numerics 776 # Instead we check that the weights of the new LSTM are a subset of the weights of 777 # the old LSTM 778 assert rows_are_subset(param, expected_params[name]) 779 del expected_params[name] 780 781 # assert we haven't deleted any keys 782 assert len(expected_params) == 0 783 784 def test_prune_lstm_linear_single_layer(self): 785 """ 786 Test fusion support for LSTM (single-layer) -> Linear 787 """ 788 model = LSTMLinearModel( 789 input_dim=8, 790 hidden_dim=8, 791 output_dim=8, 792 num_layers=1, 793 ) 794 795 config = [ 796 {"tensor_fqn": "lstm.weight_ih_l0"}, 797 {"tensor_fqn": "lstm.weight_hh_l0"}, 798 ] 799 800 lstm_input = torch.ones((1, 8)) 801 fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5}) 802 fx_pruner.prepare(model, config) 803 fx_pruner.enable_mask_update = True 804 fx_pruner.step() 805 model.eval() 806 807 out_expected, lstm_out_expected = model(lstm_input) 808 pruned_model = fx_pruner.prune() 809 pruned_model.eval() 810 out_pruned, lstm_out_pruned = pruned_model(lstm_input) 811 r, c = lstm_out_expected.size() 812 813 # We cannot check that y_expected == y_pruned as usual because 814 # zeros vs. missing elements yield different numerical results. 815 # Instead that we check that the pruned elements are the first half of the results 816 # since we are using a BottomHalfLSTMPruner 817 assert torch.isclose( 818 lstm_out_expected[:, : c // 2], lstm_out_pruned, rtol=1e-05, atol=1e-07 819 ).all() 820 # also check that output of linear is the same shape, this means we've resized 821 # linear columns correctly. 822 assert out_expected.shape == out_pruned.shape 823 824 def test_prune_lstm_layernorm_linear_multiple_layer(self): 825 """ 826 Test fusion support for LSTM(multi-layer) -> Linear 827 """ 828 model = LSTMLayerNormLinearModel( 829 input_dim=8, 830 output_dim=8, 831 hidden_dim=8, 832 num_layers=2, 833 ) 834 835 config = [ 836 {"tensor_fqn": "lstm.weight_ih_l0"}, 837 {"tensor_fqn": "lstm.weight_hh_l0"}, 838 {"tensor_fqn": "lstm.weight_ih_l1"}, 839 {"tensor_fqn": "lstm.weight_hh_l1"}, 840 ] 841 842 lstm_input = torch.ones((1, 8)) 843 fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5}) 844 fx_pruner.prepare(model, config) 845 846 fx_pruner.enable_mask_update = True 847 fx_pruner.step() 848 849 model.eval() 850 _, _ = model(lstm_input) 851 pruned_model = fx_pruner.prune() 852 pruned_model.eval() 853 _, _ = pruned_model(lstm_input) 854 855 expected_params = dict(model.named_parameters()) 856 for name, param in model.named_parameters(): 857 assert name in expected_params 858 # We cannot compare y_expected == y_pruned, as the 0 elements mess up the numerics 859 # Instead we check that the weights of the new LSTM are a subset of the weights of 860 # the old LSTM 861 assert rows_are_subset(param, expected_params[name]) 862 del expected_params[name] 863 864 # assert we haven't deleted any keys 865 assert len(expected_params) == 0 866 867 def test_prune_lstm_layernorm_linear_single_layer(self): 868 """ 869 Test fusion support for LSTM (single-layer) -> Linear 870 """ 871 model = LSTMLinearModel( 872 input_dim=8, 873 hidden_dim=8, 874 output_dim=8, 875 num_layers=1, 876 ) 877 878 config = [ 879 {"tensor_fqn": "lstm.weight_ih_l0"}, 880 {"tensor_fqn": "lstm.weight_hh_l0"}, 881 ] 882 883 lstm_input = torch.ones((1, 8)) 884 fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5}) 885 fx_pruner.prepare(model, config) 886 fx_pruner.enable_mask_update = True 887 fx_pruner.step() 888 model.eval() 889 890 out_expected, lstm_out_expected = model(lstm_input) 891 pruned_model = fx_pruner.prune() 892 pruned_model.eval() 893 out_pruned, lstm_out_pruned = pruned_model(lstm_input) 894 r, c = lstm_out_expected.size() 895 896 # We cannot check that y_expected == y_pruned as usual because 897 # zeros vs. missing elements yield different numerical results. 898 # Instead that we check that the pruned elements are the first half of the results 899 # since we are using a BottomHalfLSTMPruner 900 assert torch.isclose( 901 lstm_out_expected[:, : c // 2], lstm_out_pruned, rtol=1e-05, atol=1e-07 902 ).all() 903 # also check that output of linear is the same shape, this means we've resized 904 # linear columns correctly. 905 assert out_expected.shape == out_pruned.shape 906 907 908class TestFPGMPruner(TestCase): 909 """ 910 Test case for the implementation of paper: 911 `Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration <https://arxiv.org/abs/1811.00250>`_. 912 """ 913 914 class SimpleConvFPGM(nn.Module): 915 def __init__(self) -> None: 916 super().__init__() 917 self.conv2d1 = nn.Conv2d( 918 in_channels=1, out_channels=3, kernel_size=3, padding=1, bias=False 919 ) 920 # Manually set the filter weights for demonstration purposes 921 """ 922 Three filters' weight are manually set to values 3.0, 2.0, and 0.1. 923 Different from the norm-based decision that prunes filter with value 0.1, 924 FPGM will prune the one with value 2.0. 925 """ 926 weights = torch.tensor([3.0, 2.0, 0.1]) # Weight weights for each filter 927 weights = weights[:, None, None, None] # broadcasting 928 self.conv2d1.weight.data.copy_( 929 torch.ones(self.conv2d1.weight.shape) * weights 930 ) 931 932 # Second Convolutional Layer 933 self.conv2d2 = nn.Conv2d( 934 in_channels=3, out_channels=4, kernel_size=3, padding=1, bias=False 935 ) 936 weights = torch.tensor([6.0, 7.0, 0.4, 0.5]) 937 weights = weights[:, None, None, None] 938 self.conv2d2.weight.data.copy_( 939 torch.ones(self.conv2d2.weight.shape) * weights 940 ) 941 942 def forward(self, x): 943 x = self.conv2d1(x) 944 x = self.conv2d2(x) 945 return x 946 947 def test_compute_distance(self, device="cpu"): 948 """Test the distance computation function""" 949 model = TestFPGMPruner.SimpleConvFPGM().to(device) 950 pruner = FPGMPruner(0.3) 951 dist_conv1 = pruner._compute_distance(model.conv2d1.weight) 952 953 # compute the distance matrix using torch.cdist 954 flattened_filters = torch.Tensor( 955 [ 956 [ 957 3.0000, 958 3.0000, 959 3.0000, 960 3.0000, 961 3.0000, 962 3.0000, 963 3.0000, 964 3.0000, 965 3.0000, 966 ], 967 [ 968 2.0000, 969 2.0000, 970 2.0000, 971 2.0000, 972 2.0000, 973 2.0000, 974 2.0000, 975 2.0000, 976 2.0000, 977 ], 978 [ 979 0.1000, 980 0.1000, 981 0.1000, 982 0.1000, 983 0.1000, 984 0.1000, 985 0.1000, 986 0.1000, 987 0.1000, 988 ], 989 ] 990 ) 991 992 """ 993 Expected distance matrix should have the following values: 994 [0.0000, 3.0000, 8.7000], 995 [3.0000, 0.0000, 5.7000], 996 [8.7000, 5.7000, 0.0000], 997 the distance should therefore be: 998 [11.7000, 8.7000, 14.4000] 999 """ 1000 expected_dist_matrix_conv1 = torch.cdist( 1001 flattened_filters, flattened_filters, p=2 1002 ) 1003 expected_dist_conv1 = torch.sum(torch.abs(expected_dist_matrix_conv1), 1) 1004 assert torch.isclose( 1005 dist_conv1, expected_dist_conv1, rtol=1e-05, atol=1e-07 1006 ).all() 1007 1008 def _test_update_mask_on_single_layer(self, expected_conv1, device): 1009 """Test that pruning is conducted based on the pair-wise distance measurement instead of absolute norm value""" 1010 # test pruning with one layer of conv2d 1011 model = TestFPGMPruner.SimpleConvFPGM().to(device) 1012 x = torch.ones((1, 1, 32, 32), device=device) 1013 pruner = FPGMPruner(0.3) 1014 config = [{"tensor_fqn": "conv2d1.weight"}] 1015 pruner.prepare(model, config) 1016 pruner.enable_mask_update = True 1017 pruner.step() 1018 assert ( 1019 pruner.groups[0]["module"].parametrizations.weight[0].mask[-1].item() 1020 is not False 1021 ), "do not prune the least-norm filter" 1022 1023 # fusion step 1024 pruned_model = pruner.prune() 1025 1026 pruned_y = pruned_model(x) 1027 # assert shapes 1028 expected_conv1 = expected_conv1.to(device) 1029 assert pruned_y.shape == (1, 4, 32, 32) 1030 assert pruned_model.conv2d1.weight.shape == expected_conv1.shape 1031 assert pruned_model.conv2d2.weight.shape == ( 1032 4, 1033 2, 1034 3, 1035 3, 1036 ), "conv2d2 should have input channel pruned" 1037 # assert value 1038 assert torch.isclose( 1039 pruned_model.conv2d1.weight, expected_conv1, rtol=1e-05, atol=1e-07 1040 ).all() 1041 1042 def _test_update_mask_on_multiple_layer( 1043 self, expected_conv1, expected_conv2, device 1044 ): 1045 # the second setting 1046 model = TestFPGMPruner.SimpleConvFPGM().to(device) 1047 x = torch.ones((1, 1, 32, 32), device=device) 1048 pruner = FPGMPruner(0.3) 1049 config = [ 1050 {"tensor_fqn": "conv2d1.weight"}, 1051 {"tensor_fqn": "conv2d2.weight", "sparsity_level": 0.5}, 1052 ] 1053 pruner.prepare(model, config) 1054 pruner.enable_mask_update = True 1055 pruner.step() 1056 # Get the masks for the two least-norm filters 1057 mask1 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-1] 1058 mask2 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-2] 1059 # Check if either of the least-norm filters is not pruned 1060 assert ( 1061 mask1.item() is not False or mask2.item() is not False 1062 ), "Do not prune all least-norm filters" 1063 1064 # fusion step 1065 pruned_model = pruner.prune() 1066 pruned_y = pruned_model(x) 1067 # assert shapes 1068 expected_conv1 = expected_conv1.to(device) 1069 expected_conv2 = expected_conv2.to(device) 1070 assert pruned_y.shape == (1, 2, 32, 32) 1071 assert pruned_model.conv2d1.weight.shape == expected_conv1.shape 1072 assert pruned_model.conv2d2.weight.shape == expected_conv2.shape 1073 # assert values 1074 assert torch.isclose( 1075 pruned_model.conv2d1.weight, expected_conv1, rtol=1e-05, atol=1e-07 1076 ).all() 1077 assert torch.isclose( 1078 pruned_model.conv2d2.weight, expected_conv2, rtol=1e-05, atol=1e-07 1079 ).all() 1080 1081 def test_update_mask(self): 1082 weights = torch.tensor([3.0, 0.1]) 1083 expected_conv1 = torch.ones((2, 1, 3, 3)) * weights[:, None, None, None] 1084 1085 weights = torch.tensor([7.0, 0.4]) 1086 expected_conv2 = torch.ones((2, 2, 3, 3)) * weights[:, None, None, None] 1087 1088 for device in DEVICES: 1089 self._test_update_mask_on_single_layer(expected_conv1, device) 1090 self._test_update_mask_on_multiple_layer( 1091 expected_conv1, expected_conv2, device 1092 ) 1093