1# mypy: allow-untyped-defs 2 3import warnings 4import weakref 5from functools import wraps 6 7from torch.ao.pruning.sparsifier.base_sparsifier import BaseSparsifier 8 9 10__all__ = ["BaseScheduler"] 11 12 13class BaseScheduler: 14 def __init__(self, sparsifier, last_epoch=-1, verbose=False): 15 # Attach sparsifier 16 if not isinstance(sparsifier, BaseSparsifier): 17 raise TypeError( 18 f"{type(sparsifier).__name__} is not an instance of torch.ao.pruning.BaseSparsifier" 19 ) 20 self.sparsifier = sparsifier 21 22 # Initialize epoch and base sparsity levels 23 24 self.base_sl = [group["sparsity_level"] for group in sparsifier.groups] 25 self.last_epoch = last_epoch 26 27 # Following https://github.com/pytorch/pytorch/issues/20124 28 # We would like to ensure that `scheduler.step()` is called after 29 # `sparsifier.step()` 30 def with_counter(method): 31 if getattr(method, "_with_counter", False): 32 # `sparsifier.step()` has already been replaced, return. 33 return method 34 35 # Keep a weak reference to the sparsifier instance to prevent 36 # cyclic references. 37 instance_ref = weakref.ref(method.__self__) 38 # Get the unbound method for the same purpose. 39 func = method.__func__ 40 cls = instance_ref().__class__ 41 del method 42 43 @wraps(func) 44 def wrapper(*args, **kwargs): 45 instance = instance_ref() 46 instance._step_count += 1 # type: ignore[union-attr] 47 wrapped = func.__get__(instance, cls) 48 return wrapped(*args, **kwargs) 49 50 # Note that the returned function here is no longer a bound method, 51 # so attributes like `__func__` and `__self__` no longer exist. 52 wrapper._with_counter = True # type: ignore[attr-defined] 53 return wrapper 54 55 self.sparsifier.step = with_counter(self.sparsifier.step) # type: ignore[assignment] 56 self.sparsifier._step_count = 0 # type: ignore[attr-defined] 57 self._step_count: int = 0 58 self.verbose = verbose 59 60 # Housekeeping 61 self._get_sl_called_within_step: bool = False 62 63 self.step() 64 65 def state_dict(self): 66 """Returns the state of the scheduler as a :class:`dict`. 67 68 It contains an entry for every variable in self.__dict__ which 69 is not the sparsifier. 70 """ 71 return { 72 key: value for key, value in self.__dict__.items() if key != "sparsifier" 73 } 74 75 def load_state_dict(self, state_dict): 76 """Loads the schedulers state. 77 78 Args: 79 state_dict (dict): scheduler state. Should be an object returned 80 from a call to :meth:`state_dict`. 81 """ 82 self.__dict__.update(state_dict) 83 84 def get_last_sl(self): 85 """Return last computed sparsity level by current scheduler.""" 86 return self._last_sl 87 88 def get_sl(self): 89 # Compute sparsity level using chainable form of the scheduler 90 # Note: This method is not intended to be called directly, and is only 91 # used by the ".step" method. Use .get_last_sl() instead. 92 if not self._get_sl_called_within_step: 93 warnings.warn( 94 "To get the last sparsity level computed by the scheduler, " 95 "please use `get_last_sl()`." 96 ) 97 raise NotImplementedError 98 99 def print_sl(self, is_verbose, group, sl, epoch=None): 100 """Display the current sparsity level.""" 101 if is_verbose: 102 if epoch is None: 103 print(f"Adjusting sparsity level of group {group} to {sl:.4e}.") 104 else: 105 print( 106 f"Epoch {epoch:5d}: adjusting sparsity level of group {group} to {sl:.4e}." 107 ) 108 109 def __repr__(self): 110 format_string = self.__class__.__name__ + " (" 111 format_string += "\n" 112 format_string += f"Sparsifier {self.sparsifier}\n" 113 format_string += f" base_sl: {self.base_sl}\n" 114 format_string += ")" 115 return format_string 116 117 def step(self, epoch=None): 118 # Raise warning if trying to call scheduler step before the sparsifier. 119 # https://github.com/pytorch/pytorch/issues/20124 120 if self._step_count == 1: 121 if not hasattr(self.sparsifier.step, "_with_counter"): 122 warnings.warn( 123 "Seems like `sparsifier.step()` has been overridden after sparsity scheduler " 124 "initialization. Please, make sure to call `sparsifier.step()` before " 125 "`scheduler.step()`.", 126 UserWarning, 127 ) 128 129 # Just check if there were two first scheduler.step() calls before sparsifier.step() 130 elif self.sparsifier._step_count < 1: # type: ignore[attr-defined] 131 warnings.warn( 132 "Detected call of `scheduler.step()` before `sparsifier.step()`. " 133 "You have to make sure you run the sparsifier.step() BEFORE any " 134 "calls to the scheduler.step().", 135 UserWarning, 136 ) 137 self._step_count += 1 138 139 class _enable_get_sl_call: 140 def __init__(self, o): 141 self.o = o 142 143 def __enter__(self): 144 self.o._get_sl_called_within_step = True 145 return self 146 147 def __exit__(self, type, value, traceback): 148 self.o._get_sl_called_within_step = False 149 150 with _enable_get_sl_call(self): 151 self.last_epoch += 1 152 values = self.get_sl() 153 154 for i, data in enumerate(zip(self.sparsifier.groups, values)): 155 param_group, sl = data 156 param_group["sparsity_level"] = sl 157 self.print_sl(self.verbose, i, sl, epoch) 158 159 self._last_sl = [group["sparsity_level"] for group in self.sparsifier.groups] 160 self.sparsifier.enable_mask_update = True 161 162 def _make_sure_a_list(self, var): 163 r"""Utility that extends it to the same length as the .groups, ensuring it is a list""" 164 n = len(self.sparsifier.groups) 165 if not isinstance(var, (list, tuple)): 166 return [var] * n 167 else: 168 if len(var) != n: 169 raise ValueError(f"Expected variable of length {n}, but got {len(var)}") 170 return list(var) # We want the result to be in a list, not tuple 171