1# mypy: allow-untyped-defs 2import warnings 3 4from .base_scheduler import BaseScheduler 5 6 7__all__ = ["CubicSL"] 8 9 10def _clamp(x, lo, hi): 11 return max(lo, min(hi, x)) 12 13 14class CubicSL(BaseScheduler): 15 r"""Sets the sparsity level of each parameter group to the final sl 16 plus a given exponential function. 17 18 .. math:: 19 20 s_i = s_f + (s_0 - s_f) \cdot \left( 1 - \frac{t - t_0}{n\Delta t} \right)^3 21 22 where :math:`s_i` is the sparsity at epoch :math:`t`, :math;`s_f` is the final 23 sparsity level, :math:`f(i)` is the function to be applied to the current epoch 24 :math:`t`, initial epoch :math:`t_0`, and final epoch :math:`t_f`. 25 :math:`\Delta t` is used to control how often the update of the sparsity level 26 happens. By default, 27 28 Args: 29 sparsifier (BaseSparsifier): Wrapped sparsifier. 30 init_sl (int, list): Initial level of sparsity 31 init_t (int, list): Initial step, when pruning starts 32 delta_t (int, list): Pruning frequency 33 total_t (int, list): Total number of pruning steps 34 initially_zero (bool, list): If True, sets the level of sparsity to 0 35 before init_t (:math:`t_0`). Otherwise, the sparsity level before 36 init_t (:math:`t_0`) is set to init_sl(:math:`s_0`) 37 last_epoch (int): The index of last epoch. Default: -1. 38 verbose (bool): If ``True``, prints a message to stdout for 39 each update. Default: ``False``. 40 """ 41 42 def __init__( 43 self, 44 sparsifier, 45 init_sl=0.0, 46 init_t=0, 47 delta_t=10, 48 total_t=100, 49 initially_zero=False, 50 last_epoch=-1, 51 verbose=False, 52 ): 53 self.sparsifier = sparsifier 54 55 self.init_sl = self._make_sure_a_list(init_sl) 56 self.init_t = self._make_sure_a_list(init_t) 57 self.delta_t = self._make_sure_a_list(delta_t) 58 self.total_t = self._make_sure_a_list(total_t) 59 60 self.initially_zero = self._make_sure_a_list(initially_zero) 61 62 super().__init__(sparsifier, last_epoch, verbose) 63 64 @staticmethod 65 def sparsity_compute_fn(s_0, s_f, t, t_0, dt, n, initially_zero=False): 66 r""" "Computes the current level of sparsity. 67 68 Based on https://arxiv.org/pdf/1710.01878.pdf 69 70 Args: 71 s_0: Initial level of sparsity, :math:`s_i` 72 s_f: Target level of sparsity, :math:`s_f` 73 t: Current step, :math:`t` 74 t_0: Initial step, :math:`t_0` 75 dt: Pruning frequency, :math:`\Delta T` 76 n: Pruning steps, :math:`n` 77 initially_zero: Sets the level of sparsity to 0 before t_0. 78 If False, sets to s_0 79 80 Returns: 81 The sparsity level :math:`s_t` at the current step :math:`t` 82 """ 83 if initially_zero and t < t_0: 84 return 0 85 s_t = s_f + (s_0 - s_f) * (1.0 - (t - t_0) / (dt * n)) ** 3 86 s_t = _clamp(s_t, s_0, s_f) 87 return s_t 88 89 def get_sl(self): 90 if not self._get_sl_called_within_step: 91 warnings.warn( 92 "To get the last sparsity level computed by the scheduler, " 93 "please use `get_last_sl()`." 94 ) 95 return [ 96 self.sparsity_compute_fn( 97 s_0=initial_sparsity, 98 s_f=final_sparsity, 99 t=self.last_epoch, 100 t_0=initial_epoch, 101 dt=delta_epoch, 102 n=interval_epochs, 103 initially_zero=initially_zero, 104 ) 105 for initial_sparsity, final_sparsity, initial_epoch, delta_epoch, interval_epochs, initially_zero in zip( 106 self.init_sl, 107 self.base_sl, 108 self.init_t, 109 self.delta_t, 110 self.total_t, 111 self.initially_zero, 112 ) 113 ] 114