xref: /aosp_15_r20/external/pytorch/test/inductor/test_combo_kernels.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2
3import sys
4import unittest
5
6import torch
7import torch._inductor
8from torch.testing._internal.common_utils import (
9    instantiate_parametrized_tests,
10    TestCase,
11)
12from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
13from torch.testing._internal.triton_utils import requires_cuda
14
15
16aten = torch.ops.aten
17
18try:
19    try:
20        from .test_torchinductor import check_model, check_model_cuda
21    except ImportError:
22        from test_torchinductor import check_model, check_model_cuda
23except (unittest.SkipTest, ImportError) as e:
24    sys.stderr.write(f"{type(e)}: {e}\n")
25    if __name__ == "__main__":
26        sys.exit(0)
27    raise
28
29
30@instantiate_parametrized_tests
31class ComboKernelTests(TestCase):
32    check_model_cuda = check_model_cuda
33    check_model_cpu = check_model
34    check_kernel_count = True
35
36    def setUp(self):
37        super().setUp()
38        torch._inductor.metrics.reset()
39        torch._inductor.config.combo_kernels = True
40        torch._inductor.config.benchmark_combo_kernel = False
41
42    def tearDown(self):
43        super().tearDown()
44        torch._inductor.metrics.reset()
45
46    @requires_cuda
47    def test_activation_functions(self):
48        def test_activations(a, b, c):
49            a1 = torch.nn.functional.relu(a)
50            b1 = torch.nn.functional.sigmoid(b)
51            c1 = torch.nn.functional.tanh(c)
52            return a1, b1, c1
53
54        inps = [
55            torch.rand(10, 10, device="cuda"),
56            torch.rand(20, 20, device="cuda"),
57            torch.rand(10, 10, device="cuda"),
58        ]
59
60        out_eager = test_activations(*inps)
61        out_compiled = torch.compile(test_activations)(*inps)
62
63        self.assertEqual(out_eager, out_compiled)
64        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
65
66    @requires_cuda
67    def test_reduce_functions(self):
68        def test_reduce(a, b, c, d):
69            a1 = torch.sum(a, dim=0)
70            b1 = torch.max(b, dim=0)
71            c1 = torch.min(c, dim=0)
72            d1 = torch.nn.functional.tanh(d)
73
74            return a1, b1, c1, d1
75
76        inps = [
77            torch.rand(10, 10, device="cuda"),
78            torch.rand(20, 20, device="cuda"),
79            torch.rand(10, 10, device="cuda"),
80            torch.rand(30, 8, device="cuda"),
81        ]
82
83        out_eager = test_reduce(*inps)
84        out_compiled = torch.compile(test_reduce)(*inps)
85
86        self.assertEqual(out_eager, out_compiled)
87        self.assertTrue(torch._inductor.metrics.generated_kernel_count <= 2)
88
89    @requires_cuda
90    def test_mutated_args(self):
91        def test_mutated(a, b, c, d):
92            a.add_(1)
93            b.sigmoid_()
94            c = torch.add(c, 5)
95            d.tanh_()
96
97            return a, b, c, d
98
99        inps = [
100            torch.rand(10, 10, device="cuda"),
101            torch.rand(20, 20, device="cuda"),
102            torch.rand(10, 10, device="cuda"),
103            torch.rand(30, 8, device="cuda"),
104        ]
105
106        out_eager = test_mutated(*inps)
107        out_compiled = torch.compile(test_mutated)(*inps)
108
109        self.assertEqual(out_eager, out_compiled)
110        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
111
112    @requires_cuda
113    def test_reduce_split(self):
114        def fn(a, b):
115            a1 = torch.linalg.vector_norm(a)
116            b1 = torch.sum(b, dim=0)
117            return a1, b1
118
119        inps = [
120            torch.rand(2048, 512, device="cuda"),
121            torch.rand(20, 20, device="cuda"),
122        ]
123        out_eager = fn(*inps)
124        out_compiled = torch.compile(fn)(*inps)
125
126        self.assertEqual(out_eager, out_compiled)
127
128    @requires_cuda
129    def test_2d_blocking_partitioning(self):
130        def fn(a0, a1, a2, b0, b1, b2):
131            c0 = torch.add(a0, b0)
132            c1 = torch.add(a1, b1)
133            c2 = torch.add(a2, b2)
134            return c0, c1, c2
135
136        self.check_model_cuda(
137            fn,
138            (
139                torch.rand(30, 20, device="cuda"),
140                torch.rand(40, 30, device="cuda"),
141                torch.rand(36, 40, device="cuda"),
142                torch.rand(30, 20, device="cuda"),
143                torch.rand(30, 40, device="cuda").t(),
144                torch.rand(40, 36, device="cuda").t(),
145            ),
146        )
147
148        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2)
149
150
151@instantiate_parametrized_tests
152class ComboKernelBenchmarkTests(TestCase):
153    check_model_cuda = check_model_cuda
154    check_model_cpu = check_model
155    check_kernel_count = True
156
157    def setUp(self):
158        super().setUp()
159        torch._inductor.metrics.reset()
160        torch._inductor.config.combo_kernels = True
161        torch._inductor.config.benchmark_combo_kernel = True
162
163    def tearDown(self):
164        super().tearDown()
165        torch._inductor.metrics.reset()
166
167    @requires_cuda
168    def test_activation_benchmark(self):
169        def test_activations(a, b, c):
170            a1 = torch.nn.functional.relu(a)
171            b1 = torch.nn.functional.sigmoid(b)
172            c1 = torch.nn.functional.tanh(c)
173            return a1, b1, c1
174
175        inps = [
176            torch.rand(10, 10, device="cuda"),
177            torch.rand(20, 20, device="cuda"),
178            torch.rand(10, 10, device="cuda"),
179        ]
180
181        out_eager = test_activations(*inps)
182        out_compiled = torch.compile(test_activations)(*inps)
183
184        self.assertEqual(out_eager, out_compiled)
185        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5)
186
187    @requires_cuda
188    def test_reduce_benchmark(self):
189        def test_reduce(a, b, c, d):
190            a1 = torch.sum(a, dim=0)
191            b1 = torch.max(b, dim=0)
192            c1 = torch.min(c, dim=0)
193            d1 = torch.nn.functional.tanh(d)
194
195            return a1, b1, c1, d1
196
197        inps = [
198            torch.rand(10, 10, device="cuda"),
199            torch.rand(20, 20, device="cuda"),
200            torch.rand(10, 10, device="cuda"),
201            torch.rand(30, 8, device="cuda"),
202        ]
203
204        out_eager = test_reduce(*inps)
205        out_compiled = torch.compile(test_reduce)(*inps)
206
207        self.assertEqual(out_eager, out_compiled)
208        self.assertTrue(4 < torch._inductor.metrics.generated_kernel_count <= 10)
209
210    @requires_cuda
211    def test_mutated_benchmark(self):
212        def test_mutated(a, b, c, d):
213            a.add_(1)
214            b.sigmoid_()
215            c = torch.add(c, 5)
216            d.tanh_()
217
218            return a, b, c, d
219
220        inps = [
221            torch.rand(10, 10, device="cuda"),
222            torch.rand(20, 20, device="cuda"),
223            torch.rand(10, 10, device="cuda"),
224            torch.rand(30, 8, device="cuda"),
225        ]
226
227        out_eager = test_mutated(*inps)
228        out_compiled = torch.compile(test_mutated)(*inps)
229
230        self.assertEqual(out_eager, out_compiled)
231        self.assertTrue(torch._inductor.metrics.generated_kernel_count in [6, 9])
232
233    @requires_cuda
234    def test_round_robin_dispatch(self):
235        # combo kernel dispatch strategy: round robin
236        def test_mutated(a, b, c, d):
237            a.add_(1)
238            b.sigmoid_()
239            c = torch.add(c, 5)
240            d.tanh_()
241
242            return a, b, c, d
243
244        inps = [
245            torch.rand(10, 10, device="cuda"),
246            torch.rand(20, 5, device="cuda"),
247            torch.rand(10, 10, device="cuda"),
248            torch.rand(5, 18, device="cuda"),
249        ]
250
251        out_eager = test_mutated(*inps)
252        out_compiled = torch.compile(test_mutated)(*inps)
253
254        self.assertEqual(out_eager, out_compiled)
255        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 6)
256
257    @requires_cuda
258    def test_2d_blocking_benchmark(self):
259        def fn(a0, a1, a2, b0, b1, b2):
260            c0 = torch.add(a0, b0)
261            c1 = torch.add(a1, b1)
262            c2 = torch.add(a2, b2)
263            return c0, c1, c2
264
265        self.check_model_cuda(
266            fn,
267            (
268                torch.rand(30, 20, device="cuda"),
269                torch.rand(40, 30, device="cuda"),
270                torch.rand(36, 40, device="cuda"),
271                torch.rand(30, 20, device="cuda"),
272                torch.rand(30, 40, device="cuda").t(),
273                torch.rand(40, 36, device="cuda").t(),
274            ),
275        )
276
277        self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8)
278
279    @requires_cuda
280    def test_persistent_reduction_no_x_dim(self):
281        def fn(x, y):
282            return x.sum(1), y.sum(1)
283
284        inps = (
285            torch.rand(16, 256, device="cuda"),
286            torch.rand(32, 256, device="cuda"),
287        )
288        torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256)
289        torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256)
290        out_eager = fn(*inps)
291        out_compiled = torch.compile(fn)(*inps)
292
293        self.assertEqual(out_eager, out_compiled)
294        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4)
295
296
297@instantiate_parametrized_tests
298class ComboKernelDynamicShapesTests(TestCase):
299    check_model_cuda = check_model_cuda
300    check_model_cpu = check_model
301    check_kernel_count = True
302
303    def setUp(self):
304        super().setUp()
305        torch._inductor.metrics.reset()
306        torch._inductor.config.combo_kernels = True
307        torch._inductor.config.benchmark_combo_kernel = True
308        torch._dynamo.config.automatic_dynamic_shapes = False
309        torch._dynamo.config.assume_static_by_default = False
310
311    def tearDown(self):
312        super().tearDown()
313        torch._inductor.metrics.reset()
314
315    @requires_cuda
316    def test_dynamic_shapes_activations(self):
317        def test_activations(a, b, c):
318            a1 = torch.nn.functional.relu(a)
319            b1 = torch.nn.functional.sigmoid(b)
320            c1 = torch.nn.functional.tanh(c)
321            return a1, b1, c1
322
323        inps = [
324            torch.rand(10, 10, device="cuda"),
325            torch.rand(20, 20, device="cuda"),
326            torch.rand(10, 10, device="cuda"),
327        ]
328
329        out_eager = test_activations(*inps)
330        out_compiled = torch.compile(test_activations)(*inps)
331
332        self.assertEqual(out_eager, out_compiled)
333        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5)
334
335    @requires_cuda
336    def test_dynamic_shapes_2d_blocking(self):
337        def fn(a0, a1, a2, b0, b1, b2):
338            c0 = torch.add(a0, b0)
339            c1 = torch.add(a1, b1)
340            c2 = torch.add(a2, b2)
341            return c0, c1, c2
342
343        self.check_model_cuda(
344            fn,
345            (
346                torch.rand(30, 20, device="cuda"),
347                torch.rand(40, 30, device="cuda"),
348                torch.rand(36, 40, device="cuda"),
349                torch.rand(30, 20, device="cuda"),
350                torch.rand(30, 40, device="cuda").t(),
351                torch.rand(40, 36, device="cuda").t(),
352            ),
353        )
354
355        self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8)
356
357    @requires_cuda
358    def test_dynamic_shapes_reduce(self):
359        def test_reduce(a, b, c, d):
360            a1 = torch.sum(a, dim=0)
361            b1 = torch.max(b, dim=0)
362            c1 = torch.min(c, dim=0)
363            d1 = torch.nn.functional.tanh(d)
364
365            return a1, b1, c1, d1
366
367        inps = [
368            torch.rand(10, 10, device="cuda"),
369            torch.rand(20, 20, device="cuda"),
370            torch.rand(10, 10, device="cuda"),
371            torch.rand(30, 8, device="cuda"),
372        ]
373
374        out_eager = test_reduce(*inps)
375        out_compiled = torch.compile(test_reduce)(*inps)
376
377        self.assertEqual(out_eager, out_compiled)
378        self.assertTrue(4 < torch._inductor.metrics.generated_kernel_count <= 10)
379
380    @requires_cuda
381    def test_dynamic_shapes_mutated(self):
382        # combo kernel dispatch strategy: round robin
383        def test_mutated(a, b, c, d):
384            a.add_(1)
385            b.sigmoid_()
386            c = torch.add(c, 5)
387            d.tanh_()
388
389            return a, b, c, d
390
391        inps = [
392            torch.rand(10, 10, device="cuda"),
393            torch.rand(20, 5, device="cuda"),
394            torch.rand(10, 10, device="cuda"),
395            torch.rand(5, 18, device="cuda"),
396        ]
397
398        out_eager = test_mutated(*inps)
399        out_compiled = torch.compile(test_mutated)(*inps)
400
401        self.assertEqual(out_eager, out_compiled)
402        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 6)
403
404    @requires_cuda
405    @torch._inductor.config.patch("combo_kernels_autotune", 0)
406    def test_dynamic_shapes_activations_no_autotune(self):
407        def test_activations(a, b, c):
408            a1 = torch.nn.functional.relu(a)
409            b1 = torch.nn.functional.sigmoid(b)
410            c1 = torch.nn.functional.tanh(c)
411            return a1, b1, c1
412
413        inps = [
414            torch.rand(10, 10, device="cuda"),
415            torch.rand(20, 20, device="cuda"),
416            torch.rand(10, 10, device="cuda"),
417        ]
418
419        out_eager = test_activations(*inps)
420        out_compiled = torch.compile(test_activations)(*inps)
421
422        self.assertEqual(out_eager, out_compiled)
423        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5)
424
425    @requires_cuda
426    @torch._dynamo.config.patch("automatic_dynamic_shapes", True)
427    @torch._dynamo.config.patch("assume_static_by_default", True)
428    def test_dynamic_shapes_persistent_reduction_no_x_dim(self):
429        def fn(x, y):
430            return x.sum(1), y.sum(1)
431
432        inps = (
433            torch.rand(16, 256, device="cuda"),
434            torch.rand(32, 256, device="cuda"),
435        )
436        torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256)
437        torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256)
438        out_eager = fn(*inps)
439        out_compiled = torch.compile(fn)(*inps)
440
441        self.assertEqual(out_eager, out_compiled)
442        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4)
443
444    @requires_cuda
445    @torch._dynamo.config.patch("automatic_dynamic_shapes", True)
446    @torch._dynamo.config.patch("assume_static_by_default", True)
447    def test_dynamic_shapes_2d_blocking_round_robin(self):
448        def fn(a0, a1, a2, b0, b1, b2):
449            c0 = torch.add(a0, b0)
450            c1 = torch.add(a1, b1)
451            c2 = torch.add(a2, b2)
452            return c0, c1, c2
453
454        inps = (
455            torch.rand(20, 30, device="cuda"),
456            torch.rand(30, 30, device="cuda"),
457            torch.rand(40, 32, device="cuda"),
458            torch.rand(30, 20, device="cuda").t(),
459            torch.rand(30, 30, device="cuda").t(),
460            torch.rand(32, 40, device="cuda").t(),
461        )
462
463        out_eager = fn(*inps)
464        compiled = torch.compile(fn)
465        out_compiled = compiled(*inps)
466        self.assertEqual(out_eager, out_compiled)
467        self.assertTrue(5 <= torch._inductor.metrics.generated_kernel_count <= 6)
468        torch._inductor.metrics.reset()
469
470        inps = (
471            torch.rand(24, 30, device="cuda"),
472            torch.rand(32, 30, device="cuda"),
473            torch.rand(48, 32, device="cuda"),
474            torch.rand(30, 24, device="cuda").t(),
475            torch.rand(30, 32, device="cuda").t(),
476            torch.rand(32, 48, device="cuda").t(),
477        )
478        out_compiled = compiled(*inps)
479        out_eager = fn(*inps)
480        self.assertEqual(out_eager, out_compiled)
481        self.assertTrue(5 <= torch._inductor.metrics.generated_kernel_count <= 6)
482
483    @requires_cuda
484    @torch._dynamo.config.patch("automatic_dynamic_shapes", True)
485    @torch._dynamo.config.patch("assume_static_by_default", True)
486    @torch._inductor.config.patch("triton.autotune_at_compile_time", True)
487    def test_dynamic_shapes_persistent_reduction_mixed_x_dim_cuda(self):
488        def fn(x, y, z):
489            return x.sum(1), y.mean(1), z.max(1)
490
491        inps = (
492            torch.rand(16, 128, device="cuda"),
493            torch.rand(32, 128, device="cuda"),
494            torch.rand(32, 256, device="cuda"),
495        )
496        torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256)
497        torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256)
498        torch._dynamo.mark_dynamic(inps[2], 0, min=1, max=256)
499        out_eager = fn(*inps)
500        out_compiled = torch.compile(fn)(*inps)
501
502        self.assertEqual(out_eager, out_compiled)
503
504
505if __name__ == "__main__":
506    from torch._dynamo.test_case import run_tests
507
508    if HAS_CPU or HAS_CUDA:
509        run_tests(needs="filelock")
510