1# Owner(s): ["module: fx"] 2 3import os 4import sys 5import unittest 6 7import torch 8 9 10pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 11sys.path.append(pytorch_test_dir) 12from torch._dynamo.eval_frame import is_dynamo_supported 13from torch.fx.passes.tools_common import legalize_graph 14from torch.fx.passes.utils.source_matcher_utils import ( 15 check_subgraphs_connected, 16 get_source_partitions, 17) 18from torch.testing._internal.common_utils import ( 19 instantiate_parametrized_tests, 20 parametrize, 21) 22from torch.testing._internal.jit_utils import JitTestCase 23 24 25class TestSourceMatcher(JitTestCase): 26 @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported") 27 def test_module_partitioner_linear_relu_linear(self): 28 class M(torch.nn.Module): 29 def __init__(self) -> None: 30 super().__init__() 31 self.linear1 = torch.nn.Linear(3, 3) 32 self.relu = torch.nn.ReLU() 33 self.linear2 = torch.nn.Linear(3, 5) 34 35 def forward(self, x): 36 x = self.linear1(x) 37 x = self.linear1(x) 38 x = self.relu(x) 39 x = self.linear2(x) 40 return x 41 42 inputs = (torch.randn(3, 3),) 43 gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs) 44 gm.graph.eliminate_dead_code() 45 46 module_partitions = get_source_partitions( 47 gm.graph, [torch.nn.Linear, torch.nn.ReLU] 48 ) 49 50 self.assertEqual(len(module_partitions), 2) 51 self.assertEqual(len(module_partitions[torch.nn.Linear]), 3) 52 self.assertEqual(len(module_partitions[torch.nn.ReLU]), 1) 53 54 self.assertFalse( 55 check_subgraphs_connected( 56 module_partitions[torch.nn.Linear][0], 57 module_partitions[torch.nn.ReLU][0], 58 ) 59 ) 60 self.assertTrue( 61 check_subgraphs_connected( 62 module_partitions[torch.nn.Linear][1], 63 module_partitions[torch.nn.ReLU][0], 64 ) 65 ) 66 self.assertFalse( 67 check_subgraphs_connected( 68 module_partitions[torch.nn.Linear][2], 69 module_partitions[torch.nn.ReLU][0], 70 ) 71 ) 72 73 @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported") 74 def test_module_partitioner_conv_relu_maxpool(self): 75 class M(torch.nn.Module): 76 def __init__(self, constant_tensor: torch.Tensor) -> None: 77 super().__init__() 78 self.constant_tensor = constant_tensor 79 self.conv1 = torch.nn.Conv2d( 80 in_channels=3, out_channels=16, kernel_size=3, padding=1 81 ) 82 self.conv2 = torch.nn.Conv2d( 83 in_channels=16, out_channels=16, kernel_size=3, padding=1 84 ) 85 self.conv3 = torch.nn.Conv2d( 86 in_channels=16, out_channels=16, kernel_size=3, padding=1 87 ) 88 self.relu = torch.nn.ReLU() 89 self.maxpool = torch.nn.MaxPool2d(kernel_size=3) 90 91 def forward(self, x: torch.Tensor) -> torch.Tensor: 92 a = self.conv1(x) 93 b = self.conv2(a) 94 c = a + self.constant_tensor 95 z = self.conv3(b + c) 96 return self.maxpool(self.relu(z)) 97 98 inputs = (torch.randn(1, 3, 256, 256),) 99 gm, _ = torch._dynamo.export(M(torch.ones(1, 16, 256, 256)), aten_graph=True)( 100 *inputs 101 ) 102 gm.graph.eliminate_dead_code() 103 104 module_partitions = get_source_partitions( 105 gm.graph, [torch.nn.Conv2d, torch.nn.ReLU, torch.nn.MaxPool2d] 106 ) 107 108 self.assertEqual(len(module_partitions), 3) 109 self.assertEqual(len(module_partitions[torch.nn.Conv2d]), 3) 110 self.assertEqual(len(module_partitions[torch.nn.ReLU]), 1) 111 self.assertEqual(len(module_partitions[torch.nn.MaxPool2d]), 1) 112 113 self.assertFalse( 114 check_subgraphs_connected( 115 module_partitions[torch.nn.Conv2d][0], 116 module_partitions[torch.nn.ReLU][0], 117 ) 118 ) 119 self.assertFalse( 120 check_subgraphs_connected( 121 module_partitions[torch.nn.Conv2d][1], 122 module_partitions[torch.nn.ReLU][0], 123 ) 124 ) 125 self.assertTrue( 126 check_subgraphs_connected( 127 module_partitions[torch.nn.Conv2d][2], 128 module_partitions[torch.nn.ReLU][0], 129 ) 130 ) 131 self.assertFalse( 132 check_subgraphs_connected( 133 module_partitions[torch.nn.MaxPool2d][0], 134 module_partitions[torch.nn.ReLU][0], 135 ) 136 ) 137 self.assertTrue( 138 check_subgraphs_connected( 139 module_partitions[torch.nn.ReLU][0], 140 module_partitions[torch.nn.MaxPool2d][0], 141 ) 142 ) 143 144 @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported") 145 def test_module_partitioner_functional_conv_relu_conv(self): 146 class FunctionalConv2d(torch.nn.Module): 147 def __init__(self) -> None: 148 super().__init__() 149 self.stride = (1, 1) 150 self.padding = (0, 0) 151 self.dilation = (1, 1) 152 self.groups = 1 153 154 def forward(self, x, weight, bias): 155 return torch.nn.functional.conv2d( 156 x, 157 weight, 158 bias, 159 self.stride, 160 self.padding, 161 self.dilation, 162 self.groups, 163 ) 164 165 class M(torch.nn.Module): 166 def __init__(self) -> None: 167 super().__init__() 168 self.conv1 = FunctionalConv2d() 169 self.conv2 = FunctionalConv2d() 170 171 def forward(self, x, weight, bias): 172 x = self.conv1(x, weight, bias) 173 x = torch.nn.functional.relu(x) 174 x = self.conv2(x, weight, bias) 175 return x 176 177 inputs = (torch.randn(1, 3, 5, 5), torch.rand(3, 3, 3, 3), torch.rand(3)) 178 gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs) 179 gm.graph.eliminate_dead_code() 180 181 module_partitions = get_source_partitions( 182 gm.graph, [torch.nn.functional.conv2d] 183 ) 184 185 self.assertEqual(len(module_partitions), 1) 186 self.assertEqual(len(module_partitions[torch.nn.functional.conv2d]), 2) 187 188 @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported") 189 def test_module_partitioner_functional_linear_relu_linear(self): 190 class M(torch.nn.Module): 191 def __init__(self) -> None: 192 super().__init__() 193 194 def forward(self, x, weight, bias): 195 x = torch.nn.functional.linear(x, weight, bias) 196 x = torch.nn.functional.linear(x, weight, bias) 197 x = torch.nn.functional.relu(x) 198 x = torch.nn.functional.linear(x, weight, bias) 199 x = torch.nn.functional.linear(x, weight, bias) 200 x = torch.nn.functional.relu(x) 201 return x 202 203 inputs = (torch.randn(1, 5), torch.rand((5, 5)), torch.zeros(5)) 204 gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs) 205 gm.graph.eliminate_dead_code() 206 207 module_partitions = get_source_partitions( 208 gm.graph, [torch.nn.functional.linear, torch.nn.functional.relu] 209 ) 210 211 self.assertEqual(len(module_partitions), 2) 212 self.assertEqual(len(module_partitions[torch.nn.functional.linear]), 4) 213 self.assertEqual(len(module_partitions[torch.nn.functional.relu]), 2) 214 215 @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported") 216 def test_legalize_slice(self): 217 class M(torch.nn.Module): 218 def forward(self, x, y): 219 b = x.item() 220 torch._check_is_size(b) 221 torch._check(b + 1 < y.size(0)) 222 return y[: b + 1] 223 224 ep = torch.export.export(M(), (torch.tensor(4), torch.randn(10))) 225 fake_inputs = [ 226 node.meta["val"] for node in ep.graph.nodes if node.op == "placeholder" 227 ] 228 gm = ep.module() 229 with fake_inputs[0].fake_mode: 230 torch.fx.Interpreter(gm).run(*fake_inputs) 231 legalized_gm = legalize_graph(gm) 232 with fake_inputs[0].fake_mode: 233 torch.fx.Interpreter(legalized_gm).run(*fake_inputs) 234 235 @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported") 236 @parametrize("strict", (True, False)) 237 def test_module_partitioner_linear_relu_linear_torch_fn_export(self, strict: bool): 238 class M(torch.nn.Module): 239 def __init__(self) -> None: 240 super().__init__() 241 self.linear1 = torch.nn.Linear(3, 3) 242 self.relu = torch.nn.ReLU() 243 self.linear2 = torch.nn.Linear(3, 5) 244 245 def forward(self, x): 246 x = self.linear1(x) 247 x = self.linear1(x) 248 x = self.relu(x) 249 x = self.linear2(x) 250 return x 251 252 inputs = (torch.randn(3, 3),) 253 gm = torch.export.export(M(), inputs, strict=strict).module() 254 gm.graph.eliminate_dead_code() 255 256 # Remove "source_fn_stack" meta to let partitioner use "torch_fn" only. 257 # TODO: remove this after we fix "torch_fn". T199561090 258 for node in gm.graph.nodes: 259 node.meta["source_fn_stack"] = None 260 261 module_partitions = get_source_partitions(gm.graph, ["linear", "relu"]) 262 263 self.assertEqual(len(module_partitions), 2) 264 self.assertEqual(len(module_partitions["linear"]), 3) 265 self.assertEqual(len(module_partitions["relu"]), 1) 266 267 self.assertFalse( 268 check_subgraphs_connected( 269 module_partitions["linear"][0], 270 module_partitions["relu"][0], 271 ) 272 ) 273 self.assertTrue( 274 check_subgraphs_connected( 275 module_partitions["linear"][1], 276 module_partitions["relu"][0], 277 ) 278 ) 279 self.assertFalse( 280 check_subgraphs_connected( 281 module_partitions["linear"][2], 282 module_partitions["relu"][0], 283 ) 284 ) 285 286 @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported") 287 @parametrize("strict", (True, False)) 288 def test_module_partitioner_conv_relu_maxpool_torch_fn_export(self, strict: bool): 289 class M(torch.nn.Module): 290 def __init__(self, constant_tensor: torch.Tensor) -> None: 291 super().__init__() 292 self.constant_tensor = constant_tensor 293 self.conv1 = torch.nn.Conv2d( 294 in_channels=3, out_channels=16, kernel_size=3, padding=1 295 ) 296 self.conv2 = torch.nn.Conv2d( 297 in_channels=16, out_channels=16, kernel_size=3, padding=1 298 ) 299 self.conv3 = torch.nn.Conv2d( 300 in_channels=16, out_channels=16, kernel_size=3, padding=1 301 ) 302 self.relu = torch.nn.ReLU() 303 self.maxpool = torch.nn.MaxPool2d(kernel_size=3) 304 305 def forward(self, x: torch.Tensor) -> torch.Tensor: 306 a = self.conv1(x) 307 b = self.conv2(a) 308 c = a + self.constant_tensor 309 z = self.conv3(b + c) 310 return self.maxpool(self.relu(z)) 311 312 inputs = (torch.randn(1, 3, 256, 256),) 313 gm = torch.export.export( 314 M(torch.ones(1, 16, 256, 256)), inputs, strict=strict 315 ).module() 316 gm.graph.eliminate_dead_code() 317 318 # Remove "source_fn_stack" meta to let partitioner use "torch_fn" only. 319 # TODO: remove this after we fix "torch_fn". T199561090 320 for node in gm.graph.nodes: 321 node.meta["source_fn_stack"] = None 322 323 module_partitions = get_source_partitions( 324 gm.graph, ["conv2d", "relu", "max_pool2d"] 325 ) 326 327 self.assertEqual(len(module_partitions), 3) 328 self.assertEqual(len(module_partitions["conv2d"]), 3) 329 self.assertEqual(len(module_partitions["relu"]), 1) 330 self.assertEqual(len(module_partitions["max_pool2d"]), 1) 331 332 self.assertFalse( 333 check_subgraphs_connected( 334 module_partitions["conv2d"][0], 335 module_partitions["relu"][0], 336 ) 337 ) 338 self.assertFalse( 339 check_subgraphs_connected( 340 module_partitions["conv2d"][1], 341 module_partitions["relu"][0], 342 ) 343 ) 344 self.assertTrue( 345 check_subgraphs_connected( 346 module_partitions["conv2d"][2], 347 module_partitions["relu"][0], 348 ) 349 ) 350 self.assertFalse( 351 check_subgraphs_connected( 352 module_partitions["max_pool2d"][0], 353 module_partitions["relu"][0], 354 ) 355 ) 356 self.assertTrue( 357 check_subgraphs_connected( 358 module_partitions["relu"][0], 359 module_partitions["max_pool2d"][0], 360 ) 361 ) 362 363 @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported") 364 @parametrize("strict", (True, False)) 365 def test_module_partitioner_functional_conv_relu_conv_torch_fn_export( 366 self, strict: bool 367 ): 368 class FunctionalConv2d(torch.nn.Module): 369 def __init__(self) -> None: 370 super().__init__() 371 self.stride = (1, 1) 372 self.padding = (0, 0) 373 self.dilation = (1, 1) 374 self.groups = 1 375 376 def forward(self, x, weight, bias): 377 return torch.nn.functional.conv2d( 378 x, 379 weight, 380 bias, 381 self.stride, 382 self.padding, 383 self.dilation, 384 self.groups, 385 ) 386 387 class M(torch.nn.Module): 388 def __init__(self) -> None: 389 super().__init__() 390 self.conv1 = FunctionalConv2d() 391 self.conv2 = FunctionalConv2d() 392 393 def forward(self, x, weight, bias): 394 x = self.conv1(x, weight, bias) 395 x = torch.nn.functional.relu(x) 396 x = self.conv2(x, weight, bias) 397 return x 398 399 inputs = (torch.randn(1, 3, 5, 5), torch.rand(3, 3, 3, 3), torch.rand(3)) 400 gm = torch.export.export(M(), inputs, strict=strict).module() 401 gm.graph.eliminate_dead_code() 402 403 # Remove "source_fn_stack" meta to let partitioner use "torch_fn" only. 404 # TODO: remove this after we fix "torch_fn". T199561090 405 for node in gm.graph.nodes: 406 node.meta["source_fn_stack"] = None 407 408 module_partitions = get_source_partitions(gm.graph, ["conv2d"]) 409 410 self.assertEqual(len(module_partitions), 1) 411 self.assertEqual(len(module_partitions["conv2d"]), 2) 412 413 @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported") 414 @parametrize("strict", (True, False)) 415 def test_module_partitioner_functional_linear_relu_linear_torch_fn_export( 416 self, strict: bool 417 ): 418 class M(torch.nn.Module): 419 def __init__(self) -> None: 420 super().__init__() 421 422 def forward(self, x, weight, bias): 423 x = torch.nn.functional.linear(x, weight, bias) 424 x = torch.nn.functional.linear(x, weight, bias) 425 x = torch.nn.functional.relu(x) 426 x = torch.nn.functional.linear(x, weight, bias) 427 x = torch.nn.functional.linear(x, weight, bias) 428 x = torch.nn.functional.relu(x) 429 return x 430 431 inputs = (torch.randn(1, 5), torch.rand((5, 5)), torch.zeros(5)) 432 gm = torch.export.export(M(), inputs, strict=strict).module() 433 gm.graph.eliminate_dead_code() 434 435 # Remove "source_fn_stack" meta to let partitioner use "torch_fn" only. 436 # TODO: remove this after we fix "torch_fn". T199561090 437 for node in gm.graph.nodes: 438 node.meta["source_fn_stack"] = None 439 440 module_partitions = get_source_partitions(gm.graph, ["linear", "relu"]) 441 442 self.assertEqual(len(module_partitions), 2) 443 self.assertEqual(len(module_partitions["linear"]), 4) 444 self.assertEqual(len(module_partitions["relu"]), 2) 445 446 447instantiate_parametrized_tests(TestSourceMatcher) 448