1# Owner(s): ["module: inductor"] 2 3import collections 4import unittest 5from typing import List 6 7import torch 8import torch._inductor 9import torch._inductor.fx_passes.group_batch_fusion 10from torch._dynamo.utils import counters, optimus_scuba_log 11from torch._inductor.test_case import run_tests, TestCase 12from torch.testing._internal.inductor_utils import HAS_CUDA 13 14 15try: 16 # importing this will register fbgemm lowerings for inductor 17 import deeplearning.fbgemm.fbgemm_gpu.fb.inductor_lowerings # noqa: F401 18 19 has_fbgemm = True 20except Exception: 21 has_fbgemm = False 22 23requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") 24 25 26class TestHighwaySelfGating(torch.nn.Module): 27 def __init__( 28 self, 29 d_model: int, 30 size: int, 31 device="cuda", 32 ) -> None: 33 super().__init__() 34 self.size = size 35 self.device = device 36 self.gating_proj = torch.nn.Linear(d_model, d_model).to(self.device) 37 self.transform_proj = torch.nn.Linear(d_model, d_model).to(self.device) 38 self.gating_func = torch.nn.Sigmoid().to(self.device) 39 40 self.d_model = d_model 41 42 def forward( 43 self, 44 inputs: List[torch.Tensor], 45 ) -> torch.Tensor: 46 results = [] 47 for i in range(self.size): 48 x = inputs[i] 49 gating_proj = self.gating_proj(x) 50 transform_proj = self.transform_proj(x) 51 x = gating_proj * self.gating_func(transform_proj) 52 results.append(x) 53 54 return torch.cat(results, dim=-1) 55 56 57class MyModule(torch.nn.Module): 58 def __init__(self, z: int, has_bias: bool, device="cuda") -> None: 59 super().__init__() 60 self.z = z 61 self.device = device 62 self.seq_len = 10 63 self.seq1 = [ 64 torch.nn.Linear(z, z, has_bias).to(self.device) for _ in range(self.seq_len) 65 ] 66 self.seq2 = [ 67 torch.nn.Linear(z, z, has_bias).to(self.device) for _ in range(self.seq_len) 68 ] 69 self.seq3 = [ 70 torch.nn.Linear(z, z, has_bias).to(self.device) for _ in range(self.seq_len) 71 ] 72 73 def forward(self, x: torch.Tensor) -> torch.Tensor: 74 x1 = [x + 0.1 * i for i in range(self.seq_len)] 75 x2 = [self.seq1[i](x1[i]) for i in range(self.seq_len)] 76 x3 = [x2[i] - 0.1 * i for i in range(self.seq_len)] 77 x4 = [x1[i] for i in range(3)] + [x3[i] for i in range(3, self.seq_len)] 78 x5 = [self.seq2[i](x4[i]) for i in range(self.seq_len)] 79 x6 = [x5[i] + 0.1 * (self.seq_len - i) for i in range(self.seq_len)] 80 x7 = ( 81 [x1[i] for i in range(4)] 82 + [x3[i] for i in range(6, 8)] 83 + [x6[i] for i in range(4)] 84 ) 85 x8 = [self.seq3[i](x7[i]) for i in range(self.seq_len)] 86 x9 = torch.cat(x8, dim=1) 87 return x9 88 89 90class MyModule2(torch.nn.Module): 91 def __init__(self) -> None: 92 super().__init__() 93 self.linear0 = torch.nn.Linear(6, 8) 94 self.linear1 = torch.nn.Linear(8, 8) 95 self.linear2 = torch.nn.Linear(10, 8) 96 self.linear3 = torch.nn.Linear(6, 8) 97 self.linear4 = torch.nn.Linear(8, 8) 98 self.linear5 = torch.nn.Linear(10, 8) 99 self.bn0 = torch.nn.BatchNorm1d(8) 100 self.bn1 = torch.nn.BatchNorm1d(8) 101 self.bn2 = torch.nn.BatchNorm1d(8) 102 103 def forward(self, x: torch.Tensor) -> torch.Tensor: 104 t = torch.split(x, [6, 8, 10], dim=1) 105 a0 = self.bn0(self.linear0(t[0] + 0.1)) 106 a1 = self.bn1(self.linear1(t[1] + 0.2)) 107 a2 = self.bn2(self.linear2(t[2] + 0.3)) 108 a3 = self.linear3(torch.sin(t[0])) 109 a4 = self.linear4(torch.cos(t[1])) 110 a5 = self.linear5(torch.sin(t[2] * 0.5)) 111 112 b = torch.cat([a0, a1, a2, a3, a4, a5]) 113 return torch.sigmoid(b) 114 115 116class MyModule3(torch.nn.Module): 117 def __init__(self, device, has_weight=True, has_bias=True): 118 super().__init__() 119 self.device = device 120 self.scale0 = torch.nn.ParameterList( 121 [torch.nn.Parameter(torch.randn(10)) for _ in range(5)] 122 ).to(self.device) 123 self.bias0 = torch.nn.ParameterList( 124 [torch.nn.Parameter(torch.randn(10)) for _ in range(5)] 125 ).to(self.device) 126 self.scale1 = ( 127 torch.nn.ParameterList( 128 [torch.nn.Parameter(torch.randn(5, 10)) for _ in range(5)] 129 ).to(self.device) 130 if has_weight 131 else [None for _ in range(5)] 132 ) 133 self.bias1 = ( 134 torch.nn.ParameterList( 135 [torch.nn.Parameter(torch.randn(5, 10)) for _ in range(5)] 136 ).to(self.device) 137 if has_bias 138 else [None for _ in range(5)] 139 ) 140 141 def forward(self, x): 142 l1_out = torch.split(x.to(self.device), 10, dim=2) 143 post_l1 = [ 144 torch.nn.functional.layer_norm( 145 l1_out[i], (10,), weight=self.scale0[i], bias=self.bias0[i] 146 ) 147 for i in range(len(l1_out)) 148 ] 149 l1_out = torch.cat(post_l1, dim=2) 150 151 l2_out = torch.split(l1_out, 10, dim=2) 152 post_l2 = [ 153 torch.nn.functional.layer_norm( 154 l2_out[i], (5, 10), weight=self.scale1[i], bias=self.bias1[i] 155 ) 156 for i in range(len(l2_out)) 157 ] 158 159 return torch.cat(post_l2, dim=2) 160 161 162class MyModule4(torch.nn.Module): 163 def __init__(self, z, device, has_bias): 164 super().__init__() 165 self.z = z 166 self.device = device 167 self.has_bias = has_bias 168 self.seq_len = 10 169 self.weights1 = [ 170 torch.nn.Parameter(torch.randn(z - i % 5, z)).to(self.device) 171 for i in range(self.seq_len) 172 ] 173 self.weights2 = [ 174 torch.nn.Parameter(torch.randn(z - i % 5, z)).to(self.device) 175 for i in range(self.seq_len) 176 ] 177 178 if has_bias: 179 self.biases1 = [ 180 torch.nn.Parameter(torch.randn(z - i % 5)).to(self.device) 181 for i in range(self.seq_len) 182 ] 183 self.biases2 = [ 184 torch.nn.Parameter(torch.randn(z - i % 5)).to(self.device) 185 for i in range(self.seq_len) 186 ] 187 188 def forward(self, x): 189 x = x + 1.2 190 x1 = [ 191 torch.nn.functional.linear( 192 x, self.weights1[i], self.biases1[i] if self.has_bias else None 193 ) 194 for i in range(self.seq_len) 195 ] 196 x2 = torch.cat(x1, dim=1) 197 x3 = torch.split(x2, 10, dim=1) 198 x4 = torch.cat(x3) 199 x5 = [ 200 torch.nn.functional.linear( 201 x4, self.weights2[i], self.biases2[i] if self.has_bias else None 202 ) 203 for i in range(self.seq_len) 204 ] 205 x6 = torch.cat(x5, dim=1) 206 return torch.sigmoid(x6) 207 208 209class MyModule5(torch.nn.Module): 210 def __init__(self, device, has_bias=True): 211 super().__init__() 212 self.device = device 213 214 self.weights = torch.nn.ParameterList( 215 [torch.nn.Parameter(torch.randn(50, 100)).to(self.device) for _ in range(5)] 216 ) 217 218 self.biases = ( 219 ([torch.nn.Parameter(torch.randn(50)).to(self.device) for _ in range(5)]) 220 if has_bias 221 else [None for _ in range(5)] 222 ) 223 224 def forward(self, x): 225 l1_out = torch.split(x.to(self.device), 100, dim=1) 226 l1_linear = [ 227 torch.nn.functional.linear(l1_out[i], self.weights[i], self.biases[i]) 228 for i in range(len(l1_out)) 229 ] 230 l1_out = torch.cat(l1_linear, dim=1) 231 return torch.sin(l1_out) 232 233 234class TestPoitwiseOps(torch.nn.Module): 235 def __init__(self, device, has_bias=True): 236 super().__init__() 237 self.device = device 238 239 def forward(self, x): 240 inputs = torch.split(x.to(self.device), 500, dim=1) 241 x_split = torch.split(inputs[0].to(self.device), 50, dim=1) 242 y_split = torch.split(inputs[1].to(self.device), 50, dim=1) 243 tanh_1 = [torch.tanh(x_split[i]) for i in range(len(x_split))] 244 tanh_2 = [torch.tanh(y_split[i]) for i in range(len(y_split))] 245 sigmoid_1 = [torch.sigmoid(tanh_1[i]) for i in range(len(tanh_1))] 246 sigmoid_2 = [torch.sigmoid(tanh_2[i]) for i in range(len(tanh_2))] 247 relu_1 = [torch.nn.functional.relu(sigmoid_1[i]) for i in range(len(sigmoid_1))] 248 relu_2 = [torch.nn.functional.relu(sigmoid_2[i]) for i in range(len(sigmoid_2))] 249 add = [torch.add(relu_1[i], relu_2[i]) for i in range(len(relu_1))] 250 mul = [torch.mul(add[i], add[i]) for i in range(len(add))] 251 sub = [torch.sub(mul[i], mul[i]) for i in range(len(mul))] 252 div = [torch.div(sub[i], sub[i]) for i in range(len(sub))] 253 return torch.cat(div, dim=1) 254 255 256class TestPoitwiseOpsPostGrad(torch.nn.Module): 257 def __init__(self, device): 258 super().__init__() 259 self.device = device 260 261 def forward(self, x): 262 inputs = torch.ops.aten.split(x.to(self.device), 500, dim=1) 263 x_split = torch.ops.aten.split(inputs[0].to(self.device), 50, dim=1) 264 y_split = torch.ops.aten.split(inputs[1].to(self.device), 50, dim=1) 265 tanh_1 = [torch.ops.aten.tanh(x_split[i]) for i in range(len(x_split))] 266 tanh_2 = [torch.ops.aten.tanh(y_split[i]) for i in range(len(y_split))] 267 sigmoid_1 = [torch.ops.aten.sigmoid(tanh_1[i]) for i in range(len(tanh_1))] 268 sigmoid_2 = [torch.ops.aten.sigmoid(tanh_2[i]) for i in range(len(tanh_2))] 269 relu_1 = [torch.ops.aten.relu(sigmoid_1[i]) for i in range(len(sigmoid_1))] 270 relu_2 = [torch.ops.aten.relu(sigmoid_2[i]) for i in range(len(sigmoid_2))] 271 add = [torch.ops.aten.add(relu_1[i], relu_2[i]) for i in range(len(relu_1))] 272 return torch.cat(add, dim=1) 273 274 275@requires_cuda 276@torch._inductor.config.patch( 277 pre_grad_fusion_options={ 278 "batch_linear": {}, 279 "batch_linear_lhs": {}, 280 "batch_layernorm": {}, 281 "batch_tanh": {}, 282 "batch_relu": {}, 283 "batch_sigmoid": {}, 284 }, 285 post_grad_fusion_options={ 286 "batch_aten_add": {}, 287 "batch_aten_mul": {}, 288 "batch_aten_sub": {}, 289 "batch_aten_div": {}, 290 "group_linear": {"require_fbgemm": True}, 291 }, 292) 293class TestGroupBatchFusion(TestCase): 294 def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3): 295 if len(set(ref_dict.keys())) != len(set(res_dict.keys())): 296 return False 297 for key1 in ref_dict.keys(): 298 key2 = "_orig_mod." + key1 299 assert key2 in res_dict, f"{key1} does not exist in traced module" 300 if not torch.allclose(ref_dict[key1], res_dict[key2], rtol=rtol, atol=atol): 301 return False 302 return True 303 304 def compare_pred(self, module, traced, input, rtol=1e-3, atol=1e-3): 305 ref = module(*input) 306 res = traced(*input) 307 self.assertEqual(ref, res, rtol=rtol, atol=atol) 308 309 def compare_parameters(self, module, traced, rtol=1e-3, atol=1e-3): 310 ref_params = dict(module.named_parameters()) 311 res_params = dict(traced.named_parameters()) 312 self.assertTrue(self.compare_dict_tensors(ref_params, res_params, rtol, atol)) 313 314 def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3): 315 ref_grad = {key: param.grad for key, param in module.named_parameters()} 316 res_grad = {key: param.grad for key, param in traced.named_parameters()} 317 self.assertTrue( 318 self.compare_dict_tensors(ref_grad, res_grad, rtol=rtol, atol=atol) 319 ) 320 321 @unittest.skipIf(not has_fbgemm, "requires fbgemm") 322 def test_group_linear_fusion(self): 323 z = 10 324 for has_bias in [True, False]: 325 counters.clear() 326 module = MyModule(z, has_bias).to("cuda") 327 input = [torch.randn(z, z, device="cuda")] 328 traced = torch.compile(module) 329 ref = module(*input) 330 res = traced(*input) 331 self.compare_pred(module, traced, input) 332 self.assertEqual( 333 counters["inductor"]["group_linear"], 334 2, 335 ) 336 self.assertNotIn("group_batch_fusion_pre_grad", optimus_scuba_log) 337 ref.sum().backward() 338 res.sum().backward() 339 self.compare_parameters(module, traced) 340 self.compare_gradients(module, traced) 341 self.assertEqual( 342 counters["inductor"]["group_linear"], 343 4, 344 ) 345 self.assertEqual( 346 counters["inductor"]["batch_aten_add"], 347 3, 348 ) 349 self.assertIn("GroupLinearFusion", optimus_scuba_log) 350 counters.clear() 351 352 @unittest.skipIf(not has_fbgemm, "requires fbgemm") 353 def test_group_linear_fusion_different_shapes(self): 354 counters.clear() 355 module = MyModule2().eval().to("cuda") 356 input = [torch.rand(4, 24, device="cuda")] 357 traced = torch.compile(module) 358 ref = module(*input) 359 res = traced(*input) 360 self.compare_pred(module, traced, input) 361 self.assertEqual( 362 counters["inductor"]["group_linear"], 363 1, 364 ) 365 self.assertEqual( 366 counters["inductor"]["batch_fusion"], 367 0, 368 ) 369 ref.sum().backward() 370 res.sum().backward() 371 self.compare_parameters(module, traced) 372 self.compare_gradients(module, traced) 373 self.assertEqual( 374 counters["inductor"]["group_linear"], 375 2, 376 ) 377 self.assertEqual( 378 counters["inductor"]["batch_aten_mul"], 379 1, 380 ) 381 counters.clear() 382 383 def test_batch_layer_norm_fusion(self): 384 for has_weight in [True, False]: 385 for has_bias in [True, False]: 386 counters.clear() 387 module = MyModule3("cuda", has_weight, has_bias).to("cuda") 388 input = [torch.randn(2, 5, 50, device="cuda")] 389 traced = torch.compile(module) 390 ref = module(*input) 391 res = traced(*input) 392 self.compare_pred(module, traced, input) 393 self.assertEqual(counters["inductor"]["batch_layernorm"], 2) 394 ref.sum().backward() 395 res.sum().backward() 396 self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8) 397 self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8) 398 counters.clear() 399 400 def test_batch_linear_lhs_fusion(self): 401 z = 10 402 for has_bias in [True, False]: 403 counters.clear() 404 module = MyModule4(z, "cuda", has_bias) 405 input = [torch.randn(20, z, device="cuda")] 406 traced = torch.compile(module) 407 ref = module(*input) 408 res = traced(*input) 409 self.compare_pred(module, traced, input) 410 self.assertEqual(counters["inductor"]["batch_linear_lhs"], 2) 411 ref.sum().backward() 412 res.sum().backward() 413 self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8) 414 self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8) 415 counters.clear() 416 417 def test_batch_linear_pre_grad_fusion(self): 418 for has_bias in [True, False]: 419 counters.clear() 420 module = MyModule5("cuda", has_bias) 421 input = [torch.randn(50, 500, device="cuda")] 422 traced = torch.compile(module) 423 ref = module(*input) 424 res = traced(*input) 425 self.compare_pred(module, traced, input) 426 self.assertEqual(counters["inductor"]["batch_linear"], 1) 427 ref.sum().backward() 428 res.sum().backward() 429 self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8) 430 self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8) 431 counters.clear() 432 433 def test_pointwise_op_fusion(self): 434 counters.clear() 435 module = TestPoitwiseOps("cuda") 436 input = [torch.randn(50, 1000, requires_grad=True, device="cuda")] 437 traced = torch.compile(module) 438 ref = module(*input) 439 res = traced(*input) 440 self.compare_pred(module, traced, input) 441 self.assertEqual(counters["inductor"]["batch_tanh"], 1) 442 self.assertEqual(counters["inductor"]["batch_relu"], 1) 443 self.assertEqual(counters["inductor"]["batch_sigmoid"], 1) 444 self.assertEqual(counters["inductor"]["batch_aten_add"], 1) 445 self.assertEqual(counters["inductor"]["batch_aten_mul"], 1) 446 self.assertEqual(counters["inductor"]["batch_aten_sub"], 1) 447 self.assertEqual(counters["inductor"]["batch_aten_div"], 1) 448 ref.sum().backward() 449 res.sum().backward() 450 self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8) 451 self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8) 452 counters.clear() 453 454 @requires_cuda 455 @torch._inductor.config.patch( 456 pre_grad_fusion_options={}, 457 post_grad_fusion_options={ 458 "batch_aten_relu": {}, 459 "batch_aten_sigmoid": {}, 460 "batch_aten_tanh": {}, 461 "unbind_stack_aten_pass": {}, 462 }, 463 ) 464 def test_pointwise_op_fusion_post_grad(self): 465 counters.clear() 466 module = TestPoitwiseOpsPostGrad("cuda") 467 input = [torch.randn(50, 1000, requires_grad=True, device="cuda")] 468 traced = torch.compile(module) 469 ref = module(*input) 470 res = traced(*input) 471 self.compare_pred(module, traced, input) 472 self.assertEqual(counters["inductor"]["batch_aten_tanh"], 1) 473 self.assertEqual(counters["inductor"]["batch_aten_relu"], 1) 474 self.assertEqual(counters["inductor"]["batch_aten_sigmoid"], 1) 475 self.assertEqual(counters["inductor"]["unbind_stack_aten_pass"], 2) 476 ref.sum().backward() 477 res.sum().backward() 478 self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8) 479 self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8) 480 counters.clear() 481 482 @requires_cuda 483 @torch._inductor.config.patch( 484 pre_grad_fusion_options={}, 485 post_grad_fusion_options={ 486 "batch_linear_post_grad": { 487 "shape_broadcast_batch_linear": True, 488 "fuse_nodes_with_same_users": True, 489 }, 490 "batch_aten_mul": {"fuse_nodes_with_same_parent": False}, 491 "batch_aten_sigmoid": {"fuse_nodes_with_same_parent": True}, 492 "batch_aten_add": {"fuse_nodes_with_same_parent": True}, 493 "normalization_aten_pass": {}, 494 "unbind_stack_aten_pass": {}, 495 }, 496 ) 497 def test_gate_fusion_post_grad(self): 498 counters.clear() 499 size = 20 500 module = TestHighwaySelfGating(d_model=10, size=size) 501 input = [ 502 [ 503 torch.randn(10, 10, requires_grad=True, device="cuda") 504 for i in range(size) 505 ] 506 ] 507 traced = torch.compile(module) 508 ref = module(*input) 509 res = traced(*input) 510 self.compare_pred(module, traced, input) 511 self.assertEqual(counters["inductor"]["batch_linear_post_grad"], 2) 512 self.assertEqual(counters["inductor"]["batch_aten_sigmoid"], 1) 513 self.assertEqual(counters["inductor"]["batch_aten_mul"], 1) 514 self.assertEqual(counters["inductor"]["batch_aten_add"], 2) 515 self.assertEqual(counters["inductor"]["normalization_aten_pass"], 1) 516 self.assertEqual(counters["inductor"]["unbind_stack_aten_pass"], 5) 517 ref.sum().backward() 518 res.sum().backward() 519 self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8) 520 self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8) 521 counters.clear() 522 523 524class TestBMMFusionModule(torch.nn.Module): 525 def __init__(self) -> None: 526 super().__init__() 527 self.my_modules = torch.nn.ModuleList() 528 for _ in range(10): 529 self.my_modules.append(torch.nn.Linear(10, 10)) 530 531 def forward(self, inputs): 532 output = None 533 for linear, input in zip(self.my_modules, inputs): 534 if output is None: 535 output = linear(input) 536 else: 537 output += linear(input) 538 return output 539 540 541@requires_cuda 542@torch._inductor.config.patch( 543 post_grad_fusion_options={"batch_linear_post_grad": {"require_fbgemm": False}} 544) 545class TestPostGradBatchLinearFusion(TestCase): 546 def test_batch_linear_post_grad_fusion(self): 547 pt1_module = TestBMMFusionModule().cuda() 548 inputs = [] 549 for _ in range(10): 550 inputs.append(torch.randn(10, 10).cuda()) 551 eager_output = pt1_module(inputs) 552 pt2_module = torch.compile(pt1_module) 553 pt2_output = pt2_module(inputs) 554 self.assertTrue(torch.allclose(eager_output, pt2_output)) 555 self.assertEqual( 556 counters["inductor"]["batch_linear_post_grad"], 557 2, 558 ) 559 self.assertIn("PostGradBatchLinearFusion", optimus_scuba_log) 560 561 562class TestFindIndependentSubsetGreedy(TestCase): 563 # Helper function to build a Graph from a data description. 564 def build_graph(self, desc): 565 # desc: { 566 # "n1": ["n2", "n3"], 567 # "n2": ["n3"], 568 # "n3": [], 569 # } 570 # 571 g = torch.fx.Graph() 572 lookup = {} 573 desc = collections.deque((k, v) for k, v in desc.items()) 574 unsatisfied = 0 575 while desc: 576 unsatisfied += 1 577 assert unsatisfied <= len(desc) # cycle or bad input? 578 name, v = desc.popleft() 579 args = tuple(lookup.get(n, None) for n in v) 580 if None in args: 581 desc.append((name, v)) 582 continue 583 node = g.create_node("placeholder", "target", name=name, args=args) 584 lookup[name] = node 585 unsatisfied = 0 586 return g, lookup 587 588 def verify(self, tree, subnodes, min_fuse, max_fuse, expected): 589 g, lookup = self.build_graph(tree) 590 subnodes = [lookup[n] for n in subnodes] 591 expected = [[lookup[n] for n in sub] for sub in expected] 592 opts = { 593 "min_fuse_set_size": min_fuse, 594 "max_fuse_set_size": max_fuse, 595 } 596 result = list( 597 torch._inductor.fx_passes.group_batch_fusion.find_independent_subset_greedy( 598 subnodes, opts 599 ) 600 ) 601 self.assertEqual(expected, result) 602 603 def test_find_independent_subset_greedy(self): 604 # First some randomly generated tests. 605 self.verify({"n0": (), "n1": ()}, ["n0"], 0, 100, [["n0"]]) 606 self.verify( 607 {"n0": (), "n1": (), "n2": ("n0",)}, ["n1", "n2"], 0, 100, [["n1", "n2"]] 608 ) 609 self.verify( 610 { 611 "n0": (), 612 "n1": (), 613 "n2": ("n0",), 614 "n3": (), 615 "n4": ("n0", "n1", "n2"), 616 "n5": ("n0", "n2", "n4"), 617 "n6": ("n3",), 618 "n7": ("n4", "n5", "n6", "n1", "n3"), 619 "n8": ("n7", "n1", "n3", "n5", "n0"), 620 "n9": ("n3", "n4", "n8", "n6", "n5", "n2", "n0", "n7"), 621 "n10": ("n0",), 622 "n11": ("n4", "n0", "n2", "n3", "n1", "n9"), 623 "n12": ("n2", "n3", "n10", "n6", "n9"), 624 }, 625 ["n10", "n5", "n3", "n4", "n9"], 626 0, 627 100, 628 [["n10", "n5", "n3"], ["n4"], ["n9"]], 629 ) 630 self.verify({"n0": (), "n1": (), "n2": ("n0",)}, ["n2"], 0, 100, [["n2"]]) 631 self.verify( 632 { 633 "n0": (), 634 "n1": (), 635 "n2": (), 636 "n3": (), 637 "n4": ("n3", "n1", "n0"), 638 "n5": ("n1", "n2", "n4", "n0"), 639 "n6": ("n0", "n3", "n2"), 640 "n7": ("n6", "n1", "n5", "n4", "n3", "n0"), 641 "n8": ("n2", "n7", "n3"), 642 "n9": ("n3", "n5", "n6", "n7", "n2", "n1"), 643 "n10": ("n8", "n0", "n2", "n4", "n6", "n3"), 644 "n11": ("n6", "n5", "n8", "n1", "n3", "n10", "n2"), 645 "n12": ("n7", "n4"), 646 }, 647 ["n7"], 648 0, 649 100, 650 [["n7"]], 651 ) 652 self.verify( 653 { 654 "n0": (), 655 "n1": (), 656 "n2": (), 657 "n3": ("n1", "n2"), 658 "n4": ("n1",), 659 "n5": (), 660 "n6": ("n5",), 661 "n7": ("n1", "n6", "n5", "n2", "n3", "n0"), 662 "n8": ("n5", "n7", "n2", "n6"), 663 "n9": ("n1",), 664 "n10": ("n9",), 665 "n11": ("n3", "n4", "n0", "n2"), 666 "n12": ("n8", "n9", "n5", "n1"), 667 "n13": ("n11", "n4", "n12", "n1", "n9", "n3", "n0"), 668 }, 669 ["n9", "n2", "n8", "n10", "n5", "n6", "n13", "n7", "n3", "n0", "n4"], 670 0, 671 100, 672 [ 673 ["n9", "n2", "n5", "n0", "n4"], 674 ["n8", "n10"], 675 ["n6", "n3"], 676 ["n13"], 677 ["n7"], 678 ], 679 ) 680 self.verify({"n0": ()}, ["n0"], 0, 100, [["n0"]]) 681 self.verify( 682 { 683 "n0": (), 684 "n1": (), 685 "n2": (), 686 "n3": (), 687 "n4": ("n1", "n2"), 688 "n5": ("n0", "n4", "n1"), 689 "n6": ("n1", "n5"), 690 "n7": (), 691 "n8": ("n7", "n1", "n3", "n5", "n6"), 692 "n9": ("n2", "n1", "n8", "n0", "n4", "n7", "n6", "n5"), 693 "n10": ("n4", "n7", "n2", "n3", "n8"), 694 "n11": (), 695 "n12": ("n9", "n7", "n5", "n11", "n8"), 696 "n13": ( 697 "n5", 698 "n6", 699 "n12", 700 "n3", 701 "n9", 702 "n8", 703 "n4", 704 "n11", 705 "n2", 706 "n10", 707 "n1", 708 ), 709 "n14": ("n7", "n3", "n12", "n10", "n2", "n0", "n4", "n5"), 710 "n15": ("n9", "n5", "n1", "n13", "n8", "n10", "n12", "n7", "n11", "n3"), 711 "n16": ( 712 "n2", 713 "n4", 714 "n15", 715 "n5", 716 "n0", 717 "n6", 718 "n3", 719 "n8", 720 "n14", 721 "n12", 722 "n9", 723 "n10", 724 "n7", 725 "n13", 726 ), 727 }, 728 ["n0", "n3", "n2", "n11", "n1", "n6", "n12", "n5", "n4", "n15", "n8"], 729 0, 730 100, 731 [ 732 ["n0", "n3", "n2", "n11", "n1"], 733 ["n6"], 734 ["n12"], 735 ["n5"], 736 ["n4"], 737 ["n15"], 738 ["n8"], 739 ], 740 ) 741 self.verify( 742 { 743 "n0": (), 744 "n1": (), 745 "n2": (), 746 "n3": ("n2", "n1"), 747 "n4": ("n2", "n3", "n1"), 748 "n5": ("n3", "n1"), 749 "n6": ("n1",), 750 "n7": ("n5", "n4"), 751 "n8": ("n6", "n2"), 752 }, 753 ["n4", "n3", "n1", "n8", "n5", "n6", "n2"], 754 0, 755 100, 756 [["n4", "n8", "n5"], ["n3", "n6"], ["n1", "n2"]], 757 ) 758 self.verify( 759 { 760 "n0": (), 761 "n1": (), 762 "n2": (), 763 "n3": ("n1", "n0"), 764 "n4": ("n0",), 765 "n5": ("n1", "n4"), 766 "n6": ("n2", "n1", "n4"), 767 "n7": ("n0", "n3"), 768 "n8": ("n5", "n0", "n6", "n1", "n4", "n2", "n3"), 769 "n9": ("n1", "n4", "n8", "n7", "n5"), 770 "n10": ("n9", "n8", "n0", "n2", "n7", "n1", "n3", "n5"), 771 "n11": ("n9", "n2", "n6", "n0", "n3"), 772 "n12": ("n1", "n4", "n7", "n10", "n5", "n2", "n11", "n6"), 773 "n13": ("n9", "n2", "n3", "n0", "n7", "n5", "n10", "n11"), 774 "n14": ( 775 "n8", 776 "n0", 777 "n3", 778 "n6", 779 "n10", 780 "n1", 781 "n5", 782 "n9", 783 "n12", 784 "n11", 785 "n4", 786 ), 787 "n15": ( 788 "n3", 789 "n10", 790 "n0", 791 "n4", 792 "n9", 793 "n11", 794 "n2", 795 "n13", 796 "n12", 797 "n8", 798 "n5", 799 "n14", 800 ), 801 "n16": ("n6",), 802 "n17": ( 803 "n4", 804 "n3", 805 "n14", 806 "n8", 807 "n15", 808 "n16", 809 "n2", 810 "n5", 811 "n7", 812 "n12", 813 "n1", 814 "n0", 815 "n11", 816 ), 817 }, 818 ["n17", "n16", "n10", "n4", "n8", "n12", "n6", "n1"], 819 0, 820 100, 821 [["n17"], ["n16", "n10"], ["n4", "n1"], ["n8"], ["n12"], ["n6"]], 822 ) 823 self.verify( 824 { 825 "n0": (), 826 "n1": (), 827 "n2": ("n0",), 828 "n3": ("n0", "n1"), 829 "n4": ("n0",), 830 "n5": ("n0",), 831 "n6": ("n5", "n3", "n0", "n2"), 832 "n7": (), 833 "n8": ("n2", "n5", "n3", "n1", "n7", "n6", "n0"), 834 "n9": ("n4",), 835 "n10": ("n4", "n5", "n1", "n2", "n0", "n6", "n8", "n9", "n7"), 836 "n11": ("n3", "n0", "n9", "n10", "n5", "n1", "n2", "n7", "n4", "n6"), 837 "n12": ("n9", "n5"), 838 }, 839 ["n8", "n3", "n1", "n12", "n2", "n5", "n11", "n4", "n10", "n6", "n0"], 840 0, 841 100, 842 [ 843 ["n8", "n12"], 844 ["n3", "n2", "n5", "n4"], 845 ["n1", "n0"], 846 ["n11"], 847 ["n10"], 848 ["n6"], 849 ], 850 ) 851 self.verify( 852 { 853 "n0": (), 854 "n1": (), 855 "n2": (), 856 "n3": (), 857 "n4": ("n2", "n3"), 858 "n5": ("n1", "n3", "n2", "n4"), 859 "n6": ("n5", "n4", "n1", "n3"), 860 "n7": ("n5",), 861 "n8": ("n5", "n4", "n1"), 862 "n9": ("n2", "n3", "n1", "n5", "n7", "n0", "n8"), 863 "n10": ("n5", "n3", "n1", "n7", "n8", "n9"), 864 "n11": ("n1", "n4", "n2", "n0", "n8", "n9"), 865 "n12": ("n4", "n3", "n9"), 866 "n13": ( 867 "n6", 868 "n10", 869 "n4", 870 "n8", 871 "n0", 872 "n11", 873 "n12", 874 "n7", 875 "n3", 876 "n2", 877 "n1", 878 ), 879 "n14": ("n4", "n13", "n2"), 880 "n15": ("n11", "n7", "n6", "n10", "n14"), 881 "n16": ("n15", "n3"), 882 "n17": ("n10", "n2", "n7", "n0", "n5", "n6", "n9"), 883 "n18": ( 884 "n16", 885 "n8", 886 "n6", 887 "n9", 888 "n11", 889 "n12", 890 "n14", 891 "n5", 892 "n13", 893 "n4", 894 "n1", 895 ), 896 }, 897 [ 898 "n1", 899 "n0", 900 "n16", 901 "n6", 902 "n15", 903 "n9", 904 "n7", 905 "n4", 906 "n3", 907 "n11", 908 "n13", 909 "n17", 910 "n12", 911 "n18", 912 ], 913 0, 914 100, 915 [ 916 ["n1", "n0", "n4"], 917 ["n16", "n17"], 918 ["n6", "n9"], 919 ["n15"], 920 ["n7"], 921 ["n3"], 922 ["n11", "n12"], 923 ["n13"], 924 ["n18"], 925 ], 926 ) 927 self.verify( 928 { 929 "n0": (), 930 "n1": (), 931 "n2": (), 932 "n3": ("n2",), 933 "n4": ("n1",), 934 "n5": (), 935 "n6": ("n1", "n4"), 936 "n7": ("n5", "n1"), 937 "n8": ("n6",), 938 "n9": ("n6", "n1", "n2", "n0"), 939 "n10": ("n0", "n7"), 940 "n11": ("n0", "n4", "n3", "n5"), 941 "n12": ("n9", "n8", "n7", "n4", "n0"), 942 }, 943 ["n8", "n9", "n11", "n2", "n4", "n0", "n7", "n5", "n1"], 944 0, 945 100, 946 [["n8", "n9", "n11", "n7"], ["n2", "n4", "n0", "n5"], ["n1"]], 947 ) 948 self.verify( 949 {"n0": (), "n1": (), "n2": (), "n3": ("n0",), "n4": ("n3",)}, 950 ["n1", "n2", "n4"], 951 0, 952 100, 953 [["n1", "n2", "n4"]], 954 ) 955 self.verify( 956 { 957 "n0": (), 958 "n1": (), 959 "n2": ("n1",), 960 "n3": ("n2", "n1"), 961 "n4": ("n3",), 962 "n5": (), 963 "n6": ("n1", "n5"), 964 "n7": (), 965 "n8": ("n4", "n5"), 966 "n9": ("n0", "n3", "n6", "n4", "n5", "n8", "n7", "n1"), 967 "n10": ("n3", "n0", "n6", "n9", "n7"), 968 "n11": (), 969 "n12": ("n1", "n8", "n3", "n6", "n7", "n0", "n10", "n5", "n9", "n11"), 970 "n13": ("n9", "n11", "n4"), 971 "n14": (), 972 "n15": ("n6", "n12"), 973 "n16": ( 974 "n1", 975 "n7", 976 "n10", 977 "n3", 978 "n9", 979 "n0", 980 "n2", 981 "n5", 982 "n8", 983 "n13", 984 "n14", 985 "n15", 986 "n4", 987 "n6", 988 ), 989 }, 990 [ 991 "n11", 992 "n16", 993 "n5", 994 "n12", 995 "n7", 996 "n2", 997 "n0", 998 "n6", 999 "n3", 1000 "n9", 1001 "n8", 1002 "n15", 1003 "n14", 1004 "n4", 1005 "n13", 1006 "n1", 1007 ], 1008 0, 1009 100, 1010 [ 1011 ["n11", "n5", "n7", "n2", "n0", "n14"], 1012 ["n16"], 1013 ["n12", "n13"], 1014 ["n6", "n3"], 1015 ["n9"], 1016 ["n8"], 1017 ["n15"], 1018 ["n4"], 1019 ["n1"], 1020 ], 1021 ) 1022 self.verify({"n0": (), "n1": ()}, ["n1"], 0, 100, [["n1"]]) 1023 self.verify( 1024 { 1025 "n0": (), 1026 "n1": (), 1027 "n2": ("n1",), 1028 "n3": (), 1029 "n4": ("n0", "n2", "n3"), 1030 "n5": ("n2", "n3"), 1031 "n6": ("n3",), 1032 }, 1033 ["n6", "n2", "n3", "n1"], 1034 0, 1035 100, 1036 [["n6", "n2"], ["n3", "n1"]], 1037 ) 1038 self.verify( 1039 { 1040 "n0": (), 1041 "n1": (), 1042 "n2": (), 1043 "n3": ("n2",), 1044 "n4": ("n0",), 1045 "n5": ("n1", "n2"), 1046 "n6": ("n2", "n3", "n1", "n0", "n5"), 1047 "n7": ("n6", "n2", "n0", "n4", "n5", "n1"), 1048 "n8": ("n4",), 1049 "n9": ("n4", "n6", "n7", "n1", "n2"), 1050 }, 1051 ["n8", "n6", "n2", "n4", "n7", "n5", "n3", "n9"], 1052 0, 1053 100, 1054 [["n8", "n6"], ["n2", "n4"], ["n7"], ["n5", "n3"], ["n9"]], 1055 ) 1056 self.verify( 1057 { 1058 "n0": (), 1059 "n1": (), 1060 "n2": (), 1061 "n3": ("n1", "n2"), 1062 "n4": ("n0",), 1063 "n5": ("n2", "n3", "n0", "n1"), 1064 "n6": ("n4", "n1"), 1065 "n7": ("n5",), 1066 "n8": ("n7", "n1", "n5", "n6", "n3", "n4", "n0"), 1067 "n9": ("n2", "n8"), 1068 }, 1069 ["n1", "n7", "n4", "n2", "n0", "n8", "n3", "n5"], 1070 0, 1071 100, 1072 [["n1", "n4", "n2"], ["n7"], ["n0", "n3"], ["n8"], ["n5"]], 1073 ) 1074 self.verify( 1075 { 1076 "n0": (), 1077 "n1": (), 1078 "n2": ("n0",), 1079 "n3": ("n1",), 1080 "n4": ("n2", "n1"), 1081 "n5": (), 1082 "n6": ("n0",), 1083 "n7": ("n6", "n3", "n2", "n1", "n0"), 1084 "n8": ("n0", "n2"), 1085 "n9": ("n6", "n5", "n8", "n4", "n0"), 1086 "n10": ("n1", "n7", "n5", "n8", "n6", "n2", "n4", "n9"), 1087 }, 1088 ["n0"], 1089 0, 1090 100, 1091 [["n0"]], 1092 ) 1093 1094 # trivial test of min_fuse 1095 self.verify( 1096 { 1097 "n0": (), 1098 "n1": (), 1099 "n2": (), 1100 "n3": ("n1", "n2"), 1101 "n4": ("n1",), 1102 "n5": (), 1103 "n6": ("n5",), 1104 "n7": ("n1", "n6", "n5", "n2", "n3", "n0"), 1105 "n8": ("n5", "n7", "n2", "n6"), 1106 "n9": ("n1",), 1107 "n10": ("n9",), 1108 "n11": ("n3", "n4", "n0", "n2"), 1109 "n12": ("n8", "n9", "n5", "n1"), 1110 "n13": ("n11", "n4", "n12", "n1", "n9", "n3", "n0"), 1111 }, 1112 ["n9", "n2", "n8", "n10", "n5", "n6", "n13", "n7", "n3", "n0", "n4"], 1113 2, 1114 10, 1115 [["n9", "n2", "n5", "n0", "n4"], ["n8", "n10"], ["n6", "n3"]], 1116 ) 1117 1118 # trivial test of max_fuse 1119 self.verify( 1120 { 1121 "n0": (), 1122 "n1": (), 1123 "n2": (), 1124 "n3": ("n1", "n2"), 1125 "n4": ("n1",), 1126 "n5": (), 1127 "n6": ("n5",), 1128 "n7": ("n1", "n6", "n5", "n2", "n3", "n0"), 1129 "n8": ("n5", "n7", "n2", "n6"), 1130 "n9": ("n1",), 1131 "n10": ("n9",), 1132 "n11": ("n3", "n4", "n0", "n2"), 1133 "n12": ("n8", "n9", "n5", "n1"), 1134 "n13": ("n11", "n4", "n12", "n1", "n9", "n3", "n0"), 1135 }, 1136 ["n9", "n2", "n8", "n10", "n5", "n6", "n13", "n7", "n3", "n0", "n4"], 1137 0, 1138 3, 1139 [ 1140 ["n9", "n2", "n5"], 1141 ["n8", "n10", "n4"], 1142 ["n6", "n3", "n0"], 1143 ["n13"], 1144 ["n7"], 1145 ], 1146 ) 1147 1148 def test_find_independent_subset_greedy_fuse(self): 1149 # ensure that fusing the sets during iteration results in the correct 1150 # iteration results. In the example graph after we merge n2 and n3, 1151 # n4 is no longer independent from n1. 1152 g, lookup = self.build_graph( 1153 { 1154 "n0": (), 1155 "n1": (), 1156 "n2": ("n0",), 1157 "n3": ("n1",), 1158 "n4": ("n2",), 1159 "n5": (), 1160 } 1161 ) 1162 opts = { 1163 "min_fuse_set_size": 0, 1164 "max_fuse_set_size": 100, 1165 } 1166 subnodes = ["n2", "n3", "n4", "n0", "n1", "n5"] 1167 subnodes = [lookup[n] for n in subnodes] 1168 i = torch._inductor.fx_passes.group_batch_fusion.find_independent_subset_greedy( 1169 subnodes, opts 1170 ) 1171 self.assertEqual(next(i), [lookup[n] for n in ["n2", "n3", "n5"]]) 1172 1173 # fuse n2 and n3 which makes n4 now dependant on n1. 1174 args = tuple(lookup[n] for n in ["n0", "n1"]) 1175 fused = g.create_node("placeholder", "target", name="n2+n3", args=args) 1176 lookup["n2"].replace_all_uses_with(fused) 1177 g.erase_node(lookup["n2"]) 1178 lookup["n3"].replace_all_uses_with(fused) 1179 g.erase_node(lookup["n3"]) 1180 1181 self.assertEqual(next(i), [lookup[n] for n in ["n4"]]) 1182 self.assertEqual(next(i), [lookup[n] for n in ["n0", "n1"]]) 1183 self.assertRaises(StopIteration, lambda: next(i)) 1184 1185 1186if __name__ == "__main__": 1187 run_tests() 1188