xref: /aosp_15_r20/external/pytorch/test/inductor/test_perf.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import contextlib
3import re
4from unittest.mock import patch
5
6import functorch
7import torch
8import torch._inductor.config as config
9import torch.autograd
10from torch._inductor import metrics
11from torch._inductor.compile_fx import compile_fx, compile_fx_inner
12from torch._inductor.test_case import TestCase as InductorTestCase
13from torch._inductor.utils import run_and_get_code
14
15########################
16# Explanation of Tests #
17########################
18# These tests are all testing *memory accesses* of TorchInductor.
19# They are intended to be deterministic performance tests.
20# The expect tests are all measuring the number of memory bytes read/written by
21# the code that Inductor has generated
22#
23# If the test is failing because the number became smaller, feel free to lower it.
24# On the other hand, if the test is failing because the number became larger,
25# that means that your change is leading to *more* memory accesses on this test.
26#
27# That may still be aceeptable, but be aware that you are likely lowering
28# performance for that setting.
29#
30# Defines all the kernels for tests
31from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda
32
33
34if HAS_CUDA:
35    import triton
36    import triton.language as tl
37
38    from torch.testing._internal.triton_utils import add_kernel
39
40aten = torch.ops.aten
41
42
43def compile_but_use_eager(gm, example_inputs):
44    def inner_compile(gm, *args, **kwargs):
45        compile_fx_inner(gm, *args, **kwargs)
46        return gm
47
48    return compile_fx(gm, example_inputs, inner_compile=inner_compile)
49
50
51def count_numel(f, *args):
52    """
53    Assumes all inputs are fp32
54    """
55    metrics.reset()
56    torch.compile(f, backend=compile_but_use_eager)(*args)
57    print(metrics.nodes_num_elem)
58    return str(metrics.num_bytes_accessed // 4)
59
60
61def count_numel_train(f, *args):
62    """
63    Assumes all inputs are fp32
64    """
65    metrics.reset()
66
67    f = torch.compile(f, backend=compile_but_use_eager)
68    out = f(*args)
69    res = 0
70    for o in out:
71        res += o.mean()
72    res.backward()
73    print(metrics.nodes_num_elem)
74    return str(metrics.num_bytes_accessed // 4)
75
76
77DEVICE = "cuda"
78
79
80def T(*size, dtype=torch.float32, device=DEVICE, grad=False):
81    return torch.randn(size, dtype=dtype, device=device, requires_grad=grad)
82
83
84def TI(*size, mx=10, dtype=torch.int32, device=DEVICE):
85    return torch.randint(0, mx, size, dtype=dtype, device=device)
86
87
88class TestCase(InductorTestCase):
89    device = DEVICE
90
91
92class NumBytesMetricTests(TestCase):
93    """
94    Primarily used for sanity testing that the num_bytes_accessed metrics is correct.
95    """
96
97    def test_pointwise(self):
98        def f(x):
99            return x.cos()
100
101        inp = (T(10),)
102        self.assertExpectedInline(count_numel(f, *inp), """20""")
103
104        def f(x, y):
105            return x + y
106
107        inp = (T(10), T(10))
108        self.assertExpectedInline(count_numel(f, *inp), """30""")
109
110        def f(x, y):
111            return x + y
112
113        inp = (T(10, 10), T(10))
114        self.assertExpectedInline(count_numel(f, *inp), """210""")
115
116        def f(x):
117            return x + x
118
119        inp = (T(10),)
120        self.assertExpectedInline(count_numel(f, *inp), """20""")
121
122        def f(x):
123            return x + x.t()
124
125        inp = (T(10, 10),)
126        self.assertExpectedInline(count_numel(f, *inp), """200""")
127
128        def f(a, b, c):
129            return a.cos(), b.sin() + c.sin()
130
131        inp = (T(10), T(10), T(10))
132        self.assertExpectedInline(count_numel(f, *inp), """50""")
133
134    def test_reduction(self):
135        def f(x):
136            return x.sum(dim=1)
137
138        inp = (T(10, 10),)
139        self.assertExpectedInline(count_numel(f, *inp), """110""")
140
141        def f(x):
142            return x.sum(dim=0)
143
144        inp = (T(10, 10),)
145        self.assertExpectedInline(count_numel(f, *inp), """110""")
146
147    def test_extern(self):
148        def f(x):
149            return torch.mm(x, x)
150
151        inp = (T(10, 10),)
152        self.assertExpectedInline(count_numel(f, *inp), """200""")
153
154        def f(a, b):
155            return torch.mm(a, b)
156
157        inp = (T(10, 10), T(10, 10))
158        self.assertExpectedInline(count_numel(f, *inp), """300""")
159
160        def f(x):
161            x = x.cos()
162            x = torch.mm(x, x)
163            x = x.cos()
164            return x
165
166        inp = (T(10, 10),)
167        self.assertExpectedInline(count_numel(f, *inp), """600""")
168
169        def f(x):
170            a = x.cos()
171            b = x.sin()
172            x = torch.mm(a, b)
173            return x
174
175        inp = (T(10, 10),)
176        self.assertExpectedInline(count_numel(f, *inp), """600""")
177
178    def test_cat(self):
179        def f(a, b):
180            return torch.cat([a.sin(), b.sin()])
181
182        inp = (T(10), T(10))
183        self.assertExpectedInline(count_numel(f, *inp), """40""")
184
185        def f(a, b):
186            return torch.cat([a, b])
187
188        inp = (T(10), T(10))
189        self.assertExpectedInline(count_numel(f, *inp), """40""")
190
191        def f(a, b):
192            return torch.cat([a.cos(), b])
193
194        inp = (T(10), T(10))
195        self.assertExpectedInline(count_numel(f, *inp), """40""")
196
197        def f(a):
198            return torch.cat([a.cos(), a.sin()])
199
200        inp = (T(10),)
201        self.assertExpectedInline(count_numel(f, *inp), """30""")
202
203        def f(a, b):
204            return torch.cat([torch.mm(a, a), b.sin()])
205
206        inp = (T(10, 10), T(10, 10))
207        self.assertExpectedInline(count_numel(f, *inp), """400""")
208
209        def f(a, b, c):
210            return torch.cat((a + 1, b + 2, c + 3)) + 10
211
212        inp = (T(10, 10), T(10, 10), T(10, 10))
213        self.assertExpectedInline(count_numel(f, *inp), """600""")
214
215        def f(a, b, c, d, e):
216            return torch.cat((a + 1, b + 2, c + 3, d + 4, e + 5)) + 10
217
218        inp = [T(10, 10) for _ in range(5)]
219        self.assertExpectedInline(count_numel(f, *inp), """1000""")
220
221        def f(a, b):
222            return torch.cat([a.sum(dim=0), b.sum(dim=0)]) + 10
223
224        inp = [T(10, 10, 10), T(10, 10, 10)]
225        self.assertExpectedInline(count_numel(f, *inp), """2600""")
226
227    def test_cat_pointwise(self):
228        def f(a, b):
229            return torch.cat([torch.softmax(a, dim=-1), torch.softmax(b, dim=-1)])
230
231        inp = (T(10, 10), T(10, 10))
232        self.assertExpectedInline(count_numel(f, *inp), """400""")
233
234        def f(a, b):
235            return torch.cat([torch.softmax(a, dim=-1), torch.softmax(b, dim=-1)]).cos()
236
237        inp = (T(10, 10), T(10, 10))
238        self.assertExpectedInline(count_numel(f, *inp), """680""")
239
240        # Should turn into pointwise even if only some of inputs are pointwise.
241        def f(a, b):
242            out = torch.cat([a.cos(), torch.mm(b, b)])
243            return out.cos()
244
245        inp = (T(10, 10), T(10, 10))
246        self.assertExpectedInline(count_numel(f, *inp), """600""")
247
248        # Should not turn into pointwise if all inputs are not pointwise
249        def f(a, b):
250            out = torch.cat([torch.mm(a, a), torch.mm(b, b)])
251            return out.cos()
252
253        inp = (T(10, 10), T(10, 10))
254        self.assertExpectedInline(count_numel(f, *inp), """800""")
255
256        def f(a, b):
257            out = torch.cat([a, b])
258            return out.cos()
259
260        inp = (T(10, 10), T(10, 10))
261        self.assertExpectedInline(count_numel(f, *inp), """400""")
262
263        def f(a, b):
264            b = b.cos()
265            return torch.cat([a, b])
266
267        inp = (T(10, 10), T(10, 10))
268        self.assertExpectedInline(count_numel(f, *inp), """400""")
269
270        def f(a, b):
271            a = a @ a
272            return torch.constant_pad_nd(torch.cat([a, b]), [2, 2], 0.5)
273
274        inp = (T(10, 10), T(10, 10))
275        self.assertExpectedInline(count_numel(f, *inp), """680""")
276
277    @patch.object(config, "split_cat_fx_passes", False)
278    @patch.object(
279        config,
280        "pre_grad_fusion_options",
281        {
282            "batch_linear": {},
283            "batch_linear_lhs": {},
284            "batch_layernorm": {},
285            "batch_tanh": {},
286            "batch_relu": {},
287            "batch_sigmoid": {},
288        },
289    )
290    @patch.object(config, "post_grad_fusion_options", {})
291    def test_cat_pointwise_many_complex_inputs(self):
292        def f(*inputs):
293            input = [torch.nn.functional.gelu(val) for val in inputs]
294            return torch.cat(input) + 10
295
296        inp = (T(10, 10) for _ in range(16))
297        self.assertExpectedInline(count_numel(f, *inp), """6400""")
298
299    @patch.object(config, "split_cat_fx_passes", False)
300    @patch.object(
301        config,
302        "pre_grad_fusion_options",
303        {
304            "batch_linear": {},
305            "batch_linear_lhs": {},
306            "batch_layernorm": {},
307            "batch_tanh": {},
308            "batch_relu": {},
309            "batch_sigmoid": {},
310        },
311    )
312    @patch.object(config, "post_grad_fusion_options", {})
313    def test_cat_pointwise_many_simple_inputs(self):
314        def f(*inputs):
315            input = [torch.nn.functional.relu(val) for val in inputs]
316            return torch.cat(input) + 10
317
318        inp = (T(10, 10) for _ in range(16))
319        self.assertExpectedInline(count_numel(f, *inp), """9600""")
320
321    @patch.object(config, "max_pointwise_cat_inputs", 0)
322    def test_cat_pointwise_config_option(self):
323        def f(a, b):
324            return torch.cat([a + 1, b + 2]) + 3
325
326        inp = (T(10, 10), T(10, 10))
327        self.assertExpectedInline(count_numel(f, *inp), """400""")
328
329    def test_index(self):
330        def f(a, b):
331            return a[b]
332
333        inp = (T(10), TI(10, mx=10))
334        self.assertExpectedInline(count_numel(f, *inp), """30""")
335
336
337class FusionTests(TestCase):
338    """
339    Tests that things can be fused into a single kernel
340    """
341
342    def test_horizontal_reduction_pointwise(self):
343        def f(a):
344            b = a.sum(dim=1)
345            c = a.cos()
346            return b, c
347
348        inp = (T(10, 10),)
349        self.assertExpectedInline(count_numel(f, *inp), """210""")
350
351    def test_horizontal_reduction_reduction(self):
352        def f(a):
353            b = a.sum(dim=1)
354            c = a.amax(dim=1)
355            return b, c
356
357        inp = (T(10, 10),)
358        self.assertExpectedInline(count_numel(f, *inp), """120""")
359
360    def test_horizontal_reduction_pointwise2(self):
361        def f(a, b):
362            c = a.sum(dim=1)
363            b = b.cos()
364            return b + c
365
366        inp = (T(10, 10), T(10))
367        self.assertExpectedInline(count_numel(f, *inp), """120""")
368
369    def test_horizontal_reduction_outer_pointwise(self):
370        def f(a, b):
371            c = a.sum(dim=0)
372            b = b.cos()
373            return b + c
374
375        inp = (T(10, 10), T(10))
376        self.assertExpectedInline(count_numel(f, *inp), """120""")
377
378    def test_horizontal_sum_pw_broadcast(self):
379        def f(a, b):
380            a = a.sum(dim=1, keepdim=True)
381            b = b.cos()
382            return a * b
383
384        inp = (T(10, 10), T(10))
385        self.assertExpectedInline(count_numel(f, *inp), """210""")
386
387    def test_vertical_sum_pw(self):
388        def f(a):
389            a = a.cos()
390            a = a.sum(dim=1)
391            return a.cos()
392
393        inp = (T(10, 10),)
394        self.assertExpectedInline(count_numel(f, *inp), """110""")
395
396    def test_norm_chain(self):
397        def f(a):
398            b = a.sum(dim=1, keepdim=True)
399            a = a * b
400            b = a.sum(dim=1, keepdim=True)
401            a = a * b
402            b = a.sum(dim=1, keepdim=True)
403            a = a * b
404            return a
405
406        inp = (T(10, 10),)
407        self.assertExpectedInline(count_numel(f, *inp), """200""")
408
409    def test_softmax_inner(self):
410        def f(a):
411            return torch.softmax(a, dim=1)
412
413        inp = (T(10, 10),)
414        self.assertExpectedInline(count_numel(f, *inp), """200""")
415
416    def test_layer_norm(self):
417        # TODO: Suboptimal! We shouldn't need to save normalization stats.
418        mod = torch.nn.LayerNorm(10, device=self.device)
419
420        def f(x):
421            return mod(x)
422
423        inp = (T(10, 10),)
424        with torch.no_grad():
425            self.assertExpectedInline(count_numel(f, *inp), """220""")
426
427    def test_double_softmax(self):
428        def f(x):
429            x = torch.softmax(x, dim=1)
430            x = torch.softmax(x, dim=1)
431            return x
432
433        inp = (T(10, 10),)
434        self.assertExpectedInline(count_numel(f, *inp), """200""")
435
436    def test_softmax_backward(self):
437        def f(grad_out, out):
438            return aten._softmax_backward_data(grad_out, out, 1, torch.float32)
439
440        inp = (T(10, 10), T(10, 10))
441        self.assertExpectedInline(count_numel(f, *inp), """300""")
442
443    def test_neighbor(self):
444        def f(a, b):
445            return ((a - b) ** 2).sum(dim=-1).amax(dim=1)
446
447        inp = (T(10, 1, 4), T(1, 10, 4))
448        self.assertExpectedInline(count_numel(f, *inp), """90""")
449
450    def test_factory_reduction(self):
451        def f():
452            a = torch.ones(10, device=self.device)
453            b = torch.ones(10, 10, device=self.device)
454            return (a + b).sum(dim=-1)
455
456        inp = ()
457        self.assertExpectedInline(count_numel(f, *inp), """10""")
458
459    def test_index_pointwise(self):
460        def f(a, b):
461            return a[b].cos()
462
463        inp = (T(10, 10), TI(20, mx=10))
464        self.assertExpectedInline(count_numel(f, *inp), """320""")
465
466    def test_index_reduction(self):
467        def f(a, b):
468            return a[b].cos().sum(dim=1)
469
470        inp = (T(10, 10), TI(20, mx=10))
471        self.assertExpectedInline(count_numel(f, *inp), """140""")
472
473    def test_mutation_fusion(self):
474        def f(a, b, c):
475            a0 = a.add(c)
476            b0 = b.add(a0)
477            b.copy_(b0)
478            a.copy_(a0)
479
480        inp = (T(10, 10), T(10, 10), T(10, 10))
481        self.assertExpectedInline(count_numel(f, *inp), """500""")
482
483    def test_reduction_pointwise_multi_level_reduction(self):
484        hidden_size = 4096
485        layer_norm = torch.nn.LayerNorm(hidden_size).cuda().float()
486
487        @torch.inference_mode()
488        def f(x, scale, amax_keep_dim):
489            x = layer_norm(x.to(dtype=torch.float))
490            amax = torch.amax(torch.abs(x), keepdim=amax_keep_dim)
491            x_scaled = x * scale
492            y = torch.nn.functional.sigmoid(x_scaled)
493            return (y, amax)
494
495        inp = (T(4, 2048, hidden_size, dtype=torch.float), T(1, dtype=torch.float))
496
497        # 2 kernels:
498        # kernel 1: (input = X, scale, LN scale, LN bias, output = LN_pointwise(X), first-level amax (split-reduction))
499        # kernel 2: (input = first-level amax, output = final amax)
500        # scale (1) + X (4*2048*hidden_size) * 2 + LN scale (hidden_size) + LN bias (hidden_size) + amax (4 * 2048 * 2 + 1)
501        expected_numel = (
502            1 + hidden_size * 2 + 4 * 2048 * hidden_size * 2 + 4 * 2048 * 2 + 1
503        )
504        self.assertExpectedInline(count_numel(f, *inp, True), str(expected_numel))
505        self.assertExpectedInline(count_numel(f, *inp, False), str(expected_numel))
506
507    def test_pointwise_multi_level_reduction(self):
508        # TODO: this can be optimized by having the first pointwise kernel leveraging block sizes
509        # of the first-level reduction kernel.
510        hidden_size = 4096
511
512        def f(x, scale, amax_keep_dim):
513            x = x * 1.1
514            amax = torch.amax(torch.abs(x), keepdim=amax_keep_dim)
515            x_scaled = x * scale
516            y = torch.nn.functional.sigmoid(x_scaled)
517            return (y, amax)
518
519        inp = (T(4, 2048, hidden_size, dtype=torch.float), T(1, dtype=torch.float))
520
521        compiled_f = torch.compile(f)
522        compiled_f(*inp, True)
523
524        # 3 kernels:
525        # kernel 1: (input = X, scale, output = pointwise(X))
526        # kernel 2: (input = X, output = first-level amax)
527        # kernel 3: (input = first-level amax, output = final amax)
528        # scale (1) + X (4*2048*hidden_size) * 3 + amax (num_splits * 2 + 1)
529        # num_splits depends on SM architectures.
530        expected_numel = 1 + 4 * 2048 * hidden_size * 3 + 1
531        actual_numel_amax_keep_dim = count_numel(f, *inp, True)
532        actual_numel_amax_no_keep_dim = count_numel(f, *inp, False)
533        self.assertEqual(actual_numel_amax_keep_dim, actual_numel_amax_no_keep_dim)
534        self.assertGreaterAlmostEqual(actual_numel_amax_keep_dim, str(expected_numel))
535
536
537class SchedulerFusionTests(TestCase):
538    """
539    Testing the fusion group creation heuristic (i.e. cases where we can't fuse
540    everything into a single kernel)
541    Disables inductor rematerialization for easier reasoning of tests.
542    """
543
544    @classmethod
545    def setUpClass(cls):
546        super().setUpClass()
547        cls._stack = contextlib.ExitStack()
548        cls._stack.enter_context(patch.object(config, "realize_opcount_threshold", 0))
549
550    @classmethod
551    def tearDownClass(cls):
552        cls._stack.close()
553        super().tearDownClass()
554
555    @patch.object(config, "pattern_matcher", False)
556    def test_fusion_choice1(self):
557        # Doesn't matter where we break fusion group here
558        def f(a):
559            c = a.cos()
560            d = torch.mm(c, c)
561            e = c.cos()
562            return d + e
563
564        inp = (T(10, 10),)
565        self.assertExpectedInline(count_numel(f, *inp), """700""")
566
567    @patch.object(config, "pattern_matcher", False)
568    def test_fusion_choice2(self):
569        # We should materialize e (it's smaller!)
570        # [c, e]: 210, [f]: 210, [d]: 200
571        def f(a):
572            c = a.cos()
573            d = torch.mm(c, c)
574            e = c.sum(dim=1)
575            f = d + e
576            return f
577
578        inp = (T(10, 10),)
579        self.assertExpectedInline(count_numel(f, *inp), """620""")
580
581    @patch.object(config, "pattern_matcher", False)
582    def test_fusion_choice3(self):
583        # We should materialize e.
584        # [c, e]: 300, [f]: 300, [d]: 200
585        def f(a):
586            c = a.cos()
587            d = torch.mm(c, c)
588            e = c + a
589            f = d + e
590            return f, e
591
592        inp = (T(10, 10),)
593        self.assertExpectedInline(count_numel(f, *inp), """800""")
594
595    @patch.object(config, "pattern_matcher", False)
596    def test_fusion_choice4_cpu(self):
597        # Fuse nodes with same number of elements and compatible orginal var ranges
598        # [buf0: {d0: 60, d1: 11}, buf1: {d0: 660}] -> buf0_buf1
599        def f(x, w):
600            o1 = x * w
601            output = o1 + 1.0
602            return output
603
604        inp = (T(2, 3, 10, 11, device="cpu"), T(11, device="cpu"))
605        self.assertExpectedInline(count_numel(f, *inp), """1331""")
606
607        # [buf0_buf1: {d0: 60, d1: 11}, buf2: {d0: 660}] -> buf0_buf1_buf2
608        def f(x, w1, w2):
609            o1 = x * w1
610            o2 = x * w2
611            output = o1 + o2
612            return output
613
614        inp = (T(2, 3, 10, 11, device="cpu"), T(11, device="cpu"), T(11, device="cpu"))
615        self.assertExpectedInline(count_numel(f, *inp), """1342""")
616
617
618class TilingTests(TestCase):
619    def test_tiling_simple(self):
620        def f(a, b):
621            return a + b.t()
622
623        inp = (T(10, 10), T(10, 10))
624        self.assertExpectedInline(count_numel(f, *inp), """300""")
625
626        def f(a, b):
627            return a.t() + b
628
629        inp = (T(10, 10), T(10, 10))
630        self.assertExpectedInline(count_numel(f, *inp), """300""")
631
632    def test_tiling_three(self):
633        def f(a, b, c):
634            return a + b.permute(1, 2, 0) + c.permute(2, 0, 1)
635
636        inp = (T(10, 10, 10), T(10, 10, 10), T(10, 10, 10))
637        self.assertExpectedInline(count_numel(f, *inp), """4000""")
638
639
640class MinCutPartitioningTests(TestCase):
641    def test_partitioning_full_remat(self):
642        def f(x):
643            return x.cos().cos().cos()
644
645        inp = (T(10, grad=True),)
646        self.assertExpectedInline(count_numel_train(f, *inp), """50""")
647
648    def test_partitioning_partial_remat(self):
649        def f(a, b, c, d):
650            x = a + b + c + d
651            return x.cos().cos()
652
653        inp = (T(10, grad=True), T(10, grad=True), T(10, grad=True), T(10, grad=True))
654        self.assertExpectedInline(count_numel_train(f, *inp), """90""")
655
656    def test_partitioning_dtype(self):
657        def f(x):
658            return (x < 0) * x
659
660        inp = (T(100, grad=True),)
661        self.assertExpectedInline(count_numel_train(f, *inp), """450""")
662
663    @patch.object(functorch.compile.config, "max_dist_from_bw", 1000)
664    def test_partitioning_unremat_bw(self):
665        def f(x):
666            return torch.mm(x, x.new_ones(x.shape)).tanh().tanh()
667
668        inp = (T(10, 10, grad=True),)
669        self.assertExpectedInline(count_numel_train(f, *inp), """1300""")
670
671    @patch.object(config, "pattern_matcher", False)
672    def test_partitioning_unremat_bw2(self):
673        def f(a):
674            a = torch.mm(a, a)
675            a = a + 1
676            b = a + 2
677            c = torch.mm(a, b)
678            return c
679
680        inp = (T(10, 10, grad=True),)
681        self.assertExpectedInline(count_numel_train(f, *inp), """2600""")
682
683    def test_partitioning_keops(self):
684        def f(a, b):
685            return (a * b).cos().sum(dim=1)
686
687        inp = (T(20, 1, grad=True), T(1, 20, grad=True))
688        self.assertExpectedInline(count_numel_train(f, *inp), """220""")
689
690    def test_partitioning_cat(self):
691        def f(a, b):
692            a = torch.tanh(a)
693            return torch.cat([a, b])
694
695        inp = (T(10, grad=True), T(10, grad=True))
696        self.assertExpectedInline(count_numel_train(f, *inp), """70""")
697
698    def test_partitioning_with_view(self):
699        class Foo(torch.autograd.Function):
700            @staticmethod
701            def forward(ctx, x):
702                y = x.sin()
703                x = x.cos()
704                x = x.view(10, 10)
705                ctx.save_for_backward(x, y)
706                x = x.cos()
707                return x
708
709            @staticmethod
710            def backward(ctx, gradOut):
711                x, y = ctx.saved_tensors
712                return torch.mm(gradOut, x).view(100) * y
713
714        def f(a):
715            return Foo.apply(a)
716
717        inp = (T(100, grad=True),)
718        # We do not want to recompute the x.cos().view() chain, as it's
719        # materialized in backwards
720        self.assertExpectedInline(count_numel_train(f, *inp), """900""")
721
722    @patch.object(config, "pattern_matcher", False)
723    def test_partitioning_long_chain_add(self):
724        def f(x):
725            orig = x
726            for _ in range(2):
727                x = x * x
728                x = torch.mm(x, x)
729                x = x * 2
730                x = orig + x
731                orig = x
732            return x
733
734        inp = (T(10, 10, grad=True),)
735        self.assertExpectedInline(count_numel_train(f, *inp), """3900""")
736
737
738def unfusible(x):
739    # For the purpose of noop tests, we want inductor to fall back to
740    # eager mode, so, below we must use a aten operator that does not
741    # have decomposition nor lowering:
742    return aten._lazy_clone(x)
743
744
745class NoopTests(TestCase):
746    def test_noop_clones(self):
747        def f(a):
748            b = a.clone()
749            b = unfusible(b)
750            return b
751
752        inp = T(10)
753        self.assertExpectedInline(count_numel(f, inp), """20""")
754
755        def f(a):
756            b = a.clone()
757            c = unfusible(b)
758            return b, c
759
760        self.assertExpectedInline(count_numel(f, inp), """40""")
761
762    def test_noop_slice_scatter(self):
763        def f(a):
764            b = aten.slice_scatter(a, a)
765            c = unfusible(b)
766            return c
767
768        inp = T(10)
769        self.assertExpectedInline(count_numel(f, inp), """20""")
770
771    def test_noop_dtype_conversion(self):
772        def f(a):
773            b = torch.ops.prims.convert_element_type(a, torch.float32)
774            c = unfusible(b)
775            return c
776
777        inp = T(10)
778        self.assertExpectedInline(count_numel(f, inp), """20""")
779
780    def test_noop_device_conversion(self):
781        def f(a):
782            b = torch.ops.prims.device_put(a, "cuda")
783            c = unfusible(b)
784            return c
785
786        inp = T(10)
787        self.assertExpectedInline(count_numel(f, inp), """20""")
788
789    def test_noop_int_ops(self):
790        def f1(a):
791            b = torch.ceil(a)
792            c = unfusible(b)
793            return c
794
795        def f2(a):
796            d = torch.floor(a)
797            e = unfusible(d)
798            return e
799
800        def f3(a):
801            f = torch.round(a)
802            g = unfusible(f)
803            return g
804
805        def f4(a):
806            f = torch.pow(a, 1)
807            g = unfusible(f)
808            return g
809
810        inp = TI(10)
811        self.assertExpectedInline(count_numel(f1, inp), """20""")
812        self.assertExpectedInline(count_numel(f2, inp), """20""")
813        self.assertExpectedInline(count_numel(f3, inp), """20""")
814        self.assertExpectedInline(count_numel(f4, inp), """20""")
815
816    def test_noop_cat(self):
817        def f1(a):
818            b = torch.cat([a])
819            return unfusible(b)
820
821        inp = T(10)
822        self.assertExpectedInline(count_numel(f1, inp), """20""")
823
824        def f2(a):
825            b = torch.cat([a])
826            c = torch.cat([b])
827            return c
828
829        self.assertExpectedInline(count_numel(f2, inp), """20""")
830
831
832class InplacingTests(TestCase):
833    def test_inplace_scatter(self):
834        def f(a, b):
835            a = a.cos()
836            a[b] = 1
837            return a
838
839        inp = (T(10), TI(2, mx=5))
840        self.assertExpectedInline(count_numel(f, *inp), """26""")
841
842        def f(a, b):
843            out = aten.index_put(a, (b,), torch.tensor(1.0))
844            return a.copy_(out)
845
846        inp = (T(10), TI(2, mx=5))
847        self.assertExpectedInline(count_numel(f, *inp), """6""")
848
849        def f(a, b):
850            out = aten._unsafe_index_put(a, (b,), torch.tensor(1.0))
851            return a.copy_(out)
852
853        inp = (T(10), TI(2, mx=5))
854        self.assertExpectedInline(count_numel(f, *inp), """6""")
855
856    def test_inplace_scatter_noop_view(self):
857        def f(a, b):
858            a[:, b] = 1
859            return a
860
861        inp = (T(10, 10), TI(2, mx=5))
862        self.assertExpectedInline(count_numel(f, *inp), """42""")
863
864    @requires_cuda
865    def test_inplace_triton_kernel_training(self):
866        @triton.jit
867        def sin_kernel(
868            in_ptr0,
869            out_ptr,
870            n_elements,
871            BLOCK_SIZE: "tl.constexpr",
872        ):
873            pid = tl.program_id(axis=0)
874            block_start = pid * BLOCK_SIZE
875            offsets = block_start + tl.arange(0, BLOCK_SIZE)
876            mask = offsets < n_elements
877            x = tl.load(in_ptr0 + offsets, mask=mask)
878            output = tl.sin(x)
879            tl.store(out_ptr + offsets, output, mask=mask)
880
881        def sin_triton(x, out):
882            n_elements = x.numel()
883            sin_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
884
885        factory_op = torch.empty_like
886
887        class MySin(torch.autograd.Function):
888            @staticmethod
889            def forward(ctx, x):
890                out = factory_op(x)
891                sin_triton(x, out)
892                ctx.save_for_backward(out)
893                return out
894
895            @staticmethod
896            def backward(ctx, grad):
897                (saved,) = ctx.saved_tensors
898                out = factory_op(grad)
899                sin_triton(saved, out)
900                return out
901
902        def f(x):
903            return MySin.apply(x)
904
905        x = T(3, grad=True)
906        self.assertExpectedInline(count_numel_train(f, x), """9""")
907
908    @requires_cuda
909    def test_inplace_custom_op_training_two_mutated_inputs(self):
910        @torch.library.custom_op(
911            "_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"}
912        )
913        def sin_cos(
914            x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor
915        ) -> None:
916            out_sin.copy_(x.sin())
917            out_cos.copy_(x.cos())
918
919        def f(x):
920            out0 = torch.empty_like(x)
921            out1 = torch.empty_like(x)
922            sin_cos(x, out0, out1)
923            return x.clone(), out0, out1
924
925        x = T(3, grad=True)
926        self.assertExpectedInline(count_numel(f, x), """21""")
927
928    @requires_cuda
929    def test_inplace_custom_op_training(self):
930        @torch.library.custom_op("_reinplacing::sin", mutates_args={"result"})
931        def sin(x: torch.Tensor, result: torch.Tensor) -> None:
932            result.copy_(x.sin())
933
934        factory_op = torch.empty_like
935
936        class MySin(torch.autograd.Function):
937            @staticmethod
938            def forward(ctx, x):
939                out = factory_op(x)
940                sin(x, out)
941                ctx.save_for_backward(out)
942                return out
943
944            @staticmethod
945            def backward(ctx, grad):
946                (saved,) = ctx.saved_tensors
947                out = factory_op(grad)
948                sin(saved, out)
949                return out
950
951        def f(x):
952            return MySin.apply(x)
953
954        x = T(3, grad=True)
955        self.assertExpectedInline(count_numel_train(f, x), """9""")
956
957    @requires_cuda
958    def test_inplace_custom_op(self):
959        with torch.library._scoped_library("mylib", "FRAGMENT") as m:
960            m.define("foo(Tensor x, Tensor(a!) out) -> ()")
961
962            def foo(x: torch.Tensor, out: torch.Tensor) -> None:
963                out.copy_(x.sin())
964
965            m.impl("foo", foo, "CompositeExplicitAutograd")
966
967            def f(x, out):
968                torch.ops.mylib.foo(x, out)
969                torch.ops.mylib.foo(out, out)
970                torch.ops.mylib.foo(out, out)
971                return out
972
973            x = T(3)
974            out = T(3)
975
976            compiled_out, (code,) = run_and_get_code(
977                torch.compile(f, fullgraph=True), x, out
978            )
979            self.assertEqual(compiled_out, x.sin().sin().sin())
980
981            # Check that we are allocating the minimum number of intermediate buffers
982            matches = re.findall(r"empty_strided_\w+\(", code)
983            self.assertEqual(len(matches), 0)
984
985            self.assertExpectedInline(count_numel(f, x, out), """21""")
986
987    @requires_cuda
988    def test_inplace_custom_op_intermediate(self):
989        with torch.library._scoped_library("mylib", "FRAGMENT") as m:
990            m.define("foo(Tensor x, Tensor(a!) out) -> ()")
991
992            def foo(x: torch.Tensor, out: torch.Tensor) -> None:
993                out.copy_(x.sin())
994
995            m.impl("foo", foo, "CompositeExplicitAutograd")
996
997            def f(x, out):
998                out = torch.empty_like(x)
999                torch.ops.mylib.foo(x, out)
1000                torch.ops.mylib.foo(out, out)
1001                torch.ops.mylib.foo(out, out)
1002                return out
1003
1004            x = T(3)
1005            out = T(3)
1006
1007            compiled_out, (code,) = run_and_get_code(
1008                torch.compile(f, fullgraph=True), x, out
1009            )
1010            self.assertEqual(compiled_out, x.sin().sin().sin())
1011
1012            # Check that we are allocating the minimum number of intermediate buffers
1013            matches = re.findall(r"empty_strided_\w+\(", code)
1014            self.assertEqual(len(matches), 1)
1015
1016            self.assertExpectedInline(count_numel(f, x, out), """21""")
1017
1018    @requires_cuda
1019    def test_inplace_custom_op_two_mutated_inputs(self):
1020        with torch.library._scoped_library("mylib", "FRAGMENT") as m:
1021            m.define("foo(Tensor q, Tensor(a!) k_cache, Tensor(b!) v_cache) -> Tensor")
1022
1023            def foo(q, k_cache, v_cache):
1024                k_cache.add_(1)
1025                v_cache.add_(1)
1026                return q + 1
1027
1028            m.impl("foo", foo, "CompositeExplicitAutograd")
1029
1030            q = T(3)
1031            k_cache = T(3)
1032            v_cache = torch.rand_like(k_cache)
1033
1034            def f():
1035                x = 0
1036                for _ in range(2):
1037                    x = x + torch.ops.mylib.foo(q, k_cache, v_cache)
1038                return x
1039
1040            compiled_out, (code,) = run_and_get_code(
1041                torch.compile(f, fullgraph=True),
1042            )
1043
1044            # Check that we are allocating the minimum number of intermediate buffers
1045            matches = re.findall(r"empty_strided_\w+\(", code)
1046            self.assertEqual(len(matches), 1)
1047
1048            self.assertExpectedInline(count_numel(f), """39""")
1049
1050    @requires_cuda
1051    def test_inplace_triton_kernel_v1(self):
1052        def f(x: torch.Tensor, y: torch.Tensor):
1053            output = torch.zeros_like(x)
1054            n_elements = output.numel()
1055            grid = (n_elements,)
1056            add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
1057            return output
1058
1059        inp = (T(10), T(10))
1060        self.assertExpectedInline(count_numel(f, *inp), """50""")
1061
1062    @requires_cuda
1063    def test_inplace_triton_kernel_v2(self):
1064        def f(x: torch.Tensor, y: torch.Tensor):
1065            output = torch.zeros_like(x)
1066            n_elements = output.numel()
1067            grid = (n_elements,)
1068            tmp = torch.add(x, 1)
1069            add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
1070            return output, tmp
1071
1072        inp = (T(10), T(10))
1073        self.assertExpectedInline(count_numel(f, *inp), """70""")
1074
1075    @requires_cuda
1076    def test_inplace_triton_kernel_v3(self):
1077        def f(x: torch.Tensor, y: torch.Tensor):
1078            output = torch.zeros_like(x)
1079            n_elements = output.numel()
1080            grid = (n_elements,)
1081            add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
1082            x.add_(1)
1083            return output
1084
1085        inp = (T(10), T(10))
1086        self.assertExpectedInline(count_numel(f, *inp), """80""")
1087
1088    @requires_cuda
1089    def test_inplace_triton_kernel_v4(self):
1090        def f(x: torch.Tensor, y: torch.Tensor):
1091            x_view = x.view(-1)
1092            output = torch.zeros_like(x)
1093            n_elements = output.numel()
1094            grid = (n_elements,)
1095            add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
1096            output2 = x_view.mul(2)
1097            return output, output2
1098
1099        inp = (T(10), T(10))
1100        self.assertExpectedInline(count_numel(f, *inp), """70""")
1101
1102    @requires_cuda
1103    def test_inplace_triton_kernel_v5(self):
1104        def f(x: torch.Tensor, y: torch.Tensor):
1105            x_view = x.view(-1)
1106            output = torch.zeros_like(x)
1107            n_elements = output.numel()
1108            grid = (n_elements,)
1109            add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
1110            x_view.mul_(2)
1111            return output
1112
1113        inp = (T(10), T(10))
1114        self.assertExpectedInline(count_numel(f, *inp), """80""")
1115
1116    @requires_cuda
1117    def test_inplace_triton_kernel_v6(self):
1118        def f(x: torch.Tensor, y: torch.Tensor):
1119            output = torch.zeros_like(x)
1120            n_elements = output.numel()
1121            grid = (n_elements,)
1122            add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
1123            return output
1124
1125        t = T(10)
1126        inp = (t, t.view(-1))
1127        self.assertExpectedInline(count_numel(f, *inp), """50""")
1128
1129    def test_inplace_randperm_scatter(self):
1130        def scaled_index_add(x, y, scale_y):
1131            index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]]
1132            out = x.index_add_(dim=0, source=y * scale_y, index=index)
1133            return out
1134
1135        inp = (T(10, 10), T(5, 10), T(10))
1136        self.assertExpectedInline(count_numel(scaled_index_add, *inp), """250""")
1137
1138
1139# Test cases where we don't do the right thing yet.
1140class WouldBeNiceIfItWorked:
1141    def test_horizontal(self):
1142        def f(a):
1143            b = a.sum(dim=0)
1144            c = a.cos()
1145            return b, c
1146
1147        inp = (T(10, 10),)
1148        self.assertExpectedInline(count_numel(f, *inp), """210""")
1149
1150    # TODO: We aren't fusing outer dim softmaxes
1151    def test_softmax_outer(self):
1152        def f(a):
1153            return torch.softmax(a, dim=0)
1154
1155        inp = (T(10, 10),)
1156        self.assertExpectedInline(count_numel(f, *inp), """200""")
1157
1158    # TODO: The greedy fusion strategy results in suboptimal grouping
1159    @patch.object(config, "realize_opcount_threshold", 0)
1160    def test_fusion_choice4(self):
1161        def f(a, b, b2):
1162            c = a + b
1163            d = torch.mm(c, c)
1164            e = c + b + b2
1165            f = d + e + b2
1166            return f, e
1167
1168        inp = (T(10, 10), T(10, 10, dtype=torch.float16), T(10, 10))
1169        self.assertExpectedInline(count_numel(f, *inp), """1000""")
1170
1171    # TODO: We materialize the intermediate if we don't unroll the reduction
1172    def test_neighbor(self):
1173        def f(a, b):
1174            return ((a - b) ** 2).sum(dim=-1).amax(dim=1)
1175
1176        inp = (T(10, 1, 8), T(1, 10, 8))
1177        self.assertExpectedInline(count_numel(f, *inp), """170""")
1178
1179
1180if __name__ == "__main__":
1181    from torch._inductor.test_case import run_tests
1182
1183    if HAS_CUDA:
1184        run_tests(needs="filelock")
1185