xref: /aosp_15_r20/external/pytorch/test/ao/sparsity/test_activation_sparsifier.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: unknown"]
2
3import copy
4import logging
5from typing import List
6
7import torch
8import torch.nn as nn
9import torch.nn.functional as F
10from torch.ao.pruning._experimental.activation_sparsifier.activation_sparsifier import (
11    ActivationSparsifier,
12)
13from torch.ao.pruning.sparsifier.utils import module_to_fqn
14from torch.testing._internal.common_utils import skipIfTorchDynamo, TestCase
15
16
17logging.basicConfig(
18    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
19)
20
21
22class Model(nn.Module):
23    def __init__(self) -> None:
24        super().__init__()
25        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
26        self.conv2 = nn.Conv2d(32, 32, kernel_size=3)
27        self.identity1 = nn.Identity()
28        self.max_pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
29
30        self.linear1 = nn.Linear(4608, 128)
31        self.identity2 = nn.Identity()
32        self.linear2 = nn.Linear(128, 10)
33
34    def forward(self, x):
35        out = self.conv1(x)
36        out = self.conv2(out)
37        out = self.identity1(out)
38        out = self.max_pool1(out)
39
40        batch_size = x.shape[0]
41        out = out.reshape(batch_size, -1)
42
43        out = F.relu(self.identity2(self.linear1(out)))
44        out = self.linear2(out)
45        return out
46
47
48class TestActivationSparsifier(TestCase):
49    def _check_constructor(self, activation_sparsifier, model, defaults, sparse_config):
50        """Helper function to check if the model, defaults and sparse_config are loaded correctly
51        in the activation sparsifier
52        """
53        sparsifier_defaults = activation_sparsifier.defaults
54        combined_defaults = {**defaults, "sparse_config": sparse_config}
55
56        # more keys are populated in activation sparsifier (eventhough they may be None)
57        assert len(combined_defaults) <= len(activation_sparsifier.defaults)
58
59        for key, config in sparsifier_defaults.items():
60            # all the keys in combined_defaults should be present in sparsifier defaults
61            assert config == combined_defaults.get(key, None)
62
63    def _check_register_layer(
64        self, activation_sparsifier, defaults, sparse_config, layer_args_list
65    ):
66        """Checks if layers in the model are correctly mapped to it's arguments.
67
68        Args:
69            activation_sparsifier (sparsifier object)
70                activation sparsifier object that is being tested.
71
72            defaults (Dict)
73                all default config (except sparse_config)
74
75            sparse_config (Dict)
76                default sparse config passed to the sparsifier
77
78            layer_args_list (list of tuples)
79                Each entry in the list corresponds to the layer arguments.
80                First entry in the tuple corresponds to all the arguments other than sparse_config
81                Second entry in the tuple corresponds to sparse_config
82        """
83        # check args
84        data_groups = activation_sparsifier.data_groups
85        assert len(data_groups) == len(layer_args_list)
86        for layer_args in layer_args_list:
87            layer_arg, sparse_config_layer = layer_args
88
89            # check sparse config
90            sparse_config_actual = copy.deepcopy(sparse_config)
91            sparse_config_actual.update(sparse_config_layer)
92
93            name = module_to_fqn(activation_sparsifier.model, layer_arg["layer"])
94
95            assert data_groups[name]["sparse_config"] == sparse_config_actual
96
97            # assert the rest
98            other_config_actual = copy.deepcopy(defaults)
99            other_config_actual.update(layer_arg)
100            other_config_actual.pop("layer")
101
102            for key, value in other_config_actual.items():
103                assert key in data_groups[name]
104                assert value == data_groups[name][key]
105
106            # get_mask should raise error
107            with self.assertRaises(ValueError):
108                activation_sparsifier.get_mask(name=name)
109
110    def _check_pre_forward_hook(self, activation_sparsifier, data_list):
111        """Registering a layer attaches a pre-forward hook to that layer. This function
112        checks if the pre-forward hook works as expected. Specifically, checks if the
113        input is aggregated correctly.
114
115        Basically, asserts that the aggregate of input activations is the same as what was
116        computed in the sparsifier.
117
118        Args:
119            activation_sparsifier (sparsifier object)
120                activation sparsifier object that is being tested.
121
122            data_list (list of torch tensors)
123                data input to the model attached to the sparsifier
124
125        """
126        # can only check for the first layer
127        data_agg_actual = data_list[0]
128        model = activation_sparsifier.model
129        layer_name = module_to_fqn(model, model.conv1)
130        agg_fn = activation_sparsifier.data_groups[layer_name]["aggregate_fn"]
131
132        for i in range(1, len(data_list)):
133            data_agg_actual = agg_fn(data_agg_actual, data_list[i])
134
135        assert "data" in activation_sparsifier.data_groups[layer_name]
136        assert torch.all(
137            activation_sparsifier.data_groups[layer_name]["data"] == data_agg_actual
138        )
139
140        return data_agg_actual
141
142    def _check_step(self, activation_sparsifier, data_agg_actual):
143        """Checks if .step() works as expected. Specifically, checks if the mask is computed correctly.
144
145        Args:
146            activation_sparsifier (sparsifier object)
147                activation sparsifier object that is being tested.
148
149            data_agg_actual (torch tensor)
150                aggregated torch tensor
151
152        """
153        model = activation_sparsifier.model
154        layer_name = module_to_fqn(model, model.conv1)
155        assert layer_name is not None
156
157        reduce_fn = activation_sparsifier.data_groups[layer_name]["reduce_fn"]
158
159        data_reduce_actual = reduce_fn(data_agg_actual)
160        mask_fn = activation_sparsifier.data_groups[layer_name]["mask_fn"]
161        sparse_config = activation_sparsifier.data_groups[layer_name]["sparse_config"]
162        mask_actual = mask_fn(data_reduce_actual, **sparse_config)
163
164        mask_model = activation_sparsifier.get_mask(layer_name)
165
166        assert torch.all(mask_model == mask_actual)
167
168        for config in activation_sparsifier.data_groups.values():
169            assert "data" not in config
170
171    def _check_squash_mask(self, activation_sparsifier, data):
172        """Makes sure that squash_mask() works as usual. Specifically, checks
173        if the sparsifier hook is attached correctly.
174        This is achieved by only looking at the identity layers and making sure that
175        the output == layer(input * mask).
176
177        Args:
178            activation_sparsifier (sparsifier object)
179                activation sparsifier object that is being tested.
180
181            data (torch tensor)
182                dummy batched data
183        """
184
185        # create a forward hook for checking output == layer(input * mask)
186        def check_output(name):
187            mask = activation_sparsifier.get_mask(name)
188            features = activation_sparsifier.data_groups[name].get("features")
189            feature_dim = activation_sparsifier.data_groups[name].get("feature_dim")
190
191            def hook(module, input, output):
192                input_data = input[0]
193                if features is None:
194                    assert torch.all(mask * input_data == output)
195                else:
196                    for feature_idx in range(0, len(features)):
197                        feature = torch.Tensor(
198                            [features[feature_idx]], device=input_data.device
199                        ).long()
200                        inp_data_feature = torch.index_select(
201                            input_data, feature_dim, feature
202                        )
203                        out_data_feature = torch.index_select(
204                            output, feature_dim, feature
205                        )
206
207                        assert torch.all(
208                            mask[feature_idx] * inp_data_feature == out_data_feature
209                        )
210
211            return hook
212
213        for name, config in activation_sparsifier.data_groups.items():
214            if "identity" in name:
215                config["layer"].register_forward_hook(check_output(name))
216
217        activation_sparsifier.model(data)
218
219    def _check_state_dict(self, sparsifier1):
220        """Checks if loading and restoring of state_dict() works as expected.
221        Basically, dumps the state of the sparsifier and loads it in the other sparsifier
222        and checks if all the configuration are in line.
223
224        This function is called at various times in the workflow to makes sure that the sparsifier
225        can be dumped and restored at any point in time.
226        """
227        state_dict = sparsifier1.state_dict()
228
229        new_model = Model()
230
231        # create an empty new sparsifier
232        sparsifier2 = ActivationSparsifier(new_model)
233
234        assert sparsifier2.defaults != sparsifier1.defaults
235        assert len(sparsifier2.data_groups) != len(sparsifier1.data_groups)
236
237        sparsifier2.load_state_dict(state_dict)
238
239        assert sparsifier2.defaults == sparsifier1.defaults
240
241        for name, state in sparsifier2.state.items():
242            assert name in sparsifier1.state
243            mask1 = sparsifier1.state[name]["mask"]
244            mask2 = state["mask"]
245
246            if mask1 is None:
247                assert mask2 is None
248            else:
249                assert type(mask1) == type(mask2)
250                if isinstance(mask1, List):
251                    assert len(mask1) == len(mask2)
252                    for idx in range(len(mask1)):
253                        assert torch.all(mask1[idx] == mask2[idx])
254                else:
255                    assert torch.all(mask1 == mask2)
256
257        # make sure that the state dict is stored as torch sparse
258        for state in state_dict["state"].values():
259            mask = state["mask"]
260            if mask is not None:
261                if isinstance(mask, List):
262                    for idx in range(len(mask)):
263                        assert mask[idx].is_sparse
264                else:
265                    assert mask.is_sparse
266
267        dg1, dg2 = sparsifier1.data_groups, sparsifier2.data_groups
268
269        for layer_name, config in dg1.items():
270            assert layer_name in dg2
271
272            # exclude hook and layer
273            config1 = {
274                key: value
275                for key, value in config.items()
276                if key not in ["hook", "layer"]
277            }
278            config2 = {
279                key: value
280                for key, value in dg2[layer_name].items()
281                if key not in ["hook", "layer"]
282            }
283
284            assert config1 == config2
285
286    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
287    def test_activation_sparsifier(self):
288        """Simulates the workflow of the activation sparsifier, starting from object creation
289        till squash_mask().
290        The idea is to check that everything works as expected while in the workflow.
291        """
292
293        # defining aggregate, reduce and mask functions
294        def agg_fn(x, y):
295            return x + y
296
297        def reduce_fn(x):
298            return torch.mean(x, dim=0)
299
300        def _vanilla_norm_sparsifier(data, sparsity_level):
301            r"""Similar to data norm sparsifier but block_shape = (1,1).
302            Simply, flatten the data, sort it and mask out the values less than threshold
303            """
304            data_norm = torch.abs(data).flatten()
305            _, sorted_idx = torch.sort(data_norm)
306            threshold_idx = round(sparsity_level * len(sorted_idx))
307            sorted_idx = sorted_idx[:threshold_idx]
308
309            mask = torch.ones_like(data_norm)
310            mask.scatter_(dim=0, index=sorted_idx, value=0)
311            mask = mask.reshape(data.shape)
312
313            return mask
314
315        # Creating default function and sparse configs
316        # default sparse_config
317        sparse_config = {"sparsity_level": 0.5}
318
319        defaults = {"aggregate_fn": agg_fn, "reduce_fn": reduce_fn}
320
321        # simulate the workflow
322        # STEP 1: make data and activation sparsifier object
323        model = Model()  # create model
324        activation_sparsifier = ActivationSparsifier(model, **defaults, **sparse_config)
325
326        # Test Constructor
327        self._check_constructor(activation_sparsifier, model, defaults, sparse_config)
328
329        # STEP 2: Register some layers
330        register_layer1_args = {
331            "layer": model.conv1,
332            "mask_fn": _vanilla_norm_sparsifier,
333        }
334        sparse_config_layer1 = {"sparsity_level": 0.3}
335
336        register_layer2_args = {
337            "layer": model.linear1,
338            "features": [0, 10, 234],
339            "feature_dim": 1,
340            "mask_fn": _vanilla_norm_sparsifier,
341        }
342        sparse_config_layer2 = {"sparsity_level": 0.1}
343
344        register_layer3_args = {
345            "layer": model.identity1,
346            "mask_fn": _vanilla_norm_sparsifier,
347        }
348        sparse_config_layer3 = {"sparsity_level": 0.3}
349
350        register_layer4_args = {
351            "layer": model.identity2,
352            "features": [0, 10, 20],
353            "feature_dim": 1,
354            "mask_fn": _vanilla_norm_sparsifier,
355        }
356        sparse_config_layer4 = {"sparsity_level": 0.1}
357
358        layer_args_list = [
359            (register_layer1_args, sparse_config_layer1),
360            (register_layer2_args, sparse_config_layer2),
361        ]
362        layer_args_list += [
363            (register_layer3_args, sparse_config_layer3),
364            (register_layer4_args, sparse_config_layer4),
365        ]
366
367        # Registering..
368        for layer_args in layer_args_list:
369            layer_arg, sparse_config_layer = layer_args
370            activation_sparsifier.register_layer(**layer_arg, **sparse_config_layer)
371
372        # check if things are registered correctly
373        self._check_register_layer(
374            activation_sparsifier, defaults, sparse_config, layer_args_list
375        )
376
377        # check state_dict after registering and before model forward
378        self._check_state_dict(activation_sparsifier)
379
380        # check if forward pre hooks actually work
381        # some dummy data
382        data_list = []
383        num_data_points = 5
384        for _ in range(0, num_data_points):
385            rand_data = torch.randn(16, 1, 28, 28)
386            activation_sparsifier.model(rand_data)
387            data_list.append(rand_data)
388
389        data_agg_actual = self._check_pre_forward_hook(activation_sparsifier, data_list)
390        # check state_dict() before step()
391        self._check_state_dict(activation_sparsifier)
392
393        # STEP 3: sparsifier step
394        activation_sparsifier.step()
395
396        # check state_dict() after step() and before squash_mask()
397        self._check_state_dict(activation_sparsifier)
398
399        # self.check_step()
400        self._check_step(activation_sparsifier, data_agg_actual)
401
402        # STEP 4: squash mask
403        activation_sparsifier.squash_mask()
404
405        self._check_squash_mask(activation_sparsifier, data_list[0])
406
407        # check state_dict() after squash_mask()
408        self._check_state_dict(activation_sparsifier)
409