1# Owner(s): ["module: inductor"] 2import os 3import unittest 4 5import torch 6from torch._inductor.runtime.benchmarking import benchmarker 7from torch._inductor.test_case import run_tests, TestCase 8from torch._inductor.utils import run_and_get_code 9from torch.testing._internal.inductor_utils import HAS_CUDA 10 11 12class B2BGEMMTest(TestCase): 13 @torch._dynamo.config.patch(cache_size_limit=32) 14 @torch._inductor.config.patch(b2b_gemm_pass=True) 15 def test_b2b_gemm_left_assoc_good_shape(self): 16 """ 17 left_assoc means the pattern is (subgraph(A @ B) @ C) 18 good_shape means the sizes are good for b2b_gemm 19 """ 20 21 def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor: 22 g = torch.nn.GELU() 23 return torch.mm(g(torch.mm(m1, m2)), m3) 24 25 def f_32(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor: 26 """ 27 When the optimization is applied, 28 the Triton kernel is more precise than the above f, 29 because it internally uses float32 for accumulation while the above f uses float16. 30 To ensure a fair comparison, 31 we promote the baseline f to float32 for precision comparison. 32 This actually reduced some atol's in the tests from 0.2 to 0.1. 33 """ 34 m1 = m1.to(torch.float32) 35 m2 = m2.to(torch.float32) 36 m3 = m3.to(torch.float32) 37 return f(m1, m2, m3).to(torch.float16) 38 39 f_opt = torch.compile(f) 40 A = torch.randn((256, 32), device="cuda", dtype=torch.float16) 41 B = torch.randn((32, 256), device="cuda", dtype=torch.float16) 42 C = torch.randn((256, 32), device="cuda", dtype=torch.float16) 43 res, (code,) = run_and_get_code(f_opt, A, B, C) 44 self.assertTrue(torch.allclose(f_32(A, B, C), res, atol=0.1, rtol=0.01)) 45 self.assertTrue("B2B_GEMM_LEFT_TRITON_ENTRANCE" in code) 46 47 @torch._dynamo.config.patch(cache_size_limit=32) 48 @torch._inductor.config.patch(b2b_gemm_pass=True) 49 def test_b2b_gemm_right_assoc_good_shape(self): 50 """ 51 right_assoc means the pattern is (A @ subgraph(B @ C)) 52 good_shape means the sizes are good for b2b_gemm 53 """ 54 55 def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor: 56 g = torch.nn.ReLU() 57 return torch.mm(m1, g(torch.mm(m2, m3))) 58 59 def f_32(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor: 60 m1 = m1.to(torch.float32) 61 m2 = m2.to(torch.float32) 62 m3 = m3.to(torch.float32) 63 return f(m1, m2, m3).to(torch.float16) 64 65 f_opt = torch.compile(f) 66 A = torch.randn((32, 256), device="cuda", dtype=torch.float16) 67 B = torch.randn((256, 32), device="cuda", dtype=torch.float16) 68 C = torch.randn((32, 256), device="cuda", dtype=torch.float16) 69 res, (code,) = run_and_get_code(f_opt, A, B, C) 70 self.assertTrue(torch.allclose(f_32(A, B, C), res, atol=0.1, rtol=0.01)) 71 self.assertTrue("B2B_GEMM_RIGHT_TRITON_ENTRANCE" in code) 72 73 @torch._dynamo.config.patch(cache_size_limit=32) 74 @torch._inductor.config.patch(b2b_gemm_pass=True) 75 def test_b2b_gemm_trivial_left_assoc_good_shape(self): 76 """ 77 trivial_left_assoc means the pattern is ((A @ B) @ C) 78 good_shape means the sizes are good for b2b_gemm 79 """ 80 81 def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor: 82 return torch.mm(torch.mm(m1, m2), m3) 83 84 def f_32(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor: 85 m1 = m1.to(torch.float32) 86 m2 = m2.to(torch.float32) 87 m3 = m3.to(torch.float32) 88 return f(m1, m2, m3).to(torch.float16) 89 90 f_opt = torch.compile(f) 91 A = torch.randn((256, 32), device="cuda", dtype=torch.float16) 92 B = torch.randn((32, 256), device="cuda", dtype=torch.float16) 93 C = torch.randn((256, 32), device="cuda", dtype=torch.float16) 94 res, (code,) = run_and_get_code(f_opt, A, B, C) 95 self.assertTrue(torch.allclose(f_32(A, B, C), res, atol=0.1, rtol=0.01)) 96 self.assertTrue("B2B_GEMM_LEFT_TRITON_ENTRANCE" in code) 97 98 @torch._dynamo.config.patch(cache_size_limit=32) 99 @torch._inductor.config.patch(b2b_gemm_pass=True) 100 def test_b2b_gemm_trivial_right_assoc_good_shape(self): 101 """ 102 trivial_right_assoc means the pattern is (A @ (B @ C)) 103 good_shape means the sizes are good for b2b_gemm 104 """ 105 106 def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor: 107 return torch.mm(m1, torch.mm(m2, m3)) 108 109 def f_32(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor: 110 m1 = m1.to(torch.float32) 111 m2 = m2.to(torch.float32) 112 m3 = m3.to(torch.float32) 113 return f(m1, m2, m3).to(torch.float16) 114 115 f_opt = torch.compile(f) 116 A = torch.randn((32, 256), device="cuda", dtype=torch.float16) 117 B = torch.randn((256, 32), device="cuda", dtype=torch.float16) 118 C = torch.randn((32, 256), device="cuda", dtype=torch.float16) 119 res, (code,) = run_and_get_code(f_opt, A, B, C) 120 self.assertTrue(torch.allclose(f_32(A, B, C), res, atol=0.1, rtol=0.01)) 121 self.assertTrue("B2B_GEMM_RIGHT_TRITON_ENTRANCE" in code) 122 123 @torch._dynamo.config.patch(cache_size_limit=32) 124 @torch._inductor.config.patch(b2b_gemm_pass=True) 125 def test_b2b_gemm_bad_pattern_good_shape(self): 126 """ 127 bad_pattern means the code does not contain the supported patterns 128 """ 129 130 def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor: 131 mm1 = torch.mm(m1, m2) 132 mm2 = torch.mm(mm1, m3) 133 return torch.mm(mm1, mm2) 134 135 f_opt = torch.compile(f) 136 A = torch.randn((256, 32), device="cuda", dtype=torch.float16) 137 B = torch.randn((32, 256), device="cuda", dtype=torch.float16) 138 C = torch.randn((256, 32), device="cuda", dtype=torch.float16) 139 res, (code,) = run_and_get_code(f_opt, A, B, C) 140 self.assertTrue(torch.allclose(f(A, B, C), res, atol=0.1, rtol=0.01)) 141 self.assertTrue("B2B_GEMM_LEFT_TRITON_ENTRANCE" not in code) 142 self.assertTrue("B2B_GEMM_RIGHT_TRITON_ENTRANCE" not in code) 143 144 @torch._dynamo.config.patch(cache_size_limit=32) 145 @torch._inductor.config.patch(b2b_gemm_pass=True) 146 def test_b2b_gemm_good_pattern_bad_shape(self): 147 """ 148 bad_shape means the sizes are not good for b2b_gemm 149 """ 150 151 def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor: 152 return torch.mm(torch.mm(m1, m2), m3) 153 154 f_opt = torch.compile(f) 155 A = torch.randn((100, 100), device="cuda", dtype=torch.float16) 156 B = torch.randn((100, 100), device="cuda", dtype=torch.float16) 157 C = torch.randn((100, 100), device="cuda", dtype=torch.float16) 158 res, (code,) = run_and_get_code(f_opt, A, B, C) 159 self.assertTrue(torch.allclose(f(A, B, C), res, atol=0.1, rtol=0.01)) 160 self.assertTrue("B2B_GEMM_LEFT_TRITON_ENTRANCE" not in code) 161 self.assertTrue("B2B_GEMM_RIGHT_TRITON_ENTRANCE" not in code) 162 163 @unittest.skipIf( 164 not (os.environ.get("DO_PERF_TEST") == "1"), "Perf test not enabled" 165 ) 166 @torch._dynamo.config.patch(cache_size_limit=32) 167 def test_plain_b2b_gemm_performance(self): 168 """compare torch.compile(f, b2b_gemm = off) with torch.compile(f, b2b_gemm = on)""" 169 170 def run_with_b2b_gemm_off( 171 m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor 172 ) -> float: 173 def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor: 174 return torch.mm(torch.mm(m1, m2), m3) 175 176 f_opt = torch.compile(f, dynamic=False) 177 return benchmarker.benchmark(f_opt, (m1, m2, m3), {}, warmup=100, rep=500) 178 179 @torch._inductor.config.patch(b2b_gemm_pass=True) 180 def run_with_b2b_gemm_on( 181 m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor 182 ) -> float: 183 def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor: 184 return torch.mm(torch.mm(m1, m2), m3) 185 186 f_opt = torch.compile(f, dynamic=False) 187 return benchmarker.benchmark(f_opt, (m1, m2, m3), {}, warmup=100, rep=500) 188 189 Ms = [128, 256, 300, 400, 512] 190 Ns = [16, 20, 32, 40, 50, 64] 191 speedups = [] 192 print("Perf Test for Plain B2B-GEMM:") 193 print("Speedups".ljust(10), end="") 194 for N in Ns: 195 print(f"N = {N}".ljust(10), end="") 196 print() 197 for M in Ms: 198 print(f"M = {M}".ljust(10), end="") 199 for N in Ns: 200 O, P = M, N 201 A = torch.randn((M, N), device="cuda", dtype=torch.float16) 202 B = torch.randn((N, O), device="cuda", dtype=torch.float16) 203 C = torch.randn((O, P), device="cuda", dtype=torch.float16) 204 speedup = run_with_b2b_gemm_off(A, B, C) / run_with_b2b_gemm_on(A, B, C) 205 print(f"{round(speedup, 3)}".ljust(10), end="") 206 speedups.append(speedup) 207 print() 208 209 average_speedup = 1.0 210 for s in speedups: 211 average_speedup *= s 212 average_speedup = average_speedup ** (1 / len(speedups)) 213 print(f"Average speedup: {round(average_speedup, 3)}") 214 215 # flaky test assertion: disabled 216 # self.assertTrue(average_speedup > 1) 217 218 @unittest.skipIf( 219 not (os.environ.get("DO_PERF_TEST") == "1"), "Perf test not enabled" 220 ) 221 @torch._dynamo.config.patch(cache_size_limit=32) 222 def test_gelu_b2b_gemm_performance(self): 223 """compare torch.compile(f, b2b_gemm = off) with torch.compile(f, b2b_gemm = on)""" 224 225 def run_with_b2b_gemm_off( 226 m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor 227 ) -> float: 228 def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor: 229 g = torch.nn.GELU() 230 return torch.mm(g(torch.mm(m1, m2)), m3) 231 232 f_opt = torch.compile(f, dynamic=False) 233 return benchmarker.benchmark(f_opt, (m1, m2, m3), {}, warmup=100, rep=500) 234 235 @torch._inductor.config.patch(b2b_gemm_pass=True) 236 def run_with_b2b_gemm_on( 237 m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor 238 ) -> float: 239 def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor: 240 g = torch.nn.GELU() 241 return torch.mm(g(torch.mm(m1, m2)), m3) 242 243 f_opt = torch.compile(f, dynamic=False) 244 return benchmarker.benchmark(f_opt, (m1, m2, m3), {}, warmup=100, rep=500) 245 246 Ms = [128, 256, 300, 400, 512] 247 Ns = [16, 20, 32, 40, 50, 64] 248 speedups = [] 249 print("Perf Test for GELU B2B-GEMM:") 250 print("Speedups".ljust(10), end="") 251 for N in Ns: 252 print(f"N = {N}".ljust(10), end="") 253 print() 254 for M in Ms: 255 print(f"M = {M}".ljust(10), end="") 256 for N in Ns: 257 O, P = M, N 258 A = torch.randn((M, N), device="cuda", dtype=torch.float16) 259 B = torch.randn((N, O), device="cuda", dtype=torch.float16) 260 C = torch.randn((O, P), device="cuda", dtype=torch.float16) 261 speedup = run_with_b2b_gemm_off(A, B, C) / run_with_b2b_gemm_on(A, B, C) 262 print(f"{round(speedup, 3)}".ljust(10), end="") 263 speedups.append(speedup) 264 print() 265 266 average_speedup = 1.0 267 for s in speedups: 268 average_speedup *= s 269 average_speedup = average_speedup ** (1 / len(speedups)) 270 print(f"Average speedup: {round(average_speedup, 3)}") 271 272 # flaky test assertion: disabled 273 # self.assertTrue(average_speedup > 1) 274 275 @unittest.skipIf( 276 not (os.environ.get("DO_PERF_TEST") == "1"), "Perf test not enabled" 277 ) 278 @torch._dynamo.config.patch(cache_size_limit=32) 279 def test_gelu_mlp_b2b_gemm_performance(self): 280 """compare torch.compile(f, b2b_gemm = off) with torch.compile(f, b2b_gemm = on)""" 281 282 def run_with_b2b_gemm_off( 283 m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor 284 ) -> float: 285 def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor: 286 g = torch.nn.GELU() 287 return torch.mm(g(torch.mm(m1, m2)), m3) 288 289 f_opt = torch.compile(f, dynamic=False) 290 return benchmarker.benchmark(f_opt, (m1, m2, m3), {}, warmup=100, rep=500) 291 292 @torch._inductor.config.patch(b2b_gemm_pass=True) 293 def run_with_b2b_gemm_on( 294 m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor 295 ) -> float: 296 def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor: 297 g = torch.nn.GELU() 298 return torch.mm(g(torch.mm(m1, m2)), m3) 299 300 f_opt = torch.compile(f, dynamic=False) 301 return benchmarker.benchmark(f_opt, (m1, m2, m3), {}, warmup=100, rep=500) 302 303 Ms = [128, 256, 300, 400, 512] 304 Ns = [16, 20, 32, 40, 50, 64] 305 speedups = [] 306 print("Perf Test for GELU B2B-GEMM (MLP):") 307 print("Speedups".ljust(10), end="") 308 for N in Ns: 309 print(f"N = {N}".ljust(10), end="") 310 print() 311 for M in Ms: 312 print(f"M = {M}".ljust(10), end="") 313 for N in Ns: 314 O, P = N, N 315 A = torch.randn((M, N), device="cuda", dtype=torch.float16) 316 B = torch.randn((N, O), device="cuda", dtype=torch.float16) 317 C = torch.randn((O, P), device="cuda", dtype=torch.float16) 318 speedup = run_with_b2b_gemm_off(A, B, C) / run_with_b2b_gemm_on(A, B, C) 319 print(f"{round(speedup, 3)}".ljust(10), end="") 320 speedups.append(speedup) 321 print() 322 323 average_speedup = 1.0 324 for s in speedups: 325 average_speedup *= s 326 average_speedup = average_speedup ** (1 / len(speedups)) 327 print(f"Average speedup: {round(average_speedup, 3)}") 328 329 # flaky test assertion: disabled 330 # self.assertTrue(average_speedup > 1) 331 332 333if __name__ == "__main__": 334 if HAS_CUDA: 335 run_tests() 336