1# Custom Compiler Passes and Partitioners 2 3## Passes 4 5Passes can be roughly categorized into a couple of axes: 6 7Axis A: 81. Creating one-to-X mapping (for example, decomposition) 92. Creating many-to-one mapping (for example, fusion) 10 11Axis B: 121. Performing forwards iteration (for example, shape propagation) 132. Performing backwards iteration (for example, dead code elimination) 14 15Axis C: 161. Dependent on local node information (eg. out-variant conversion) 172. Dependent on global graph information (eg. memory planning) 18 19Our projection on the frequency of these use cases are: 201. A.1, B.1, C.1 212. A.2 223. B.2, C.2 23 24### Level 1 25 26For level 1 uses cases (creating one-to-X mappings, performing forwards iterations, 27and looking at local node information), we can utilize a helper class called 28[`ExportPass`](https://github.com/pytorch/executorch/blob/d9eef24bb720804aa7b400b05241487510ae0dc2/exir/pass_base.py#L44). 29This is an 30[interpreter-based](https://pytorch.org/docs/stable/fx.html#the-interpreter-pattern) 31way where we execute each node and recreate the graph except with 32transformations specified. This allows us to preserve the IR Spec by ensuring 33that all nodes created while in the pass meet the IR Spec including ensuring that 34metadata such as stack trace, FakeTensor values, and torch.nn.Module hierarchy 35are preserved and updated depending on the transformations made. 36 37To implement this pass, we can create a subclass of 38[`ExportPass`](https://github.com/pytorch/executorch/blob/d9eef24bb720804aa7b400b05241487510ae0dc2/exir/pass_base.py#L44) 39and implement the exposed functions. When called with a graph module, it will 40run the graph module and create a new graph containing the changes specified by 41the pass. This means that the graph module passed in must be runnable on CPU, 42and this invariant will be maintained after the pass is run. 43 44#### One-to-One Pass 45 46An example for one-to-one mappings, if we wanted to replace an op A with another op B, 47we can run the given 48`fx.GraphModule`, and every time we see op A, return op B. 49 50Consider the following example: 51 52```python 53class ReplaceInPlaceReluWithOutOfPlaceReluPass(ExportPass): 54 """ 55 relu_ is the in-place version. Replace it with relu, which is the 56 out-of-place version 57 """ 58 59 def call_operator(self, op, args, kwargs, meta): 60 if op != torch.ops.aten.relu_.default: 61 return super().call_operator(op, args, kwargs, meta) 62 return super().call_operator(Op(torch.ops.aten.relu.default), args, kwargs, meta) 63 64# To create a pass 65replace_pass = ReplaceInPlaceReluWithOutOfPlaceReluPass() 66# To run a pass 67new_graph_module = replace_pass(graph_module).graph_module 68``` 69 70The `super().call_operator(op, args, kwargs, meta)` call creates a 71`call_function` FX node, and returns the result of running the operator with the 72given arguments. 73 74#### One-to-X Pass 75 76If we wanted to do one-to-X mappings, like replacing op A with 2 other ops B and 77C, we would then make 2 calls to `super().call_operator` to create 2 FX nodes, 78one with op B and another with op C, and return the result of running op C. 79 80For example: 81```python 82class ReplaceAddWithMulSub(ExportPass): 83 """ 84 Original: 85 def f(x, y): 86 return x + y 87 88 After pass: 89 def f(x, y): 90 z = x * y 91 return z - y 92 """ 93 def call_operator(self, op, args, kwargs, meta): 94 if op != torch.ops.aten.add.default: 95 return super().call_operator(op, args, kwargs, meta) 96 97 x, y = args 98 99 mul_res = super().call_operator( 100 torch.ops.aten.mul.default, 101 args, 102 {}, 103 meta 104 ) 105 106 return super().call_operator( 107 torch.ops.aten.sub.default, 108 (mul_res, y), 109 {}, 110 meta 111 ) 112``` 113 114#### One-to-None Pass 115 116If we wanted to remove an op, we can just return the value passed into the 117function: 118 119```python 120class RemoveDetachPass(ExportPass): 121 def call_operator(self, op, args, kwargs, meta): 122 if op not in ( 123 torch.ops.aten.detach.default, 124 torch.ops.aten.detach_copy.default, 125 ): 126 return super().call_operator(op, args, kwargs, meta) 127 128 assert len(args) == 1 129 return args[0] 130``` 131 132#### Utilizing Local Information 133 134An example of utilizing local node information is, if we wanted to convert all the 135scalars within the graph to tensors, we 136can run the given `fx.GraphModule`, and for every argument that contains a scalar, 137we convert it to a tensor. It might look something like: 138 139```python 140def args_map(op, fn, args, kwargs): 141 assert isinstance(args, tuple) 142 assert isinstance(kwargs, dict) 143 args = list(args) 144 kwargs = kwargs.copy() 145 146 # Update the argument based on the function passed 147 def update(key, args, schema): 148 args[key] = fn(args[key], schema) 149 150 # Update each argument in the schema 151 for i, schema in enumerate(self.op._schema.arguments): 152 if schema.name in kwargs: 153 update(schema.name, kwargs, schema) 154 elif not schema.kwarg_only and i < len(args): 155 update(i, args, schema) 156 157class ScalarToTensorPass(ExportPass): 158 def call_operator(self, op, args, kwargs): 159 def try_coerce(value, arg): 160 return ( 161 torch.tensor(value) 162 if isinstance(value, (float, int, bool)) 163 and type(arg.type) == torch.TensorType 164 else value 165 ) 166 167 args, kwargs = args_map(op, try_coerce, args, kwargs) 168 return super().call_operator(op, args, kwargs) 169``` 170 171### Level 2 172 173For creating many-to-one mappings, we can utilize FX's [subgraph 174rewriter](https://github.com/pytorch/pytorch/blob/8597d37536ef11bdf6b0a539ab79af876e1c92f6/torch/fx/subgraph_rewriter.py#L77). 175Given a `pattern`, it creates a subgraph of operators matching to the pattern, 176and then replaces each matched subgraph with the `replacement`. 177 178```{note} 179 180 This is an inplace operation. 181 182``` 183 184The `pattern` and `replacement` inputs must be callable functions written with 185the same ops that are used in the EXIR graph you are matching with (ATen ops) 186so that the subgraph rewriter can find the correct pattern in the graph. Inputs 187to the pattern/replacement callables will be treated as wildcards. 188 189Consider the following example: 190 191```python 192from torch.fx import subgraph_rewriter 193 194def replace_patterns(graph_module): 195 def pattern(x, y): 196 x = torch.ops.aten.add.Tensor(x, y) 197 x = torch.ops.aten.mul.Tensor(x, y) 198 return x 199 200 def replacement(x, y): 201 return torch.ops.aten.sub.Tensor(x, y) 202 203replaced_patterns = subgraph_rewriter.replace_pattern_with_filters( 204 traced_module, pattern, replacement 205) 206``` 207 208The subgraph rewriter returns a list of `ReplacedPatterns`: 209 210```python 211@dataclass 212class ReplacedPatterns: 213 # Node from which the match was found 214 anchor: Node 215 # Maps nodes in the pattern subgraph to nodes in the larger graph 216 nodes_map: Dict[Node, Node] 217 # List of nodes that were added into the graph 218 replacements: List[Node] 219``` 220 221```{note} 222 223 The nodes created by the subgraph rewriter will not have the metadata that 224 is normally in EXIR nodes (`stack_trace`, `val`, `nn_module_stack`). 225 226``` 227 228 229### Level 3 230 231For the third way of creating a pass, we can utilize the most basic 232[`PassBase`](https://github.com/pytorch/pytorch/blob/8597d37536ef11bdf6b0a539ab79af876e1c92f6/torch/fx/passes/infra/pass_base.py#L22). 233To create a pass, we can subclass this and implement the function `call` with 234the pass contents. Additionally, we can implement the functions `requires` and 235`ensures` which will be called before and after the function `call`. Note that 236these functions can also be overridden in `ExportPass`. To run a pass on a graph 237module, we can pass the graph module directly to an instance of the class. 238 239Consider the following example: 240 241```python 242class ReplaceAddPass(PassBase): 243 244 def __init__(self, replace_op): 245 self.replace_op = replace_op 246 247 def call(self, graph_module): 248 for node in gm.graph.nodes: 249 if node.op == "call_function" and node.target == torch.add: 250 node.target = self.replace_op 251 252 # Optional to implement, will be called before call() 253 def requires(self, graph_module) -> None: 254 for node in graph_module.graph.nodes: 255 if node.op == "call_function" and node.target == torch.add: 256 return 257 raise ValueError("No torch.add ops!") 258 259 # Optional to implement, will be called after call() 260 def ensures(self, graph_module: torch.fx.GraphModule) -> None: 261 pass 262 263# To create a pass 264replace_add_with_div = ReplaceAddPass(torch.div) 265# To run a pass 266replace_add_with_div(graph_module) 267``` 268 269## Pass Manager 270 271The `PassManager` is a class used to run multiple passes on a given graph 272module. When initializing a `PassManager` instance, we pass in a list of passes 273that we want to run and set a couple of flags. To run the collection of passes 274on a graph module, we can pass the graph module directly to the `PassManager` 275instance. 276 277An example: 278```python 279from executorch.exir.pass_manager import PassManager 280 281pm = PassManager( 282 passes=[replace_add_with_div, replace_div_with_mul], 283 run_checks_after_each_pass=True, 284 suppress_check_failures=False, 285) 286graph_module_out = pm(graph_module) 287``` 288 289To add a common set of checks that are run after each pass, we can call the 290function `set_checks(check: Callable)` which takes in a callable function as 291input. If the `run_checks_after_each_pass` flag is set, the `check` will be 292called after each pass is run on the graph module. 293 294An example: 295```python 296pm = PassManager(passes=[replace_add_with_div, replace_div_with_mul]) 297 298def check_div_target(graph_module): 299 for node in graph_module.graph.nodes: 300 if node.op == "call_function" and node.target != torch.div: 301 raise ValueError("Target should be div!") 302 303pm.add_checks(check_div_target) 304 305pm(graph_module) # raises ValueError after replace_div_with_mul pass 306``` 307 308## Partitioner 309 310There are a couple of common FX-graph based partitioners we can use to partition 311the graph. However, these do not necessarily produce a graph that is compliant 312with IR Spec, so be careful when using them. 313 314### Subgraph Matcher 315 316For finding subgraphs within a graph that match a specific pattern, we can 317utilize FX's 318[`SubgraphMatcher`](https://github.com/pytorch/pytorch/blob/8597d37536ef11bdf6b0a539ab79af876e1c92f6/torch/fx/passes/utils/matcher_utils.py#L51). 319 320Class Attributes: 321 322* `pattern (Graph)`: The targeted matching pattern. Placeholder nodes in the 323 graph will be treated as wildcards when matching. 324* `match_output (bool)`: If True, output node in the pattern graph will be 325 treated as a part of the targeted pattern. If False, output node is ignored 326 during match. 327* `match_placeholder (bool)`: If True, placeholder node in the pattern graph 328 will be treated as a part of the targeted pattern. If False, placeholder 329 nodes will be used a wildcard. 330* `remove_overlapping_matches (bool)`: If True, in the case of overlapping 331 matches, only the first match will be returned. 332* `ignore_literals (bool)`: If True, will not check if literals are equal and 333 will instead treat them as wildcards. 334 335Consider the following example: 336 337```python 338from torch.fx.passes.utils.matcher_utils import SubgraphMatcher 339 340class LargeModel(torch.nn.Module): 341 def __init__(self): 342 super().__init__() 343 self._weight = torch.nn.Parameter(torch.ones(3, 3)) 344 self._bias = torch.nn.Parameter(torch.ones(3, 3)) 345 346 def forward(self, x): 347 return torch.ops.aten.addmm.default(self._bias, x, self._weight) 348 349large_model_graph = to_edge(export(LargeModel(), large_inputs)).exported_program().graph_module.graph 350 351class PatternModel(torch.nn.Module): 352 def __init__(self): 353 super().__init__() 354 self._weight_1 = torch.nn.Parameter(torch.ones(5, 5)) 355 self._bias_1 = torch.nn.Parameter(torch.ones(5, 5)) 356 357 def forward(self, x): 358 return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1) 359 360pattern_graph = to_edge(export(PatternModel(), pattern_inputs)).exported_program().graph_module.graph 361 362subgraph_matcher = SubgraphMatcher(pattern_graph) 363match_result = subgraph_matcher.match(large_model_graph) 364``` 365 366The `match` function returns a list of `InternalMatch`: 367 368```python 369@dataclass 370class InternalMatch(): 371 # Nodes from which the match was found 372 anchors: List[Node] 373 # Maps nodes in the pattern subgraph to nodes in the larger graph 374 nodes_map: Dict[Node, Node] = field(default_factory=dict) 375 # Nodes in target graph that are matched placeholder in pattern 376 placeholder_nodes: List[Node] = field(default_factory=list) 377 # Nodes in matched subgraph returned by output 378 returning_nodes: List[Node] = field(default_factory=list) 379``` 380 381### Capability Based Partitioner 382 383To find the largest subgraphs of nodes that support a specific invariant, we can 384utilize FX's 385[`CapabilityBasedPartitioner`](https://github.com/pytorch/pytorch/blob/8597d37536ef11bdf6b0a539ab79af876e1c92f6/torch/fx/passes/infra/partitioner.py#L34C1-L34C1). 386 387Class Attributes 388 389* `graph_module (torch.fx.GraphModule)`: The graph module we are partitioning on. 390* `operator_support (OperatorSupportBase)`: The object used to determine if a 391 node in the graph is supported in the partition. 392* `allows_single_node_partition (bool)`: If True, allows single node 393 partitions to be formed. 394* `non_compute_ops (Optional[Sequence[str]])`: A set of ops that are 395 considered to be "non-compute" (ex `torch.ops.aten.view` and 396 `_operator.getitem`, so that the partitioner will not create graphs that only 397 contain these non-compute ops 398* `allowed_single_node_partition_ops (Optional[Sequence[str]])`: A set of ops 399 that are allowed to be in a single node partition. 400 401The 402[`OperatorSupportBase`](https://github.com/pytorch/pytorch/blob/8597d37536ef11bdf6b0a539ab79af876e1c92f6/torch/fx/passes/operator_support.py#L28) 403class is used by 404the partitioner to determine if a specific node in the graph belongs in the 405partition. This is done by overriding the `is_node_supported` function. You can 406chain multiple `OperatorSuppportBase` by using 407[`chain`](https://github.com/pytorch/pytorch/blob/8597d37536ef11bdf6b0a539ab79af876e1c92f6/torch/fx/passes/operator_support.py#L150)(which 408returns False if any of the OperatorSupportBase return False) and 409[`any_chain`](https://github.com/pytorch/pytorch/blob/8597d37536ef11bdf6b0a539ab79af876e1c92f6/torch/fx/passes/operator_support.py#L164) 410(which returns True if any of the OperatorSupportBase returns True). 411 412Consider the following example: 413 414```python 415from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner 416from torch.fx.passes.operator_support import any_chain, OperatorSupportBase 417 418class AddMulOperatorSupport(OperatorSupportBase): 419 def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: 420 return node.op == "call_function" and node.target in [ 421 torch.ops.aten.add.Tensor, torch.ops.aten.mul.Tensor, 422 ] 423 424capability_partitioner = CapabilityBasedPartitioner( 425 graph_module, 426 op_support, 427) 428 429# Returns a list of partitions (list of nodes that belong in each partition) 430partition_list = capability_partitioner.propose_partitions() 431``` 432 433If you look at the capability based partitioner, you may also find a 434`fuse_partition` function which will return a modified graph with the partitions 435as submodules, and calls to these submodules in the toplevel graph through 436`call_module` nodes. However, this is not compliant to the IR Spec because we do 437not allow `call_module` nodes. 438 439 440### Combined 441 442We also provide a combined helper function: 443[`generate_pattern_op_partitions`](https://github.com/pytorch/executorch/blob/d9eef24bb720804aa7b400b05241487510ae0dc2/exir/backend/canonical_partitioners/pattern_op_partitioner.py#L59) 444 445Args: 446* `graph_module (fx.GraphModule)`: Module that we want to partition 447* `patterns (List[torch.fx.Graph])`: A list of patterns in the form of 448 torch.fx.Graph. These graphs can be obtained through the `graph` field from a 449 GraphModule obtained by exir.capture (recommended) or symbolic tracing (which 450 might not result in an accurate edge dialect graph), or by manual crafting a 451 graph module. 452* `op_support (OperatorSupportBase)`: A OperatorSupportBase that can be created 453 in the following ways: 454 * Subclassing it directly and implementing `is_node_supported()` 455 * Getting the result of `create_op_support()` 456 * Getting the result of `create_pattern_support()` 457 * Multiple OperatorSupportBase classes chained together with `chain()` or `any_chain()` 458 459Returns 460* A list of partitions (largest possible subgraphs) containing nodes are 461 supported by the union of the given OperatorSupportBase object and the 462 given pattern graphs. 463 464 465### Source Partitioner 466 467For more complicated use cases in which users want to partition based on higher 468level modules (`torch.nn.Linear` or `torch.nn.functional.Linear`) which are now 469decomposed into their operators (`aten.permute`, `aten.addmm`), we have the 470following [helper function](https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/utils/source_matcher_utils.py#L51): 471 472`get_source_partitions(graph: torch.fx.Graph, wanted_sources: List[Any]) -> Dict[Any, SourcePartition]` 473 474Args: 475* `graph`: The graph we want to partition 476* `wanted_sources`: List of sources of nodes that were decomposed from this 477 source. This can be a function (ex. `torch.nn.functional.linear`) or a leaf 478 module type (ex. `torch.nn.Linear`) 479 480Returns: 481* Dictionary mapping sources (ex. `torch.nn.modules.linear.Linear`) to a list of 482 `SourcePartitions` that correspond to the list of nodes that were flattened from 483 a module of that type. 484 485```python 486@dataclass 487class SourcePartition(): 488 # Nodes in a particular partition 489 nodes: List[Node] 490 # Module type 491 module_type: Type 492 # Nodes in the graph that are needed as inputs to the partition 493 input_nodes: List[Node] = field(default_factory=list) 494 # Nodes in the partition that are being used by nodes outside of the partition 495 output_nodes: List[Node] = field(default_factory=list) 496 # Parameters that are being used 497 params: List[str] = field(default_factory=list) 498``` 499 500An example: 501 502```python 503class M(torch.nn.Module): 504 def __init__(self): 505 super().__init__() 506 self.linear1 = torch.nn.Linear(3, 3) 507 self.relu = torch.nn.ReLU() 508 self.linear2 = torch.nn.Linear(3, 5) 509 510 def forward(self, x): 511 x = self.linear1(x) 512 x = self.linear1(x) 513 x = self.relu(x) 514 x = self.linear2(x) 515 return x 516 517inputs = (torch.randn(3, 3),) 518edge_graph = to_edge(export(M(), inputs)).exported_program().graph_module.graph 519print(edge_graph) 520""" 521graph(): 522 %arg0 : [#users=1] = placeholder[target=arg0] 523 %_param_constant0 : [#users=1] = get_attr[target=_param_constant0] 524 %permute_default : [#users=1] = call_function[target=torch.ops.aten.permute_copy.default](args = (%_param_constant0,), kwargs = {}) 525 %_param_constant1 : [#users=1] = get_attr[target=_param_constant1] 526 %addmm_default : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0, %t_default), kwargs = {}) 527 %_param_constant0_1 : [#users=1] = get_attr[target=_param_constant0] 528 %permute_default_1 : [#users=1] = call_function[target=torch.ops.aten.permute_copy.default](args = (%_param_constant0_1,), kwargs = {}) 529 %_param_constant1_1 : [#users=1] = get_attr[target=_param_constant1] 530 %addmm_default_1 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1_1, %addmm_default, %t_default_1), kwargs = {}) 531 %relu_default : [#users=1] = call_function[target=torch.ops.aten.relu.default](args = (%addmm_default_1,), kwargs = {}) 532 %_param_constant2 : [#users=1] = get_attr[target=_param_constant2] 533 %permute_default_2 : [#users=1] = call_function[target=torch.ops.aten.permute_copy.default](args = (%_param_constant2,), kwargs = {}) 534 %_param_constant3 : [#users=1] = get_attr[target=_param_constant3] 535 %addmm_default_2 : [#users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant3, %relu_default, %t_default_2), kwargs = {}) 536 return [addmm_default_2] 537""" 538 539module_partitions = get_source_partitions(edge_graph, [torch.nn.Linear, torch.nn.ReLU]) 540print(module_partitions) 541""" 542{<class 'torch.nn.modules.linear.Linear'>: [ 543 ModulePartition(nodes=[_param_constant0, t_default, _param_constant1, addmm_default], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[arg0], output_nodes=[addmm_default], params=["_param_constant0", "_param_constant1"]), 544 ModulePartition(nodes=[_param_constant0_1, t_default_1, _param_constant1_1, addmm_default_1], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[addmm_default], output_nodes=[addmm_default_1], params=["_param_constant0_1", "_param_constant1_1"]), 545 ModulePartition(nodes=[_param_constant2, t_default_2, _param_constant3, addmm_default_2], module_type=<class 'torch.nn.modules.linear.Linear'>, input_nodes=[relu_default], output_nodes=[addmm_default_2], params=["_param_constant2", "_param_constant3"])], 546 547 <class 'torch.nn.modules.activation.ReLU'>: [ 548 ModulePartition(nodes=[relu_default], module_type=<class 'torch.nn.modules.activation.ReLU'>, input_nodes=[addmm_default_1], output_nodes=[relu_default], params=[])]} 549""" 550``` 551