1# Owner(s): ["module: unknown"] 2 3import copy 4import logging 5from typing import List 6 7import torch 8import torch.nn as nn 9import torch.nn.functional as F 10from torch.ao.pruning._experimental.activation_sparsifier.activation_sparsifier import ( 11 ActivationSparsifier, 12) 13from torch.ao.pruning.sparsifier.utils import module_to_fqn 14from torch.testing._internal.common_utils import skipIfTorchDynamo, TestCase 15 16 17logging.basicConfig( 18 format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO 19) 20 21 22class Model(nn.Module): 23 def __init__(self) -> None: 24 super().__init__() 25 self.conv1 = nn.Conv2d(1, 32, kernel_size=3) 26 self.conv2 = nn.Conv2d(32, 32, kernel_size=3) 27 self.identity1 = nn.Identity() 28 self.max_pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 29 30 self.linear1 = nn.Linear(4608, 128) 31 self.identity2 = nn.Identity() 32 self.linear2 = nn.Linear(128, 10) 33 34 def forward(self, x): 35 out = self.conv1(x) 36 out = self.conv2(out) 37 out = self.identity1(out) 38 out = self.max_pool1(out) 39 40 batch_size = x.shape[0] 41 out = out.reshape(batch_size, -1) 42 43 out = F.relu(self.identity2(self.linear1(out))) 44 out = self.linear2(out) 45 return out 46 47 48class TestActivationSparsifier(TestCase): 49 def _check_constructor(self, activation_sparsifier, model, defaults, sparse_config): 50 """Helper function to check if the model, defaults and sparse_config are loaded correctly 51 in the activation sparsifier 52 """ 53 sparsifier_defaults = activation_sparsifier.defaults 54 combined_defaults = {**defaults, "sparse_config": sparse_config} 55 56 # more keys are populated in activation sparsifier (eventhough they may be None) 57 assert len(combined_defaults) <= len(activation_sparsifier.defaults) 58 59 for key, config in sparsifier_defaults.items(): 60 # all the keys in combined_defaults should be present in sparsifier defaults 61 assert config == combined_defaults.get(key, None) 62 63 def _check_register_layer( 64 self, activation_sparsifier, defaults, sparse_config, layer_args_list 65 ): 66 """Checks if layers in the model are correctly mapped to it's arguments. 67 68 Args: 69 activation_sparsifier (sparsifier object) 70 activation sparsifier object that is being tested. 71 72 defaults (Dict) 73 all default config (except sparse_config) 74 75 sparse_config (Dict) 76 default sparse config passed to the sparsifier 77 78 layer_args_list (list of tuples) 79 Each entry in the list corresponds to the layer arguments. 80 First entry in the tuple corresponds to all the arguments other than sparse_config 81 Second entry in the tuple corresponds to sparse_config 82 """ 83 # check args 84 data_groups = activation_sparsifier.data_groups 85 assert len(data_groups) == len(layer_args_list) 86 for layer_args in layer_args_list: 87 layer_arg, sparse_config_layer = layer_args 88 89 # check sparse config 90 sparse_config_actual = copy.deepcopy(sparse_config) 91 sparse_config_actual.update(sparse_config_layer) 92 93 name = module_to_fqn(activation_sparsifier.model, layer_arg["layer"]) 94 95 assert data_groups[name]["sparse_config"] == sparse_config_actual 96 97 # assert the rest 98 other_config_actual = copy.deepcopy(defaults) 99 other_config_actual.update(layer_arg) 100 other_config_actual.pop("layer") 101 102 for key, value in other_config_actual.items(): 103 assert key in data_groups[name] 104 assert value == data_groups[name][key] 105 106 # get_mask should raise error 107 with self.assertRaises(ValueError): 108 activation_sparsifier.get_mask(name=name) 109 110 def _check_pre_forward_hook(self, activation_sparsifier, data_list): 111 """Registering a layer attaches a pre-forward hook to that layer. This function 112 checks if the pre-forward hook works as expected. Specifically, checks if the 113 input is aggregated correctly. 114 115 Basically, asserts that the aggregate of input activations is the same as what was 116 computed in the sparsifier. 117 118 Args: 119 activation_sparsifier (sparsifier object) 120 activation sparsifier object that is being tested. 121 122 data_list (list of torch tensors) 123 data input to the model attached to the sparsifier 124 125 """ 126 # can only check for the first layer 127 data_agg_actual = data_list[0] 128 model = activation_sparsifier.model 129 layer_name = module_to_fqn(model, model.conv1) 130 agg_fn = activation_sparsifier.data_groups[layer_name]["aggregate_fn"] 131 132 for i in range(1, len(data_list)): 133 data_agg_actual = agg_fn(data_agg_actual, data_list[i]) 134 135 assert "data" in activation_sparsifier.data_groups[layer_name] 136 assert torch.all( 137 activation_sparsifier.data_groups[layer_name]["data"] == data_agg_actual 138 ) 139 140 return data_agg_actual 141 142 def _check_step(self, activation_sparsifier, data_agg_actual): 143 """Checks if .step() works as expected. Specifically, checks if the mask is computed correctly. 144 145 Args: 146 activation_sparsifier (sparsifier object) 147 activation sparsifier object that is being tested. 148 149 data_agg_actual (torch tensor) 150 aggregated torch tensor 151 152 """ 153 model = activation_sparsifier.model 154 layer_name = module_to_fqn(model, model.conv1) 155 assert layer_name is not None 156 157 reduce_fn = activation_sparsifier.data_groups[layer_name]["reduce_fn"] 158 159 data_reduce_actual = reduce_fn(data_agg_actual) 160 mask_fn = activation_sparsifier.data_groups[layer_name]["mask_fn"] 161 sparse_config = activation_sparsifier.data_groups[layer_name]["sparse_config"] 162 mask_actual = mask_fn(data_reduce_actual, **sparse_config) 163 164 mask_model = activation_sparsifier.get_mask(layer_name) 165 166 assert torch.all(mask_model == mask_actual) 167 168 for config in activation_sparsifier.data_groups.values(): 169 assert "data" not in config 170 171 def _check_squash_mask(self, activation_sparsifier, data): 172 """Makes sure that squash_mask() works as usual. Specifically, checks 173 if the sparsifier hook is attached correctly. 174 This is achieved by only looking at the identity layers and making sure that 175 the output == layer(input * mask). 176 177 Args: 178 activation_sparsifier (sparsifier object) 179 activation sparsifier object that is being tested. 180 181 data (torch tensor) 182 dummy batched data 183 """ 184 185 # create a forward hook for checking output == layer(input * mask) 186 def check_output(name): 187 mask = activation_sparsifier.get_mask(name) 188 features = activation_sparsifier.data_groups[name].get("features") 189 feature_dim = activation_sparsifier.data_groups[name].get("feature_dim") 190 191 def hook(module, input, output): 192 input_data = input[0] 193 if features is None: 194 assert torch.all(mask * input_data == output) 195 else: 196 for feature_idx in range(0, len(features)): 197 feature = torch.Tensor( 198 [features[feature_idx]], device=input_data.device 199 ).long() 200 inp_data_feature = torch.index_select( 201 input_data, feature_dim, feature 202 ) 203 out_data_feature = torch.index_select( 204 output, feature_dim, feature 205 ) 206 207 assert torch.all( 208 mask[feature_idx] * inp_data_feature == out_data_feature 209 ) 210 211 return hook 212 213 for name, config in activation_sparsifier.data_groups.items(): 214 if "identity" in name: 215 config["layer"].register_forward_hook(check_output(name)) 216 217 activation_sparsifier.model(data) 218 219 def _check_state_dict(self, sparsifier1): 220 """Checks if loading and restoring of state_dict() works as expected. 221 Basically, dumps the state of the sparsifier and loads it in the other sparsifier 222 and checks if all the configuration are in line. 223 224 This function is called at various times in the workflow to makes sure that the sparsifier 225 can be dumped and restored at any point in time. 226 """ 227 state_dict = sparsifier1.state_dict() 228 229 new_model = Model() 230 231 # create an empty new sparsifier 232 sparsifier2 = ActivationSparsifier(new_model) 233 234 assert sparsifier2.defaults != sparsifier1.defaults 235 assert len(sparsifier2.data_groups) != len(sparsifier1.data_groups) 236 237 sparsifier2.load_state_dict(state_dict) 238 239 assert sparsifier2.defaults == sparsifier1.defaults 240 241 for name, state in sparsifier2.state.items(): 242 assert name in sparsifier1.state 243 mask1 = sparsifier1.state[name]["mask"] 244 mask2 = state["mask"] 245 246 if mask1 is None: 247 assert mask2 is None 248 else: 249 assert type(mask1) == type(mask2) 250 if isinstance(mask1, List): 251 assert len(mask1) == len(mask2) 252 for idx in range(len(mask1)): 253 assert torch.all(mask1[idx] == mask2[idx]) 254 else: 255 assert torch.all(mask1 == mask2) 256 257 # make sure that the state dict is stored as torch sparse 258 for state in state_dict["state"].values(): 259 mask = state["mask"] 260 if mask is not None: 261 if isinstance(mask, List): 262 for idx in range(len(mask)): 263 assert mask[idx].is_sparse 264 else: 265 assert mask.is_sparse 266 267 dg1, dg2 = sparsifier1.data_groups, sparsifier2.data_groups 268 269 for layer_name, config in dg1.items(): 270 assert layer_name in dg2 271 272 # exclude hook and layer 273 config1 = { 274 key: value 275 for key, value in config.items() 276 if key not in ["hook", "layer"] 277 } 278 config2 = { 279 key: value 280 for key, value in dg2[layer_name].items() 281 if key not in ["hook", "layer"] 282 } 283 284 assert config1 == config2 285 286 @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 287 def test_activation_sparsifier(self): 288 """Simulates the workflow of the activation sparsifier, starting from object creation 289 till squash_mask(). 290 The idea is to check that everything works as expected while in the workflow. 291 """ 292 293 # defining aggregate, reduce and mask functions 294 def agg_fn(x, y): 295 return x + y 296 297 def reduce_fn(x): 298 return torch.mean(x, dim=0) 299 300 def _vanilla_norm_sparsifier(data, sparsity_level): 301 r"""Similar to data norm sparsifier but block_shape = (1,1). 302 Simply, flatten the data, sort it and mask out the values less than threshold 303 """ 304 data_norm = torch.abs(data).flatten() 305 _, sorted_idx = torch.sort(data_norm) 306 threshold_idx = round(sparsity_level * len(sorted_idx)) 307 sorted_idx = sorted_idx[:threshold_idx] 308 309 mask = torch.ones_like(data_norm) 310 mask.scatter_(dim=0, index=sorted_idx, value=0) 311 mask = mask.reshape(data.shape) 312 313 return mask 314 315 # Creating default function and sparse configs 316 # default sparse_config 317 sparse_config = {"sparsity_level": 0.5} 318 319 defaults = {"aggregate_fn": agg_fn, "reduce_fn": reduce_fn} 320 321 # simulate the workflow 322 # STEP 1: make data and activation sparsifier object 323 model = Model() # create model 324 activation_sparsifier = ActivationSparsifier(model, **defaults, **sparse_config) 325 326 # Test Constructor 327 self._check_constructor(activation_sparsifier, model, defaults, sparse_config) 328 329 # STEP 2: Register some layers 330 register_layer1_args = { 331 "layer": model.conv1, 332 "mask_fn": _vanilla_norm_sparsifier, 333 } 334 sparse_config_layer1 = {"sparsity_level": 0.3} 335 336 register_layer2_args = { 337 "layer": model.linear1, 338 "features": [0, 10, 234], 339 "feature_dim": 1, 340 "mask_fn": _vanilla_norm_sparsifier, 341 } 342 sparse_config_layer2 = {"sparsity_level": 0.1} 343 344 register_layer3_args = { 345 "layer": model.identity1, 346 "mask_fn": _vanilla_norm_sparsifier, 347 } 348 sparse_config_layer3 = {"sparsity_level": 0.3} 349 350 register_layer4_args = { 351 "layer": model.identity2, 352 "features": [0, 10, 20], 353 "feature_dim": 1, 354 "mask_fn": _vanilla_norm_sparsifier, 355 } 356 sparse_config_layer4 = {"sparsity_level": 0.1} 357 358 layer_args_list = [ 359 (register_layer1_args, sparse_config_layer1), 360 (register_layer2_args, sparse_config_layer2), 361 ] 362 layer_args_list += [ 363 (register_layer3_args, sparse_config_layer3), 364 (register_layer4_args, sparse_config_layer4), 365 ] 366 367 # Registering.. 368 for layer_args in layer_args_list: 369 layer_arg, sparse_config_layer = layer_args 370 activation_sparsifier.register_layer(**layer_arg, **sparse_config_layer) 371 372 # check if things are registered correctly 373 self._check_register_layer( 374 activation_sparsifier, defaults, sparse_config, layer_args_list 375 ) 376 377 # check state_dict after registering and before model forward 378 self._check_state_dict(activation_sparsifier) 379 380 # check if forward pre hooks actually work 381 # some dummy data 382 data_list = [] 383 num_data_points = 5 384 for _ in range(0, num_data_points): 385 rand_data = torch.randn(16, 1, 28, 28) 386 activation_sparsifier.model(rand_data) 387 data_list.append(rand_data) 388 389 data_agg_actual = self._check_pre_forward_hook(activation_sparsifier, data_list) 390 # check state_dict() before step() 391 self._check_state_dict(activation_sparsifier) 392 393 # STEP 3: sparsifier step 394 activation_sparsifier.step() 395 396 # check state_dict() after step() and before squash_mask() 397 self._check_state_dict(activation_sparsifier) 398 399 # self.check_step() 400 self._check_step(activation_sparsifier, data_agg_actual) 401 402 # STEP 4: squash mask 403 activation_sparsifier.squash_mask() 404 405 self._check_squash_mask(activation_sparsifier, data_list[0]) 406 407 # check state_dict() after squash_mask() 408 self._check_state_dict(activation_sparsifier) 409