xref: /aosp_15_r20/external/pytorch/torch/ao/pruning/scheduler/cubic_scheduler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import warnings
3
4from .base_scheduler import BaseScheduler
5
6
7__all__ = ["CubicSL"]
8
9
10def _clamp(x, lo, hi):
11    return max(lo, min(hi, x))
12
13
14class CubicSL(BaseScheduler):
15    r"""Sets the sparsity level of each parameter group to the final sl
16    plus a given exponential function.
17
18    .. math::
19
20        s_i = s_f + (s_0 - s_f) \cdot \left( 1 - \frac{t - t_0}{n\Delta t} \right)^3
21
22    where :math:`s_i` is the sparsity at epoch :math:`t`, :math;`s_f` is the final
23    sparsity level, :math:`f(i)` is the function to be applied to the current epoch
24    :math:`t`, initial epoch :math:`t_0`, and final epoch :math:`t_f`.
25    :math:`\Delta t` is used to control how often the update of the sparsity level
26    happens. By default,
27
28    Args:
29        sparsifier (BaseSparsifier): Wrapped sparsifier.
30        init_sl (int, list): Initial level of sparsity
31        init_t (int, list): Initial step, when pruning starts
32        delta_t (int, list): Pruning frequency
33        total_t (int, list): Total number of pruning steps
34        initially_zero (bool, list): If True, sets the level of sparsity to 0
35            before init_t (:math:`t_0`). Otherwise, the sparsity level before
36            init_t (:math:`t_0`) is set to init_sl(:math:`s_0`)
37        last_epoch (int): The index of last epoch. Default: -1.
38        verbose (bool): If ``True``, prints a message to stdout for
39            each update. Default: ``False``.
40    """
41
42    def __init__(
43        self,
44        sparsifier,
45        init_sl=0.0,
46        init_t=0,
47        delta_t=10,
48        total_t=100,
49        initially_zero=False,
50        last_epoch=-1,
51        verbose=False,
52    ):
53        self.sparsifier = sparsifier
54
55        self.init_sl = self._make_sure_a_list(init_sl)
56        self.init_t = self._make_sure_a_list(init_t)
57        self.delta_t = self._make_sure_a_list(delta_t)
58        self.total_t = self._make_sure_a_list(total_t)
59
60        self.initially_zero = self._make_sure_a_list(initially_zero)
61
62        super().__init__(sparsifier, last_epoch, verbose)
63
64    @staticmethod
65    def sparsity_compute_fn(s_0, s_f, t, t_0, dt, n, initially_zero=False):
66        r""" "Computes the current level of sparsity.
67
68        Based on https://arxiv.org/pdf/1710.01878.pdf
69
70        Args:
71            s_0: Initial level of sparsity, :math:`s_i`
72            s_f: Target level of sparsity, :math:`s_f`
73            t: Current step, :math:`t`
74            t_0: Initial step, :math:`t_0`
75            dt: Pruning frequency, :math:`\Delta T`
76            n: Pruning steps, :math:`n`
77            initially_zero: Sets the level of sparsity to 0 before t_0.
78                If False, sets to s_0
79
80        Returns:
81            The sparsity level :math:`s_t` at the current step :math:`t`
82        """
83        if initially_zero and t < t_0:
84            return 0
85        s_t = s_f + (s_0 - s_f) * (1.0 - (t - t_0) / (dt * n)) ** 3
86        s_t = _clamp(s_t, s_0, s_f)
87        return s_t
88
89    def get_sl(self):
90        if not self._get_sl_called_within_step:
91            warnings.warn(
92                "To get the last sparsity level computed by the scheduler, "
93                "please use `get_last_sl()`."
94            )
95        return [
96            self.sparsity_compute_fn(
97                s_0=initial_sparsity,
98                s_f=final_sparsity,
99                t=self.last_epoch,
100                t_0=initial_epoch,
101                dt=delta_epoch,
102                n=interval_epochs,
103                initially_zero=initially_zero,
104            )
105            for initial_sparsity, final_sparsity, initial_epoch, delta_epoch, interval_epochs, initially_zero in zip(
106                self.init_sl,
107                self.base_sl,
108                self.init_t,
109                self.delta_t,
110                self.total_t,
111                self.initially_zero,
112            )
113        ]
114