xref: /aosp_15_r20/external/pytorch/test/inductor/test_b2b_gemm.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import os
3import unittest
4
5import torch
6from torch._inductor.runtime.benchmarking import benchmarker
7from torch._inductor.test_case import run_tests, TestCase
8from torch._inductor.utils import run_and_get_code
9from torch.testing._internal.inductor_utils import HAS_CUDA
10
11
12class B2BGEMMTest(TestCase):
13    @torch._dynamo.config.patch(cache_size_limit=32)
14    @torch._inductor.config.patch(b2b_gemm_pass=True)
15    def test_b2b_gemm_left_assoc_good_shape(self):
16        """
17        left_assoc means the pattern is (subgraph(A @ B) @ C)
18        good_shape means the sizes are good for b2b_gemm
19        """
20
21        def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor:
22            g = torch.nn.GELU()
23            return torch.mm(g(torch.mm(m1, m2)), m3)
24
25        def f_32(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor:
26            """
27            When the optimization is applied,
28            the Triton kernel is more precise than the above f,
29            because it internally uses float32 for accumulation while the above f uses float16.
30            To ensure a fair comparison,
31            we promote the baseline f to float32 for precision comparison.
32            This actually reduced some atol's in the tests from 0.2 to 0.1.
33            """
34            m1 = m1.to(torch.float32)
35            m2 = m2.to(torch.float32)
36            m3 = m3.to(torch.float32)
37            return f(m1, m2, m3).to(torch.float16)
38
39        f_opt = torch.compile(f)
40        A = torch.randn((256, 32), device="cuda", dtype=torch.float16)
41        B = torch.randn((32, 256), device="cuda", dtype=torch.float16)
42        C = torch.randn((256, 32), device="cuda", dtype=torch.float16)
43        res, (code,) = run_and_get_code(f_opt, A, B, C)
44        self.assertTrue(torch.allclose(f_32(A, B, C), res, atol=0.1, rtol=0.01))
45        self.assertTrue("B2B_GEMM_LEFT_TRITON_ENTRANCE" in code)
46
47    @torch._dynamo.config.patch(cache_size_limit=32)
48    @torch._inductor.config.patch(b2b_gemm_pass=True)
49    def test_b2b_gemm_right_assoc_good_shape(self):
50        """
51        right_assoc means the pattern is (A @ subgraph(B @ C))
52        good_shape means the sizes are good for b2b_gemm
53        """
54
55        def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor:
56            g = torch.nn.ReLU()
57            return torch.mm(m1, g(torch.mm(m2, m3)))
58
59        def f_32(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor:
60            m1 = m1.to(torch.float32)
61            m2 = m2.to(torch.float32)
62            m3 = m3.to(torch.float32)
63            return f(m1, m2, m3).to(torch.float16)
64
65        f_opt = torch.compile(f)
66        A = torch.randn((32, 256), device="cuda", dtype=torch.float16)
67        B = torch.randn((256, 32), device="cuda", dtype=torch.float16)
68        C = torch.randn((32, 256), device="cuda", dtype=torch.float16)
69        res, (code,) = run_and_get_code(f_opt, A, B, C)
70        self.assertTrue(torch.allclose(f_32(A, B, C), res, atol=0.1, rtol=0.01))
71        self.assertTrue("B2B_GEMM_RIGHT_TRITON_ENTRANCE" in code)
72
73    @torch._dynamo.config.patch(cache_size_limit=32)
74    @torch._inductor.config.patch(b2b_gemm_pass=True)
75    def test_b2b_gemm_trivial_left_assoc_good_shape(self):
76        """
77        trivial_left_assoc means the pattern is ((A @ B) @ C)
78        good_shape means the sizes are good for b2b_gemm
79        """
80
81        def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor:
82            return torch.mm(torch.mm(m1, m2), m3)
83
84        def f_32(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor:
85            m1 = m1.to(torch.float32)
86            m2 = m2.to(torch.float32)
87            m3 = m3.to(torch.float32)
88            return f(m1, m2, m3).to(torch.float16)
89
90        f_opt = torch.compile(f)
91        A = torch.randn((256, 32), device="cuda", dtype=torch.float16)
92        B = torch.randn((32, 256), device="cuda", dtype=torch.float16)
93        C = torch.randn((256, 32), device="cuda", dtype=torch.float16)
94        res, (code,) = run_and_get_code(f_opt, A, B, C)
95        self.assertTrue(torch.allclose(f_32(A, B, C), res, atol=0.1, rtol=0.01))
96        self.assertTrue("B2B_GEMM_LEFT_TRITON_ENTRANCE" in code)
97
98    @torch._dynamo.config.patch(cache_size_limit=32)
99    @torch._inductor.config.patch(b2b_gemm_pass=True)
100    def test_b2b_gemm_trivial_right_assoc_good_shape(self):
101        """
102        trivial_right_assoc means the pattern is (A @ (B @ C))
103        good_shape means the sizes are good for b2b_gemm
104        """
105
106        def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor:
107            return torch.mm(m1, torch.mm(m2, m3))
108
109        def f_32(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor:
110            m1 = m1.to(torch.float32)
111            m2 = m2.to(torch.float32)
112            m3 = m3.to(torch.float32)
113            return f(m1, m2, m3).to(torch.float16)
114
115        f_opt = torch.compile(f)
116        A = torch.randn((32, 256), device="cuda", dtype=torch.float16)
117        B = torch.randn((256, 32), device="cuda", dtype=torch.float16)
118        C = torch.randn((32, 256), device="cuda", dtype=torch.float16)
119        res, (code,) = run_and_get_code(f_opt, A, B, C)
120        self.assertTrue(torch.allclose(f_32(A, B, C), res, atol=0.1, rtol=0.01))
121        self.assertTrue("B2B_GEMM_RIGHT_TRITON_ENTRANCE" in code)
122
123    @torch._dynamo.config.patch(cache_size_limit=32)
124    @torch._inductor.config.patch(b2b_gemm_pass=True)
125    def test_b2b_gemm_bad_pattern_good_shape(self):
126        """
127        bad_pattern means the code does not contain the supported patterns
128        """
129
130        def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor:
131            mm1 = torch.mm(m1, m2)
132            mm2 = torch.mm(mm1, m3)
133            return torch.mm(mm1, mm2)
134
135        f_opt = torch.compile(f)
136        A = torch.randn((256, 32), device="cuda", dtype=torch.float16)
137        B = torch.randn((32, 256), device="cuda", dtype=torch.float16)
138        C = torch.randn((256, 32), device="cuda", dtype=torch.float16)
139        res, (code,) = run_and_get_code(f_opt, A, B, C)
140        self.assertTrue(torch.allclose(f(A, B, C), res, atol=0.1, rtol=0.01))
141        self.assertTrue("B2B_GEMM_LEFT_TRITON_ENTRANCE" not in code)
142        self.assertTrue("B2B_GEMM_RIGHT_TRITON_ENTRANCE" not in code)
143
144    @torch._dynamo.config.patch(cache_size_limit=32)
145    @torch._inductor.config.patch(b2b_gemm_pass=True)
146    def test_b2b_gemm_good_pattern_bad_shape(self):
147        """
148        bad_shape means the sizes are not good for b2b_gemm
149        """
150
151        def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor:
152            return torch.mm(torch.mm(m1, m2), m3)
153
154        f_opt = torch.compile(f)
155        A = torch.randn((100, 100), device="cuda", dtype=torch.float16)
156        B = torch.randn((100, 100), device="cuda", dtype=torch.float16)
157        C = torch.randn((100, 100), device="cuda", dtype=torch.float16)
158        res, (code,) = run_and_get_code(f_opt, A, B, C)
159        self.assertTrue(torch.allclose(f(A, B, C), res, atol=0.1, rtol=0.01))
160        self.assertTrue("B2B_GEMM_LEFT_TRITON_ENTRANCE" not in code)
161        self.assertTrue("B2B_GEMM_RIGHT_TRITON_ENTRANCE" not in code)
162
163    @unittest.skipIf(
164        not (os.environ.get("DO_PERF_TEST") == "1"), "Perf test not enabled"
165    )
166    @torch._dynamo.config.patch(cache_size_limit=32)
167    def test_plain_b2b_gemm_performance(self):
168        """compare torch.compile(f, b2b_gemm = off) with torch.compile(f, b2b_gemm = on)"""
169
170        def run_with_b2b_gemm_off(
171            m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor
172        ) -> float:
173            def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor:
174                return torch.mm(torch.mm(m1, m2), m3)
175
176            f_opt = torch.compile(f, dynamic=False)
177            return benchmarker.benchmark(f_opt, (m1, m2, m3), {}, warmup=100, rep=500)
178
179        @torch._inductor.config.patch(b2b_gemm_pass=True)
180        def run_with_b2b_gemm_on(
181            m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor
182        ) -> float:
183            def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor:
184                return torch.mm(torch.mm(m1, m2), m3)
185
186            f_opt = torch.compile(f, dynamic=False)
187            return benchmarker.benchmark(f_opt, (m1, m2, m3), {}, warmup=100, rep=500)
188
189        Ms = [128, 256, 300, 400, 512]
190        Ns = [16, 20, 32, 40, 50, 64]
191        speedups = []
192        print("Perf Test for Plain B2B-GEMM:")
193        print("Speedups".ljust(10), end="")
194        for N in Ns:
195            print(f"N = {N}".ljust(10), end="")
196        print()
197        for M in Ms:
198            print(f"M = {M}".ljust(10), end="")
199            for N in Ns:
200                O, P = M, N
201                A = torch.randn((M, N), device="cuda", dtype=torch.float16)
202                B = torch.randn((N, O), device="cuda", dtype=torch.float16)
203                C = torch.randn((O, P), device="cuda", dtype=torch.float16)
204                speedup = run_with_b2b_gemm_off(A, B, C) / run_with_b2b_gemm_on(A, B, C)
205                print(f"{round(speedup, 3)}".ljust(10), end="")
206                speedups.append(speedup)
207            print()
208
209        average_speedup = 1.0
210        for s in speedups:
211            average_speedup *= s
212        average_speedup = average_speedup ** (1 / len(speedups))
213        print(f"Average speedup: {round(average_speedup, 3)}")
214
215        # flaky test assertion: disabled
216        # self.assertTrue(average_speedup > 1)
217
218    @unittest.skipIf(
219        not (os.environ.get("DO_PERF_TEST") == "1"), "Perf test not enabled"
220    )
221    @torch._dynamo.config.patch(cache_size_limit=32)
222    def test_gelu_b2b_gemm_performance(self):
223        """compare torch.compile(f, b2b_gemm = off) with torch.compile(f, b2b_gemm = on)"""
224
225        def run_with_b2b_gemm_off(
226            m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor
227        ) -> float:
228            def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor:
229                g = torch.nn.GELU()
230                return torch.mm(g(torch.mm(m1, m2)), m3)
231
232            f_opt = torch.compile(f, dynamic=False)
233            return benchmarker.benchmark(f_opt, (m1, m2, m3), {}, warmup=100, rep=500)
234
235        @torch._inductor.config.patch(b2b_gemm_pass=True)
236        def run_with_b2b_gemm_on(
237            m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor
238        ) -> float:
239            def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor:
240                g = torch.nn.GELU()
241                return torch.mm(g(torch.mm(m1, m2)), m3)
242
243            f_opt = torch.compile(f, dynamic=False)
244            return benchmarker.benchmark(f_opt, (m1, m2, m3), {}, warmup=100, rep=500)
245
246        Ms = [128, 256, 300, 400, 512]
247        Ns = [16, 20, 32, 40, 50, 64]
248        speedups = []
249        print("Perf Test for GELU B2B-GEMM:")
250        print("Speedups".ljust(10), end="")
251        for N in Ns:
252            print(f"N = {N}".ljust(10), end="")
253        print()
254        for M in Ms:
255            print(f"M = {M}".ljust(10), end="")
256            for N in Ns:
257                O, P = M, N
258                A = torch.randn((M, N), device="cuda", dtype=torch.float16)
259                B = torch.randn((N, O), device="cuda", dtype=torch.float16)
260                C = torch.randn((O, P), device="cuda", dtype=torch.float16)
261                speedup = run_with_b2b_gemm_off(A, B, C) / run_with_b2b_gemm_on(A, B, C)
262                print(f"{round(speedup, 3)}".ljust(10), end="")
263                speedups.append(speedup)
264            print()
265
266        average_speedup = 1.0
267        for s in speedups:
268            average_speedup *= s
269        average_speedup = average_speedup ** (1 / len(speedups))
270        print(f"Average speedup: {round(average_speedup, 3)}")
271
272        # flaky test assertion: disabled
273        # self.assertTrue(average_speedup > 1)
274
275    @unittest.skipIf(
276        not (os.environ.get("DO_PERF_TEST") == "1"), "Perf test not enabled"
277    )
278    @torch._dynamo.config.patch(cache_size_limit=32)
279    def test_gelu_mlp_b2b_gemm_performance(self):
280        """compare torch.compile(f, b2b_gemm = off) with torch.compile(f, b2b_gemm = on)"""
281
282        def run_with_b2b_gemm_off(
283            m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor
284        ) -> float:
285            def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor:
286                g = torch.nn.GELU()
287                return torch.mm(g(torch.mm(m1, m2)), m3)
288
289            f_opt = torch.compile(f, dynamic=False)
290            return benchmarker.benchmark(f_opt, (m1, m2, m3), {}, warmup=100, rep=500)
291
292        @torch._inductor.config.patch(b2b_gemm_pass=True)
293        def run_with_b2b_gemm_on(
294            m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor
295        ) -> float:
296            def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor:
297                g = torch.nn.GELU()
298                return torch.mm(g(torch.mm(m1, m2)), m3)
299
300            f_opt = torch.compile(f, dynamic=False)
301            return benchmarker.benchmark(f_opt, (m1, m2, m3), {}, warmup=100, rep=500)
302
303        Ms = [128, 256, 300, 400, 512]
304        Ns = [16, 20, 32, 40, 50, 64]
305        speedups = []
306        print("Perf Test for GELU B2B-GEMM (MLP):")
307        print("Speedups".ljust(10), end="")
308        for N in Ns:
309            print(f"N = {N}".ljust(10), end="")
310        print()
311        for M in Ms:
312            print(f"M = {M}".ljust(10), end="")
313            for N in Ns:
314                O, P = N, N
315                A = torch.randn((M, N), device="cuda", dtype=torch.float16)
316                B = torch.randn((N, O), device="cuda", dtype=torch.float16)
317                C = torch.randn((O, P), device="cuda", dtype=torch.float16)
318                speedup = run_with_b2b_gemm_off(A, B, C) / run_with_b2b_gemm_on(A, B, C)
319                print(f"{round(speedup, 3)}".ljust(10), end="")
320                speedups.append(speedup)
321            print()
322
323        average_speedup = 1.0
324        for s in speedups:
325            average_speedup *= s
326        average_speedup = average_speedup ** (1 / len(speedups))
327        print(f"Average speedup: {round(average_speedup, 3)}")
328
329        # flaky test assertion: disabled
330        # self.assertTrue(average_speedup > 1)
331
332
333if __name__ == "__main__":
334    if HAS_CUDA:
335        run_tests()
336