1# Owner(s): ["module: inductor"] 2import logging 3import math 4import os 5import unittest 6from typing import Callable, List, Optional 7from unittest import mock 8 9import torch 10from torch._dynamo.utils import counters 11from torch._inductor import config 12from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller 13from torch._inductor.codegen.cuda.cutlass_utils import get_max_alignment 14from torch._inductor.ir import ChoiceCaller, FixedLayout 15from torch._inductor.select_algorithm import NoValidChoicesError 16from torch._inductor.test_case import run_tests, TestCase 17from torch._inductor.utils import fresh_inductor_cache 18from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured 19from torch.testing._internal.common_cuda import SM75OrLater, SM80OrLater, SM90OrLater 20from torch.testing._internal.common_utils import ( 21 instantiate_parametrized_tests, 22 parametrize, 23) 24from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA 25 26 27torch.set_float32_matmul_precision("high") 28if HAS_CUDA: 29 torch.cuda.memory._set_allocator_settings("expandable_segments:False") 30 31_CUTLASS_DIR = os.path.join(os.path.dirname(__file__), "../../third_party/cutlass/") 32 33log = logging.getLogger(__name__) 34 35HAS_CUDA = HAS_CUDA and not torch.version.hip 36SM75OrLater = SM75OrLater and not torch.version.hip 37SM80OrLater = SM80OrLater and not torch.version.hip 38SM90OrLater = SM90OrLater and not torch.version.hip 39SM80 = SM80OrLater and torch.cuda.get_device_capability() == (8, 0) 40 41 42def _get_path_without_sccache() -> str: 43 """ 44 Get the PATH environment variable without sccache. 45 """ 46 path_envs = os.environ.get("PATH", "").split(":") 47 path_envs = [env for env in path_envs if "/opt/cache/bin" not in env] 48 return ":".join(path_envs) 49 50 51@instantiate_parametrized_tests 52class TestCutlassBackend(TestCase): 53 def setUp(self): 54 # The new inductor cache refresh mechanism 55 # introduced with https://github.com/pytorch/pytorch/pull/122661 56 # interacts badly with persistent subprocesses during 57 # autotuning. So we need to disable automatic cache refresh 58 # before calling setUp() on the parent class. 59 old_disable_fresh_cache_envvar = os.environ.get( 60 "INDUCTOR_TEST_DISABLE_FRESH_CACHE", "" 61 ) 62 try: 63 os.environ["INDUCTOR_TEST_DISABLE_FRESH_CACHE"] = "1" 64 super().setUp() 65 finally: 66 os.environ[ 67 "INDUCTOR_TEST_DISABLE_FRESH_CACHE" 68 ] = old_disable_fresh_cache_envvar 69 torch.random.manual_seed(1234) 70 71 @unittest.skipIf(not SM75OrLater, "need sm_75") 72 @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") 73 @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) 74 def test_max_autotune_cutlass_threshold(self): 75 """ 76 Make sure Cutlass GEMM threshold works as intended. 77 """ 78 79 if torch.version.hip: 80 return 81 82 torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False 83 84 def mm(a, b): 85 return a @ b 86 87 a = torch.randn(100, 10).cuda().half() 88 b = torch.randn(10, 100).cuda().half() 89 90 with config.patch( 91 { 92 "max_autotune": True, 93 "autotune_in_subproc": True, 94 "max_autotune_gemm_backends": "CUTLASS,ATen", 95 "compile_threads": 4, 96 "cuda.cutlass_backend_min_gemm_size": 100000, 97 "cuda.cutlass_dir": _CUTLASS_DIR, 98 "cuda.cutlass_max_profiling_configs": 2, 99 } 100 ): 101 from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller 102 103 with mock.patch( 104 "torch._inductor.select_algorithm.autotune_select_algorithm" 105 ) as mocked_select_algorithm: 106 Y_compiled = torch.compile(mm, dynamic=False)(a, b) 107 Y = mm(a, b) 108 passed_choice_callers: List[ChoiceCaller] = mocked_select_algorithm[0][ 109 1 110 ] 111 assert all( 112 isinstance(cc, ChoiceCaller) for cc in passed_choice_callers 113 ), "Argument 1 to autotune_select_algorithm should be a list of ChoiceCaller instances" 114 # We expect that no Cutlass Kernels are considered, due to the threshold 115 assert all( 116 not isinstance(cc, CUDATemplateCaller) 117 for cc in passed_choice_callers 118 ), "Cutlass Kernels should have been filtered, GEMM size is too small" 119 torch.testing.assert_close(Y_compiled, Y) 120 121 @unittest.skipIf(not SM75OrLater, "need sm_75") 122 @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") 123 @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) 124 def test_max_autotune_precompile(self): 125 """ 126 Make sure autotuning mm in sub processes work without crashes. 127 """ 128 129 if torch.version.hip: 130 return 131 132 torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False 133 134 def mm(a, b): 135 return a @ b 136 137 a = torch.randn(100, 10).cuda().half() 138 b = torch.randn(10, 100).cuda().half() 139 140 with config.patch( 141 { 142 "max_autotune": True, 143 "autotune_in_subproc": True, 144 "max_autotune_gemm_backends": "CUTLASS,Triton,ATen", 145 "compile_threads": 4, 146 "cuda.cutlass_dir": _CUTLASS_DIR, 147 "cuda.cutlass_max_profiling_configs": 2, 148 } 149 ): 150 Y_compiled = torch.compile(mm, dynamic=False)(a, b) 151 Y = mm(a, b) 152 torch.testing.assert_close(Y_compiled, Y) 153 154 # TODO: Enable dynamic test cases when dynamic support is added. 155 @unittest.skipIf(not SM75OrLater, "need sm_75") 156 @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") 157 @parametrize("dynamic", (False, True)) 158 @parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen,Triton,CUTLASS")) 159 @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) 160 def test_max_autotune_cutlass_backend_regular_mm( 161 self, dynamic: bool, max_autotune_gemm_backends: str 162 ): 163 """ 164 Make sure autotuning mm in sub processes work without crashes. 165 """ 166 167 if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip: 168 return 169 170 torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False 171 172 def mm(a, b): 173 return a @ b 174 175 a = torch.randn(128, 16).cuda().half() 176 b = torch.randn(16, 128).cuda().half() 177 178 with config.patch( 179 { 180 "max_autotune": True, 181 "autotune_in_subproc": False, 182 "max_autotune_gemm_backends": max_autotune_gemm_backends, 183 "cuda.cutlass_dir": _CUTLASS_DIR, 184 "cuda.cutlass_max_profiling_configs": 2, 185 } 186 ): 187 Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b) 188 Y = mm(a, b) 189 torch.testing.assert_close(Y_compiled, Y) 190 191 @unittest.skipIf(not SM90OrLater, "need sm_90") 192 @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") 193 @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) 194 def test_max_autotune_cutlass_backend_regular_mm_streamk( 195 self, dynamic: bool = False, max_autotune_gemm_backends: str = "CUTLASS" 196 ): 197 """ 198 Make sure autotuning mm in sub processes work without crashes. 199 """ 200 201 if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip: 202 return 203 204 torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False 205 206 def mm(a, b): 207 return a @ b 208 209 a = torch.randn(128, 16).cuda().half() 210 b = torch.randn(16, 128).cuda().half() 211 212 with config.patch( 213 { 214 "max_autotune": True, 215 "autotune_in_subproc": True, 216 "max_autotune_gemm_backends": max_autotune_gemm_backends, 217 "cuda.cutlass_dir": _CUTLASS_DIR, 218 "cuda.cutlass_max_profiling_configs": 2, 219 "cuda.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels 220 } 221 ): 222 for M, K, N in ( 223 (128, 16, 128), 224 (1024, 256, 1024), 225 ( 226 16384, 227 1024, 228 16384, 229 ), 230 ( 231 16384, 232 1408, 233 16384, 234 ), 235 ): 236 a = torch.randn(M, K).cuda().half() 237 b = torch.randn(K, N).cuda().half() 238 Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b) 239 Y = mm(a, b) 240 # we need relaxed numerical limits due to the sheer size of the 241 # matmuls involved. Many small addition differences add up. 242 torch.testing.assert_close(Y_compiled, Y, atol=0.01, rtol=0.01) 243 244 def _test_max_autotune_cutlass_backend_epilogue_fusion( 245 self, 246 dynamic: bool = False, 247 max_autotune_gemm_backends: str = "CUTLASS", 248 mixed_precision=False, 249 fp16=True, 250 expected_fuse_count=0, 251 mm: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None, 252 batch_size: Optional[int] = None, 253 ): 254 torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = ( 255 mixed_precision 256 ) 257 258 # Note: The ops that are available 259 # also depend on the alignment of the shapes 260 # so if these shapes don't all align to at least 8 elements 261 # it can happen that no Cutlass 3.x op is available 262 # that allows fusions 263 if batch_size is None: 264 a = torch.randn(256, 32).cuda() 265 b = torch.randn(32, 256).cuda() 266 else: 267 a = torch.randn(batch_size, 256, 32).cuda() 268 b = torch.randn(batch_size, 32, 256).cuda() 269 if fp16: 270 a = a.half() 271 b = b.half() 272 273 with config.patch( 274 { 275 "max_autotune": True, 276 "autotune_in_subproc": True, 277 "max_autotune_gemm_backends": max_autotune_gemm_backends, 278 "cuda.cutlass_dir": _CUTLASS_DIR, 279 "cuda.cutlass_max_profiling_configs": 4, 280 "cuda.version": "12.2", # required to enable the Kernels we need 281 } 282 ): 283 counters["inductor"]["cuda_epilogue_fusion_counter"] = 0 284 Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b) 285 Y = mm(a, b) 286 actual_count = counters["inductor"]["cuda_epilogue_fusion_counter"] 287 assert ( 288 actual_count == expected_fuse_count 289 ), f"Expected fuse count of {expected_fuse_count} but got {actual_count}" 290 torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2) 291 292 @unittest.skipIf(not SM90OrLater, "need sm_90") 293 @unittest.skipIf(torch.version.hip, "HIP not supported") 294 @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") 295 def test_max_autotune_cutlass_backend_simple_fusion_fp16(self): 296 def mm(a, b): 297 return (a @ b) * 3.0 298 299 # The pointwise ops seem to be pre-fused into a single Pointwise 300 self._test_max_autotune_cutlass_backend_epilogue_fusion( 301 mixed_precision=False, fp16=True, expected_fuse_count=0, mm=mm 302 ) 303 304 @unittest.skipIf(not SM90OrLater, "need sm_90") 305 @unittest.skipIf(torch.version.hip, "HIP not supported") 306 @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") 307 def test_max_autotune_cutlass_backend_simple_fusion_fp16_fp32acc(self): 308 def mm(a, b): 309 return (a @ b) * 3.0 310 311 self._test_max_autotune_cutlass_backend_epilogue_fusion( 312 mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm 313 ) 314 315 @unittest.skipIf(not SM90OrLater, "need sm_90") 316 @unittest.skipIf(torch.version.hip, "HIP not supported") 317 @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") 318 def test_max_autotune_cutlass_backend_chained_fusion_fp16(self): 319 def mm(a, b): 320 return (a @ b) * 3.3 - 1.234 321 322 # The pointwise ops seem to be pre-fused into a single Pointwise 323 self._test_max_autotune_cutlass_backend_epilogue_fusion( 324 mixed_precision=False, fp16=True, expected_fuse_count=0, mm=mm 325 ) 326 327 @unittest.skipIf(not SM90OrLater, "need sm_90") 328 @unittest.skipIf(torch.version.hip, "HIP not supported") 329 @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") 330 def test_max_autotune_cutlass_backend_chained_fusion_fp16_fp32acc(self): 331 def mm(a, b): 332 return (a @ b) * 3.3 - 1.234 333 334 self._test_max_autotune_cutlass_backend_epilogue_fusion( 335 mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm 336 ) 337 338 @unittest.skipIf(not SM90OrLater, "need sm_90") 339 @unittest.skipIf(torch.version.hip, "HIP not supported") 340 @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") 341 def test_max_autotune_cutlass_backend_relu_fusion_fp16(self): 342 def mm(a, b): 343 return torch.nn.functional.relu((a @ b) * 3.3 - 1.234) 344 345 self._test_max_autotune_cutlass_backend_epilogue_fusion( 346 mixed_precision=False, fp16=True, expected_fuse_count=0, mm=mm 347 ) 348 349 @unittest.skipIf(not SM90OrLater, "need sm_90") 350 @unittest.skipIf(torch.version.hip, "HIP not supported") 351 @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") 352 def test_max_autotune_cutlass_backend_relu_fusion_fp16_fp32acc(self): 353 def mm(a, b): 354 return torch.nn.functional.relu((a @ b) * 3.3 - 1.234) 355 356 # The pointwise ops seem to be pre-fused into a single Pointwise 357 self._test_max_autotune_cutlass_backend_epilogue_fusion( 358 mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm 359 ) 360 361 @unittest.skipIf(not SM90OrLater, "need sm_90") 362 @unittest.skipIf(torch.version.hip, "HIP not supported") 363 @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") 364 def test_max_autotune_cutlass_backend_relu6_fusion_fp16_fp32acc(self): 365 def mm(a, b): 366 return torch.clamp(torch.nn.functional.relu(a @ b), max=6.0) 367 368 # The pointwise ops seem to be pre-fused into a single Pointwise 369 self._test_max_autotune_cutlass_backend_epilogue_fusion( 370 mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm 371 ) 372 373 @unittest.skipIf(not SM90OrLater, "need sm_90") 374 @unittest.skipIf(torch.version.hip, "HIP not supported") 375 @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") 376 def test_max_autotune_cutlass_backend_no_fusion_dtype_mismatch(self): 377 def mm(a, b): 378 # this should not be fused, since the output dtype is different from the matmul dtype 379 return (a @ b).to(torch.float32) * 0.00001 380 381 self._test_max_autotune_cutlass_backend_epilogue_fusion( 382 mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm 383 ) 384 385 def test_max_autotune_cutlass_backend_simple_bmm(self): 386 def bmm(a, b): 387 return torch.bmm(a, b) 388 389 self._test_max_autotune_cutlass_backend_epilogue_fusion( # test bmm 390 mixed_precision=False, 391 fp16=True, 392 expected_fuse_count=0, 393 mm=bmm, 394 batch_size=10, 395 ) 396 397 @unittest.skipIf(not SM90OrLater, "need sm_90") 398 @unittest.skipIf(torch.version.hip, "HIP not supported") 399 @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") 400 def test_max_autotune_cutlass_backend_shape_dependent_normalization_fusion(self): 401 def mm(a, b): 402 return (a @ b) / b.size(1) 403 404 self._test_max_autotune_cutlass_backend_epilogue_fusion( 405 mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm 406 ) 407 408 # TODO: Enable dynamic test cases when dynamic support is added. 409 @unittest.skipIf(not SM75OrLater, "need sm_75") 410 @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") 411 @parametrize("dynamic", (False,)) 412 @parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen,Triton,CUTLASS")) 413 @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) 414 def test_max_autotune_cutlass_backend_mm_bias( 415 self, dynamic: bool = False, max_autotune_gemm_backends: str = "CUTLASS" 416 ): 417 """ 418 Make sure autotuning mm in sub processes work without crashes. 419 """ 420 421 if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip: 422 return 423 424 torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False 425 426 def mm(a, b, bias): 427 return torch.nn.functional.linear(a, b, bias) 428 429 a = torch.randn(2048, 4096).cuda().half() 430 bias = torch.randn(2048).cuda().half() 431 432 with config.patch( 433 { 434 "max_autotune": True, 435 "autotune_in_subproc": True, 436 "max_autotune_gemm_backends": max_autotune_gemm_backends, 437 "cuda.cutlass_dir": _CUTLASS_DIR, 438 "cuda.cutlass_max_profiling_configs": 2, 439 } 440 ): 441 Y = mm(a, a, bias) 442 Y_compiled = torch.compile(mm, dynamic=dynamic)(a, a, bias) 443 torch.testing.assert_close(Y_compiled, Y, atol=1e-1, rtol=1e-1) 444 445 @unittest.skipIf(not SM75OrLater, "need sm_75") 446 @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") 447 @parametrize("dynamic", (False,)) 448 @parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen,Triton,CUTLASS")) 449 @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) 450 def test_max_autotune_cutlass_backend_addmm( 451 self, dynamic, max_autotune_gemm_backends 452 ): 453 """ 454 Make sure autotuning addmm in sub processes work without crashes. 455 """ 456 457 if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip: 458 return 459 460 torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False 461 462 def addmm(x, a, b, alpha, beta): 463 return torch.addmm(x, a, b, alpha=alpha, beta=beta) 464 465 def compare_results( 466 m: int, k: int, n: int, alpha: float, beta: float, x_shape: List[int] 467 ) -> None: 468 x = torch.randn(x_shape).cuda().half() 469 a = torch.randn(m, k).cuda().half() 470 b = torch.randn(k, n).cuda().half() 471 y_expected = addmm(x, a, b, alpha, beta) 472 473 compiled_fn = torch.compile(addmm, dynamic=dynamic) 474 y = compiled_fn(x, a, b, alpha, beta) 475 torch.testing.assert_close(y, y_expected) 476 477 with config.patch( 478 { 479 "max_autotune": True, 480 # Some Cutlass Kernels fail with IMA on this example, which leads to unrecoverable CUDA errors 481 # unless we tune in a subproc here. 482 "autotune_in_subproc": True, 483 "max_autotune_gemm_backends": max_autotune_gemm_backends, 484 "cuda.cutlass_dir": _CUTLASS_DIR, 485 "cuda.cutlass_max_profiling_configs": 4, 486 "cuda.cutlass_op_allowlist_regex": "", 487 "cuda.cutlass_op_denylist_regex": "pingpong", # Pingpong Kernels can lead to numerical issues 488 } 489 ): 490 # No broadcast 491 compare_results(4096, 25728, 2048, 2.0, 0.4, [4096, 2048]) 492 # Broadcast first dim. 493 compare_results(4096, 25728, 2048, 2.0, 0.4, [2048]) 494 # Broadcast last dim. 495 compare_results(4096, 25728, 2048, 2.0, 0.4, [4096, 1]) 496 497 # TODO: Enable dynamic test cases when dynamic support is added. 498 @unittest.skipIf(not SM80OrLater, "need sm_80") 499 @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") 500 @parametrize("dynamic", (False,)) 501 @parametrize("max_autotune_gemm_backends", ("CUTLASS", "CUTLASS,ATen")) 502 @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) 503 def test_max_autotune_cutlass_backend_int_mm( 504 self, dynamic: bool, max_autotune_gemm_backends: str 505 ): 506 """ 507 Make sure autotuning mm in sub processes work without crashes. 508 """ 509 510 if "CUTLASS" in max_autotune_gemm_backends.upper() and torch.version.hip: 511 return 512 513 def mm(a, b): 514 return torch._int_mm(a, b) 515 516 # CUTLASS only supports row-major/column-major combination of 517 # layouts for this operation, thus the transpose of tensor b 518 # (on the other side, Triton at the moment doesn't support 519 # this combination, so it's excluded from the test). Also, 520 # for CUTLASS alignment requirements, number of columns in 521 # both tensors has to be divisible by 16. 522 a = torch.randint(0, 5, (100, 16), dtype=torch.int8).cuda() 523 b = torch.randint(0, 5, (32, 16), dtype=torch.int8).cuda().T 524 525 with config.patch( 526 { 527 "max_autotune": True, 528 "autotune_in_subproc": True, 529 "max_autotune_gemm_backends": max_autotune_gemm_backends, 530 "cuda.cutlass_dir": _CUTLASS_DIR, 531 "cuda.cutlass_max_profiling_configs": 2, 532 } 533 ): 534 Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b) 535 Y = mm(a, b) 536 torch.testing.assert_close(Y_compiled, Y) 537 538 # TODO: Enable dynamic test cases when dynamic support is added. 539 @unittest.skipIf(not SM80, "need sm_80 exactly") 540 @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") 541 @parametrize("dynamic", (False,)) 542 @parametrize("max_autotune_gemm_backends", ("CUTLASS", "CUTLASS,Triton,ATen")) 543 @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) 544 def test_max_autotune_cutlass_backend_mixed_mm( 545 self, dynamic: bool, max_autotune_gemm_backends: str 546 ): 547 """ 548 Make sure autotuning mm in sub processes work without crashes. 549 """ 550 551 if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip: 552 return 553 554 torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False 555 556 def mm(a, b): 557 return torch.mm(a, b.to(torch.half)) 558 559 # CUTLASS only supports row-major/column-major combination of 560 # layouts for this operation, thus the transpose of tensor b. 561 # Also, for CUTLASS alignment requirements, number of columns 562 # of the first tensor has to be divisible by 16. 563 m, n, k = 100, 16, 100 564 a = torch.randn(m, k).cuda().half() 565 b = torch.randint(0, 5, (n, k), dtype=torch.int8).cuda().T 566 567 with config.patch( 568 { 569 "max_autotune": True, 570 "autotune_in_subproc": True, 571 "max_autotune_gemm_backends": max_autotune_gemm_backends, 572 "cuda.cutlass_dir": _CUTLASS_DIR, 573 "cuda.cutlass_max_profiling_configs": 2, 574 "use_mixed_mm": True, 575 "autotune_local_cache": True, 576 } 577 ): 578 Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b) 579 Y = mm(a, b) 580 torch.testing.assert_close(Y_compiled, Y) 581 582 cache = torch._inductor.codecache.LocalCache().lookup("mixed_mm") 583 high = cache[ 584 f"[('cuda', 'torch.float16', {m}, {k}, {k}, 1, 0), " 585 f"('cuda', 'torch.int8', {k}, {n}, 1, {k}, 0)]" 586 ]["high"] 587 cutlass_kernels_count = 0 588 for kernel, time in high.items(): 589 if kernel.startswith("cutlass_gemm") and not math.isinf(time): 590 cutlass_kernels_count += 1 591 assert cutlass_kernels_count > 0 592 593 # TODO: Enable dynamic test cases when dynamic support is added. 594 @unittest.skipIf(not SM80, "need sm_80 exactly") 595 @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") 596 @parametrize("dynamic", (False,)) 597 @parametrize("max_autotune_gemm_backends", ("CUTLASS", "CUTLASS,Triton,ATen")) 598 @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) 599 def test_max_autotune_cutlass_backend_sparse_semi_structured_mm( 600 self, dynamic: bool, max_autotune_gemm_backends: str 601 ): 602 """ 603 Make sure autotuning mm in sub processes work without crashes. 604 """ 605 606 if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip: 607 return 608 609 SparseSemiStructuredTensor._FORCE_CUTLASS = True 610 611 def mm(a, b): 612 return torch.mm(a, b) 613 614 m, n, k = 32, 8, 64 615 mask = torch.tensor([0, 0, 1, 1]).tile(m, k // 4).cuda().half() 616 a = torch.rand(m, k).cuda().half() * mask 617 a_sparse = to_sparse_semi_structured(a) 618 b = torch.rand(k, n).cuda().half() 619 620 with config.patch( 621 { 622 "max_autotune": True, 623 "autotune_in_subproc": True, 624 "max_autotune_gemm_backends": max_autotune_gemm_backends, 625 "cuda.cutlass_dir": _CUTLASS_DIR, 626 "cuda.cutlass_max_profiling_configs": 2, 627 "autotune_local_cache": True, 628 } 629 ): 630 Y_compiled = torch.compile(mm, dynamic=dynamic)(a_sparse, b) 631 Y = mm(a, b) 632 torch.testing.assert_close(Y_compiled, Y) 633 634 cache = torch._inductor.codecache.LocalCache().lookup( 635 "sparse_semi_structured_mm" 636 ) 637 high = cache[ 638 f"[('cuda', 'torch.float16', {m}, {k // 2}, {k // 2}, 1, 0), " 639 f"('cuda', 'torch.int16', {m}, {k // 16}, {k // 16}, 1, 0), " 640 f"('cuda', 'torch.float16', {k}, {n}, {n}, 1, 0)]" 641 ]["high"] 642 cutlass_kernels_count = 0 643 for kernel, time in high.items(): 644 if kernel.startswith("cutlass_gemm") and not math.isinf(time): 645 cutlass_kernels_count += 1 646 assert cutlass_kernels_count > 0 647 648 @unittest.skipIf(not SM90OrLater, "need sm_90") 649 @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") 650 @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) 651 def test_cutlass_backend_op_denylist( 652 self, 653 ): 654 def my_addmm(x, a, b, alpha, beta): 655 return torch.addmm(x, a, b, alpha=beta, beta=alpha) 656 657 x = torch.randn((128, 128)).cuda().half() 658 a = torch.randn(128, 128).cuda().half() 659 b = torch.randn(128, 128).cuda().half() 660 661 def select_no_algorithm(*args, **kwargs): 662 raise NoValidChoicesError 663 664 with fresh_inductor_cache(): 665 with config.patch( 666 { 667 "max_autotune": True, 668 # Some Cutlass Kernels fail with IMA on this example, which leads to unrecoverable CUDA errors 669 # unless we tune in a subproc here. 670 "autotune_in_subproc": False, 671 "max_autotune_gemm_backends": "CUTLASS,ATen", 672 "cuda.cutlass_dir": _CUTLASS_DIR, 673 "cuda.cutlass_max_profiling_configs": 2, 674 "cuda.cutlass_op_allowlist_regex": "", 675 "cuda.cutlass_op_denylist_regex": "pingpong", # Pingpong Kernels can lead to numerical issues 676 } 677 ): 678 with mock.patch( 679 "torch._inductor.kernel.mm.autotune_select_algorithm", 680 wraps=select_no_algorithm, 681 ) as sa: 682 torch.compile(my_addmm, dynamic=False)(x, a, b, 1.0, 2.0) 683 args, kwargs = sa.call_args 684 op_name, choices, _, __ = args 685 assert op_name == "addmm" 686 cuda_template_count = 0 687 for choice in choices: 688 if isinstance(choice, CUDATemplateCaller): 689 choice_info = choice.info_dict() 690 assert ( 691 "pingpong" not in choice_info["op_conf_name"] 692 ), "All pingpong Kernels should have been filtered" 693 cuda_template_count += 1 694 assert cuda_template_count > 0, "No CUDATemplateCaller choices" 695 696 @unittest.skipIf(not SM90OrLater, "need sm_90") 697 @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") 698 @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) 699 def test_cutlass_backend_op_allowlist( 700 self, 701 ): 702 def addmm(x, a, b, alpha, beta): 703 return torch.addmm(x, a, b, alpha=alpha, beta=beta) 704 705 x = torch.randn((128, 128)).cuda().half() 706 a = torch.randn(128, 128).cuda().half() 707 b = torch.randn(128, 128).cuda().half() 708 709 def select_no_algorithm(*args, **kwargs): 710 raise NoValidChoicesError 711 712 with fresh_inductor_cache(): 713 with config.patch( 714 { 715 "max_autotune": True, 716 # Some Cutlass Kernels fail with IMA on this example, which leads to unrecoverable CUDA errors 717 # unless we tune in a subproc here. 718 "autotune_in_subproc": False, 719 "max_autotune_gemm_backends": "CUTLASS,ATen", 720 "cuda.cutlass_dir": _CUTLASS_DIR, 721 "cuda.cutlass_max_profiling_configs": 2, 722 "cuda.cutlass_op_allowlist_regex": "pingpong", 723 "cuda.cutlass_op_denylist_regex": None, # Pingpong Kernels can lead to numerical issues 724 } 725 ): 726 with mock.patch( 727 "torch._inductor.kernel.mm.autotune_select_algorithm", 728 wraps=select_no_algorithm, 729 ) as sa: 730 torch.compile(addmm, dynamic=False)(x, a, b, 1.0, 1.0) 731 args, kwargs = sa.call_args 732 op_name, choices, _, __ = args 733 assert op_name == "addmm" 734 cuda_template_count = 0 735 for choice in choices: 736 if isinstance(choice, CUDATemplateCaller): 737 choice_info = choice.info_dict() 738 assert ( 739 "pingpong" in choice_info["op_conf_name"] 740 ), "Only pingpong Kernels should have been allowed" 741 cuda_template_count += 1 742 assert cuda_template_count > 0, "No CUDATemplateCaller choices" 743 744 @unittest.skipIf(not SM80OrLater, "need sm_80") 745 @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") 746 @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) 747 def test_get_max_alignment(self): 748 l4 = FixedLayout("cpu", torch.half, size=(1, 2, 4), stride=(0, 4, 1)) 749 m4 = get_max_alignment(l4) 750 self.assertEqual( 751 m4, 4, "Wrong max alignment. Should have been 4. (simple, contiguous case)" 752 ) 753 754 l4_2 = FixedLayout("cpu", torch.half, size=(1, 4, 2), stride=(0, 1, 4)) 755 m4_2 = get_max_alignment(l4_2) 756 self.assertEqual( 757 m4_2, 758 4, 759 "Wrong max alignment. Should have been 4. Did not deal with strides correctly", 760 ) 761 762 l1 = FixedLayout("cpu", torch.half, size=(2, 4, 2), stride=(23, 1, 4)) 763 m1 = get_max_alignment(l1) 764 self.assertEqual( 765 m1, 766 1, 767 "Wrong max alignment. Should have been 1. Did not take stride into account correctly", 768 ) 769 770 l2 = FixedLayout("cpu", torch.half, size=(1, 2, 4), stride=(0, 4, 1), offset=6) 771 m2 = get_max_alignment(l2) 772 self.assertEqual( 773 m2, 2, "Wrong max alignment. Should have been 2. (due to choice of offset)" 774 ) 775 776 l8 = FixedLayout( 777 "cpu", torch.half, size=(2, 2, 8), stride=(32, 8, 1), offset=24 778 ) 779 m8 = get_max_alignment(l8) 780 self.assertEqual(m8, 8, "Wrong max alignment. Should have been 8.") 781 782 l4 = FixedLayout( 783 "cpu", torch.float32, size=(2, 2, 8), stride=(32, 8, 1), offset=24 784 ) 785 m4 = get_max_alignment(l4) 786 self.assertEqual( 787 m4, 4, "Wrong max alignment. Should have been 4 (due to float32 dtype )." 788 ) 789 790 791if __name__ == "__main__": 792 from torch._inductor.utils import is_big_gpu 793 794 # Set env to make it work in CI. 795 if HAS_CUDA and HAS_CPU and is_big_gpu(0): 796 run_tests() 797