xref: /aosp_15_r20/external/pytorch/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import abc
3import copy
4import sys
5import warnings
6from collections import defaultdict
7from typing import Any, Dict, List, Optional, Tuple
8
9import torch
10from torch import nn
11from torch.ao.pruning.sparsifier import base_sparsifier, utils
12from torch.nn.utils import parametrize
13
14
15if not sys.warnoptions:
16    # to suppress repeated warnings when being used in a training loop.
17    warnings.simplefilter("once")
18
19__all__ = ["BaseDataSparsifier"]
20
21EMBEDDING_TYPES = {
22    nn.Embedding,
23    nn.EmbeddingBag,
24}
25
26SUPPORTED_TYPES = {
27    torch.Tensor,
28    nn.Parameter,
29    *EMBEDDING_TYPES,
30}
31
32
33class _Container(nn.Module):
34    pass
35
36
37class BaseDataSparsifier(base_sparsifier.BaseSparsifier):
38    r"""
39    Base Data Sparsifier class for all Data sparsifiers.
40    The abstract class accepts raw torch tensors / embedding / embedding bags (refer to SUPPORTED_TYPES above)
41    to prepare for sparsification.
42    In this case, mask (and parametrizations) is owned by the class and not by the user.
43    Specifically, the container object inside the class maintains the mask and parametrizations of the input data
44
45    Args:
46        data_list (list of tuples)
47            list of (name, data) tuples to sparsify. Lookup SUPPORTED_TYPES
48            for type of data. Internally, a container module handles the data sparsification.
49
50        defaults (dict)
51            default configurations will be attached to the
52            configuration. Only the keys that don't exist in the `config` will
53            be updated.
54    Example::
55        >>> # xdoctest: +SKIP
56        >>> data_list = [('tensor_1', torch.randn(3,3)), ('tensor_2', torch.randn(4,4))]
57        >>> defaults = {'sparsity_level': 0.7}
58        >>> sparsifier = DerivedDataSparsifier(data_list = data_list, **defaults) # Some sparsifier that inherits BaseDataSparsifier
59        >>> new_tensor_to_add = {'name': 'tensor_3', 'data': torch.randn(5,5), 'sparsity_level': 0.3}
60        >>> sparsifier.add_data(**new_tensor_to_add)
61        >>> # tensor_1 and tensor_2 will have sparsity_level of 0.7 but tensor_3 will have sparsity_level=0.3
62    """
63
64    def __init__(self, data_list: Optional[List[Tuple[str, Any]]] = None, **defaults):
65        super().__init__(defaults=defaults)
66
67        self._container = _Container()
68
69        self.data_groups: Dict[str, Dict] = defaultdict(dict)  # name -> {**config}
70        if data_list is not None:
71            # add data with default config here
72            [self.add_data(name, data, **self.defaults) for name, data in data_list]
73
74    def prepare(self):
75        raise NotImplementedError("this function is undefined for this class")
76
77    def _extract_weight(self, data):
78        # extract the weight parameter instead of underlying data
79        if type(data) in [torch.Tensor, nn.Parameter]:
80            return data
81        elif type(data) in EMBEDDING_TYPES:
82            return data.weight
83
84    def add_data(self, name: str, data, reuse_mask=True, **config):
85        r"""Configures and parametrizes the internal container model with name and data.
86
87        **Note**:
88            1. If the data with name already exists, it replaces the data.
89            2. While replacing, the old mask is reused when `reuse_mask=True`
90            3. If `reuse_mask=True`, then the replacing data needs to have the same shape as that of old data.
91            4. By default, the config of the replaced data is used as config for the replacing data, unless something
92               is specified in the config dictionary.
93        """
94        assert (
95            type(data) in SUPPORTED_TYPES
96        ), "specified data type not supported at the moment"
97        local_args = copy.deepcopy(self.defaults)
98        local_args.update(config)
99        weight = self._extract_weight(data)
100
101        # Bookkeeping in the container class
102        mask = local_args.get("mask", torch.ones_like(weight))
103        param_class = local_args.get("parametrization", utils.FakeSparsity)
104
105        if name in self.state:
106            # If the named data already exists - replace
107            warnings.warn(
108                "Replacing existing data of the same name. - Did you mean a different name?"
109            )
110
111            # reuse old config
112            old_args = self.data_groups[name]
113            local_args = copy.deepcopy(old_args)
114            local_args.update(config)
115
116            if reuse_mask:
117                current_data = self.get_data(name=name)
118                assert (
119                    weight.shape == current_data.shape
120                ), "to retain the old mask, the shape of the new data must be the same as the previous one"
121                mask = self.get_mask(
122                    name=name
123                )  # reuse mask instead of creating a new one
124
125            self._delete_data(name=name)
126
127        # parameter creates a deepcopy of the weight inside, so create a buffer
128        self._container.register_buffer(name=name, tensor=weight)
129        parametrize.register_parametrization(self._container, name, param_class(mask))
130        self.state[name]["mask"] = mask
131        self.data_groups[name] = local_args
132        return getattr(self._container, name)
133
134    def get_data(self, name: str, return_original: bool = True):
135        r"""Returns weight tensor (or data)
136        Args:
137            - name: name of the data to be returned
138            - return_original returns weight tensor without applying parametrization if True
139                else - returns the sparsified version (parametrized)
140        """
141        if name not in self.data_groups:
142            raise ValueError("data with specified name does not exist")
143
144        if return_original:
145            if not parametrize.is_parametrized(self._container, name):
146                raise ValueError("mask squashed - original mask value does not exist")
147            data = getattr(self._container.parametrizations, name).original
148            return data
149        else:
150            return getattr(self._container, name)
151
152    def _convert_mask(self, states, sparse_coo=True):
153        r"""Converts the mask to sparse coo or dense tensors depending on the `sparse_coo` argument."""
154        states = copy.deepcopy(states)
155        for state in states.values():
156            if sparse_coo:
157                state["mask"] = state["mask"].to_sparse_coo()
158            else:
159                state["mask"] = state["mask"].to_dense()
160
161        return states
162
163    def state_dict(self):
164        r"""Returns the state of the optimizer as a :class:`dict`.
165
166        It contains:
167        * state - contains name -> mask mapping.
168        * data_groups - a list containing all sparsity configuration groups
169            with the key name specifying the name of the data
170        * container_state_dict - the state dictionary of the internal
171            container model used for sparsification
172        """
173        state = self._convert_mask(self.state)
174        return {
175            "state": state,
176            "data_groups": self.data_groups,
177            "_container": self._container.state_dict(),
178        }
179
180    def _load_container_from_state(self, states, data_groups, container_state_dict):
181        r"""This restores the state of the container specifically based on the data present in state and data_groups
182        If the data was parametrized, then the data would be added to the container and then parametrized,
183        else it would just add the attribute the container.
184        """
185        for name, state in states.items():
186            config_name = data_groups.get(name, None)
187            if config_name is None:
188                raise RuntimeError(f"Error loading {name}")
189
190            # check if the data with such a name was parametrized, if so parametrize
191            # otherwise just set the attribute and continue
192            parametrized_name = f"parametrizations.{name}.original"
193            parametrized = False
194            data = container_state_dict.get(name, None)
195            if name in container_state_dict:
196                # the parametrization was probably removed for this
197                data = container_state_dict.get(name)
198
199            elif parametrized_name in container_state_dict:
200                # so the weight was parametrized
201                data = container_state_dict.get(parametrized_name)
202                parametrized = True
203
204            else:
205                raise RuntimeError(f"Error loading {name}")
206
207            self._container.register_buffer(name=name, tensor=data)
208
209            if parametrized:
210                # register parameter if parametrized
211                mask = state.get("mask", torch.ones_like(data))
212                param_class = data_groups.get(
213                    "parametrization", utils.FakeSparsity
214                )  # change once public_api for utils is fixed!
215                parametrize.register_parametrization(
216                    self._container, name, param_class(mask)
217                )
218
219    def load_state_dict(self, state_dict, strict=True):
220        r"""The load_state_dict() restores the state of the sparsifier based on the state_dict
221
222        Args:
223        * state_dict - the dictionary that to which the current sparsifier needs to be restored to
224        * strict - If True - the sparsifier is reset and is restored exactly to the state in state_dict.
225            If False - the current sparsifier is not reset before loading the state_dict i.e. data added
226            before loading the state_dict is not erased.
227        """
228        states = copy.deepcopy(state_dict["state"])
229        data_groups = copy.deepcopy(state_dict["data_groups"])
230        container_state_dict = copy.deepcopy(state_dict["_container"])
231
232        states = self._convert_mask(
233            states, sparse_coo=False
234        )  # convert sparse coo mask to dense
235        if strict:
236            # if strict load -> then reset container
237            self._container = _Container()
238
239        self._load_container_from_state(states, data_groups, container_state_dict)
240
241        if not strict:
242            states.update(self.state)
243            data_groups.update(self.data_groups)
244
245        self.__setstate__({"state": states, "data_groups": data_groups})
246
247    def __setstate__(self, state):
248        if "_container" in state:  # If container object is in state then load model
249            container_dict = state.pop("_container")
250            self._container = _Container()
251            state["state"] = self._convert_mask(
252                state["state"], sparse_coo=False
253            )  # convert sparse coo mask to dense
254            self._load_container_from_state(
255                state["state"], state["data_groups"], container_dict
256            )
257
258        self.__dict__.update(state)
259
260    def __getstate__(self):
261        state = self._convert_mask(self.state)
262        return {
263            "defaults": self.defaults,
264            "state": state,
265            "data_groups": self.data_groups,
266            "_container": self._container.state_dict(),
267        }
268
269    def __repr__(self):
270        format_string = self.__class__.__name__ + " ("
271        for name, sparse_args in self.data_groups.items():
272            format_string += "\n"
273            format_string += "\tData Group\n"
274            format_string += f"\t    name: {name}\n"
275            for key in sorted(sparse_args.keys()):
276                if key == "data":
277                    continue
278                format_string += f"\t    {key}: {sparse_args[key]}\n"
279        format_string += ")"
280        return format_string
281
282    def get_mask(self, name: str):
283        if name not in self.state:
284            raise ValueError("data with specified name does not exist")
285        return self.state[name]["mask"]
286
287    def squash_mask(self, *args, leave_parametrized=True, names=None, **kwargs):
288        r"""Squashes the sparse masks into the appropriate tensors. Also, accepts list of strings
289        to squash mask for. If none, squashes mask for all the keys
290        kwargs:
291            * names: list of strings to squash mask for
292            * sparsified: if true - applies the mask before squashing
293                          if false - does not apply the mask before squashing
294        """
295        if names is None:
296            names = list(self.data_groups.keys())
297        for name in names:
298            parametrize.remove_parametrizations(
299                self._container, name, leave_parametrized=leave_parametrized
300            )
301
302    def step(self):
303        if not self.enable_mask_update:
304            return
305        with torch.no_grad():
306            for name, config in self.data_groups.items():
307                # get non-sparsified data
308                data = self.get_data(name)
309                # need name for the mask otherwise can directly pass mask?
310                self.update_mask(name, data, **config)
311
312    @abc.abstractmethod
313    def update_mask(self, name, data, **kwargs):
314        pass
315
316    def _delete_data(self, name):
317        """Detaches some data from the sparsifier.
318
319        Args:
320            name (str)
321                Name of the data to be removed from the sparsifier
322
323        Note:
324            Currently private. Kind of used as a helper function when replacing data of the same name
325        """
326        self.squash_mask(
327            names=[name], leave_parametrized=False
328        )  # do not apply the mask while deleting
329        delattr(self._container, name)
330        self.state.pop(name)
331        self.data_groups.pop(name)
332