xref: /aosp_15_r20/external/pytorch/test/inductor/test_select_algorithm.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import functools
3from unittest.mock import patch
4
5import torch
6import torch._dynamo.config as dynamo_config
7import torch._inductor.config as inductor_config
8import torch._inductor.select_algorithm as select_algorithm
9import torch.nn.functional as F
10from torch._dynamo.testing import expectedFailureDynamicWrapper
11from torch._dynamo.utils import counters
12from torch._inductor.autotune_process import TritonBenchmarkRequest
13from torch._inductor.test_case import run_tests, TestCase
14from torch._inductor.utils import is_big_gpu
15from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm
16from torch.testing._internal.inductor_utils import HAS_CUDA
17
18
19aten = torch.ops.aten
20
21
22def patches(fn):
23    def skip_cache(self, choices, name, key, benchmark):
24        if benchmark is None:
25            return {}
26        return benchmark(choices)
27
28    for patcher in [
29        dynamo_config.patch(verbose=True),
30        inductor_config.patch(debug=True, max_autotune=True, epilogue_fusion=True),
31        patch.object(select_algorithm, "VERIFY", dict(atol=1e-4, rtol=1e-4)),
32        patch.object(select_algorithm.AlgorithmSelectorCache, "lookup", skip_cache),
33        torch.backends.cudnn.flags(allow_tf32=False),
34    ]:
35        fn = patcher(fn)
36
37    @functools.wraps(fn)
38    def wrapped(*args, **kwargs):
39        counters.clear()
40        torch.manual_seed(12345)
41        assert (
42            not torch.backends.cuda.matmul.allow_tf32
43        ), "correctness testing is allergic to tf32"
44        return fn(*args, **kwargs)
45
46    return wrapped
47
48
49class TestSelectAlgorithm(TestCase):
50    def setUp(self):
51        super().setUp()
52        if not is_big_gpu(0):
53            return self.skipTest("Need a big GPU to run max_autotune=True")
54
55    @patches
56    def test_linear_relu_cuda(self):
57        @torch.compile
58        def foo(input, weight, bias):
59            return F.relu(F.linear(input, weight, bias))
60
61        foo(
62            torch.randn(64, 32, device="cuda"),
63            torch.randn(16, 32, device="cuda"),
64            torch.randn(1, 16, device="cuda"),
65        )
66        # Autotuning checks correctness of each version
67        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
68        # It would be nice to assert this got fused into a single kernel, but that
69        # only happens if we select a triton template (and not aten).
70
71    @patches
72    def test_addmm_cuda(self):
73        @torch.compile
74        def foo(input, weight, bias):
75            return torch.addmm(bias, input, weight)
76
77        inps = (
78            torch.randn(20, 33, device="cuda"),
79            torch.randn(33, 16, device="cuda"),
80            torch.randn(20, 16, device="cuda"),
81        )
82
83        foo(*inps)
84        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
85
86    @patch.object(select_algorithm, "VERIFY", dict(atol=5e-2, rtol=5e-2))
87    @patches
88    def test_addmm_fp16(self):
89        @torch.compile
90        def foo(input, weight, bias):
91            return torch.addmm(bias, input, weight)
92
93        inps = (
94            torch.randn(2, 320, device="cuda", dtype=torch.half),
95            torch.randn(320, 320, device="cuda", dtype=torch.half).t(),
96            torch.empty(320, device="cuda", dtype=torch.half),
97        )
98
99        foo(*inps)
100        # Autotuning checks correctness of each version
101        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
102
103    @patches
104    def test_mm(self):
105        @torch.compile
106        def foo(a, b):
107            return torch.mm(a, b)
108
109        foo(
110            torch.randn(8, 32, device="cuda"),
111            torch.randn(32, 8, device="cuda"),
112        )
113        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
114
115    # FIXME: Investigate why _int_mm_out_cuda is not compiled on ROCm
116    @skipIfRocm
117    @patches
118    def test__int_mm(self):
119        @torch.compile
120        def foo(a, b):
121            return torch._int_mm(a, b)
122
123        foo(
124            torch.randint(-10, 10, (64, 32), device="cuda", dtype=torch.int8),
125            torch.randint(-10, 10, (32, 64), device="cuda", dtype=torch.int8),
126        )
127        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
128
129    @patches
130    def test_mm_skip(self):
131        @torch.compile
132        def foo(a, b):
133            return torch.mm(a, b)
134
135        foo(
136            torch.randn(8, 32, device="cuda", dtype=torch.float64),
137            torch.randn(32, 8, device="cuda", dtype=torch.float64),
138        )
139        # float64 not supported by tl.dot()
140        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0)
141
142    @patches
143    def test_bmm(self):
144        @torch.compile
145        def foo(a, b):
146            return torch.bmm(a, b)
147
148        foo(
149            torch.randn(2, 8, 32, device="cuda"),
150            torch.randn(2, 32, 8, device="cuda"),
151        )
152        # Autotuning checks correctness of each version
153        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
154
155    @patches
156    def test_mm_not_even_k(self):
157        @torch.compile
158        def foo(a, b):
159            return torch.mm(a, b)
160
161        foo(
162            torch.randn(11, 22, device="cuda"),
163            torch.randn(22, 33, device="cuda"),
164        )
165        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
166
167    @patches
168    def test_baddbmm(self):
169        @torch.compile
170        def foo(a, b, c):
171            return torch.baddbmm(c, a, b)
172
173        foo(
174            torch.randn(2, 8, 32, device="cuda"),
175            torch.randn(2, 32, 8, device="cuda"),
176            torch.randn(2, 1, 8, device="cuda"),
177        )
178        # Autotuning checks correctness of each version
179        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
180
181    @patches
182    def test_mm_plus_mm(self):
183        @torch.compile
184        def foo(a, b, c, d):
185            return (a @ b) + (c @ d)
186
187        foo(
188            torch.randn(32, 32, device="cuda"),
189            torch.randn(32, 32, device="cuda"),
190            torch.randn(32, 32, device="cuda"),
191            torch.randn(32, 32, device="cuda"),
192        )
193        # Autotuning checks correctness of each version
194        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
195
196    @patches
197    def test_mm_plus_mm2_cuda(self):
198        @torch.compile
199        def foo(a, b, c, d):
200            return (a @ b) + (c @ d)
201
202        foo(
203            torch.randn(512, 512, device="cuda"),
204            torch.randn(512, 512, device="cuda"),
205            torch.randn(512, 512, device="cuda"),
206            torch.randn(512, 512, device="cuda"),
207        )
208        # Autotuning checks correctness of each version
209        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
210
211    @expectedFailureDynamicWrapper
212    @patches
213    def test_mm_plus_mm3_cuda(self):
214        @torch.compile
215        def foo(a, b, c, d):
216            return (a @ b) + (c @ d)
217
218        foo(
219            torch.randn(512, 32, device="cuda"),
220            torch.randn(32, 8, device="cuda"),
221            torch.randn(512, 32, device="cuda"),
222            torch.randn(32, 8, device="cuda"),
223        )
224        # Autotuning checks correctness of each version
225        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
226
227    @patches
228    def test_mm_dup_args(self):
229        @torch.compile
230        def foo(a):
231            return torch.mm(a, a)
232
233        foo(torch.randn(32, 32, device="cuda"))
234        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
235
236    @patches
237    def test_mm_dup_args_view(self):
238        @torch.compile
239        def foo(a):
240            q = a[:32, :]
241            k = a[32:, :]
242            return torch.mm(q, k.transpose(0, 1))
243
244        foo(torch.randn(64, 64, device="cuda"))
245        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
246
247    @expectedFailureDynamicWrapper
248    @patches
249    def test_convolution1(self):
250        @torch.compile
251        def foo(x, w, b):
252            return aten.convolution(
253                x + 1,
254                w,
255                b,
256                stride=(2, 3),
257                padding=(4, 5),
258                dilation=(1, 1),
259                transposed=False,
260                output_padding=(0, 0),
261                groups=1,
262            )
263
264        foo(
265            torch.randn(2, 33, 34, 41, device="cuda"),
266            torch.randn(34, 33, 3, 3, device="cuda"),
267            torch.randn(34, device="cuda"),
268        )
269        # Autotuning checks correctness of each version
270        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
271
272    @skipIfRocm
273    @patches
274    def test_mm_dropout(self):
275        @torch.compile
276        def fn(x1, x2, seed):
277            mm_4 = torch.ops.aten.mm.default(x2, x1)
278            rnd = torch.ops.prims.inductor_random.default(mm_4.shape, seed, "rand")
279            return mm_4 * rnd
280
281        # sizes picked so triton autotuning wins
282        fn(
283            torch.randn(512, 1024, dtype=torch.float16, device="cuda"),
284            torch.randn(384, 512, dtype=torch.float16, device="cuda"),
285            torch.tensor(12345, device="cuda"),
286        )
287        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
288
289    @skipIfRocm
290    @patches
291    @torch._inductor.config.patch(conv_1x1_as_mm=False)
292    def test_convolution2(self):
293        @torch.compile
294        def foo(x, w, b):
295            return aten.convolution(
296                x,
297                w,
298                b,
299                stride=(1, 1),
300                padding=(0, 0),
301                dilation=(1, 1),
302                transposed=False,
303                output_padding=(0, 0),
304                groups=1,
305            )
306
307        foo(
308            torch.randn(1, 33, 16, 16, device="cuda"),
309            torch.randn(34, 33, 1, 1, device="cuda"),
310            torch.randn(34, device="cuda"),
311        )
312        # Autotuning checks correctness of each version
313        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
314
315    @patches
316    @torch._inductor.config.patch(conv_1x1_as_mm=True)
317    def test_convolution_as_mm(self):
318        @torch.compile
319        def foo(x, w, b):
320            return aten.convolution(
321                x + 1,
322                w,
323                b,
324                stride=(1, 1),
325                padding=(0, 0),
326                dilation=(1, 1),
327                transposed=False,
328                output_padding=(0, 0),
329                groups=1,
330            )
331
332        foo(
333            torch.randn(2, 33, 16, 16, device="cuda"),
334            torch.randn(34, 33, 1, 1, device="cuda"),
335            torch.randn(34, device="cuda"),
336        )
337        # Autotuning checks correctness of each version
338        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
339
340    def test_TritonTemplateCaller_str(self):
341        """
342        Make sure str(TritonTemplateCaller) does not raise exceptions.
343        """
344        module_path = "abc.py"
345        bmreq = TritonBenchmarkRequest(
346            module_path=module_path,
347            module_cache_key=None,
348            kernel_name=None,
349            grid=None,
350            extra_args=None,
351            num_stages=None,
352            num_warps=None,
353            input_tensor_meta=None,
354            output_tensor_meta=None,
355        )
356        caller = select_algorithm.TritonTemplateCaller(
357            None, None, None, None, "extra", bmreq
358        )
359        caller_str = str(caller)
360        self.assertEqual(caller_str, f"TritonTemplateCaller({module_path}, extra)")
361
362
363if __name__ == "__main__":
364    if IS_LINUX and HAS_CUDA and is_big_gpu(0):
365        run_tests()
366