xref: /aosp_15_r20/external/pytorch/test/inductor/test_split_cat_fx_passes.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2
3
4import torch
5from torch._dynamo.utils import counters, optimus_scuba_log
6from torch._inductor.fx_passes.misc_patterns import numpy_compat_normalization
7from torch._inductor.test_case import run_tests, TestCase
8from torch.testing._internal.common_utils import IS_LINUX
9from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
10from torch.testing._internal.triton_utils import requires_gpu
11
12
13def patch(f):
14    f = torch._inductor.config.patch(
15        pre_grad_fusion_options={
16            "normalization_pass": {},
17            "remove_split_with_size_one_pass": {},
18            "merge_getitem_cat_pass": {},
19            "merge_splits_pass": {},
20            "mutate_cat_pass": {},
21            "split_cat_pass": {},
22            "unbind_stack_pass": {},
23        },
24        post_grad_fusion_options={},
25    )(f)
26    return f
27
28
29class TestSplitCatFxPasses(TestCase):
30    @patch
31    def test_split_normalization(self):
32        def arg_only(x):
33            return [torch.relu(s) for s in torch.split(x, 2, 1)]
34
35        def arg_only_dim0(x):
36            return [torch.relu(s) for s in torch.split(x, 2, 0)]
37
38        def kwarg1(x):
39            return [torch.relu(s) for s in torch.split(x, 2, dim=1)]
40
41        def kwarg2(x):
42            return [
43                torch.relu(s) for s in torch.split(x, split_size_or_sections=2, dim=1)
44            ]
45
46        def kwarg3(x):
47            return [
48                torch.relu(s)
49                for s in torch.split(tensor=x, split_size_or_sections=2, dim=-1)
50            ]
51
52        def list_replace(x):
53            return [torch.relu(s) for s in torch.split(x, [16, 16], dim=1)]
54
55        def multi_split(x):
56            return [torch.split(s, 2, 1) for s in torch.split(x, 2, 1)]
57
58        def unequal_split(x):
59            return [torch.relu(s) for s in torch.split(x, 3, 1)]
60
61        def arg_only_cm(x):
62            return [torch.relu(s) for s in x.split(2, 1)]
63
64        def kwarg1_cm(x):
65            return [torch.relu(s) for s in x.split(2, dim=1)]
66
67        def kwarg2_cm(x):
68            return [torch.relu(s) for s in x.split(split_size=2, dim=1)]
69
70        def multi_split_cm(x):
71            return [s.split(2, 1) for s in x.split(2, 1)]
72
73        def unequal_split_cm(x):
74            return [torch.relu(s) for s in x.split(3, 1)]
75
76        def cm_with_list(x):
77            return [torch.relu(s) for s in x.split([16, 16], dim=-1)]
78
79        args = [
80            torch.randn(2, 32),
81        ]
82        for fn, expected_split_norm_count in [
83            (arg_only, 1),
84            (arg_only_dim0, 1),
85            (kwarg1, 1),
86            (kwarg2, 1),
87            (kwarg3, 1),
88            (list_replace, 0),
89            (multi_split, 17),
90            (unequal_split, 1),
91            (arg_only_cm, 1),
92            (kwarg1_cm, 1),
93            (kwarg2_cm, 1),
94            (multi_split_cm, 17),
95            (unequal_split_cm, 1),
96            (cm_with_list, 1),
97        ]:
98            expected = fn(*args)
99            actual = torch.compile(fn)(*args)
100
101            torch.testing.assert_close(actual, expected)
102            self.assertEqual(
103                counters["inductor"]["normalization_pass"],
104                expected_split_norm_count,
105                msg=f"for {fn}",
106            )
107            if expected_split_norm_count > 0:
108                self.assertIn("normalization_pass_pre_grad", optimus_scuba_log)
109            counters.clear()
110
111    @patch
112    def test_consecutive_split_merge(self):
113        def multi_split(x):
114            return [torch.split(s, 2, 1) for s in torch.split(x, 2, 1)]
115
116        def multi_split_2(x):
117            return [torch.split(s, 1, 1) for s in torch.split(x, 2, 1)]
118
119        def multi_split_2_neg_dim(x):
120            return [torch.split(s, 1, 1) for s in torch.split(x, 2, -1)]
121
122        def multi_split_with_sizes(x):
123            return [torch.split(s, 2, 1) for s in torch.split(x, [16, 16], 1)]
124
125        def multi_split_kwarg1(x):
126            return [torch.split(s, 2, dim=1) for s in torch.split(x, 2, dim=1)]
127
128        def multi_split_kwarg2(x):
129            return [
130                torch.split(s, split_size_or_sections=2, dim=1)
131                for s in torch.split(x, split_size_or_sections=2, dim=1)
132            ]
133
134        def unequal_multi_split(x):
135            fs = torch.split(x, [10, 10, 12], dim=1)
136            item0 = fs[0]
137            item1 = fs[1]
138            item2 = fs[2]
139
140            final_items = []
141            final_items.extend(item0.split([4, 6], 1))
142            final_items.extend(item1.split([6, 4], 1))
143            final_items.extend(item2.split([4, 4, 4], 1))
144
145            return [torch.relu(s) for s in final_items]
146
147        def unequal_multi_split_neg_index(x):
148            fs = torch.split(x, [10, 10, 12], dim=1)
149            item0 = fs[-3]
150            item1 = fs[-2]
151            item2 = fs[-1]
152
153            final_items = []
154            final_items.extend(item0.split([4, 6], 1))
155            final_items.extend(item1.split([6, 4], 1))
156            final_items.extend(item2.split([4, 4, 4], 1))
157
158            return [torch.relu(s) for s in final_items]
159
160        # Shouldn't merge
161        def diff_dims(x):
162            return [torch.split(s, 2, dim=0) for s in torch.split(x, 2, dim=1)]
163
164        def some_users_not_splits(x):
165            fs = torch.split(x, [10, 10, 12], dim=1)
166            item0 = fs[0]
167            item1 = fs[1]
168            item2 = fs[2]
169
170            final_items = []
171            final_items.extend(item0.split([4, 6], 1))
172            final_items.extend(item1.split([6, 4], 1))
173            final_items.append(torch.sin(item2))
174
175            return [torch.relu(s) for s in final_items]
176
177        def split_with_cat(x):
178            fs = torch.split(x, [4, 4, 24], dim=1)
179            item0 = fs[0]
180            item1 = fs[1]
181            item2 = fs[2]
182
183            final_items = [item0, item1]
184            final_items.extend(item2.split((4, 4, 4, 4, 4, 4), 1))
185
186            return torch.cat(final_items, dim=1)
187
188        def duplicate_getitems(x):
189            fs = torch.split(x, [10, 10, 12], dim=1)
190            item0 = fs[0]
191            item1_1 = fs[1]
192            item1_2 = fs[1]
193            item2 = fs[2]
194
195            final_items = []
196            final_items.extend(item0.split([4, 6], 1))
197            final_items.extend(item1_1.split([6, 4], 1))
198            final_items.extend(item1_2)
199            final_items.append(torch.sin(item2))
200
201            return [torch.relu(s) for s in final_items]
202
203        def duplicate_getitems_neg_index(x):
204            fs = torch.split(x, [10, 10, 12], dim=1)
205            item0 = fs[0]
206            item1_1 = fs[1]
207            item1_2 = fs[-2]  # negative index
208            item2 = fs[2]
209
210            final_items = []
211            final_items.extend(item0.split([4, 6], 1))
212            final_items.extend(item1_1.split([6, 4], 1))
213            final_items.extend(item1_2)
214            final_items.append(torch.sin(item2))
215
216            return [torch.relu(s) for s in final_items]
217
218        def split_getitem_gap(x):
219            fs = torch.split(x, [4, 4, 24], dim=1)
220            item0 = fs[0]
221            item2 = fs[2]
222
223            final_items = [
224                item0,
225            ]
226            final_items.extend(item2.split((4, 4, 4, 4, 4, 4), 1))
227
228            return torch.cat(final_items, dim=1)
229
230        def split_getitem_out_of_order(x):
231            fs = torch.split(x, [4, 4, 4, 20], dim=1)
232            item0 = fs[0]
233            item2 = fs[2]
234            item1 = fs[1]
235            item3 = fs[3]
236
237            final_items = [item0, item2, item1]
238            final_items.extend(item3.split((4, 4, 4, 4, 4), 1))
239
240            return torch.cat(final_items, dim=1)
241
242        def split_partial_getitem_cat(x):
243            fs = torch.split(x, [4, 4, 24], dim=1)
244            item0 = fs[0]
245            item2 = fs[2]
246
247            final_items = [
248                item0,
249            ]
250            final_items.extend(item2.split((4, 4, 4, 4, 4, 4), 1))
251
252            return torch.cat(final_items, dim=1)
253
254        args = [
255            torch.randn(2, 32),
256        ]
257        for fn, expected_split_merged in [
258            (multi_split, 0),
259            (multi_split_2, 16),
260            (multi_split_2_neg_dim, 16),
261            (multi_split_with_sizes, 2),
262            (multi_split_kwarg1, 0),
263            (multi_split_kwarg2, 0),
264            (unequal_multi_split, 3),
265            (unequal_multi_split_neg_index, 3),
266            (diff_dims, 0),
267            (some_users_not_splits, 2),
268            (split_with_cat, 1),
269            (duplicate_getitems, 1),
270            (duplicate_getitems_neg_index, 1),
271            (split_getitem_gap, 1),
272            (split_getitem_out_of_order, 1),
273            (split_partial_getitem_cat, 1),
274        ]:
275            expected = fn(*args)
276            actual = torch.compile(fn)(*args)
277
278            torch.testing.assert_close(actual, expected)
279            self.assertEqual(
280                counters["inductor"]["merge_splits_pass"],
281                expected_split_merged,
282            )
283            if expected_split_merged > 0:
284                self.assertIn("merge_splits_pass_pre_grad", optimus_scuba_log)
285            counters.clear()
286
287    @patch
288    def test_split_cat_merge(self):
289        def simple_split_cat(x):
290            return torch.cat(torch.split(x, 4, dim=1), dim=1)
291
292        def simple_split_cat_argspec1(x):
293            return torch.cat(torch.split(x, 4, dim=1), 1)
294
295        def simple_split_cat_argspec2(x):
296            return torch.cat(tensors=torch.split(x, 4, dim=1), dim=1)
297
298        def simple_split_cat_argspec3(x):
299            return torch.cat(torch.split(x, 4, dim=1), -2)
300
301        def simple_split_cat_argspec4(x):
302            return torch.cat(tensors=torch.split(x, 4, dim=1), dim=-2)
303
304        def simple_split_stack(x):
305            return torch.stack(torch.split(x, 4, dim=1), dim=1)
306
307        def simple_split_stack_argspec1(x):
308            return torch.stack(torch.split(x, 4, dim=1), 1)
309
310        def simple_split_stack_argspec2(x):
311            return torch.stack(tensors=torch.split(x, 4, dim=1), dim=1)
312
313        def split_cat_addn_args(x):
314            split_output = list(torch.split(x, 4, dim=1))
315            return torch.cat(
316                [torch.ones(2, 5, 32, 16)] + split_output + [torch.ones(2, 6, 32, 16)],
317                dim=1,
318            )
319
320        def split_stack_addn_args(x):
321            split_output = list(torch.split(x, 4, dim=1))
322            return torch.stack(
323                [torch.ones(2, 4, 32, 16)]
324                + split_output
325                + [torch.ones(2, 4, 32, 16), torch.ones(2, 4, 32, 16)],
326                dim=1,
327            )
328
329        def split_cat_addn_args_dim2(x):
330            split_output = list(torch.split(x, 4, dim=2))
331            return torch.cat(
332                [torch.ones(2, 32, 5, 16)] + split_output + [torch.ones(2, 32, 6, 16)],
333                dim=2,
334            )
335
336        # split_dim=1, cat_dim=2
337        def split_cat_dim_mismatch(x):
338            split_output = list(torch.split(x, 4, dim=1))
339            return torch.cat(
340                [torch.ones(2, 4, 32, 16)] + split_output + [torch.ones(2, 4, 32, 16)],
341                dim=2,
342            )
343
344        def split_stack_dim_mismatch(x):
345            split_output = list(torch.split(x, 4, dim=1))
346            return torch.stack(
347                [torch.ones(2, 4, 32, 16)] + split_output + [torch.ones(2, 4, 32, 16)],
348                dim=2,
349            )
350
351        # split_dim=1, cat_dim=3
352        def split_cat_dim_mismatch2(x):
353            split_output = list(torch.split(x, 4, dim=1))
354            return torch.cat(
355                [torch.ones(2, 4, 32, 16)] + split_output + [torch.ones(2, 4, 32, 16)],
356                dim=3,
357            )
358
359        def split_stack_dim_mismatch2(x):
360            split_output = list(torch.split(x, 4, dim=1))
361            return torch.stack(
362                [torch.ones(2, 4, 32, 16)] + split_output + [torch.ones(2, 4, 32, 16)],
363                dim=3,
364            )
365
366        # split_dim=2, cat_dim=0
367        def split_cat_dim_mismatch3(x):
368            split_output = list(torch.split(x, 4, dim=2))
369            return torch.cat(
370                [torch.ones(2, 32, 4, 16)] + split_output + [torch.ones(2, 32, 4, 16)],
371                dim=0,
372            )
373
374        def split_stack_dim_mismatch3(x):
375            split_output = list(torch.split(x, 4, dim=2))
376            return torch.stack(
377                [torch.ones(2, 32, 4, 16)] + split_output + [torch.ones(2, 32, 4, 16)],
378                dim=0,
379            )
380
381        def input_shuffling(x):
382            split_output = list(torch.split(x, 4, dim=1))
383            return torch.cat(
384                [torch.ones(2, 4, 32, 16)]
385                + [split_output[1], split_output[2], split_output[3]]
386                + [torch.ones(2, 4, 32, 16)]
387                + [split_output[5], split_output[6], split_output[7]]
388                + [torch.ones(2, 4, 32, 16)],
389                dim=1,
390            )
391
392        def input_shuffling_stack(x):
393            split_output = list(torch.split(x, 4, dim=1))
394            return torch.stack(
395                [torch.ones(2, 4, 32, 16)]
396                + [split_output[1], split_output[2], split_output[3]]
397                + [torch.ones(2, 4, 32, 16)]
398                + [split_output[5], split_output[6], split_output[7]]
399                + [torch.ones(2, 4, 32, 16)],
400                dim=1,
401            )
402
403        def input_shuffling_dim_mismatch(x):
404            split_output = list(torch.split(x, 4, dim=1))
405            return torch.cat(
406                [torch.ones(2, 4, 32, 16)]
407                + [split_output[1], split_output[2], split_output[3]]
408                + [torch.ones(2, 4, 32, 16)]
409                + [split_output[5], split_output[6], split_output[7]]
410                + [torch.ones(2, 4, 32, 16)],
411                dim=2,
412            )
413
414        def input_shuffling_dim_mismatch_stack(x):
415            split_output = list(torch.split(x, 4, dim=1))
416            return torch.stack(
417                [torch.ones(2, 4, 32, 16)]
418                + [split_output[1], split_output[2], split_output[3]]
419                + [torch.ones(2, 4, 32, 16)]
420                + [split_output[5], split_output[6], split_output[7]]
421                + [torch.ones(2, 4, 32, 16)],
422                dim=2,
423            )
424
425        def input_shuffling_multiple_output(x):
426            split_output = list(torch.split(x, 4, dim=1))
427            cat1 = torch.cat(
428                [torch.ones(2, 4, 32, 16)]
429                + [split_output[1], split_output[2], split_output[3]]
430                + [torch.ones(2, 4, 32, 16)],
431                dim=2,
432            )
433            stack1 = torch.stack(
434                [
435                    torch.ones(2, 4, 32, 16),
436                    split_output[4],
437                    split_output[5],
438                    torch.ones(2, 4, 32, 16),
439                ],
440                dim=1,
441            )
442
443            relu1 = torch.relu(split_output[6])
444
445            return cat1, stack1, relu1
446
447        def input_shuffling_direct_output(x):
448            split_output = list(torch.split(x, 4, dim=1))
449            cat1 = torch.cat(
450                [torch.ones(2, 4, 32, 16)]
451                + [split_output[1], split_output[2], split_output[3]]
452                + [torch.ones(2, 4, 32, 16)],
453                dim=2,
454            )
455            stack1 = torch.stack(
456                [
457                    torch.ones(2, 4, 32, 16),
458                    split_output[4],
459                    split_output[5],
460                    torch.ones(2, 4, 32, 16),
461                ],
462                dim=1,
463            )
464
465            return cat1, stack1, split_output[6]
466
467        def input_shuffling_multiple_output_same_ranges(x):
468            split_output = list(torch.split(x, 4, dim=1))
469            cat1 = torch.cat(
470                [torch.ones(2, 4, 32, 16)]
471                + [split_output[1], split_output[2], split_output[3]]
472                + [torch.ones(2, 4, 32, 16)],
473                dim=2,
474            )
475
476            cat2 = torch.cat(
477                [torch.ones(2, 4, 32, 16)]
478                + [split_output[1], split_output[2], split_output[3]]
479                + [torch.ones(2, 4, 32, 16)],
480                dim=2,
481            )
482            stack1 = torch.stack(
483                [
484                    torch.ones(2, 4, 32, 16),
485                    split_output[4],
486                    split_output[5],
487                    torch.ones(2, 4, 32, 16),
488                ],
489                dim=1,
490            )
491
492            relu1 = torch.relu(split_output[6])
493
494            return cat1, cat2, stack1, relu1
495
496        def unequal_split_multiple_output(x):
497            split_output = list(torch.split(x, [2, 4, 4, 4, 4, 4, 8, 2], dim=1))
498            cat1 = torch.cat(
499                [torch.ones(2, 4, 32, 16)]
500                + [split_output[1], split_output[2], split_output[3]]
501                + [torch.ones(2, 4, 32, 16)],
502                dim=2,
503            )
504            stack1 = torch.stack(
505                [
506                    torch.ones(2, 4, 32, 16),
507                    split_output[4],
508                    split_output[5],
509                    torch.ones(2, 4, 32, 16),
510                ],
511                dim=1,
512            )
513
514            relu1 = torch.relu(split_output[6])
515
516            return cat1, stack1, relu1
517
518        def multi_split_cat(x1, x2):
519            split_output_1 = list(torch.split(x1, 4, dim=1))
520            split_output_2 = list(torch.split(x2, 4, dim=1))
521            cat1 = torch.cat(
522                [torch.ones(2, 4, 32, 16)]
523                + [split_output_1[1], split_output_1[2], split_output_1[3]]
524                + [torch.ones(2, 4, 32, 16)]
525                + [split_output_2[1], split_output_2[2], split_output_2[3]]
526                + [torch.ones(2, 4, 32, 16)],
527                dim=2,
528            )
529            stack1 = torch.stack(
530                [
531                    torch.ones(2, 4, 32, 16),
532                    split_output_1[4],
533                    split_output_1[5],
534                    torch.ones(2, 4, 32, 16),
535                    split_output_2[4],
536                    split_output_2[5],
537                    torch.ones(2, 4, 32, 16),
538                ],
539                dim=1,
540            )
541
542            relu1 = torch.relu(split_output_1[6])
543            relu2 = torch.relu(split_output_2[6])
544
545            return cat1, stack1, relu1, relu2
546
547        # TODO: Add more tests:
548        # * Cases where replacement shouldn't happen
549        default_args = [
550            torch.randn(2, 32, 32, 16),
551        ]
552        multi_args = [
553            torch.randn(2, 32, 32, 16),
554            torch.randn(2, 32, 32, 16),
555        ]
556        for (
557            fn,
558            expected_split_added,
559            expected_split_removed,
560            expected_cat_added,
561            expected_cat_removed,
562            expected_sections_removed,
563            args,
564        ) in [
565            (simple_split_cat, 0, 0, 0, 0, 0, default_args),
566            (simple_split_cat_argspec1, 0, 0, 0, 0, 0, default_args),
567            (simple_split_cat_argspec2, 0, 0, 0, 0, 0, default_args),
568            (simple_split_cat_argspec3, 0, 1, 0, 1, 7, default_args),
569            (simple_split_cat_argspec4, 0, 1, 0, 1, 7, default_args),
570            (simple_split_stack, 0, 1, 0, 1, 7, default_args),
571            (simple_split_stack_argspec1, 0, 1, 0, 1, 7, default_args),
572            (simple_split_stack_argspec2, 0, 1, 0, 1, 7, default_args),
573            (split_cat_addn_args, 0, 1, 1, 1, 7, default_args),
574            (split_stack_addn_args, 0, 1, 1, 1, 7, default_args),
575            (split_cat_addn_args_dim2, 0, 1, 1, 1, 7, default_args),
576            (split_cat_dim_mismatch, 0, 1, 1, 1, 7, default_args),
577            (split_stack_dim_mismatch, 0, 1, 1, 1, 7, default_args),
578            (split_cat_dim_mismatch2, 0, 1, 1, 1, 7, default_args),
579            (split_stack_dim_mismatch2, 0, 1, 1, 1, 7, default_args),
580            (split_cat_dim_mismatch3, 0, 1, 1, 1, 7, default_args),
581            (split_stack_dim_mismatch3, 0, 1, 1, 1, 7, default_args),
582            (input_shuffling, 1, 1, 1, 1, 4, default_args),
583            (input_shuffling_stack, 1, 1, 1, 1, 4, default_args),
584            (input_shuffling_dim_mismatch, 1, 1, 1, 1, 4, default_args),
585            (input_shuffling_dim_mismatch_stack, 1, 1, 1, 1, 4, default_args),
586            (input_shuffling_multiple_output, 1, 1, 2, 2, 3, default_args),
587            (input_shuffling_direct_output, 1, 1, 2, 2, 3, default_args),
588            (unequal_split_multiple_output, 1, 1, 2, 2, 3, default_args),
589            (multi_split_cat, 1, 1, 2, 2, 3, multi_args),
590        ]:
591            expected = fn(*args)
592            actual = torch.compile(fn)(*args)
593
594            torch.testing.assert_close(actual, expected)
595            self.assertEqual(
596                counters["inductor"]["scmerge_split_added"],
597                expected_split_added,
598            )
599            self.assertEqual(
600                counters["inductor"]["scmerge_split_removed"],
601                expected_split_removed,
602            )
603            self.assertEqual(
604                counters["inductor"]["scmerge_cat_added"],
605                expected_cat_added,
606            )
607            self.assertEqual(
608                counters["inductor"]["scmerge_cat_removed"],
609                expected_cat_removed,
610            )
611            self.assertEqual(
612                counters["inductor"]["scmerge_split_sections_removed"],
613                expected_sections_removed,
614            )
615            counters.clear()
616
617    @torch._inductor.config.patch(
618        pre_grad_fusion_options={},
619        post_grad_fusion_options={},
620    )
621    def test_config_flag_is_respected(self):
622        def split_with_cat(x):
623            fs = torch.split(x, [4, 4, 24], dim=-1)
624            item0 = fs[0]
625            item1 = fs[1]
626            item2 = fs[2]
627
628            final_items = [item0, item1]
629            final_items.extend(item2.split((4, 4, 4, 4, 4, 4), 1))
630
631            return torch.cat(final_items, dim=1)
632
633        args = [
634            torch.randn(2, 32),
635        ]
636
637        expected = split_with_cat(*args)
638        actual = torch.compile(split_with_cat)(*args)
639
640        torch.testing.assert_close(actual, expected)
641        self.assertEqual(
642            counters["inductor"]["merge_splits_pass"],
643            0,
644        )
645        self.assertEqual(
646            counters["inductor"]["normalization_pass"],
647            0,
648        )
649
650    @patch
651    def test_split_cat_merge_mutation(self):
652        args = [
653            torch.randn(2, 32, 32, 16),
654        ]
655
656        def split_cat_mutation(x):
657            splits = torch.split(x, 4, dim=1)
658            splits[1].copy_(splits[0])
659            return torch.cat(splits, dim=1)
660
661        expected = split_cat_mutation(*args)
662        actual = torch.compile(split_cat_mutation)(*args)
663
664        torch.testing.assert_close(actual, expected)
665
666        self.assertEqual(counters["inductor"]["scmerge_split_removed"], 0)
667        self.assertEqual(counters["inductor"]["scmerge_cat_removed"], 0)
668
669    @patch
670    def test_split_squeeze(self):
671        def split_squeeze_stack(x):
672            items = list(torch.split(x, 1, dim=1))
673            split_items = [torch.squeeze(s, 1) for s in items]
674            return torch.stack(split_items)
675
676        def split_squeeze_stack_callmethod(x):
677            items = list(torch.split(x, 1, dim=1))
678            split_items = [s.squeeze(1) for s in items]
679            return torch.stack(split_items)
680
681        def split_squeeze_stack_callmethod_none_dim(x):
682            items = list(torch.split(x, 1, dim=1))
683            split_items = [s.squeeze() for s in items]
684            return torch.stack(split_items)
685
686        def split_squeeze_stack_kwarg1(x):
687            items = list(torch.split(x, 1, dim=1))
688            split_items = [torch.squeeze(s, dim=1) for s in items]
689            return torch.stack(split_items)
690
691        def split_squeeze_stack_kwarg1_callmethod(x):
692            items = list(torch.split(x, 1, dim=1))
693            split_items = [s.squeeze(dim=1) for s in items]
694            return torch.stack(split_items)
695
696        def split_squeeze_multi_squeeze_users(x):
697            items = list(torch.split(x, 1, dim=1))
698            split_items = [torch.squeeze(s, 1) for s in items]
699            return (
700                torch.stack(split_items),
701                torch.relu(split_items[0]),
702                torch.tanh(split_items[1]),
703            )
704
705        def split_size_not_1(x):
706            items = list(torch.split(x, 2, dim=1))
707            split_items = [torch.squeeze(s, 1) for s in items]
708            return torch.stack(split_items)
709
710        def dim_mismatch(x):
711            items = list(torch.split(x, 1, dim=1))
712            split_items = [torch.squeeze(s, 0) for s in items]
713            return torch.stack(split_items)
714
715        def other_users(x):
716            items = list(torch.split(x, 1, dim=1))
717            split_items = [torch.squeeze(s, 1) for s in items]
718            return torch.stack(split_items), torch.relu(items[0])
719
720        def other_users_2(x):
721            items = list(torch.split(x, 1, dim=1))
722            split_items = [torch.squeeze(s, 1) for s in items[1:]]
723            return torch.stack(split_items), torch.relu(items[0])
724
725        def graph_should_be_topological_sorted(x):
726            output = []
727            for t in x.split(1):
728                output.append(torch.sin(t.squeeze(dim=0)))
729            output = torch.stack(output)
730            return output
731
732        args = [
733            torch.randn(2, 32),
734        ]
735        for fn, split_squeeze_replaced in [
736            (split_squeeze_stack, 1),
737            (split_squeeze_stack_callmethod, 1),
738            # TODO handle none dim
739            (split_squeeze_stack_callmethod_none_dim, 0),
740            (split_squeeze_stack_kwarg1, 1),
741            (split_squeeze_stack_kwarg1_callmethod, 1),
742            (split_squeeze_multi_squeeze_users, 1),
743            (split_size_not_1, 0),
744            (dim_mismatch, 0),
745            (other_users, 0),
746            (other_users_2, 0),
747            (graph_should_be_topological_sorted, 1),
748        ]:
749            expected = fn(*args)
750            actual = torch.compile(fn)(*args)
751
752            torch.testing.assert_close(actual, expected)
753            self.assertEqual(
754                counters["inductor"]["split_cat_pass"],
755                split_squeeze_replaced,
756            )
757            counters.clear()
758
759    @patch
760    def test_unbind_stack(self):
761        def unbind_stack(x):
762            return torch.stack(torch.unbind(x, 1), 1)
763
764        def unbind_cat(x):
765            return torch.cat(torch.unbind(x, dim=-3), 1)
766
767        def unbind_stack_argspec1(x):
768            return torch.stack(torch.unbind(input=x, dim=1), dim=1)
769
770        def unbind_stack_argspec2(x):
771            return torch.stack(tensors=torch.unbind(x, dim=1), dim=1)
772
773        def dim_mismatch(x):
774            return torch.stack(torch.unbind(x, dim=1), 0)
775
776        def split_squeeze_stack(x):
777            items = list(torch.split(x, 1, dim=1))
778            split_items = [torch.squeeze(s, 1) for s in items]
779            return torch.stack(split_items, 1)
780
781        def split_squeeze_stack_callmethod(x):
782            items = list(torch.split(x, 1, dim=1))
783            split_items = [torch.squeeze(s, 1) for s in items]
784            return torch.stack(split_items, 1)
785
786        def other_users(x):
787            items = list(torch.split(x, 1, dim=1))
788            split_items = [torch.squeeze(s, 1) for s in items]
789            return torch.stack(split_items, 1), torch.relu(items[0])
790
791        def other_users_2(x):
792            items = list(torch.split(x, 1, dim=1))
793            split_items = [torch.squeeze(s, 1) for s in items[1:]]
794            return torch.stack(split_items, 1), torch.relu(items[0])
795
796        def unbind_cat_addn_args(x):
797            split_output = list(torch.unbind(x, dim=1))
798
799            return torch.cat(
800                [torch.ones(2, 32, 16)] + split_output + [torch.ones(2, 32, 16)],
801                dim=1,
802            )
803
804        def unbind_stack_addn_args(x):
805            split_output = list(torch.unbind(x, dim=1))
806            return torch.stack(
807                [torch.ones(2, 32, 16)]
808                + split_output
809                + [torch.ones(2, 32, 16), torch.ones(2, 32, 16)],
810                dim=1,
811            )
812
813        def unbind_cat_addn_args_dim2(x):
814            split_output = list(torch.unbind(x, dim=2))
815            return torch.cat(
816                [torch.ones(2, 32, 16)] + split_output + [torch.ones(2, 32, 16)],
817                dim=2,
818            )
819
820        # split_dim=1, cat_dim=2
821        def unbind_cat_dim_mismatch(x):
822            split_output = list(torch.unbind(x, dim=1))
823            return torch.cat(
824                [torch.ones(2, 32, 16)] + split_output + [torch.ones(2, 32, 16)],
825                dim=2,
826            )
827
828        def unbind_stack_dim_mismatch(x):
829            split_output = list(torch.unbind(x, dim=1))
830            return torch.stack(
831                [torch.ones(2, 32, 16)] + split_output + [torch.ones(2, 32, 16)],
832                dim=2,
833            )
834
835        def unbind_cat_multi_users(x):
836            split_output = list(torch.unbind(x, dim=1))
837            return torch.cat(
838                [torch.ones(2, 32, 16)] + split_output + [torch.ones(2, 32, 16)],
839                dim=1,
840            ), torch.stack(
841                [torch.ones(2, 32, 16)]
842                + split_output
843                + [torch.ones(2, 32, 16), torch.ones(2, 32, 16)],
844                dim=1,
845            )
846
847        def unbind_cat_multi_users_diff_dims(x):
848            split_output = list(torch.unbind(x, dim=1))
849            return torch.cat(
850                [torch.ones(2, 32, 16)] + split_output + [torch.ones(2, 32, 16)],
851                dim=1,
852            ), torch.stack(
853                [torch.ones(2, 32, 16)] + split_output + [torch.ones(2, 32, 16)],
854                dim=2,
855            )
856
857        args = [
858            torch.randn(2, 32, 32, 16),
859        ]
860        for (
861            fn,
862            expected_unbind_added,
863            expected_unbind_removed,
864            expected_cat_added,
865            expected_cat_removed,
866            expected_sections_removed,
867            expected_unbind_normalized,
868        ) in [
869            (unbind_stack, 0, 1, 0, 1, 31, 2),
870            (unbind_stack_argspec1, 0, 1, 0, 1, 31, 2),
871            (unbind_stack_argspec2, 0, 1, 0, 1, 31, 2),
872            (dim_mismatch, 0, 1, 0, 1, 31, 2),
873            (split_squeeze_stack, 0, 1, 0, 1, 31, 2),
874            (split_squeeze_stack_callmethod, 0, 1, 0, 1, 31, 2),
875            (other_users, 0, 0, 0, 0, 0, 2),
876            (other_users_2, 0, 0, 0, 0, 0, 2),
877            (unbind_cat_addn_args, 0, 1, 1, 1, 31, 1),
878            (unbind_stack_addn_args, 0, 1, 1, 1, 31, 2),
879            (unbind_cat_addn_args_dim2, 0, 1, 1, 1, 31, 1),
880            (unbind_cat_dim_mismatch, 0, 1, 1, 1, 31, 1),
881            (unbind_stack_dim_mismatch, 0, 1, 1, 1, 31, 2),
882            (unbind_cat_multi_users, 0, 1, 2, 2, 31, 2),
883            (unbind_cat_multi_users_diff_dims, 0, 1, 2, 2, 31, 2),
884        ]:
885            expected = fn(*args)
886            actual = torch.compile(fn)(*args)
887
888            torch.testing.assert_close(actual, expected)
889            self.assertEqual(
890                counters["inductor"]["scmerge_split_added"],
891                expected_unbind_added,
892                msg=f"for {fn}",
893            )
894            self.assertEqual(
895                counters["inductor"]["scmerge_split_removed"],
896                expected_unbind_removed,
897                msg=f"for {fn}",
898            )
899            self.assertEqual(
900                counters["inductor"]["scmerge_cat_added"],
901                expected_cat_added,
902                msg=f"for {fn}",
903            )
904            self.assertEqual(
905                counters["inductor"]["scmerge_cat_removed"],
906                expected_cat_removed,
907                msg=f"for {fn}",
908            )
909            self.assertEqual(
910                counters["inductor"]["scmerge_split_sections_removed"],
911                expected_sections_removed,
912                msg=f"for {fn}",
913            )
914            self.assertEqual(
915                counters["inductor"]["normalization_pass"],
916                expected_unbind_normalized,
917                msg=f"for {fn}",
918            )
919            counters.clear()
920
921    @patch
922    def test_split_cat_new_patterns(self):
923        def split_cat_split(x):
924            l1_out = torch.split(x, [200, 50, 50, 20, 20, 20, 20, 20, 20, 50, 30], 1)
925            item0 = l1_out[0]
926            item1 = l1_out[1]
927            item2 = l1_out[2]
928            item3 = l1_out[3]
929            item4 = l1_out[4]
930            item5 = l1_out[5]
931            item6 = l1_out[6]
932            item7 = l1_out[7]
933            item8 = l1_out[8]
934            item9 = l1_out[9]
935            item10 = l1_out[10]
936            cat_1 = torch.cat((item0, item1), 1)
937            cat_2 = torch.cat((item9, item10), 1)
938            l2_out = torch.split(cat_1, [50, 120, 80], 1)
939            l3_out = torch.split(cat_2, [10, 20, 50], 1)
940            item11 = l2_out[0]
941            item12 = l2_out[1]
942            item13 = l2_out[2]
943            item14 = l3_out[0]
944            item15 = l3_out[1]
945            item16 = l3_out[2]
946
947            output = torch.cat(
948                [
949                    item11,
950                    item12,
951                    item13,
952                    item14,
953                    item15,
954                    item16,
955                    item2,
956                    item3,
957                    item4,
958                    item5,
959                    item6,
960                    item7,
961                    item8,
962                ],
963                1,
964            )
965            return output
966
967        def split_cat_split_kwarg(x):
968            l1_out = torch.split(
969                x, [200, 50, 50, 20, 20, 20, 20, 20, 20, 50, 30], dim=1
970            )
971            item0 = l1_out[0]
972            item1 = l1_out[1]
973            item2 = l1_out[2]
974            item3 = l1_out[3]
975            item4 = l1_out[4]
976            item5 = l1_out[5]
977            item6 = l1_out[6]
978            item7 = l1_out[7]
979            item8 = l1_out[8]
980            item9 = l1_out[9]
981            item10 = l1_out[10]
982            cat_1 = torch.cat((item0, item1), dim=1)
983            cat_2 = torch.cat((item9, item10), dim=1)
984            l2_out = torch.split(cat_1, [50, 120, 80], dim=1)
985            l3_out = torch.split(cat_2, [10, 20, 50], dim=1)
986            item11 = l2_out[0]
987            item12 = l2_out[1]
988            item13 = l2_out[2]
989            item14 = l3_out[0]
990            item15 = l3_out[1]
991            item16 = l3_out[2]
992
993            output = torch.cat(
994                [
995                    item11,
996                    item12,
997                    item13,
998                    item14,
999                    item15,
1000                    item16,
1001                    item2,
1002                    item3,
1003                    item4,
1004                    item5,
1005                    item6,
1006                    item7,
1007                    item8,
1008                ],
1009                dim=1,
1010            )
1011            return output
1012
1013        def remove_cat_node_with_all_getitmes(x):
1014            l1_out = torch.split(
1015                x, [50, 50, 200, 20, 20, 20, 20, 20, 40, 10, 50], dim=0
1016            )
1017            item0 = l1_out[0]
1018            item1 = l1_out[1]
1019            item2 = l1_out[2]
1020            item3 = l1_out[3]
1021            item4 = l1_out[4]
1022            item5 = l1_out[5]
1023            item6 = l1_out[6]
1024            item7 = l1_out[7]
1025            item8 = l1_out[8]
1026            item9 = l1_out[9]
1027            item10 = l1_out[10]
1028            cat = torch.cat(
1029                (
1030                    item0,
1031                    item1,
1032                    item2,
1033                    item3,
1034                    item4,
1035                    item5,
1036                    item6,
1037                    item7,
1038                    item8,
1039                    item9,
1040                    item10,
1041                ),
1042                dim=0,
1043            )
1044            cat_1 = torch.cat((item0, item1), dim=0)
1045            cat_2 = torch.cat((item0, item10), dim=0)
1046            l2_out = torch.split(cat_1, [20, 30, 50], dim=0)
1047            l3_out = torch.split(cat_2, [10, 60, 30], dim=0)
1048            item11 = l2_out[0]
1049            item12 = l2_out[1]
1050            item13 = l2_out[2]
1051            item14 = l3_out[0]
1052            item15 = l3_out[1]
1053            item16 = l3_out[2]
1054
1055            output = torch.cat(
1056                [
1057                    item11,
1058                    item12,
1059                    item13,
1060                    item14,
1061                    item15,
1062                    item16,
1063                    item2,
1064                    item3,
1065                    item4,
1066                    item5,
1067                    item6,
1068                    item7,
1069                    item8,
1070                ],
1071                dim=0,
1072            )
1073            return torch.cat((output, cat), dim=0)
1074
1075        def mutate_cat_node_with_some_getitmes(x):
1076            l1_out = torch.split(
1077                x, [50, 50, 200, 20, 20, 20, 20, 20, 40, 10, 50], dim=0
1078            )
1079            item0 = l1_out[0]
1080            item1 = l1_out[1]
1081            item2 = l1_out[2]
1082            item3 = l1_out[3]
1083            item4 = l1_out[4]
1084            item5 = l1_out[5]
1085            item6 = l1_out[6]
1086            item7 = l1_out[7]
1087            item8 = l1_out[8]
1088            item9 = l1_out[9]
1089            item10 = l1_out[10]
1090            cat = torch.cat(
1091                (
1092                    item6,
1093                    item7,
1094                    item8,
1095                    item9,
1096                    item10,
1097                    item2,
1098                    item3,
1099                    item4,
1100                    item5,
1101                ),
1102                dim=0,
1103            )
1104            cat_1 = torch.cat((item0, item1), dim=0)
1105            cat_2 = torch.cat((item0, item10), dim=0)
1106            l2_out = torch.split(cat_1, [20, 30, 50], dim=0)
1107            l3_out = torch.split(cat_2, [10, 60, 30], dim=0)
1108            item11 = l2_out[0]
1109            item12 = l2_out[1]
1110            item13 = l2_out[2]
1111            item14 = l3_out[0]
1112            item15 = l3_out[1]
1113            item16 = l3_out[2]
1114
1115            output = torch.cat(
1116                [
1117                    item11,
1118                    item12,
1119                    item13,
1120                    item14,
1121                    item15,
1122                    item16,
1123                    item2,
1124                ],
1125                dim=0,
1126            )
1127            return torch.cat((output, cat), dim=0)
1128
1129        @torch._inductor.config.patch(
1130            pre_grad_fusion_options={
1131                "split_cat_to_slices_pass": {},
1132            },
1133            post_grad_fusion_options={},
1134        )
1135        def split_cat_to_slices(x):
1136            x_c = x.clone()
1137            x_c_2 = x.clone()
1138            l1_out = torch.split(x, [50, 50, 50, 50, 50, 50, 50, 50, 50, 50], dim=0)
1139            l2_out = torch.split(x_c, [50, 50, 50, 50, 50, 50, 50, 50, 50, 50], dim=0)
1140            l3_out = torch.split(x_c_2, [100, 100, 100, 100, 100], dim=0)
1141            item0 = l1_out[0]
1142            item1 = l1_out[1]
1143            item2 = l1_out[2]
1144            item3 = l1_out[3]
1145            item4 = l1_out[4]
1146            item5 = l1_out[5]
1147            item6 = l1_out[6]
1148            item7 = l1_out[7]
1149            item8 = l1_out[8]
1150            item9 = l1_out[9]
1151            item0_c = l2_out[0]
1152            item1_c = l2_out[1]
1153            item2_c = l2_out[2]
1154            item3_c = l2_out[3]
1155            item4_c = l2_out[4]
1156            item5_c = l2_out[5]
1157            item6_c = l2_out[6]
1158            item7_c = l2_out[7]
1159            item8_c = l2_out[8]
1160            item9_c = l2_out[9]
1161            item0_c_2 = l3_out[0]
1162            item1_c_2 = l3_out[1]
1163            item2_c_2 = l3_out[2]
1164            item3_c_2 = l3_out[3]
1165            item4_c_2 = l3_out[4]
1166            other = item0.clone()
1167            return torch.cat(
1168                [
1169                    other,
1170                    item0,
1171                    item1,
1172                    item2,
1173                    item3,
1174                    item4,
1175                    item5,
1176                    item6,
1177                    item7,
1178                    item8,
1179                    item9,
1180                    item4_c,
1181                    item5_c,
1182                    item6_c,
1183                    item7_c,
1184                    item8_c,
1185                    item9_c,
1186                    item0_c,
1187                    item1_c,
1188                    item2_c,
1189                    item3_c,
1190                    item0_c_2,
1191                    item1_c_2,
1192                    item2_c_2,
1193                    item3_c_2,
1194                    item4_c_2,
1195                ],
1196                dim=0,
1197            )
1198
1199        @torch._inductor.config.patch(
1200            pre_grad_fusion_options={
1201                "unbind_cat_to_view_pass": {},
1202            },
1203            post_grad_fusion_options={},
1204        )
1205        def unbind_cat_to_view(x):
1206            y = x.view(10, 50, 500)
1207            z = x.view(10, 50, 500)
1208            l1_out = torch.unbind(y, dim=0)
1209            l2_out = torch.unbind(z, dim=0)
1210            item0 = l1_out[0]
1211            item1 = l1_out[1]
1212            item2 = l1_out[2]
1213            item3 = l1_out[3]
1214            item4 = l1_out[4]
1215            item5 = l1_out[5]
1216            item6 = l1_out[6]
1217            item7 = l1_out[7]
1218            item8 = l1_out[8]
1219            item9 = l1_out[9]
1220            item2_0 = l2_out[0]
1221            item2_1 = l2_out[1]
1222            item2_2 = l2_out[2]
1223            item2_3 = l2_out[3]
1224            item2_4 = l2_out[4]
1225            item2_5 = l2_out[5]
1226            item2_6 = l2_out[6]
1227            item2_7 = l2_out[7]
1228            item2_8 = l2_out[8]
1229            item2_9 = l2_out[9]
1230            other1 = item7.clone()
1231            other2 = item8.clone()
1232            other3 = item9.clone()
1233            cat = torch.cat(
1234                [
1235                    item0,
1236                    item1,
1237                    item2,
1238                    item3,
1239                    item4,
1240                    item5,
1241                    item6,
1242                    other1,
1243                    item2_0,
1244                    item2_1,
1245                    item2_2,
1246                    item2_3,
1247                    item2_4,
1248                    item2_5,
1249                    item2_6,
1250                    item2_7,
1251                    item2_8,
1252                    item2_9,
1253                    other2,
1254                    other3,
1255                ],
1256                dim=1,
1257            )
1258            return cat
1259
1260        @torch._inductor.config.patch(
1261            pre_grad_fusion_options={
1262                "split_stack_to_cats_pass": {},
1263            },
1264            post_grad_fusion_options={},
1265        )
1266        def split_stack_to_cats_same_dim(x):
1267            x_c = x.view(10, 50, 500)
1268            l1_out = torch.unbind(x_c, dim=0)
1269            item0 = l1_out[0]
1270            item1 = l1_out[1]
1271            item2 = l1_out[2]
1272            item3 = l1_out[3]
1273            item4 = l1_out[4]
1274            item5 = l1_out[5]
1275            split1 = torch.split(item0, [250, 250], dim=1)
1276            split2 = torch.split(item1, [250, 250], dim=1)
1277            split3 = torch.split(item2, [250, 250], dim=1)
1278            split4 = torch.split(item3, [250, 250], dim=1)
1279            split5 = torch.split(item4, [250, 250], dim=1)
1280            split6 = torch.split(item5, [250, 250], dim=1)
1281            getitem0, getitem1 = split1[0], split1[1]
1282            getitem2, getitem3 = split2[0], split2[1]
1283            getitem4, getitem5 = split3[0], split3[1]
1284            getitem6, getitem7 = split4[0], split4[1]
1285            getitem8, getitem9 = split5[0], split5[1]
1286            getitem10, getitem11 = split6[0], split6[1]
1287            getitem0_c = getitem0.clone()
1288            getitem1_c = getitem1.clone()
1289            getitem2_c = getitem2.clone()
1290            return torch.stack(
1291                (
1292                    getitem0,
1293                    getitem1,
1294                    getitem2,
1295                    getitem3,
1296                    getitem4,
1297                    getitem5,
1298                    getitem0_c,
1299                    getitem1_c,
1300                    getitem6,
1301                    getitem7,
1302                    getitem8,
1303                    getitem9,
1304                    getitem10,
1305                    getitem11,
1306                    getitem2_c,
1307                ),
1308                dim=1,
1309            )
1310
1311        @torch._inductor.config.patch(
1312            pre_grad_fusion_options={
1313                "split_stack_to_cats_pass": {},
1314            },
1315            post_grad_fusion_options={},
1316        )
1317        def split_stack_to_cats_different_dim(x):
1318            l1_out = torch.split(x, [100, 100, 100, 100, 100], dim=1)
1319            x_c = x.clone()
1320            l2_out = torch.split(x_c, [100, 100, 100, 100, 100], dim=1)
1321            item0 = l1_out[0]
1322            item1 = l1_out[1]
1323            item2 = l1_out[2]
1324            item3 = l1_out[3]
1325            item4 = l1_out[4]
1326            item0_c = l2_out[0]
1327            item1_c = l2_out[1]
1328            item2_c = l2_out[2]
1329            item3_c = l2_out[3]
1330            item4_c = l2_out[4]
1331            other_1 = item0.clone()
1332            other_2 = item1.clone()
1333            other_3 = item2.clone()
1334            return torch.stack(
1335                (
1336                    other_1,
1337                    other_2,
1338                    other_3,
1339                    item0,
1340                    item1,
1341                    item2,
1342                    item3,
1343                    item4,
1344                    item0_c,
1345                    item1_c,
1346                    item2_c,
1347                    item3_c,
1348                    item4_c,
1349                ),
1350                dim=2,
1351            )
1352
1353        @torch._inductor.config.patch(
1354            pre_grad_fusion_options={
1355                "unbind_stack_to_slices_pass": {},
1356            },
1357            post_grad_fusion_options={},
1358        )
1359        def unbind_stack_to_slices(x):
1360            x_1 = x.view(50, 10, 500)
1361            l1_out = torch.unbind(x_1, dim=1)
1362            item0 = l1_out[0]
1363            item1 = l1_out[1]
1364            item2 = l1_out[2]
1365            item3 = l1_out[3]
1366            item4 = l1_out[4]
1367            item5 = l1_out[5]
1368            item6 = l1_out[6]
1369            item7 = l1_out[7]
1370            item8 = l1_out[8]
1371            item9 = l1_out[9]
1372            other_1 = item0.clone()
1373            other_2 = item1.clone()
1374            other_3 = item2.clone()
1375            return torch.stack(
1376                (
1377                    other_1,
1378                    other_2,
1379                    other_3,
1380                    item0,
1381                    item1,
1382                    item2,
1383                    item3,
1384                    item4,
1385                    item5,
1386                    item6,
1387                    item7,
1388                    item8,
1389                    item9,
1390                ),
1391                dim=1,
1392            )
1393
1394        @torch._inductor.config.patch(
1395            pre_grad_fusion_options={
1396                "normalization_pass": {},
1397                "move_reshape_out_of_split_stack_pass": {},
1398            },
1399            post_grad_fusion_options={},
1400        )
1401        def move_reshape_out_of_split_stack(x):
1402            x_c = x.view(50000, 5)
1403            l1_out = torch.split(x_c, [1, 1, 1, 1, 1], dim=1)
1404            item0 = l1_out[0]
1405            item1 = l1_out[1]
1406            item2 = l1_out[2]
1407            item3 = l1_out[3]
1408            item4 = l1_out[4]
1409            reshape0 = item0.reshape(-1, 5)
1410            reshape1 = item1.reshape(-1, 5)
1411            reshape2 = item2.reshape(-1, 5)
1412            reshape3 = item3.reshape(-1, 5)
1413            reshape4 = item4.reshape(-1, 5)
1414            other0 = reshape0.clone()
1415            other1 = reshape1.clone()
1416            other2 = reshape2.clone()
1417            other3 = reshape3.clone()
1418            return torch.stack(
1419                (
1420                    other0,
1421                    other1,
1422                    other2,
1423                    reshape0,
1424                    reshape1,
1425                    reshape2,
1426                    reshape3,
1427                    reshape4,
1428                    other3,
1429                ),
1430                dim=0,
1431            )
1432
1433        args = [
1434            torch.randn(500, 500),
1435        ]
1436        for (
1437            fn,
1438            expected_getitem_cat_merged,
1439            expected_cat_removed,
1440            expected_split_cat_to_slices,
1441            exptected_unbind_to_cat_view,
1442            expected_split_stack_to_cats,
1443            exptected_unbind_stack_to_slices,
1444            expected_move_reshape_out_of_split_stack,
1445        ) in [
1446            (split_cat_split, 2, 0, 0, 0, 0, 0, 0),
1447            (split_cat_split_kwarg, 2, 0, 0, 0, 0, 0, 0),
1448            (remove_cat_node_with_all_getitmes, 0, 2, 0, 0, 0, 0, 0),
1449            (mutate_cat_node_with_some_getitmes, 0, 1, 0, 0, 0, 0, 0),
1450            (split_cat_to_slices, 0, 0, 1, 0, 0, 0, 0),
1451            (unbind_cat_to_view, 0, 0, 0, 1, 0, 0, 0),
1452            (split_stack_to_cats_same_dim, 0, 0, 0, 0, 1, 0, 0),
1453            (split_stack_to_cats_different_dim, 0, 0, 0, 0, 1, 0, 0),
1454            (unbind_stack_to_slices, 0, 0, 0, 0, 0, 1, 0),
1455            (move_reshape_out_of_split_stack, 0, 0, 0, 0, 0, 0, 1),
1456        ]:
1457            expected = fn(*args)
1458            actual = torch.compile(fn)(*args)
1459
1460            torch.testing.assert_close(actual, expected)
1461            self.assertEqual(
1462                counters["inductor"]["merge_getitem_cat_pass"],
1463                expected_getitem_cat_merged,
1464            )
1465            self.assertEqual(
1466                counters["inductor"]["mutate_cat_pass"],
1467                expected_cat_removed,
1468            )
1469            self.assertEqual(
1470                counters["inductor"]["split_cat_to_slices_pass"],
1471                expected_split_cat_to_slices,
1472            )
1473            self.assertEqual(
1474                counters["inductor"]["unbind_cat_to_view_pass"],
1475                exptected_unbind_to_cat_view,
1476            )
1477            self.assertEqual(
1478                counters["inductor"]["split_stack_to_cats_pass"],
1479                expected_split_stack_to_cats,
1480            )
1481            self.assertEqual(
1482                counters["inductor"]["unbind_stack_to_slices_pass"],
1483                exptected_unbind_stack_to_slices,
1484            )
1485            self.assertEqual(
1486                counters["inductor"]["move_reshape_out_of_split_stack_pass"],
1487                expected_move_reshape_out_of_split_stack,
1488            )
1489            counters.clear()
1490
1491    def test_numpy_compat_normalization(self):
1492        def fn(x, y):
1493            a = torch.stack([x, y], axis=1)
1494            b = torch.mul(x, x2=y)
1495            c = torch.mul(x, x2=y)
1496            d = torch.mul(x, x2=y)
1497            e = torch.max(x, dim=1, keepdims=True)
1498            f = torch.dropout(x=x, p=0.5, train=True)
1499            return a, b, c, d, e, f
1500
1501        fn_t = torch.fx.symbolic_trace(fn)
1502        numpy_compat_normalization(fn_t.graph)
1503
1504        for n in fn_t.graph.nodes:
1505            for k in n.kwargs.keys():
1506                self.assertTrue(k not in {"x", "x1", "x2", "a", "axis", "keepdims"})
1507
1508    @patch
1509    @requires_gpu
1510    def test_stack_normalization_axis_kwarg(self):
1511        def fn(x, y):
1512            return torch.stack([x, y], axis=1)
1513
1514        x, y = (torch.rand((4, 4), device=GPU_TYPE) for _ in range(2))
1515        expected = fn(x, y)
1516        actual = torch.compile(fn)(x, y)
1517
1518        self.assertEqual(actual, expected)
1519
1520
1521if __name__ == "__main__":
1522    if IS_LINUX and HAS_GPU:
1523        run_tests()
1524