1# Owner(s): ["module: fx"] 2 3import os 4import sys 5from typing import Callable 6 7import torch 8import torch.nn.functional as F 9from torch.fx import symbolic_trace 10from torch.fx.experimental.proxy_tensor import make_fx 11 12 13pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 14sys.path.append(pytorch_test_dir) 15import unittest 16 17from torch.fx.passes.utils.matcher_utils import SubgraphMatcher 18from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( 19 SubgraphMatcherWithNameNodeMap, 20) 21from torch.testing._internal.common_utils import IS_WINDOWS, run_tests 22from torch.testing._internal.jit_utils import JitTestCase 23 24 25class WrapperModule(torch.nn.Module): 26 def __init__(self, fn: Callable): 27 super().__init__() 28 self.fn = fn 29 30 def forward(self, *args, **kwargs): 31 return self.fn(*args, **kwargs) 32 33 34class TestMatcher(JitTestCase): 35 def test_subgraph_matcher_with_attributes(self): 36 class LargeModel(torch.nn.Module): 37 def __init__(self) -> None: 38 super().__init__() 39 self._weight = torch.nn.Parameter(torch.ones(3, 3)) 40 self._bias = torch.nn.Parameter(torch.ones(3, 3)) 41 42 def forward(self, x): 43 return torch.ops.aten.addmm.default(self._bias, x, self._weight) 44 45 # Large Model graph: 46 # opcode name target args kwargs 47 # ------------- ------------- ------------------ ------------------- -------- 48 # placeholder x x () {} 49 # get_attr _bias _bias () {} 50 # get_attr _weight _weight () {} 51 # call_function addmm_default aten.addmm.default (_bias, x, _weight) {} 52 # output output output (addmm_default,) {} 53 large_model_graph = symbolic_trace(LargeModel()).graph 54 55 class PatternModel(torch.nn.Module): 56 def __init__(self) -> None: 57 super().__init__() 58 self._weight_1 = torch.nn.Parameter(torch.ones(5, 5)) 59 self._bias_1 = torch.nn.Parameter(torch.ones(5, 5)) 60 61 def forward(self, x): 62 return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1) 63 64 pattern_graph = torch.fx.symbolic_trace(PatternModel()).graph 65 66 subgraph_matcher = SubgraphMatcher(pattern_graph) 67 match_result = subgraph_matcher.match(large_model_graph) 68 self.assertEqual(len(match_result), 1) 69 70 def test_subgraph_matcher_with_list(self): 71 def original(x, y): 72 return torch.ops.aten.view(x, [5, y.shape[0]]) 73 74 original_graph = torch.fx.symbolic_trace(original).graph 75 76 def pattern(x, y, z): 77 return torch.ops.aten.view(x, [z, y.shape[0]]) 78 79 pattern_graph = torch.fx.symbolic_trace(pattern).graph 80 81 subgraph_matcher = SubgraphMatcher(pattern_graph) 82 match_result = subgraph_matcher.match(original_graph) 83 self.assertEqual(len(match_result), 1) 84 85 def test_subgraph_matcher_with_list_bad(self): 86 def original(x, y): 87 return torch.ops.aten._reshape_alias_copy.default( 88 x, [1, y.shape[0]], [y.shape[1], y.shape[1]] 89 ) 90 91 original_graph = torch.fx.symbolic_trace(original).graph 92 93 def pattern(x, y, b): 94 return torch.ops.aten._reshape_alias_copy.default( 95 x, [b, y.shape[0], y.shape[1]], [y.shape[1]] 96 ) 97 98 pattern_graph = torch.fx.symbolic_trace(pattern).graph 99 100 subgraph_matcher = SubgraphMatcher(pattern_graph) 101 match_result = subgraph_matcher.match(original_graph) 102 self.assertEqual(len(match_result), 0) 103 104 def test_subgraph_matcher_ignore_literals(self): 105 def original(x): 106 return x + 1 107 108 original_graph = make_fx(original)(torch.ones(3, 3)).graph 109 original_graph.eliminate_dead_code() 110 111 def pattern(x): 112 return x + 2 113 114 pattern_graph = make_fx(pattern)(torch.ones(4, 4)).graph 115 pattern_graph.eliminate_dead_code() 116 117 subgraph_matcher = SubgraphMatcher(pattern_graph) 118 match_result = subgraph_matcher.match(original_graph) 119 self.assertEqual(len(match_result), 0) 120 121 subgraph_matcher = SubgraphMatcher(pattern_graph, ignore_literals=True) 122 match_result = subgraph_matcher.match(original_graph) 123 self.assertEqual(len(match_result), 1) 124 125 def test_variatic_arg_matching(self): 126 inputs = (torch.randn(20, 16, 50, 32),) 127 128 def maxpool(x, kernel_size, stride, padding, dilation): 129 return torch.ops.aten.max_pool2d_with_indices.default( 130 x, kernel_size, stride, padding, dilation 131 ) 132 133 maxpool_graph = torch.fx.symbolic_trace(maxpool).graph 134 135 maxpool_matcher = SubgraphMatcher(maxpool_graph) 136 match_result = maxpool_matcher.match(maxpool_graph) 137 self.assertEqual(len(match_result), 1) 138 139 # Graph only contains "stride" argument 140 maxpool_s = torch.nn.MaxPool2d(kernel_size=2, stride=1).eval() 141 maxpool_s_graph = make_fx(maxpool_s)(*inputs).graph 142 match_s_result = maxpool_matcher.match(maxpool_s_graph) 143 self.assertEqual(len(match_s_result), 1) 144 145 # Graph only contains "padding" argument 146 maxpool_p = torch.nn.MaxPool2d(kernel_size=2, padding=1) 147 maxpool_p_graph = make_fx(maxpool_p)(*inputs).graph 148 match_p_result = maxpool_matcher.match(maxpool_p_graph) 149 self.assertEqual(len(match_p_result), 1) 150 151 # Graph only contains "stride, padding" argument 152 maxpool_sp = torch.nn.MaxPool2d(kernel_size=2, stride=1, padding=1) 153 maxpool_sp_graph = make_fx(maxpool_sp)(*inputs).graph 154 match_sp_result = maxpool_matcher.match(maxpool_sp_graph) 155 self.assertEqual(len(match_sp_result), 1) 156 157 @unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") 158 def test_split_to_graph_and_name_node_map(self): 159 """Testing the internal helper function for splitting the pattern graph""" 160 from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( 161 _split_to_graph_and_name_node_map, 162 ) 163 164 def pattern(x, weight): 165 conv = F.conv2d(x, weight) 166 relu = F.relu(conv) 167 relu_mul_by_two = relu * 2 168 return relu, relu_mul_by_two, {"conv": conv, "relu": relu} 169 170 from torch._export import capture_pre_autograd_graph 171 172 example_inputs = ( 173 torch.randn(1, 3, 3, 3) * 10, 174 torch.randn(3, 3, 3, 3), 175 ) 176 pattern_gm = capture_pre_autograd_graph(WrapperModule(pattern), example_inputs) 177 before_split_res = pattern_gm(*example_inputs) 178 pattern_gm, name_node_map = _split_to_graph_and_name_node_map(pattern_gm) 179 after_split_res = pattern_gm(*example_inputs) 180 self.assertEqual(before_split_res[0], after_split_res[0]) 181 self.assertEqual(before_split_res[1], after_split_res[1]) 182 183 @unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") 184 def test_matcher_with_name_node_map_function(self): 185 """Testing SubgraphMatcherWithNameNodeMap with function pattern""" 186 187 def target_graph(x, weight): 188 x = x * 2 189 weight = weight * 3 190 conv = F.conv2d(x, weight) 191 relu = F.relu(conv) 192 relu2 = relu * 2 193 return relu + relu2 194 195 def pattern(x, weight): 196 conv = F.conv2d(x, weight) 197 relu = F.relu(conv) 198 relu_mul_by_two = relu * 2 199 return relu, relu_mul_by_two, {"conv": conv, "relu": relu} 200 201 from torch._export import capture_pre_autograd_graph 202 203 example_inputs = ( 204 torch.randn(1, 3, 3, 3) * 10, 205 torch.randn(3, 3, 3, 3), 206 ) 207 pattern_gm = capture_pre_autograd_graph(WrapperModule(pattern), example_inputs) 208 matcher = SubgraphMatcherWithNameNodeMap(pattern_gm) 209 target_gm = capture_pre_autograd_graph( 210 WrapperModule(target_graph), example_inputs 211 ) 212 internal_matches = matcher.match(target_gm.graph) 213 for internal_match in internal_matches: 214 name_node_map = internal_match.name_node_map 215 assert "conv" in name_node_map 216 assert "relu" in name_node_map 217 name_node_map["conv"].meta["custom_annotation"] = "annotation" 218 # check if we correctly annotated the target graph module 219 for n in target_gm.graph.nodes: 220 if n == name_node_map["conv"]: 221 assert ( 222 "custom_annotation" in n.meta 223 and n.meta["custom_annotation"] == "annotation" 224 ) 225 226 @unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") 227 def test_matcher_with_name_node_map_module(self): 228 """Testing SubgraphMatcherWithNameNodeMap with module pattern""" 229 230 class M(torch.nn.Module): 231 def __init__(self) -> None: 232 super().__init__() 233 self.linear = torch.nn.Linear(5, 5) 234 235 def forward(self, x): 236 return self.linear(x) 237 238 class Pattern(torch.nn.Module): 239 def __init__(self) -> None: 240 super().__init__() 241 self.linear = torch.nn.Linear(5, 5) 242 243 def forward(self, x): 244 linear = self.linear(x) 245 # Note: we can't put "weight": self.linear.weight in dictionary since 246 # nn.Parameter is not an allowed output type in dynamo 247 return linear, {"linear": linear, "x": x} 248 249 from torch._export import capture_pre_autograd_graph 250 251 example_inputs = (torch.randn(3, 5),) 252 pattern_gm = capture_pre_autograd_graph(Pattern(), example_inputs) 253 matcher = SubgraphMatcherWithNameNodeMap(pattern_gm) 254 target_gm = capture_pre_autograd_graph(M(), example_inputs) 255 internal_matches = matcher.match(target_gm.graph) 256 for internal_match in internal_matches: 257 name_node_map = internal_match.name_node_map 258 assert "linear" in name_node_map 259 assert "x" in name_node_map 260 name_node_map["linear"].meta["custom_annotation"] = "annotation" 261 # check if we correctly annotated the target graph module 262 for n in target_gm.graph.nodes: 263 if n == name_node_map["linear"]: 264 assert ( 265 "custom_annotation" in n.meta 266 and n.meta["custom_annotation"] == "annotation" 267 ) 268 269 270if __name__ == "__main__": 271 run_tests() 272