xref: /aosp_15_r20/external/pytorch/test/inductor/test_pad_mm.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import unittest
3
4import torch
5import torch._inductor.config as inductor_config
6from torch._dynamo.testing import rand_strided
7from torch._inductor.fx_passes.pad_mm import (
8    get_alignment_size,
9    get_pad_cache,
10    get_padded_length,
11    should_pad_common,
12)
13from torch._inductor.test_case import run_tests, TestCase
14from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code
15from torch.testing import FileCheck
16from torch.testing._internal.inductor_utils import HAS_CUDA
17
18
19class PadMMTest(TestCase):
20    def setUp(self):
21        super().setUp()
22        if not is_big_gpu(0):
23            return self.skipTest("Need a big GPU to run max_autotune=True")
24
25    @inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
26    def test_pad_mm_dyn_m(self):
27        M = 40
28        K1 = 581
29        K2 = 49
30        N = 30
31
32        class Model(torch.nn.Module):
33            def __init__(self) -> None:
34                super().__init__()
35                self.w = rand_strided(
36                    (K2, N), (1, K2), device="cuda", dtype=torch.float32
37                )
38
39            def forward(self, a):
40                a1 = torch.narrow(a, 1, 0, K2)
41                return torch.mm(a1, self.w)
42
43        fn = Model().cuda()
44        a = rand_strided((M, K1), (K1, 1), device="cuda", dtype=torch.float32)
45        aligned_k = get_padded_length(K2, get_alignment_size(a)) + K2
46        torch._dynamo.mark_dynamic(a, 0)
47        with unittest.mock.patch(
48            "torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True
49        ):
50            res1 = fn(a)
51            compiled_fn = torch.compile(fn)
52            res2, (code,) = run_and_get_code(compiled_fn, a)
53            FileCheck().check(f"K = {aligned_k}").run(code)
54        self.assertEqual(res1, res2)
55
56    @inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
57    def test_cat_pad_mm_dyn_m(self):
58        M1 = 128
59        M2 = 40
60        K1 = 129
61        K2 = 111
62        N = 100
63
64        class Model(torch.nn.Module):
65            def __init__(self) -> None:
66                super().__init__()
67                self.w = rand_strided(
68                    (K2, N), (1, K2), device="cuda", dtype=torch.float32
69                )
70
71            def forward(self, a, b):
72                c = torch.cat([a, b], dim=0)
73                a1 = torch.narrow(c, 1, 0, K2)
74                return torch.mm(a1, self.w)
75
76        fn = Model().cuda()
77        a = rand_strided((M1, K1), (K1, 1), device="cuda", dtype=torch.float32)
78        b = rand_strided((M2, K1), (K1, 1), device="cuda", dtype=torch.float32)
79        torch._dynamo.mark_dynamic(a, 0)
80        torch._dynamo.mark_dynamic(b, 0)
81        aligned_k = get_padded_length(K2, get_alignment_size(a)) + K2
82        with unittest.mock.patch(
83            "torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True
84        ):
85            res1 = fn(a, b)
86            compiled_fn = torch.compile(fn)
87            res2, (code,) = run_and_get_code(compiled_fn, a, b)
88            FileCheck().check(f"K = {aligned_k}").run(code)
89        self.assertEqual(res1, res2)
90
91    @inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
92    def test_pad_mm_dyn_n(self):
93        M = 20
94        K = 81
95        N = 30
96
97        class Model(torch.nn.Module):
98            def __init__(self) -> None:
99                super().__init__()
100
101            def forward(self, a, b):
102                return torch.mm(a, b)
103
104        fn = Model().cuda()
105        a = rand_strided((M, K), (K, 1), device="cuda", dtype=torch.float32)
106        b = rand_strided((K, N), (1, K), device="cuda", dtype=torch.float32)
107        aligned_k = get_padded_length(K, get_alignment_size(a)) + K
108        torch._dynamo.mark_dynamic(b, 1)
109        with unittest.mock.patch(
110            "torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True
111        ):
112            res1 = fn(a, b)
113            compiled_fn = torch.compile(fn)
114            res2, (code,) = run_and_get_code(compiled_fn, a, b)
115            FileCheck().check(f"K = {aligned_k}").run(code)
116        self.assertEqual(res1, res2)
117
118    @inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
119    def test_pad_mm_dyn_k(self):
120        M = 21
121        K = 80
122        N = 30
123
124        class Model(torch.nn.Module):
125            def __init__(self) -> None:
126                super().__init__()
127
128            def forward(self, a, b):
129                return torch.mm(a, b)
130
131        fn = Model().cuda()
132        a = rand_strided((M, K), (K, 1), device="cuda", dtype=torch.float32)
133        b = rand_strided((K, N), (1, K), device="cuda", dtype=torch.float32)
134        # TODO: Getting the alignment right requires pattern matcher to
135        # run on newly added nodes
136        aligned_m = get_padded_length(M, get_alignment_size(a)) + M
137        torch._dynamo.mark_dynamic(a, 1)
138        torch._dynamo.mark_dynamic(b, 0)
139        with unittest.mock.patch(
140            "torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True
141        ):
142            res1 = fn(a, b)
143            compiled_fn = torch.compile(fn)
144            res2, (code,) = run_and_get_code(compiled_fn, a, b)
145            FileCheck().check(f"M = {aligned_m}").run(code)
146        self.assertEqual(res1, res2)
147
148    def test_pad_mm_dyn_mnk(self):
149        M = 20
150        K = 81
151        N = 30
152
153        class Model(torch.nn.Module):
154            def __init__(self) -> None:
155                super().__init__()
156
157            def forward(self, a, b):
158                return torch.mm(a, b)
159
160        fn = Model().cuda()
161        a = rand_strided((M, K), (K, 1), device="cuda", dtype=torch.float32)
162        b = rand_strided((K, N), (1, K), device="cuda", dtype=torch.float32)
163        torch._dynamo.mark_dynamic(a, 0)
164        torch._dynamo.mark_dynamic(a, 1)
165        torch._dynamo.mark_dynamic(b, 0)
166        torch._dynamo.mark_dynamic(b, 1)
167        with unittest.mock.patch(
168            "torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True
169        ):
170            res1 = fn(a, b)
171            compiled_fn = torch.compile(fn)
172            res2, (code,) = run_and_get_code(compiled_fn, a, b)
173        self.assertEqual(res1, res2)
174
175    @inductor_config.patch(force_shape_pad=True)
176    def test_zero_dim(self):
177        def addmm(x, a, b):
178            return torch.addmm(x, a, b)
179
180        x = torch.randn(100).cuda()
181        a = torch.randn(0, 10).cuda()
182        b = torch.randn(10, 100).cuda()
183        self.assertEqual(torch.compile(addmm)(x, a, b), addmm(x, a, b))
184
185    @inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
186    def test_pad_bmm_dyn_b(self):
187        B = 10
188        M = 128
189        K = 33
190        N = 40
191
192        class Model(torch.nn.Module):
193            def __init__(self) -> None:
194                super().__init__()
195
196            def forward(self, a, b):
197                return torch.bmm(a, b)
198
199        fn = Model().cuda()
200        a = torch.randn(B, M, K, device="cuda", dtype=torch.float32)
201        b = torch.randn(B, K, N, device="cuda", dtype=torch.float32)
202        aligned_k = get_padded_length(K, get_alignment_size(a)) + K
203        torch._dynamo.mark_dynamic(a, 0)
204        torch._dynamo.mark_dynamic(b, 0)
205        with unittest.mock.patch(
206            "torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True
207        ):
208            res1 = fn(a, b)
209            compiled_fn = torch.compile(fn)
210            res2, (code,) = run_and_get_code(compiled_fn, a, b)
211            FileCheck().check(f"K = {aligned_k}").run(code)
212        self.assertEqual(res1, res2)
213
214    @inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
215    def test_pad_bmm_dyn_k(self):
216        B = 10
217        M = 128
218        K = 40
219        N = 41
220
221        class Model(torch.nn.Module):
222            def __init__(self) -> None:
223                super().__init__()
224
225            def forward(self, a, b):
226                return torch.bmm(a, b)
227
228        fn = Model().cuda()
229        a = torch.randn(B, M, K, device="cuda", dtype=torch.float32)
230        b = torch.randn(B, K, N, device="cuda", dtype=torch.float32)
231        aligned_n = get_padded_length(N, get_alignment_size(b)) + N
232        torch._dynamo.mark_dynamic(a, 2)
233        torch._dynamo.mark_dynamic(b, 1)
234        with unittest.mock.patch(
235            "torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True
236        ):
237            res1 = fn(a, b)
238            compiled_fn = torch.compile(fn)
239            res2, (code,) = run_and_get_code(compiled_fn, a, b)
240            FileCheck().check(f"N = {aligned_n}").run(code)
241        self.assertEqual(res1, res2)
242
243    @inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
244    def test_pad_bmm_dyn_bm(self):
245        B = 10
246        M = 128
247        K = 40
248        N = 41
249
250        class Model(torch.nn.Module):
251            def __init__(self) -> None:
252                super().__init__()
253
254            def forward(self, a, b):
255                return torch.bmm(a, b)
256
257        fn = Model().cuda()
258        a = torch.randn(B, M, K, device="cuda", dtype=torch.float32)
259        b = torch.randn(B, K, N, device="cuda", dtype=torch.float32)
260        aligned_n = get_padded_length(N, get_alignment_size(b)) + N
261        torch._dynamo.mark_dynamic(a, 0)
262        torch._dynamo.mark_dynamic(a, 1)
263        torch._dynamo.mark_dynamic(b, 0)
264        with unittest.mock.patch(
265            "torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True
266        ):
267            res1 = fn(a, b)
268            compiled_fn = torch.compile(fn)
269            res2, (code,) = run_and_get_code(compiled_fn, a, b)
270            FileCheck().check(f"N = {aligned_n}").run(code)
271        self.assertEqual(res1, res2)
272
273    @inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
274    def test_pad_addmm_dyn_m(self):
275        M = 128
276        K = 33
277        N = 40
278
279        class Model(torch.nn.Module):
280            def __init__(self) -> None:
281                super().__init__()
282
283            def forward(self, a, b, c):
284                return torch.addmm(a, b, c)
285
286        fn = Model().cuda()
287        a = torch.randn(M, N, device="cuda", dtype=torch.float32)
288        b = torch.randn(M, K, device="cuda", dtype=torch.float32)
289        c = torch.randn(K, N, device="cuda", dtype=torch.float32)
290        aligned_k = get_padded_length(K, get_alignment_size(b)) + K
291        torch._dynamo.mark_dynamic(a, 0)
292        torch._dynamo.mark_dynamic(b, 0)
293        with unittest.mock.patch(
294            "torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True
295        ):
296            res1 = fn(a, b, c)
297            compiled_fn = torch.compile(fn)
298            res2, (code,) = run_and_get_code(compiled_fn, a, b, c)
299            FileCheck().check(f"K = {aligned_k}").run(code)
300        self.assertEqual(res1, res2)
301
302    @inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON")
303    def test_pad_addmm_dyn_mn(self):
304        M = 128
305        K = 33
306        N = 40
307
308        class Model(torch.nn.Module):
309            def __init__(self) -> None:
310                super().__init__()
311
312            def forward(self, a, b, c):
313                return torch.addmm(a, b, c)
314
315        fn = Model().cuda()
316        a = torch.randn(M, N, device="cuda", dtype=torch.float32)
317        b = torch.randn(M, K, device="cuda", dtype=torch.float32)
318        c = torch.randn(K, N, device="cuda", dtype=torch.float32)
319        torch._dynamo.mark_dynamic(a, 0)
320        torch._dynamo.mark_dynamic(a, 1)
321        torch._dynamo.mark_dynamic(b, 0)
322        torch._dynamo.mark_dynamic(c, 1)
323        with unittest.mock.patch(
324            "torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True
325        ):
326            res1 = fn(a, b, c)
327            compiled_fn = torch.compile(fn)
328            res2, (code,) = run_and_get_code(compiled_fn, a, b, c)
329            # no padding
330            FileCheck().check(f"K = {K}").run(code)
331        self.assertEqual(res1, res2)
332
333    @inductor_config.patch(force_shape_pad=True)
334    def test_pad_single_cat(self):
335        @torch.compile()
336        def foo(x, y):
337            return x @ y
338
339        inps = [torch.rand([5, 5], device="cuda") for _ in range(2)]
340        out = foo(*inps)
341        self.assertEqual(out, inps[0] @ inps[1])
342
343    @inductor_config.patch(force_shape_pad=True)
344    @fresh_inductor_cache()
345    def test_pad_addmm_2d_bias(self):
346        @torch.compile()
347        def foo(input, x, y):
348            return torch.ops.aten.addmm(input, x, y)
349
350        for a in [1, 4]:
351            for b in [1, 6]:
352                inps = (
353                    torch.rand([a, b], device="cuda"),
354                    torch.rand([4, 5], device="cuda"),
355                    torch.rand([5, 6], device="cuda"),
356                )
357                out = foo(*inps)
358                out_eager = torch.ops.aten.addmm(*inps)
359                self.assertEqual(out, out_eager)
360
361        for a in [1, 6]:
362            inps = (
363                torch.rand([a], device="cuda"),
364                torch.rand([4, 5], device="cuda"),
365                torch.rand([5, 6], device="cuda"),
366            )
367            out = foo(*inps)
368            out_eager = torch.ops.aten.addmm(*inps)
369            self.assertEqual(out, out_eager)
370
371    @inductor_config.patch(force_shape_pad=True)
372    def test_pad_batch(self):
373        m = 6
374        n = 9
375        k = 11
376        batch_size = 3
377        mat1 = torch.ones((batch_size, m, k), device="cuda", dtype=torch.float16)
378        mat2 = torch.ones((batch_size, k, n), device="cuda", dtype=torch.float16)
379        expected_alignment = get_alignment_size(mat1)
380
381        assert expected_alignment == 8, "Alignment for float16 should be 8"
382        assert should_pad_common(
383            mat1, mat2
384        ), "This should pass the common padding criteria"
385
386        @torch.compile()
387        def bmm(mat1, mat2):
388            return torch.bmm(mat1, mat2)
389
390        res2, (code,) = run_and_get_code(bmm, mat1, mat2)
391        bmm_expected_result = torch.bmm(mat1, mat2)
392        # in call code, expect to see a single pad per input, and then we should see padded allocation for output
393        FileCheck().check("del async_compile").check_count(
394            ".run(", 2, exactly=True
395        ).check("empty_strided_cuda((3, 8, 16)").run(code)
396
397        assert torch.allclose(
398            res2, bmm_expected_result
399        ), "BMM results are not identical"
400
401    @fresh_inductor_cache()
402    def test_exclude_padding(self):
403        @torch.compile()
404        def mm(a, b):
405            return a @ b
406
407        mm(torch.rand([25, 25], device="cuda"), torch.rand([25, 25], device="cuda"))
408        local_cache = get_pad_cache().get_local_cache()
409        self.assertTrue(len(local_cache) == 2)
410        FileCheck().check_count("exclude_pad:False", 2, exactly=True).run(
411            repr(local_cache)
412        )
413
414        @torch.compile()
415        def mm(a, b):
416            return (a + 1) @ b
417
418        mm(torch.rand([25, 25], device="cuda"), torch.rand([25, 25], device="cuda"))
419        local_cache = get_pad_cache().get_local_cache()
420        # reuse original base timing
421        self.assertTrue(len(local_cache) == 3)
422
423        FileCheck().check_count("exclude_pad:False", 3, exactly=True).run(
424            repr(local_cache)
425        )
426        FileCheck().check_count("exclude_pad:True", 1, exactly=True).run(
427            repr(local_cache)
428        )
429
430    @fresh_inductor_cache()
431    @inductor_config.patch(max_pointwise_cat_inputs=2)
432    def test_exclude_cat_padding(self):
433        @torch.compile()
434        def mm(inps, b):
435            return torch.cat(inps) @ b
436
437        inp = torch.rand([2046, 2046], device="cuda")
438        inp2 = torch.rand([2046, 2046], device="cuda")
439
440        inps = inp.chunk(3)
441        mm(inps, inp2)
442        FileCheck().check_count("exclude_pad:False", 2, exactly=True).run(
443            repr(get_pad_cache().get_local_cache())
444        )
445
446        inps = inp.chunk(2)
447        mm(inps, inp2)
448        FileCheck().check_count("exclude_pad:False", 3, exactly=True).run(
449            repr(get_pad_cache().get_local_cache())
450        )
451
452
453if __name__ == "__main__":
454    if HAS_CUDA:
455        run_tests()
456