1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import collections 8import unittest 9from typing import List, Optional, Tuple 10 11import torch 12from executorch.exir import EdgeCompileConfig, to_edge 13from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( 14 generate_partitions_from_list_of_nodes, 15) 16from executorch.exir.dialects._ops import ops as exir_ops 17from torch.export import export 18from torch.fx.node import Node 19from torch.fx.passes.operator_support import OperatorSupportBase 20 21 22class TestGraphPartition(unittest.TestCase): 23 def get_graph_module( 24 self, module: torch.nn.Module, inputs: Tuple[torch.Tensor] 25 ) -> torch.fx.GraphModule: 26 graph_module = ( 27 to_edge( 28 export(module, inputs), 29 compile_config=EdgeCompileConfig( 30 _check_ir_validity=False, 31 ), 32 ) 33 .exported_program() 34 .graph_module 35 ) 36 37 return graph_module 38 39 # hackily get list of nodes 40 def get_node_list( 41 self, 42 graph_module: torch.fx.GraphModule, 43 supported_modules: List[torch.nn.Module], 44 ) -> List[List[Node]]: 45 pattern_list_map = collections.defaultdict(list) 46 placeholders = [ 47 node 48 for node in graph_module.graph.nodes 49 if node.op == "placeholder" 50 and node.target != "x" # x is a hack to avoid the user input 51 ] 52 for node in graph_module.graph.nodes: 53 if "nn_module_stack" in node.meta: 54 module_values_list = list(node.meta["nn_module_stack"].values()) 55 full_qualified_name = module_values_list[-1][0] 56 owning_module = module_values_list[-1][1] 57 if owning_module in supported_modules: 58 pattern_list_map[(full_qualified_name, owning_module)].append(node) 59 for arg in node.args: 60 if isinstance(arg, Node) and arg in placeholders: 61 pattern_list_map[ 62 (full_qualified_name, owning_module) 63 ].append(arg) 64 65 return list(pattern_list_map.values()) 66 67 def extract_partition_list( 68 self, 69 graph_module: torch.fx.GraphModule, 70 supported_modules: List[torch.nn.Module], 71 op_support: Optional[OperatorSupportBase] = None, 72 ) -> List: 73 74 node_list = self.get_node_list(graph_module, supported_modules) 75 76 partition_list = generate_partitions_from_list_of_nodes( 77 graph_module, node_list, op_support 78 ) 79 80 return partition_list 81 82 def test_partition_list_without_op_support_one_partition(self): 83 """ 84 check all of submodules should be lowered into a single part 85 """ 86 87 class TestModule(torch.nn.Module): 88 def __init__(self): 89 super().__init__() 90 self.conv1 = torch.nn.Conv2d(32, 32, 1) 91 self.conv2 = torch.nn.Conv2d(32, 32, 1) 92 self.conv3 = torch.nn.Conv2d(32, 32, 1) 93 self.relu = torch.nn.ReLU() 94 95 def forward(self, x: torch.Tensor): 96 a = self.conv1(x) 97 b = self.conv2(a) 98 c = self.conv3(b) 99 d = self.conv3(c) 100 return self.relu(d) 101 102 example_inputs = (torch.rand(1, 32, 16, 16),) 103 test_module = TestModule() 104 graph_module = self.get_graph_module(test_module, example_inputs) 105 106 supported_module = [ 107 "torch.nn.modules.conv.Conv2d", 108 "torch.nn.modules.activation.ReLU", 109 ] 110 partition_list = self.extract_partition_list(graph_module, supported_module) 111 112 self.assertEqual(len(partition_list), 1) 113 114 def test_partition_list_without_op_support_two_partitions(self): 115 """ 116 check graph will be divided into 2 parts when the supported modules is provided, but OpeartorSupportBase is not provideds 117 """ 118 119 class TestModule(torch.nn.Module): 120 def __init__(self): 121 super().__init__() 122 self.conv1 = torch.nn.Conv2d(32, 32, 1) 123 self.conv2 = torch.nn.Conv2d(32, 32, 1) 124 self.conv3 = torch.nn.Conv2d(32, 32, 1) 125 self.relu = torch.nn.ReLU() 126 127 def forward(self, x: torch.Tensor): 128 a = self.conv1(x) 129 b = self.conv2(a) 130 c = self.conv3(a + b) 131 d = self.conv3(c) 132 return self.relu(d) 133 134 example_inputs = (torch.rand(1, 32, 16, 16),) 135 test_module = TestModule() 136 graph_module = self.get_graph_module(test_module, example_inputs) 137 138 supported_module = [ 139 "torch.nn.modules.conv.Conv2d", 140 "torch.nn.modules.activation.ReLU", 141 ] 142 partition_list = self.extract_partition_list(graph_module, supported_module) 143 144 self.assertEqual(len(partition_list), 2) 145 146 partition_1 = [ 147 "aten_convolution_default_2", 148 "aten_convolution_default_3", 149 "aten_relu_default", 150 "p_conv3_bias", 151 "p_conv3_weight", 152 ] 153 partition_2 = [ 154 "aten_convolution_default", 155 "aten_convolution_default_1", 156 "p_conv1_bias", 157 "p_conv1_weight", 158 "p_conv2_bias", 159 "p_conv2_weight", 160 ] 161 162 # extract node names from partition_list, compare them with expected node names 163 node_list_1 = [] 164 for node in partition_list[0].nodes: 165 node_list_1.append(node.name) 166 167 node_list_2 = [] 168 for node in partition_list[1].nodes: 169 node_list_2.append(node.name) 170 171 node_list_1 = sorted(node_list_1) 172 node_list_2 = sorted(node_list_2) 173 174 self.assertEqual(node_list_1, partition_1) 175 self.assertEqual(node_list_2, partition_2) 176 177 def test_graph_partition_with_op_support(self): 178 """ 179 check graph will be divided into 2 parts when the supported modules and OpeartorSupportBase are provided, 180 """ 181 182 class TestOperatorSupport(OperatorSupportBase): 183 def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: 184 return node.op == "call_function" and node.target in [ 185 exir_ops.edge.aten.div.Tensor, 186 exir_ops.edge.aten.add.Tensor, 187 ] 188 189 class TestModule(torch.nn.Module): 190 def __init__(self): 191 super().__init__() 192 self.conv1 = torch.nn.Conv2d(32, 32, 1) 193 self.conv2 = torch.nn.Conv2d(32, 32, 1) 194 self.conv3 = torch.nn.Conv2d(32, 32, 1) 195 self.relu = torch.nn.ReLU() 196 197 def forward(self, x: torch.Tensor): 198 a = self.conv1(x) 199 b = self.conv2(a) 200 c = self.conv3(a + b) 201 d = self.conv3(c) 202 c, _ = torch.max(c, dim=2) 203 d, _ = torch.max(d, dim=2) 204 e = d - c 205 return self.relu(e) 206 207 example_inputs = (torch.rand(1, 32, 16, 16),) 208 test_module = TestModule() 209 graph_module = self.get_graph_module(test_module, example_inputs) 210 211 supported_module = [ 212 "torch.nn.modules.conv.Conv2d", 213 "torch.nn.modules.activation.ReLU", 214 ] 215 partition_list = self.extract_partition_list( 216 graph_module, supported_module, TestOperatorSupport() 217 ) 218 219 self.assertEqual(len(partition_list), 2) 220 221 partition_1 = ["aten_relu_default"] 222 partition_2 = [ 223 "aten_add_tensor", 224 "aten_convolution_default", 225 "aten_convolution_default_1", 226 "aten_convolution_default_2", 227 "aten_convolution_default_3", 228 "p_conv1_bias", 229 "p_conv1_weight", 230 "p_conv2_bias", 231 "p_conv2_weight", 232 "p_conv3_bias", 233 "p_conv3_weight", 234 ] 235 236 # extract node names from partition_list, compare them with expected node names 237 node_list_1 = [] 238 for node in partition_list[0].nodes: 239 node_list_1.append(node.name) 240 241 node_list_2 = [] 242 for node in partition_list[1].nodes: 243 node_list_2.append(node.name) 244 245 node_list_1 = sorted(node_list_1) 246 node_list_2 = sorted(node_list_2) 247 248 self.assertEqual(node_list_1, partition_1) 249 self.assertEqual(node_list_2, partition_2) 250