1# Owner(s): ["module: unknown"] 2 3import itertools 4import logging 5import re 6 7import torch 8from torch import nn 9from torch.ao.pruning import ( 10 BaseSparsifier, 11 FakeSparsity, 12 NearlyDiagonalSparsifier, 13 WeightNormSparsifier, 14) 15from torch.nn.utils.parametrize import is_parametrized 16from torch.testing._internal.common_pruning import ( 17 ImplementedSparsifier, 18 MockSparseLinear, 19 SimpleLinear, 20) 21from torch.testing._internal.common_utils import TestCase 22 23 24logging.basicConfig( 25 format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO 26) 27 28 29class TestBaseSparsifier(TestCase): 30 def test_constructor(self): 31 # Cannot instantiate the abstract base 32 self.assertRaises(TypeError, BaseSparsifier) 33 # Can instantiate the model with no configs 34 model = SimpleLinear() 35 sparsifier = ImplementedSparsifier(test=3) 36 sparsifier.prepare(model, config=None) 37 assert len(sparsifier.groups) == 5 38 sparsifier.step() 39 # Can instantiate the model with configs 40 sparsifier = ImplementedSparsifier(test=3) 41 sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}]) 42 assert len(sparsifier.groups) == 1 43 assert sparsifier.groups[0]["tensor_fqn"] == "linear1.weight" 44 assert "test" in sparsifier.groups[0] 45 assert sparsifier.groups[0]["test"] == 3 46 47 def test_prepare_config(self): 48 model = SimpleLinear() 49 sparsifier = ImplementedSparsifier(test=3) 50 # Make sure there are no parametrizations before `prepare` 51 assert not hasattr(model.seq[0], "parametrizations") 52 assert not hasattr(model.linear1, "parametrizations") 53 assert not hasattr(model.linear2, "parametrizations") 54 sparsifier.prepare( 55 model, 56 config=[ 57 {"tensor_fqn": "seq.0.weight", "test": 42}, 58 # No 'linear1' to make sure it will be skipped in the sparsification 59 {"tensor_fqn": "linear2.weight"}, 60 ], 61 ) 62 assert len(sparsifier.groups) == 2 63 # Check if default argument is not assigned if explicit 64 assert sparsifier.groups[0]["tensor_fqn"] == "seq.0.weight" 65 assert sparsifier.groups[0]["test"] == 42 66 # Check if FQN and module are pointing to the same location 67 assert sparsifier.groups[1]["tensor_fqn"] == "linear2.weight" 68 assert sparsifier.groups[1]["module"] == model.linear2 69 # Check if parameterizations are attached 70 assert hasattr(model.seq[0], "parametrizations") 71 assert not hasattr(model.linear1, "parametrizations") 72 assert hasattr(model.linear2, "parametrizations") 73 74 def test_step(self): 75 model = SimpleLinear() 76 sparsifier = ImplementedSparsifier(test=3) 77 sparsifier.enable_mask_update = True 78 sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}]) 79 sparsifier.step() 80 assert torch.all(model.linear1.parametrizations.weight[0].mask[0] == 0) 81 82 def test_state_dict(self): 83 step_count = 3 84 model0 = SimpleLinear() 85 sparsifier0 = ImplementedSparsifier(test=3) 86 sparsifier0.prepare(model0, [{"tensor_fqn": "linear1.weight"}]) 87 mask = model0.linear1.parametrizations["weight"][0].mask 88 mask.data = torch.arange(mask.shape[0] * mask.shape[1]).reshape(mask.shape) 89 for step in range(step_count): 90 sparsifier0.step() 91 state_dict = sparsifier0.state_dict() 92 93 # Check the expected keys in the state_dict 94 assert "state" in state_dict 95 assert "step_count" in state_dict["state"]["linear1.weight"] 96 assert state_dict["state"]["linear1.weight"]["step_count"] == 3 97 assert "groups" in state_dict 98 assert "test" in state_dict["groups"][0] 99 assert "tensor_fqn" in state_dict["groups"][0] 100 assert state_dict["groups"][0]["tensor_fqn"] == "linear1.weight" 101 102 # Check loading static_dict creates an equivalent model 103 model1 = SimpleLinear() 104 sparsifier1 = ImplementedSparsifier() 105 sparsifier1.prepare(model1, None) 106 107 assert sparsifier0.state != sparsifier1.state 108 109 # Make sure the masks are different in the beginning 110 for mg in sparsifier0.groups: 111 if mg["tensor_fqn"] == "linear1.weight": 112 mask0 = mg["module"].parametrizations.weight[0].mask 113 for mg in sparsifier1.groups: 114 if mg["tensor_fqn"] == "linear1.weight": 115 mask1 = mg["module"].parametrizations.weight[0].mask 116 self.assertNotEqual(mask0, mask1) 117 118 sparsifier1.load_state_dict(state_dict) 119 120 # Make sure the states are loaded, and are correct 121 assert sparsifier0.state == sparsifier1.state 122 123 # Make sure the masks (and all dicts) are the same after loading 124 assert len(sparsifier0.groups) == len(sparsifier1.groups) 125 for idx in range(len(sparsifier0.groups)): 126 mg0 = sparsifier0.groups[idx] 127 mg1 = sparsifier1.groups[idx] 128 for key in mg0.keys(): 129 assert key in mg1 130 if key == "module": 131 # We cannot compare modules as they are different 132 param0 = mg0[key].parametrizations.weight[0] 133 param1 = mg1[key].parametrizations.weight[0] 134 assert hasattr(param0, "mask") 135 assert hasattr(param1, "mask") 136 self.assertEqual(param0.__dict__, param1.__dict__) 137 else: 138 assert mg0[key] == mg1[key] 139 140 def test_convert(self): 141 model = SimpleLinear() 142 sparsifier = ImplementedSparsifier(test=3) 143 sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}]) 144 new_model = sparsifier.convert( 145 model, mapping={nn.Linear: MockSparseLinear}, inplace=False 146 ) 147 148 assert isinstance(new_model.linear1, MockSparseLinear) 149 assert isinstance(new_model.seq[0], nn.Linear) 150 assert isinstance(new_model.linear2, nn.Linear) 151 152 def test_mask_squash(self): 153 model = SimpleLinear() 154 sparsifier = ImplementedSparsifier(test=3) 155 sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}]) 156 assert hasattr(model.linear1.parametrizations.weight[0], "mask") 157 assert is_parametrized(model.linear1, "weight") 158 assert not is_parametrized(model.seq[0], "weight") 159 160 sparsifier.squash_mask() 161 assert not is_parametrized(model.seq[0], "weight") 162 assert not is_parametrized(model.linear1, "weight") 163 164 def test_mask_squash_with_params1(self): 165 model = SimpleLinear() 166 sparsifier = ImplementedSparsifier(foo=3, bar=2, baz=1) 167 sparsifier.prepare( 168 model, [{"tensor_fqn": "linear1.weight"}, {"tensor_fqn": "seq.0.weight"}] 169 ) 170 sparsifier.squash_mask( 171 params_to_keep_per_layer={"linear1": ("foo", "bar"), "seq.0": ("baz",)} 172 ) 173 assert not is_parametrized(model.seq[0], "weight") 174 assert not is_parametrized(model.linear1, "weight") 175 assert hasattr(model.seq[0], "sparse_params") 176 assert hasattr(model.linear1, "sparse_params") 177 assert model.seq[0].sparse_params.get("foo", None) is None 178 assert model.seq[0].sparse_params.get("bar", None) is None 179 assert model.seq[0].sparse_params.get("baz", None) == 1 180 assert model.linear1.sparse_params.get("foo", None) == 3 181 assert model.linear1.sparse_params.get("bar", None) == 2 182 assert model.linear1.sparse_params.get("baz", None) is None 183 184 def test_mask_squash_with_params2(self): 185 model = SimpleLinear() 186 sparsifier = ImplementedSparsifier(foo=3, bar=2, baz=1) 187 sparsifier.prepare( 188 model, [{"tensor_fqn": "linear1.weight"}, {"tensor_fqn": "seq.0.weight"}] 189 ) 190 sparsifier.squash_mask(params_to_keep=("foo", "bar")) 191 assert not is_parametrized(model.seq[0], "weight") 192 assert not is_parametrized(model.linear1, "weight") 193 assert hasattr(model.seq[0], "sparse_params") 194 assert hasattr(model.linear1, "sparse_params") 195 assert model.seq[0].sparse_params.get("foo", None) == 3 196 assert model.seq[0].sparse_params.get("bar", None) == 2 197 assert model.seq[0].sparse_params.get("baz", None) is None 198 assert model.linear1.sparse_params.get("foo", None) == 3 199 assert model.linear1.sparse_params.get("bar", None) == 2 200 assert model.linear1.sparse_params.get("baz", None) is None 201 202 def test_mask_squash_with_params3(self): 203 model = SimpleLinear() 204 sparsifier = ImplementedSparsifier(foo=3, bar=2, baz=1) 205 sparsifier.prepare( 206 model, [{"tensor_fqn": "linear1.weight"}, {"tensor_fqn": "seq.0.weight"}] 207 ) 208 sparsifier.squash_mask( 209 params_to_keep=("foo", "bar"), params_to_keep_per_layer={"seq.0": ("baz",)} 210 ) 211 assert not is_parametrized(model.seq[0], "weight") 212 assert not is_parametrized(model.linear1, "weight") 213 assert hasattr(model.seq[0], "sparse_params") 214 assert hasattr(model.linear1, "sparse_params") 215 assert model.seq[0].sparse_params.get("foo", None) == 3 216 assert model.seq[0].sparse_params.get("bar", None) == 2 217 assert model.seq[0].sparse_params.get("baz", None) == 1 218 assert model.linear1.sparse_params.get("foo", None) == 3 219 assert model.linear1.sparse_params.get("bar", None) == 2 220 assert model.linear1.sparse_params.get("baz", None) is None 221 222 223class TestWeightNormSparsifier(TestCase): 224 def test_constructor(self): 225 model = SimpleLinear() 226 sparsifier = WeightNormSparsifier() 227 sparsifier.prepare(model, config=None) 228 for g in sparsifier.groups: 229 assert isinstance(g["module"], nn.Linear) 230 # The groups are unordered 231 assert g["module_fqn"] in ("seq.0", "seq.1", "seq.2", "linear1", "linear2") 232 233 def test_step(self): 234 model = SimpleLinear() 235 sparsifier = WeightNormSparsifier(sparsity_level=0.5) 236 sparsifier.prepare(model, config=[{"tensor_fqn": "linear1.weight"}]) 237 for g in sparsifier.groups: 238 # Before step 239 module = g["module"] 240 assert ( 241 1.0 - module.parametrizations["weight"][0].mask.mean() 242 ) == 0 # checking sparsity level is 0 243 sparsifier.enable_mask_update = True 244 sparsifier.step() 245 self.assertAlmostEqual( 246 model.linear1.parametrizations["weight"][0].mask.mean().item(), 247 0.5, 248 places=2, 249 ) 250 for g in sparsifier.groups: 251 # After step 252 module = g["module"] 253 assert ( 254 1.0 - module.parametrizations["weight"][0].mask.mean() 255 ) > 0 # checking sparsity level has increased 256 # Test if the mask collapses to all zeros if the weights are randomized 257 iters_before_collapse = 1000 258 for _ in range(iters_before_collapse): 259 model.linear1.weight.data = torch.randn(model.linear1.weight.shape) 260 sparsifier.step() 261 for g in sparsifier.groups: 262 # After step 263 module = g["module"] 264 assert ( 265 1.0 - module.parametrizations["weight"][0].mask.mean() 266 ) > 0 # checking sparsity level did not collapse 267 268 def test_step_2_of_4(self): 269 model = SimpleLinear() 270 sparsifier = WeightNormSparsifier( 271 sparsity_level=1.0, sparse_block_shape=(1, 4), zeros_per_block=2 272 ) 273 sparsifier.prepare(model, config=[{"tensor_fqn": "linear1.weight"}]) 274 sparsifier.step() 275 # make sure the sparsity level is approximately 50% 276 mask = model.linear1.parametrizations["weight"][0].mask.to( 277 torch.float 278 ) # mean works on float only 279 self.assertAlmostEqual(mask.mean().item(), 0.5, places=2) 280 # Make sure each block has exactly 50% zeros 281 module = sparsifier.groups[0]["module"] 282 mask = module.parametrizations["weight"][0].mask 283 for row in mask: 284 for idx in range(0, len(row), 4): 285 block = row[idx : idx + 4] 286 block, _ = block.sort() 287 assert (block[:2] == 0).all() 288 assert (block[2:] != 0).all() 289 290 def test_prepare(self): 291 model = SimpleLinear() 292 sparsifier = WeightNormSparsifier() 293 sparsifier.prepare(model, config=None) 294 for g in sparsifier.groups: 295 module = g["module"] 296 # Check mask exists 297 assert hasattr(module.parametrizations["weight"][0], "mask") 298 # Check parametrization exists and is correct 299 assert is_parametrized(module, "weight") 300 assert type(module.parametrizations.weight[0]) == FakeSparsity 301 302 def test_mask_squash(self): 303 model = SimpleLinear() 304 sparsifier = WeightNormSparsifier() 305 sparsifier.prepare(model, config=None) 306 sparsifier.squash_mask() 307 for g in sparsifier.groups: 308 module = g["module"] 309 assert not is_parametrized(module, "weight") 310 assert not hasattr(module, "mask") 311 312 def test_sparsity_levels(self): 313 sparsity_levels = [-1.0, 0.0, 0.5, 1.0, 2.0] 314 sparse_block_shapes = [(1, 1), (1, 4), (2, 2), (4, 1)] 315 zeros_per_blocks = [0, 1, 2, 3, 4] 316 317 testcases = itertools.tee( 318 itertools.product(sparsity_levels, sparse_block_shapes, zeros_per_blocks) 319 ) 320 # Create a config and model with all the testcases 321 model = nn.Sequential() 322 sparsifier = WeightNormSparsifier() 323 324 sparsity_per_layer_config = [] 325 p = re.compile(r"[-\.\s]") 326 for sl, sbs, zpb in testcases[0]: 327 # Make sure the number of zeros is not > values in a block 328 if zpb > sbs[0] * sbs[1]: 329 continue 330 layer_name = f"{sl}_{sbs}_{zpb}" 331 layer_name = p.sub("_", layer_name) 332 333 layer = nn.Linear(12, 12, bias=False) 334 layer.weight = nn.Parameter(torch.ones(12, 12)) 335 model.add_module(layer_name, layer) 336 config = { 337 "tensor_fqn": layer_name + ".weight", 338 "sparsity_level": sl, 339 "sparse_block_shape": sbs, 340 "zeros_per_block": zpb, 341 } 342 sparsity_per_layer_config.append(config) 343 344 sparsifier.prepare(model, sparsity_per_layer_config) 345 sparsifier.step() 346 sparsifier.squash_mask() 347 model.eval() 348 349 for sl, sbs, zpb in testcases[1]: 350 if zpb > sbs[0] * sbs[1]: 351 continue 352 layer_name = f"{sl}_{sbs}_{zpb}" 353 layer_name = p.sub("_", layer_name) 354 layer = getattr(model, layer_name) 355 356 # Level of sparsity is achieved 357 sparse_mask = (layer.weight == 0).float() 358 if zpb == 0: 359 assert sparse_mask.mean() == 0 360 else: 361 # Ratio of individual zeros in the tensor 362 true_sl = min(max(sl, 0.0), 1.0) 363 true_sl = true_sl * zpb / sbs[0] / sbs[1] 364 assert sparse_mask.mean() == true_sl 365 366 367class TestNearlyDiagonalSparsifier(TestCase): 368 def test_constructor(self): 369 model = SimpleLinear() 370 sparsifier = NearlyDiagonalSparsifier(nearliness=1) 371 sparsifier.prepare(model, config=None) 372 for g in sparsifier.groups: 373 assert isinstance(g["module"], nn.Linear) 374 # The groups are unordered 375 assert g["module_fqn"] in ("seq.0", "seq.1", "seq.2", "linear1", "linear2") 376 377 def test_step(self): 378 model = SimpleLinear() 379 sparsifier = NearlyDiagonalSparsifier(nearliness=1) 380 sparsifier.prepare(model, config=[{"tensor_fqn": "linear1.weight"}]) 381 382 for g in sparsifier.groups: 383 # Before step 384 module = g["module"] 385 assert ( 386 1.0 - module.parametrizations["weight"][0].mask.mean() 387 ) == 0 # checking sparsity level is 0 388 389 sparsifier.enable_mask_update = True 390 sparsifier.step() 391 mask = module.parametrizations["weight"][0].mask 392 height, width = mask.shape 393 assert torch.all(mask == torch.eye(height, width)) 394 395 for g in sparsifier.groups: 396 # After step 397 module = g["module"] 398 assert ( 399 1.0 - module.parametrizations["weight"][0].mask.mean() 400 ) > 0 # checking sparsity level has increased 401 402 # Test if the mask collapses to all zeros if the weights are randomized 403 iters_before_collapse = 1000 404 for _ in range(iters_before_collapse): 405 model.linear1.weight.data = torch.randn(model.linear1.weight.shape) 406 sparsifier.step() 407 for g in sparsifier.groups: 408 # After step 409 module = g["module"] 410 assert ( 411 1.0 - module.parametrizations["weight"][0].mask.mean() 412 ) > 0 # checking sparsity level did not collapse 413 414 def test_prepare(self): 415 model = SimpleLinear() 416 sparsifier = NearlyDiagonalSparsifier(nearliness=1) 417 sparsifier.prepare(model, config=None) 418 for g in sparsifier.groups: 419 module = g["module"] 420 # Check mask exists 421 assert hasattr(module.parametrizations["weight"][0], "mask") 422 # Check parametrization exists and is correct 423 assert is_parametrized(module, "weight") 424 assert type(module.parametrizations.weight[0]) == FakeSparsity 425 426 def test_mask_squash(self): 427 model = SimpleLinear() 428 sparsifier = NearlyDiagonalSparsifier(nearliness=1) 429 sparsifier.prepare(model, config=None) 430 sparsifier.step() 431 sparsifier.squash_mask() 432 for g in sparsifier.groups: 433 module = g["module"] 434 assert not is_parametrized(module, "weight") 435 assert not hasattr(module, "mask") 436 weights = module.weight 437 height, width = weights.shape 438 assert torch.all( 439 weights == torch.eye(height, width) * weights 440 ) # only diagonal to be present 441 442 def test_sparsity_levels(self): 443 nearliness_levels = list(range(-1, 100)) 444 model = nn.Sequential() 445 446 p = re.compile(r"[-\.\s]") 447 for nearliness in nearliness_levels: 448 sparsifier = NearlyDiagonalSparsifier(nearliness=1) 449 layer_name = f"{nearliness}" 450 layer_name = p.sub("_", layer_name) 451 452 layer = nn.Linear(32, 32, bias=False) 453 layer.weight = nn.Parameter(torch.ones(32, 32)) 454 width, height = layer.weight.shape 455 model.add_module(layer_name, layer) 456 config = {"tensor_fqn": layer_name + ".weight", "nearliness": nearliness} 457 458 sparsifier.prepare(model, [config]) 459 # should raise a ValueError when nearliness arg is illegal 460 if (nearliness > 0 and nearliness % 2 == 0) or ( 461 nearliness // 2 >= min(width, height) 462 ): 463 with self.assertRaises(ValueError): 464 sparsifier.step() 465 else: 466 sparsifier.step() 467 sparsifier.squash_mask() 468 model.eval() 469 470 layer = getattr(model, layer_name) 471 # verify that mask created corresponds to the nearliness 472 self._verify_nearliness(layer.weight, nearliness) 473 474 # helper function to verify nearliness of a mask 475 def _verify_nearliness(self, mask: torch.Tensor, nearliness: int): 476 if nearliness <= 0: 477 assert torch.all(mask == torch.zeros(mask.shape[0], mask.shape[1])) 478 else: 479 height, width = mask.shape 480 dist_to_diagonal = nearliness // 2 481 for row in range(0, height): 482 for col in range(0, width): 483 if abs(row - col) <= dist_to_diagonal: 484 assert mask[row, col] == 1 485 else: 486 assert mask[row, col] == 0 487