xref: /aosp_15_r20/external/pytorch/torch/fx/passes/pass_manager.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from functools import wraps
3from inspect import unwrap
4from typing import Callable, List, Optional
5import logging
6
7logger = logging.getLogger(__name__)
8
9__all__ = [
10    "PassManager",
11    "inplace_wrapper",
12    "log_hook",
13    "loop_pass",
14    "this_before_that_pass_constraint",
15    "these_before_those_pass_constraint",
16]
17
18# for callables which modify object inplace and return something other than
19# the object on which they act
20def inplace_wrapper(fn: Callable) -> Callable:
21    """
22    Convenience wrapper for passes which modify an object inplace. This
23    wrapper makes them return the modified object instead.
24
25    Args:
26        fn (Callable[Object, Any])
27
28    Returns:
29        wrapped_fn (Callable[Object, Object])
30    """
31
32    @wraps(fn)
33    def wrapped_fn(gm):
34        val = fn(gm)
35        return gm
36
37    return wrapped_fn
38
39def log_hook(fn: Callable, level=logging.INFO) -> Callable:
40    """
41    Logs callable output.
42
43    This is useful for logging output of passes. Note inplace_wrapper replaces
44    the pass output with the modified object. If we want to log the original
45    output, apply this wrapper before inplace_wrapper.
46
47
48    ```
49    def my_pass(d: Dict) -> bool:
50        changed = False
51        if 'foo' in d:
52            d['foo'] = 'bar'
53            changed = True
54        return changed
55
56    pm = PassManager(
57        passes=[
58            inplace_wrapper(log_hook(my_pass))
59        ]
60    )
61    ```
62
63    Args:
64        fn (Callable[Type1, Type2])
65        level: logging level (e.g. logging.INFO)
66
67    Returns:
68        wrapped_fn (Callable[Type1, Type2])
69    """
70    @wraps(fn)
71    def wrapped_fn(gm):
72        val = fn(gm)
73        logger.log(level, "Ran pass %s\t Return value: %s", fn, val)
74        return val
75
76    return wrapped_fn
77
78
79
80def loop_pass(base_pass: Callable, n_iter: Optional[int] = None, predicate: Optional[Callable] = None):
81    """
82    Convenience wrapper for passes which need to be applied multiple times.
83
84    Exactly one of `n_iter`or `predicate` must be specified.
85
86    Args:
87        base_pass (Callable[Object, Object]): pass to be applied in loop
88        n_iter (int, optional): number of times to loop pass
89        predicate (Callable[Object, bool], optional):
90
91    """
92    assert (n_iter is not None) ^ (
93        predicate is not None
94    ), "Exactly one of `n_iter`or `predicate` must be specified."
95
96    @wraps(base_pass)
97    def new_pass(source):
98        output = source
99        if n_iter is not None and n_iter > 0:
100            for _ in range(n_iter):
101                output = base_pass(output)
102        elif predicate is not None:
103            while predicate(output):
104                output = base_pass(output)
105        else:
106            raise RuntimeError(
107                f"loop_pass must be given positive int n_iter (given "
108                f"{n_iter}) xor predicate (given {predicate})"
109            )
110        return output
111
112    return new_pass
113
114
115# Pass Schedule Constraints:
116#
117# Implemented as 'depends on' operators. A constraint is satisfied iff a list
118# has a valid partial ordering according to this comparison operator.
119def _validate_pass_schedule_constraint(
120    constraint: Callable[[Callable, Callable], bool], passes: List[Callable]
121):
122    for i, a in enumerate(passes):
123        for j, b in enumerate(passes[i + 1 :]):
124            if constraint(a, b):
125                continue
126            raise RuntimeError(
127                f"pass schedule constraint violated. Expected {a} before {b}"
128                f" but found {a} at index {i} and {b} at index{j} in pass"
129                f" list."
130            )
131
132
133def this_before_that_pass_constraint(this: Callable, that: Callable):
134    """
135    Defines a partial order ('depends on' function) where `this` must occur
136    before `that`.
137    """
138
139    def depends_on(a: Callable, b: Callable):
140        return a != that or b != this
141
142    return depends_on
143
144
145def these_before_those_pass_constraint(these: Callable, those: Callable):
146    """
147    Defines a partial order ('depends on' function) where `these` must occur
148    before `those`. Where the inputs are 'unwrapped' before comparison.
149
150    For example, the following pass list and constraint list would be invalid.
151    ```
152    passes = [
153        loop_pass(pass_b, 3),
154        loop_pass(pass_a, 5),
155    ]
156
157    constraints = [
158        these_before_those_pass_constraint(pass_a, pass_b)
159    ]
160    ```
161
162    Args:
163        these (Callable): pass which should occur first
164        those (Callable): pass which should occur later
165
166    Returns:
167        depends_on (Callable[[Object, Object], bool]
168    """
169
170    def depends_on(a: Callable, b: Callable):
171        return unwrap(a) != those or unwrap(b) != these
172
173    return depends_on
174
175
176class PassManager:
177    """
178    Construct a PassManager.
179
180    Collects passes and constraints. This defines the pass schedule, manages
181    pass constraints and pass execution.
182
183    Args:
184        passes (Optional[List[Callable]]): list of passes. A pass is a
185            callable which modifies an object and returns modified object
186        constraint (Optional[List[Callable]]): list of constraints. A
187            constraint is a callable which takes two passes (A, B) and returns
188            True if A depends on B and False otherwise. See implementation of
189            `this_before_that_pass_constraint` for example.
190    """
191
192    passes: List[Callable]
193    constraints: List[Callable]
194    _validated: bool = False
195
196    def __init__(
197        self,
198        passes=None,
199        constraints=None,
200    ):
201        self.passes = passes or []
202        self.constraints = constraints or []
203
204    @classmethod
205    def build_from_passlist(cls, passes):
206        pm = PassManager(passes)
207        # TODO(alexbeloi): add constraint management/validation
208        return pm
209
210    def add_pass(self, _pass: Callable):
211        self.passes.append(_pass)
212        self._validated = False
213
214    def add_constraint(self, constraint):
215        self.constraints.append(constraint)
216        self._validated = False
217
218    def remove_pass(self, _passes: List[str]):
219        if _passes is None:
220            return
221        passes_left = []
222        for ps in self.passes:
223            if ps.__name__ not in _passes:
224                passes_left.append(ps)
225        self.passes = passes_left
226        self._validated = False
227
228    def replace_pass(self, _target, _replacement):
229        passes_left = []
230        for ps in self.passes:
231            if ps.__name__ == _target.__name__:
232                passes_left.append(_replacement)
233            else:
234                passes_left.append(ps)
235        self.passes = passes_left
236        self._validated = False
237
238    def validate(self):
239        """
240        Validates that current pass schedule defined by `self.passes` is valid
241        according to all constraints in `self.constraints`
242        """
243        if self._validated:
244            return
245        for constraint in self.constraints:
246            _validate_pass_schedule_constraint(constraint, self.passes)
247        self._validated = True
248
249    def __call__(self, source):
250        self.validate()
251        out = source
252        for _pass in self.passes:
253            out = _pass(out)
254        return out
255