xref: /aosp_15_r20/external/executorch/docs/source/compiler-custom-compiler-passes.md (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Custom Compiler Passes and Partitioners
2
3## Passes
4
5Passes can be roughly categorized into a couple of axes:
6
7Axis A:
81. Creating one-to-X mapping (for example, decomposition)
92. Creating many-to-one mapping (for example, fusion)
10
11Axis B:
121. Performing forwards iteration (for example, shape propagation)
132. Performing backwards iteration (for example, dead code elimination)
14
15Axis C:
161. Dependent on local node information (eg. out-variant conversion)
172. Dependent on global graph information (eg. memory planning)
18
19Our projection on the frequency of these use cases are:
201. A.1, B.1, C.1
212. A.2
223. B.2, C.2
23
24### Level 1
25
26For level 1 uses cases (creating one-to-X mappings, performing forwards iterations,
27and looking at local node information), we can utilize a helper class called
28[`ExportPass`](https://github.com/pytorch/executorch/blob/d9eef24bb720804aa7b400b05241487510ae0dc2/exir/pass_base.py#L44).
29This is an
30[interpreter-based](https://pytorch.org/docs/stable/fx.html#the-interpreter-pattern)
31way where we execute each node and recreate the graph except with
32transformations specified. This allows us to preserve the IR Spec by ensuring
33that all nodes created while in the pass meet the IR Spec including ensuring that
34metadata such as stack trace, FakeTensor values, and torch.nn.Module hierarchy
35are preserved and updated depending on the transformations made.
36
37To implement this pass, we can create a subclass of
38[`ExportPass`](https://github.com/pytorch/executorch/blob/d9eef24bb720804aa7b400b05241487510ae0dc2/exir/pass_base.py#L44)
39and implement the exposed functions.  When called with a graph module, it will
40run the graph module and create a new graph containing the changes specified by
41the pass. This means that the graph module passed in must be runnable on CPU,
42and this invariant will be maintained after the pass is run.
43
44#### One-to-One Pass
45
46An example for one-to-one mappings, if we wanted to replace an op A with another op B,
47we can run the given
48`fx.GraphModule`, and every time we see op A, return op B.
49
50Consider the following example:
51
52```python
53class ReplaceInPlaceReluWithOutOfPlaceReluPass(ExportPass):
54    """
55    relu_ is the in-place version. Replace it with relu, which is the
56    out-of-place version
57    """
58
59    def call_operator(self, op, args, kwargs, meta):
60        if op != torch.ops.aten.relu_.default:
61            return super().call_operator(op, args, kwargs, meta)
62        return super().call_operator(Op(torch.ops.aten.relu.default), args, kwargs, meta)
63
64# To create a pass
65replace_pass = ReplaceInPlaceReluWithOutOfPlaceReluPass()
66# To run a pass
67new_graph_module = replace_pass(graph_module).graph_module
68```
69
70The `super().call_operator(op, args, kwargs, meta)` call creates a
71`call_function` FX node, and returns the result of running the operator with the
72given arguments.
73
74#### One-to-X Pass
75
76If we wanted to do one-to-X mappings, like replacing op A with 2 other ops B and
77C, we would then make 2 calls to `super().call_operator` to create 2 FX nodes,
78one with op B and another with op C, and return the result of running op C.
79
80For example:
81```python
82class ReplaceAddWithMulSub(ExportPass):
83    """
84    Original:
85        def f(x, y):
86            return x + y
87
88    After pass:
89        def f(x, y):
90            z = x * y
91            return z - y
92    """
93    def call_operator(self, op, args, kwargs, meta):
94        if op != torch.ops.aten.add.default:
95            return super().call_operator(op, args, kwargs, meta)
96
97        x, y = args
98
99        mul_res = super().call_operator(
100            torch.ops.aten.mul.default,
101            args,
102            {},
103            meta
104        )
105
106        return super().call_operator(
107            torch.ops.aten.sub.default,
108            (mul_res, y),
109            {},
110            meta
111        )
112```
113
114#### One-to-None Pass
115
116If we wanted to remove an op, we can just return the value passed into the
117function:
118
119```python
120class RemoveDetachPass(ExportPass):
121    def call_operator(self, op, args, kwargs, meta):
122        if op not in (
123            torch.ops.aten.detach.default,
124            torch.ops.aten.detach_copy.default,
125        ):
126            return super().call_operator(op, args, kwargs, meta)
127
128        assert len(args) == 1
129        return args[0]
130```
131
132#### Utilizing Local Information
133
134An example of utilizing local node information is, if we wanted to convert all the
135scalars within the graph to tensors, we
136can run the given `fx.GraphModule`, and for every argument that contains a scalar,
137we convert it to a tensor. It might look something like:
138
139```python
140def args_map(op, fn, args, kwargs):
141    assert isinstance(args, tuple)
142    assert isinstance(kwargs, dict)
143    args = list(args)
144    kwargs = kwargs.copy()
145
146    # Update the argument based on the function passed
147    def update(key, args, schema):
148        args[key] = fn(args[key], schema)
149
150    # Update each argument in the schema
151    for i, schema in enumerate(self.op._schema.arguments):
152        if schema.name in kwargs:
153            update(schema.name, kwargs, schema)
154        elif not schema.kwarg_only and i < len(args):
155            update(i, args, schema)
156
157class ScalarToTensorPass(ExportPass):
158    def call_operator(self, op, args, kwargs):
159        def try_coerce(value, arg):
160            return (
161                torch.tensor(value)
162                if isinstance(value, (float, int, bool))
163                and type(arg.type) == torch.TensorType
164                else value
165            )
166
167        args, kwargs = args_map(op, try_coerce, args, kwargs)
168        return super().call_operator(op, args, kwargs)
169```
170
171### Level 2
172
173For creating many-to-one mappings, we can utilize FX's [subgraph
174rewriter](https://github.com/pytorch/pytorch/blob/8597d37536ef11bdf6b0a539ab79af876e1c92f6/torch/fx/subgraph_rewriter.py#L77).
175Given a `pattern`, it creates a subgraph of operators matching to the pattern,
176and then replaces each matched subgraph with the `replacement`.
177
178```{note}
179
180    This is an inplace operation.
181
182```
183
184The `pattern` and `replacement` inputs must be callable functions written with
185the same ops that are used in the EXIR graph you are matching with (ATen ops)
186so that the subgraph rewriter can find the correct pattern in the graph. Inputs
187to the pattern/replacement callables will be treated as wildcards.
188
189Consider the following example:
190
191```python
192from torch.fx import subgraph_rewriter
193
194def replace_patterns(graph_module):
195    def pattern(x, y):
196        x = torch.ops.aten.add.Tensor(x, y)
197        x = torch.ops.aten.mul.Tensor(x, y)
198        return x
199
200    def replacement(x, y):
201        return torch.ops.aten.sub.Tensor(x, y)
202
203replaced_patterns = subgraph_rewriter.replace_pattern_with_filters(
204    traced_module, pattern, replacement
205)
206```
207
208The subgraph rewriter returns a list of `ReplacedPatterns`:
209
210```python
211@dataclass
212class ReplacedPatterns:
213    # Node from which the match was found
214    anchor: Node
215    # Maps nodes in the pattern subgraph to nodes in the larger graph
216    nodes_map: Dict[Node, Node]
217    # List of nodes that were added into the graph
218    replacements: List[Node]
219```
220
221```{note}
222
223    The nodes created by the subgraph rewriter will not have the metadata that
224    is normally in EXIR nodes (`stack_trace`, `val`, `nn_module_stack`).
225
226```
227
228
229### Level 3
230
231For the third way of creating a pass, we can utilize the most basic
232[`PassBase`](https://github.com/pytorch/pytorch/blob/8597d37536ef11bdf6b0a539ab79af876e1c92f6/torch/fx/passes/infra/pass_base.py#L22).
233To create a pass, we can subclass this and implement the function `call` with
234the pass contents. Additionally, we can implement the functions `requires` and
235`ensures` which will be called before and after the function `call`. Note that
236these functions can also be overridden in `ExportPass`. To run a pass on a graph
237module, we can pass the graph module directly to an instance of the class.
238
239Consider the following example:
240
241```python
242class ReplaceAddPass(PassBase):
243
244    def __init__(self, replace_op):
245        self.replace_op = replace_op
246
247    def call(self, graph_module):
248        for node in gm.graph.nodes:
249            if node.op == "call_function" and node.target == torch.add:
250                node.target = self.replace_op
251
252    # Optional to implement, will be called before call()
253    def requires(self, graph_module) -> None:
254        for node in graph_module.graph.nodes:
255            if node.op == "call_function" and node.target == torch.add:
256                return
257        raise ValueError("No torch.add ops!")
258
259    # Optional to implement, will be called after call()
260    def ensures(self, graph_module: torch.fx.GraphModule) -> None:
261        pass
262
263# To create a pass
264replace_add_with_div = ReplaceAddPass(torch.div)
265# To run a pass
266replace_add_with_div(graph_module)
267```
268
269## Pass Manager
270
271The `PassManager` is a class used to run multiple passes on a given graph
272module. When initializing a `PassManager` instance, we pass in a list of passes
273that we want to run and set a couple of flags. To run the collection of passes
274on a graph module, we can pass the graph module directly to the `PassManager`
275instance.
276
277An example:
278```python
279from executorch.exir.pass_manager import PassManager
280
281pm = PassManager(
282    passes=[replace_add_with_div, replace_div_with_mul],
283    run_checks_after_each_pass=True,
284    suppress_check_failures=False,
285)
286graph_module_out = pm(graph_module)
287```
288
289To add a common set of checks that are run after each pass, we can call the
290function `set_checks(check: Callable)` which takes in a callable function as
291input. If the `run_checks_after_each_pass` flag is set, the `check` will be
292called after each pass is run on the graph module.
293
294An example:
295```python
296pm = PassManager(passes=[replace_add_with_div, replace_div_with_mul])
297
298def check_div_target(graph_module):
299    for node in graph_module.graph.nodes:
300        if node.op == "call_function" and node.target != torch.div:
301            raise ValueError("Target should be div!")
302
303pm.add_checks(check_div_target)
304
305pm(graph_module)    # raises ValueError after replace_div_with_mul pass
306```
307
308## Partitioner
309
310There are a couple of common FX-graph based partitioners we can use to partition
311the graph. However, these do not necessarily produce a graph that is compliant
312with IR Spec, so be careful when using them.
313
314### Subgraph Matcher
315
316For finding subgraphs within a graph that match a specific pattern, we can
317utilize FX's
318[`SubgraphMatcher`](https://github.com/pytorch/pytorch/blob/8597d37536ef11bdf6b0a539ab79af876e1c92f6/torch/fx/passes/utils/matcher_utils.py#L51).
319
320Class Attributes:
321
322* `pattern (Graph)`: The targeted matching pattern. Placeholder nodes in the
323   graph will be treated as wildcards when matching.
324* `match_output (bool)`: If True, output node in the pattern graph will be
325   treated as a part of the targeted pattern.  If False, output node is ignored
326   during match.
327* `match_placeholder (bool)`: If True, placeholder node in the pattern graph
328   will be treated as a part of the targeted pattern. If False, placeholder
329   nodes will be used a wildcard.
330* `remove_overlapping_matches (bool)`: If True, in the case of overlapping
331   matches, only the first match will be returned.
332*  `ignore_literals (bool)`: If True, will not check if literals are equal and
333   will instead treat them as wildcards.
334
335Consider the following example:
336
337```python
338from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
339
340class LargeModel(torch.nn.Module):
341    def __init__(self):
342        super().__init__()
343        self._weight = torch.nn.Parameter(torch.ones(3, 3))
344        self._bias = torch.nn.Parameter(torch.ones(3, 3))
345
346    def forward(self, x):
347        return torch.ops.aten.addmm.default(self._bias, x, self._weight)
348
349large_model_graph = to_edge(export(LargeModel(), large_inputs)).exported_program().graph_module.graph
350
351class PatternModel(torch.nn.Module):
352    def __init__(self):
353        super().__init__()
354        self._weight_1 = torch.nn.Parameter(torch.ones(5, 5))
355        self._bias_1 = torch.nn.Parameter(torch.ones(5, 5))
356
357    def forward(self, x):
358        return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1)
359
360pattern_graph = to_edge(export(PatternModel(), pattern_inputs)).exported_program().graph_module.graph
361
362subgraph_matcher = SubgraphMatcher(pattern_graph)
363match_result = subgraph_matcher.match(large_model_graph)
364```
365
366The `match` function returns a list of `InternalMatch`:
367
368```python
369@dataclass
370class InternalMatch():
371    # Nodes from which the match was found
372    anchors: List[Node]
373    # Maps nodes in the pattern subgraph to nodes in the larger graph
374    nodes_map: Dict[Node, Node] = field(default_factory=dict)
375    # Nodes in target graph that are matched placeholder in pattern
376    placeholder_nodes: List[Node] = field(default_factory=list)
377    # Nodes in matched subgraph returned by output
378    returning_nodes: List[Node] = field(default_factory=list)
379```
380
381### Capability Based Partitioner
382
383To find the largest subgraphs of nodes that support a specific invariant, we can
384utilize FX's
385[`CapabilityBasedPartitioner`](https://github.com/pytorch/pytorch/blob/8597d37536ef11bdf6b0a539ab79af876e1c92f6/torch/fx/passes/infra/partitioner.py#L34C1-L34C1).
386
387Class Attributes
388
389* `graph_module (torch.fx.GraphModule)`: The graph module we are partitioning on.
390* `operator_support (OperatorSupportBase)`: The object used to determine if a
391   node in the graph is supported in the partition.
392* `allows_single_node_partition (bool)`: If True, allows single node
393   partitions to be formed.
394* `non_compute_ops (Optional[Sequence[str]])`: A set of ops that are
395   considered to be "non-compute" (ex `torch.ops.aten.view` and
396   `_operator.getitem`, so that the partitioner will not create graphs that only
397   contain these non-compute ops
398* `allowed_single_node_partition_ops (Optional[Sequence[str]])`: A set of ops
399   that are allowed to be in a single node partition.
400
401The
402[`OperatorSupportBase`](https://github.com/pytorch/pytorch/blob/8597d37536ef11bdf6b0a539ab79af876e1c92f6/torch/fx/passes/operator_support.py#L28)
403class is used by
404the partitioner to determine if a specific node in the graph belongs in the
405partition. This is done by overriding the `is_node_supported` function. You can
406chain multiple `OperatorSuppportBase` by using
407[`chain`](https://github.com/pytorch/pytorch/blob/8597d37536ef11bdf6b0a539ab79af876e1c92f6/torch/fx/passes/operator_support.py#L150)(which
408returns False if any of the OperatorSupportBase return False) and
409[`any_chain`](https://github.com/pytorch/pytorch/blob/8597d37536ef11bdf6b0a539ab79af876e1c92f6/torch/fx/passes/operator_support.py#L164)
410(which returns True if any of the OperatorSupportBase returns True).
411
412Consider the following example:
413
414```python
415from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
416from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
417
418class AddMulOperatorSupport(OperatorSupportBase):
419    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
420        return node.op == "call_function" and node.target in [
421            torch.ops.aten.add.Tensor, torch.ops.aten.mul.Tensor,
422        ]
423
424capability_partitioner = CapabilityBasedPartitioner(
425    graph_module,
426    op_support,
427)
428
429# Returns a list of partitions (list of nodes that belong in each partition)
430partition_list = capability_partitioner.propose_partitions()
431```
432
433If you look at the capability based partitioner, you may also find a
434`fuse_partition` function which will return a modified graph with the partitions
435as submodules, and calls to these submodules in the toplevel graph through
436`call_module` nodes. However, this is not compliant to the IR Spec because we do
437not allow `call_module` nodes.
438
439
440### Combined
441
442We also provide a combined helper function:
443[`generate_pattern_op_partitions`](https://github.com/pytorch/executorch/blob/d9eef24bb720804aa7b400b05241487510ae0dc2/exir/backend/canonical_partitioners/pattern_op_partitioner.py#L59)
444
445Args:
446* `graph_module (fx.GraphModule)`: Module that we want to partition
447* `patterns (List[torch.fx.Graph])`: A list of patterns in the form of
448   torch.fx.Graph. These graphs can be obtained through the `graph` field from a
449   GraphModule obtained by exir.capture (recommended) or symbolic tracing (which
450   might not result in an accurate edge dialect graph), or by manual crafting a
451   graph module.
452* `op_support (OperatorSupportBase)`: A OperatorSupportBase that can be created
453   in the following ways:
454    * Subclassing it directly and implementing `is_node_supported()`
455    * Getting the result of `create_op_support()`
456    * Getting the result of `create_pattern_support()`
457    * Multiple OperatorSupportBase classes chained together with `chain()` or `any_chain()`
458
459Returns
460* A list of partitions (largest possible subgraphs) containing nodes are
461  supported by the union of the given OperatorSupportBase object and the
462  given pattern graphs.
463
464
465### Source Partitioner
466
467For more complicated use cases in which users want to partition based on higher
468level modules (`torch.nn.Linear` or `torch.nn.functional.Linear`) which are now
469decomposed into their operators (`aten.permute`, `aten.addmm`), we have the
470following [helper function](https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/utils/source_matcher_utils.py#L51):
471
472`get_source_partitions(graph: torch.fx.Graph, wanted_sources: List[Any]) -> Dict[Any, SourcePartition]`
473
474Args:
475* `graph`: The graph we want to partition
476* `wanted_sources`: List of sources of nodes that were decomposed from this
477    source. This can be a function (ex. `torch.nn.functional.linear`) or a leaf
478    module type (ex. `torch.nn.Linear`)
479
480Returns:
481* Dictionary mapping sources (ex. `torch.nn.modules.linear.Linear`) to a list of
482    `SourcePartitions` that correspond to the list of nodes that were flattened from
483    a module of that type.
484
485```python
486@dataclass
487class SourcePartition():
488    # Nodes in a particular partition
489    nodes: List[Node]
490    # Module type
491    module_type: Type
492    # Nodes in the graph that are needed as inputs to the partition
493    input_nodes: List[Node] = field(default_factory=list)
494    # Nodes in the partition that are being used by nodes outside of the partition
495    output_nodes: List[Node] = field(default_factory=list)
496    # Parameters that are being used
497    params: List[str] = field(default_factory=list)
498```
499
500An example:
501
502```python
503class M(torch.nn.Module):
504    def __init__(self):
505        super().__init__()
506        self.linear1 = torch.nn.Linear(3, 3)
507        self.relu = torch.nn.ReLU()
508        self.linear2 = torch.nn.Linear(3, 5)
509
510    def forward(self, x):
511        x = self.linear1(x)
512        x = self.linear1(x)
513        x = self.relu(x)
514        x = self.linear2(x)
515        return x
516
517inputs = (torch.randn(3, 3),)
518edge_graph = to_edge(export(M(), inputs)).exported_program().graph_module.graph
519print(edge_graph)
520"""
521graph():
522    %arg0 : [#users=1] = placeholder[target=arg0]
523    %_param_constant0 : [#users=1] = get_attr[target=_param_constant0]
524    %permute_default : [#users=1] = call_function[target=torch.ops.aten.permute_copy.default](args = (%_param_constant0,), kwargs = {})
525    %_param_constant1 : [#users=1] = get_attr[target=_param_constant1]
526    %addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0, %t_default), kwargs = {})
527    %_param_constant0_1 : [#users=1] = get_attr[target=_param_constant0]
528    %permute_default_1 : [#users=1] = call_function[target=torch.ops.aten.permute_copy.default](args = (%_param_constant0_1,), kwargs = {})
529    %_param_constant1_1 : [#users=1] = get_attr[target=_param_constant1]
530    %addmm_default_1 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1_1, %addmm_default, %t_default_1), kwargs = {})
531    %relu_default : [#users=1] = call_function[target=torch.ops.aten.relu.default](args = (%addmm_default_1,), kwargs = {})
532    %_param_constant2 : [#users=1] = get_attr[target=_param_constant2]
533    %permute_default_2 : [#users=1] = call_function[target=torch.ops.aten.permute_copy.default](args = (%_param_constant2,), kwargs = {})
534    %_param_constant3 : [#users=1] = get_attr[target=_param_constant3]
535    %addmm_default_2 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant3, %relu_default, %t_default_2), kwargs = {})
536    return [addmm_default_2]
537"""
538
539module_partitions = get_source_partitions(edge_graph, [torch.nn.Linear, torch.nn.ReLU])
540print(module_partitions)
541"""
542{<class 'torch.nn.modules.linear.Linear'>: [
543    ModulePartition(nodes=[_param_constant0, t_default, _param_constant1, addmm_default], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[arg0], output_nodes=[addmm_default], params=["_param_constant0", "_param_constant1"]),
544    ModulePartition(nodes=[_param_constant0_1, t_default_1, _param_constant1_1, addmm_default_1], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[addmm_default], output_nodes=[addmm_default_1], params=["_param_constant0_1", "_param_constant1_1"]),
545    ModulePartition(nodes=[_param_constant2, t_default_2, _param_constant3, addmm_default_2], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[relu_default], output_nodes=[addmm_default_2], params=["_param_constant2", "_param_constant3"])],
546
547 <class 'torch.nn.modules.activation.ReLU'>: [
548    ModulePartition(nodes=[relu_default], module_type=<class 'torch.nn.modules.activation.ReLU'>, input_nodes=[addmm_default_1], output_nodes=[relu_default], params=[])]}
549"""
550```
551