1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch import exir 11from executorch.exir import to_edge 12from executorch.exir.backend.backend_api import to_backend 13from executorch.exir.backend.partitioner import Partitioner, PartitionResult 14from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo 15from executorch.exir.backend.utils import ( 16 format_delegated_graph, 17 get_delegates, 18 get_non_lowered_nodes, 19 is_identical_graph, 20) 21 22from executorch.exir.dialects._ops import bind_pattern_to_op, ops as exir_ops 23from torch.export import export, ExportedProgram 24from torch.fx import symbolic_trace 25from torch.fx.passes.utils.matcher_utils import SubgraphMatcher 26from torch.library import Library 27 28T_QuantPerTensor = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default 29T_DQuantPerTensor = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default 30 31 32class TestUtils(unittest.TestCase): 33 def test_identical_graph_with_unused_args(self): 34 class MyModule(torch.nn.Module): 35 def __init__(self): 36 super().__init__() 37 38 def forward(self, x, y): 39 # y is not used arg 40 return x 41 42 m = MyModule() 43 graph_module: torch.fx.GraphModule = symbolic_trace(m) 44 is_matched = is_identical_graph(graph_module, graph_module) 45 self.assertTrue(is_matched) 46 47 def test_identical_graph_with_used_args(self): 48 class MyModule(torch.nn.Module): 49 def __init__(self): 50 super().__init__() 51 52 def forward(self, x, y): 53 return x, y 54 55 m = MyModule() 56 graph_module: torch.fx.GraphModule = symbolic_trace(m) 57 is_matched = is_identical_graph(graph_module, graph_module) 58 self.assertTrue(is_matched) 59 60 def test_identical_graph_for_linear(self): 61 graph_module: torch.fx.GraphModule = symbolic_trace(torch.nn.Linear(10, 10)) 62 is_matched = is_identical_graph(graph_module, graph_module) 63 self.assertTrue(is_matched) 64 65 def test_identical_graph_for_composite_module(self): 66 class MyModule(torch.nn.Module): 67 def __init__(self): 68 super().__init__() 69 self.param = torch.nn.Parameter(torch.rand(3, 4)) 70 self.linear = torch.nn.Linear(4, 5) 71 72 def forward(self, x): 73 return self.linear(x + self.param).clamp(min=0.0, max=1.0) 74 75 graph_module: torch.fx.GraphModule = symbolic_trace(MyModule()) 76 is_matched = is_identical_graph(graph_module, graph_module) 77 self.assertTrue(is_matched) 78 79 def test_not_identical_graph_for_args(self): 80 class MyModule1(torch.nn.Module): 81 def __init__(self): 82 super().__init__() 83 84 def forward(self, x, y): 85 # y is not used arg 86 return x + 1 87 88 class MyModule2(torch.nn.Module): 89 def __init__(self): 90 super().__init__() 91 92 def forward(self, x, y): 93 return x + 1, y + 2 94 95 graph_module_1: torch.fx.GraphModule = ( 96 to_edge( 97 export( 98 MyModule1(), 99 (torch.rand(3, 4), torch.rand(3, 4)), 100 ) 101 ) 102 .exported_program() 103 .graph_module 104 ) 105 graph_module_2: torch.fx.GraphModule = ( 106 to_edge( 107 export( 108 MyModule2(), 109 (torch.rand(3, 4), torch.rand(3, 4)), 110 ) 111 ) 112 .exported_program() 113 .graph_module 114 ) 115 is_matched = is_identical_graph(graph_module_1, graph_module_2) 116 self.assertFalse(is_matched) 117 118 def test_match_attrs(self): 119 class LargeModel(torch.nn.Module): 120 def __init__(self): 121 super().__init__() 122 self.weght = torch.nn.Parameter(torch.ones(3, 3)) 123 self.linear = torch.nn.Linear(3, 3) 124 125 def forward(self, x): 126 a = x + self.weght 127 b = self.linear(x) 128 return a, b 129 130 inputs = (torch.ones(3, 3),) 131 132 large_model = ( 133 to_edge( 134 export( 135 LargeModel(), 136 inputs, 137 ), 138 compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 139 ) 140 .exported_program() 141 .graph_module 142 ) 143 144 pattern = ( 145 to_edge( 146 export(torch.nn.Linear(3, 3), inputs), 147 compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 148 ) 149 .exported_program() 150 .graph_module.graph 151 ) 152 153 subgraph_matcher = SubgraphMatcher(pattern) 154 match_result = subgraph_matcher.match(large_model.graph) 155 156 # Should find exact one match 157 self.assertEqual(len(match_result), 1) 158 159 def test_invalid_partitioner_without_partitioner(self): 160 """ 161 Tests replacing literals with placeholders in the case there are 162 `getitem` calls which do not have a schema. 163 """ 164 165 class InvalidPartitioner(Partitioner): 166 """ 167 Partitions all add/mul nodes regardless of order 168 """ 169 170 def __init__(self) -> None: 171 # A valid partitioner should have partition_tags 172 self.test = "a" 173 174 def partition( 175 self, edge_exported_program: ExportedProgram 176 ) -> PartitionResult: 177 return PartitionResult( 178 tagged_exported_program=edge_exported_program, partition_tags=None 179 ) 180 181 exported_program = to_edge( 182 export( 183 torch.nn.Linear(3, 3), 184 (torch.randn(3, 3),), 185 ) 186 ) 187 188 error_msg = r"needs a `partition_tags` field containing a mapping of tags to delegate spec" 189 with self.assertRaisesRegex( 190 AssertionError, 191 error_msg, 192 ): 193 _ = to_backend(exported_program.exported_program(), InvalidPartitioner()) 194 195 test_lib = Library("test_lib", "DEF") 196 197 @staticmethod 198 @bind_pattern_to_op( 199 test_lib, "test_q_linear(Tensor x, Tensor weight, Tensor bias) -> Tensor" 200 ) 201 def q_linear(x, weight, bias): 202 return x 203 204 def test_get_non_lowered_nodes(self): 205 class Model(torch.nn.Module): 206 def __init__(self): 207 super().__init__() 208 209 def forward(self, a, x, b): 210 y = torch.mm(a, x) 211 z = y + b 212 a = z - a 213 y = torch.mm(a, x) 214 z = y + b 215 return z 216 217 m = Model() 218 inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2)) 219 edge = to_edge(export(m, inputs)) 220 edge = edge.to_backend(AddMulPartitionerDemo()) 221 number_of_cpu_nodes = get_non_lowered_nodes(edge.exported_program().graph) 222 # Only sub is not not lowerable 223 self.assertEqual(len(number_of_cpu_nodes), 1) 224 225 def test_get_delegates(self): 226 class Model(torch.nn.Module): 227 def __init__(self): 228 super().__init__() 229 230 def forward(self, a, x, b): 231 y = torch.mm(a, x) 232 z = y + b 233 a = z - a 234 y = torch.mm(a, x) 235 z = y + b 236 return z 237 238 m = Model() 239 inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2)) 240 edge = to_edge(export(m, inputs)) 241 edge = edge.to_backend(AddMulPartitionerDemo()) 242 number_of_delegates = get_delegates(edge.exported_program().graph) 243 # there will be 2 delegates: (mm + add) -> sub -> (mm + add) 244 self.assertEqual(len(number_of_delegates), 2) 245 246 def test_print_delegted_graph(self): 247 class Model(torch.nn.Module): 248 def __init__(self): 249 super().__init__() 250 251 def forward(self, a, x, b): 252 y = torch.mm(a, x) 253 z = y + b 254 a = z - a 255 y = torch.mm(a, x) 256 z = y + b 257 return z 258 259 m = Model() 260 inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2)) 261 262 edge = to_edge(export(m, inputs)).to_backend(AddMulPartitionerDemo()) 263 264 graph_str = format_delegated_graph(edge.exported_program().graph_module) 265 self.assertIn( 266 "BackendWithCompilerDemo", 267 graph_str, 268 "Expect to find the backend id in the graph format string", 269 ) 270 self.assertIn( 271 "executorch.exir.dialects.edge._ops.aten.mm.default", 272 graph_str, 273 "Expect to see the aten.mm in the delegated graph", 274 ) 275