1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import operator 8import unittest 9from typing import Dict, final, List 10 11import executorch.exir as exir 12 13import torch 14 15from executorch.exir.backend.backend_api import to_backend 16from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult 17from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( 18 generate_pattern_op_partitions, 19) 20from executorch.exir.backend.compile_spec_schema import CompileSpec 21from executorch.exir.backend.partitioner import ( 22 DelegationSpec, 23 Partitioner, 24 PartitionResult, 25) 26 27from executorch.exir.backend.test.op_partitioner_demo import ( 28 AddOperatorSupport, 29 MatmulOperatorSupport, 30) 31from executorch.exir.delegate import executorch_call_delegate 32 33from executorch.exir.graph_module import _get_submodule, get_control_flow_submodules 34from executorch.exir.lowered_backend_module import get_lowered_submodules 35from functorch.experimental import control_flow 36from torch.export import ExportedProgram 37from torch.fx.passes.operator_support import any_chain, OperatorSupportBase 38 39 40class M(torch.nn.Module): 41 def __init__(self): 42 super().__init__() 43 44 def forward(self, x, pred1, pred2, y): 45 def true_fn(x, pred2): 46 def true_nested(y): 47 y = y + y 48 y = torch.mm(y, y) 49 return y 50 51 def false_nested(y): 52 return torch.mm(y, y) 53 54 z = control_flow.cond(pred2, true_nested, false_nested, [x]) 55 return x + z 56 57 def false_fn(x, _pred2): 58 return torch.mm(x, x) 59 60 x = x.cos() 61 x = x + y 62 y = control_flow.cond(pred1, true_fn, false_fn, [x, pred2]) 63 return y.sin() 64 65 def get_example_inputs(self): 66 return ( 67 torch.ones(2, 2), 68 torch.tensor([False]), 69 torch.Tensor([False]), 70 torch.ones(2, 2), 71 ) 72 73 74@final 75class Backend2Demo(BackendDetails): 76 @staticmethod 77 def preprocess( 78 edge_program: ExportedProgram, 79 compile_specs: List[CompileSpec], 80 ) -> PreprocessResult: 81 processed_bytes = "Backend2::" 82 for node in edge_program.graph.nodes: 83 if node.op == "call_function": 84 processed_bytes += f"{node.target.__name__};" 85 return PreprocessResult( 86 processed_bytes=bytes(processed_bytes, encoding="utf8"), 87 ) 88 89 90@final 91class Backend2PartitionerDemo(Partitioner): 92 """ 93 Partitions all add/mul nodes regardless of order for Backend2 94 """ 95 96 def __init__(self) -> None: 97 self.op_support = any_chain(AddOperatorSupport(), MatmulOperatorSupport()) 98 self.delegation_spec = DelegationSpec("Backend2Demo", []) 99 self.partition_tags = {} 100 101 def _partition_graph_module( 102 self, edge_graph_module: torch.fx.GraphModule 103 ) -> Dict[str, DelegationSpec]: 104 partition_tags: Dict[str, DelegationSpec] = {} 105 partition_list = generate_pattern_op_partitions( 106 edge_graph_module, op_support=self.op_support 107 ) 108 109 for _, submodule, _ in get_control_flow_submodules(edge_graph_module): 110 submodule_partition_tags = self._partition_graph_module(submodule) 111 partition_tags.update(submodule_partition_tags) 112 113 for partition in partition_list: 114 for node in partition.nodes: 115 delegation_tag = f"backend2_tag{partition.id}" 116 node.meta["delegation_tag"] = delegation_tag 117 partition_tags[delegation_tag] = self.delegation_spec 118 return partition_tags 119 120 def partition(self, exported_program: ExportedProgram) -> PartitionResult: 121 partition_tags = self._partition_graph_module(exported_program.graph_module) 122 return PartitionResult( 123 tagged_exported_program=exported_program, partition_tags=partition_tags 124 ) 125 126 127@final 128class Backend1Demo(BackendDetails): 129 @staticmethod 130 def preprocess( 131 edge_program: ExportedProgram, 132 compile_specs: List[CompileSpec], 133 ) -> PreprocessResult: 134 assert isinstance(edge_program, ExportedProgram) 135 partitioned_module = to_backend(edge_program, Backend2PartitionerDemo()) 136 137 def process(gm): 138 processed_bytes = "" 139 for node in gm.graph.nodes: 140 if node.op == "call_function": 141 if node.target is torch.ops.higher_order.cond: 142 _, true_gm, _ = _get_submodule(gm, node, 1) 143 _, false_gm, _ = _get_submodule(gm, node, 2) 144 processed_bytes += f"{node.target.__name__}({process(true_gm)},{process(false_gm)});" 145 elif node.target is operator.getitem: 146 continue 147 elif node.target is executorch_call_delegate: 148 _, lowered, _ = _get_submodule(gm, node, 0) 149 processed_bytes += f"call_delegate({lowered.processed_bytes});" 150 else: 151 processed_bytes += f"{node.target.__name__};" 152 return processed_bytes 153 154 processed_bytes = f"Backend1::({process(partitioned_module.graph_module)})" 155 return PreprocessResult( 156 processed_bytes=bytes(processed_bytes, encoding="utf8"), 157 ) 158 159 160class CondOperatorSupport(OperatorSupportBase): 161 def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: 162 return node.op == "call_function" and node.target is torch.ops.higher_order.cond 163 164 165@final 166class Backend1PartitionerDemo(Partitioner): 167 """ 168 Partitions all add/mul/cond nodes regardless of order. Since we're 169 partitioning the cond ops, we do not need to go into those submodules. 170 """ 171 172 def __init__(self) -> None: 173 self.op_support = any_chain( 174 AddOperatorSupport(), MatmulOperatorSupport(), CondOperatorSupport() 175 ) 176 self.delegation_spec = DelegationSpec("Backend1Demo", []) 177 178 def _partition_graph_module( 179 self, edge_graph_module: torch.fx.GraphModule 180 ) -> Dict[str, DelegationSpec]: 181 partition_tags: Dict[str, DelegationSpec] = {} 182 partition_list = generate_pattern_op_partitions( 183 edge_graph_module, op_support=self.op_support 184 ) 185 186 for _, submodule, node in get_control_flow_submodules(edge_graph_module): 187 # Don't partition the cond submodules because we are lowering the 188 # entire cond node, including it's submodules. 189 if node.target is not control_flow.cond: 190 self._partition_graph_module(submodule) 191 192 for partition in partition_list: 193 for node in partition.nodes: 194 delegation_tag = f"backend1_tag{partition.id}" 195 if ( 196 node.op == "call_function" 197 and node.target is torch.ops.higher_order.cond 198 ): 199 # Tag the arguments that take in the submodules to cond 200 node.args[1].meta["delegation_tag"] = delegation_tag 201 node.args[2].meta["delegation_tag"] = delegation_tag 202 node.meta["delegation_tag"] = delegation_tag 203 partition_tags[delegation_tag] = self.delegation_spec 204 return partition_tags 205 206 def partition(self, exported_program: ExportedProgram) -> PartitionResult: 207 partition_tags = self._partition_graph_module(exported_program.graph_module) 208 return PartitionResult( 209 tagged_exported_program=exported_program, partition_tags=partition_tags 210 ) 211 212 213class TestNestedBackends(unittest.TestCase): 214 def test(self) -> None: 215 """ 216 Partitions the cond ops into the delegate 217 """ 218 219 m = M() 220 orig_res = m(*m.get_example_inputs()) 221 orig = exir.capture( 222 m, 223 m.get_example_inputs(), 224 exir.CaptureConfig(), 225 ).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False)) 226 227 partitioned = orig 228 partitioned.exported_program = to_backend( 229 orig.exported_program, Backend1PartitionerDemo() 230 ) 231 232 new_res = partitioned(*m.get_example_inputs())[0] 233 self.assertTrue(torch.allclose(orig_res, new_res)) 234 235 # The toplevel module should have lowered the cond and add op 236 toplevel_lowered = get_lowered_submodules( 237 partitioned.exported_program.graph_module 238 ) 239 self.assertEqual(len(toplevel_lowered), 1) 240 toplevel_lowered = toplevel_lowered[0][1] 241 self.maxDiff = None 242 self.assertEqual( 243 str(toplevel_lowered.processed_bytes), 244 ( 245 'b"Backend1::(' 246 + "call_delegate(b'Backend2::aten.add.Tensor;');" 247 + "cond(" 248 # True function of toplevel cond (nested cond) 249 + "cond(call_delegate(b'Backend2::aten.add.Tensor;aten.mm.default;');,call_delegate(b'Backend2::aten.mm.default;'););" 250 # True function of toplevel cond (delegated add) 251 + "call_delegate(b'Backend2::aten.add.Tensor;');," 252 # False function of toplevel cond 253 + "call_delegate(b'Backend2::aten.mm.default;'););)\"" 254 ), 255 ) 256