xref: /aosp_15_r20/external/pytorch/torch/ao/pruning/scheduler/base_scheduler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3import warnings
4import weakref
5from functools import wraps
6
7from torch.ao.pruning.sparsifier.base_sparsifier import BaseSparsifier
8
9
10__all__ = ["BaseScheduler"]
11
12
13class BaseScheduler:
14    def __init__(self, sparsifier, last_epoch=-1, verbose=False):
15        # Attach sparsifier
16        if not isinstance(sparsifier, BaseSparsifier):
17            raise TypeError(
18                f"{type(sparsifier).__name__} is not an instance of torch.ao.pruning.BaseSparsifier"
19            )
20        self.sparsifier = sparsifier
21
22        # Initialize epoch and base sparsity levels
23
24        self.base_sl = [group["sparsity_level"] for group in sparsifier.groups]
25        self.last_epoch = last_epoch
26
27        # Following https://github.com/pytorch/pytorch/issues/20124
28        # We would like to ensure that `scheduler.step()` is called after
29        # `sparsifier.step()`
30        def with_counter(method):
31            if getattr(method, "_with_counter", False):
32                # `sparsifier.step()` has already been replaced, return.
33                return method
34
35            # Keep a weak reference to the sparsifier instance to prevent
36            # cyclic references.
37            instance_ref = weakref.ref(method.__self__)
38            # Get the unbound method for the same purpose.
39            func = method.__func__
40            cls = instance_ref().__class__
41            del method
42
43            @wraps(func)
44            def wrapper(*args, **kwargs):
45                instance = instance_ref()
46                instance._step_count += 1  # type: ignore[union-attr]
47                wrapped = func.__get__(instance, cls)
48                return wrapped(*args, **kwargs)
49
50            # Note that the returned function here is no longer a bound method,
51            # so attributes like `__func__` and `__self__` no longer exist.
52            wrapper._with_counter = True  # type: ignore[attr-defined]
53            return wrapper
54
55        self.sparsifier.step = with_counter(self.sparsifier.step)  # type: ignore[assignment]
56        self.sparsifier._step_count = 0  # type: ignore[attr-defined]
57        self._step_count: int = 0
58        self.verbose = verbose
59
60        # Housekeeping
61        self._get_sl_called_within_step: bool = False
62
63        self.step()
64
65    def state_dict(self):
66        """Returns the state of the scheduler as a :class:`dict`.
67
68        It contains an entry for every variable in self.__dict__ which
69        is not the sparsifier.
70        """
71        return {
72            key: value for key, value in self.__dict__.items() if key != "sparsifier"
73        }
74
75    def load_state_dict(self, state_dict):
76        """Loads the schedulers state.
77
78        Args:
79            state_dict (dict): scheduler state. Should be an object returned
80                from a call to :meth:`state_dict`.
81        """
82        self.__dict__.update(state_dict)
83
84    def get_last_sl(self):
85        """Return last computed sparsity level by current scheduler."""
86        return self._last_sl
87
88    def get_sl(self):
89        # Compute sparsity level using chainable form of the scheduler
90        # Note: This method is not intended to be called directly, and is only
91        #       used by the ".step" method. Use .get_last_sl() instead.
92        if not self._get_sl_called_within_step:
93            warnings.warn(
94                "To get the last sparsity level computed by the scheduler, "
95                "please use `get_last_sl()`."
96            )
97        raise NotImplementedError
98
99    def print_sl(self, is_verbose, group, sl, epoch=None):
100        """Display the current sparsity level."""
101        if is_verbose:
102            if epoch is None:
103                print(f"Adjusting sparsity level of group {group} to {sl:.4e}.")
104            else:
105                print(
106                    f"Epoch {epoch:5d}: adjusting sparsity level of group {group} to {sl:.4e}."
107                )
108
109    def __repr__(self):
110        format_string = self.__class__.__name__ + " ("
111        format_string += "\n"
112        format_string += f"Sparsifier {self.sparsifier}\n"
113        format_string += f"    base_sl: {self.base_sl}\n"
114        format_string += ")"
115        return format_string
116
117    def step(self, epoch=None):
118        # Raise warning if trying to call scheduler step before the sparsifier.
119        # https://github.com/pytorch/pytorch/issues/20124
120        if self._step_count == 1:
121            if not hasattr(self.sparsifier.step, "_with_counter"):
122                warnings.warn(
123                    "Seems like `sparsifier.step()` has been overridden after sparsity scheduler "
124                    "initialization. Please, make sure to call `sparsifier.step()` before "
125                    "`scheduler.step()`.",
126                    UserWarning,
127                )
128
129            # Just check if there were two first scheduler.step() calls before sparsifier.step()
130            elif self.sparsifier._step_count < 1:  # type: ignore[attr-defined]
131                warnings.warn(
132                    "Detected call of `scheduler.step()` before `sparsifier.step()`. "
133                    "You have to make sure you run the sparsifier.step() BEFORE any "
134                    "calls to the scheduler.step().",
135                    UserWarning,
136                )
137        self._step_count += 1
138
139        class _enable_get_sl_call:
140            def __init__(self, o):
141                self.o = o
142
143            def __enter__(self):
144                self.o._get_sl_called_within_step = True
145                return self
146
147            def __exit__(self, type, value, traceback):
148                self.o._get_sl_called_within_step = False
149
150        with _enable_get_sl_call(self):
151            self.last_epoch += 1
152            values = self.get_sl()
153
154        for i, data in enumerate(zip(self.sparsifier.groups, values)):
155            param_group, sl = data
156            param_group["sparsity_level"] = sl
157            self.print_sl(self.verbose, i, sl, epoch)
158
159        self._last_sl = [group["sparsity_level"] for group in self.sparsifier.groups]
160        self.sparsifier.enable_mask_update = True
161
162    def _make_sure_a_list(self, var):
163        r"""Utility that extends it to the same length as the .groups, ensuring it is a list"""
164        n = len(self.sparsifier.groups)
165        if not isinstance(var, (list, tuple)):
166            return [var] * n
167        else:
168            if len(var) != n:
169                raise ValueError(f"Expected variable of length {n}, but got {len(var)}")
170            return list(var)  # We want the result to be in a list, not tuple
171