1# Owner(s): ["module: unknown"] 2 3import copy 4import itertools 5import logging 6import math 7from typing import Tuple 8 9import torch 10from torch import nn 11from torch.ao.pruning._experimental.data_sparsifier import ( 12 BaseDataSparsifier, 13 DataNormSparsifier, 14) 15from torch.ao.pruning._experimental.data_sparsifier.quantization_utils import ( 16 post_training_sparse_quantize, 17) 18from torch.nn.utils.parametrize import is_parametrized 19from torch.testing._internal.common_utils import TestCase 20 21 22logging.basicConfig( 23 format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO 24) 25 26 27class ImplementedSparsifier(BaseDataSparsifier): 28 def __init__(self, **kwargs): 29 super().__init__(**kwargs) 30 31 def update_mask(self, name, data, **kwargs): 32 mask = self.get_mask(name) 33 mask[0] = 0 34 linear_state = self.state[name] 35 linear_state["step_count"] = linear_state.get("step_count", 0) + 1 36 37 38class _BaseDataSparsiferTestCase(TestCase): 39 r"""This helper test class takes in any supported type of and runs some tests. 40 The user is required to pass in the data that needs to sparsified and the 41 runner will run some tests that needs to be passed in order for the data 42 type to be supported. 43 TODO: Change the structure by creating a separate test case class for each 44 member function 45 """ 46 47 def run_all_checks(self, data_list, data_with_config, defaults): 48 self.check_constructor(data_list, data_with_config, defaults) 49 self.check_squash_mask(data_list, data_with_config, defaults) 50 self.check_add_data(data_list, data_with_config, defaults) 51 self.check_step(data_list, data_with_config, defaults) 52 self.check_state_dict(data_list, data_with_config, defaults) 53 self.check_memory_reference(data_list, data_with_config, defaults) 54 55 @staticmethod 56 def _get_name_data_config(some_data, defaults=None): 57 if isinstance(some_data, Tuple): 58 # dealing with data_list 59 name, data = some_data 60 config = defaults 61 else: 62 # dealing with data_with_config 63 name, data, config = ( 64 some_data["name"], 65 some_data["data"], 66 some_data["config"], 67 ) 68 return name, data, config 69 70 @staticmethod 71 def _make_sparsifier( 72 data_list, 73 data_with_config, 74 defaults, 75 sparsifier_type=None, 76 sparsifier_kwargs=None, 77 ): 78 if sparsifier_type is None: 79 sparsifier = ImplementedSparsifier(data_list=data_list, **defaults) 80 else: 81 kwargs = copy.deepcopy(defaults) 82 kwargs.update(sparsifier_kwargs) 83 kwargs["data_list"] = data_list 84 sparsifier = sparsifier_type(**kwargs) 85 assert len(sparsifier.data_groups) == len(data_list) 86 for data_config_dict in data_with_config: 87 name, data, config = ( 88 data_config_dict["name"], 89 data_config_dict["data"], 90 data_config_dict["config"], 91 ) 92 sparsifier.add_data(name=name, data=data, **config) 93 return sparsifier 94 95 def check_constructor(self, data_list, data_with_config, defaults, **kwargs): 96 sparsifier = self._make_sparsifier( 97 data_list, data_with_config, defaults=defaults, **kwargs 98 ) 99 self.assertEqual( 100 len(sparsifier.data_groups), 101 len(data_list) + len(data_with_config), 102 msg="Sparsifier data groups don't match the input " 103 f"({len(sparsifier.data_groups)} vs. " 104 f"{len(data_list) + len(data_with_config)}).", 105 ) 106 107 all_data = data_list + data_with_config 108 109 for some_data in all_data: 110 name, _, config = self._get_name_data_config(some_data, defaults=defaults) 111 self.assertIn(name, sparsifier.data_groups) 112 self.assertEqual(sparsifier.data_groups[name], config) 113 114 def check_step(self, data_list, data_with_config, defaults, **kwargs): 115 sparsifier = self._make_sparsifier( 116 data_list, data_with_config, defaults=defaults, **kwargs 117 ) 118 all_data = data_list + data_with_config 119 120 # Check data and mask before doing the step 121 for some_data in all_data: 122 name, data, _ = self._get_name_data_config(some_data) 123 data = sparsifier._extract_weight(data) 124 sparsified_data = sparsifier.get_data(name=name, return_original=False) 125 original_data = sparsifier.get_data(name=name, return_original=True) 126 mask = sparsifier.get_mask(name=name) 127 self.assertEqual(sparsified_data, data) 128 self.assertEqual(original_data, data) 129 self.assertEqualBroadcasting(mask[0], 1) 130 131 step_count = 3 132 133 for _ in range(0, step_count): 134 sparsifier.step() 135 for some_data in all_data: 136 name, data, _ = self._get_name_data_config(some_data) 137 data = sparsifier._extract_weight(data) 138 sparsified_data = sparsifier.get_data(name=name, return_original=False) 139 original_data = sparsifier.get_data(name=name, return_original=True) 140 mask = sparsifier.get_mask(name=name) 141 self.assertEqualBroadcasting(sparsified_data[0], 0) 142 self.assertEqual(original_data, data) 143 self.assertEqualBroadcasting(mask[0], 0) 144 assert "step_count" in sparsifier.state[name] 145 assert sparsifier.state[name]["step_count"] == 3 146 147 def check_squash_mask(self, data_list, data_with_config, defaults, **kwargs): 148 sparsifier = self._make_sparsifier( 149 data_list, data_with_config, defaults=defaults, **kwargs 150 ) 151 all_data = data_list + data_with_config 152 for some_data in all_data: 153 name, _, _ = self._get_name_data_config(some_data) 154 assert hasattr(sparsifier._container, name) 155 assert is_parametrized(sparsifier._container, name) 156 sparsifier.step() 157 sparsifier.squash_mask() 158 159 for some_data in all_data: 160 name, _, _ = self._get_name_data_config(some_data) 161 assert not is_parametrized( 162 sparsifier._container, name 163 ) # not parametrized anymore 164 with self.assertRaises(ValueError): 165 sparsifier.get_data(name, return_original=True) 166 167 def check_add_data(self, data_list, data_with_config, defaults, **kwargs): 168 sparsifier = self._make_sparsifier( 169 data_list, data_with_config, defaults=defaults, **kwargs 170 ) 171 all_data = data_list + data_with_config 172 for some_data in all_data: 173 name1, data1, config = self._get_name_data_config( 174 some_data, defaults=defaults 175 ) 176 data1 = sparsifier._extract_weight(data1) 177 data1_old = copy.deepcopy(data1) 178 assert torch.all(data1 == sparsifier.get_data(name=name1)) 179 180 sparsifier.step() 181 mask = sparsifier.get_mask(name1) 182 183 data2 = torch.randn( 184 data1.shape 185 ) # add another data with the same shape as original data 186 sparsifier.add_data(name=name1, data=data2) 187 assert torch.all(data2 == sparsifier.get_data(name=name1)) 188 189 assert torch.all( 190 sparsifier.get_mask(name1) == mask 191 ) # mask should not change 192 assert torch.all(data1_old == data1) 193 194 assert ( 195 sparsifier.data_groups[name1] == config 196 ) # if replaced old_config should match new config 197 198 def check_state_dict(self, data_list, data_with_config, defaults, **kwargs): 199 sparsifier1 = self._make_sparsifier( 200 data_list, data_with_config, defaults=defaults, **kwargs 201 ) 202 sparsifier2 = self._make_sparsifier( 203 data_list=[data_list[0]], data_with_config=[], defaults=defaults, **kwargs 204 ) 205 sparsifier1.step() 206 207 state_dict1 = sparsifier1.state_dict() 208 209 assert sparsifier1.state != sparsifier2.state 210 name, _, _ = self._get_name_data_config(data_list[0]) 211 self.assertNotEqual(sparsifier1.get_mask(name), sparsifier2.get_mask(name)) 212 213 sparsifier2.load_state_dict(state_dict1) 214 assert len(sparsifier1.state) == len(sparsifier2.state) 215 assert len(sparsifier1.data_groups) == len(sparsifier2.data_groups) 216 217 state1 = state_dict1["state"] 218 for name in state1.keys(): 219 # compare mask 220 assert name in sparsifier2.state 221 assert "mask" in sparsifier2.state[name] 222 assert "mask" in sparsifier1.state[name] 223 mask1, mask2 = state1[name]["mask"], sparsifier2.state[name]["mask"] 224 assert mask1.is_sparse and not mask2.is_sparse 225 assert torch.all( 226 mask1.to_dense() == mask2 227 ) # mask1 is stored as sparse coo now 228 229 # compare data_groups 230 dg1, dg2 = sparsifier1.data_groups, sparsifier2.data_groups 231 assert name in dg1 and name in dg2 232 assert dg1[name] == dg2[name] 233 234 # compare container 235 container1, container2 = sparsifier1._container, sparsifier2._container 236 assert torch.all(getattr(container1, name) == getattr(container2, name)) 237 assert is_parametrized(container1, name) == is_parametrized( 238 container2, name 239 ) 240 if is_parametrized(container1, name): 241 param1 = getattr(container1.parametrizations, name)[0] 242 param2 = getattr(container2.parametrizations, name)[0] 243 assert hasattr(param1, "mask") 244 assert hasattr(param2, "mask") 245 self.assertEqual(param1.__dict__, param2.__dict__) 246 247 def check_memory_reference(self, data_list, data_with_config, defaults, **kwargs): 248 """Checks if the data is truly "attached" to the sparsifier. Meaning, when the 249 data is changed outside of the sparsifier, the changes must be reflected on the data 250 inside the data sparsifier as well. 251 This makes sure that the sparsifier is holding the memory reference of the data and 252 not copies. 253 254 This test modifies the data and asserts that data in the sparsifier is changed as well 255 """ 256 sparsifier = self._make_sparsifier( 257 data_list, data_with_config, defaults=defaults, **kwargs 258 ) 259 all_data = data_list + data_with_config 260 for some_data in all_data: 261 name, data, _ = self._get_name_data_config(some_data) 262 weight = sparsifier._extract_weight(data) 263 weight.data = weight + torch.randn(*weight.shape) 264 contained_data = sparsifier.get_data(name=name) 265 assert ( 266 weight.data.storage().data_ptr() 267 == contained_data.data.storage().data_ptr() 268 ) 269 assert torch.all(contained_data == weight) 270 271 272class _NormDataSparsifierTestCase(_BaseDataSparsiferTestCase): 273 r"""This helper test class takes in any supported type of and runs some tests. 274 This inherits the TestBaseDataSparsifierRuner wherein some functions are 275 over-ridden to take accomodate the specific sparsifier. 276 TODO: Change the structure by creating a separate test case class for each 277 member function 278 """ 279 280 def run_all_checks(self, data_list, defaults, data_with_config, norm_type="L1"): 281 assert norm_type in ["L1", "L2"] 282 kwargs = { 283 "sparsifier_type": DataNormSparsifier, 284 "sparsifier_kwargs": {"norm": norm_type}, 285 } 286 self.check_constructor(data_list, data_with_config, defaults, **kwargs) 287 self.check_squash_mask(data_list, data_with_config, defaults, **kwargs) 288 self.check_add_data(data_list, data_with_config, defaults, **kwargs) 289 self.check_state_dict(data_list, data_with_config, defaults, **kwargs) 290 self.check_step(data_list, data_with_config, defaults, norm_type=norm_type) 291 self.check_step_2_of_4(norm_type=norm_type) 292 self.check_sparsity_level( 293 data_list, data_with_config, defaults, norm_type=norm_type 294 ) 295 self.check_memory_reference(data_list, data_with_config, defaults, **kwargs) 296 297 @staticmethod 298 def _get_bounds_on_actual_sparsity(config, tensor_shape): 299 r"""This function gets the bounds on actual sparsity. 300 Note:: 301 Although we specify the sparsity_level parameter, this does not mean that 302 the actual sparsity obtained after sparsification is the same as sparsity_level. 303 The actual sparsity depends largely on the shape and the data itself. 304 """ 305 sparsity_level = config["sparsity_level"] 306 zeros_per_block = config["zeros_per_block"] 307 sparse_block_shape = config["sparse_block_shape"] 308 309 height, width = tensor_shape[-2], tensor_shape[-1] 310 block_height, block_width = sparse_block_shape 311 number_blocks = math.ceil(height / block_height) * math.ceil( 312 width / block_width 313 ) 314 values_per_block = block_height * block_width 315 316 if zeros_per_block == 0: 317 return (1.0, 1.0) 318 else: 319 # min value assumes zeros_per_block is 1 320 min_values_sparsified = round(number_blocks * sparsity_level) 321 # max value assumes actual zeros_per_block 322 max_values_sparsified = min_values_sparsified * min( 323 values_per_block, zeros_per_block 324 ) 325 lower_bound = min_values_sparsified / (height * width) 326 upper_bound = min(1.0, max_values_sparsified / (height * width)) 327 328 lower_bound, upper_bound = round(lower_bound, 3), round(upper_bound, 3) 329 return lower_bound, upper_bound 330 331 def check_step(self, data_list, data_with_config, defaults, norm_type="L1"): 332 sparsifier = self._make_sparsifier( 333 data_list, 334 data_with_config, 335 defaults, 336 sparsifier_type=DataNormSparsifier, 337 sparsifier_kwargs={"norm": norm_type}, 338 ) 339 all_data = data_list + data_with_config 340 341 # mask before step() should not be sparsified 342 for some_data in all_data: 343 name, _, _ = self._get_name_data_config(some_data) 344 mask = sparsifier.get_mask(name=name) 345 assert (1.0 - mask.mean()) == 0 # checking sparsity level is 0 346 347 sparsifier.step() 348 349 for some_data in all_data: 350 name, _, _ = self._get_name_data_config(some_data) 351 mask = sparsifier.get_mask(name=name) 352 config = sparsifier.data_groups[name] 353 lb, ub = self._get_bounds_on_actual_sparsity(config, mask.shape) 354 mask = mask.to(torch.float) 355 actual_sparsity = round(1 - mask.mean().item(), 3) 356 assert actual_sparsity >= lb and actual_sparsity <= ub 357 assert ( 358 actual_sparsity > 0.0 359 ) # exact sparsity level cannot be achieved due to size of tensor 360 361 iters_before_collapse = 100 362 363 test_sparsifier = DataNormSparsifier( 364 sparsity_level=0.5, 365 sparse_block_shape=(1, 4), 366 zeros_per_block=4, 367 norm=norm_type, 368 ) 369 370 for _ in range(iters_before_collapse): 371 new_data = torch.randn(20, 20) 372 test_sparsifier.add_data(name="test_data", data=new_data) 373 test_sparsifier.step() 374 mask = test_sparsifier.get_mask(name="test_data") 375 mask = mask.to(torch.float) 376 assert (1.0 - mask.mean().item()) > 0 # some sparsity achieved 377 378 def check_step_2_of_4(self, norm_type): 379 # overriding default config for test purposes 380 default_config = { 381 "sparsity_level": 1.0, 382 "zeros_per_block": 2, 383 "sparse_block_shape": (1, 4), 384 } 385 data_list = [("test_data", torch.randn(4, 4))] 386 387 sparsifier = DataNormSparsifier( 388 data_list=data_list, norm=norm_type, **default_config 389 ) 390 sparsifier.step() 391 392 for some_data in data_list: 393 name, _ = some_data 394 mask = sparsifier.get_mask(name=name) 395 mask = mask.to(torch.float) 396 self.assertAlmostEqual(1.0 - mask.mean().item(), 0.5, places=2) 397 for row in mask: 398 for idx in range(0, len(row), 4): 399 block = row[idx : idx + 4] 400 block, _ = block.sort() 401 assert (block[:2] == 0).all() 402 assert (block[2:] != 0).all() 403 404 def check_sparsity_level( 405 self, data_list, data_with_config, defaults, norm_type="L1" 406 ): 407 sparsity_levels = [-1.0, 0.0, 0.5, 1.0, 2.0] 408 sparse_block_shapes = [(1, 1), (1, 4), (2, 2), (4, 1)] 409 zeros_per_blocks = [0, 1, 2, 3, 4] 410 sparsifier = DataNormSparsifier(data_list=data_list, norm=norm_type) 411 412 testcases = itertools.tee( 413 itertools.product(sparsity_levels, sparse_block_shapes, zeros_per_blocks) 414 ) 415 416 assert ( 417 len(data_with_config) > 0 418 and "name" in data_with_config[0] 419 and "data" in data_with_config[0] 420 ) 421 # get some data 422 name, data = data_with_config[0]["name"], data_with_config[0]["data"] 423 for idx, (sl, sbs, zpb) in enumerate(testcases[0]): 424 new_name = f"{name}_{idx}" 425 if zpb > sbs[0] * sbs[1]: 426 continue 427 current_config = { 428 "sparsity_level": sl, 429 "sparse_block_shape": sbs, 430 "zeros_per_block": zpb, 431 } 432 sparsifier.add_data(name=new_name, data=data, **current_config) 433 if zpb > sbs[0] * sbs[1]: 434 continue 435 436 sparsifier.step() 437 sparsifier.squash_mask() 438 for idx, (sl, sbs, zpb) in enumerate(testcases[0]): 439 new_name = f"{name}_{idx}" 440 sparsified_data = sparsifier.get_data(name=new_name, original=False) 441 # sparse mask 442 sparse_mask = (sparsified_data == 0).float() 443 if zpb == 0: 444 assert sparse_mask.mean() == 0 445 else: 446 # Ratio of individual zeros in the tensor 447 true_sl = min(max(sl, 0.0), 1.0) 448 true_sl = true_sl * zpb / sbs[0] / sbs[1] 449 assert sparse_mask.mean() == true_sl 450 451 452class TestBaseDataSparsifier(_BaseDataSparsiferTestCase): 453 """To add unit tests to support new data types for the BaseDataSparsifier, create the following 454 data_list: List of tuples of name, data to be added to the constructor 455 defaults: default config for the above data in data_list 456 data_with_config: list of dictionaries defining name, data and config (look test_tensors()) 457 458 Once the above is done, create an instance of TestBaseDataSparsifierType and call all the run_tests() 459 """ 460 461 def test_tensors(self): 462 tensor1, tensor2, tensor3 = ( 463 torch.randn(3, 3), 464 torch.randn(4, 4), 465 torch.randn(5, 5), 466 ) 467 tensor4, tensor5 = torch.randn(1, 1), torch.randn(4, 4) 468 data_list = [("tensor1", tensor1), ("tensor2", tensor2), ("tensor3", tensor3)] 469 defaults = {"test": 3} 470 471 data_with_config = [ 472 {"name": "tensor4", "data": tensor4, "config": {"test": 7}}, 473 {"name": "tensor5", "data": tensor5, "config": {"test": 8}}, 474 ] 475 self.run_all_checks( 476 data_list=data_list, defaults=defaults, data_with_config=data_with_config 477 ) 478 479 def test_nn_parameters(self): 480 param1, param2, param3 = ( 481 nn.Parameter(torch.randn(3, 3)), 482 nn.Parameter(torch.randn(4, 4)), 483 nn.Parameter(torch.randn(5, 5)), 484 ) 485 param4, param5 = nn.Parameter(torch.randn(1, 1)), nn.Parameter( 486 torch.randn(4, 4) 487 ) 488 data_list = [("param1", param1), ("param2", param2), ("param3", param3)] 489 defaults = {"test": 3} 490 491 data_with_config = [ 492 {"name": "param4", "data": param4, "config": {"test": 7}}, 493 {"name": "param5", "data": param5, "config": {"test": 8}}, 494 ] 495 self.run_all_checks( 496 data_list=data_list, defaults=defaults, data_with_config=data_with_config 497 ) 498 499 def test_nn_embeddings(self): 500 ( 501 emb1, 502 emb2, 503 ) = nn.Embedding( 504 10, 3 505 ), nn.Embedding(20, 3) 506 emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3) 507 508 emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3) 509 data_list = [ 510 ("emb1", emb1), 511 ("emb1_bag", emb1_bag), 512 ("emb2", emb2), 513 ("emb2_bag", emb2_bag), 514 ] 515 defaults = {"test": 3} 516 517 data_with_config = [ 518 {"name": "emb3", "data": emb3, "config": {"test": 7}}, 519 {"name": "emb3_bag", "data": emb3_bag, "config": {"test": 8}}, 520 ] 521 self.run_all_checks( 522 data_list=data_list, defaults=defaults, data_with_config=data_with_config 523 ) 524 525 526class TestNormDataSparsifiers(_NormDataSparsifierTestCase): 527 """To add unit tests to support new data types for the NormDataSparsifier, create the following 528 data_list: List of tuples of name, data to be added to the constructor 529 defaults: default config for the above data in data_list 530 data_with_config: list of dictionaries defining name, data and config (look test_tensors()) 531 532 Once the above is done, create an instance of _NormDataSparsifierTestRunner and call run_tests() 533 """ 534 535 def test_tensors(self): 536 tensor1, tensor2, tensor3 = ( 537 torch.randn(1, 10), 538 torch.randn(4, 4), 539 torch.randn(1, 5), 540 ) 541 tensor4, tensor5 = torch.randn(1, 2), torch.randn(4, 4) 542 data_list = [("tensor1", tensor1), ("tensor2", tensor2), ("tensor3", tensor3)] 543 defaults = { 544 "sparsity_level": 0.5, 545 "sparse_block_shape": (1, 4), 546 "zeros_per_block": 4, 547 } 548 549 data_with_config = [ 550 { 551 "name": "tensor4", 552 "data": tensor4, 553 "config": { 554 "sparsity_level": 0.7, 555 "sparse_block_shape": (2, 3), 556 "zeros_per_block": 6, 557 }, 558 }, 559 { 560 "name": "tensor5", 561 "data": tensor5, 562 "config": { 563 "sparsity_level": 0.3, 564 "sparse_block_shape": (2, 3), 565 "zeros_per_block": 6, 566 }, 567 }, 568 ] 569 self.run_all_checks( 570 data_list=data_list, 571 defaults=defaults, 572 data_with_config=data_with_config, 573 norm_type="L1", 574 ) 575 self.run_all_checks( 576 data_list=data_list, 577 defaults=defaults, 578 data_with_config=data_with_config, 579 norm_type="L2", 580 ) 581 582 def test_nn_parameters(self): 583 param1, param2, param3 = ( 584 nn.Parameter(torch.randn(1, 8)), 585 nn.Parameter(torch.randn(4, 4)), 586 nn.Parameter(torch.randn(5, 5)), 587 ) 588 param4, param5 = nn.Parameter(torch.randn(10, 10)), nn.Parameter( 589 torch.randn(4, 4) 590 ) 591 data_list = [("param1", param1), ("param2", param2), ("param3", param3)] 592 defaults = { 593 "sparsity_level": 0.5, 594 "sparse_block_shape": (1, 4), 595 "zeros_per_block": 4, 596 } 597 598 data_with_config = [ 599 { 600 "name": "param4", 601 "data": param4, 602 "config": { 603 "sparsity_level": 0.7, 604 "sparse_block_shape": (2, 3), 605 "zeros_per_block": 6, 606 }, 607 }, 608 { 609 "name": "param5", 610 "data": param5, 611 "config": { 612 "sparsity_level": 0.3, 613 "sparse_block_shape": (2, 3), 614 "zeros_per_block": 6, 615 }, 616 }, 617 ] 618 self.run_all_checks( 619 data_list=data_list, 620 defaults=defaults, 621 data_with_config=data_with_config, 622 norm_type="L1", 623 ) 624 self.run_all_checks( 625 data_list=data_list, 626 defaults=defaults, 627 data_with_config=data_with_config, 628 norm_type="L2", 629 ) 630 631 def test_nn_embeddings(self): 632 ( 633 emb1, 634 emb2, 635 ) = nn.Embedding( 636 10, 3 637 ), nn.Embedding(20, 3) 638 emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3) 639 640 emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3) 641 data_list = [ 642 ("emb1", emb1), 643 ("emb1_bag", emb1_bag), 644 ("emb2", emb2), 645 ("emb2_bag", emb2_bag), 646 ] 647 defaults = { 648 "sparsity_level": 0.5, 649 "sparse_block_shape": (1, 4), 650 "zeros_per_block": 4, 651 } 652 653 data_with_config = [ 654 { 655 "name": "emb3", 656 "data": emb3, 657 "config": { 658 "sparsity_level": 0.7, 659 "sparse_block_shape": (2, 3), 660 "zeros_per_block": 6, 661 }, 662 }, 663 { 664 "name": "emb3_bag", 665 "data": emb3_bag, 666 "config": { 667 "sparsity_level": 0.3, 668 "sparse_block_shape": (2, 3), 669 "zeros_per_block": 6, 670 }, 671 }, 672 ] 673 self.run_all_checks( 674 data_list=data_list, 675 defaults=defaults, 676 data_with_config=data_with_config, 677 norm_type="L1", 678 ) 679 680 self.run_all_checks( 681 data_list=data_list, 682 defaults=defaults, 683 data_with_config=data_with_config, 684 norm_type="L2", 685 ) 686 687 688class Model(nn.Module): 689 def __init__(self) -> None: 690 super().__init__() 691 self.emb1 = nn.Embedding(100, 3) 692 self.embbag1 = nn.EmbeddingBag(200, 32) 693 self.emb_seq = nn.Sequential(nn.Embedding(150, 3), nn.EmbeddingBag(100, 3)) 694 self.linear1 = nn.Linear(32, 32) 695 self.linear2 = nn.Linear(16, 16) 696 697 698class TestQuantizationUtils(TestCase): 699 def test_ptq_sparsify_first(self): 700 """The expectation is post_training_sparse_quantize function 701 1. Takes in a model 702 2. Sparsifies the embeddings 703 3. Quantize the embeddings 704 705 This unit test checks that 706 1. Embeddings and EmbeddingBags are sparsified to the right sparsity levels 707 2. Embeddings and EmbeddingBags are quantized 708 3. Linear modules are not quantized 709 """ 710 model = Model() 711 712 sparse_config = {"sparsity_level": 0.80, "sparse_block_shape": (1, 1)} 713 select_embeddings = [model.embbag1, model.emb1] 714 post_training_sparse_quantize( 715 model, 716 data_sparsifier_class=DataNormSparsifier, 717 sparsify_first=True, 718 select_embeddings=select_embeddings, 719 **sparse_config, 720 ) 721 722 assert type(model.emb1) == torch.ao.nn.quantized.modules.embedding_ops.Embedding 723 assert ( 724 type(model.embbag1) 725 == torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag 726 ) 727 assert type(model.emb_seq[0] == nn.Embedding) 728 assert type(model.emb_seq[1] == nn.EmbeddingBag) 729 assert type(model.linear1) == nn.Linear 730 assert type(model.linear2) == nn.Linear 731 732 dequant_emb1 = torch.dequantize(model.emb1.weight()) 733 dequant_embbag1 = torch.dequantize(model.embbag1.weight()) 734 735 threshold = 1e-2 736 737 sl_emb1 = (torch.abs(dequant_emb1) < threshold).float().mean() 738 sl_embbag1 = (torch.abs(dequant_embbag1) < threshold).float().mean() 739 740 assert abs(sl_emb1 - 0.80) <= 0.05 # +- 5% leeway 741 assert abs(sl_embbag1 - 0.80) <= 0.05 # +- 5% leeway 742 743 def test_ptq_quantize_first(self): 744 """The expectation is post_training_sparse_quantize function 745 1. Takes in a model 746 2. Quantize the embeddings 747 3. Sparsifies the embeddings 748 749 This unit test checks that 750 1. Embeddings and EmbeddingBags are sparsified to the right sparsity levels 751 2. Embeddings and EmbeddingBags are quantized 752 3. Linear modules are not quantized 753 """ 754 model = Model() 755 756 sparse_config = {"sparsity_level": 0.8, "sparse_block_shape": (1, 1)} 757 post_training_sparse_quantize( 758 model, DataNormSparsifier, sparsify_first=False, **sparse_config 759 ) 760 761 assert type(model.emb1) == torch.ao.nn.quantized.modules.embedding_ops.Embedding 762 assert ( 763 type(model.embbag1) 764 == torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag 765 ) 766 assert type( 767 model.emb_seq[0] == torch.ao.nn.quantized.modules.embedding_ops.Embedding 768 ) 769 assert type( 770 model.emb_seq[1] == torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag 771 ) 772 assert type(model.linear1) == nn.Linear # not quantized 773 assert type(model.linear2) == nn.Linear # not quantized 774 775 dequant_emb1 = torch.dequantize(model.emb1.weight()) 776 dequant_embbag1 = torch.dequantize(model.embbag1.weight()) 777 dequant_emb_seq_0 = torch.dequantize(model.emb_seq[0].weight()) 778 dequant_emb_seq_1 = torch.dequantize(model.emb_seq[1].weight()) 779 780 # higher threshold as quantization occurs before sparsity 781 threshold = ( 782 1 # zero points seem to have higher magnitude with sparsity occuring after 783 ) 784 785 sl_emb1 = (torch.abs(dequant_emb1) < threshold).float().mean() 786 sl_embbag1 = (torch.abs(dequant_embbag1) < threshold).float().mean() 787 sl_emb_seq_0 = (torch.abs(dequant_emb_seq_0) < threshold).float().mean() 788 sl_emb_seq_1 = (torch.abs(dequant_emb_seq_1) < threshold).float().mean() 789 790 assert abs(sl_emb1 - 0.80) <= 0.05 # +- 5% leeway 791 assert abs(sl_embbag1 - 0.80) <= 0.05 # +- 5% leeway 792 assert abs(sl_emb_seq_0 - 0.80) <= 0.05 # +- 5% leeway 793 assert abs(sl_emb_seq_1 - 0.80) <= 0.05 # +- 5% leeway 794