xref: /aosp_15_r20/external/pytorch/torch/ao/pruning/scheduler/lambda_scheduler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import warnings
3
4from .base_scheduler import BaseScheduler
5
6
7__all__ = ["LambdaSL"]
8
9
10class LambdaSL(BaseScheduler):
11    """Sets the sparsity level of each parameter group to the final sl
12    times a given function. When last_epoch=-1, sets initial sl as zero.
13    Args:
14        sparsifier (BaseSparsifier): Wrapped sparsifier.
15        sl_lambda (function or list): A function which computes a multiplicative
16            factor given an integer parameter epoch, or a list of such
17            functions, one for each group in sparsifier.param_groups.
18        last_epoch (int): The index of last epoch. Default: -1.
19        verbose (bool): If ``True``, prints a message to stdout for
20            each update. Default: ``False``.
21    Example:
22        >>> # Assuming sparsifier has two groups.
23        >>> lambda1 = lambda epoch: epoch // 30
24        >>> lambda2 = lambda epoch: 0.95 ** epoch
25        >>> # xdoctest: +SKIP
26        >>> scheduler = LambdaSL(sparsifier, sl_lambda=[lambda1, lambda2])
27        >>> for epoch in range(100):
28        >>>     train(...)
29        >>>     validate(...)
30        >>>     scheduler.step()
31    """
32
33    def __init__(self, sparsifier, sl_lambda, last_epoch=-1, verbose=False):
34        self.sparsifier = sparsifier
35
36        if not isinstance(sl_lambda, list) and not isinstance(sl_lambda, tuple):
37            self.sl_lambdas = [sl_lambda] * len(sparsifier.groups)
38        else:
39            if len(sl_lambda) != len(sparsifier.groups):
40                raise ValueError(
41                    f"Expected {len(sparsifier.groups)} lr_lambdas, but got {len(sl_lambda)}"
42                )
43            self.sl_lambdas = list(sl_lambda)
44        super().__init__(sparsifier, last_epoch, verbose)
45
46    def get_sl(self):
47        if not self._get_sl_called_within_step:
48            warnings.warn(
49                "To get the last sparsity level computed by the scheduler, "
50                "please use `get_last_sl()`."
51            )
52        return [
53            base_sl * lmbda(self.last_epoch)
54            for lmbda, base_sl in zip(self.sl_lambdas, self.base_sl)
55        ]
56