1# mypy: allow-untyped-defs 2from functools import wraps 3from inspect import unwrap 4from typing import Callable, List, Optional 5import logging 6 7logger = logging.getLogger(__name__) 8 9__all__ = [ 10 "PassManager", 11 "inplace_wrapper", 12 "log_hook", 13 "loop_pass", 14 "this_before_that_pass_constraint", 15 "these_before_those_pass_constraint", 16] 17 18# for callables which modify object inplace and return something other than 19# the object on which they act 20def inplace_wrapper(fn: Callable) -> Callable: 21 """ 22 Convenience wrapper for passes which modify an object inplace. This 23 wrapper makes them return the modified object instead. 24 25 Args: 26 fn (Callable[Object, Any]) 27 28 Returns: 29 wrapped_fn (Callable[Object, Object]) 30 """ 31 32 @wraps(fn) 33 def wrapped_fn(gm): 34 val = fn(gm) 35 return gm 36 37 return wrapped_fn 38 39def log_hook(fn: Callable, level=logging.INFO) -> Callable: 40 """ 41 Logs callable output. 42 43 This is useful for logging output of passes. Note inplace_wrapper replaces 44 the pass output with the modified object. If we want to log the original 45 output, apply this wrapper before inplace_wrapper. 46 47 48 ``` 49 def my_pass(d: Dict) -> bool: 50 changed = False 51 if 'foo' in d: 52 d['foo'] = 'bar' 53 changed = True 54 return changed 55 56 pm = PassManager( 57 passes=[ 58 inplace_wrapper(log_hook(my_pass)) 59 ] 60 ) 61 ``` 62 63 Args: 64 fn (Callable[Type1, Type2]) 65 level: logging level (e.g. logging.INFO) 66 67 Returns: 68 wrapped_fn (Callable[Type1, Type2]) 69 """ 70 @wraps(fn) 71 def wrapped_fn(gm): 72 val = fn(gm) 73 logger.log(level, "Ran pass %s\t Return value: %s", fn, val) 74 return val 75 76 return wrapped_fn 77 78 79 80def loop_pass(base_pass: Callable, n_iter: Optional[int] = None, predicate: Optional[Callable] = None): 81 """ 82 Convenience wrapper for passes which need to be applied multiple times. 83 84 Exactly one of `n_iter`or `predicate` must be specified. 85 86 Args: 87 base_pass (Callable[Object, Object]): pass to be applied in loop 88 n_iter (int, optional): number of times to loop pass 89 predicate (Callable[Object, bool], optional): 90 91 """ 92 assert (n_iter is not None) ^ ( 93 predicate is not None 94 ), "Exactly one of `n_iter`or `predicate` must be specified." 95 96 @wraps(base_pass) 97 def new_pass(source): 98 output = source 99 if n_iter is not None and n_iter > 0: 100 for _ in range(n_iter): 101 output = base_pass(output) 102 elif predicate is not None: 103 while predicate(output): 104 output = base_pass(output) 105 else: 106 raise RuntimeError( 107 f"loop_pass must be given positive int n_iter (given " 108 f"{n_iter}) xor predicate (given {predicate})" 109 ) 110 return output 111 112 return new_pass 113 114 115# Pass Schedule Constraints: 116# 117# Implemented as 'depends on' operators. A constraint is satisfied iff a list 118# has a valid partial ordering according to this comparison operator. 119def _validate_pass_schedule_constraint( 120 constraint: Callable[[Callable, Callable], bool], passes: List[Callable] 121): 122 for i, a in enumerate(passes): 123 for j, b in enumerate(passes[i + 1 :]): 124 if constraint(a, b): 125 continue 126 raise RuntimeError( 127 f"pass schedule constraint violated. Expected {a} before {b}" 128 f" but found {a} at index {i} and {b} at index{j} in pass" 129 f" list." 130 ) 131 132 133def this_before_that_pass_constraint(this: Callable, that: Callable): 134 """ 135 Defines a partial order ('depends on' function) where `this` must occur 136 before `that`. 137 """ 138 139 def depends_on(a: Callable, b: Callable): 140 return a != that or b != this 141 142 return depends_on 143 144 145def these_before_those_pass_constraint(these: Callable, those: Callable): 146 """ 147 Defines a partial order ('depends on' function) where `these` must occur 148 before `those`. Where the inputs are 'unwrapped' before comparison. 149 150 For example, the following pass list and constraint list would be invalid. 151 ``` 152 passes = [ 153 loop_pass(pass_b, 3), 154 loop_pass(pass_a, 5), 155 ] 156 157 constraints = [ 158 these_before_those_pass_constraint(pass_a, pass_b) 159 ] 160 ``` 161 162 Args: 163 these (Callable): pass which should occur first 164 those (Callable): pass which should occur later 165 166 Returns: 167 depends_on (Callable[[Object, Object], bool] 168 """ 169 170 def depends_on(a: Callable, b: Callable): 171 return unwrap(a) != those or unwrap(b) != these 172 173 return depends_on 174 175 176class PassManager: 177 """ 178 Construct a PassManager. 179 180 Collects passes and constraints. This defines the pass schedule, manages 181 pass constraints and pass execution. 182 183 Args: 184 passes (Optional[List[Callable]]): list of passes. A pass is a 185 callable which modifies an object and returns modified object 186 constraint (Optional[List[Callable]]): list of constraints. A 187 constraint is a callable which takes two passes (A, B) and returns 188 True if A depends on B and False otherwise. See implementation of 189 `this_before_that_pass_constraint` for example. 190 """ 191 192 passes: List[Callable] 193 constraints: List[Callable] 194 _validated: bool = False 195 196 def __init__( 197 self, 198 passes=None, 199 constraints=None, 200 ): 201 self.passes = passes or [] 202 self.constraints = constraints or [] 203 204 @classmethod 205 def build_from_passlist(cls, passes): 206 pm = PassManager(passes) 207 # TODO(alexbeloi): add constraint management/validation 208 return pm 209 210 def add_pass(self, _pass: Callable): 211 self.passes.append(_pass) 212 self._validated = False 213 214 def add_constraint(self, constraint): 215 self.constraints.append(constraint) 216 self._validated = False 217 218 def remove_pass(self, _passes: List[str]): 219 if _passes is None: 220 return 221 passes_left = [] 222 for ps in self.passes: 223 if ps.__name__ not in _passes: 224 passes_left.append(ps) 225 self.passes = passes_left 226 self._validated = False 227 228 def replace_pass(self, _target, _replacement): 229 passes_left = [] 230 for ps in self.passes: 231 if ps.__name__ == _target.__name__: 232 passes_left.append(_replacement) 233 else: 234 passes_left.append(ps) 235 self.passes = passes_left 236 self._validated = False 237 238 def validate(self): 239 """ 240 Validates that current pass schedule defined by `self.passes` is valid 241 according to all constraints in `self.constraints` 242 """ 243 if self._validated: 244 return 245 for constraint in self.constraints: 246 _validate_pass_schedule_constraint(constraint, self.passes) 247 self._validated = True 248 249 def __call__(self, source): 250 self.validate() 251 out = source 252 for _pass in self.passes: 253 out = _pass(out) 254 return out 255