1# Owner(s): ["module: inductor"] 2import unittest 3 4import torch 5import torch._inductor.config as inductor_config 6from torch._dynamo.testing import rand_strided 7from torch._inductor.fx_passes.pad_mm import ( 8 get_alignment_size, 9 get_pad_cache, 10 get_padded_length, 11 should_pad_common, 12) 13from torch._inductor.test_case import run_tests, TestCase 14from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code 15from torch.testing import FileCheck 16from torch.testing._internal.inductor_utils import HAS_CUDA 17 18 19class PadMMTest(TestCase): 20 def setUp(self): 21 super().setUp() 22 if not is_big_gpu(0): 23 return self.skipTest("Need a big GPU to run max_autotune=True") 24 25 @inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON") 26 def test_pad_mm_dyn_m(self): 27 M = 40 28 K1 = 581 29 K2 = 49 30 N = 30 31 32 class Model(torch.nn.Module): 33 def __init__(self) -> None: 34 super().__init__() 35 self.w = rand_strided( 36 (K2, N), (1, K2), device="cuda", dtype=torch.float32 37 ) 38 39 def forward(self, a): 40 a1 = torch.narrow(a, 1, 0, K2) 41 return torch.mm(a1, self.w) 42 43 fn = Model().cuda() 44 a = rand_strided((M, K1), (K1, 1), device="cuda", dtype=torch.float32) 45 aligned_k = get_padded_length(K2, get_alignment_size(a)) + K2 46 torch._dynamo.mark_dynamic(a, 0) 47 with unittest.mock.patch( 48 "torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True 49 ): 50 res1 = fn(a) 51 compiled_fn = torch.compile(fn) 52 res2, (code,) = run_and_get_code(compiled_fn, a) 53 FileCheck().check(f"K = {aligned_k}").run(code) 54 self.assertEqual(res1, res2) 55 56 @inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON") 57 def test_cat_pad_mm_dyn_m(self): 58 M1 = 128 59 M2 = 40 60 K1 = 129 61 K2 = 111 62 N = 100 63 64 class Model(torch.nn.Module): 65 def __init__(self) -> None: 66 super().__init__() 67 self.w = rand_strided( 68 (K2, N), (1, K2), device="cuda", dtype=torch.float32 69 ) 70 71 def forward(self, a, b): 72 c = torch.cat([a, b], dim=0) 73 a1 = torch.narrow(c, 1, 0, K2) 74 return torch.mm(a1, self.w) 75 76 fn = Model().cuda() 77 a = rand_strided((M1, K1), (K1, 1), device="cuda", dtype=torch.float32) 78 b = rand_strided((M2, K1), (K1, 1), device="cuda", dtype=torch.float32) 79 torch._dynamo.mark_dynamic(a, 0) 80 torch._dynamo.mark_dynamic(b, 0) 81 aligned_k = get_padded_length(K2, get_alignment_size(a)) + K2 82 with unittest.mock.patch( 83 "torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True 84 ): 85 res1 = fn(a, b) 86 compiled_fn = torch.compile(fn) 87 res2, (code,) = run_and_get_code(compiled_fn, a, b) 88 FileCheck().check(f"K = {aligned_k}").run(code) 89 self.assertEqual(res1, res2) 90 91 @inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON") 92 def test_pad_mm_dyn_n(self): 93 M = 20 94 K = 81 95 N = 30 96 97 class Model(torch.nn.Module): 98 def __init__(self) -> None: 99 super().__init__() 100 101 def forward(self, a, b): 102 return torch.mm(a, b) 103 104 fn = Model().cuda() 105 a = rand_strided((M, K), (K, 1), device="cuda", dtype=torch.float32) 106 b = rand_strided((K, N), (1, K), device="cuda", dtype=torch.float32) 107 aligned_k = get_padded_length(K, get_alignment_size(a)) + K 108 torch._dynamo.mark_dynamic(b, 1) 109 with unittest.mock.patch( 110 "torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True 111 ): 112 res1 = fn(a, b) 113 compiled_fn = torch.compile(fn) 114 res2, (code,) = run_and_get_code(compiled_fn, a, b) 115 FileCheck().check(f"K = {aligned_k}").run(code) 116 self.assertEqual(res1, res2) 117 118 @inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON") 119 def test_pad_mm_dyn_k(self): 120 M = 21 121 K = 80 122 N = 30 123 124 class Model(torch.nn.Module): 125 def __init__(self) -> None: 126 super().__init__() 127 128 def forward(self, a, b): 129 return torch.mm(a, b) 130 131 fn = Model().cuda() 132 a = rand_strided((M, K), (K, 1), device="cuda", dtype=torch.float32) 133 b = rand_strided((K, N), (1, K), device="cuda", dtype=torch.float32) 134 # TODO: Getting the alignment right requires pattern matcher to 135 # run on newly added nodes 136 aligned_m = get_padded_length(M, get_alignment_size(a)) + M 137 torch._dynamo.mark_dynamic(a, 1) 138 torch._dynamo.mark_dynamic(b, 0) 139 with unittest.mock.patch( 140 "torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True 141 ): 142 res1 = fn(a, b) 143 compiled_fn = torch.compile(fn) 144 res2, (code,) = run_and_get_code(compiled_fn, a, b) 145 FileCheck().check(f"M = {aligned_m}").run(code) 146 self.assertEqual(res1, res2) 147 148 def test_pad_mm_dyn_mnk(self): 149 M = 20 150 K = 81 151 N = 30 152 153 class Model(torch.nn.Module): 154 def __init__(self) -> None: 155 super().__init__() 156 157 def forward(self, a, b): 158 return torch.mm(a, b) 159 160 fn = Model().cuda() 161 a = rand_strided((M, K), (K, 1), device="cuda", dtype=torch.float32) 162 b = rand_strided((K, N), (1, K), device="cuda", dtype=torch.float32) 163 torch._dynamo.mark_dynamic(a, 0) 164 torch._dynamo.mark_dynamic(a, 1) 165 torch._dynamo.mark_dynamic(b, 0) 166 torch._dynamo.mark_dynamic(b, 1) 167 with unittest.mock.patch( 168 "torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True 169 ): 170 res1 = fn(a, b) 171 compiled_fn = torch.compile(fn) 172 res2, (code,) = run_and_get_code(compiled_fn, a, b) 173 self.assertEqual(res1, res2) 174 175 @inductor_config.patch(force_shape_pad=True) 176 def test_zero_dim(self): 177 def addmm(x, a, b): 178 return torch.addmm(x, a, b) 179 180 x = torch.randn(100).cuda() 181 a = torch.randn(0, 10).cuda() 182 b = torch.randn(10, 100).cuda() 183 self.assertEqual(torch.compile(addmm)(x, a, b), addmm(x, a, b)) 184 185 @inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON") 186 def test_pad_bmm_dyn_b(self): 187 B = 10 188 M = 128 189 K = 33 190 N = 40 191 192 class Model(torch.nn.Module): 193 def __init__(self) -> None: 194 super().__init__() 195 196 def forward(self, a, b): 197 return torch.bmm(a, b) 198 199 fn = Model().cuda() 200 a = torch.randn(B, M, K, device="cuda", dtype=torch.float32) 201 b = torch.randn(B, K, N, device="cuda", dtype=torch.float32) 202 aligned_k = get_padded_length(K, get_alignment_size(a)) + K 203 torch._dynamo.mark_dynamic(a, 0) 204 torch._dynamo.mark_dynamic(b, 0) 205 with unittest.mock.patch( 206 "torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True 207 ): 208 res1 = fn(a, b) 209 compiled_fn = torch.compile(fn) 210 res2, (code,) = run_and_get_code(compiled_fn, a, b) 211 FileCheck().check(f"K = {aligned_k}").run(code) 212 self.assertEqual(res1, res2) 213 214 @inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON") 215 def test_pad_bmm_dyn_k(self): 216 B = 10 217 M = 128 218 K = 40 219 N = 41 220 221 class Model(torch.nn.Module): 222 def __init__(self) -> None: 223 super().__init__() 224 225 def forward(self, a, b): 226 return torch.bmm(a, b) 227 228 fn = Model().cuda() 229 a = torch.randn(B, M, K, device="cuda", dtype=torch.float32) 230 b = torch.randn(B, K, N, device="cuda", dtype=torch.float32) 231 aligned_n = get_padded_length(N, get_alignment_size(b)) + N 232 torch._dynamo.mark_dynamic(a, 2) 233 torch._dynamo.mark_dynamic(b, 1) 234 with unittest.mock.patch( 235 "torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True 236 ): 237 res1 = fn(a, b) 238 compiled_fn = torch.compile(fn) 239 res2, (code,) = run_and_get_code(compiled_fn, a, b) 240 FileCheck().check(f"N = {aligned_n}").run(code) 241 self.assertEqual(res1, res2) 242 243 @inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON") 244 def test_pad_bmm_dyn_bm(self): 245 B = 10 246 M = 128 247 K = 40 248 N = 41 249 250 class Model(torch.nn.Module): 251 def __init__(self) -> None: 252 super().__init__() 253 254 def forward(self, a, b): 255 return torch.bmm(a, b) 256 257 fn = Model().cuda() 258 a = torch.randn(B, M, K, device="cuda", dtype=torch.float32) 259 b = torch.randn(B, K, N, device="cuda", dtype=torch.float32) 260 aligned_n = get_padded_length(N, get_alignment_size(b)) + N 261 torch._dynamo.mark_dynamic(a, 0) 262 torch._dynamo.mark_dynamic(a, 1) 263 torch._dynamo.mark_dynamic(b, 0) 264 with unittest.mock.patch( 265 "torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True 266 ): 267 res1 = fn(a, b) 268 compiled_fn = torch.compile(fn) 269 res2, (code,) = run_and_get_code(compiled_fn, a, b) 270 FileCheck().check(f"N = {aligned_n}").run(code) 271 self.assertEqual(res1, res2) 272 273 @inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON") 274 def test_pad_addmm_dyn_m(self): 275 M = 128 276 K = 33 277 N = 40 278 279 class Model(torch.nn.Module): 280 def __init__(self) -> None: 281 super().__init__() 282 283 def forward(self, a, b, c): 284 return torch.addmm(a, b, c) 285 286 fn = Model().cuda() 287 a = torch.randn(M, N, device="cuda", dtype=torch.float32) 288 b = torch.randn(M, K, device="cuda", dtype=torch.float32) 289 c = torch.randn(K, N, device="cuda", dtype=torch.float32) 290 aligned_k = get_padded_length(K, get_alignment_size(b)) + K 291 torch._dynamo.mark_dynamic(a, 0) 292 torch._dynamo.mark_dynamic(b, 0) 293 with unittest.mock.patch( 294 "torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True 295 ): 296 res1 = fn(a, b, c) 297 compiled_fn = torch.compile(fn) 298 res2, (code,) = run_and_get_code(compiled_fn, a, b, c) 299 FileCheck().check(f"K = {aligned_k}").run(code) 300 self.assertEqual(res1, res2) 301 302 @inductor_config.patch(max_autotune=True, max_autotune_gemm_backends="TRITON") 303 def test_pad_addmm_dyn_mn(self): 304 M = 128 305 K = 33 306 N = 40 307 308 class Model(torch.nn.Module): 309 def __init__(self) -> None: 310 super().__init__() 311 312 def forward(self, a, b, c): 313 return torch.addmm(a, b, c) 314 315 fn = Model().cuda() 316 a = torch.randn(M, N, device="cuda", dtype=torch.float32) 317 b = torch.randn(M, K, device="cuda", dtype=torch.float32) 318 c = torch.randn(K, N, device="cuda", dtype=torch.float32) 319 torch._dynamo.mark_dynamic(a, 0) 320 torch._dynamo.mark_dynamic(a, 1) 321 torch._dynamo.mark_dynamic(b, 0) 322 torch._dynamo.mark_dynamic(c, 1) 323 with unittest.mock.patch( 324 "torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True 325 ): 326 res1 = fn(a, b, c) 327 compiled_fn = torch.compile(fn) 328 res2, (code,) = run_and_get_code(compiled_fn, a, b, c) 329 # no padding 330 FileCheck().check(f"K = {K}").run(code) 331 self.assertEqual(res1, res2) 332 333 @inductor_config.patch(force_shape_pad=True) 334 def test_pad_single_cat(self): 335 @torch.compile() 336 def foo(x, y): 337 return x @ y 338 339 inps = [torch.rand([5, 5], device="cuda") for _ in range(2)] 340 out = foo(*inps) 341 self.assertEqual(out, inps[0] @ inps[1]) 342 343 @inductor_config.patch(force_shape_pad=True) 344 @fresh_inductor_cache() 345 def test_pad_addmm_2d_bias(self): 346 @torch.compile() 347 def foo(input, x, y): 348 return torch.ops.aten.addmm(input, x, y) 349 350 for a in [1, 4]: 351 for b in [1, 6]: 352 inps = ( 353 torch.rand([a, b], device="cuda"), 354 torch.rand([4, 5], device="cuda"), 355 torch.rand([5, 6], device="cuda"), 356 ) 357 out = foo(*inps) 358 out_eager = torch.ops.aten.addmm(*inps) 359 self.assertEqual(out, out_eager) 360 361 for a in [1, 6]: 362 inps = ( 363 torch.rand([a], device="cuda"), 364 torch.rand([4, 5], device="cuda"), 365 torch.rand([5, 6], device="cuda"), 366 ) 367 out = foo(*inps) 368 out_eager = torch.ops.aten.addmm(*inps) 369 self.assertEqual(out, out_eager) 370 371 @inductor_config.patch(force_shape_pad=True) 372 def test_pad_batch(self): 373 m = 6 374 n = 9 375 k = 11 376 batch_size = 3 377 mat1 = torch.ones((batch_size, m, k), device="cuda", dtype=torch.float16) 378 mat2 = torch.ones((batch_size, k, n), device="cuda", dtype=torch.float16) 379 expected_alignment = get_alignment_size(mat1) 380 381 assert expected_alignment == 8, "Alignment for float16 should be 8" 382 assert should_pad_common( 383 mat1, mat2 384 ), "This should pass the common padding criteria" 385 386 @torch.compile() 387 def bmm(mat1, mat2): 388 return torch.bmm(mat1, mat2) 389 390 res2, (code,) = run_and_get_code(bmm, mat1, mat2) 391 bmm_expected_result = torch.bmm(mat1, mat2) 392 # in call code, expect to see a single pad per input, and then we should see padded allocation for output 393 FileCheck().check("del async_compile").check_count( 394 ".run(", 2, exactly=True 395 ).check("empty_strided_cuda((3, 8, 16)").run(code) 396 397 assert torch.allclose( 398 res2, bmm_expected_result 399 ), "BMM results are not identical" 400 401 @fresh_inductor_cache() 402 def test_exclude_padding(self): 403 @torch.compile() 404 def mm(a, b): 405 return a @ b 406 407 mm(torch.rand([25, 25], device="cuda"), torch.rand([25, 25], device="cuda")) 408 local_cache = get_pad_cache().get_local_cache() 409 self.assertTrue(len(local_cache) == 2) 410 FileCheck().check_count("exclude_pad:False", 2, exactly=True).run( 411 repr(local_cache) 412 ) 413 414 @torch.compile() 415 def mm(a, b): 416 return (a + 1) @ b 417 418 mm(torch.rand([25, 25], device="cuda"), torch.rand([25, 25], device="cuda")) 419 local_cache = get_pad_cache().get_local_cache() 420 # reuse original base timing 421 self.assertTrue(len(local_cache) == 3) 422 423 FileCheck().check_count("exclude_pad:False", 3, exactly=True).run( 424 repr(local_cache) 425 ) 426 FileCheck().check_count("exclude_pad:True", 1, exactly=True).run( 427 repr(local_cache) 428 ) 429 430 @fresh_inductor_cache() 431 @inductor_config.patch(max_pointwise_cat_inputs=2) 432 def test_exclude_cat_padding(self): 433 @torch.compile() 434 def mm(inps, b): 435 return torch.cat(inps) @ b 436 437 inp = torch.rand([2046, 2046], device="cuda") 438 inp2 = torch.rand([2046, 2046], device="cuda") 439 440 inps = inp.chunk(3) 441 mm(inps, inp2) 442 FileCheck().check_count("exclude_pad:False", 2, exactly=True).run( 443 repr(get_pad_cache().get_local_cache()) 444 ) 445 446 inps = inp.chunk(2) 447 mm(inps, inp2) 448 FileCheck().check_count("exclude_pad:False", 3, exactly=True).run( 449 repr(get_pad_cache().get_local_cache()) 450 ) 451 452 453if __name__ == "__main__": 454 if HAS_CUDA: 455 run_tests() 456