1# mypy: allow-untyped-defs 2import warnings 3 4from .base_scheduler import BaseScheduler 5 6 7__all__ = ["LambdaSL"] 8 9 10class LambdaSL(BaseScheduler): 11 """Sets the sparsity level of each parameter group to the final sl 12 times a given function. When last_epoch=-1, sets initial sl as zero. 13 Args: 14 sparsifier (BaseSparsifier): Wrapped sparsifier. 15 sl_lambda (function or list): A function which computes a multiplicative 16 factor given an integer parameter epoch, or a list of such 17 functions, one for each group in sparsifier.param_groups. 18 last_epoch (int): The index of last epoch. Default: -1. 19 verbose (bool): If ``True``, prints a message to stdout for 20 each update. Default: ``False``. 21 Example: 22 >>> # Assuming sparsifier has two groups. 23 >>> lambda1 = lambda epoch: epoch // 30 24 >>> lambda2 = lambda epoch: 0.95 ** epoch 25 >>> # xdoctest: +SKIP 26 >>> scheduler = LambdaSL(sparsifier, sl_lambda=[lambda1, lambda2]) 27 >>> for epoch in range(100): 28 >>> train(...) 29 >>> validate(...) 30 >>> scheduler.step() 31 """ 32 33 def __init__(self, sparsifier, sl_lambda, last_epoch=-1, verbose=False): 34 self.sparsifier = sparsifier 35 36 if not isinstance(sl_lambda, list) and not isinstance(sl_lambda, tuple): 37 self.sl_lambdas = [sl_lambda] * len(sparsifier.groups) 38 else: 39 if len(sl_lambda) != len(sparsifier.groups): 40 raise ValueError( 41 f"Expected {len(sparsifier.groups)} lr_lambdas, but got {len(sl_lambda)}" 42 ) 43 self.sl_lambdas = list(sl_lambda) 44 super().__init__(sparsifier, last_epoch, verbose) 45 46 def get_sl(self): 47 if not self._get_sl_called_within_step: 48 warnings.warn( 49 "To get the last sparsity level computed by the scheduler, " 50 "please use `get_last_sl()`." 51 ) 52 return [ 53 base_sl * lmbda(self.last_epoch) 54 for lmbda, base_sl in zip(self.sl_lambdas, self.base_sl) 55 ] 56