1# Owner(s): ["module: inductor"] 2 3 4import torch 5from torch._dynamo.utils import counters, optimus_scuba_log 6from torch._inductor.fx_passes.misc_patterns import numpy_compat_normalization 7from torch._inductor.test_case import run_tests, TestCase 8from torch.testing._internal.common_utils import IS_LINUX 9from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU 10from torch.testing._internal.triton_utils import requires_gpu 11 12 13def patch(f): 14 f = torch._inductor.config.patch( 15 pre_grad_fusion_options={ 16 "normalization_pass": {}, 17 "remove_split_with_size_one_pass": {}, 18 "merge_getitem_cat_pass": {}, 19 "merge_splits_pass": {}, 20 "mutate_cat_pass": {}, 21 "split_cat_pass": {}, 22 "unbind_stack_pass": {}, 23 }, 24 post_grad_fusion_options={}, 25 )(f) 26 return f 27 28 29class TestSplitCatFxPasses(TestCase): 30 @patch 31 def test_split_normalization(self): 32 def arg_only(x): 33 return [torch.relu(s) for s in torch.split(x, 2, 1)] 34 35 def arg_only_dim0(x): 36 return [torch.relu(s) for s in torch.split(x, 2, 0)] 37 38 def kwarg1(x): 39 return [torch.relu(s) for s in torch.split(x, 2, dim=1)] 40 41 def kwarg2(x): 42 return [ 43 torch.relu(s) for s in torch.split(x, split_size_or_sections=2, dim=1) 44 ] 45 46 def kwarg3(x): 47 return [ 48 torch.relu(s) 49 for s in torch.split(tensor=x, split_size_or_sections=2, dim=-1) 50 ] 51 52 def list_replace(x): 53 return [torch.relu(s) for s in torch.split(x, [16, 16], dim=1)] 54 55 def multi_split(x): 56 return [torch.split(s, 2, 1) for s in torch.split(x, 2, 1)] 57 58 def unequal_split(x): 59 return [torch.relu(s) for s in torch.split(x, 3, 1)] 60 61 def arg_only_cm(x): 62 return [torch.relu(s) for s in x.split(2, 1)] 63 64 def kwarg1_cm(x): 65 return [torch.relu(s) for s in x.split(2, dim=1)] 66 67 def kwarg2_cm(x): 68 return [torch.relu(s) for s in x.split(split_size=2, dim=1)] 69 70 def multi_split_cm(x): 71 return [s.split(2, 1) for s in x.split(2, 1)] 72 73 def unequal_split_cm(x): 74 return [torch.relu(s) for s in x.split(3, 1)] 75 76 def cm_with_list(x): 77 return [torch.relu(s) for s in x.split([16, 16], dim=-1)] 78 79 args = [ 80 torch.randn(2, 32), 81 ] 82 for fn, expected_split_norm_count in [ 83 (arg_only, 1), 84 (arg_only_dim0, 1), 85 (kwarg1, 1), 86 (kwarg2, 1), 87 (kwarg3, 1), 88 (list_replace, 0), 89 (multi_split, 17), 90 (unequal_split, 1), 91 (arg_only_cm, 1), 92 (kwarg1_cm, 1), 93 (kwarg2_cm, 1), 94 (multi_split_cm, 17), 95 (unequal_split_cm, 1), 96 (cm_with_list, 1), 97 ]: 98 expected = fn(*args) 99 actual = torch.compile(fn)(*args) 100 101 torch.testing.assert_close(actual, expected) 102 self.assertEqual( 103 counters["inductor"]["normalization_pass"], 104 expected_split_norm_count, 105 msg=f"for {fn}", 106 ) 107 if expected_split_norm_count > 0: 108 self.assertIn("normalization_pass_pre_grad", optimus_scuba_log) 109 counters.clear() 110 111 @patch 112 def test_consecutive_split_merge(self): 113 def multi_split(x): 114 return [torch.split(s, 2, 1) for s in torch.split(x, 2, 1)] 115 116 def multi_split_2(x): 117 return [torch.split(s, 1, 1) for s in torch.split(x, 2, 1)] 118 119 def multi_split_2_neg_dim(x): 120 return [torch.split(s, 1, 1) for s in torch.split(x, 2, -1)] 121 122 def multi_split_with_sizes(x): 123 return [torch.split(s, 2, 1) for s in torch.split(x, [16, 16], 1)] 124 125 def multi_split_kwarg1(x): 126 return [torch.split(s, 2, dim=1) for s in torch.split(x, 2, dim=1)] 127 128 def multi_split_kwarg2(x): 129 return [ 130 torch.split(s, split_size_or_sections=2, dim=1) 131 for s in torch.split(x, split_size_or_sections=2, dim=1) 132 ] 133 134 def unequal_multi_split(x): 135 fs = torch.split(x, [10, 10, 12], dim=1) 136 item0 = fs[0] 137 item1 = fs[1] 138 item2 = fs[2] 139 140 final_items = [] 141 final_items.extend(item0.split([4, 6], 1)) 142 final_items.extend(item1.split([6, 4], 1)) 143 final_items.extend(item2.split([4, 4, 4], 1)) 144 145 return [torch.relu(s) for s in final_items] 146 147 def unequal_multi_split_neg_index(x): 148 fs = torch.split(x, [10, 10, 12], dim=1) 149 item0 = fs[-3] 150 item1 = fs[-2] 151 item2 = fs[-1] 152 153 final_items = [] 154 final_items.extend(item0.split([4, 6], 1)) 155 final_items.extend(item1.split([6, 4], 1)) 156 final_items.extend(item2.split([4, 4, 4], 1)) 157 158 return [torch.relu(s) for s in final_items] 159 160 # Shouldn't merge 161 def diff_dims(x): 162 return [torch.split(s, 2, dim=0) for s in torch.split(x, 2, dim=1)] 163 164 def some_users_not_splits(x): 165 fs = torch.split(x, [10, 10, 12], dim=1) 166 item0 = fs[0] 167 item1 = fs[1] 168 item2 = fs[2] 169 170 final_items = [] 171 final_items.extend(item0.split([4, 6], 1)) 172 final_items.extend(item1.split([6, 4], 1)) 173 final_items.append(torch.sin(item2)) 174 175 return [torch.relu(s) for s in final_items] 176 177 def split_with_cat(x): 178 fs = torch.split(x, [4, 4, 24], dim=1) 179 item0 = fs[0] 180 item1 = fs[1] 181 item2 = fs[2] 182 183 final_items = [item0, item1] 184 final_items.extend(item2.split((4, 4, 4, 4, 4, 4), 1)) 185 186 return torch.cat(final_items, dim=1) 187 188 def duplicate_getitems(x): 189 fs = torch.split(x, [10, 10, 12], dim=1) 190 item0 = fs[0] 191 item1_1 = fs[1] 192 item1_2 = fs[1] 193 item2 = fs[2] 194 195 final_items = [] 196 final_items.extend(item0.split([4, 6], 1)) 197 final_items.extend(item1_1.split([6, 4], 1)) 198 final_items.extend(item1_2) 199 final_items.append(torch.sin(item2)) 200 201 return [torch.relu(s) for s in final_items] 202 203 def duplicate_getitems_neg_index(x): 204 fs = torch.split(x, [10, 10, 12], dim=1) 205 item0 = fs[0] 206 item1_1 = fs[1] 207 item1_2 = fs[-2] # negative index 208 item2 = fs[2] 209 210 final_items = [] 211 final_items.extend(item0.split([4, 6], 1)) 212 final_items.extend(item1_1.split([6, 4], 1)) 213 final_items.extend(item1_2) 214 final_items.append(torch.sin(item2)) 215 216 return [torch.relu(s) for s in final_items] 217 218 def split_getitem_gap(x): 219 fs = torch.split(x, [4, 4, 24], dim=1) 220 item0 = fs[0] 221 item2 = fs[2] 222 223 final_items = [ 224 item0, 225 ] 226 final_items.extend(item2.split((4, 4, 4, 4, 4, 4), 1)) 227 228 return torch.cat(final_items, dim=1) 229 230 def split_getitem_out_of_order(x): 231 fs = torch.split(x, [4, 4, 4, 20], dim=1) 232 item0 = fs[0] 233 item2 = fs[2] 234 item1 = fs[1] 235 item3 = fs[3] 236 237 final_items = [item0, item2, item1] 238 final_items.extend(item3.split((4, 4, 4, 4, 4), 1)) 239 240 return torch.cat(final_items, dim=1) 241 242 def split_partial_getitem_cat(x): 243 fs = torch.split(x, [4, 4, 24], dim=1) 244 item0 = fs[0] 245 item2 = fs[2] 246 247 final_items = [ 248 item0, 249 ] 250 final_items.extend(item2.split((4, 4, 4, 4, 4, 4), 1)) 251 252 return torch.cat(final_items, dim=1) 253 254 args = [ 255 torch.randn(2, 32), 256 ] 257 for fn, expected_split_merged in [ 258 (multi_split, 0), 259 (multi_split_2, 16), 260 (multi_split_2_neg_dim, 16), 261 (multi_split_with_sizes, 2), 262 (multi_split_kwarg1, 0), 263 (multi_split_kwarg2, 0), 264 (unequal_multi_split, 3), 265 (unequal_multi_split_neg_index, 3), 266 (diff_dims, 0), 267 (some_users_not_splits, 2), 268 (split_with_cat, 1), 269 (duplicate_getitems, 1), 270 (duplicate_getitems_neg_index, 1), 271 (split_getitem_gap, 1), 272 (split_getitem_out_of_order, 1), 273 (split_partial_getitem_cat, 1), 274 ]: 275 expected = fn(*args) 276 actual = torch.compile(fn)(*args) 277 278 torch.testing.assert_close(actual, expected) 279 self.assertEqual( 280 counters["inductor"]["merge_splits_pass"], 281 expected_split_merged, 282 ) 283 if expected_split_merged > 0: 284 self.assertIn("merge_splits_pass_pre_grad", optimus_scuba_log) 285 counters.clear() 286 287 @patch 288 def test_split_cat_merge(self): 289 def simple_split_cat(x): 290 return torch.cat(torch.split(x, 4, dim=1), dim=1) 291 292 def simple_split_cat_argspec1(x): 293 return torch.cat(torch.split(x, 4, dim=1), 1) 294 295 def simple_split_cat_argspec2(x): 296 return torch.cat(tensors=torch.split(x, 4, dim=1), dim=1) 297 298 def simple_split_cat_argspec3(x): 299 return torch.cat(torch.split(x, 4, dim=1), -2) 300 301 def simple_split_cat_argspec4(x): 302 return torch.cat(tensors=torch.split(x, 4, dim=1), dim=-2) 303 304 def simple_split_stack(x): 305 return torch.stack(torch.split(x, 4, dim=1), dim=1) 306 307 def simple_split_stack_argspec1(x): 308 return torch.stack(torch.split(x, 4, dim=1), 1) 309 310 def simple_split_stack_argspec2(x): 311 return torch.stack(tensors=torch.split(x, 4, dim=1), dim=1) 312 313 def split_cat_addn_args(x): 314 split_output = list(torch.split(x, 4, dim=1)) 315 return torch.cat( 316 [torch.ones(2, 5, 32, 16)] + split_output + [torch.ones(2, 6, 32, 16)], 317 dim=1, 318 ) 319 320 def split_stack_addn_args(x): 321 split_output = list(torch.split(x, 4, dim=1)) 322 return torch.stack( 323 [torch.ones(2, 4, 32, 16)] 324 + split_output 325 + [torch.ones(2, 4, 32, 16), torch.ones(2, 4, 32, 16)], 326 dim=1, 327 ) 328 329 def split_cat_addn_args_dim2(x): 330 split_output = list(torch.split(x, 4, dim=2)) 331 return torch.cat( 332 [torch.ones(2, 32, 5, 16)] + split_output + [torch.ones(2, 32, 6, 16)], 333 dim=2, 334 ) 335 336 # split_dim=1, cat_dim=2 337 def split_cat_dim_mismatch(x): 338 split_output = list(torch.split(x, 4, dim=1)) 339 return torch.cat( 340 [torch.ones(2, 4, 32, 16)] + split_output + [torch.ones(2, 4, 32, 16)], 341 dim=2, 342 ) 343 344 def split_stack_dim_mismatch(x): 345 split_output = list(torch.split(x, 4, dim=1)) 346 return torch.stack( 347 [torch.ones(2, 4, 32, 16)] + split_output + [torch.ones(2, 4, 32, 16)], 348 dim=2, 349 ) 350 351 # split_dim=1, cat_dim=3 352 def split_cat_dim_mismatch2(x): 353 split_output = list(torch.split(x, 4, dim=1)) 354 return torch.cat( 355 [torch.ones(2, 4, 32, 16)] + split_output + [torch.ones(2, 4, 32, 16)], 356 dim=3, 357 ) 358 359 def split_stack_dim_mismatch2(x): 360 split_output = list(torch.split(x, 4, dim=1)) 361 return torch.stack( 362 [torch.ones(2, 4, 32, 16)] + split_output + [torch.ones(2, 4, 32, 16)], 363 dim=3, 364 ) 365 366 # split_dim=2, cat_dim=0 367 def split_cat_dim_mismatch3(x): 368 split_output = list(torch.split(x, 4, dim=2)) 369 return torch.cat( 370 [torch.ones(2, 32, 4, 16)] + split_output + [torch.ones(2, 32, 4, 16)], 371 dim=0, 372 ) 373 374 def split_stack_dim_mismatch3(x): 375 split_output = list(torch.split(x, 4, dim=2)) 376 return torch.stack( 377 [torch.ones(2, 32, 4, 16)] + split_output + [torch.ones(2, 32, 4, 16)], 378 dim=0, 379 ) 380 381 def input_shuffling(x): 382 split_output = list(torch.split(x, 4, dim=1)) 383 return torch.cat( 384 [torch.ones(2, 4, 32, 16)] 385 + [split_output[1], split_output[2], split_output[3]] 386 + [torch.ones(2, 4, 32, 16)] 387 + [split_output[5], split_output[6], split_output[7]] 388 + [torch.ones(2, 4, 32, 16)], 389 dim=1, 390 ) 391 392 def input_shuffling_stack(x): 393 split_output = list(torch.split(x, 4, dim=1)) 394 return torch.stack( 395 [torch.ones(2, 4, 32, 16)] 396 + [split_output[1], split_output[2], split_output[3]] 397 + [torch.ones(2, 4, 32, 16)] 398 + [split_output[5], split_output[6], split_output[7]] 399 + [torch.ones(2, 4, 32, 16)], 400 dim=1, 401 ) 402 403 def input_shuffling_dim_mismatch(x): 404 split_output = list(torch.split(x, 4, dim=1)) 405 return torch.cat( 406 [torch.ones(2, 4, 32, 16)] 407 + [split_output[1], split_output[2], split_output[3]] 408 + [torch.ones(2, 4, 32, 16)] 409 + [split_output[5], split_output[6], split_output[7]] 410 + [torch.ones(2, 4, 32, 16)], 411 dim=2, 412 ) 413 414 def input_shuffling_dim_mismatch_stack(x): 415 split_output = list(torch.split(x, 4, dim=1)) 416 return torch.stack( 417 [torch.ones(2, 4, 32, 16)] 418 + [split_output[1], split_output[2], split_output[3]] 419 + [torch.ones(2, 4, 32, 16)] 420 + [split_output[5], split_output[6], split_output[7]] 421 + [torch.ones(2, 4, 32, 16)], 422 dim=2, 423 ) 424 425 def input_shuffling_multiple_output(x): 426 split_output = list(torch.split(x, 4, dim=1)) 427 cat1 = torch.cat( 428 [torch.ones(2, 4, 32, 16)] 429 + [split_output[1], split_output[2], split_output[3]] 430 + [torch.ones(2, 4, 32, 16)], 431 dim=2, 432 ) 433 stack1 = torch.stack( 434 [ 435 torch.ones(2, 4, 32, 16), 436 split_output[4], 437 split_output[5], 438 torch.ones(2, 4, 32, 16), 439 ], 440 dim=1, 441 ) 442 443 relu1 = torch.relu(split_output[6]) 444 445 return cat1, stack1, relu1 446 447 def input_shuffling_direct_output(x): 448 split_output = list(torch.split(x, 4, dim=1)) 449 cat1 = torch.cat( 450 [torch.ones(2, 4, 32, 16)] 451 + [split_output[1], split_output[2], split_output[3]] 452 + [torch.ones(2, 4, 32, 16)], 453 dim=2, 454 ) 455 stack1 = torch.stack( 456 [ 457 torch.ones(2, 4, 32, 16), 458 split_output[4], 459 split_output[5], 460 torch.ones(2, 4, 32, 16), 461 ], 462 dim=1, 463 ) 464 465 return cat1, stack1, split_output[6] 466 467 def input_shuffling_multiple_output_same_ranges(x): 468 split_output = list(torch.split(x, 4, dim=1)) 469 cat1 = torch.cat( 470 [torch.ones(2, 4, 32, 16)] 471 + [split_output[1], split_output[2], split_output[3]] 472 + [torch.ones(2, 4, 32, 16)], 473 dim=2, 474 ) 475 476 cat2 = torch.cat( 477 [torch.ones(2, 4, 32, 16)] 478 + [split_output[1], split_output[2], split_output[3]] 479 + [torch.ones(2, 4, 32, 16)], 480 dim=2, 481 ) 482 stack1 = torch.stack( 483 [ 484 torch.ones(2, 4, 32, 16), 485 split_output[4], 486 split_output[5], 487 torch.ones(2, 4, 32, 16), 488 ], 489 dim=1, 490 ) 491 492 relu1 = torch.relu(split_output[6]) 493 494 return cat1, cat2, stack1, relu1 495 496 def unequal_split_multiple_output(x): 497 split_output = list(torch.split(x, [2, 4, 4, 4, 4, 4, 8, 2], dim=1)) 498 cat1 = torch.cat( 499 [torch.ones(2, 4, 32, 16)] 500 + [split_output[1], split_output[2], split_output[3]] 501 + [torch.ones(2, 4, 32, 16)], 502 dim=2, 503 ) 504 stack1 = torch.stack( 505 [ 506 torch.ones(2, 4, 32, 16), 507 split_output[4], 508 split_output[5], 509 torch.ones(2, 4, 32, 16), 510 ], 511 dim=1, 512 ) 513 514 relu1 = torch.relu(split_output[6]) 515 516 return cat1, stack1, relu1 517 518 def multi_split_cat(x1, x2): 519 split_output_1 = list(torch.split(x1, 4, dim=1)) 520 split_output_2 = list(torch.split(x2, 4, dim=1)) 521 cat1 = torch.cat( 522 [torch.ones(2, 4, 32, 16)] 523 + [split_output_1[1], split_output_1[2], split_output_1[3]] 524 + [torch.ones(2, 4, 32, 16)] 525 + [split_output_2[1], split_output_2[2], split_output_2[3]] 526 + [torch.ones(2, 4, 32, 16)], 527 dim=2, 528 ) 529 stack1 = torch.stack( 530 [ 531 torch.ones(2, 4, 32, 16), 532 split_output_1[4], 533 split_output_1[5], 534 torch.ones(2, 4, 32, 16), 535 split_output_2[4], 536 split_output_2[5], 537 torch.ones(2, 4, 32, 16), 538 ], 539 dim=1, 540 ) 541 542 relu1 = torch.relu(split_output_1[6]) 543 relu2 = torch.relu(split_output_2[6]) 544 545 return cat1, stack1, relu1, relu2 546 547 # TODO: Add more tests: 548 # * Cases where replacement shouldn't happen 549 default_args = [ 550 torch.randn(2, 32, 32, 16), 551 ] 552 multi_args = [ 553 torch.randn(2, 32, 32, 16), 554 torch.randn(2, 32, 32, 16), 555 ] 556 for ( 557 fn, 558 expected_split_added, 559 expected_split_removed, 560 expected_cat_added, 561 expected_cat_removed, 562 expected_sections_removed, 563 args, 564 ) in [ 565 (simple_split_cat, 0, 0, 0, 0, 0, default_args), 566 (simple_split_cat_argspec1, 0, 0, 0, 0, 0, default_args), 567 (simple_split_cat_argspec2, 0, 0, 0, 0, 0, default_args), 568 (simple_split_cat_argspec3, 0, 1, 0, 1, 7, default_args), 569 (simple_split_cat_argspec4, 0, 1, 0, 1, 7, default_args), 570 (simple_split_stack, 0, 1, 0, 1, 7, default_args), 571 (simple_split_stack_argspec1, 0, 1, 0, 1, 7, default_args), 572 (simple_split_stack_argspec2, 0, 1, 0, 1, 7, default_args), 573 (split_cat_addn_args, 0, 1, 1, 1, 7, default_args), 574 (split_stack_addn_args, 0, 1, 1, 1, 7, default_args), 575 (split_cat_addn_args_dim2, 0, 1, 1, 1, 7, default_args), 576 (split_cat_dim_mismatch, 0, 1, 1, 1, 7, default_args), 577 (split_stack_dim_mismatch, 0, 1, 1, 1, 7, default_args), 578 (split_cat_dim_mismatch2, 0, 1, 1, 1, 7, default_args), 579 (split_stack_dim_mismatch2, 0, 1, 1, 1, 7, default_args), 580 (split_cat_dim_mismatch3, 0, 1, 1, 1, 7, default_args), 581 (split_stack_dim_mismatch3, 0, 1, 1, 1, 7, default_args), 582 (input_shuffling, 1, 1, 1, 1, 4, default_args), 583 (input_shuffling_stack, 1, 1, 1, 1, 4, default_args), 584 (input_shuffling_dim_mismatch, 1, 1, 1, 1, 4, default_args), 585 (input_shuffling_dim_mismatch_stack, 1, 1, 1, 1, 4, default_args), 586 (input_shuffling_multiple_output, 1, 1, 2, 2, 3, default_args), 587 (input_shuffling_direct_output, 1, 1, 2, 2, 3, default_args), 588 (unequal_split_multiple_output, 1, 1, 2, 2, 3, default_args), 589 (multi_split_cat, 1, 1, 2, 2, 3, multi_args), 590 ]: 591 expected = fn(*args) 592 actual = torch.compile(fn)(*args) 593 594 torch.testing.assert_close(actual, expected) 595 self.assertEqual( 596 counters["inductor"]["scmerge_split_added"], 597 expected_split_added, 598 ) 599 self.assertEqual( 600 counters["inductor"]["scmerge_split_removed"], 601 expected_split_removed, 602 ) 603 self.assertEqual( 604 counters["inductor"]["scmerge_cat_added"], 605 expected_cat_added, 606 ) 607 self.assertEqual( 608 counters["inductor"]["scmerge_cat_removed"], 609 expected_cat_removed, 610 ) 611 self.assertEqual( 612 counters["inductor"]["scmerge_split_sections_removed"], 613 expected_sections_removed, 614 ) 615 counters.clear() 616 617 @torch._inductor.config.patch( 618 pre_grad_fusion_options={}, 619 post_grad_fusion_options={}, 620 ) 621 def test_config_flag_is_respected(self): 622 def split_with_cat(x): 623 fs = torch.split(x, [4, 4, 24], dim=-1) 624 item0 = fs[0] 625 item1 = fs[1] 626 item2 = fs[2] 627 628 final_items = [item0, item1] 629 final_items.extend(item2.split((4, 4, 4, 4, 4, 4), 1)) 630 631 return torch.cat(final_items, dim=1) 632 633 args = [ 634 torch.randn(2, 32), 635 ] 636 637 expected = split_with_cat(*args) 638 actual = torch.compile(split_with_cat)(*args) 639 640 torch.testing.assert_close(actual, expected) 641 self.assertEqual( 642 counters["inductor"]["merge_splits_pass"], 643 0, 644 ) 645 self.assertEqual( 646 counters["inductor"]["normalization_pass"], 647 0, 648 ) 649 650 @patch 651 def test_split_cat_merge_mutation(self): 652 args = [ 653 torch.randn(2, 32, 32, 16), 654 ] 655 656 def split_cat_mutation(x): 657 splits = torch.split(x, 4, dim=1) 658 splits[1].copy_(splits[0]) 659 return torch.cat(splits, dim=1) 660 661 expected = split_cat_mutation(*args) 662 actual = torch.compile(split_cat_mutation)(*args) 663 664 torch.testing.assert_close(actual, expected) 665 666 self.assertEqual(counters["inductor"]["scmerge_split_removed"], 0) 667 self.assertEqual(counters["inductor"]["scmerge_cat_removed"], 0) 668 669 @patch 670 def test_split_squeeze(self): 671 def split_squeeze_stack(x): 672 items = list(torch.split(x, 1, dim=1)) 673 split_items = [torch.squeeze(s, 1) for s in items] 674 return torch.stack(split_items) 675 676 def split_squeeze_stack_callmethod(x): 677 items = list(torch.split(x, 1, dim=1)) 678 split_items = [s.squeeze(1) for s in items] 679 return torch.stack(split_items) 680 681 def split_squeeze_stack_callmethod_none_dim(x): 682 items = list(torch.split(x, 1, dim=1)) 683 split_items = [s.squeeze() for s in items] 684 return torch.stack(split_items) 685 686 def split_squeeze_stack_kwarg1(x): 687 items = list(torch.split(x, 1, dim=1)) 688 split_items = [torch.squeeze(s, dim=1) for s in items] 689 return torch.stack(split_items) 690 691 def split_squeeze_stack_kwarg1_callmethod(x): 692 items = list(torch.split(x, 1, dim=1)) 693 split_items = [s.squeeze(dim=1) for s in items] 694 return torch.stack(split_items) 695 696 def split_squeeze_multi_squeeze_users(x): 697 items = list(torch.split(x, 1, dim=1)) 698 split_items = [torch.squeeze(s, 1) for s in items] 699 return ( 700 torch.stack(split_items), 701 torch.relu(split_items[0]), 702 torch.tanh(split_items[1]), 703 ) 704 705 def split_size_not_1(x): 706 items = list(torch.split(x, 2, dim=1)) 707 split_items = [torch.squeeze(s, 1) for s in items] 708 return torch.stack(split_items) 709 710 def dim_mismatch(x): 711 items = list(torch.split(x, 1, dim=1)) 712 split_items = [torch.squeeze(s, 0) for s in items] 713 return torch.stack(split_items) 714 715 def other_users(x): 716 items = list(torch.split(x, 1, dim=1)) 717 split_items = [torch.squeeze(s, 1) for s in items] 718 return torch.stack(split_items), torch.relu(items[0]) 719 720 def other_users_2(x): 721 items = list(torch.split(x, 1, dim=1)) 722 split_items = [torch.squeeze(s, 1) for s in items[1:]] 723 return torch.stack(split_items), torch.relu(items[0]) 724 725 def graph_should_be_topological_sorted(x): 726 output = [] 727 for t in x.split(1): 728 output.append(torch.sin(t.squeeze(dim=0))) 729 output = torch.stack(output) 730 return output 731 732 args = [ 733 torch.randn(2, 32), 734 ] 735 for fn, split_squeeze_replaced in [ 736 (split_squeeze_stack, 1), 737 (split_squeeze_stack_callmethod, 1), 738 # TODO handle none dim 739 (split_squeeze_stack_callmethod_none_dim, 0), 740 (split_squeeze_stack_kwarg1, 1), 741 (split_squeeze_stack_kwarg1_callmethod, 1), 742 (split_squeeze_multi_squeeze_users, 1), 743 (split_size_not_1, 0), 744 (dim_mismatch, 0), 745 (other_users, 0), 746 (other_users_2, 0), 747 (graph_should_be_topological_sorted, 1), 748 ]: 749 expected = fn(*args) 750 actual = torch.compile(fn)(*args) 751 752 torch.testing.assert_close(actual, expected) 753 self.assertEqual( 754 counters["inductor"]["split_cat_pass"], 755 split_squeeze_replaced, 756 ) 757 counters.clear() 758 759 @patch 760 def test_unbind_stack(self): 761 def unbind_stack(x): 762 return torch.stack(torch.unbind(x, 1), 1) 763 764 def unbind_cat(x): 765 return torch.cat(torch.unbind(x, dim=-3), 1) 766 767 def unbind_stack_argspec1(x): 768 return torch.stack(torch.unbind(input=x, dim=1), dim=1) 769 770 def unbind_stack_argspec2(x): 771 return torch.stack(tensors=torch.unbind(x, dim=1), dim=1) 772 773 def dim_mismatch(x): 774 return torch.stack(torch.unbind(x, dim=1), 0) 775 776 def split_squeeze_stack(x): 777 items = list(torch.split(x, 1, dim=1)) 778 split_items = [torch.squeeze(s, 1) for s in items] 779 return torch.stack(split_items, 1) 780 781 def split_squeeze_stack_callmethod(x): 782 items = list(torch.split(x, 1, dim=1)) 783 split_items = [torch.squeeze(s, 1) for s in items] 784 return torch.stack(split_items, 1) 785 786 def other_users(x): 787 items = list(torch.split(x, 1, dim=1)) 788 split_items = [torch.squeeze(s, 1) for s in items] 789 return torch.stack(split_items, 1), torch.relu(items[0]) 790 791 def other_users_2(x): 792 items = list(torch.split(x, 1, dim=1)) 793 split_items = [torch.squeeze(s, 1) for s in items[1:]] 794 return torch.stack(split_items, 1), torch.relu(items[0]) 795 796 def unbind_cat_addn_args(x): 797 split_output = list(torch.unbind(x, dim=1)) 798 799 return torch.cat( 800 [torch.ones(2, 32, 16)] + split_output + [torch.ones(2, 32, 16)], 801 dim=1, 802 ) 803 804 def unbind_stack_addn_args(x): 805 split_output = list(torch.unbind(x, dim=1)) 806 return torch.stack( 807 [torch.ones(2, 32, 16)] 808 + split_output 809 + [torch.ones(2, 32, 16), torch.ones(2, 32, 16)], 810 dim=1, 811 ) 812 813 def unbind_cat_addn_args_dim2(x): 814 split_output = list(torch.unbind(x, dim=2)) 815 return torch.cat( 816 [torch.ones(2, 32, 16)] + split_output + [torch.ones(2, 32, 16)], 817 dim=2, 818 ) 819 820 # split_dim=1, cat_dim=2 821 def unbind_cat_dim_mismatch(x): 822 split_output = list(torch.unbind(x, dim=1)) 823 return torch.cat( 824 [torch.ones(2, 32, 16)] + split_output + [torch.ones(2, 32, 16)], 825 dim=2, 826 ) 827 828 def unbind_stack_dim_mismatch(x): 829 split_output = list(torch.unbind(x, dim=1)) 830 return torch.stack( 831 [torch.ones(2, 32, 16)] + split_output + [torch.ones(2, 32, 16)], 832 dim=2, 833 ) 834 835 def unbind_cat_multi_users(x): 836 split_output = list(torch.unbind(x, dim=1)) 837 return torch.cat( 838 [torch.ones(2, 32, 16)] + split_output + [torch.ones(2, 32, 16)], 839 dim=1, 840 ), torch.stack( 841 [torch.ones(2, 32, 16)] 842 + split_output 843 + [torch.ones(2, 32, 16), torch.ones(2, 32, 16)], 844 dim=1, 845 ) 846 847 def unbind_cat_multi_users_diff_dims(x): 848 split_output = list(torch.unbind(x, dim=1)) 849 return torch.cat( 850 [torch.ones(2, 32, 16)] + split_output + [torch.ones(2, 32, 16)], 851 dim=1, 852 ), torch.stack( 853 [torch.ones(2, 32, 16)] + split_output + [torch.ones(2, 32, 16)], 854 dim=2, 855 ) 856 857 args = [ 858 torch.randn(2, 32, 32, 16), 859 ] 860 for ( 861 fn, 862 expected_unbind_added, 863 expected_unbind_removed, 864 expected_cat_added, 865 expected_cat_removed, 866 expected_sections_removed, 867 expected_unbind_normalized, 868 ) in [ 869 (unbind_stack, 0, 1, 0, 1, 31, 2), 870 (unbind_stack_argspec1, 0, 1, 0, 1, 31, 2), 871 (unbind_stack_argspec2, 0, 1, 0, 1, 31, 2), 872 (dim_mismatch, 0, 1, 0, 1, 31, 2), 873 (split_squeeze_stack, 0, 1, 0, 1, 31, 2), 874 (split_squeeze_stack_callmethod, 0, 1, 0, 1, 31, 2), 875 (other_users, 0, 0, 0, 0, 0, 2), 876 (other_users_2, 0, 0, 0, 0, 0, 2), 877 (unbind_cat_addn_args, 0, 1, 1, 1, 31, 1), 878 (unbind_stack_addn_args, 0, 1, 1, 1, 31, 2), 879 (unbind_cat_addn_args_dim2, 0, 1, 1, 1, 31, 1), 880 (unbind_cat_dim_mismatch, 0, 1, 1, 1, 31, 1), 881 (unbind_stack_dim_mismatch, 0, 1, 1, 1, 31, 2), 882 (unbind_cat_multi_users, 0, 1, 2, 2, 31, 2), 883 (unbind_cat_multi_users_diff_dims, 0, 1, 2, 2, 31, 2), 884 ]: 885 expected = fn(*args) 886 actual = torch.compile(fn)(*args) 887 888 torch.testing.assert_close(actual, expected) 889 self.assertEqual( 890 counters["inductor"]["scmerge_split_added"], 891 expected_unbind_added, 892 msg=f"for {fn}", 893 ) 894 self.assertEqual( 895 counters["inductor"]["scmerge_split_removed"], 896 expected_unbind_removed, 897 msg=f"for {fn}", 898 ) 899 self.assertEqual( 900 counters["inductor"]["scmerge_cat_added"], 901 expected_cat_added, 902 msg=f"for {fn}", 903 ) 904 self.assertEqual( 905 counters["inductor"]["scmerge_cat_removed"], 906 expected_cat_removed, 907 msg=f"for {fn}", 908 ) 909 self.assertEqual( 910 counters["inductor"]["scmerge_split_sections_removed"], 911 expected_sections_removed, 912 msg=f"for {fn}", 913 ) 914 self.assertEqual( 915 counters["inductor"]["normalization_pass"], 916 expected_unbind_normalized, 917 msg=f"for {fn}", 918 ) 919 counters.clear() 920 921 @patch 922 def test_split_cat_new_patterns(self): 923 def split_cat_split(x): 924 l1_out = torch.split(x, [200, 50, 50, 20, 20, 20, 20, 20, 20, 50, 30], 1) 925 item0 = l1_out[0] 926 item1 = l1_out[1] 927 item2 = l1_out[2] 928 item3 = l1_out[3] 929 item4 = l1_out[4] 930 item5 = l1_out[5] 931 item6 = l1_out[6] 932 item7 = l1_out[7] 933 item8 = l1_out[8] 934 item9 = l1_out[9] 935 item10 = l1_out[10] 936 cat_1 = torch.cat((item0, item1), 1) 937 cat_2 = torch.cat((item9, item10), 1) 938 l2_out = torch.split(cat_1, [50, 120, 80], 1) 939 l3_out = torch.split(cat_2, [10, 20, 50], 1) 940 item11 = l2_out[0] 941 item12 = l2_out[1] 942 item13 = l2_out[2] 943 item14 = l3_out[0] 944 item15 = l3_out[1] 945 item16 = l3_out[2] 946 947 output = torch.cat( 948 [ 949 item11, 950 item12, 951 item13, 952 item14, 953 item15, 954 item16, 955 item2, 956 item3, 957 item4, 958 item5, 959 item6, 960 item7, 961 item8, 962 ], 963 1, 964 ) 965 return output 966 967 def split_cat_split_kwarg(x): 968 l1_out = torch.split( 969 x, [200, 50, 50, 20, 20, 20, 20, 20, 20, 50, 30], dim=1 970 ) 971 item0 = l1_out[0] 972 item1 = l1_out[1] 973 item2 = l1_out[2] 974 item3 = l1_out[3] 975 item4 = l1_out[4] 976 item5 = l1_out[5] 977 item6 = l1_out[6] 978 item7 = l1_out[7] 979 item8 = l1_out[8] 980 item9 = l1_out[9] 981 item10 = l1_out[10] 982 cat_1 = torch.cat((item0, item1), dim=1) 983 cat_2 = torch.cat((item9, item10), dim=1) 984 l2_out = torch.split(cat_1, [50, 120, 80], dim=1) 985 l3_out = torch.split(cat_2, [10, 20, 50], dim=1) 986 item11 = l2_out[0] 987 item12 = l2_out[1] 988 item13 = l2_out[2] 989 item14 = l3_out[0] 990 item15 = l3_out[1] 991 item16 = l3_out[2] 992 993 output = torch.cat( 994 [ 995 item11, 996 item12, 997 item13, 998 item14, 999 item15, 1000 item16, 1001 item2, 1002 item3, 1003 item4, 1004 item5, 1005 item6, 1006 item7, 1007 item8, 1008 ], 1009 dim=1, 1010 ) 1011 return output 1012 1013 def remove_cat_node_with_all_getitmes(x): 1014 l1_out = torch.split( 1015 x, [50, 50, 200, 20, 20, 20, 20, 20, 40, 10, 50], dim=0 1016 ) 1017 item0 = l1_out[0] 1018 item1 = l1_out[1] 1019 item2 = l1_out[2] 1020 item3 = l1_out[3] 1021 item4 = l1_out[4] 1022 item5 = l1_out[5] 1023 item6 = l1_out[6] 1024 item7 = l1_out[7] 1025 item8 = l1_out[8] 1026 item9 = l1_out[9] 1027 item10 = l1_out[10] 1028 cat = torch.cat( 1029 ( 1030 item0, 1031 item1, 1032 item2, 1033 item3, 1034 item4, 1035 item5, 1036 item6, 1037 item7, 1038 item8, 1039 item9, 1040 item10, 1041 ), 1042 dim=0, 1043 ) 1044 cat_1 = torch.cat((item0, item1), dim=0) 1045 cat_2 = torch.cat((item0, item10), dim=0) 1046 l2_out = torch.split(cat_1, [20, 30, 50], dim=0) 1047 l3_out = torch.split(cat_2, [10, 60, 30], dim=0) 1048 item11 = l2_out[0] 1049 item12 = l2_out[1] 1050 item13 = l2_out[2] 1051 item14 = l3_out[0] 1052 item15 = l3_out[1] 1053 item16 = l3_out[2] 1054 1055 output = torch.cat( 1056 [ 1057 item11, 1058 item12, 1059 item13, 1060 item14, 1061 item15, 1062 item16, 1063 item2, 1064 item3, 1065 item4, 1066 item5, 1067 item6, 1068 item7, 1069 item8, 1070 ], 1071 dim=0, 1072 ) 1073 return torch.cat((output, cat), dim=0) 1074 1075 def mutate_cat_node_with_some_getitmes(x): 1076 l1_out = torch.split( 1077 x, [50, 50, 200, 20, 20, 20, 20, 20, 40, 10, 50], dim=0 1078 ) 1079 item0 = l1_out[0] 1080 item1 = l1_out[1] 1081 item2 = l1_out[2] 1082 item3 = l1_out[3] 1083 item4 = l1_out[4] 1084 item5 = l1_out[5] 1085 item6 = l1_out[6] 1086 item7 = l1_out[7] 1087 item8 = l1_out[8] 1088 item9 = l1_out[9] 1089 item10 = l1_out[10] 1090 cat = torch.cat( 1091 ( 1092 item6, 1093 item7, 1094 item8, 1095 item9, 1096 item10, 1097 item2, 1098 item3, 1099 item4, 1100 item5, 1101 ), 1102 dim=0, 1103 ) 1104 cat_1 = torch.cat((item0, item1), dim=0) 1105 cat_2 = torch.cat((item0, item10), dim=0) 1106 l2_out = torch.split(cat_1, [20, 30, 50], dim=0) 1107 l3_out = torch.split(cat_2, [10, 60, 30], dim=0) 1108 item11 = l2_out[0] 1109 item12 = l2_out[1] 1110 item13 = l2_out[2] 1111 item14 = l3_out[0] 1112 item15 = l3_out[1] 1113 item16 = l3_out[2] 1114 1115 output = torch.cat( 1116 [ 1117 item11, 1118 item12, 1119 item13, 1120 item14, 1121 item15, 1122 item16, 1123 item2, 1124 ], 1125 dim=0, 1126 ) 1127 return torch.cat((output, cat), dim=0) 1128 1129 @torch._inductor.config.patch( 1130 pre_grad_fusion_options={ 1131 "split_cat_to_slices_pass": {}, 1132 }, 1133 post_grad_fusion_options={}, 1134 ) 1135 def split_cat_to_slices(x): 1136 x_c = x.clone() 1137 x_c_2 = x.clone() 1138 l1_out = torch.split(x, [50, 50, 50, 50, 50, 50, 50, 50, 50, 50], dim=0) 1139 l2_out = torch.split(x_c, [50, 50, 50, 50, 50, 50, 50, 50, 50, 50], dim=0) 1140 l3_out = torch.split(x_c_2, [100, 100, 100, 100, 100], dim=0) 1141 item0 = l1_out[0] 1142 item1 = l1_out[1] 1143 item2 = l1_out[2] 1144 item3 = l1_out[3] 1145 item4 = l1_out[4] 1146 item5 = l1_out[5] 1147 item6 = l1_out[6] 1148 item7 = l1_out[7] 1149 item8 = l1_out[8] 1150 item9 = l1_out[9] 1151 item0_c = l2_out[0] 1152 item1_c = l2_out[1] 1153 item2_c = l2_out[2] 1154 item3_c = l2_out[3] 1155 item4_c = l2_out[4] 1156 item5_c = l2_out[5] 1157 item6_c = l2_out[6] 1158 item7_c = l2_out[7] 1159 item8_c = l2_out[8] 1160 item9_c = l2_out[9] 1161 item0_c_2 = l3_out[0] 1162 item1_c_2 = l3_out[1] 1163 item2_c_2 = l3_out[2] 1164 item3_c_2 = l3_out[3] 1165 item4_c_2 = l3_out[4] 1166 other = item0.clone() 1167 return torch.cat( 1168 [ 1169 other, 1170 item0, 1171 item1, 1172 item2, 1173 item3, 1174 item4, 1175 item5, 1176 item6, 1177 item7, 1178 item8, 1179 item9, 1180 item4_c, 1181 item5_c, 1182 item6_c, 1183 item7_c, 1184 item8_c, 1185 item9_c, 1186 item0_c, 1187 item1_c, 1188 item2_c, 1189 item3_c, 1190 item0_c_2, 1191 item1_c_2, 1192 item2_c_2, 1193 item3_c_2, 1194 item4_c_2, 1195 ], 1196 dim=0, 1197 ) 1198 1199 @torch._inductor.config.patch( 1200 pre_grad_fusion_options={ 1201 "unbind_cat_to_view_pass": {}, 1202 }, 1203 post_grad_fusion_options={}, 1204 ) 1205 def unbind_cat_to_view(x): 1206 y = x.view(10, 50, 500) 1207 z = x.view(10, 50, 500) 1208 l1_out = torch.unbind(y, dim=0) 1209 l2_out = torch.unbind(z, dim=0) 1210 item0 = l1_out[0] 1211 item1 = l1_out[1] 1212 item2 = l1_out[2] 1213 item3 = l1_out[3] 1214 item4 = l1_out[4] 1215 item5 = l1_out[5] 1216 item6 = l1_out[6] 1217 item7 = l1_out[7] 1218 item8 = l1_out[8] 1219 item9 = l1_out[9] 1220 item2_0 = l2_out[0] 1221 item2_1 = l2_out[1] 1222 item2_2 = l2_out[2] 1223 item2_3 = l2_out[3] 1224 item2_4 = l2_out[4] 1225 item2_5 = l2_out[5] 1226 item2_6 = l2_out[6] 1227 item2_7 = l2_out[7] 1228 item2_8 = l2_out[8] 1229 item2_9 = l2_out[9] 1230 other1 = item7.clone() 1231 other2 = item8.clone() 1232 other3 = item9.clone() 1233 cat = torch.cat( 1234 [ 1235 item0, 1236 item1, 1237 item2, 1238 item3, 1239 item4, 1240 item5, 1241 item6, 1242 other1, 1243 item2_0, 1244 item2_1, 1245 item2_2, 1246 item2_3, 1247 item2_4, 1248 item2_5, 1249 item2_6, 1250 item2_7, 1251 item2_8, 1252 item2_9, 1253 other2, 1254 other3, 1255 ], 1256 dim=1, 1257 ) 1258 return cat 1259 1260 @torch._inductor.config.patch( 1261 pre_grad_fusion_options={ 1262 "split_stack_to_cats_pass": {}, 1263 }, 1264 post_grad_fusion_options={}, 1265 ) 1266 def split_stack_to_cats_same_dim(x): 1267 x_c = x.view(10, 50, 500) 1268 l1_out = torch.unbind(x_c, dim=0) 1269 item0 = l1_out[0] 1270 item1 = l1_out[1] 1271 item2 = l1_out[2] 1272 item3 = l1_out[3] 1273 item4 = l1_out[4] 1274 item5 = l1_out[5] 1275 split1 = torch.split(item0, [250, 250], dim=1) 1276 split2 = torch.split(item1, [250, 250], dim=1) 1277 split3 = torch.split(item2, [250, 250], dim=1) 1278 split4 = torch.split(item3, [250, 250], dim=1) 1279 split5 = torch.split(item4, [250, 250], dim=1) 1280 split6 = torch.split(item5, [250, 250], dim=1) 1281 getitem0, getitem1 = split1[0], split1[1] 1282 getitem2, getitem3 = split2[0], split2[1] 1283 getitem4, getitem5 = split3[0], split3[1] 1284 getitem6, getitem7 = split4[0], split4[1] 1285 getitem8, getitem9 = split5[0], split5[1] 1286 getitem10, getitem11 = split6[0], split6[1] 1287 getitem0_c = getitem0.clone() 1288 getitem1_c = getitem1.clone() 1289 getitem2_c = getitem2.clone() 1290 return torch.stack( 1291 ( 1292 getitem0, 1293 getitem1, 1294 getitem2, 1295 getitem3, 1296 getitem4, 1297 getitem5, 1298 getitem0_c, 1299 getitem1_c, 1300 getitem6, 1301 getitem7, 1302 getitem8, 1303 getitem9, 1304 getitem10, 1305 getitem11, 1306 getitem2_c, 1307 ), 1308 dim=1, 1309 ) 1310 1311 @torch._inductor.config.patch( 1312 pre_grad_fusion_options={ 1313 "split_stack_to_cats_pass": {}, 1314 }, 1315 post_grad_fusion_options={}, 1316 ) 1317 def split_stack_to_cats_different_dim(x): 1318 l1_out = torch.split(x, [100, 100, 100, 100, 100], dim=1) 1319 x_c = x.clone() 1320 l2_out = torch.split(x_c, [100, 100, 100, 100, 100], dim=1) 1321 item0 = l1_out[0] 1322 item1 = l1_out[1] 1323 item2 = l1_out[2] 1324 item3 = l1_out[3] 1325 item4 = l1_out[4] 1326 item0_c = l2_out[0] 1327 item1_c = l2_out[1] 1328 item2_c = l2_out[2] 1329 item3_c = l2_out[3] 1330 item4_c = l2_out[4] 1331 other_1 = item0.clone() 1332 other_2 = item1.clone() 1333 other_3 = item2.clone() 1334 return torch.stack( 1335 ( 1336 other_1, 1337 other_2, 1338 other_3, 1339 item0, 1340 item1, 1341 item2, 1342 item3, 1343 item4, 1344 item0_c, 1345 item1_c, 1346 item2_c, 1347 item3_c, 1348 item4_c, 1349 ), 1350 dim=2, 1351 ) 1352 1353 @torch._inductor.config.patch( 1354 pre_grad_fusion_options={ 1355 "unbind_stack_to_slices_pass": {}, 1356 }, 1357 post_grad_fusion_options={}, 1358 ) 1359 def unbind_stack_to_slices(x): 1360 x_1 = x.view(50, 10, 500) 1361 l1_out = torch.unbind(x_1, dim=1) 1362 item0 = l1_out[0] 1363 item1 = l1_out[1] 1364 item2 = l1_out[2] 1365 item3 = l1_out[3] 1366 item4 = l1_out[4] 1367 item5 = l1_out[5] 1368 item6 = l1_out[6] 1369 item7 = l1_out[7] 1370 item8 = l1_out[8] 1371 item9 = l1_out[9] 1372 other_1 = item0.clone() 1373 other_2 = item1.clone() 1374 other_3 = item2.clone() 1375 return torch.stack( 1376 ( 1377 other_1, 1378 other_2, 1379 other_3, 1380 item0, 1381 item1, 1382 item2, 1383 item3, 1384 item4, 1385 item5, 1386 item6, 1387 item7, 1388 item8, 1389 item9, 1390 ), 1391 dim=1, 1392 ) 1393 1394 @torch._inductor.config.patch( 1395 pre_grad_fusion_options={ 1396 "normalization_pass": {}, 1397 "move_reshape_out_of_split_stack_pass": {}, 1398 }, 1399 post_grad_fusion_options={}, 1400 ) 1401 def move_reshape_out_of_split_stack(x): 1402 x_c = x.view(50000, 5) 1403 l1_out = torch.split(x_c, [1, 1, 1, 1, 1], dim=1) 1404 item0 = l1_out[0] 1405 item1 = l1_out[1] 1406 item2 = l1_out[2] 1407 item3 = l1_out[3] 1408 item4 = l1_out[4] 1409 reshape0 = item0.reshape(-1, 5) 1410 reshape1 = item1.reshape(-1, 5) 1411 reshape2 = item2.reshape(-1, 5) 1412 reshape3 = item3.reshape(-1, 5) 1413 reshape4 = item4.reshape(-1, 5) 1414 other0 = reshape0.clone() 1415 other1 = reshape1.clone() 1416 other2 = reshape2.clone() 1417 other3 = reshape3.clone() 1418 return torch.stack( 1419 ( 1420 other0, 1421 other1, 1422 other2, 1423 reshape0, 1424 reshape1, 1425 reshape2, 1426 reshape3, 1427 reshape4, 1428 other3, 1429 ), 1430 dim=0, 1431 ) 1432 1433 args = [ 1434 torch.randn(500, 500), 1435 ] 1436 for ( 1437 fn, 1438 expected_getitem_cat_merged, 1439 expected_cat_removed, 1440 expected_split_cat_to_slices, 1441 exptected_unbind_to_cat_view, 1442 expected_split_stack_to_cats, 1443 exptected_unbind_stack_to_slices, 1444 expected_move_reshape_out_of_split_stack, 1445 ) in [ 1446 (split_cat_split, 2, 0, 0, 0, 0, 0, 0), 1447 (split_cat_split_kwarg, 2, 0, 0, 0, 0, 0, 0), 1448 (remove_cat_node_with_all_getitmes, 0, 2, 0, 0, 0, 0, 0), 1449 (mutate_cat_node_with_some_getitmes, 0, 1, 0, 0, 0, 0, 0), 1450 (split_cat_to_slices, 0, 0, 1, 0, 0, 0, 0), 1451 (unbind_cat_to_view, 0, 0, 0, 1, 0, 0, 0), 1452 (split_stack_to_cats_same_dim, 0, 0, 0, 0, 1, 0, 0), 1453 (split_stack_to_cats_different_dim, 0, 0, 0, 0, 1, 0, 0), 1454 (unbind_stack_to_slices, 0, 0, 0, 0, 0, 1, 0), 1455 (move_reshape_out_of_split_stack, 0, 0, 0, 0, 0, 0, 1), 1456 ]: 1457 expected = fn(*args) 1458 actual = torch.compile(fn)(*args) 1459 1460 torch.testing.assert_close(actual, expected) 1461 self.assertEqual( 1462 counters["inductor"]["merge_getitem_cat_pass"], 1463 expected_getitem_cat_merged, 1464 ) 1465 self.assertEqual( 1466 counters["inductor"]["mutate_cat_pass"], 1467 expected_cat_removed, 1468 ) 1469 self.assertEqual( 1470 counters["inductor"]["split_cat_to_slices_pass"], 1471 expected_split_cat_to_slices, 1472 ) 1473 self.assertEqual( 1474 counters["inductor"]["unbind_cat_to_view_pass"], 1475 exptected_unbind_to_cat_view, 1476 ) 1477 self.assertEqual( 1478 counters["inductor"]["split_stack_to_cats_pass"], 1479 expected_split_stack_to_cats, 1480 ) 1481 self.assertEqual( 1482 counters["inductor"]["unbind_stack_to_slices_pass"], 1483 exptected_unbind_stack_to_slices, 1484 ) 1485 self.assertEqual( 1486 counters["inductor"]["move_reshape_out_of_split_stack_pass"], 1487 expected_move_reshape_out_of_split_stack, 1488 ) 1489 counters.clear() 1490 1491 def test_numpy_compat_normalization(self): 1492 def fn(x, y): 1493 a = torch.stack([x, y], axis=1) 1494 b = torch.mul(x, x2=y) 1495 c = torch.mul(x, x2=y) 1496 d = torch.mul(x, x2=y) 1497 e = torch.max(x, dim=1, keepdims=True) 1498 f = torch.dropout(x=x, p=0.5, train=True) 1499 return a, b, c, d, e, f 1500 1501 fn_t = torch.fx.symbolic_trace(fn) 1502 numpy_compat_normalization(fn_t.graph) 1503 1504 for n in fn_t.graph.nodes: 1505 for k in n.kwargs.keys(): 1506 self.assertTrue(k not in {"x", "x1", "x2", "a", "axis", "keepdims"}) 1507 1508 @patch 1509 @requires_gpu 1510 def test_stack_normalization_axis_kwarg(self): 1511 def fn(x, y): 1512 return torch.stack([x, y], axis=1) 1513 1514 x, y = (torch.rand((4, 4), device=GPU_TYPE) for _ in range(2)) 1515 expected = fn(x, y) 1516 actual = torch.compile(fn)(x, y) 1517 1518 self.assertEqual(actual, expected) 1519 1520 1521if __name__ == "__main__": 1522 if IS_LINUX and HAS_GPU: 1523 run_tests() 1524